#ifndef SPROUT_RANDOM_GEOMETRIC_DISTRIBUTION_HPP #define SPROUT_RANDOM_GEOMETRIC_DISTRIBUTION_HPP #include #include #include #include #include #include #include #include namespace sprout { namespace random { // // geometric_distribution // template class geometric_distribution { public: typedef RealType input_type; typedef IntType result_type; private: static SPROUT_CONSTEXPR bool arg_check_nothrow(RealType p_arg) { return RealType(0) < p_arg && p_arg < RealType(1); } static SPROUT_CONSTEXPR RealType arg_check(RealType p_arg) { return arg_check_nothrow(p_arg) ? p_arg : throw std::invalid_argument("geometric_distribution<>: invalid argument (0 < p_arg && p_arg < 1)") ; } public: // // param_type // class param_type { public: typedef geometric_distribution distribution_type; private: static SPROUT_CONSTEXPR bool arg_check_nothrow(RealType p_arg) { return distribution_type::arg_check_nothrow(p_arg); } private: RealType p_; public: SPROUT_CONSTEXPR param_type() : p_(RealType(0.5)) {} explicit SPROUT_CONSTEXPR param_type(RealType p_arg) : p_(arg_check(p_arg)) {} SPROUT_CONSTEXPR RealType p() const SPROUT_NOEXCEPT { return p_; } template friend std::basic_istream& operator>>( std::basic_istream& lhs, param_type& rhs ) { RealType p; if (lhs >> p) { if (arg_check_nothrow(p)) { rhs.p_ = p; } else { lhs.setstate(std::ios_base::failbit); } } return lhs; } template friend std::basic_ostream& operator<<( std::basic_ostream& lhs, param_type const& rhs ) { return lhs << rhs.p_; } friend SPROUT_CONSTEXPR bool operator==(param_type const& lhs, param_type const& rhs) SPROUT_NOEXCEPT { return lhs.p_ == rhs.p_; } friend SPROUT_CONSTEXPR bool operator!=(param_type const& lhs, param_type const& rhs) SPROUT_NOEXCEPT { return !(lhs == rhs); } }; private: public: static SPROUT_CONSTEXPR RealType init_log_1mp(RealType p) { return sprout::log(1 - p); } private: public: RealType p_; RealType log_1mp_; private: template SPROUT_CONSTEXPR sprout::random::random_result generate_1(Random const& rnd) const { return sprout::random::random_result( static_cast(sprout::floor(sprout::log(RealType(1) - rnd.result()) / log_1mp_)), rnd.engine(), *this ); } template SPROUT_CONSTEXPR sprout::random::random_result generate(Engine const& eng) const { return generate_1(sprout::random::uniform_01()(eng)); } public: SPROUT_CONSTEXPR geometric_distribution() : p_(RealType(0.5)) , log_1mp_(init_log_1mp(RealType(0.5))) {} explicit SPROUT_CONSTEXPR geometric_distribution(RealType p_arg) : p_(arg_check(p_arg)) , log_1mp_(init_log_1mp(p_arg)) {} explicit SPROUT_CONSTEXPR geometric_distribution(param_type const& parm) : p_(parm.p()) , log_1mp_(init_log_1mp(parm.p())) {} SPROUT_CONSTEXPR result_type p() const SPROUT_NOEXCEPT { return p_; } SPROUT_CONSTEXPR result_type min() const SPROUT_NOEXCEPT { return 0; } SPROUT_CONSTEXPR result_type max() const SPROUT_NOEXCEPT { return std::numeric_limits::max(); } SPROUT_CONSTEXPR param_type param() const SPROUT_NOEXCEPT { return param_type(p_); } void param(param_type const& parm) { p_ = parm.p(); log_1mp_ = init_log_1mp(p_); } template SPROUT_CONSTEXPR sprout::random::random_result operator()(Engine const& eng) const { return generate(eng); } template friend std::basic_istream& operator>>( std::basic_istream& lhs, geometric_distribution& rhs ) { param_type parm; if (lhs >> parm) { rhs.param(parm); } return lhs; } template friend std::basic_ostream& operator<<( std::basic_ostream& lhs, geometric_distribution const& rhs ) { return lhs << rhs.param(); } friend SPROUT_CONSTEXPR bool operator==(geometric_distribution const& lhs, geometric_distribution const& rhs) SPROUT_NOEXCEPT { return lhs.param() == rhs.param(); } friend SPROUT_CONSTEXPR bool operator!=(geometric_distribution const& lhs, geometric_distribution const& rhs) SPROUT_NOEXCEPT { return !(lhs == rhs); } }; } // namespace random using sprout::random::geometric_distribution; } // namespace sprout #endif // #ifndef SPROUT_RANDOM_GEOMETRIC_DISTRIBUTION_HPP