}
unsigned char *client;
- int client_len;
- if (ctx->advertisedNextProtocols_.empty()) {
- client = (unsigned char *) "";
- client_len = 0;
- } else {
- client = ctx->advertisedNextProtocols_[0].protocols;
- client_len = ctx->advertisedNextProtocols_[0].length;
+ unsigned int client_len;
+ bool filtered = false;
+ auto cpf = ctx->getClientProtocolFilterCallback();
+ if (cpf) {
+ filtered = (*cpf)(&client, &client_len, server, server_len);
+ }
+
+ if (!filtered) {
+ if (ctx->advertisedNextProtocols_.empty()) {
+ client = (unsigned char *) "";
+ client_len = 0;
+ } else {
+ client = ctx->advertisedNextProtocols_[0].protocols;
+ client_len = ctx->advertisedNextProtocols_[0].length;
+ }
}
int retval = SSL_select_next_proto(out, outlen, server, server_len,
double probability;
};
+ // Function that selects a client protocol given the server's list
+ using ClientProtocolFilterCallback = bool (*)(unsigned char**, unsigned int*,
+ const unsigned char*, unsigned int);
+
/**
* Convenience function to call getErrors() with the current errno value.
*
bool setRandomizedAdvertisedNextProtocols(
const std::list<NextProtocolsItem>& items);
+ void setClientProtocolFilterCallback(ClientProtocolFilterCallback cb) {
+ clientProtoFilter_ = cb;
+ }
+
+ ClientProtocolFilterCallback getClientProtocolFilterCallback() {
+ return clientProtoFilter_;
+ }
/**
* Disables NPN on this SSL context.
*/
std::vector<ClientHelloCallback> clientHelloCbs_;
#endif
+ ClientProtocolFilterCallback clientProtoFilter_{nullptr};
+
static bool initialized_;
#ifdef OPENSSL_NPN_NEGOTIATED
// (*serverSock)->setSendTimeout(100);
}
+// client protocol filters
+bool clientProtoFilterPickPony(unsigned char** client,
+ unsigned int* client_len, const unsigned char*, unsigned int ) {
+ //the protocol string in length prefixed byte string. the
+ //length byte is not included in the length
+ static unsigned char p[7] = {6,'p','o','n','i','e','s'};
+ *client = p;
+ *client_len = 7;
+ return true;
+}
+
+bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
+ const unsigned char*, unsigned int) {
+ return false;
+}
/**
* Test connecting to, writing to, reading from, and closing the
EXPECT_EQ(selected.compare("blub"), 0);
}
+TEST(AsyncSSLSocketTest, NpnTestClientProtoFilterHit) {
+ EventBase eventBase;
+ auto clientCtx = std::make_shared<SSLContext>();
+ auto serverCtx = std::make_shared<SSLContext>();
+ int fds[2];
+ getfds(fds);
+ getctx(clientCtx, serverCtx);
+
+ clientCtx->setAdvertisedNextProtocols({"blub"});
+ clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
+ serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
+
+ AsyncSSLSocket::UniquePtr clientSock(
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ AsyncSSLSocket::UniquePtr serverSock(
+ new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+ NpnClient client(std::move(clientSock));
+ NpnServer server(std::move(serverSock));
+
+ eventBase.loop();
+
+ EXPECT_TRUE(client.nextProtoLength != 0);
+ EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
+ EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
+ server.nextProtoLength), 0);
+ string selected((const char*)client.nextProto, client.nextProtoLength);
+ EXPECT_EQ(selected.compare("ponies"), 0);
+}
+
+TEST(AsyncSSLSocketTest, NpnTestClientProtoFilterMiss) {
+ EventBase eventBase;
+ auto clientCtx = std::make_shared<SSLContext>();
+ auto serverCtx = std::make_shared<SSLContext>();
+ int fds[2];
+ getfds(fds);
+ getctx(clientCtx, serverCtx);
+
+ clientCtx->setAdvertisedNextProtocols({"blub"});
+ clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
+ serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
+
+ AsyncSSLSocket::UniquePtr clientSock(
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ AsyncSSLSocket::UniquePtr serverSock(
+ new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+ NpnClient client(std::move(clientSock));
+ NpnServer server(std::move(serverSock));
+
+ eventBase.loop();
+
+ EXPECT_TRUE(client.nextProtoLength != 0);
+ EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
+ EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
+ server.nextProtoLength), 0);
+ string selected((const char*)client.nextProto, client.nextProtoLength);
+ EXPECT_EQ(selected.compare("blub"), 0);
+}
+
TEST(AsyncSSLSocketTest, RandomizedNpnTest) {
// Probability that this test will fail is 2^-64, which could be considered
// as negligible.