From 44b67d28e208bec9d5a32f895282a4cd22dbe5fd Mon Sep 17 00:00:00 2001 From: bolero-MURAKAMI Date: Wed, 19 Oct 2011 22:47:59 +0900 Subject: [PATCH] =?UTF-8?q?random/normal=5Fdistribution.hpp=20=E8=BF=BD?= =?UTF-8?q?=E5=8A=A0=20numeric/iota.hpp=20=E8=BF=BD=E5=8A=A0=20algorithm/s?= =?UTF-8?q?huffle.hpp=20=E8=BF=BD=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sprout/algorithm/fit.hpp | 1 + sprout/algorithm/fit/shuffle.hpp | 43 ++++ sprout/algorithm/fixed.hpp | 1 + sprout/algorithm/fixed/shuffle.hpp | 118 ++++++++++ sprout/algorithm/shuffle.hpp | 8 + sprout/numeric.hpp | 8 + sprout/numeric/fit.hpp | 7 + sprout/numeric/fit/iota.hpp | 42 ++++ sprout/numeric/fixed.hpp | 7 + sprout/numeric/fixed/iota.hpp | 56 +++++ sprout/numeric/iota.hpp | 8 + sprout/random.hpp | 1 + sprout/random/bernoulli_distribution.hpp | 5 +- sprout/random/binomial_distribution.hpp | 5 +- sprout/random/geometric_distribution.hpp | 5 +- sprout/random/normal_distribution.hpp | 246 ++++++++++++++++++++ sprout/random/uniform_int_distribution.hpp | 9 +- sprout/random/uniform_real_distribution.hpp | 9 +- sprout/random/uniform_smallint.hpp | 9 +- 19 files changed, 570 insertions(+), 18 deletions(-) create mode 100644 sprout/algorithm/fit/shuffle.hpp create mode 100644 sprout/algorithm/fixed/shuffle.hpp create mode 100644 sprout/algorithm/shuffle.hpp create mode 100644 sprout/numeric.hpp create mode 100644 sprout/numeric/fit.hpp create mode 100644 sprout/numeric/fit/iota.hpp create mode 100644 sprout/numeric/fixed.hpp create mode 100644 sprout/numeric/fixed/iota.hpp create mode 100644 sprout/numeric/iota.hpp create mode 100644 sprout/random/normal_distribution.hpp diff --git a/sprout/algorithm/fit.hpp b/sprout/algorithm/fit.hpp index 25507e55..f6530e11 100644 --- a/sprout/algorithm/fit.hpp +++ b/sprout/algorithm/fit.hpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include diff --git a/sprout/algorithm/fit/shuffle.hpp b/sprout/algorithm/fit/shuffle.hpp new file mode 100644 index 00000000..af64c5ab --- /dev/null +++ b/sprout/algorithm/fit/shuffle.hpp @@ -0,0 +1,43 @@ +#ifndef SPROUT_ALGORITHM_FIT_SHUFFLE_HPP +#define SPROUT_ALGORITHM_FIT_SHUFFLE_HPP + +#include +#include +#include +#include +#include +#include +#include + +namespace sprout { + namespace fit { + namespace detail { + template + SPROUT_CONSTEXPR inline typename sprout::fit::result_of::algorithm::type shuffle_impl( + Container const& cont, + UniformRandomNumberGenerator&& g, + typename sprout::fixed_container_traits::difference_type offset + ) + { + return sprout::sub_copy( + sprout::get_fixed(sprout::fixed::shuffle(cont, sprout::forward(g))), + offset, + offset + sprout::size(cont) + ); + } + } // namespace detail + // + // shuffle + // + template + SPROUT_CONSTEXPR inline typename sprout::fit::result_of::algorithm::type shuffle( + Container const& cont, + UniformRandomNumberGenerator&& g + ) + { + return sprout::fit::detail::shuffle_impl(cont, sprout::forward(g), sprout::fixed_begin_offset(cont)); + } + } // namespace fit +} // namespace sprout + +#endif // #ifndef SPROUT_ALGORITHM_FIT_SHUFFLE_HPP diff --git a/sprout/algorithm/fixed.hpp b/sprout/algorithm/fixed.hpp index 4352f602..384293b7 100644 --- a/sprout/algorithm/fixed.hpp +++ b/sprout/algorithm/fixed.hpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include diff --git a/sprout/algorithm/fixed/shuffle.hpp b/sprout/algorithm/fixed/shuffle.hpp new file mode 100644 index 00000000..292bb75c --- /dev/null +++ b/sprout/algorithm/fixed/shuffle.hpp @@ -0,0 +1,118 @@ +#ifndef SPROUT_ALGORITHM_FIXED_SHUFFLE_HPP +#define SPROUT_ALGORITHM_FIXED_SHUFFLE_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace sprout { + namespace fixed { + namespace detail { + template + SPROUT_CONSTEXPR inline sprout::array make_shuffle_indexes_1( + std::ptrdiff_t n, + Random const& rnd, + sprout::array const& arr, + std::ptrdiff_t i + ) + { + return i < n + ? sprout::fixed::detail::make_shuffle_indexes_1( + n, + rnd(), + sprout::fixed::swap_element(arr, arr.begin() + i, arr.begin() + rnd.result()), + i + 1 + ) + : arr + ; + } + template + SPROUT_CONSTEXPR inline sprout::array make_shuffle_indexes( + std::ptrdiff_t n, + UniformRandomNumberGenerator&& g + ) + { + + return n > 0 + ? sprout::fixed::detail::make_shuffle_indexes_1( + n, + sprout::random::uniform_int_distribution(0, n - 1)(sprout::forward(g)), + sprout::fixed::iota(sprout::null_array >(), 0), + 0 + ) + : sprout::array{{}} + ; + } + template + SPROUT_CONSTEXPR inline typename sprout::fixed::result_of::algorithm::type shuffle_impl_1( + Container const& cont, + sprout::index_tuple, + Shuffled const& shuffled, + typename sprout::fixed_container_traits::difference_type offset, + typename sprout::fixed_container_traits::size_type size + ) + { + return sprout::remake_clone( + cont, + sprout::size(cont), + (Indexes >= offset && Indexes < offset + size + ? *sprout::next(sprout::begin(cont), shuffled[Indexes - offset]) + : *sprout::next(sprout::fixed_begin(cont), Indexes) + )... + ); + } + template + SPROUT_CONSTEXPR inline typename sprout::fixed::result_of::algorithm::type shuffle_impl( + Container const& cont, + sprout::index_tuple indexes, + UniformRandomNumberGenerator&& g, + typename sprout::fixed_container_traits::difference_type offset, + typename sprout::fixed_container_traits::size_type size + ) + { + return sprout::fixed::detail::shuffle_impl_1( + cont, + indexes, + sprout::fixed::detail::make_shuffle_indexes::fixed_size>( + size, + sprout::forward(g) + ), + offset, + size + ); + } + } // namespace detail + // + // shuffle + // + template + SPROUT_CONSTEXPR inline typename sprout::fixed::result_of::algorithm::type shuffle( + Container const& cont, + UniformRandomNumberGenerator&& g + ) + { + return sprout::fixed::detail::shuffle_impl( + cont, + typename sprout::index_range<0, sprout::fixed_container_traits::fixed_size>::type(), + sprout::forward(g), + sprout::fixed_begin_offset(cont), + sprout::size(cont) + ); + } + } // namespace fixed + + using sprout::fixed::shuffle; +} // namespace sprout + +#endif // #ifndef SPROUT_ALGORITHM_FIXED_SHUFFLE_HPP diff --git a/sprout/algorithm/shuffle.hpp b/sprout/algorithm/shuffle.hpp new file mode 100644 index 00000000..ff6b77d9 --- /dev/null +++ b/sprout/algorithm/shuffle.hpp @@ -0,0 +1,8 @@ +#ifndef SPROUT_ALGORITHM_SHUFFLE_HPP +#define SPROUT_ALGORITHM_SHUFFLE_HPP + +#include +#include +#include + +#endif // #ifndef SPROUT_ALGORITHM_SHUFFLE_HPP diff --git a/sprout/numeric.hpp b/sprout/numeric.hpp new file mode 100644 index 00000000..4d55f62b --- /dev/null +++ b/sprout/numeric.hpp @@ -0,0 +1,8 @@ +#ifndef SPROUT_NUMERIC_HPP +#define SPROUT_NUMERIC_HPP + +#include +#include +#include + +#endif // #ifndef SPROUT_NUMERIC_HPP diff --git a/sprout/numeric/fit.hpp b/sprout/numeric/fit.hpp new file mode 100644 index 00000000..e57351a5 --- /dev/null +++ b/sprout/numeric/fit.hpp @@ -0,0 +1,7 @@ +#ifndef SPROUT_NUMERIC_FIT_HPP +#define SPROUT_NUMERIC_FIT_HPP + +#include +#include + +#endif // #ifndef SPROUT_NUMERIC_FIT_HPP diff --git a/sprout/numeric/fit/iota.hpp b/sprout/numeric/fit/iota.hpp new file mode 100644 index 00000000..b48a149c --- /dev/null +++ b/sprout/numeric/fit/iota.hpp @@ -0,0 +1,42 @@ +#ifndef SPROUT_NUMERIC_FIT_IOTA_HPP +#define SPROUT_NUMERIC_FIT_IOTA_HPP + +#include +#include +#include +#include +#include +#include + +namespace sprout { + namespace fit { + namespace detail { + template + SPROUT_CONSTEXPR inline typename sprout::fit::result_of::algorithm::type iota_impl( + Container const& cont, + T const& value, + typename sprout::fixed_container_traits::difference_type offset + ) + { + return sprout::sub_copy( + sprout::get_fixed(sprout::fixed::iota(cont, value)), + offset, + offset + sprout::size(cont) + ); + } + } // namespace detail + // + // iota + // + template + SPROUT_CONSTEXPR inline typename sprout::fit::result_of::algorithm::type iota( + Container const& cont, + T const& value + ) + { + return sprout::fit::detail::iota_impl(cont, value, sprout::fixed_begin_offset(cont)); + } + } // namespace fit +} // namespace sprout + +#endif // #ifndef SPROUT_NUMERIC_FIT_IOTA_HPP diff --git a/sprout/numeric/fixed.hpp b/sprout/numeric/fixed.hpp new file mode 100644 index 00000000..083ecfa9 --- /dev/null +++ b/sprout/numeric/fixed.hpp @@ -0,0 +1,7 @@ +#ifndef SPROUT_NUMERIC_FIXED_HPP +#define SPROUT_NUMERIC_FIXED_HPP + +#include +#include + +#endif // #ifndef SPROUT_NUMERIC_FIXED_HPP diff --git a/sprout/numeric/fixed/iota.hpp b/sprout/numeric/fixed/iota.hpp new file mode 100644 index 00000000..8b7634a5 --- /dev/null +++ b/sprout/numeric/fixed/iota.hpp @@ -0,0 +1,56 @@ +#ifndef SPROUT_NUMERIC_FIXED_IOTA_HPP +#define SPROUT_NUMERIC_FIXED_IOTA_HPP + +#include +#include +#include +#include +#include +#include +#include + +namespace sprout { + namespace fixed { + namespace detail { + template + SPROUT_CONSTEXPR inline typename sprout::fixed::result_of::algorithm::type iota_impl( + Container const& cont, + sprout::index_tuple, + T value, + typename sprout::fixed_container_traits::difference_type offset, + typename sprout::fixed_container_traits::size_type size + ) + { + return sprout::remake_clone( + cont, + sprout::size(cont), + (Indexes >= offset && Indexes < offset + size + ? value + (Indexes - offset) + : *sprout::next(sprout::fixed_begin(cont), Indexes) + )... + ); + } + } // namespace detail + // + // iota + // + template + SPROUT_CONSTEXPR inline typename sprout::fixed::result_of::algorithm::type iota( + Container const& cont, + T value + ) + { + return sprout::fixed::detail::iota_impl( + cont, + typename sprout::index_range<0, sprout::fixed_container_traits::fixed_size>::type(), + value, + sprout::fixed_begin_offset(cont), + sprout::size(cont) + ); + } + } // namespace fixed + + using sprout::fixed::iota; +} // namespace sprout + +#endif // #ifndef SPROUT_NUMERIC_FIXED_IOTA_HPP diff --git a/sprout/numeric/iota.hpp b/sprout/numeric/iota.hpp new file mode 100644 index 00000000..61088fb0 --- /dev/null +++ b/sprout/numeric/iota.hpp @@ -0,0 +1,8 @@ +#ifndef SPROUT_NUMERIC_IOTA_HPP +#define SPROUT_NUMERIC_IOTA_HPP + +#include +#include +#include + +#endif // #ifndef SPROUT_NUMERIC_IOTA_HPP diff --git a/sprout/random.hpp b/sprout/random.hpp index f288c9f7..ff00afbf 100644 --- a/sprout/random.hpp +++ b/sprout/random.hpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include diff --git a/sprout/random/bernoulli_distribution.hpp b/sprout/random/bernoulli_distribution.hpp index 61c1ad84..d819822b 100644 --- a/sprout/random/bernoulli_distribution.hpp +++ b/sprout/random/bernoulli_distribution.hpp @@ -132,8 +132,9 @@ namespace sprout { ) { param_type parm; - return lhs >> parm; - rhs.param(parm); + if (lhs >> parm) { + rhs.param(parm); + } return lhs; } template diff --git a/sprout/random/binomial_distribution.hpp b/sprout/random/binomial_distribution.hpp index 0bf6b1cd..925c2b89 100644 --- a/sprout/random/binomial_distribution.hpp +++ b/sprout/random/binomial_distribution.hpp @@ -415,8 +415,9 @@ namespace sprout { ) { param_type parm; - lhs >> parm; - rhs.param(parm); + if (lhs >> parm) { + rhs.param(parm); + } return lhs; } template diff --git a/sprout/random/geometric_distribution.hpp b/sprout/random/geometric_distribution.hpp index fd181ba1..ffe5467f 100644 --- a/sprout/random/geometric_distribution.hpp +++ b/sprout/random/geometric_distribution.hpp @@ -157,8 +157,9 @@ namespace sprout { ) { param_type parm; - lhs >> parm; - rhs.param(parm); + if (lhs >> parm) { + rhs.param(parm); + } return lhs; } template diff --git a/sprout/random/normal_distribution.hpp b/sprout/random/normal_distribution.hpp new file mode 100644 index 00000000..0041267d --- /dev/null +++ b/sprout/random/normal_distribution.hpp @@ -0,0 +1,246 @@ +#ifndef SPROUT_RANDOM_NORMAL_DISTRIBUTION_HPP +#define SPROUT_RANDOM_NORMAL_DISTRIBUTION_HPP + +#include +#include +#include +#include +#include +#include +#include + +namespace sprout { + namespace random { + // + // normal_distribution + // + template + class normal_distribution { + public: + typedef RealType input_type; + typedef RealType result_type; + private: + struct private_constructor_tag {}; + private: + SPROUT_STATIC_CONSTEXPR result_type pi = result_type(3.14159265358979323846); + private: + static SPROUT_CONSTEXPR bool arg_check_nothrow(RealType mean_arg, RealType sigma_arg) { + return sigma_arg >= RealType(0); + } + static SPROUT_CONSTEXPR RealType arg_check(RealType mean_arg, RealType sigma_arg) { + return arg_check_nothrow(mean_arg, sigma_arg) + ? mean_arg + : throw "assert(sigma_arg >= RealType(0))" + ; + } + public: + // + // param_type + // + class param_type { + public: + typedef normal_distribution distribution_type; + private: + static SPROUT_CONSTEXPR bool arg_check_nothrow(RealType mean_arg, RealType sigma_arg) { + return distribution_type::arg_check_nothrow(mean_arg, sigma_arg); + } + private: + RealType mean_; + RealType sigma_; + public: + SPROUT_CONSTEXPR param_type() + : mean_(RealType(0.0)) + , sigma_(RealType(1.0)) + {} + SPROUT_CONSTEXPR explicit param_type(RealType mean_arg, RealType sigma_arg = RealType(1.0)) + : mean_(arg_check(mean_arg, sigma_arg)) + , sigma_(sigma_arg) + {} + SPROUT_CONSTEXPR RealType mean() const { + return mean_; + } + SPROUT_CONSTEXPR RealType sigma() const { + return sigma_; + } + template + friend std::basic_istream& operator>>( + std::basic_istream& lhs, + param_type& rhs + ) + { + RealType mean; + RealType sigma; + if (lhs >> mean >> std::ws >> sigma) { + if (arg_check_nothrow(mean, sigma)) { + rhs.mean_ = mean; + rhs.sigma_ = sigma; + } else { + lhs.setstate(std::ios_base::failbit); + } + } + return lhs; + } + template + friend std::basic_ostream& operator<<( + std::basic_ostream& lhs, + param_type const& rhs + ) + { + return lhs << rhs.mean_ << " " << rhs.sigma_; + } + SPROUT_CONSTEXPR friend bool operator==(param_type const& lhs, param_type const& rhs) { + return lhs.mean_ == rhs.mean_ && lhs.sigma_ == rhs.sigma_; + } + SPROUT_CONSTEXPR friend bool operator!=(param_type const& lhs, param_type const& rhs) { + return !(lhs == rhs); + } + }; + private: + RealType mean_; + RealType sigma_; + RealType r1_; + RealType r2_; + RealType cached_rho_; + bool valid_; + private: + SPROUT_CONSTEXPR normal_distribution( + RealType mean, + RealType sigma, + RealType r1, + RealType r2, + RealType cached_rho, + bool valid, + private_constructor_tag + ) + : mean_(mean) + , sigma_(sigma) + , r1_(r1) + , r2_(r2) + , cached_rho_(cached_rho) + , valid_(valid) + {} + template + SPROUT_CONSTEXPR sprout::random::random_result generate_2(Engine const& eng, RealType r1, RealType r2, RealType cached_rho, bool valid) const { + using std::sin; + using std::cos; + return sprout::random::random_result( + cached_rho * (valid ? cos(result_type(2) * pi * r1) : sin(result_type(2) * pi * r1)) * sigma_ + mean_, + eng, + normal_distribution( + mean_, + sigma_, + r1, + r2, + cached_rho, + valid, + private_constructor_tag() + ) + ); + } + template + SPROUT_CONSTEXPR sprout::random::random_result generate_1_1(RealType r1, Random const& rnd) const { + using std::sqrt; + using std::log; + return generate_2(rnd.engine(), r1, rnd.result(), sqrt(-result_type(2) * log(result_type(1) - rnd.result())), true); + } + template + SPROUT_CONSTEXPR sprout::random::random_result generate_1(Random const& rnd) const { + return generate_1_1(rnd.result(), rnd()); + } + template + SPROUT_CONSTEXPR sprout::random::random_result generate(Engine const& eng) const { + return !valid_ + ? generate_1(sprout::random::uniform_01()(eng)) + : generate_2(eng, r1_, r2_, cached_rho_, false) + ; + } + public: + SPROUT_CONSTEXPR normal_distribution() + : mean_(RealType(0.0)) + , sigma_(RealType(1.0)) + , r1_(0) + , r2_(0) + , cached_rho_(0) + , valid_(false) + {} + SPROUT_CONSTEXPR explicit normal_distribution(RealType mean_arg, RealType sigma_arg = RealType(1.0)) + : mean_(arg_check(mean_arg, sigma_arg)) + , sigma_(sigma_arg) + , r1_(0) + , r2_(0) + , cached_rho_(0) + , valid_(false) + {} + SPROUT_CONSTEXPR explicit normal_distribution(param_type const& parm) + : mean_(parm.mean()) + , sigma_(parm.sigma()) + , r1_(0) + , r2_(0) + , cached_rho_(0) + , valid_(false) + {} + SPROUT_CONSTEXPR result_type mean() const { + return mean_; + } + SPROUT_CONSTEXPR result_type sigma() const { + return sigma_; + } + SPROUT_CONSTEXPR result_type min() const { + return -std::numeric_limits::infinity(); + } + SPROUT_CONSTEXPR result_type max() const { + return std::numeric_limits::infinity(); + } + SPROUT_CONSTEXPR param_type param() const { + return param_type(mean_, sigma_); + } + void param(param_type const& parm) { + mean_ = parm.mean(); + sigma_ = parm.sigma(); + valid_ = false; + } + template + SPROUT_CONSTEXPR sprout::random::random_result operator()(Engine const& eng) const { + return generate(eng); + } + template + friend std::basic_istream& operator>>( + std::basic_istream& lhs, + normal_distribution& rhs + ) + { + param_type parm; + bool valid; + RealType cached_rho; + RealType r1; + RealType r2; + if (lhs >> parm >> std::ws >> valid >> std::ws >> cached_rho >> std::ws >> r1 >> std::ws >> r2) { + rhs.param(parm); + rhs.valid_ = valid; + rhs.cached_rho_ = cached_rho; + rhs.r1_ = r1; + rhs.r2_ = r2; + } + return lhs; + } + template + friend std::basic_ostream& operator<<( + std::basic_ostream& lhs, + normal_distribution const& rhs + ) + { + return lhs << rhs.param() << " " << rhs.valid_ << " " << rhs.cached_rho_ << " " << rhs.r1_ << " " << rhs.r2_; + } + SPROUT_CONSTEXPR friend bool operator==(normal_distribution const& lhs, normal_distribution const& rhs) { + return lhs.param() == rhs.param() && lhs.valid_ == rhs.valid_ && lhs.cached_rho_ == rhs.cached_rho_ && lhs.r1_ == rhs.r1_ && lhs.r2_ == rhs.r2_; + } + SPROUT_CONSTEXPR friend bool operator!=(normal_distribution const& lhs, normal_distribution const& rhs) { + return !(lhs == rhs); + } + }; + } // namespace random + + using sprout::random::normal_distribution; +} // namespace sprout + +#endif // #ifndef SPROUT_RANDOM_NORMAL_DISTRIBUTION_HPP diff --git a/sprout/random/uniform_int_distribution.hpp b/sprout/random/uniform_int_distribution.hpp index 4338938f..02c0deb5 100644 --- a/sprout/random/uniform_int_distribution.hpp +++ b/sprout/random/uniform_int_distribution.hpp @@ -462,8 +462,8 @@ namespace sprout { } }; private: - result_type min_; - result_type max_; + IntType min_; + IntType max_; private: template SPROUT_CONSTEXPR sprout::random::random_result generate(Result const& rnd) const { @@ -516,8 +516,9 @@ namespace sprout { ) { param_type parm; - lhs >> parm; - rhs.param(parm); + if (lhs >> parm) { + rhs.param(parm); + } return lhs; } template diff --git a/sprout/random/uniform_real_distribution.hpp b/sprout/random/uniform_real_distribution.hpp index 8e4be241..45e5bbd3 100644 --- a/sprout/random/uniform_real_distribution.hpp +++ b/sprout/random/uniform_real_distribution.hpp @@ -254,8 +254,8 @@ namespace sprout { } }; private: - result_type min_; - result_type max_; + RealType min_; + RealType max_; private: template SPROUT_CONSTEXPR sprout::random::random_result generate(Result const& rnd) const { @@ -308,8 +308,9 @@ namespace sprout { ) { param_type parm; - lhs >> parm; - rhs.param(parm); + if (lhs >> parm) { + rhs.param(parm); + } return lhs; } template diff --git a/sprout/random/uniform_smallint.hpp b/sprout/random/uniform_smallint.hpp index e3605848..867b2d56 100644 --- a/sprout/random/uniform_smallint.hpp +++ b/sprout/random/uniform_smallint.hpp @@ -92,8 +92,8 @@ namespace sprout { } }; private: - result_type min_; - result_type max_; + IntType min_; + IntType max_; private: template SPROUT_CONSTEXPR sprout::random::random_result generate_true_2( @@ -244,8 +244,9 @@ namespace sprout { ) { param_type parm; - lhs >> parm; - rhs.param(parm); + if (lhs >> parm) { + rhs.param(parm); + } return lhs; } template