diff --git a/sprout/random/mersenne_twister.hpp b/sprout/random/mersenne_twister.hpp index 3e00eb76..b0d714aa 100644 --- a/sprout/random/mersenne_twister.hpp +++ b/sprout/random/mersenne_twister.hpp @@ -249,7 +249,7 @@ namespace sprout { return rewind_1(data, last, z, x_[m - 1] ^ x_[n - 1]); } SPROUT_CONSTEXPR bool - equal_impl_2(mersenne_twister_engine const& other, sprout::array back, std::size_t offset, std::size_t i = 0) const { + equal_impl_2(mersenne_twister_engine const& other, sprout::array const& back, std::size_t offset, std::size_t i = 0) const { return i < offset ? back[i + n - offset] != other.x_[i] ? false @@ -258,7 +258,7 @@ namespace sprout { ; } SPROUT_CONSTEXPR bool - equal_impl_1(mersenne_twister_engine const& other, sprout::array back, std::size_t offset, std::size_t i = 0) const { + equal_impl_1(mersenne_twister_engine const& other, sprout::array const& back, std::size_t offset, std::size_t i = 0) const { return i + offset < n ? x_[i] != other.x_[i + offset] ? false @@ -372,6 +372,70 @@ namespace sprout { SPROUT_CONSTEXPR mersenne_twister_engine twist() const { return twist_1(0); } + SPROUT_CXX14_CONSTEXPR void do_twist() { + UIntType const upper_mask = (~static_cast(0)) << r; + UIntType const lower_mask = ~upper_mask; + std::size_t const unroll_factor = 6; + std::size_t const unroll_extra1 = (n - m) % unroll_factor; + std::size_t const unroll_extra2 = (m - 1) % unroll_factor; + for (std::size_t j = 0; j < n - m - unroll_extra1; ++j) { + UIntType y = (x_[j] & upper_mask) | (x_[j + 1] & lower_mask); + x_[j] = x_[j+m] ^ (y >> 1) ^ ((x_[j + 1] & 1) * a); + } + for (std::size_t j = n - m - unroll_extra1; j < n-m; ++j) { + UIntType y = (x_[j] & upper_mask) | (x_[j + 1] & lower_mask); + x_[j] = x_[j + m] ^ (y >> 1) ^ ((x_[j + 1] & 1) * a); + } + for (std::size_t j = n - m; j < n - 1 - unroll_extra2; ++j) { + UIntType y = (x_[j] & upper_mask) | (x_[j + 1] & lower_mask); + x_[j] = x_[j - (n - m)] ^ (y >> 1) ^ ((x_[j + 1] & 1) * a); + } + for (std::size_t j = n - 1 - unroll_extra2; j < n - 1; ++j) { + UIntType y = (x_[j] & upper_mask) | (x_[j + 1] & lower_mask); + x_[j] = x_[j - (n - m)] ^ (y >> 1) ^ ((x_[j + 1] & 1) * a); + } + { + UIntType y = (x_[n - 1] & upper_mask) | (x_[0] & lower_mask); + x_[n - 1] = x_[m - 1] ^ (y >> 1) ^ ((x_[0] & 1) * a); + } + i_ = 0; + } + SPROUT_CXX14_CONSTEXPR void do_rewind(UIntType* last, std::size_t z) const { + UIntType const upper_mask = (~static_cast(0)) << r; + UIntType const lower_mask = ~upper_mask; + UIntType y0 = x_[m - 1] ^ x_[n - 1]; + if (y0 & (static_cast(1) << (w - 1))) { + y0 = ((y0 ^ a) << 1) | 1; + } else { + y0 = y0 << 1; + } + for (std::size_t sz = 0; sz < z; ++sz) { + UIntType y1 = rewind_find(last, sz, m - 1) ^ rewind_find(last, sz, n - 1); + if (y1 & (static_cast(1) << (w - 1))) { + y1 = ((y1 ^ a) << 1) | 1; + } else { + y1 = y1 << 1; + } + *(last - sz) = (y0 & upper_mask) | (y1 & lower_mask); + y0 = y1; + } + } + SPROUT_CXX14_CONSTEXPR bool do_equal_impl(mersenne_twister_engine const& other) const { + UIntType back[n] = {}; + std::size_t offset = other.i_ - i_; + for (std::size_t j = 0; j + offset < n; ++j) { + if (x_[j] != other.x_[j + offset]) { + return false; + } + } + do_rewind(&back[n - 1], offset); + for (std::size_t j = 0; j < offset; ++j) { + if (back[j + n - offset] != other.x_[j]) { + return false; + } + } + return true; + } public: SPROUT_CONSTEXPR mersenne_twister_engine() : x_(init_seed(default_seed)) @@ -388,12 +452,35 @@ namespace sprout { SPROUT_CONSTEXPR result_type max() const SPROUT_NOEXCEPT { return static_max(); } + SPROUT_CXX14_CONSTEXPR result_type operator()() { + if (i_ == n) { + do_twist(); + } + UIntType z = x_[i_]; + ++i_; + z ^= ((z >> u) & d); + z ^= ((z << s) & b); + z ^= ((z << t) & c); + z ^= (z >> l); + return z; + } SPROUT_CONSTEXPR sprout::random::random_result const operator()() const { return i_ == n ? twist().generate() : generate() ; } +#ifndef SPROUT_CONFIG_DISABLE_CXX14_CONSTEXPR + friend SPROUT_CXX14_CONSTEXPR bool operator==(mersenne_twister_engine const& lhs, mersenne_twister_engine const& rhs) SPROUT_NOEXCEPT { + return lhs.i_ < rhs.i_ + ? lhs.do_equal_impl(rhs) + : rhs.do_equal_impl(lhs) + ; + } + friend SPROUT_CXX14_CONSTEXPR bool operator!=(mersenne_twister_engine const& lhs, mersenne_twister_engine const& rhs) SPROUT_NOEXCEPT { + return !(lhs == rhs); + } +#else // #ifndef SPROUT_CONFIG_DISABLE_CXX14_CONSTEXPR friend SPROUT_CONSTEXPR bool operator==(mersenne_twister_engine const& lhs, mersenne_twister_engine const& rhs) SPROUT_NOEXCEPT { return lhs.i_ < rhs.i_ ? lhs.equal_impl(rhs) @@ -403,6 +490,7 @@ namespace sprout { friend SPROUT_CONSTEXPR bool operator!=(mersenne_twister_engine const& lhs, mersenne_twister_engine const& rhs) SPROUT_NOEXCEPT { return !(lhs == rhs); } +#endif // #ifndef SPROUT_CONFIG_DISABLE_CXX14_CONSTEXPR template friend SPROUT_NON_CONSTEXPR std::basic_istream& operator>>( std::basic_istream& lhs,