From: Christopher Small Date: Tue, 27 Dec 2016 00:45:14 +0000 (-0800) Subject: pass RNG by reference so state is updated on each call X-Git-Tag: v2017.03.06.00~161 X-Git-Url: http://demsky.eecs.uci.edu/git/?a=commitdiff_plain;h=f2987ecd74d921af398f7592762fbf1ec67bab1a;p=folly.git pass RNG by reference so state is updated on each call Summary: folly::Random was taking the RNG by value (not reference) so it was not updating the RNG's state on each invocation -- so the RNG would not advance to the next value in the sequence. Reviewed By: yfeldblum, nbronson Differential Revision: D4362999 fbshipit-source-id: f93fc11911b92e230ac0cc2406151474d15f85af --- diff --git a/folly/Random.h b/folly/Random.h index 0e52772a..3cb553da 100644 --- a/folly/Random.h +++ b/folly/Random.h @@ -133,25 +133,22 @@ class Random { * Returns a random uint32_t */ static uint32_t rand32() { - ThreadLocalPRNG prng; - return rand32(prng); + return rand32(ThreadLocalPRNG()); } /** * Returns a random uint32_t given a specific RNG */ template > - static uint32_t rand32(RNG rng) { - uint32_t r = rng.operator()(); - return r; + static uint32_t rand32(RNG&& rng) { + return rng(); } /** * Returns a random uint32_t in [0, max). If max == 0, returns 0. */ static uint32_t rand32(uint32_t max) { - ThreadLocalPRNG prng; - return rand32(max, prng); + return rand32(0, max, ThreadLocalPRNG()); } /** @@ -159,84 +156,123 @@ class Random { * If max == 0, returns 0. */ template > - static uint32_t rand32(uint32_t max, RNG rng = RNG()) { - if (max == 0) { - return 0; - } - - return std::uniform_int_distribution(0, max - 1)(rng); + static uint32_t rand32(uint32_t max, RNG&& rng) { + return rand32(0, max, rng); } /** * Returns a random uint32_t in [min, max). If min == max, returns 0. */ + static uint32_t rand32(uint32_t min, uint32_t max) { + return rand32(min, max, ThreadLocalPRNG()); + } + + /** + * Returns a random uint32_t in [min, max) given a specific RNG. + * If min == max, returns 0. + */ template > - static uint32_t rand32(uint32_t min, uint32_t max, RNG rng = RNG()) { + static uint32_t rand32(uint32_t min, uint32_t max, RNG&& rng) { if (min == max) { return 0; } - return std::uniform_int_distribution(min, max - 1)(rng); } + /** + * Returns a random uint64_t + */ + static uint64_t rand64() { + return rand64(ThreadLocalPRNG()); + } + /** * Returns a random uint64_t */ template > - static uint64_t rand64(RNG rng = RNG()) { - return ((uint64_t) rng() << 32) | rng(); + static uint64_t rand64(RNG&& rng) { + return ((uint64_t)rng() << 32) | rng(); + } + + /** + * Returns a random uint64_t in [0, max). If max == 0, returns 0. + */ + static uint64_t rand64(uint64_t max) { + return rand64(0, max, ThreadLocalPRNG()); } /** * Returns a random uint64_t in [0, max). If max == 0, returns 0. */ template > - static uint64_t rand64(uint64_t max, RNG rng = RNG()) { - if (max == 0) { - return 0; - } + static uint64_t rand64(uint64_t max, RNG&& rng) { + return rand64(0, max, rng); + } - return std::uniform_int_distribution(0, max - 1)(rng); + /** + * Returns a random uint64_t in [min, max). If min == max, returns 0. + */ + static uint64_t rand64(uint64_t min, uint64_t max) { + return rand64(min, max, ThreadLocalPRNG()); } /** * Returns a random uint64_t in [min, max). If min == max, returns 0. */ template > - static uint64_t rand64(uint64_t min, uint64_t max, RNG rng = RNG()) { + static uint64_t rand64(uint64_t min, uint64_t max, RNG&& rng) { if (min == max) { return 0; } - return std::uniform_int_distribution(min, max - 1)(rng); } + /** + * Returns true 1/n of the time. If n == 0, always returns false + */ + static bool oneIn(uint32_t n) { + return oneIn(n, ThreadLocalPRNG()); + } + /** * Returns true 1/n of the time. If n == 0, always returns false */ template > - static bool oneIn(uint32_t n, ValidRNG rng = RNG()) { + static bool oneIn(uint32_t n, RNG&& rng) { if (n == 0) { return false; } + return rand32(0, n, rng) == 0; + } - return rand32(n, rng) == 0; + /** + * Returns a double in [0, 1) + */ + static double randDouble01() { + return randDouble01(ThreadLocalPRNG()); } /** * Returns a double in [0, 1) */ template > - static double randDouble01(RNG rng = RNG()) { - return std::generate_canonical::digits> - (rng); + static double randDouble01(RNG&& rng) { + return std::generate_canonical::digits>( + rng); + } + + /** + * Returns a double in [min, max), if min == max, returns 0. + */ + static double randDouble(double min, double max) { + return randDouble(min, max, ThreadLocalPRNG()); } /** * Returns a double in [min, max), if min == max, returns 0. */ template > - static double randDouble(double min, double max, RNG rng = RNG()) { + static double randDouble(double min, double max, RNG&& rng) { if (std::fabs(max - min) < std::numeric_limits::epsilon()) { return 0; } diff --git a/folly/test/RandomTest.cpp b/folly/test/RandomTest.cpp index 7d175b6e..da906f69 100644 --- a/folly/test/RandomTest.cpp +++ b/folly/test/RandomTest.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include @@ -94,3 +95,46 @@ TEST(Random, MultiThreaded) { EXPECT_LT(seeds[i], seeds[i+1]); } } + +TEST(Random, sanity) { + // edge cases + EXPECT_EQ(folly::Random::rand32(0), 0); + EXPECT_EQ(folly::Random::rand32(12, 12), 0); + EXPECT_EQ(folly::Random::rand64(0), 0); + EXPECT_EQ(folly::Random::rand64(12, 12), 0); + + // 32-bit repeatability, uniqueness + constexpr int kTestSize = 1000; + { + std::vector vals; + folly::Random::DefaultGenerator rng; + rng.seed(0xdeadbeef); + for (int i = 0; i < kTestSize; ++i) { + vals.push_back(folly::Random::rand32(rng)); + } + rng.seed(0xdeadbeef); + for (int i = 0; i < kTestSize; ++i) { + EXPECT_EQ(vals[i], folly::Random::rand32(rng)); + } + EXPECT_EQ( + vals.size(), + std::unordered_set(vals.begin(), vals.end()).size()); + } + + // 64-bit repeatability, uniqueness + { + std::vector vals; + folly::Random::DefaultGenerator rng; + rng.seed(0xdeadbeef); + for (int i = 0; i < kTestSize; ++i) { + vals.push_back(folly::Random::rand64(rng)); + } + rng.seed(0xdeadbeef); + for (int i = 0; i < kTestSize; ++i) { + EXPECT_EQ(vals[i], folly::Random::rand64(rng)); + } + EXPECT_EQ( + vals.size(), + std::unordered_set(vals.begin(), vals.end()).size()); + } +}