From 015f5dc525643e73a517cf22046ded72c5b6224d Mon Sep 17 00:00:00 2001 From: David Goldblatt Date: Fri, 15 Apr 2016 10:14:13 -0700 Subject: [PATCH] Correctly deduce RNG type in folly::Random Summary:Fix a bug in folly where it can't infer any RNG type except the default-supplied parameter. Differential Revision: D3163421 fb-gh-sync-id: 23d3963ba19dac93fa3407f3a2dfd1d9aa39ea44 fbshipit-source-id: 23d3963ba19dac93fa3407f3a2dfd1d9aa39ea44 --- folly/Random-inl.h | 8 ++--- folly/Random.h | 69 +++++++++++++++++++++++---------------- folly/test/RandomTest.cpp | 27 +++++++++++++++ 3 files changed, 72 insertions(+), 32 deletions(-) diff --git a/folly/Random-inl.h b/folly/Random-inl.h index 0ff458e8..d2a67a82 100644 --- a/folly/Random-inl.h +++ b/folly/Random-inl.h @@ -118,15 +118,15 @@ struct SeedData { } // namespace detail -template -void Random::seed(ValidRNG& rng) { +template +void Random::seed(RNG& rng) { detail::SeedData sd; std::seed_seq s(std::begin(sd.seedData), std::end(sd.seedData)); rng.seed(s); } -template -auto Random::create() -> ValidRNG { +template +auto Random::create() -> RNG { detail::SeedData sd; std::seed_seq s(std::begin(sd.seedData), std::end(sd.seedData)); return RNG(s); diff --git a/folly/Random.h b/folly/Random.h index 0cfc5efe..0962d07e 100644 --- a/folly/Random.h +++ b/folly/Random.h @@ -76,10 +76,10 @@ class ThreadLocalPRNG { class Random { private: - template + template using ValidRNG = typename std::enable_if< - std::is_unsigned::type>::value, - RNG>::type; + std::is_unsigned::type>::value, + RNG>::type; public: // Default generator type. @@ -115,8 +115,8 @@ class Random { * to create a RNG with a good seed in production, and seed it yourself * in test. */ - template - static void seed(ValidRNG& rng); + template > + static void seed(RNG& rng); /** * Create a new RNG, seeded with a good seed. @@ -126,14 +126,22 @@ class Random { * to create a RNG with a good seed in production, and seed it yourself * in test. */ - template - static ValidRNG create(); + template > + static RNG create(); /** * Returns a random uint32_t */ - template - static uint32_t rand32(ValidRNG rng = RNG()) { + static uint32_t rand32() { + ThreadLocalPRNG prng; + return rand32(prng); + } + + /** + * Returns a random uint32_t given a specific RNG + */ + template > + static uint32_t rand32(RNG rng) { uint32_t r = rng.operator()(); return r; } @@ -141,8 +149,17 @@ class Random { /** * Returns a random uint32_t in [0, max). If max == 0, returns 0. */ - template - static uint32_t rand32(uint32_t max, ValidRNG rng = RNG()) { + static uint32_t rand32(uint32_t max) { + ThreadLocalPRNG prng; + return rand32(max, prng); + } + + /** + * Returns a random uint32_t in [0, max) given a specific RNG. + * If max == 0, returns 0. + */ + template > + static uint32_t rand32(uint32_t max, RNG rng = RNG()) { if (max == 0) { return 0; } @@ -153,10 +170,8 @@ class Random { /** * Returns a random uint32_t in [min, max). If min == max, returns 0. */ - template - static uint32_t rand32(uint32_t min, - uint32_t max, - ValidRNG rng = RNG()) { + template > + static uint32_t rand32(uint32_t min, uint32_t max, RNG rng = RNG()) { if (min == max) { return 0; } @@ -167,16 +182,16 @@ class Random { /** * Returns a random uint64_t */ - template - static uint64_t rand64(ValidRNG rng = RNG()) { + template > + static uint64_t rand64(RNG rng = RNG()) { return ((uint64_t) rng() << 32) | rng(); } /** * Returns a random uint64_t in [0, max). If max == 0, returns 0. */ - template - static uint64_t rand64(uint64_t max, ValidRNG rng = RNG()) { + template > + static uint64_t rand64(uint64_t max, RNG rng = RNG()) { if (max == 0) { return 0; } @@ -187,10 +202,8 @@ class Random { /** * Returns a random uint64_t in [min, max). If min == max, returns 0. */ - template - static uint64_t rand64(uint64_t min, - uint64_t max, - ValidRNG rng = RNG()) { + template > + static uint64_t rand64(uint64_t min, uint64_t max, RNG rng = RNG()) { if (min == max) { return 0; } @@ -201,7 +214,7 @@ class Random { /** * Returns true 1/n of the time. If n == 0, always returns false */ - template + template > static bool oneIn(uint32_t n, ValidRNG rng = RNG()) { if (n == 0) { return false; @@ -213,8 +226,8 @@ class Random { /** * Returns a double in [0, 1) */ - template - static double randDouble01(ValidRNG rng = RNG()) { + template > + static double randDouble01(RNG rng = RNG()) { return std::generate_canonical::digits> (rng); } @@ -222,8 +235,8 @@ class Random { /** * Returns a double in [min, max), if min == max, returns 0. */ - template - static double randDouble(double min, double max, ValidRNG rng = RNG()) { + template > + static double randDouble(double min, double max, RNG rng = RNG()) { if (std::fabs(max - min) < std::numeric_limits::epsilon()) { return 0; } diff --git a/folly/test/RandomTest.cpp b/folly/test/RandomTest.cpp index 2f3c97d5..69447872 100644 --- a/folly/test/RandomTest.cpp +++ b/folly/test/RandomTest.cpp @@ -47,6 +47,33 @@ TEST(Random, Simple) { } } +TEST(Random, FixedSeed) { + // clang-format off + struct ConstantRNG { + typedef uint32_t result_type; + result_type operator()() { + return 4; // chosen by fair dice roll. + // guaranteed to be random. + } + result_type min() { + return std::numeric_limits::min(); + } + result_type max() { + return std::numeric_limits::max(); + } + }; + // clang-format on + + ConstantRNG gen; + // Loop to make sure it really is constant. + for (int i = 0; i < 1024; ++i) { + auto result = Random::rand32(10, gen); + // TODO: This is a little bit brittle; standard library changes could break + // it, if it starts implementing distribution types differently. + EXPECT_EQ(0, result); + } +} + TEST(Random, MultiThreaded) { const int n = 100; std::vector seeds(n); -- 2.34.1