int64_t startTime_;
protected:
- virtual ~AsyncSSLSocketConnector() {
- }
+ ~AsyncSSLSocketConnector() override {}
public:
AsyncSSLSocketConnector(AsyncSSLSocket *sslSocket,
std::chrono::steady_clock::now().time_since_epoch()).count()) {
}
- virtual void connectSuccess() noexcept {
+ void connectSuccess() noexcept override {
VLOG(7) << "client socket connected";
int64_t timeoutLeft = 0;
sslSocket_->sslConn(this, timeoutLeft);
}
- virtual void connectErr(const AsyncSocketException& ex) noexcept {
+ void connectErr(const AsyncSocketException& ex) noexcept override {
LOG(ERROR) << "TCP connect failed: " << ex.what();
fail(ex);
delete this;
}
- virtual void handshakeSuc(AsyncSSLSocket *sock) noexcept {
+ void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
VLOG(7) << "client handshake success";
if (callback_) {
callback_->connectSuccess();
delete this;
}
- virtual void handshakeErr(AsyncSSLSocket *socket,
- const AsyncSocketException& ex) noexcept {
+ void handshakeErr(AsyncSSLSocket* socket,
+ const AsyncSocketException& ex) noexcept override {
LOG(ERROR) << "client handshakeErr: " << ex.what();
fail(ex);
delete this;
// Do this here to ensure we initialize this once before any use of
// AsyncSSLSocket instances and not as part of library load.
static const auto eorAwareBioMethodInitializer = initEorBioMethod();
+ (void)eorAwareBioMethodInitializer;
+
setup_SSL_CTX(ctx_->getSSLCtx());
}
DestructorGuard dg(this);
- if (handshakeCallback_) {
- AsyncSocketException ex(AsyncSocketException::END_OF_FILE,
- "SSL connection closed locally");
- HandshakeCB* callback = handshakeCallback_;
- handshakeCallback_ = nullptr;
- callback->handshakeErr(this, ex);
- }
+ invokeHandshakeErr(
+ AsyncSocketException(
+ AsyncSocketException::END_OF_FILE,
+ "SSL connection closed locally"));
if (ssl_ != nullptr) {
SSL_free(ssl_);
AsyncSocketException ex(AsyncSocketException::INVALID_STATE,
"sslAccept() called with socket in invalid state");
+ handshakeEndTime_ = std::chrono::steady_clock::now();
if (callback) {
callback->handshakeErr(this, ex);
}
handshakeCallback_ != nullptr) {
return invalidState(callback);
}
+ handshakeStartTime_ = std::chrono::steady_clock::now();
+ // Make end time at least >= start time.
+ handshakeEndTime_ = handshakeStartTime_;
sslState_ = STATE_ACCEPTING;
handshakeCallback_ = callback;
}
}
-int AsyncSSLSocket::sslExDataIndex_ = -1;
-std::mutex AsyncSSLSocket::mutex_;
-
int AsyncSSLSocket::getSSLExDataIndex() {
- if (sslExDataIndex_ < 0) {
- std::lock_guard<std::mutex> g(mutex_);
- if (sslExDataIndex_ < 0) {
- sslExDataIndex_ = SSL_get_ex_new_index(0,
- (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr);
- }
- }
- return sslExDataIndex_;
+ static auto index = SSL_get_ex_new_index(
+ 0, (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr);
+ return index;
}
AsyncSSLSocket* AsyncSSLSocket::getFromSSL(const SSL *ssl) {
void AsyncSSLSocket::failHandshake(const char* fn,
const AsyncSocketException& ex) {
startFail();
-
if (handshakeTimeout_.isScheduled()) {
handshakeTimeout_.cancelTimeout();
}
+ invokeHandshakeErr(ex);
+ finishFail();
+}
+
+void AsyncSSLSocket::invokeHandshakeErr(const AsyncSocketException& ex) {
+ handshakeEndTime_ = std::chrono::steady_clock::now();
if (handshakeCallback_ != nullptr) {
HandshakeCB* callback = handshakeCallback_;
handshakeCallback_ = nullptr;
callback->handshakeErr(this, ex);
}
-
- finishFail();
}
void AsyncSSLSocket::invokeHandshakeCB() {
+ handshakeEndTime_ = std::chrono::steady_clock::now();
if (handshakeTimeout_.isScheduled()) {
handshakeTimeout_.cancelTimeout();
}
return invalidState(callback);
}
+ handshakeStartTime_ = std::chrono::steady_clock::now();
+ // Make end time at least >= start time.
+ handshakeEndTime_ = handshakeStartTime_;
+
sslState_ = STATE_CONNECTING;
handshakeCallback_ = callback;
AsyncSocket::handleInitialReadWrite();
}
+void AsyncSSLSocket::setReadCB(ReadCallback *callback) {
+#ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
+ // turn on the buffer movable in openssl
+ if (!isBufferMovable_ && callback != nullptr && callback->isBufferMovable()) {
+ SSL_set_mode(ssl_, SSL_get_mode(ssl_) | SSL_MODE_MOVE_BUFFER_OWNERSHIP);
+ isBufferMovable_ = true;
+ }
+#endif
+
+ AsyncSocket::setReadCB(callback);
+}
+
+void AsyncSSLSocket::prepareReadBuffer(void** buf, size_t* buflen) noexcept {
+ CHECK(readCallback_);
+ if (isBufferMovable_) {
+ *buf = nullptr;
+ *buflen = 0;
+ } else {
+ // buf is necessary for SSLSocket without SSL_MODE_MOVE_BUFFER_OWNERSHIP
+ readCallback_->getReadBuffer(buf, buflen);
+ }
+}
+
void
AsyncSSLSocket::handleRead() noexcept {
VLOG(5) << "AsyncSSLSocket::handleRead() this=" << this << ", fd=" << fd_
}
ssize_t
-AsyncSSLSocket::performRead(void* buf, size_t buflen) {
+AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
+ VLOG(4) << "AsyncSSLSocket::performRead() this=" << this
+ << ", buf=" << *buf << ", buflen=" << *buflen;
+
if (sslState_ == STATE_UNENCRYPTED) {
- return AsyncSocket::performRead(buf, buflen);
+ return AsyncSocket::performRead(buf, buflen, offset);
}
errno = 0;
- ssize_t bytes = SSL_read(ssl_, buf, buflen);
+ ssize_t bytes = 0;
+ if (!isBufferMovable_) {
+ bytes = SSL_read(ssl_, *buf, *buflen);
+ }
+#ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
+ else {
+ bytes = SSL_read_buf(ssl_, buf, (int *) offset, (int *) buflen);
+ }
+#endif
+
if (server_ && renegotiateAttempted_) {
LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
<< ", sslstate=" << sslState_ << ", events=" << eventFlags_
if (cursor.totalLength() > 0) {
uint16_t extensionsLength = cursor.readBE<uint16_t>();
while (extensionsLength) {
+ TLSExtension extensionType = static_cast<TLSExtension>(
+ cursor.readBE<uint16_t>());
sock->clientHelloInfo_->
- clientHelloExtensions_.push_back(cursor.readBE<uint16_t>());
+ clientHelloExtensions_.push_back(extensionType);
extensionsLength -= 2;
uint16_t extensionDataLength = cursor.readBE<uint16_t>();
extensionsLength -= 2;
- cursor.skip(extensionDataLength);
- extensionsLength -= extensionDataLength;
+
+ if (extensionType == TLSExtension::SIGNATURE_ALGORITHMS) {
+ cursor.skip(2);
+ extensionDataLength -= 2;
+ while (extensionDataLength) {
+ HashAlgorithm hashAlg = static_cast<HashAlgorithm>(
+ cursor.readBE<uint8_t>());
+ SignatureAlgorithm sigAlg = static_cast<SignatureAlgorithm>(
+ cursor.readBE<uint8_t>());
+ extensionDataLength -= 2;
+ sock->clientHelloInfo_->
+ clientHelloSigAlgs_.emplace_back(hashAlg, sigAlg);
+ }
+ } else {
+ cursor.skip(extensionDataLength);
+ extensionsLength -= extensionDataLength;
+ }
}
}
} catch (std::out_of_range& e) {