/*============================================================================= Copyright (c) 2011-2019 Bolero MURAKAMI https://github.com/bolero-MURAKAMI/Sprout Distributed under the Boost Software License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) =============================================================================*/ #ifndef SPROUT_RANDOM_POISSON_DISTRIBUTION_HPP #define SPROUT_RANDOM_POISSON_DISTRIBUTION_HPP #include #include #include #include #include #include #include #include #include #include #include #include #ifdef SPROUT_WORKAROUND_NOT_TERMINATE_RECURSIVE_CONSTEXPR_FUNCTION_TEMPLATE # include #endif namespace sprout { namespace random { namespace detail { # define SPROUT_POISSON_TABLE_DEF \ {{ \ 0.0, \ 0.0, \ 0.69314718055994529, \ 1.7917594692280550, \ 3.1780538303479458, \ 4.7874917427820458, \ 6.5792512120101012, \ 8.5251613610654147, \ 10.604602902745251, \ 12.801827480081469 \ }} template struct poisson_table { public: typedef sprout::array table_type; public: SPROUT_STATIC_CONSTEXPR table_type table SPROUT_STATIC_CONSTEXPR_DATA_MEMBER_INNER(SPROUT_POISSON_TABLE_DEF) ; }; template SPROUT_CONSTEXPR_OR_CONST typename sprout::random::detail::poisson_table::table_type sprout::random::detail::poisson_table::table SPROUT_STATIC_CONSTEXPR_DATA_MEMBER_OUTER(SPROUT_POISSON_TABLE_DEF) ; # undef SPROUT_POISSON_TABLE_DEF } // namespace detail // // poisson_distribution // template class poisson_distribution { public: typedef RealType input_type; typedef IntType result_type; public: // // param_type // class param_type { public: typedef poisson_distribution distribution_type; private: RealType mean_; public: SPROUT_CONSTEXPR param_type() : mean_(RealType(1)) {} param_type(param_type const&) = default; explicit SPROUT_CONSTEXPR param_type(RealType mean_arg) : mean_((SPROUT_ASSERT(RealType(0) < mean_arg), mean_arg)) {} SPROUT_CONSTEXPR RealType mean() const SPROUT_NOEXCEPT { return mean_; } template friend SPROUT_NON_CONSTEXPR std::basic_istream& operator>>( std::basic_istream& lhs, param_type& rhs ) { RealType mean; if (lhs >> mean) { if (RealType(0) < mean) { rhs.mean_ = mean; } else { lhs.setstate(std::ios_base::failbit); } } return lhs; } template friend SPROUT_NON_CONSTEXPR std::basic_ostream& operator<<( std::basic_ostream& lhs, param_type const& rhs ) { return lhs << rhs.mean_; } friend SPROUT_CONSTEXPR bool operator==(param_type const& lhs, param_type const& rhs) SPROUT_NOEXCEPT { return lhs.mean_ == rhs.mean_; } friend SPROUT_CONSTEXPR bool operator!=(param_type const& lhs, param_type const& rhs) SPROUT_NOEXCEPT { return !(lhs == rhs); } }; struct ptrd_type { public: RealType v_r; RealType a; RealType b; RealType smu; RealType inv_alpha; }; private: static SPROUT_CONSTEXPR bool use_inversion_check(RealType mean) { return mean < 10; } static SPROUT_CONSTEXPR ptrd_type init_ptrd_2(RealType smu, RealType b) { return ptrd_type{ 0.9277 - 3.6224 / (b - 2), -0.059 + 0.02483 * b, b, smu, 1.1239 + 1.1328 / (b - 3.4) }; } static SPROUT_CONSTEXPR ptrd_type init_ptrd_1(RealType smu) { return init_ptrd_2(smu, 0.931 + 2.53 * smu); } static SPROUT_CONSTEXPR ptrd_type init_ptrd(RealType mean) { return use_inversion_check(mean) ? ptrd_type() : init_ptrd_1(sprout::sqrt(mean)) ; } static SPROUT_CONSTEXPR RealType init_exp_mean(RealType mean) { return !use_inversion_check(mean) ? RealType() : sprout::exp(-mean) ; } static SPROUT_CONSTEXPR RealType log_factorial(IntType k) { return sprout::random::detail::poisson_table::table[k]; } static SPROUT_CONSTEXPR RealType log_sqrt_2pi() { return 0.91893853320467267; } static SPROUT_CONSTEXPR RealType generate_us(RealType u) { return 0.5 - sprout::abs(u); } private: RealType mean_; ptrd_type ptrd_; RealType exp_mean_; private: SPROUT_CONSTEXPR bool use_inversion() const { return use_inversion_check(mean_); } template SPROUT_CXX14_CONSTEXPR result_type do_invert(Engine& eng) const { RealType u = sprout::random::uniform_01()(eng); IntType x = 0; RealType p = exp_mean_; while (u > p) { u -= p; ++x; p = p * mean_ / x; } return x; } template SPROUT_CONSTEXPR sprout::random::random_result invert_2(Random const& rnd, RealType u, IntType x, RealType p) const { return !(u > p) ? sprout::random::random_result(x, sprout::random::next(rnd).engine(), *this) : invert_2(rnd, u - p, x + 1, p * mean_ / (x + 1)) ; } template SPROUT_CONSTEXPR sprout::random::random_result invert_1(Random const& rnd, IntType x, RealType p) const { return invert_2(rnd, sprout::random::result(rnd), x, p); } template SPROUT_CONSTEXPR sprout::random::random_result invert(Engine const& eng) const { return invert_1(sprout::random::uniform_01()(eng), 0, exp_mean_); } template SPROUT_CXX14_CONSTEXPR result_type do_generate(Engine& eng) const { for (; ; ) { RealType v = sprout::random::uniform_01()(eng); if (v <= 0.86 * ptrd_.v_r) { RealType u = v / ptrd_.v_r - 0.43; return static_cast(sprout::floor((2 * ptrd_.a / (0.5 - sprout::abs(u)) + ptrd_.b) * u + mean_ + 0.445)); } RealType u = RealType(); if (v >= ptrd_.v_r) { u = sprout::random::uniform_01()(eng) - 0.5; } else { u = v / ptrd_.v_r - 0.93; u = ((u < 0) ? -0.5 : 0.5) - u; v = sprout::random::uniform_01()(eng) * ptrd_.v_r; } RealType us = 0.5 - sprout::abs(u); if (us < 0.013 && v > us) { continue; } RealType k = sprout::floor((2 * ptrd_.a / us + ptrd_.b) * u + mean_ + 0.445); v = v * ptrd_.inv_alpha / (ptrd_.a / (us * us) + ptrd_.b); if ((k >= 10 && sprout::log(v * ptrd_.smu) <= (k + 0.5) * sprout::log(mean_ / k) - mean_ - log_sqrt_2pi() + k - (1 / 12. - (1 / 360. - 1 / (1260. * k * k)) / (k * k)) / k) || (k >= 0 && sprout::log(v) <= k * sprout::log(mean_) - mean_ - log_factorial(static_cast(k))) ) { return static_cast(k); } } } template SPROUT_CONSTEXPR sprout::random::random_result generate_4(Random const& rnd, RealType u) const { return sprout::random::random_result( static_cast(sprout::floor((2 * ptrd_.a / (0.5 - sprout::abs(u)) + ptrd_.b) * u + mean_ + 0.445)), sprout::random::next(rnd).engine(), *this ); } #ifdef SPROUT_WORKAROUND_NOT_TERMINATE_RECURSIVE_CONSTEXPR_FUNCTION_TEMPLATE template SPROUT_CONSTEXPR sprout::random::random_result generate_3(Random const& rnd, RealType k, RealType v) const { return (k >= 10 && sprout::log(v * ptrd_.smu) <= (k + 0.5) * sprout::log(mean_ / k) - mean_ - log_sqrt_2pi() + k - (1 / 12. - (1 / 360. - 1 / (1260. * k * k)) / (k * k)) / k) || (k >= 0 && sprout::log(v) <= k * sprout::log(mean_) - mean_ - log_factorial(static_cast(k))) ? sprout::random::random_result(static_cast(k), sprout::random::next(rnd).engine(), *this) : generate_1(rnd()) ; } template SPROUT_CONSTEXPR sprout::random::random_result generate_3(Random const&, RealType, RealType) const { return sprout::throw_recursive_function_template_instantiation_exeeded(); } template SPROUT_CONSTEXPR sprout::random::random_result generate_2(Random const& rnd, RealType v, RealType u, RealType us) const { return us < 0.013 && v > us ? generate_1(rnd()) : generate_3( rnd, sprout::floor((2 * ptrd_.a / us + ptrd_.b) * u + mean_ + 0.445), v * ptrd_.inv_alpha / (ptrd_.a / (us * us) + ptrd_.b) ) ; } template SPROUT_CONSTEXPR sprout::random::random_result generate_2(Random const&, RealType, RealType, RealType) const { return sprout::throw_recursive_function_template_instantiation_exeeded(); } template SPROUT_CONSTEXPR sprout::random::random_result generate_1_3(Random const& rnd, RealType, RealType u) const { return generate_2(rnd, sprout::random::result(rnd), ((u < 0) ? -0.5 : 0.5) - u, generate_us(((u < 0) ? -0.5 : 0.5) - u)); } template SPROUT_CONSTEXPR sprout::random::random_result generate_1_3(Random const&, RealType, RealType) const { return sprout::throw_recursive_function_template_instantiation_exeeded(); } template SPROUT_CONSTEXPR sprout::random::random_result generate_1_2(Random const& rnd, RealType v) const { return generate_2(rnd, v, sprout::random::result(rnd) - 0.5, generate_us(sprout::random::result(rnd) - 0.5)); } template SPROUT_CONSTEXPR sprout::random::random_result generate_1_2(Random const&, RealType) const { return sprout::throw_recursive_function_template_instantiation_exeeded(); } template SPROUT_CONSTEXPR sprout::random::random_result generate_1(Random const& rnd) const { return sprout::random::result(rnd) <= 0.86 * ptrd_.v_r ? generate_4(rnd, sprout::random::result(rnd) / ptrd_.v_r - 0.43) : sprout::random::result(rnd) >= ptrd_.v_r ? generate_1_2(rnd(), sprout::random::result(rnd)) : generate_1_3(rnd(), sprout::random::result(rnd), sprout::random::result(rnd) / ptrd_.v_r - 0.93) ; } template SPROUT_CONSTEXPR sprout::random::random_result generate_1(Random const&) const { return sprout::throw_recursive_function_template_instantiation_exeeded(); } #else template SPROUT_CONSTEXPR sprout::random::random_result generate_3(Random const& rnd, RealType k, RealType v) const { return (k >= 10 && sprout::log(v * ptrd_.smu) <= (k + 0.5) * sprout::log(mean_ / k) - mean_ - log_sqrt_2pi() + k - (1 / 12. - (1 / 360. - 1 / (1260. * k * k)) / (k * k)) / k) || (k >= 0 && sprout::log(v) <= k * sprout::log(mean_) - mean_ - log_factorial(static_cast(k))) ? sprout::random::random_result(static_cast(k), sprout::random::next(rnd).engine(), *this) : generate_1(rnd()) ; } template SPROUT_CONSTEXPR sprout::random::random_result generate_2(Random const& rnd, RealType v, RealType u, RealType us) const { return us < 0.013 && v > us ? generate_1(rnd()) : generate_3( rnd, sprout::floor((2 * ptrd_.a / us + ptrd_.b) * u + mean_ + 0.445), v * ptrd_.inv_alpha / (ptrd_.a / (us * us) + ptrd_.b) ) ; } template SPROUT_CONSTEXPR sprout::random::random_result generate_1_3(Random const& rnd, RealType v, RealType u) const { return generate_2(rnd, sprout::random::result(rnd), ((u < 0) ? -0.5 : 0.5) - u, generate_us(((u < 0) ? -0.5 : 0.5) - u)); } template SPROUT_CONSTEXPR sprout::random::random_result generate_1_2(Random const& rnd, RealType v) const { return generate_2(rnd, v, sprout::random::result(rnd) - 0.5, generate_us(sprout::random::result(rnd) - 0.5)); } template SPROUT_CONSTEXPR sprout::random::random_result generate_1(Random const& rnd) const { return sprout::random::result(rnd) <= 0.86 * ptrd_.v_r ? generate_4(rnd, sprout::random::result(rnd) / ptrd_.v_r - 0.43) : sprout::random::result(rnd) >= ptrd_.v_r ? generate_1_2(rnd(), sprout::random::result(rnd)) : generate_1_3(rnd(), sprout::random::result(rnd), sprout::random::result(rnd) / ptrd_.v_r - 0.93) ; } #endif template SPROUT_CONSTEXPR sprout::random::random_result generate(Engine const& eng) const { return generate_1(sprout::random::uniform_01()(eng)); } public: SPROUT_CONSTEXPR poisson_distribution() : mean_(RealType(1)) , ptrd_(init_ptrd(RealType(1))) , exp_mean_(init_exp_mean(RealType(1))) {} poisson_distribution(poisson_distribution const&) = default; explicit SPROUT_CONSTEXPR poisson_distribution(RealType mean_arg) : mean_((SPROUT_ASSERT(RealType(0) < mean_arg), mean_arg)) , ptrd_(init_ptrd(mean_arg)) , exp_mean_(init_exp_mean(mean_arg)) {} explicit SPROUT_CONSTEXPR poisson_distribution(param_type const& parm) : mean_(parm.mean()) , ptrd_(init_ptrd(parm.mean())) , exp_mean_(init_exp_mean(parm.mean())) {} SPROUT_CONSTEXPR result_type mean() const SPROUT_NOEXCEPT { return mean_; } SPROUT_CONSTEXPR result_type min() const SPROUT_NOEXCEPT { return 0; } SPROUT_CONSTEXPR result_type max() const SPROUT_NOEXCEPT { return sprout::numeric_limits::max(); } SPROUT_CXX14_CONSTEXPR void reset() SPROUT_NOEXCEPT {} SPROUT_CONSTEXPR param_type param() const SPROUT_NOEXCEPT { return param_type(mean_); } SPROUT_CXX14_CONSTEXPR void param(param_type const& parm) { mean_ = parm.mean(); ptrd_ = init_ptrd(mean_); exp_mean_ = init_exp_mean(mean_); } template SPROUT_CXX14_CONSTEXPR result_type operator()(Engine& eng) const { return use_inversion() ? do_invert(eng) : do_generate(eng) ; } template SPROUT_CONSTEXPR sprout::random::random_result const operator()(Engine const& eng) const { return use_inversion() ? invert(eng) : generate(eng) ; } template SPROUT_CXX14_CONSTEXPR result_type operator()(Engine& eng, param_type const& parm) const { return poisson_distribution(parm)(eng); } template SPROUT_CONSTEXPR sprout::random::random_result const operator()(Engine const& eng, param_type const& parm) const { return poisson_distribution(parm)(eng); } template friend SPROUT_NON_CONSTEXPR std::basic_istream& operator>>( std::basic_istream& lhs, poisson_distribution& rhs ) { param_type parm; if (lhs >> parm) { rhs.param(parm); } return lhs; } template friend SPROUT_NON_CONSTEXPR std::basic_ostream& operator<<( std::basic_ostream& lhs, poisson_distribution const& rhs ) { return lhs << rhs.param(); } friend SPROUT_CONSTEXPR bool operator==(poisson_distribution const& lhs, poisson_distribution const& rhs) SPROUT_NOEXCEPT { return lhs.param() == rhs.param(); } friend SPROUT_CONSTEXPR bool operator!=(poisson_distribution const& lhs, poisson_distribution const& rhs) SPROUT_NOEXCEPT { return !(lhs == rhs); } }; } // namespace random using sprout::random::poisson_distribution; } // namespace sprout #endif // #ifndef SPROUT_RANDOM_POISSON_DISTRIBUTION_HPP