/*============================================================================= Copyright (c) 2011-2019 Bolero MURAKAMI https://github.com/bolero-MURAKAMI/Sprout Distributed under the Boost Software License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) =============================================================================*/ #ifndef SPROUT_RANDOM_NORMAL_DISTRIBUTION_HPP #define SPROUT_RANDOM_NORMAL_DISTRIBUTION_HPP #include #include #include #include #include #include #include #include #include #include #include #include #include namespace sprout { namespace random { // // normal_distribution // template class normal_distribution { public: typedef RealType input_type; typedef RealType result_type; private: struct private_construct_t {}; public: // // param_type // class param_type { public: typedef normal_distribution distribution_type; private: RealType mean_; RealType sigma_; public: SPROUT_CONSTEXPR param_type() : mean_(RealType(0.0)) , sigma_(RealType(1.0)) {} param_type(param_type const&) = default; explicit SPROUT_CONSTEXPR param_type(RealType mean_arg, RealType sigma_arg = RealType(1.0)) : mean_(mean_arg) , sigma_((SPROUT_ASSERT(sigma_arg >= RealType(0)), sigma_arg)) {} SPROUT_CONSTEXPR RealType mean() const SPROUT_NOEXCEPT { return mean_; } SPROUT_CONSTEXPR RealType sigma() const SPROUT_NOEXCEPT { return sigma_; } SPROUT_CONSTEXPR RealType stddev() const SPROUT_NOEXCEPT { return sigma_; } template friend SPROUT_NON_CONSTEXPR std::basic_istream& operator>>( std::basic_istream& lhs, param_type& rhs ) { RealType mean; RealType sigma; if (lhs >> mean >> std::ws >> sigma) { if (sigma >= RealType(0)) { rhs.mean_ = mean; rhs.sigma_ = sigma; } else { lhs.setstate(std::ios_base::failbit); } } return lhs; } template friend SPROUT_NON_CONSTEXPR std::basic_ostream& operator<<( std::basic_ostream& lhs, param_type const& rhs ) { return lhs << rhs.mean_ << " " << rhs.sigma_; } friend SPROUT_CONSTEXPR bool operator==(param_type const& lhs, param_type const& rhs) SPROUT_NOEXCEPT { return lhs.mean_ == rhs.mean_ && lhs.sigma_ == rhs.sigma_; } friend SPROUT_CONSTEXPR bool operator!=(param_type const& lhs, param_type const& rhs) SPROUT_NOEXCEPT { return !(lhs == rhs); } }; private: RealType mean_; RealType sigma_; RealType r1_; RealType r2_; RealType cached_rho_; bool valid_; private: SPROUT_CONSTEXPR normal_distribution( RealType mean, RealType sigma, RealType r1, RealType r2, RealType cached_rho, bool valid, private_construct_t ) : mean_(mean) , sigma_(sigma) , r1_(r1) , r2_(r2) , cached_rho_(cached_rho) , valid_(valid) {} template SPROUT_CONSTEXPR sprout::random::random_result generate_2(Engine const& eng, RealType r1, RealType r2, RealType cached_rho, bool valid) const { return sprout::random::random_result( cached_rho * (valid ? sprout::cos(sprout::math::two_pi() * r1) : sprout::sin(sprout::math::two_pi() * r1) ) * sigma_ + mean_, eng, normal_distribution( mean_, sigma_, r1, r2, cached_rho, valid, private_construct_t() ) ); } template SPROUT_CONSTEXPR sprout::random::random_result generate_1_1(RealType r1, Random const& rnd) const { return generate_2( sprout::random::next(rnd).engine(), r1, sprout::random::result(rnd), sprout::sqrt(-result_type(2) * sprout::math::log(result_type(1) - sprout::random::result(rnd))), true ); } template SPROUT_CONSTEXPR sprout::random::random_result generate_1(Random const& rnd) const { return generate_1_1(sprout::random::result(rnd), sprout::random::next(rnd)()); } template SPROUT_CONSTEXPR sprout::random::random_result generate(Engine const& eng) const { return !valid_ ? generate_1(sprout::random::uniform_01()(eng)) : generate_2(eng, r1_, r2_, cached_rho_, false) ; } public: SPROUT_CONSTEXPR normal_distribution() : mean_(RealType(0.0)) , sigma_(RealType(1.0)) , r1_(0) , r2_(0) , cached_rho_(0) , valid_(false) {} normal_distribution(normal_distribution const&) = default; explicit SPROUT_CONSTEXPR normal_distribution(RealType mean_arg, RealType sigma_arg = RealType(1.0)) : mean_(mean_arg) , sigma_((SPROUT_ASSERT(sigma_arg >= RealType(0)), sigma_arg)) , r1_(0) , r2_(0) , cached_rho_(0) , valid_(false) {} explicit SPROUT_CONSTEXPR normal_distribution(param_type const& parm) : mean_(parm.mean()) , sigma_(parm.sigma()) , r1_(0) , r2_(0) , cached_rho_(0) , valid_(false) {} SPROUT_CONSTEXPR result_type mean() const SPROUT_NOEXCEPT { return mean_; } SPROUT_CONSTEXPR result_type sigma() const SPROUT_NOEXCEPT { return sigma_; } SPROUT_CONSTEXPR result_type stddev() const SPROUT_NOEXCEPT { return sigma_; } SPROUT_CONSTEXPR result_type min() const SPROUT_NOEXCEPT { return -sprout::numeric_limits::infinity(); } SPROUT_CONSTEXPR result_type max() const SPROUT_NOEXCEPT { return sprout::numeric_limits::infinity(); } SPROUT_CXX14_CONSTEXPR void reset() SPROUT_NOEXCEPT { valid_ = false; } SPROUT_CONSTEXPR param_type param() const SPROUT_NOEXCEPT { return param_type(mean_, sigma_); } SPROUT_CXX14_CONSTEXPR void param(param_type const& parm) { mean_ = parm.mean(); sigma_ = parm.sigma(); valid_ = false; } template SPROUT_CXX14_CONSTEXPR result_type operator()(Engine& eng) { if (!valid_) { r1_ = static_cast(sprout::random::uniform_01()(eng)); r2_ = static_cast(sprout::random::uniform_01()(eng)); cached_rho_ = sprout::math::sqrt(-result_type(2) * sprout::math::log(result_type(1) - r2_)); valid_ = true; } else { valid_ = false; } return cached_rho_ * (valid_ ? sprout::math::cos(sprout::math::two_pi() * r1_) : sprout::math::sin(sprout::math::two_pi() * r1_)) * sigma_ + mean_ ; } template SPROUT_CONSTEXPR sprout::random::random_result const operator()(Engine const& eng) const { return generate(eng); } template SPROUT_CXX14_CONSTEXPR result_type operator()(Engine& eng, param_type const& parm) const { return normal_distribution(parm)(eng); } template SPROUT_CONSTEXPR sprout::random::random_result const operator()(Engine const& eng, param_type const& parm) const { return normal_distribution(parm)(eng); } template friend SPROUT_NON_CONSTEXPR std::basic_istream& operator>>( std::basic_istream& lhs, normal_distribution& rhs ) { param_type parm; bool valid; RealType cached_rho; RealType r1; RealType r2; if (lhs >> parm >> std::ws >> valid >> std::ws >> cached_rho >> std::ws >> r1 >> std::ws >> r2) { rhs.param(parm); rhs.valid_ = valid; rhs.cached_rho_ = cached_rho; rhs.r1_ = r1; rhs.r2_ = r2; } return lhs; } template friend SPROUT_NON_CONSTEXPR std::basic_ostream& operator<<( std::basic_ostream& lhs, normal_distribution const& rhs ) { return lhs << rhs.param() << " " << rhs.valid_ << " " << rhs.cached_rho_ << " " << rhs.r1_ << " " << rhs.r2_; } friend SPROUT_CONSTEXPR bool operator==(normal_distribution const& lhs, normal_distribution const& rhs) SPROUT_NOEXCEPT { return lhs.param() == rhs.param() && lhs.valid_ == rhs.valid_ && lhs.cached_rho_ == rhs.cached_rho_ && lhs.r1_ == rhs.r1_ && lhs.r2_ == rhs.r2_ ; } friend SPROUT_CONSTEXPR bool operator!=(normal_distribution const& lhs, normal_distribution const& rhs) SPROUT_NOEXCEPT { return !(lhs == rhs); } }; } // namespace random using sprout::random::normal_distribution; } // namespace sprout #endif // #ifndef SPROUT_RANDOM_NORMAL_DISTRIBUTION_HPP