SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
SSL_CTX_set_tlsext_servername_arg(ctx_, this);
#endif
+
+#ifdef OPENSSL_NPN_NEGOTIATED
+ Random::seed(nextProtocolPicker_);
+#endif
}
SSLContext::~SSLContext() {
dst += protoLength;
}
total_weight += item.weight;
- advertised_item.probability = item.weight;
advertisedNextProtocols_.push_back(advertised_item);
+ advertisedNextProtocolWeights_.push_back(item.weight);
}
if (total_weight == 0) {
deleteNextProtocolsStrings();
return false;
}
- for (auto& advertised_item : advertisedNextProtocols_) {
- advertised_item.probability /= total_weight;
- }
+ nextProtocolDistribution_ =
+ std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
+ advertisedNextProtocolWeights_.end());
if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
SSL_CTX_set_next_protos_advertised_cb(
ctx_, advertisedNextProtocolCallback, this);
delete[] protocols.protocols;
}
advertisedNextProtocols_.clear();
+ advertisedNextProtocolWeights_.clear();
}
void SSLContext::unsetNextProtocols() {
}
size_t SSLContext::pickNextProtocols() {
- unsigned char random_byte;
- RAND_bytes(&random_byte, 1);
- double random_value = random_byte / 255.0;
- double sum = 0;
- for (size_t i = 0; i < advertisedNextProtocols_.size(); ++i) {
- sum += advertisedNextProtocols_[i].probability;
- if (sum < random_value && i + 1 < advertisedNextProtocols_.size()) {
- continue;
- }
- return i;
- }
- CHECK(false) << "Failed to pickNextProtocols";
+ CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
+ return nextProtocolDistribution_(nextProtocolPicker_);
}
int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
#include <vector>
#include <memory>
#include <string>
+#include <random>
#include <openssl/ssl.h>
#include <openssl/tls1.h>
#include <folly/folly-config.h>
#endif
+#include <folly/Random.h>
+
namespace folly {
/**
std::list<std::string> protocols;
};
- struct AdvertisedNextProtocolsItem {
- unsigned char* protocols;
- unsigned length;
- double probability;
- };
-
// Function that selects a client protocol given the server's list
using ClientProtocolFilterCallback = bool (*)(unsigned char**, unsigned int*,
const unsigned char*, unsigned int);
static bool initialized_;
#ifdef OPENSSL_NPN_NEGOTIATED
+
+ struct AdvertisedNextProtocolsItem {
+ unsigned char* protocols;
+ unsigned length;
+ };
+
/**
* Wire-format list of advertised protocols for use in NPN.
*/
std::vector<AdvertisedNextProtocolsItem> advertisedNextProtocols_;
+ std::vector<int> advertisedNextProtocolWeights_;
+ std::discrete_distribution<int> nextProtocolDistribution_;
+ Random::DefaultGenerator nextProtocolPicker_;
+
static int sNextProtocolsExDataIndex_;
static int advertisedNextProtocolCallback(SSL* ssl,