1
0
Fork 0
mirror of https://github.com/bolero-MURAKAMI/Sprout synced 2025-08-03 12:49:50 +00:00

add examble: perseptron/g3.cpp brainfuck/x86_compile.cpp

This commit is contained in:
bolero-MURAKAMI 2014-12-27 13:56:41 +09:00
parent 5ccbc4e903
commit 03b4eda323
15 changed files with 1411 additions and 8 deletions

204
example/perceptron/g3.cpp Normal file
View file

@ -0,0 +1,204 @@
/*=============================================================================
Copyright (c) 2011-2014 Bolero MURAKAMI
https://github.com/bolero-MURAKAMI/Sprout
Distributed under the Boost Software License, Version 1.0. (See accompanying
file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
=============================================================================*/
#include <cstddef>
#include <sprout/config.hpp>
#include <sprout/array.hpp>
#include <sprout/algorithm/copy.hpp>
#include <sprout/iterator/operation.hpp>
#include <sprout/container/functions.hpp>
#include <sprout/math/sigmoid.hpp>
#include <sprout/random/uniform_01.hpp>
#include <sprout/random/generate_array.hpp>
#include <sprout/assert.hpp>
//
// perceptron
//
template<typename FloatType, std::size_t In, std::size_t Hid, std::size_t Out>
class perceptron {
public:
typedef FloatType value_type;
private:
struct worker {
public:
// 入力
sprout::array<value_type, In + 1> xi1;
sprout::array<value_type, Hid + 1> xi2;
sprout::array<value_type, Out> xi3;
// 出力
sprout::array<value_type, In + 1> o1;
sprout::array<value_type, Hid + 1> o2;
sprout::array<value_type, Out> o3;
};
private:
// 誤差
sprout::array<value_type, Hid + 1> d2;
sprout::array<value_type, Out> d3;
// 重み
sprout::array<value_type, (In + 1) * Hid> w1;
sprout::array<value_type, (Hid + 1) * Out> w2;
private:
// 順伝播
template<typename ForwardIterator>
SPROUT_CXX14_CONSTEXPR void
forward_propagation(ForwardIterator in_first, ForwardIterator in_last, worker& work) const {
// 入力層の順伝播
sprout::copy(in_first, in_last, sprout::begin(work.xi1));
work.xi1[In] = 1;
sprout::copy(sprout::begin(work.xi1), sprout::end(work.xi1), sprout::begin(work.o1));
// 隠れ層の順伝播
for (std::size_t i = 0; i != Hid; ++i) {
work.xi2[i] = 0;
for (std::size_t j = 0; j != In + 1; ++j) {
work.xi2[i] += w1[j * Hid + i] * work.o1[j];
}
work.o2[i] = sprout::math::sigmoid(work.xi2[i]);
}
work.o2[Hid] = 1;
// 出力層の順伝播
for (std::size_t i = 0; i != Hid; ++i) {
work.xi3[i] = 0;
for (std::size_t j = 0; j != In + 1; ++j) {
work.xi3[i] += w2[j * Out + i] * work.o2[j];
}
work.o3[i] = work.xi3[i];
}
}
public:
template<typename RandomNumberGenerator>
explicit SPROUT_CXX14_CONSTEXPR perceptron(RandomNumberGenerator& rng)
: d2{{}}, d3{{}}
, w1(sprout::random::generate_array<(In + 1) * Hid>(rng, sprout::random::uniform_01<value_type>()))
, w2(sprout::random::generate_array<(Hid + 1) * Out>(rng, sprout::random::uniform_01<value_type>()))
{}
// ニューラルネットの訓練
// [in_first, in_last) : 訓練データ (N*In 個)
// [t_first, t_last) : 教師データ (N 個)
template<typename ForwardIterator1, typename ForwardIterator2>
SPROUT_CXX14_CONSTEXPR void
train(
ForwardIterator1 in_first, ForwardIterator1 in_last,
ForwardIterator2 t_first, ForwardIterator2 t_last,
std::size_t repeat = 1000, value_type eta = value_type(0.1)
)
{
SPROUT_ASSERT(sprout::distance(in_first, in_last) % In == 0);
SPROUT_ASSERT(sprout::distance(in_first, in_last) / In == sprout::distance(t_first, t_last));
worker work{};
for (std::size_t times = 0; times != repeat; ++times) {
ForwardIterator1 in_it = in_first;
ForwardIterator2 t_it = t_first;
for (; in_it != in_last; sprout::advance(in_it, In), ++t_it) {
// 順伝播
forward_propagation(in_it, sprout::next(in_it, In), work);
// 出力層の誤差計算
for (std::size_t i = 0; i != Out; ++i) {
d3[i] = *t_it == i ? work.o3[i] - 1
: work.o3[i]
;
}
// 出力層の重み更新
for (std::size_t i = 0; i != Hid + 1; ++i) {
for (std::size_t j = 0; j != Out; ++j) {
w2[i * Out + j] -= eta * d3[j] * work.o2[i];
}
}
// 隠れ層の誤差計算
for (std::size_t i = 0; i != Hid + 1; ++i) {
d2[i] = 0;
for (std::size_t j = 0; j != Out; ++j) {
d2[i] += w2[i * Out + j] * d3[j];
}
d2[i] *= sprout::math::d_sigmoid(work.xi2[i]);
}
// 隠れ層の重み更新
for (std::size_t i = 0; i != In + 1; ++i) {
for (std::size_t j = 0; j != Hid; ++j) {
w1[i * Hid + j] -= eta * d2[j] * work.o1[i];
}
}
}
}
}
// 与えられたデータに対して最も可能性の高いクラスを返す
template<typename ForwardIterator>
SPROUT_CXX14_CONSTEXPR std::size_t
predict(ForwardIterator in_first, ForwardIterator in_last) const {
SPROUT_ASSERT(sprout::distance(in_first, in_last) == In);
worker work{};
// 順伝播による予測
forward_propagation(in_first, in_last, work);
// 出力が最大になるクラスを判定
return sprout::distance(
sprout::begin(work.o3),
sprout::max_element(sprout::begin(work.o3), sprout::end(work.o3))
);
}
};
#include <cstddef>
#include <iostream>
#include <sprout/config.hpp>
#include <sprout/random/default_random_engine.hpp>
#include <sprout/random/unique_seed.hpp>
#include <sprout/static_assert.hpp>
// 訓練データ
SPROUT_CONSTEXPR auto train_data = sprout::make_array<double>(
# include "g3_train.csv"
);
// 教師データ
SPROUT_CONSTEXPR auto teach_data = sprout::make_array<std::size_t>(
# include "g3_teach.csv"
);
SPROUT_STATIC_ASSERT(train_data.size() % 2 == 0);
SPROUT_STATIC_ASSERT(train_data.size() / 2 == teach_data.size());
// 訓練済みパーセプトロンを生成
template<typename FloatType, std::size_t In, std::size_t Hid, std::size_t Out>
SPROUT_CXX14_CONSTEXPR ::perceptron<FloatType, In, Hid, Out>
make_trained_perceptron() {
// 乱数生成器
sprout::random::default_random_engine rng(SPROUT_UNIQUE_SEED);
// パーセプトロン
::perceptron<FloatType, In, Hid, Out> per(rng);
// 訓練
per.train(
train_data.begin(), train_data.end(),
teach_data.begin(), teach_data.end(),
500, 0.1
);
return per;
}
int main() {
// パーセプトロンを生成入力2 隠れ3 出力3
SPROUT_CXX14_CONSTEXPR auto per = ::make_trained_perceptron<double, 2, 3, 3>();
// 結果の表示
for (auto it = train_data.begin(), last = train_data.end(); it != last; it += 2) {
std::cout << per.predict(it, it + 2) << std::endl;
}
}

View file

@ -0,0 +1,150 @@
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2
Can't render this file because it has a wrong number of fields in line 150.

View file

@ -0,0 +1,150 @@
2.838248351,-0.057127115,
2.84962986,0.30984388,
3.069875981,0.069165603,
3.033612564,0.237406102,
3.162983688,0.334387888,
3.163831849,0.487806933,
2.997986085,0.061074061,
2.799347441,0.521430032,
3.137156052,0.266393153,
2.870857355,0.267030982,
3.43934308,0.170662698,
3.035552198,0.251317543,
2.886143741,0.125075012,
3.221207099,0.030493356,
3.050238248,0.007279557,
2.733967037,0.235668448,
3.153513285,-0.010778431,
3.311231402,0.343325475,
3.298245949,0.065857414,
2.848633328,-0.004875479,
3.070831457,0.211490982,
3.082296334,0.31440897,
2.926651655,-0.240105616,
2.978054772,0.031004937,
2.35804167,0.239646483,
2.645048857,0.159202643,
2.932230736,0.104463435,
3.097376157,0.150507296,
2.803701193,0.042427284,
2.953016638,0.28536828,
3.238065532,0.346081055,
3.055363223,0.445131721,
2.995237434,0.507696093,
2.804479208,0.245421171,
3.012620496,0.384000638,
2.693411745,0.312330244,
3.102647408,0.343171956,
2.896095389,0.036131501,
2.718440992,-0.01150165,
3.170644375,-0.048096883,
3.011452172,0.034237664,
2.872141309,0.07956016,
2.865096742,0.008559658,
3.279540305,0.302372444,
2.860098477,0.218127155,
3.269317565,-0.005452377,
2.761144719,0.673127168,
2.902469266,0.361832898,
3.11297363,0.50154383,
3.025469568,0.399359182,
1.840419314,1.905105063,
1.813240148,1.813839986,
2.123039813,2.132571185,
1.946192802,1.773046885,
1.969172236,1.805946832,
2.258526017,2.000929535,
1.864577066,2.144995459,
1.911976195,1.980181472,
1.990154583,1.919792502,
1.769958598,2.105082362,
2.137578551,1.944983188,
2.00439482,2.072477229,
1.927718762,1.694463892,
2.207399026,1.851368634,
2.495855335,1.78918729,
2.381107166,2.040305134,
2.05695779,2.458901875,
1.782594688,2.133780789,
1.731552089,2.188212826,
2.157206504,2.249346956,
1.771550549,2.388096625,
1.596862739,2.285617904,
2.116253295,2.043142148,
2.051844958,2.033143301,
1.931137989,2.426396361,
1.987232105,2.141510946,
2.265269874,2.150853439,
2.02603174,2.015376408,
1.95851869,2.243280964,
1.982323697,1.705817048,
1.860646929,2.221006838,
1.850979072,2.557549512,
2.245893751,1.851633418,
2.225460386,1.968984437,
1.817314806,2.306871098,
2.28503916,2.17214804,
1.974994597,1.313041694,
2.031987688,2.022965729,
1.953448998,2.261221941,
2.167787963,1.950768156,
2.325040244,1.533073263,
2.058211845,2.0741863,
2.047485215,2.098219462,
2.228436076,1.920446541,
2.328240588,2.151820146,
2.183441475,2.292478916,
2.190968871,1.760135247,
1.716478165,2.078620922,
1.878852022,2.15075819,
2.085908678,1.96855502,
3.798041143,2.925502959,
3.835899929,3.004531872,
4.226026357,3.327911034,
3.63906142,2.764149093,
3.976732793,2.765971538,
3.735533807,3.211145512,
3.922743188,2.908463331,
3.959548749,2.92371863,
4.220358369,3.028638464,
4.101855229,2.860451817,
3.818023487,2.821524824,
3.944688596,3.251039443,
4.139128323,3.071444196,
4.235165326,3.005161748,
4.248626372,3.156115907,
4.286060111,2.952368701,
3.974827507,2.924985851,
4.140698521,3.042471222,
4.31503215,2.912715016,
4.41718743,3.064667839,
3.894120786,2.925224863,
3.953199682,3.348500249,
4.003736459,3.208318835,
3.872457887,2.698607706,
3.889675701,2.977512411,
3.974258739,3.099134752,
3.794940485,3.033699013,
4.020409067,2.784229901,
3.851699136,2.972769818,
4.036060943,3.066876325,
3.590866735,3.062219821,
4.000802489,3.041278742,
4.254308353,2.947953876,
4.126845488,2.932815351,
3.807568611,3.145490402,
3.986677755,3.378673859,
3.684653059,2.911221769,
4.259008498,3.374089548,
3.964802614,3.072807287,
4.182124157,2.950380575,
3.948426268,3.099264552,
4.053742816,3.195652528,
4.340603893,3.098341999,
3.67327559,2.94969907,
4.043519751,3.311304107,
3.97872525,3.166533465,
4.19023214,2.962725503,
3.717902917,2.75762358,
3.947763326,3.160248033,
4.191797555,3.162533112
Can't render this file because it has a wrong number of fields in line 150.