#include <folly/io/Cursor.h>
#include <folly/io/IOBuf.h>
#include <folly/portability/OpenSSL.h>
-#include <folly/portability/Unistd.h>
using folly::SocketAddress;
using folly::SSLContext;
using namespace folly::ssl;
using folly::ssl::OpenSSLUtils;
+
// We have one single dummy SSL context so that we can implement attach
// and detach methods in a thread safe fashion without modifying opnessl.
static SSLContext *dummyCtx = nullptr;
int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
struct msghdr msg;
struct iovec iov;
- int flags = 0;
AsyncSSLSocket* tsslSock;
iov.iov_base = const_cast<char*>(in);
tsslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
CHECK(tsslSock);
+ WriteFlags flags = WriteFlags::NONE;
if (tsslSock->isEorTrackingEnabled() && tsslSock->minEorRawByteNo_ &&
tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
- flags = MSG_EOR;
+ flags |= WriteFlags::EOR;
}
-#ifdef MSG_NOSIGNAL
- flags |= MSG_NOSIGNAL;
-#endif
-
-#ifdef MSG_MORE
if (tsslSock->corkCurrentWrite_) {
- flags |= MSG_MORE;
+ flags |= WriteFlags::CORK;
+ }
+
+ int msg_flags = tsslSock->getSendMsgParamsCB()->getFlags(flags);
+ msg.msg_controllen =
+ tsslSock->getSendMsgParamsCB()->getAncillaryDataSize(flags);
+ CHECK_GE(AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize,
+ msg.msg_controllen);
+ if (msg.msg_controllen != 0) {
+ msg.msg_control = reinterpret_cast<char*>(alloca(msg.msg_controllen));
+ tsslSock->getSendMsgParamsCB()->getAncillaryData(flags, msg.msg_control);
}
-#endif
auto result = tsslSock->sendSocketMessage(
- OpenSSLUtils::getBioFd(b, nullptr), &msg, flags);
+ OpenSSLUtils::getBioFd(b, nullptr), &msg, msg_flags);
BIO_clear_retry_flags(b);
if (!result.exception && result.writeReturn <= 0) {
if (OpenSSLUtils::getBioShouldRetryWrite(int(result.writeReturn))) {
struct iovec writeOps_[]; ///< write operation(s) list
};
+int AsyncSocket::SendMsgParamsCallback::getDefaultFlags(folly::WriteFlags flags)
+ noexcept {
+ int msg_flags = MSG_DONTWAIT;
+
+#ifdef MSG_NOSIGNAL // Linux-only
+ msg_flags |= MSG_NOSIGNAL;
+#ifdef MSG_MORE
+ if (isSet(flags, WriteFlags::CORK)) {
+ // MSG_MORE tells the kernel we have more data to send, so wait for us to
+ // give it the rest of the data rather than immediately sending a partial
+ // frame, even when TCP_NODELAY is enabled.
+ msg_flags |= MSG_MORE;
+ }
+#endif // MSG_MORE
+#endif // MSG_NOSIGNAL
+ if (isSet(flags, WriteFlags::EOR)) {
+ // marks that this is the last byte of a record (response)
+ msg_flags |= MSG_EOR;
+ }
+
+ return msg_flags;
+}
+
+namespace {
+static AsyncSocket::SendMsgParamsCallback defaultSendMsgParamsCallback;
+}
+
AsyncSocket::AsyncSocket()
: eventBase_(nullptr),
writeTimeout_(this, nullptr),
shutdownSocketSet_ = nullptr;
appBytesWritten_ = 0;
appBytesReceived_ = 0;
+ sendMsgParamCallback_ = &defaultSendMsgParamsCallback;
}
AsyncSocket::~AsyncSocket() {
return errMessageCallback_;
}
+void AsyncSocket::setSendMsgParamCB(SendMsgParamsCallback* callback) {
+ sendMsgParamCallback_ = callback;
+}
+
+AsyncSocket::SendMsgParamsCallback* AsyncSocket::getSendMsgParamsCB() const {
+ return sendMsgParamCallback_;
+}
+
void AsyncSocket::setReadCB(ReadCallback *callback) {
VLOG(6) << "AsyncSocket::setReadCallback() this=" << this << ", fd=" << fd_
<< ", callback=" << callback << ", state=" << state_;
}
void AsyncSocket::ioReady(uint16_t events) noexcept {
- VLOG(7) << "AsyncSocket::ioRead() this=" << this << ", fd" << fd_
+ VLOG(7) << "AsyncSocket::ioRead() this=" << this << ", fd=" << fd_
<< ", events=" << std::hex << events << ", state=" << state_;
DestructorGuard dg(this);
assert(events & EventHandler::READ_WRITE);
msg.msg_namelen = 0;
msg.msg_iov = const_cast<iovec *>(vec);
msg.msg_iovlen = std::min<size_t>(count, kIovMax);
- msg.msg_control = nullptr;
- msg.msg_controllen = 0;
msg.msg_flags = 0;
+ msg.msg_controllen = sendMsgParamCallback_->getAncillaryDataSize(flags);
+ CHECK_GE(AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize,
+ msg.msg_controllen);
- int msg_flags = MSG_DONTWAIT;
-
-#ifdef MSG_NOSIGNAL // Linux-only
- msg_flags |= MSG_NOSIGNAL;
- if (isSet(flags, WriteFlags::CORK)) {
- // MSG_MORE tells the kernel we have more data to send, so wait for us to
- // give it the rest of the data rather than immediately sending a partial
- // frame, even when TCP_NODELAY is enabled.
- msg_flags |= MSG_MORE;
- }
-#endif
- if (isSet(flags, WriteFlags::EOR)) {
- // marks that this is the last byte of a record (response)
- msg_flags |= MSG_EOR;
+ if (msg.msg_controllen != 0) {
+ msg.msg_control = reinterpret_cast<char*>(alloca(msg.msg_controllen));
+ sendMsgParamCallback_->getAncillaryData(flags, msg.msg_control);
+ } else {
+ msg.msg_control = nullptr;
}
+ int msg_flags = sendMsgParamCallback_->getFlags(flags);
+
auto writeResult = sendSocketMessage(fd_, &msg, msg_flags);
auto totalWritten = writeResult.writeReturn;
if (totalWritten < 0) {
virtual void errMessageError(const AsyncSocketException& ex) noexcept = 0;
};
+ class SendMsgParamsCallback {
+ public:
+ virtual ~SendMsgParamsCallback() = default;
+
+ /**
+ * getFlags() will be invoked to retrieve the desired flags to be passed
+ * to ::sendmsg() system call. This method was intentionally declared
+ * non-virtual, so there is no way to override it. Instead feel free to
+ * override getFlagsImpl(flags, defaultFlags) method instead, and enjoy
+ * the convenience of defaultFlags passed there.
+ *
+ * @param flags Write flags requested for the given write operation
+ */
+ int getFlags(folly::WriteFlags flags) noexcept {
+ return getFlagsImpl(flags, getDefaultFlags(flags));
+ }
+
+ /**
+ * getAncillaryData() will be invoked to initialize ancillary data
+ * buffer referred by "msg_control" field of msghdr structure passed to
+ * ::sendmsg() system call. The function assumes that the size of buffer
+ * is not smaller than the value returned by getAncillaryDataSize() method
+ * for the same combination of flags.
+ *
+ * @param flags Write flags requested for the given write operation
+ * @param data Pointer to ancillary data buffer to initialize.
+ */
+ virtual void getAncillaryData(
+ folly::WriteFlags /*flags*/,
+ void* /*data*/) noexcept {}
+
+ /**
+ * getAncillaryDataSize() will be invoked to retrieve the size of
+ * ancillary data buffer which should be passed to ::sendmsg() system call
+ *
+ * @param flags Write flags requested for the given write operation
+ */
+ virtual uint32_t getAncillaryDataSize(folly::WriteFlags /*flags*/)
+ noexcept {
+ return 0;
+ }
+
+ static const size_t maxAncillaryDataSize{0x5000};
+
+ private:
+ /**
+ * getFlagsImpl() will be invoked by getFlags(folly::WriteFlags flags)
+ * method to retrieve the flags to be passed to ::sendmsg() system call.
+ * SendMsgParamsCallback::getFlags() is calling this method, and returns
+ * its results directly to the caller in AsyncSocket.
+ * Classes inheriting from SendMsgParamsCallback are welcome to override
+ * this method to force SendMsgParamsCallback to return its own set
+ * of flags.
+ *
+ * @param flags Write flags requested for the given write operation
+ * @param defaultflags A set of message flags returned by getDefaultFlags()
+ * method for the given "flags" mask.
+ */
+ virtual int getFlagsImpl(folly::WriteFlags /*flags*/, int defaultFlags) {
+ return defaultFlags;
+ }
+
+ /**
+ * getDefaultFlags() will be invoked by getFlags(folly::WriteFlags flags)
+ * to retrieve the default set of flags, and pass them to getFlagsImpl(...)
+ *
+ * @param flags Write flags requested for the given write operation
+ */
+ int getDefaultFlags(folly::WriteFlags flags) noexcept;
+ };
+
explicit AsyncSocket();
/**
* Create a new unconnected AsyncSocket.
*/
ErrMessageCallback* getErrMessageCallback() const;
+ /**
+ * Set a pointer to SendMsgParamsCallback implementation which
+ * will be used to form ::sendmsg() system call parameters
+ *
+ */
+ void setSendMsgParamCB(SendMsgParamsCallback* callback);
+
+ /**
+ * Get a pointer to SendMsgParamsCallback implementation currently
+ * registered with this socket.
+ *
+ */
+ SendMsgParamsCallback* getSendMsgParamsCB() const;
+
// Read and write methods
void setReadCB(ReadCallback* callback) override;
ReadCallback* getReadCallback() const override;
ConnectCallback* connectCallback_; ///< ConnectCallback
ErrMessageCallback* errMessageCallback_; ///< TimestampCallback
+ SendMsgParamsCallback* ///< Callback for retreaving
+ sendMsgParamCallback_; ///< ::sendmsg() parameters
ReadCallback* readCallback_; ///< ReadCallback
WriteRequest* writeReqHead_; ///< Chain of WriteRequests
WriteRequest* writeReqTail_; ///< End of WriteRequest chain
// socket.
std::unique_ptr<IOBuf> preReceivedData_;
- int8_t readErr_{READ_NO_ERROR}; ///< The read error encountered, if any.
+ int8_t readErr_{READ_NO_ERROR}; ///< The read error encountered, if any
std::chrono::steady_clock::time_point connectStartTime_;
std::chrono::steady_clock::time_point connectEndTime_;
#include <folly/io/Cursor.h>
#include <openssl/bio.h>
#include <sys/types.h>
+#include <sys/utsname.h>
#include <fstream>
#include <iostream>
#include <list>
serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
}
+/**
+ * Test overriding the flags passed to "sendmsg()" system call,
+ * and verifying that write requests fail properly.
+ */
+TEST(AsyncSSLSocketTest, SendMsgParamsCallback) {
+ // Start listening on a local port
+ SendMsgFlagsCallback msgCallback;
+ ExpectWriteErrorCallback writeCallback(&msgCallback);
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback);
+
+ // Set up SSL context.
+ auto sslContext = std::make_shared<SSLContext>();
+ sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+ // connect
+ auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
+ sslContext);
+ socket->open();
+
+ // Setting flags to "-1" to trigger "Invalid argument" error
+ // on attempt to use this flags in sendmsg() system call.
+ msgCallback.resetFlags(-1);
+
+ // write()
+ std::vector<uint8_t> buf(128, 'a');
+ ASSERT_EQ(socket->write(buf.data(), buf.size()), buf.size());
+
+ // close()
+ socket->close();
+
+ cerr << "SendMsgParamsCallback test completed" << endl;
+}
+
+#ifdef MSG_ERRQUEUE
+/**
+ * Test connecting to, writing to, reading from, and closing the
+ * connection to the SSL server.
+ */
+TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
+ // This test requires Linux kernel v4.6 or later
+ struct utsname s_uname;
+ memset(&s_uname, 0, sizeof(s_uname));
+ ASSERT_EQ(uname(&s_uname), 0);
+ int major, minor;
+ folly::StringPiece extra;
+ if (folly::split<false>(
+ '.', std::string(s_uname.release) + ".", major, minor, extra)) {
+ if (major < 4 || (major == 4 && minor < 6)) {
+ LOG(INFO) << "Kernel version: 4.6 and newer required for this test ("
+ << "kernel ver. " << s_uname.release << " detected).";
+ return;
+ }
+ }
+
+ // Start listening on a local port
+ SendMsgDataCallback msgCallback;
+ WriteCheckTimestampCallback writeCallback(&msgCallback);
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback);
+
+ // Set up SSL context.
+ auto sslContext = std::make_shared<SSLContext>();
+ sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+ // connect
+ auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
+ sslContext);
+ socket->open();
+
+ // Adding MSG_EOR flag to the message flags - it'll trigger
+ // timestamp generation for the last byte of the message.
+ msgCallback.resetFlags(MSG_DONTWAIT|MSG_NOSIGNAL|MSG_EOR);
+
+ // Init ancillary data buffer to trigger timestamp notification
+ union {
+ uint8_t ctrl_data[CMSG_LEN(sizeof(uint32_t))];
+ struct cmsghdr cmsg;
+ } u;
+ u.cmsg.cmsg_level = SOL_SOCKET;
+ u.cmsg.cmsg_type = SO_TIMESTAMPING;
+ u.cmsg.cmsg_len = CMSG_LEN(sizeof(uint32_t));
+ uint32_t flags =
+ SOF_TIMESTAMPING_TX_SCHED |
+ SOF_TIMESTAMPING_TX_SOFTWARE |
+ SOF_TIMESTAMPING_TX_ACK;
+ memcpy(CMSG_DATA(&u.cmsg), &flags, sizeof(uint32_t));
+ std::vector<char> ctrl(CMSG_LEN(sizeof(uint32_t)));
+ memcpy(ctrl.data(), u.ctrl_data, CMSG_LEN(sizeof(uint32_t)));
+ msgCallback.resetData(std::move(ctrl));
+
+ // write()
+ std::vector<uint8_t> buf(128, 'a');
+ socket->write(buf.data(), buf.size());
+
+ // read()
+ std::vector<uint8_t> readbuf(buf.size());
+ uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
+ EXPECT_EQ(bytesRead, buf.size());
+ EXPECT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
+
+ writeCallback.checkForTimestampNotifications();
+
+ // close()
+ socket->close();
+
+ cerr << "SendMsgDataCallback test completed" << endl;
+}
+#endif // MSG_ERRQUEUE
+
#endif
} // namespace
// are responsible for setting the succeeded state properly before the
// destructors are called.
+class SendMsgParamsCallbackBase :
+ public folly::AsyncSocket::SendMsgParamsCallback {
+ public:
+ SendMsgParamsCallbackBase() {}
+
+ void setSocket(
+ const std::shared_ptr<AsyncSSLSocket> &socket) {
+ socket_ = socket;
+ oldCallback_ = socket_->getSendMsgParamsCB();
+ socket_->setSendMsgParamCB(this);
+ }
+
+ int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
+ override {
+ return oldCallback_->getFlags(flags);
+ }
+
+ void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
+ oldCallback_->getAncillaryData(flags, data);
+ }
+
+ uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
+ return oldCallback_->getAncillaryDataSize(flags);
+ }
+
+ std::shared_ptr<AsyncSSLSocket> socket_;
+ folly::AsyncSocket::SendMsgParamsCallback* oldCallback_{nullptr};
+};
+
+class SendMsgFlagsCallback : public SendMsgParamsCallbackBase {
+ public:
+ SendMsgFlagsCallback() {}
+
+ void resetFlags(int flags) {
+ flags_ = flags;
+ }
+
+ int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
+ override {
+ if (flags_) {
+ return flags_;
+ } else {
+ return oldCallback_->getFlags(flags);
+ }
+ }
+
+ int flags_{0};
+};
+
+class SendMsgDataCallback : public SendMsgFlagsCallback {
+ public:
+ SendMsgDataCallback() {}
+
+ void resetData(std::vector<char>&& data) {
+ ancillaryData_.swap(data);
+ }
+
+ void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
+ if (ancillaryData_.size()) {
+ std::cerr << "getAncillaryData: copying data" << std::endl;
+ memcpy(data, ancillaryData_.data(), ancillaryData_.size());
+ } else {
+ oldCallback_->getAncillaryData(flags, data);
+ }
+ }
+
+ uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
+ if (ancillaryData_.size()) {
+ std::cerr << "getAncillaryDataSize: returning size" << std::endl;
+ return ancillaryData_.size();
+ } else {
+ return oldCallback_->getAncillaryDataSize(flags);
+ }
+ }
+
+ std::vector<char> ancillaryData_;
+};
+
class WriteCallbackBase :
public AsyncTransportWrapper::WriteCallback {
public:
- WriteCallbackBase()
+ explicit WriteCallbackBase(SendMsgParamsCallbackBase* mcb = nullptr)
: state(STATE_WAITING)
, bytesWritten(0)
- , exception(AsyncSocketException::UNKNOWN, "none") {}
+ , exception(AsyncSocketException::UNKNOWN, "none")
+ , mcb_(mcb) {}
~WriteCallbackBase() {
EXPECT_EQ(STATE_SUCCEEDED, state);
}
- void setSocket(
+ virtual void setSocket(
const std::shared_ptr<AsyncSSLSocket> &socket) {
socket_ = socket;
+ if (mcb_) {
+ mcb_->setSocket(socket);
+ }
}
- void writeSuccess() noexcept override {
+ virtual void writeSuccess() noexcept override {
std::cerr << "writeSuccess" << std::endl;
state = STATE_SUCCEEDED;
}
StateEnum state;
size_t bytesWritten;
AsyncSocketException exception;
+ SendMsgParamsCallbackBase* mcb_;
+};
+
+class ExpectWriteErrorCallback :
+public WriteCallbackBase {
+public:
+ explicit ExpectWriteErrorCallback(SendMsgParamsCallbackBase* mcb = nullptr)
+ : WriteCallbackBase(mcb) {}
+
+ ~ExpectWriteErrorCallback() {
+ EXPECT_EQ(STATE_FAILED, state);
+ EXPECT_EQ(exception.type_,
+ AsyncSocketException::AsyncSocketExceptionType::NETWORK_ERROR);
+ EXPECT_EQ(exception.errno_, 22);
+ // Suppress the assert in ~WriteCallbackBase()
+ state = STATE_SUCCEEDED;
+ }
+};
+
+#ifdef MSG_ERRQUEUE
+/* copied from include/uapi/linux/net_tstamp.h */
+/* SO_TIMESTAMPING gets an integer bit field comprised of these values */
+enum SOF_TIMESTAMPING {
+ SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
+ SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
+ SOF_TIMESTAMPING_OPT_ID = (1 << 7),
+ SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
+ SOF_TIMESTAMPING_TX_ACK = (1 << 9),
+ SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
+};
+
+class WriteCheckTimestampCallback :
+ public WriteCallbackBase {
+public:
+ explicit WriteCheckTimestampCallback(SendMsgParamsCallbackBase* mcb = nullptr)
+ : WriteCallbackBase(mcb) {}
+
+ ~WriteCheckTimestampCallback() {
+ EXPECT_EQ(STATE_SUCCEEDED, state);
+ EXPECT_TRUE(gotTimestamp_);
+ EXPECT_TRUE(gotByteSeq_);
+ }
+
+ void setSocket(
+ const std::shared_ptr<AsyncSSLSocket> &socket) override {
+ WriteCallbackBase::setSocket(socket);
+
+ EXPECT_NE(socket_->getFd(), 0);
+ int flags = SOF_TIMESTAMPING_OPT_ID
+ | SOF_TIMESTAMPING_OPT_TSONLY
+ | SOF_TIMESTAMPING_SOFTWARE;
+ AsyncSocket::OptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
+ int ret = tstampingOpt.apply(socket_->getFd(), flags);
+ EXPECT_EQ(ret, 0);
+ }
+
+ void checkForTimestampNotifications() noexcept {
+ int fd = socket_->getFd();
+ std::vector<char> ctrl(1024, 0);
+ unsigned char data;
+ struct msghdr msg;
+ iovec entry;
+
+ memset(&msg, 0, sizeof(msg));
+ entry.iov_base = &data;
+ entry.iov_len = sizeof(data);
+ msg.msg_iov = &entry;
+ msg.msg_iovlen = 1;
+ msg.msg_control = ctrl.data();
+ msg.msg_controllen = ctrl.size();
+
+ int ret;
+ while (true) {
+ ret = recvmsg(fd, &msg, MSG_ERRQUEUE);
+ if (ret < 0) {
+ if (errno != EAGAIN) {
+ auto errnoCopy = errno;
+ std::cerr << "::recvmsg exited with code " << ret
+ << ", errno: " << errnoCopy << std::endl;
+ AsyncSocketException ex(
+ AsyncSocketException::INTERNAL_ERROR,
+ "recvmsg() failed",
+ errnoCopy);
+ exception = ex;
+ }
+ return;
+ }
+
+ for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ cmsg != nullptr && cmsg->cmsg_len != 0;
+ cmsg = CMSG_NXTHDR(&msg, cmsg)) {
+ if (cmsg->cmsg_level == SOL_SOCKET &&
+ cmsg->cmsg_type == SCM_TIMESTAMPING) {
+ gotTimestamp_ = true;
+ continue;
+ }
+
+ if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) ||
+ (cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) {
+ gotByteSeq_ = true;
+ continue;
+ }
+ }
+ }
+ }
+
+ bool gotTimestamp_{false};
+ bool gotByteSeq_{false};
};
+#endif // MSG_ERRQUEUE
class ReadCallbackBase :
public AsyncTransportWrapper::ReadCallback {
bool gotByteSeq_{false};
};
+class TestSendMsgParamsCallback :
+ public folly::AsyncSocket::SendMsgParamsCallback {
+ public:
+ TestSendMsgParamsCallback(int flags, uint32_t dataSize, void* data)
+ : flags_(flags),
+ writeFlags_(folly::WriteFlags::NONE),
+ dataSize_(dataSize),
+ data_(data),
+ queriedFlags_(false),
+ queriedData_(false)
+ {}
+
+ void reset(int flags) {
+ flags_ = flags;
+ writeFlags_ = folly::WriteFlags::NONE;
+ queriedFlags_ = false;
+ queriedData_ = false;
+ }
+
+ int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
+ override {
+ queriedFlags_ = true;
+ if (writeFlags_ == folly::WriteFlags::NONE) {
+ writeFlags_ = flags;
+ } else {
+ assert(flags == writeFlags_);
+ }
+ return flags_;
+ }
+
+ void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
+ queriedData_ = true;
+ if (writeFlags_ == folly::WriteFlags::NONE) {
+ writeFlags_ = flags;
+ } else {
+ assert(flags == writeFlags_);
+ }
+ assert(data != nullptr);
+ memcpy(data, data_, dataSize_);
+ }
+
+ uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
+ if (writeFlags_ == folly::WriteFlags::NONE) {
+ writeFlags_ = flags;
+ } else {
+ assert(flags == writeFlags_);
+ }
+ return dataSize_;
+ }
+
+ int flags_;
+ folly::WriteFlags writeFlags_;
+ uint32_t dataSize_;
+ void* data_;
+ bool queriedFlags_;
+ bool queriedData_;
+};
+
class TestServer {
public:
// Create a TestServer.
/* copied from include/uapi/linux/net_tstamp.h */
/* SO_TIMESTAMPING gets an integer bit field comprised of these values */
enum SOF_TIMESTAMPING {
- // SOF_TIMESTAMPING_TX_HARDWARE = (1 << 0),
- // SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
- // SOF_TIMESTAMPING_RX_HARDWARE = (1 << 2),
- // SOF_TIMESTAMPING_RX_SOFTWARE = (1 << 3),
SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
- // SOF_TIMESTAMPING_SYS_HARDWARE = (1 << 5),
- // SOF_TIMESTAMPING_RAW_HARDWARE = (1 << 6),
SOF_TIMESTAMPING_OPT_ID = (1 << 7),
SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
- // SOF_TIMESTAMPING_TX_ACK = (1 << 9),
SOF_TIMESTAMPING_OPT_CMSG = (1 << 10),
SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
-
- // SOF_TIMESTAMPING_LAST = SOF_TIMESTAMPING_OPT_TSONLY,
- // SOF_TIMESTAMPING_MASK = (SOF_TIMESTAMPING_LAST - 1) | SOF_TIMESTAMPING_LAST,
};
TEST(AsyncSocketTest, ErrMessageCallback) {
TestServer server;
evb.loop();
}
+
+TEST(AsyncSocketTest, SendMessageFlags) {
+ TestServer server;
+ TestSendMsgParamsCallback sendMsgCB(
+ MSG_DONTWAIT|MSG_NOSIGNAL|MSG_MORE, 0, nullptr);
+
+ // connect()
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+ std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+
+ evb.loop();
+ ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
+
+ // Set SendMsgParamsCallback
+ socket->setSendMsgParamCB(&sendMsgCB);
+ ASSERT_EQ(socket->getSendMsgParamsCB(), &sendMsgCB);
+
+ // Write the first portion of data. This data is expected to be
+ // sent out immediately.
+ std::vector<uint8_t> buf(128, 'a');
+ WriteCallback wcb;
+ sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL);
+ socket->write(&wcb, buf.data(), buf.size());
+ ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
+ ASSERT_TRUE(sendMsgCB.queriedFlags_);
+ ASSERT_FALSE(sendMsgCB.queriedData_);
+
+ // Using different flags for the second write operation.
+ // MSG_MORE flag is expected to delay sending this
+ // data to the wire.
+ sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL | MSG_MORE);
+ socket->write(&wcb, buf.data(), buf.size());
+ ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
+ ASSERT_TRUE(sendMsgCB.queriedFlags_);
+ ASSERT_FALSE(sendMsgCB.queriedData_);
+
+ // Make sure the accepted socket saw only the data from
+ // the first write request.
+ std::vector<uint8_t> readbuf(2 * buf.size());
+ uint32_t bytesRead = acceptedSocket->read(readbuf.data(), readbuf.size());
+ ASSERT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
+ ASSERT_EQ(bytesRead, buf.size());
+
+ // Make sure the server got a connection and received the data
+ acceptedSocket->close();
+ socket->close();
+
+ ASSERT_TRUE(socket->isClosedBySelf());
+ ASSERT_FALSE(socket->isClosedByPeer());
+}
+
+TEST(AsyncSocketTest, SendMessageAncillaryData) {
+ struct sockaddr_un addr = {AF_UNIX,
+ "AsyncSocketTest.SendMessageAncillaryData\0"};
+
+ // Clean up the name in the name space we're going to use
+ ASSERT_FALSE(remove(addr.sun_path) == -1 && errno != ENOENT);
+
+ // Set up listening socket
+ int lfd = fsp::socket(AF_UNIX, SOCK_STREAM, 0);
+ ASSERT_NE(lfd, -1);
+ ASSERT_NE(bind(lfd, (struct sockaddr*)&addr, sizeof(addr)), -1)
+ << "Bind failed: " << errno;
+
+ // Create the connecting socket
+ int csd = fsp::socket(AF_UNIX, SOCK_STREAM, 0);
+ ASSERT_NE(csd, -1);
+
+ // Listen for incoming connect
+ ASSERT_NE(listen(lfd, 5), -1);
+
+ // Connect to the listening socket
+ ASSERT_NE(fsp::connect(csd, (struct sockaddr*)&addr, sizeof(addr)), -1)
+ << "Connect request failed: " << errno;
+
+ // Accept the connection
+ int sfd = accept(lfd, nullptr, nullptr);
+ ASSERT_NE(sfd, -1);
+
+ // Instantiate AsyncSocket object for the connected socket
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb, csd);
+
+ // Open a temporary file and write a magic string to it
+ // We'll transfer the file handle to test the message parameters
+ // callback logic.
+ int tmpfd = open("/var/tmp", O_RDWR | O_TMPFILE);
+ ASSERT_NE(tmpfd, -1) << "Failed to open a temporary file";
+ std::string magicString("Magic string");
+ ASSERT_EQ(write(tmpfd, magicString.c_str(), magicString.length()),
+ magicString.length());
+
+ // Send message
+ union {
+ // Space large enough to hold an 'int'
+ char control[CMSG_SPACE(sizeof(int))];
+ struct cmsghdr cmh;
+ } s_u;
+ s_u.cmh.cmsg_len = CMSG_LEN(sizeof(int));
+ s_u.cmh.cmsg_level = SOL_SOCKET;
+ s_u.cmh.cmsg_type = SCM_RIGHTS;
+ memcpy(CMSG_DATA(&s_u.cmh), &tmpfd, sizeof(int));
+
+ // Set up the callback providing message parameters
+ TestSendMsgParamsCallback sendMsgCB(
+ MSG_DONTWAIT | MSG_NOSIGNAL, sizeof(s_u.control), s_u.control);
+ socket->setSendMsgParamCB(&sendMsgCB);
+
+ // We must transmit at least 1 byte of real data in order
+ // to send ancillary data
+ int s_data = 12345;
+ WriteCallback wcb;
+ socket->write(&wcb, &s_data, sizeof(s_data));
+ ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
+
+ // Receive the message
+ union {
+ // Space large enough to hold an 'int'
+ char control[CMSG_SPACE(sizeof(int))];
+ struct cmsghdr cmh;
+ } r_u;
+ struct msghdr msgh;
+ struct iovec iov;
+ int r_data = 0;
+
+ msgh.msg_control = r_u.control;
+ msgh.msg_controllen = sizeof(r_u.control);
+ msgh.msg_name = nullptr;
+ msgh.msg_namelen = 0;
+ msgh.msg_iov = &iov;
+ msgh.msg_iovlen = 1;
+ iov.iov_base = &r_data;
+ iov.iov_len = sizeof(r_data);
+
+ // Receive data
+ ASSERT_NE(recvmsg(sfd, &msgh, 0), -1) << "recvmsg failed: " << errno;
+
+ // Validate the received message
+ ASSERT_EQ(r_u.cmh.cmsg_len, CMSG_LEN(sizeof(int)));
+ ASSERT_EQ(r_u.cmh.cmsg_level, SOL_SOCKET);
+ ASSERT_EQ(r_u.cmh.cmsg_type, SCM_RIGHTS);
+ ASSERT_EQ(r_data, s_data);
+ int fd = 0;
+ memcpy(&fd, CMSG_DATA(&r_u.cmh), sizeof(int));
+ ASSERT_NE(fd, 0);
+
+ std::vector<uint8_t> transferredMagicString(magicString.length() + 1, 0);
+
+ // Reposition to the beginning of the file
+ ASSERT_EQ(0, lseek(fd, 0, SEEK_SET));
+
+ // Read the magic string back, and compare it with the original
+ ASSERT_EQ(
+ magicString.length(),
+ read(fd, transferredMagicString.data(), transferredMagicString.size()));
+ ASSERT_TRUE(std::equal(
+ magicString.begin(),
+ magicString.end(),
+ transferredMagicString.begin()));
+}