SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
SSL_CTX_set_tlsext_servername_arg(ctx_, this);
#endif
-
-#ifdef OPENSSL_NPN_NEGOTIATED
- Random::seed(nextProtocolPicker_);
-#endif
}
SSLContext::~SSLContext() {
dst += protoLength;
}
total_weight += item.weight;
+ advertised_item.probability = item.weight;
advertisedNextProtocols_.push_back(advertised_item);
- advertisedNextProtocolWeights_.push_back(item.weight);
}
if (total_weight == 0) {
deleteNextProtocolsStrings();
return false;
}
- nextProtocolDistribution_ =
- std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
- advertisedNextProtocolWeights_.end());
+ for (auto& advertised_item : advertisedNextProtocols_) {
+ advertised_item.probability /= total_weight;
+ }
if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
SSL_CTX_set_next_protos_advertised_cb(
ctx_, advertisedNextProtocolCallback, this);
delete[] protocols.protocols;
}
advertisedNextProtocols_.clear();
- advertisedNextProtocolWeights_.clear();
}
void SSLContext::unsetNextProtocols() {
}
size_t SSLContext::pickNextProtocols() {
- CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
- return nextProtocolDistribution_(nextProtocolPicker_);
+ unsigned char random_byte;
+ RAND_bytes(&random_byte, 1);
+ double random_value = random_byte / 255.0;
+ double sum = 0;
+ for (size_t i = 0; i < advertisedNextProtocols_.size(); ++i) {
+ sum += advertisedNextProtocols_[i].probability;
+ if (sum < random_value && i + 1 < advertisedNextProtocols_.size()) {
+ continue;
+ }
+ return i;
+ }
+ CHECK(false) << "Failed to pickNextProtocols";
}
int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
#include <vector>
#include <memory>
#include <string>
-#include <random>
#include <openssl/ssl.h>
#include <openssl/tls1.h>
#include <folly/folly-config.h>
#endif
-#include <folly/Random.h>
-
namespace folly {
/**
std::list<std::string> protocols;
};
+ struct AdvertisedNextProtocolsItem {
+ unsigned char* protocols;
+ unsigned length;
+ double probability;
+ };
+
// Function that selects a client protocol given the server's list
using ClientProtocolFilterCallback = bool (*)(unsigned char**, unsigned int*,
const unsigned char*, unsigned int);
static bool initialized_;
#ifdef OPENSSL_NPN_NEGOTIATED
-
- struct AdvertisedNextProtocolsItem {
- unsigned char* protocols;
- unsigned length;
- };
-
/**
* Wire-format list of advertised protocols for use in NPN.
*/
std::vector<AdvertisedNextProtocolsItem> advertisedNextProtocols_;
- std::vector<int> advertisedNextProtocolWeights_;
- std::discrete_distribution<int> nextProtocolDistribution_;
- Random::DefaultGenerator nextProtocolPicker_;
-
static int sNextProtocolsExDataIndex_;
static int advertisedNextProtocolCallback(SSL* ssl,