From 9bcc98cb2965819ff80abcb3e6d95f08f22ffcd9 Mon Sep 17 00:00:00 2001 From: bolero-MURAKAMI Date: Fri, 10 Apr 2015 15:40:12 +0900 Subject: [PATCH] add random::poisson_distribution --- libs/random/test/poisson_distribution.cpp | 28 ++ libs/random/test/random.cpp | 2 + sprout/random/distribution.hpp | 1 + sprout/random/geometric_distribution.hpp | 2 - sprout/random/poisson_distribution.hpp | 417 ++++++++++++++++++++++ 5 files changed, 448 insertions(+), 2 deletions(-) create mode 100644 libs/random/test/poisson_distribution.cpp create mode 100644 sprout/random/poisson_distribution.hpp diff --git a/libs/random/test/poisson_distribution.cpp b/libs/random/test/poisson_distribution.cpp new file mode 100644 index 00000000..5ba0630d --- /dev/null +++ b/libs/random/test/poisson_distribution.cpp @@ -0,0 +1,28 @@ +/*============================================================================= + Copyright (c) 2011-2015 Bolero MURAKAMI + https://github.com/bolero-MURAKAMI/Sprout + + Distributed under the Boost Software License, Version 1.0. (See accompanying + file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +=============================================================================*/ +#ifndef SPROUT_LIBS_RANDOM_TEST_POISSON_DISTRIBUTION_CPP +#define SPROUT_LIBS_RANDOM_TEST_POISSON_DISTRIBUTION_CPP + +#include +#include "./distribution_generic.hpp" +#include + +namespace testspr { + void random_poisson_distribution_test() { + using namespace sprout; + + testspr::random_distribution_test_generic >(); + } +} // namespace testspr + +#ifndef TESTSPR_CPP_INCLUDE +# define TESTSPR_TEST_FUNCTION testspr::random_poisson_distribution_test +# include +#endif + +#endif // #ifndef SPROUT_LIBS_RANDOM_TEST_POISSON_DISTRIBUTION_CPP diff --git a/libs/random/test/random.cpp b/libs/random/test/random.cpp index 0a2eb6c9..97648707 100644 --- a/libs/random/test/random.cpp +++ b/libs/random/test/random.cpp @@ -26,6 +26,7 @@ #include "./bernoulli_distribution.cpp" #include "./binomial_distribution.cpp" #include "./geometric_distribution.cpp" +#include "./poisson_distribution.cpp" #include "./normal_distribution.cpp" #include @@ -50,6 +51,7 @@ namespace testspr { testspr::random_bernoulli_distribution_test(); testspr::random_binomial_distribution_test(); testspr::random_geometric_distribution_test(); + testspr::random_poisson_distribution_test(); testspr::random_normal_distribution_test(); } } // namespace testspr diff --git a/sprout/random/distribution.hpp b/sprout/random/distribution.hpp index bce7e120..224c0b77 100644 --- a/sprout/random/distribution.hpp +++ b/sprout/random/distribution.hpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #endif // #ifndef SPROUT_RANDOM_DISTRIBUTION_HPP diff --git a/sprout/random/geometric_distribution.hpp b/sprout/random/geometric_distribution.hpp index 7b0ac9f3..21fc1897 100644 --- a/sprout/random/geometric_distribution.hpp +++ b/sprout/random/geometric_distribution.hpp @@ -81,12 +81,10 @@ namespace sprout { } }; private: - public: static SPROUT_CONSTEXPR RealType init_log_1mp(RealType p) { return sprout::math::log(1 - p); } private: - public: RealType p_; RealType log_1mp_; private: diff --git a/sprout/random/poisson_distribution.hpp b/sprout/random/poisson_distribution.hpp new file mode 100644 index 00000000..674512a7 --- /dev/null +++ b/sprout/random/poisson_distribution.hpp @@ -0,0 +1,417 @@ +/*============================================================================= + Copyright (c) 2011-2015 Bolero MURAKAMI + https://github.com/bolero-MURAKAMI/Sprout + + Distributed under the Boost Software License, Version 1.0. (See accompanying + file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +=============================================================================*/ +#ifndef SPROUT_RANDOM_POISSON_DISTRIBUTION_HPP +#define SPROUT_RANDOM_POISSON_DISTRIBUTION_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef SPROUT_WORKAROUND_NOT_TERMINATE_RECURSIVE_CONSTEXPR_FUNCTION_TEMPLATE +# include +#endif + +namespace sprout { + namespace random { + namespace detail { +# define SPROUT_POISSON_TABLE_DEF \ + {{ \ + 0.0, \ + 0.0, \ + 0.69314718055994529, \ + 1.7917594692280550, \ + 3.1780538303479458, \ + 4.7874917427820458, \ + 6.5792512120101012, \ + 8.5251613610654147, \ + 10.604602902745251, \ + 12.801827480081469 \ + }} + + template + struct poisson_table { + public: + typedef sprout::array table_type; + public: + SPROUT_STATIC_CONSTEXPR table_type table + SPROUT_STATIC_CONSTEXPR_DATA_MEMBER_INNER(SPROUT_POISSON_TABLE_DEF) + ; + }; + template + SPROUT_CONSTEXPR_OR_CONST typename sprout::random::detail::poisson_table::table_type + sprout::random::detail::poisson_table::table + SPROUT_STATIC_CONSTEXPR_DATA_MEMBER_OUTER(SPROUT_POISSON_TABLE_DEF) + ; + +# undef SPROUT_POISSON_TABLE_DEF + } // namespace detail + // + // poisson_distribution + // + template + class poisson_distribution { + public: + typedef RealType input_type; + typedef IntType result_type; + public: + // + // param_type + // + class param_type { + public: + typedef poisson_distribution distribution_type; + private: + RealType mean_; + public: + SPROUT_CONSTEXPR param_type() + : mean_(RealType(1)) + {} + param_type(param_type const&) = default; + explicit SPROUT_CONSTEXPR param_type(RealType mean_arg) + : mean_((SPROUT_ASSERT(RealType(0) < mean_arg), mean_arg)) + {} + SPROUT_CONSTEXPR RealType mean() const SPROUT_NOEXCEPT { + return mean_; + } + template + friend SPROUT_NON_CONSTEXPR std::basic_istream& operator>>( + std::basic_istream& lhs, + param_type& rhs + ) + { + RealType mean; + if (lhs >> mean) { + if (RealType(0) < mean) { + rhs.mean_ = mean; + } else { + lhs.setstate(std::ios_base::failbit); + } + } + return lhs; + } + template + friend SPROUT_NON_CONSTEXPR std::basic_ostream& operator<<( + std::basic_ostream& lhs, + param_type const& rhs + ) + { + return lhs << rhs.mean_; + } + friend SPROUT_CONSTEXPR bool operator==(param_type const& lhs, param_type const& rhs) SPROUT_NOEXCEPT { + return lhs.mean_ == rhs.mean_; + } + friend SPROUT_CONSTEXPR bool operator!=(param_type const& lhs, param_type const& rhs) SPROUT_NOEXCEPT { + return !(lhs == rhs); + } + }; + struct ptrd_type { + public: + RealType v_r; + RealType a; + RealType b; + RealType smu; + RealType inv_alpha; + }; + private: + static SPROUT_CONSTEXPR bool use_inversion_check(RealType mean) { + return mean < 10; + } + static SPROUT_CONSTEXPR ptrd_type init_ptrd_2(RealType smu, RealType b) { + return ptrd_type{ + 0.9277 - 3.6224 / (b - 2), + -0.059 + 0.02483 * b, + b, + smu, + 1.1239 + 1.1328 / (b - 3.4) + }; + } + static SPROUT_CONSTEXPR ptrd_type init_ptrd_1(RealType smu) { + return init_ptrd_2(smu, 0.931 + 2.53 * smu); + } + static SPROUT_CONSTEXPR ptrd_type init_ptrd(RealType mean) { + return use_inversion_check(mean) ? ptrd_type() + : init_ptrd_1(sprout::sqrt(mean)) + ; + } + static SPROUT_CONSTEXPR RealType init_exp_mean(RealType mean) { + return !use_inversion_check(mean) ? RealType() + : sprout::exp(-mean) + ; + } + static SPROUT_CONSTEXPR RealType log_factorial(IntType k) { + return sprout::random::detail::poisson_table::table[k]; + } + static SPROUT_CONSTEXPR RealType log_sqrt_2pi() { + return 0.91893853320467267; + } + static SPROUT_CONSTEXPR RealType generate_us(RealType u) { + return 0.5 - sprout::abs(u); + } + private: + RealType mean_; + ptrd_type ptrd_; + RealType exp_mean_; + private: + SPROUT_CONSTEXPR bool use_inversion() const { + return use_inversion_check(mean_); + } + template + SPROUT_CXX14_CONSTEXPR result_type do_invert(Engine& eng) const { + RealType u = sprout::random::uniform_01()(eng); + IntType x = 0; + RealType p = exp_mean_; + while (u > p) { + u -= p; + ++x; + p = p * mean_ / x; + } + return x; + } + template + SPROUT_CONSTEXPR sprout::random::random_result invert_2(Random const& rnd, RealType u, IntType x, RealType p) const { + return !(u > p) ? sprout::random::random_result(x, sprout::random::next(rnd).engine(), *this) + : invert_2(rnd, u - p, x + 1, p * mean_ / (x + 1)) + ; + } + template + SPROUT_CONSTEXPR sprout::random::random_result invert_1(Random const& rnd, IntType x, RealType p) const { + return invert_2(rnd, sprout::random::result(rnd), x, p); + } + template + SPROUT_CONSTEXPR sprout::random::random_result invert(Engine const& eng) const { + return invert_1(sprout::random::uniform_01()(eng), 0, exp_mean_); + } + template + SPROUT_CXX14_CONSTEXPR result_type do_generate(Engine& eng) const { + for (; ; ) { + RealType v = sprout::random::uniform_01()(eng); + if (v <= 0.86 * ptrd_.v_r) { + RealType u = v / ptrd_.v_r - 0.43; + return static_cast(sprout::floor((2 * ptrd_.a / (0.5 - sprout::abs(u)) + ptrd_.b) * u + mean_ + 0.445)); + } + RealType u = RealType(); + if (v >= ptrd_.v_r) { + u = sprout::random::uniform_01()(eng) - 0.5; + } else { + u = v / ptrd_.v_r - 0.93; + u = ((u < 0) ? -0.5 : 0.5) - u; + v = sprout::random::uniform_01()(eng) * ptrd_.v_r; + } + RealType us = 0.5 - sprout::abs(u); + if (us < 0.013 && v > us) { + continue; + } + RealType k = sprout::floor((2 * ptrd_.a / us + ptrd_.b) * u + mean_ + 0.445); + v = v * ptrd_.inv_alpha / (ptrd_.a / (us * us) + ptrd_.b); + if ((k >= 10 && sprout::log(v * ptrd_.smu) <= (k + 0.5) * sprout::log(mean_ / k) - mean_ - log_sqrt_2pi() + k - (1 / 12. - (1 / 360. - 1 / (1260. * k * k)) / (k * k)) / k) + || (k >= 0 && sprout::log(v) <= k * sprout::log(mean_) - mean_ - log_factorial(static_cast(k))) + ) + { + return static_cast(k); + } + } + } + template + SPROUT_CONSTEXPR sprout::random::random_result generate_4(Random const& rnd, RealType u) const { + return sprout::random::random_result( + static_cast(sprout::floor((2 * ptrd_.a / (0.5 - sprout::abs(u)) + ptrd_.b) * u + mean_ + 0.445)), + sprout::random::next(rnd).engine(), + *this + ); + } +#ifdef SPROUT_WORKAROUND_NOT_TERMINATE_RECURSIVE_CONSTEXPR_FUNCTION_TEMPLATE + template + SPROUT_CONSTEXPR sprout::random::random_result generate_3(Random const& rnd, RealType k, RealType v) const { + return (k >= 10 && sprout::log(v * ptrd_.smu) <= (k + 0.5) * sprout::log(mean_ / k) - mean_ - log_sqrt_2pi() + k - (1 / 12. - (1 / 360. - 1 / (1260. * k * k)) / (k * k)) / k) + || (k >= 0 && sprout::log(v) <= k * sprout::log(mean_) - mean_ - log_factorial(static_cast(k))) + ? sprout::random::random_result(static_cast(k), sprout::random::next(rnd).engine(), *this) + : generate_1(rnd()) + ; + } + template + SPROUT_CONSTEXPR sprout::random::random_result generate_3(Random const&, RealType, RealType) const { + return sprout::throw_recursive_function_template_instantiation_exeeded(); + } + template + SPROUT_CONSTEXPR sprout::random::random_result generate_2(Random const& rnd, RealType v, RealType u, RealType us) const { + return us < 0.013 && v > us ? generate_1(rnd()) + : generate_3( + rnd, + sprout::floor((2 * ptrd_.a / us + ptrd_.b) * u + mean_ + 0.445), + v * ptrd_.inv_alpha / (ptrd_.a / (us * us) + ptrd_.b) + ) + ; + } + template + SPROUT_CONSTEXPR sprout::random::random_result generate_2(Random const&, RealType, RealType, RealType) const { + return sprout::throw_recursive_function_template_instantiation_exeeded(); + } + template + SPROUT_CONSTEXPR sprout::random::random_result generate_1_3(Random const& rnd, RealType v, RealType u) const { + return generate_2(rnd, sprout::random::result(rnd), ((u < 0) ? -0.5 : 0.5) - u, generate_us(((u < 0) ? -0.5 : 0.5) - u)); + } + template + SPROUT_CONSTEXPR sprout::random::random_result generate_1_3(Random const&, RealType, RealType) const { + return sprout::throw_recursive_function_template_instantiation_exeeded(); + } + template + SPROUT_CONSTEXPR sprout::random::random_result generate_1_2(Random const& rnd, RealType v) const { + return generate_2(rnd, v, sprout::random::result(rnd) - 0.5, generate_us(sprout::random::result(rnd) - 0.5)); + } + template + SPROUT_CONSTEXPR sprout::random::random_result generate_1_2(Random const&, RealType) const { + return sprout::throw_recursive_function_template_instantiation_exeeded(); + } + template + SPROUT_CONSTEXPR sprout::random::random_result generate_1(Random const& rnd) const { + return sprout::random::result(rnd) <= 0.86 * ptrd_.v_r ? generate_4(rnd, sprout::random::result(rnd) / ptrd_.v_r - 0.43) + : sprout::random::result(rnd) >= ptrd_.v_r + ? generate_1_2(rnd(), sprout::random::result(rnd)) + : generate_1_3(rnd(), sprout::random::result(rnd), sprout::random::result(rnd) / ptrd_.v_r - 0.93) + ; + } + template + SPROUT_CONSTEXPR sprout::random::random_result generate_1(Random const&) const { + return sprout::throw_recursive_function_template_instantiation_exeeded(); + } +#else + template + SPROUT_CONSTEXPR sprout::random::random_result generate_3(Random const& rnd, RealType k, RealType v) const { + return (k >= 10 && sprout::log(v * ptrd_.smu) <= (k + 0.5) * sprout::log(mean_ / k) - mean_ - log_sqrt_2pi() + k - (1 / 12. - (1 / 360. - 1 / (1260. * k * k)) / (k * k)) / k) + || (k >= 0 && sprout::log(v) <= k * sprout::log(mean_) - mean_ - log_factorial(static_cast(k))) + ? sprout::random::random_result(static_cast(k), sprout::random::next(rnd).engine(), *this) + : generate_1(rnd()) + ; + } + template + SPROUT_CONSTEXPR sprout::random::random_result generate_2(Random const& rnd, RealType v, RealType u, RealType us) const { + return us < 0.013 && v > us ? generate_1(rnd()) + : generate_3( + rnd, + sprout::floor((2 * ptrd_.a / us + ptrd_.b) * u + mean_ + 0.445), + v * ptrd_.inv_alpha / (ptrd_.a / (us * us) + ptrd_.b) + ) + ; + } + template + SPROUT_CONSTEXPR sprout::random::random_result generate_1_3(Random const& rnd, RealType v, RealType u) const { + return generate_2(rnd, sprout::random::result(rnd), ((u < 0) ? -0.5 : 0.5) - u, generate_us(((u < 0) ? -0.5 : 0.5) - u)); + } + template + SPROUT_CONSTEXPR sprout::random::random_result generate_1_2(Random const& rnd, RealType v) const { + return generate_2(rnd, v, sprout::random::result(rnd) - 0.5, generate_us(sprout::random::result(rnd) - 0.5)); + } + template + SPROUT_CONSTEXPR sprout::random::random_result generate_1(Random const& rnd) const { + return sprout::random::result(rnd) <= 0.86 * ptrd_.v_r ? generate_4(rnd, sprout::random::result(rnd) / ptrd_.v_r - 0.43) + : sprout::random::result(rnd) >= ptrd_.v_r + ? generate_1_2(rnd(), sprout::random::result(rnd)) + : generate_1_3(rnd(), sprout::random::result(rnd), sprout::random::result(rnd) / ptrd_.v_r - 0.93) + ; + } +#endif + template + SPROUT_CONSTEXPR sprout::random::random_result generate(Engine const& eng) const { + return generate_1(sprout::random::uniform_01()(eng)); + } + public: + SPROUT_CONSTEXPR poisson_distribution() + : mean_(RealType(1)) + , ptrd_(init_ptrd(RealType(1))) + , exp_mean_(init_exp_mean(RealType(1))) + {} + poisson_distribution(poisson_distribution const&) = default; + explicit SPROUT_CONSTEXPR poisson_distribution(RealType mean_arg) + : mean_((SPROUT_ASSERT(RealType(0) < mean_arg), mean_arg)) + , ptrd_(init_ptrd(mean_arg)) + , exp_mean_(init_exp_mean(mean_arg)) + {} + explicit SPROUT_CONSTEXPR poisson_distribution(param_type const& parm) + : mean_(parm.mean()) + , ptrd_(init_ptrd(parm.mean())) + , exp_mean_(init_exp_mean(parm.mean())) + {} + SPROUT_CONSTEXPR result_type mean() const SPROUT_NOEXCEPT { + return mean_; + } + SPROUT_CONSTEXPR result_type min() const SPROUT_NOEXCEPT { + return 0; + } + SPROUT_CONSTEXPR result_type max() const SPROUT_NOEXCEPT { + return sprout::numeric_limits::max(); + } + SPROUT_CXX14_CONSTEXPR void reset() SPROUT_NOEXCEPT {} + SPROUT_CONSTEXPR param_type param() const SPROUT_NOEXCEPT { + return param_type(mean_); + } + SPROUT_CXX14_CONSTEXPR void param(param_type const& parm) { + mean_ = parm.mean(); + ptrd_ = init_ptrd(mean_); + exp_mean_ = init_exp_mean(mean_); + } + template + SPROUT_CXX14_CONSTEXPR result_type operator()(Engine& eng) const { + return use_inversion() ? do_invert(eng) + : do_generate(eng) + ; + } + template + SPROUT_CONSTEXPR sprout::random::random_result const operator()(Engine const& eng) const { + return use_inversion() ? invert(eng) + : generate(eng) + ; + } + template + SPROUT_CXX14_CONSTEXPR result_type operator()(Engine& eng, param_type const& parm) const { + return poisson_distribution(parm)(eng); + } + template + SPROUT_CONSTEXPR sprout::random::random_result const operator()(Engine const& eng, param_type const& parm) const { + return poisson_distribution(parm)(eng); + } + template + friend SPROUT_NON_CONSTEXPR std::basic_istream& operator>>( + std::basic_istream& lhs, + poisson_distribution& rhs + ) + { + param_type parm; + if (lhs >> parm) { + rhs.param(parm); + } + return lhs; + } + template + friend SPROUT_NON_CONSTEXPR std::basic_ostream& operator<<( + std::basic_ostream& lhs, + poisson_distribution const& rhs + ) + { + return lhs << rhs.param(); + } + friend SPROUT_CONSTEXPR bool operator==(poisson_distribution const& lhs, poisson_distribution const& rhs) SPROUT_NOEXCEPT { + return lhs.param() == rhs.param(); + } + friend SPROUT_CONSTEXPR bool operator!=(poisson_distribution const& lhs, poisson_distribution const& rhs) SPROUT_NOEXCEPT { + return !(lhs == rhs); + } + }; + } // namespace random + + using sprout::random::poisson_distribution; +} // namespace sprout + +#endif // #ifndef SPROUT_RANDOM_POISSON_DISTRIBUTION_HPP