X-Git-Url: http://demsky.eecs.uci.edu/git/?a=blobdiff_plain;f=folly%2Fio%2Fasync%2FAsyncSocket.h;h=e99300fb238a6ae491f453bdbc886b405037adda;hb=fbc4c23895b0ee3874d9a36401d580a2a8957ba9;hp=973e493cb81f3adcaf1334654cb8d1e645088754;hpb=b4a27a035cc092f444df9a952e371ed4a201c2f4;p=folly.git diff --git a/folly/io/async/AsyncSocket.h b/folly/io/async/AsyncSocket.h index 973e493c..e99300fb 100644 --- a/folly/io/async/AsyncSocket.h +++ b/folly/io/async/AsyncSocket.h @@ -1,5 +1,5 @@ /* - * Copyright 2016 Facebook, Inc. + * Copyright 2017 Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,8 +31,8 @@ #include #include -#include #include +#include namespace folly { @@ -64,6 +64,14 @@ namespace folly { * responding and no further progress can be made sending the data. */ +#if defined __linux__ && !defined SO_NO_TRANSPARENT_TLS +#define SO_NO_TRANSPARENT_TLS 200 +#endif + +#if defined __linux__ && !defined SO_NO_TSOCKS +#define SO_NO_TSOCKS 201 +#endif + #ifdef _MSC_VER // We do a dynamic_cast on this, in // AsyncTransportWrapper::getUnderlyingTransport so be safe and @@ -94,6 +102,118 @@ class AsyncSocket : virtual public AsyncTransportWrapper { noexcept = 0; }; + class EvbChangeCallback { + public: + virtual ~EvbChangeCallback() = default; + + // Called when the socket has been attached to a new EVB + // and is called from within that EVB thread + virtual void evbAttached(AsyncSocket* socket) = 0; + + // Called when the socket is detached from an EVB and + // is called from the EVB thread being detached + virtual void evbDetached(AsyncSocket* socket) = 0; + }; + + /** + * This interface is implemented only for platforms supporting + * per-socket error queues. + */ + class ErrMessageCallback { + public: + virtual ~ErrMessageCallback() = default; + + /** + * errMessage() will be invoked when kernel puts a message to + * the error queue associated with the socket. + * + * @param cmsg Reference to cmsghdr structure describing + * a message read from error queue associated + * with the socket. + */ + virtual void + errMessage(const cmsghdr& cmsg) noexcept = 0; + + /** + * errMessageError() will be invoked if an error occurs reading a message + * from the socket error stream. + * + * @param ex An exception describing the error that occurred. + */ + virtual void errMessageError(const AsyncSocketException& ex) noexcept = 0; + }; + + class SendMsgParamsCallback { + public: + virtual ~SendMsgParamsCallback() = default; + + /** + * getFlags() will be invoked to retrieve the desired flags to be passed + * to ::sendmsg() system call. This method was intentionally declared + * non-virtual, so there is no way to override it. Instead feel free to + * override getFlagsImpl(flags, defaultFlags) method instead, and enjoy + * the convenience of defaultFlags passed there. + * + * @param flags Write flags requested for the given write operation + */ + int getFlags(folly::WriteFlags flags, bool zeroCopyEnabled) noexcept { + return getFlagsImpl(flags, getDefaultFlags(flags, zeroCopyEnabled)); + } + + /** + * getAncillaryData() will be invoked to initialize ancillary data + * buffer referred by "msg_control" field of msghdr structure passed to + * ::sendmsg() system call. The function assumes that the size of buffer + * is not smaller than the value returned by getAncillaryDataSize() method + * for the same combination of flags. + * + * @param flags Write flags requested for the given write operation + * @param data Pointer to ancillary data buffer to initialize. + */ + virtual void getAncillaryData( + folly::WriteFlags /*flags*/, + void* /*data*/) noexcept {} + + /** + * getAncillaryDataSize() will be invoked to retrieve the size of + * ancillary data buffer which should be passed to ::sendmsg() system call + * + * @param flags Write flags requested for the given write operation + */ + virtual uint32_t getAncillaryDataSize(folly::WriteFlags /*flags*/) + noexcept { + return 0; + } + + static const size_t maxAncillaryDataSize{0x5000}; + + private: + /** + * getFlagsImpl() will be invoked by getFlags(folly::WriteFlags flags) + * method to retrieve the flags to be passed to ::sendmsg() system call. + * SendMsgParamsCallback::getFlags() is calling this method, and returns + * its results directly to the caller in AsyncSocket. + * Classes inheriting from SendMsgParamsCallback are welcome to override + * this method to force SendMsgParamsCallback to return its own set + * of flags. + * + * @param flags Write flags requested for the given write operation + * @param defaultflags A set of message flags returned by getDefaultFlags() + * method for the given "flags" mask. + */ + virtual int getFlagsImpl(folly::WriteFlags /*flags*/, int defaultFlags) { + return defaultFlags; + } + + /** + * getDefaultFlags() will be invoked by getFlags(folly::WriteFlags flags) + * to retrieve the default set of flags, and pass them to getFlagsImpl(...) + * + * @param flags Write flags requested for the given write operation + */ + int getDefaultFlags(folly::WriteFlags flags, bool zeroCopyEnabled) noexcept; + }; + explicit AsyncSocket(); /** * Create a new unconnected AsyncSocket. @@ -144,6 +264,14 @@ class AsyncSocket : virtual public AsyncTransportWrapper { */ AsyncSocket(EventBase* evb, int fd); + /** + * Create an AsyncSocket from a different, already connected AsyncSocket. + * + * Similar to AsyncSocket(evb, fd) when fd was previously owned by an + * AsyncSocket. + */ + explicit AsyncSocket(AsyncSocket::UniquePtr); + /** * Helper function to create a shared_ptr. * @@ -196,7 +324,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { * This prevents callers from deleting a AsyncSocket while it is invoking a * callback. */ - virtual void destroy() override; + void destroy() override; /** * Get the EventBase used by this socket. @@ -219,6 +347,10 @@ class AsyncSocket : virtual public AsyncTransportWrapper { * error. The AsyncSocket may no longer be used after the file descriptor * has been extracted. * + * This method should be used with care as the resulting fd is not guaranteed + * to perfectly reflect the state of the AsyncSocket (security state, + * pre-received data, etc.). + * * Returns the file descriptor. The caller assumes ownership of the * descriptor, and it will not be closed when the AsyncSocket is destroyed. */ @@ -336,10 +468,57 @@ class AsyncSocket : virtual public AsyncTransportWrapper { return maxReadsPerEvent_; } + /** + * Set a pointer to ErrMessageCallback implementation which will be + * receiving notifications for messages posted to the error queue + * associated with the socket. + * ErrMessageCallback is implemented only for platforms with + * per-socket error message queus support (recvmsg() system call must + * ) + * + */ + virtual void setErrMessageCB(ErrMessageCallback* callback); + + /** + * Get a pointer to ErrMessageCallback implementation currently + * registered with this socket. + * + */ + virtual ErrMessageCallback* getErrMessageCallback() const; + + /** + * Set a pointer to SendMsgParamsCallback implementation which + * will be used to form ::sendmsg() system call parameters + * + */ + virtual void setSendMsgParamCB(SendMsgParamsCallback* callback); + + /** + * Get a pointer to SendMsgParamsCallback implementation currently + * registered with this socket. + * + */ + virtual SendMsgParamsCallback* getSendMsgParamsCB() const; + // Read and write methods void setReadCB(ReadCallback* callback) override; ReadCallback* getReadCallback() const override; + static const size_t kDefaultZeroCopyThreshold = 32768; // 32KB + + bool setZeroCopy(bool enable); + bool getZeroCopy() const { + return zeroCopyEnabled_; + } + + void setZeroCopyWriteChainThreshold(size_t threshold); + size_t getZeroCopyWriteChainThreshold() const { + return zeroCopyWriteChainThreshold_; + } + + bool isZeroCopyMsg(const cmsghdr& cmsg) const; + void processZeroCopyMsg(const cmsghdr& cmsg); + void write(WriteCallback* callback, const void* buf, size_t bytes, WriteFlags flags = WriteFlags::NONE) override; void writev(WriteCallback* callback, const iovec* vec, size_t count, @@ -362,6 +541,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { void shutdownWriteNow() override; bool readable() const override; + bool writable() const override; bool isPending() const override; virtual bool hangup() const; bool good() const override; @@ -375,9 +555,13 @@ class AsyncSocket : virtual public AsyncTransportWrapper { void getPeerAddress( folly::SocketAddress* address) const override; - bool isEorTrackingEnabled() const override { return false; } + bool isEorTrackingEnabled() const override { + return trackEor_; + } - void setEorTracking(bool /*track*/) override {} + void setEorTracking(bool track) override { + trackEor_ = track; + } bool connecting() const override { return (state_ == StateEnum::CONNECTING); @@ -534,8 +718,51 @@ class AsyncSocket : virtual public AsyncTransportWrapper { return setsockopt(fd_, level, optname, optval, sizeof(T)); } - virtual void setPeek(bool peek) { - peek_ = peek; + /** + * Virtual method for reading a socket option returning integer + * value, which is the most typical case. Convenient for overriding + * and mocking. + * + * @param level same as the "level" parameter in getsockopt(). + * @param optname same as the "optname" parameter in getsockopt(). + * @param optval same as "optval" parameter in getsockopt(). + * @param optlen same as "optlen" parameter in getsockopt(). + * @return same as the return value of getsockopt(). + */ + virtual int + getSockOptVirtual(int level, int optname, void* optval, socklen_t* optlen) { + return getsockopt(fd_, level, optname, optval, optlen); + } + + /** + * Virtual method for setting a socket option accepting integer + * value, which is the most typical case. Convenient for overriding + * and mocking. + * + * @param level same as the "level" parameter in setsockopt(). + * @param optname same as the "optname" parameter in setsockopt(). + * @param optval same as "optval" parameter in setsockopt(). + * @param optlen same as "optlen" parameter in setsockopt(). + * @return same as the return value of setsockopt(). + */ + virtual int setSockOptVirtual( + int level, + int optname, + void const* optval, + socklen_t optlen) { + return setsockopt(fd_, level, optname, optval, optlen); + } + + /** + * Set pre-received data, to be returned to read callback before any data + * from the socket. + */ + virtual void setPreReceivedData(std::unique_ptr data) { + if (preReceivedData_) { + preReceivedData_->prependChain(std::move(data)); + } else { + preReceivedData_ = std::move(data); + } } /** @@ -549,6 +776,14 @@ class AsyncSocket : virtual public AsyncTransportWrapper { #endif } + void disableTransparentTls() { + noTransparentTls_ = true; + } + + void disableTSocks() { + noTSocks_ = true; + } + enum class StateEnum : uint8_t { UNINIT, CONNECTING, @@ -560,6 +795,19 @@ class AsyncSocket : virtual public AsyncTransportWrapper { void setBufferCallback(BufferCallback* cb); + // Callers should set this prior to connecting the socket for the safest + // behavior. + void setEvbChangedCallback(std::unique_ptr cb) { + evbChangeCb_ = std::move(cb); + } + + /** + * Attempt to cache the current local and peer addresses (if not already + * cached) so that they are available from getPeerAddress() and + * getLocalAddress() even after the socket is closed. + */ + void cacheAddresses(); + /** * writeReturn is the total number of bytes written, or WRITE_ERROR on error. * If no data has been written, 0 is returned. @@ -668,7 +916,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { * destroy() instead. (See the documentation in DelayedDestruction.h for * more details.) */ - ~AsyncSocket(); + ~AsyncSocket() override; friend std::ostream& operator << (std::ostream& os, const StateEnum& state); @@ -701,7 +949,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { : AsyncTimeout(eventBase) , socket_(socket) {} - virtual void timeoutExpired() noexcept { + void timeoutExpired() noexcept override { socket_->timeoutExpired(); } @@ -718,7 +966,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { : EventHandler(eventBase, fd) , socket_(socket) {} - virtual void handlerReady(uint16_t events) noexcept { + void handlerReady(uint16_t events) noexcept override { socket_->ioReady(events); } @@ -768,6 +1016,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { virtual void checkForImmediateRead() noexcept; virtual void handleInitialReadWrite() noexcept; virtual void prepareReadBuffer(void** buf, size_t* buflen); + virtual void handleErrMessages() noexcept; virtual void handleRead() noexcept; virtual void handleWrite() noexcept; virtual void handleConnect() noexcept; @@ -882,6 +1131,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { void fail(const char* fn, const AsyncSocketException& ex); void failConnect(const char* fn, const AsyncSocketException& ex); void failRead(const char* fn, const AsyncSocketException& ex); + void failErrMessageRead(const char* fn, const AsyncSocketException& ex); void failWrite(const char* fn, WriteCallback* callback, size_t bytesWritten, const AsyncSocketException& ex); void failWrite(const char* fn, const AsyncSocketException& ex); @@ -889,37 +1139,72 @@ class AsyncSocket : virtual public AsyncTransportWrapper { virtual void invokeConnectErr(const AsyncSocketException& ex); virtual void invokeConnectSuccess(); void invalidState(ConnectCallback* callback); + void invalidState(ErrMessageCallback* callback); void invalidState(ReadCallback* callback); void invalidState(WriteCallback* callback); std::string withAddr(const std::string& s); - StateEnum state_; ///< StateEnum describing current state - uint8_t shutdownFlags_; ///< Shutdown state (ShutdownFlags) - uint16_t eventFlags_; ///< EventBase::HandlerFlags settings - int fd_; ///< The socket file descriptor + void cacheLocalAddress() const; + void cachePeerAddress() const; + + bool isZeroCopyRequest(WriteFlags flags); + uint32_t getNextZeroCopyBuffId() { + return zeroCopyBuffId_++; + } + void adjustZeroCopyFlags(folly::IOBuf* buf, folly::WriteFlags& flags); + void adjustZeroCopyFlags( + const iovec* vec, + uint32_t count, + folly::WriteFlags& flags); + void addZeroCopyBuff(std::unique_ptr&& buf); + void addZeroCopyBuff(folly::IOBuf* ptr); + void setZeroCopyBuff(std::unique_ptr&& buf); + bool containsZeroCopyBuff(folly::IOBuf* ptr); + void releaseZeroCopyBuff(uint32_t id); + + // a folly::IOBuf can be used in multiple partial requests + // so we keep a map that maps a buffer id to a raw folly::IOBuf ptr + // and one more map that adds a ref count for a folly::IOBuf that is either + // the original ptr or nullptr + uint32_t zeroCopyBuffId_{0}; + std::unordered_map idZeroCopyBufPtrMap_; + std::unordered_map< + folly::IOBuf*, + std::pair>> + idZeroCopyBufPtrToBufMap_; + + StateEnum state_; ///< StateEnum describing current state + uint8_t shutdownFlags_; ///< Shutdown state (ShutdownFlags) + uint16_t eventFlags_; ///< EventBase::HandlerFlags settings + int fd_; ///< The socket file descriptor mutable folly::SocketAddress addr_; ///< The address we tried to connect to mutable folly::SocketAddress localAddr_; - ///< The address we are connecting from - uint32_t sendTimeout_; ///< The send timeout, in milliseconds - uint16_t maxReadsPerEvent_; ///< Max reads per event loop iteration - EventBase* eventBase_; ///< The EventBase - WriteTimeout writeTimeout_; ///< A timeout for connect and write - IoHandler ioHandler_; ///< A EventHandler to monitor the fd + ///< The address we are connecting from + uint32_t sendTimeout_; ///< The send timeout, in milliseconds + uint16_t maxReadsPerEvent_; ///< Max reads per event loop iteration + EventBase* eventBase_; ///< The EventBase + WriteTimeout writeTimeout_; ///< A timeout for connect and write + IoHandler ioHandler_; ///< A EventHandler to monitor the fd ImmediateReadCB immediateReadHandler_; ///< LoopCallback for checking read - ConnectCallback* connectCallback_; ///< ConnectCallback - ReadCallback* readCallback_; ///< ReadCallback - WriteRequest* writeReqHead_; ///< Chain of WriteRequests - WriteRequest* writeReqTail_; ///< End of WriteRequest chain + ConnectCallback* connectCallback_; ///< ConnectCallback + ErrMessageCallback* errMessageCallback_; ///< TimestampCallback + SendMsgParamsCallback* ///< Callback for retrieving + sendMsgParamCallback_; ///< ::sendmsg() parameters + ReadCallback* readCallback_; ///< ReadCallback + WriteRequest* writeReqHead_; ///< Chain of WriteRequests + WriteRequest* writeReqTail_; ///< End of WriteRequest chain ShutdownSocketSet* shutdownSocketSet_; - size_t appBytesReceived_; ///< Num of bytes received from socket - size_t appBytesWritten_; ///< Num of bytes written to socket + size_t appBytesReceived_; ///< Num of bytes received from socket + size_t appBytesWritten_; ///< Num of bytes written to socket bool isBufferMovable_{false}; - bool peek_{false}; // Peek bytes. + // Pre-received data, to be returned to read callback before any data from the + // socket. + std::unique_ptr preReceivedData_; - int8_t readErr_{READ_NO_ERROR}; ///< The read error encountered, if any. + int8_t readErr_{READ_NO_ERROR}; ///< The read error encountered, if any std::chrono::steady_clock::time_point connectStartTime_; std::chrono::steady_clock::time_point connectEndTime_; @@ -930,9 +1215,18 @@ class AsyncSocket : virtual public AsyncTransportWrapper { bool tfoEnabled_{false}; bool tfoAttempted_{false}; bool tfoFinished_{false}; + bool noTransparentTls_{false}; + bool noTSocks_{false}; + // Whether to track EOR or not. + bool trackEor_{false}; + bool zeroCopyEnabled_{false}; + bool zeroCopyVal_{false}; + size_t zeroCopyWriteChainThreshold_{kDefaultZeroCopyThreshold}; + + std::unique_ptr evbChangeCb_{nullptr}; }; #ifdef _MSC_VER #pragma vtordisp(pop) #endif -} // folly +} // namespace folly