Correctly deduce RNG type in folly::Random
authorDavid Goldblatt <davidgoldblatt@fb.com>
Fri, 15 Apr 2016 17:14:13 +0000 (10:14 -0700)
committerFacebook Github Bot 9 <facebook-github-bot-9-bot@fb.com>
Fri, 15 Apr 2016 17:20:24 +0000 (10:20 -0700)
Summary:Fix a bug in folly where it can't infer any RNG type except the
default-supplied parameter.

Differential Revision: D3163421

fb-gh-sync-id: 23d3963ba19dac93fa3407f3a2dfd1d9aa39ea44
fbshipit-source-id: 23d3963ba19dac93fa3407f3a2dfd1d9aa39ea44

folly/Random-inl.h
folly/Random.h
folly/test/RandomTest.cpp

index 0ff458e852abf9dce94ad9778506db15be508354..d2a67a82dfbe61756931c944386d70df496097fa 100644 (file)
@@ -118,15 +118,15 @@ struct SeedData {
 
 }  // namespace detail
 
-template <class RNG>
-void Random::seed(ValidRNG<RNG>& rng) {
+template <class RNG, class /* EnableIf */>
+void Random::seed(RNG& rng) {
   detail::SeedData<RNG> sd;
   std::seed_seq s(std::begin(sd.seedData), std::end(sd.seedData));
   rng.seed(s);
 }
 
-template <class RNG>
-auto Random::create() -> ValidRNG<RNG> {
+template <class RNG, class /* EnableIf */>
+auto Random::create() -> RNG {
   detail::SeedData<RNG> sd;
   std::seed_seq s(std::begin(sd.seedData), std::end(sd.seedData));
   return RNG(s);
index 0cfc5efe8a991a2b19507e8f7dfa6c7039c63d57..0962d07e15445c6ae47385585073d680ea87ceed 100644 (file)
@@ -76,10 +76,10 @@ class ThreadLocalPRNG {
 class Random {
 
  private:
-  template<class RNG>
+  template <class RNG>
   using ValidRNG = typename std::enable_if<
-   std::is_unsigned<typename std::result_of<RNG&()>::type>::value,
-   RNG>::type;
+      std::is_unsigned<typename std::result_of<RNG&()>::type>::value,
+      RNG>::type;
 
  public:
   // Default generator type.
@@ -115,8 +115,8 @@ class Random {
    * to create a RNG with a good seed in production, and seed it yourself
    * in test.
    */
-  template <class RNG = DefaultGenerator>
-  static void seed(ValidRNG<RNG>& rng);
+  template <class RNG = DefaultGenerator, class /* EnableIf */ = ValidRNG<RNG>>
+  static void seed(RNG& rng);
 
   /**
    * Create a new RNG, seeded with a good seed.
@@ -126,14 +126,22 @@ class Random {
    * to create a RNG with a good seed in production, and seed it yourself
    * in test.
    */
-  template <class RNG = DefaultGenerator>
-  static ValidRNG<RNG> create();
+  template <class RNG = DefaultGenerator, class /* EnableIf */ = ValidRNG<RNG>>
+  static RNG create();
 
   /**
    * Returns a random uint32_t
    */
-  template<class RNG = ThreadLocalPRNG>
-  static uint32_t rand32(ValidRNG<RNG> rng = RNG()) {
+  static uint32_t rand32() {
+    ThreadLocalPRNG prng;
+    return rand32(prng);
+  }
+
+  /**
+   * Returns a random uint32_t given a specific RNG
+   */
+  template <class RNG, class /* EnableIf */ = ValidRNG<RNG>>
+  static uint32_t rand32(RNG rng) {
     uint32_t r = rng.operator()();
     return r;
   }
@@ -141,8 +149,17 @@ class Random {
   /**
    * Returns a random uint32_t in [0, max). If max == 0, returns 0.
    */
-  template<class RNG = ThreadLocalPRNG>
-  static uint32_t rand32(uint32_t max, ValidRNG<RNG> rng = RNG()) {
+  static uint32_t rand32(uint32_t max) {
+    ThreadLocalPRNG prng;
+    return rand32(max, prng);
+  }
+
+  /**
+   * Returns a random uint32_t in [0, max) given a specific RNG.
+   * If max == 0, returns 0.
+   */
+  template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
+  static uint32_t rand32(uint32_t max, RNG rng = RNG()) {
     if (max == 0) {
       return 0;
     }
@@ -153,10 +170,8 @@ class Random {
   /**
    * Returns a random uint32_t in [min, max). If min == max, returns 0.
    */
-  template<class RNG = ThreadLocalPRNG>
-  static uint32_t rand32(uint32_t min,
-                         uint32_t max,
-                         ValidRNG<RNG> rng = RNG()) {
+  template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
+  static uint32_t rand32(uint32_t min, uint32_t max, RNG rng = RNG()) {
     if (min == max) {
       return 0;
     }
@@ -167,16 +182,16 @@ class Random {
   /**
    * Returns a random uint64_t
    */
-  template<class RNG = ThreadLocalPRNG>
-  static uint64_t rand64(ValidRNG<RNG> rng = RNG()) {
+  template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
+  static uint64_t rand64(RNG rng = RNG()) {
     return ((uint64_t) rng() << 32) | rng();
   }
 
   /**
    * Returns a random uint64_t in [0, max). If max == 0, returns 0.
    */
-  template<class RNG = ThreadLocalPRNG>
-  static uint64_t rand64(uint64_t max, ValidRNG<RNG> rng = RNG()) {
+  template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
+  static uint64_t rand64(uint64_t max, RNG rng = RNG()) {
     if (max == 0) {
       return 0;
     }
@@ -187,10 +202,8 @@ class Random {
   /**
    * Returns a random uint64_t in [min, max). If min == max, returns 0.
    */
-  template<class RNG = ThreadLocalPRNG>
-  static uint64_t rand64(uint64_t min,
-                         uint64_t max,
-                         ValidRNG<RNG> rng = RNG()) {
+  template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
+  static uint64_t rand64(uint64_t min, uint64_t max, RNG rng = RNG()) {
     if (min == max) {
       return 0;
     }
@@ -201,7 +214,7 @@ class Random {
   /**
    * Returns true 1/n of the time. If n == 0, always returns false
    */
-  template<class RNG = ThreadLocalPRNG>
+  template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
   static bool oneIn(uint32_t n, ValidRNG<RNG> rng = RNG()) {
     if (n == 0) {
       return false;
@@ -213,8 +226,8 @@ class Random {
   /**
    * Returns a double in [0, 1)
    */
-  template<class RNG = ThreadLocalPRNG>
-  static double randDouble01(ValidRNG<RNG> rng = RNG()) {
+  template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
+  static double randDouble01(RNG rng = RNG()) {
     return std::generate_canonical<double, std::numeric_limits<double>::digits>
       (rng);
   }
@@ -222,8 +235,8 @@ class Random {
   /**
     * Returns a double in [min, max), if min == max, returns 0.
     */
-  template<class RNG = ThreadLocalPRNG>
-  static double randDouble(double min, double max, ValidRNG<RNG> rng = RNG()) {
+  template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
+  static double randDouble(double min, double max, RNG rng = RNG()) {
     if (std::fabs(max - min) < std::numeric_limits<double>::epsilon()) {
       return 0;
     }
index 2f3c97d5cc03180e185f986886fb69c5354d91de..69447872efdd0b5e23166025c6b15fda56483941 100644 (file)
@@ -47,6 +47,33 @@ TEST(Random, Simple) {
   }
 }
 
+TEST(Random, FixedSeed) {
+  // clang-format off
+  struct ConstantRNG {
+    typedef uint32_t result_type;
+    result_type operator()() {
+      return 4; // chosen by fair dice roll.
+                // guaranteed to be random.
+    }
+    result_type min() {
+      return std::numeric_limits<result_type>::min();
+    }
+    result_type max() {
+      return std::numeric_limits<result_type>::max();
+    }
+  };
+  // clang-format on
+
+  ConstantRNG gen;
+  // Loop to make sure it really is constant.
+  for (int i = 0; i < 1024; ++i) {
+    auto result = Random::rand32(10, gen);
+    // TODO: This is a little bit brittle; standard library changes could break
+    // it, if it starts implementing distribution types differently.
+    EXPECT_EQ(0, result);
+  }
+}
+
 TEST(Random, MultiThreaded) {
   const int n = 100;
   std::vector<uint32_t> seeds(n);