From 7ae0072f2e7e5a34d4742aecfc53b4e1ec0a1851 Mon Sep 17 00:00:00 2001 From: bolero-MURAKAMI Date: Fri, 14 Oct 2011 21:23:13 +0900 Subject: [PATCH] =?UTF-8?q?random/geometric=5Fdistribution.hpp=20=E8=BF=BD?= =?UTF-8?q?=E5=8A=A0=20distribution=20=E3=81=AE=E3=82=B9=E3=83=88=E3=83=AA?= =?UTF-8?q?=E3=83=BC=E3=83=A0=E5=85=A5=E5=8A=9B=E3=81=AE=E3=83=90=E3=82=B0?= =?UTF-8?q?=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sprout/random.hpp | 1 + sprout/random/bernoulli_distribution.hpp | 15 +- sprout/random/binomial_distribution.hpp | 20 ++- sprout/random/geometric_distribution.hpp | 180 ++++++++++++++++++++ sprout/random/uniform_int_distribution.hpp | 19 ++- sprout/random/uniform_real_distribution.hpp | 19 ++- sprout/random/uniform_smallint.hpp | 19 ++- 7 files changed, 259 insertions(+), 14 deletions(-) create mode 100644 sprout/random/geometric_distribution.hpp diff --git a/sprout/random.hpp b/sprout/random.hpp index 390727cf..9f8f0bd1 100644 --- a/sprout/random.hpp +++ b/sprout/random.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include diff --git a/sprout/random/bernoulli_distribution.hpp b/sprout/random/bernoulli_distribution.hpp index bedaa84d..217603d3 100644 --- a/sprout/random/bernoulli_distribution.hpp +++ b/sprout/random/bernoulli_distribution.hpp @@ -16,8 +16,11 @@ namespace sprout { typedef int input_type; typedef bool result_type; private: + static SPROUT_CONSTEXPR bool arg_check_nothrow(RealType p_arg) { + return p_arg >= 0 && p_arg <= 1; + } static SPROUT_CONSTEXPR RealType arg_check(RealType p_arg) { - return p_arg >= 0 && p_arg <= 1 + return arg_check_nothrow(p_arg) ? p_arg : throw "assert(p_arg >= 0 && p_arg <= 1)" ; @@ -47,7 +50,15 @@ namespace sprout { param_type const& rhs ) { - return lhs >> rhs.p_; + RealType p; + if (lhs >> p) { + if (arg_check_nothrow(p)) { + rhs.p_ = p; + } else { + lhs.setstate(std::ios_base::failbit); + } + } + return lhs; } template friend std::basic_ostream& operator<<( diff --git a/sprout/random/binomial_distribution.hpp b/sprout/random/binomial_distribution.hpp index 610d78d9..001a7987 100644 --- a/sprout/random/binomial_distribution.hpp +++ b/sprout/random/binomial_distribution.hpp @@ -1,6 +1,7 @@ #ifndef SPROUT_RANDOM_BINOMIAL_DISTRIBUTION_HPP #define SPROUT_RANDOM_BINOMIAL_DISTRIBUTION_HPP +#include #include #include #include @@ -52,8 +53,11 @@ namespace sprout { RealType u_rv_r; }; private: + static SPROUT_CONSTEXPR bool arg_check_nothrow(IntType t_arg, RealType p_arg) { + return t_arg >= IntType(0) && RealType(0) <= p_arg && p_arg <= RealType(1); + } static SPROUT_CONSTEXPR IntType arg_check(IntType t_arg, RealType p_arg) { - return t_arg >= IntType(0) && RealType(0) <= p_arg && p_arg <= RealType(1) + return arg_check_nothrow(t_arg, p_arg) ? t_arg : throw "assert(t_arg >= IntType(0) && RealType(0) <= p_arg && p_arg <= RealType(1))" ; @@ -89,7 +93,17 @@ namespace sprout { param_type const& rhs ) { - return lhs >> rhs.t_ >> std::ws >> rhs.p_; + IntType t; + RealType p; + if (lhs >> t >> std::ws >> p) { + if (arg_check_nothrow(t, p)) { + rhs.t_ = t; + rhs.p_ = p; + } else { + lhs.setstate(std::ios_base::failbit); + } + } + return lhs; } template friend std::basic_ostream& operator<<( @@ -397,7 +411,7 @@ namespace sprout { ) { param_type parm; - return lhs >> parm; + lhs >> parm; param(parm); return lhs; } diff --git a/sprout/random/geometric_distribution.hpp b/sprout/random/geometric_distribution.hpp new file mode 100644 index 00000000..0b69328e --- /dev/null +++ b/sprout/random/geometric_distribution.hpp @@ -0,0 +1,180 @@ +#ifndef SPROUT_RANDOM_GEOMETRIC_DISTRIBUTION_HPP +#define SPROUT_RANDOM_GEOMETRIC_DISTRIBUTION_HPP + +#include +#include +#include +#include +#include +#include + +namespace sprout { + namespace detail { + template + SPROUT_CONSTEXPR T floor(T x) { + return x >= T(0) ? std::floor(x) : -std::ceil(-x); + } + template + SPROUT_CONSTEXPR T ceil(T x) { + return x >= T(0) ? std::ceil(x) : -std::floor(-x); + } + } // namespace detail + namespace random { + // + // geometric_distribution + // + template + class geometric_distribution { + public: + typedef RealType input_type; + typedef IntType result_type; + private: + static SPROUT_CONSTEXPR bool arg_check_nothrow(RealType p_arg) { + return RealType(0) < p_arg && p_arg < RealType(1); + } + static SPROUT_CONSTEXPR RealType arg_check(RealType p_arg) { + return arg_check_nothrow(p_arg) + ? p_arg + : throw "assert(RealType(0) < p_arg && p_arg < RealType(1))" + ; + } + public: + // + // param_type + // + class param_type { + public: + typedef geometric_distribution distribution_type; + private: + RealType p_; + public: + SPROUT_CONSTEXPR param_type() + : p_(RealType(0.5)) + {} + SPROUT_CONSTEXPR explicit param_type(RealType p_arg) + : p_(arg_check(p_arg)) + {} + SPROUT_CONSTEXPR RealType p() const { + return p_; + } + template + friend std::basic_ostream& operator>>( + std::basic_istream& lhs, + param_type const& rhs + ) + { + RealType p; + if (lhs >> p) { + if (arg_check_nothrow(p)) { + rhs.p_ = p; + } else { + lhs.setstate(std::ios_base::failbit); + } + } + return lhs; + } + template + friend std::basic_ostream& operator<<( + std::basic_ostream& lhs, + param_type const& rhs + ) + { + return lhs << rhs.p_; + } + SPROUT_CONSTEXPR friend bool operator==(param_type const& lhs, param_type const& rhs) { + return lhs.p_ == rhs.p_; + } + SPROUT_CONSTEXPR friend bool operator!=(param_type const& lhs, param_type const& rhs) { + return !(lhs == rhs); + } + }; + private: + public: + static SPROUT_CONSTEXPR RealType init_log_1mp(RealType p) { + using std::log; + return log(1 - p); + } + private: + public: + RealType p_; + RealType log_1mp_; + private: + template + SPROUT_CONSTEXPR sprout::random::random_result generate_1(Random const& rnd) const { + using std::log; + using std::floor; + return sprout::random::random_result( + static_cast(floor(log(RealType(1) - rnd.result()) / log_1mp_)), + rnd.engine(), + *this + ); + } + template + SPROUT_CONSTEXPR sprout::random::random_result generate(Engine const& eng) const { + return generate_1(sprout::random::uniform_01()(eng)); + } + public: + SPROUT_CONSTEXPR geometric_distribution() + : p_(RealType(0.5)) + , log_1mp_(init_log_1mp(RealType(0.5))) + {} + SPROUT_CONSTEXPR explicit geometric_distribution(RealType p_arg) + : p_(arg_check(p_arg)) + , log_1mp_(init_log_1mp(p_arg)) + {} + SPROUT_CONSTEXPR explicit geometric_distribution(param_type const& parm) + : p_(parm.p()) + , log_1mp_(init_log_1mp(parm.p())) + {} + SPROUT_CONSTEXPR result_type p() const { + return p_; + } + SPROUT_CONSTEXPR result_type min() const { + return 0; + } + SPROUT_CONSTEXPR result_type max() const { + return std::numeric_limits::max(); + } + SPROUT_CONSTEXPR param_type param() const { + return param_type(p_); + } + void param(param_type const& parm) { + p_ = parm.p(); + log_1mp_ = init_log_1mp(p_); + } + template + SPROUT_CONSTEXPR sprout::random::random_result operator()(Engine const& eng) const { + return generate(eng); + } + template + friend std::basic_ostream& operator>>( + std::basic_istream& lhs, + geometric_distribution const& rhs + ) + { + param_type parm; + lhs >> parm; + param(parm); + return lhs; + } + template + friend std::basic_ostream& operator<<( + std::basic_ostream& lhs, + geometric_distribution const& rhs + ) + { + return lhs << param(); + } + SPROUT_CONSTEXPR friend bool operator==(geometric_distribution const& lhs, geometric_distribution const& rhs) { + return lhs.param() == rhs.param(); + } + SPROUT_CONSTEXPR friend bool operator!=(geometric_distribution const& lhs, geometric_distribution const& rhs) { + return !(lhs == rhs); + } + }; + } // namespace random + + using sprout::random::geometric_distribution; +} // namespace sprout + +#endif // #ifndef SPROUT_RANDOM_GEOMETRIC_DISTRIBUTION_HPP diff --git a/sprout/random/uniform_int_distribution.hpp b/sprout/random/uniform_int_distribution.hpp index 19567520..aa1e6db7 100644 --- a/sprout/random/uniform_int_distribution.hpp +++ b/sprout/random/uniform_int_distribution.hpp @@ -390,8 +390,11 @@ namespace sprout { typedef IntType input_type; typedef IntType result_type; private: + static SPROUT_CONSTEXPR bool arg_check_nothrow(IntType min_arg, IntType max_arg) { + return min_arg <= max_arg; + } static SPROUT_CONSTEXPR IntType arg_check(IntType min_arg, IntType max_arg) { - return min_arg <= max_arg + return arg_check_nothrow(min_arg, max_arg) ? min_arg : throw "assert(min_arg <= max_arg)" ; @@ -427,7 +430,17 @@ namespace sprout { param_type const& rhs ) { - return lhs >> rhs.min_ >> std::ws >> rhs.max_; + IntType min; + IntType max; + if (lhs >> min >> std::ws >> max) { + if (arg_check_nothrow(min, max)) { + rhs.min_ = min; + rhs.max_ = max; + } else { + lhs.setstate(std::ios_base::failbit); + } + } + return lhs; } template friend std::basic_ostream& operator<<( @@ -499,7 +512,7 @@ namespace sprout { ) { param_type parm; - return lhs >> parm; + lhs >> parm; param(parm); return lhs; } diff --git a/sprout/random/uniform_real_distribution.hpp b/sprout/random/uniform_real_distribution.hpp index b6912ff9..a77b3990 100644 --- a/sprout/random/uniform_real_distribution.hpp +++ b/sprout/random/uniform_real_distribution.hpp @@ -182,8 +182,11 @@ namespace sprout { typedef RealType input_type; typedef RealType result_type; private: + static SPROUT_CONSTEXPR bool arg_check_nothrow(RealType min_arg, RealType max_arg) { + return min_arg <= max_arg; + } static SPROUT_CONSTEXPR RealType arg_check(RealType min_arg, RealType max_arg) { - return min_arg <= max_arg + return arg_check_nothrow(min_arg, max_arg) ? min_arg : throw "assert(min_arg <= max_arg)" ; @@ -219,7 +222,17 @@ namespace sprout { param_type const& rhs ) { - return lhs >> rhs.min_ >> std::ws >> rhs.max_; + RealType min; + RealType max; + if (lhs >> min >> std::ws >> max) { + if (arg_check_nothrow(min, max)) { + rhs.min_ = min; + rhs.max_ = max; + } else { + lhs.setstate(std::ios_base::failbit); + } + } + return lhs; } template friend std::basic_ostream& operator<<( @@ -291,7 +304,7 @@ namespace sprout { ) { param_type parm; - return lhs >> parm; + lhs >> parm; param(parm); return lhs; } diff --git a/sprout/random/uniform_smallint.hpp b/sprout/random/uniform_smallint.hpp index 891bb820..54b7cde2 100644 --- a/sprout/random/uniform_smallint.hpp +++ b/sprout/random/uniform_smallint.hpp @@ -20,8 +20,11 @@ namespace sprout { typedef IntType input_type; typedef IntType result_type; private: + static SPROUT_CONSTEXPR bool arg_check_nothrow(IntType min_arg, IntType max_arg) { + return min_arg <= max_arg; + } static SPROUT_CONSTEXPR IntType arg_check(IntType min_arg, IntType max_arg) { - return min_arg <= max_arg + return arg_check_nothrow(min_arg, max_arg) ? min_arg : throw "assert(min_arg <= max_arg)" ; @@ -57,7 +60,17 @@ namespace sprout { param_type const& rhs ) { - return lhs >> rhs.min_ >> std::ws >> rhs.max_; + IntType min; + IntType max; + if (lhs >> min >> std::ws >> max) { + if (arg_check_nothrow(min, max)) { + rhs.min_ = min; + rhs.max_ = max; + } else { + lhs.setstate(std::ios_base::failbit); + } + } + return lhs; } template friend std::basic_ostream& operator<<( @@ -227,7 +240,7 @@ namespace sprout { ) { param_type parm; - return lhs >> parm; + lhs >> parm; param(parm); return lhs; }