From 9255ad680849aedf08a0831acaad11b2d5e7f4be Mon Sep 17 00:00:00 2001 From: bolero-MURAKAMI Date: Sun, 16 Oct 2011 23:38:40 +0900 Subject: [PATCH] =?UTF-8?q?random=20=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sprout/random/bernoulli_distribution.hpp | 3 ++ sprout/random/linear_congruential.hpp | 15 ++++++---- sprout/random/mersenne_twister.hpp | 26 ++++++++-------- sprout/random/uniform_01.hpp | 38 ++++++++++++++++++++++++ 4 files changed, 63 insertions(+), 19 deletions(-) diff --git a/sprout/random/bernoulli_distribution.hpp b/sprout/random/bernoulli_distribution.hpp index 217603d3..6103a0c3 100644 --- a/sprout/random/bernoulli_distribution.hpp +++ b/sprout/random/bernoulli_distribution.hpp @@ -96,6 +96,9 @@ namespace sprout { SPROUT_CONSTEXPR explicit bernoulli_distribution(RealType p_arg = RealType(0.5)) : p_(arg_check(p_arg)) {} + SPROUT_CONSTEXPR explicit bernoulli_distribution(param_type const& parm) + : p_(parm.p()) + {} SPROUT_CONSTEXPR RealType p() const { return p_; } diff --git a/sprout/random/linear_congruential.hpp b/sprout/random/linear_congruential.hpp index a6a00385..861dd59e 100644 --- a/sprout/random/linear_congruential.hpp +++ b/sprout/random/linear_congruential.hpp @@ -29,14 +29,17 @@ namespace sprout { static_assert(m == 0 || a < m, "m == 0 || a < m"); static_assert(m == 0 || c < m, "m == 0 || c < m"); private: - static SPROUT_CONSTEXPR IntType init_seed_3(IntType const& x0) { - return x0 >= static_min() && x0 <= static_max() + static SPROUT_CONSTEXPR bool arg_check_nothrow(IntType const& x0) { + return x0 >= static_min() && x0 <= static_max(); + } + static SPROUT_CONSTEXPR IntType arg_check(IntType const& x0) { + return arg_check_nothrow(x0) ? x0 : throw "assert(x0 >= static_min() && x0 <= static_max())" ; } static SPROUT_CONSTEXPR IntType init_seed_2(IntType const& x0) { - return init_seed_3(increment == 0 && x0 == 0 ? 1 : x0); + return arg_check(increment == 0 && x0 == 0 ? 1 : x0); } static SPROUT_CONSTEXPR IntType init_seed_1(IntType const& x0) { return init_seed_2(x0 <= 0 && x0 != 0 ? x0 + modulus : x0); @@ -88,12 +91,12 @@ namespace sprout { template friend std::basic_istream& operator>>( std::basic_istream& lhs, - linear_congruential_engine const& rhs + linear_congruential_engine& rhs ) { IntType x; if(lhs >> x) { - if(x >= min() && x <= max()) { + if(arg_check_nothrow(x)) { rhs.x_ = x; } else { lhs.setstate(std::ios_base::failbit); @@ -192,7 +195,7 @@ namespace sprout { template friend std::basic_istream& operator>>( std::basic_istream& lhs, - rand48 const& rhs + rand48& rhs ) { return lhs >> rhs.lcf_; diff --git a/sprout/random/mersenne_twister.hpp b/sprout/random/mersenne_twister.hpp index 61381c65..ef31f85a 100644 --- a/sprout/random/mersenne_twister.hpp +++ b/sprout/random/mersenne_twister.hpp @@ -331,7 +331,7 @@ namespace sprout { ; } friend SPROUT_CONSTEXPR bool operator==(mersenne_twister_engine const& lhs, mersenne_twister_engine const& rhs) { - return lhs.i_ < lhs.i_ + return lhs.i_ < rhs.i_ ? lhs.equal_impl(rhs) : rhs.equal_impl(lhs) ; @@ -342,6 +342,18 @@ namespace sprout { template friend std::basic_istream& operator>>( std::basic_istream& lhs, + mersenne_twister_engine& rhs + ) + { + for (std::size_t i = 0; i < state_size; ++i) { + lhs >> rhs.x_[i] >> std::ws; + } + rhs.i_ = state_size; + return lhs; + } + template + friend std::basic_ostream& operator<<( + std::basic_ostream& lhs, mersenne_twister_engine const& rhs ) { @@ -358,18 +370,6 @@ namespace sprout { } return lhs; } - template - friend std::basic_ostream& operator<<( - std::basic_ostream& lhs, - mersenne_twister_engine const& rhs - ) - { - for (std::size_t i = 0; i < state_size; ++i) { - lhs >> rhs.x_[i] >> std::ws; - } - rhs.i_ = state_size; - return lhs; - } }; template SPROUT_CONSTEXPR std::size_t sprout::random::mersenne_twister_engine::word_size; diff --git a/sprout/random/uniform_01.hpp b/sprout/random/uniform_01.hpp index 6ccd2100..85147c7a 100644 --- a/sprout/random/uniform_01.hpp +++ b/sprout/random/uniform_01.hpp @@ -17,6 +17,37 @@ namespace sprout { public: typedef RealType input_type; typedef RealType result_type; + public: + // + // param_type + // + class param_type { + public: + typedef uniform_01 distribution_type; + public: + template + friend std::basic_ostream& operator>>( + std::basic_istream& lhs, + param_type const& rhs + ) + { + return lhs; + } + template + friend std::basic_ostream& operator<<( + std::basic_ostream& lhs, + param_type const& rhs + ) + { + return lhs; + } + SPROUT_CONSTEXPR friend bool operator==(param_type const& lhs, param_type const& rhs) { + return true; + } + SPROUT_CONSTEXPR friend bool operator!=(param_type const& lhs, param_type const& rhs) { + return !(lhs == rhs); + } + }; private: template SPROUT_CONSTEXPR sprout::random::random_result generate_1( @@ -50,12 +81,19 @@ namespace sprout { ); } public: + SPROUT_CONSTEXPR explicit uniform_01(param_type const& parm) + {} SPROUT_CONSTEXPR result_type min() const { return result_type(0); } SPROUT_CONSTEXPR result_type max() const { return result_type(1); } + SPROUT_CONSTEXPR param_type param() const { + return param_type(); + } + void param(param_type const& parm) { + } template SPROUT_CONSTEXPR sprout::random::random_result operator()(Engine const& eng) const { return generate(eng, eng());