Implement more of the sockets API
authorChristopher Dykes <cdykes@fb.com>
Tue, 6 Sep 2016 17:07:44 +0000 (10:07 -0700)
committerFacebook Github Bot 9 <facebook-github-bot-9-bot@fb.com>
Tue, 6 Sep 2016 17:09:05 +0000 (10:09 -0700)
Summary:
This gets the socket portability layer functional enough that most of the socket tests are passing, the few that are left are depending on specific error messages.
This also switches all of the additional overloads for some of the socket functions to forward to a single implementation, to make adjustments to these functions easier in the future. (most of the functions already needed adjustments and would have had to change regardless)

Reviewed By: yfeldblum

Differential Revision: D3814011

fbshipit-source-id: c6793ee74a91d9e164775a2d52c96f54b28b9f24

folly/portability/Sockets.cpp
folly/portability/Sockets.h

index 61cf6d0f3840f9b22d77fa352edd42a3d9aaaf45..14f31fd0065d0a76dd1e779e0e9a8dc57d2b0a1c 100755 (executable)
@@ -70,11 +70,20 @@ int socket_to_fd(SOCKET s) {
   return _open_osfhandle((intptr_t)s, O_RDWR | O_BINARY);
 }
 
+int translate_wsa_error(int wsaErr) {
+  switch (wsaErr) {
+    case WSAEWOULDBLOCK:
+      return EAGAIN;
+    default:
+      return wsaErr;
+  }
+}
+
 template <class R, class F, class... Args>
 static R wrapSocketFunction(F f, int s, Args... args) {
   SOCKET h = fd_to_socket(s);
   R ret = f(h, args...);
-  errno = WSAGetLastError();
+  errno = translate_wsa_error(WSAGetLastError());
   return ret;
 }
 
@@ -88,7 +97,7 @@ int bind(int s, const struct sockaddr* name, socklen_t namelen) {
 
 int connect(int s, const struct sockaddr* name, socklen_t namelen) {
   auto r = wrapSocketFunction<int>(::connect, s, name, namelen);
-  if (r == -1 && errno == WSAEWOULDBLOCK) {
+  if (r == -1 && WSAGetLastError() == WSAEWOULDBLOCK) {
     errno = EINPROGRESS;
   }
   return r;
@@ -103,13 +112,20 @@ int getsockname(int s, struct sockaddr* name, socklen_t* namelen) {
 }
 
 int getsockopt(int s, int level, int optname, char* optval, socklen_t* optlen) {
-  return wrapSocketFunction<int>(
-      ::getsockopt, s, level, optname, (char*)optval, optlen);
+  return getsockopt(s, level, optname, (void*)optval, optlen);
 }
 
 int getsockopt(int s, int level, int optname, void* optval, socklen_t* optlen) {
-  return wrapSocketFunction<int>(
-      ::getsockopt, s, level, optname, (char*)optval, optlen);
+  auto ret = wrapSocketFunction<int>(
+      ::getsockopt, s, level, optname, (char*)optval, (int*)optlen);
+  if (optname == TCP_NODELAY && *optlen == 1) {
+    // Windows is weird about this value, and documents it as a
+    // BOOL (ie. int) but expects the variable to be bool (1-byte),
+    // so we get to adapt the interface to work that way.
+    *(int*)optval = *(uint8_t*)optval;
+    *optlen = sizeof(int);
+  }
+  return ret;
 }
 
 int inet_aton(const char* cp, struct in_addr* inp) {
@@ -134,15 +150,34 @@ int poll(struct pollfd fds[], nfds_t nfds, int timeout) {
 }
 
 ssize_t recv(int s, void* buf, size_t len, int flags) {
+  if ((flags & MSG_DONTWAIT) == MSG_DONTWAIT) {
+    flags &= ~MSG_DONTWAIT;
+
+    u_long pendingRead = 0;
+    if (ioctlsocket(fd_to_socket(s), FIONREAD, &pendingRead)) {
+      errno = translate_wsa_error(WSAGetLastError());
+      return -1;
+    }
+
+    fd_set readSet;
+    FD_ZERO(&readSet);
+    FD_SET(fd_to_socket(s), &readSet);
+    timeval timeout{0, 0};
+    auto ret = select(1, &readSet, nullptr, nullptr, &timeout);
+    if (ret == 0) {
+      errno = EWOULDBLOCK;
+      return -1;
+    }
+  }
   return wrapSocketFunction<ssize_t>(::recv, s, (char*)buf, (int)len, flags);
 }
 
 ssize_t recv(int s, char* buf, int len, int flags) {
-  return wrapSocketFunction<ssize_t>(::recv, s, (char*)buf, len, flags);
+  return recv(s, (void*)buf, (size_t)len, flags);
 }
 
 ssize_t recv(int s, void* buf, int len, int flags) {
-  return wrapSocketFunction<ssize_t>(::recv, s, (char*)buf, len, flags);
+  return recv(s, (void*)buf, (size_t)len, flags);
 }
 
 ssize_t recvfrom(
@@ -152,8 +187,53 @@ ssize_t recvfrom(
     int flags,
     struct sockaddr* from,
     socklen_t* fromlen) {
+  if ((flags & MSG_TRUNC) == MSG_TRUNC) {
+    SOCKET h = fd_to_socket(s);
+
+    WSABUF wBuf{};
+    wBuf.buf = (CHAR*)buf;
+    wBuf.len = len;
+    WSAMSG wMsg{};
+    wMsg.dwBufferCount = 1;
+    wMsg.lpBuffers = &wBuf;
+    wMsg.name = from;
+    if (fromlen != nullptr) {
+      wMsg.namelen = *fromlen;
+    }
+
+    // WSARecvMsg is an extension, so we don't get
+    // the convenience of being able to call it directly, even though
+    // WSASendMsg is part of the normal API -_-...
+    LPFN_WSARECVMSG WSARecvMsg;
+    GUID WSARecgMsg_GUID = WSAID_WSARECVMSG;
+    DWORD recMsgBytes;
+    WSAIoctl(
+        h,
+        SIO_GET_EXTENSION_FUNCTION_POINTER,
+        &WSARecgMsg_GUID,
+        sizeof(WSARecgMsg_GUID),
+        &WSARecvMsg,
+        sizeof(WSARecvMsg),
+        &recMsgBytes,
+        nullptr,
+        nullptr);
+
+    DWORD bytesReceived;
+    int res = WSARecvMsg(h, &wMsg, &bytesReceived, nullptr, nullptr);
+    errno = translate_wsa_error(WSAGetLastError());
+    if (res == 0) {
+      return bytesReceived;
+    }
+    if (fromlen != nullptr) {
+      *fromlen = wMsg.namelen;
+    }
+    if ((wMsg.dwFlags & MSG_TRUNC) == MSG_TRUNC) {
+      return wBuf.len + 1;
+    }
+    return -1;
+  }
   return wrapSocketFunction<ssize_t>(
-      ::recvfrom, s, (char*)buf, (int)len, flags, from, fromlen);
+      ::recvfrom, s, (char*)buf, (int)len, flags, from, (int*)fromlen);
 }
 
 ssize_t recvfrom(
@@ -163,8 +243,7 @@ ssize_t recvfrom(
     int flags,
     struct sockaddr* from,
     socklen_t* fromlen) {
-  return wrapSocketFunction<ssize_t>(
-      ::recvfrom, s, (char*)buf, len, flags, from, fromlen);
+  return recvfrom(s, (void*)buf, (size_t)len, flags, from, fromlen);
 }
 
 ssize_t recvfrom(
@@ -174,8 +253,7 @@ ssize_t recvfrom(
     int flags,
     struct sockaddr* from,
     socklen_t* fromlen) {
-  return wrapSocketFunction<ssize_t>(
-      ::recvfrom, s, (char*)buf, len, flags, from, fromlen);
+  return recvfrom(s, (void*)buf, (size_t)len, flags, from, fromlen);
 }
 
 ssize_t recvmsg(int s, struct msghdr* message, int fl) {
@@ -218,40 +296,48 @@ ssize_t recvmsg(int s, struct msghdr* message, int fl) {
 
   DWORD bytesReceived;
   int res = WSARecvMsg(h, &msg, &bytesReceived, nullptr, nullptr);
+  errno = translate_wsa_error(WSAGetLastError());
   return res == 0 ? (ssize_t)bytesReceived : -1;
 }
 
 ssize_t send(int s, const void* buf, size_t len, int flags) {
-  return wrapSocketFunction<ssize_t>(::send, s, (char*)buf, (int)len, flags);
+  return wrapSocketFunction<ssize_t>(
+      ::send, s, (const char*)buf, (int)len, flags);
 }
 
 ssize_t send(int s, const char* buf, int len, int flags) {
-  return wrapSocketFunction<ssize_t>(::send, s, (char*)buf, len, flags);
+  return send(s, (const void*)buf, (size_t)len, flags);
 }
 
 ssize_t send(int s, const void* buf, int len, int flags) {
-  return wrapSocketFunction<ssize_t>(::send, s, (char*)buf, len, flags);
+  return send(s, (const void*)buf, (size_t)len, flags);
 }
 
 ssize_t sendmsg(int s, const struct msghdr* message, int fl) {
   SOCKET h = fd_to_socket(s);
 
-  // Don't currently support the name translation.
-  if (message->msg_name != nullptr || message->msg_namelen != 0) {
-    return (ssize_t)-1;
-  }
-
   // Unfortunately, WSASendMsg requires the socket to have been opened
   // as either SOCK_DGRAM or SOCK_RAW, but sendmsg has no such requirement,
   // so we have to implement it based on send instead :(
   ssize_t bytesSent = 0;
   for (size_t i = 0; i < message->msg_iovlen; i++) {
-    auto r = ::send(
-        h,
-        (const char*)message->msg_iov[i].iov_base,
-        message->msg_iov[i].iov_len,
-        message->msg_flags);
-    if (r == -1) {
+    int r = -1;
+    if (message->msg_name != nullptr) {
+      r = ::sendto(
+          h,
+          (const char*)message->msg_iov[i].iov_base,
+          (int)message->msg_iov[i].iov_len,
+          message->msg_flags,
+          (const sockaddr*)message->msg_name,
+          (int)message->msg_namelen);
+    } else {
+      r = ::send(
+          h,
+          (const char*)message->msg_iov[i].iov_base,
+          (int)message->msg_iov[i].iov_len,
+          message->msg_flags);
+    }
+    if (r == -1 || r != message->msg_iov[i].iov_len) {
       return -1;
     }
     bytesSent += r;
@@ -267,7 +353,7 @@ ssize_t sendto(
     const sockaddr* to,
     socklen_t tolen) {
   return wrapSocketFunction<ssize_t>(
-      ::sendto, s, (char*)buf, (int)len, flags, to, tolen);
+      ::sendto, s, (const char*)buf, (int)len, flags, to, (int)tolen);
 }
 
 ssize_t sendto(
@@ -277,8 +363,7 @@ ssize_t sendto(
     int flags,
     const sockaddr* to,
     socklen_t tolen) {
-  return wrapSocketFunction<ssize_t>(
-      ::sendto, s, (char*)buf, len, flags, to, tolen);
+  return sendto(s, (const void*)buf, (size_t)len, flags, to, tolen);
 }
 
 ssize_t sendto(
@@ -288,16 +373,24 @@ ssize_t sendto(
     int flags,
     const sockaddr* to,
     socklen_t tolen) {
-  return wrapSocketFunction<ssize_t>(
-      ::sendto, s, (char*)buf, len, flags, to, tolen);
+  return sendto(s, buf, (size_t)len, flags, to, tolen);
 }
 
 int setsockopt(
     int s,
     int level,
     int optname,
-    const char* optval,
+    const void* optval,
     socklen_t optlen) {
+  if (optname == SO_REUSEADDR) {
+    // We don't have an equivelent to the Linux & OSX meaning of this
+    // on Windows, so ignore it.
+    return 0;
+  } else if (optname == SO_REUSEPORT) {
+    // Windows's SO_REUSEADDR option is closer to SO_REUSEPORT than
+    // it is to the Linux & OSX meaning of SO_REUSEADDR.
+    return -1;
+  }
   return wrapSocketFunction<int>(
       ::setsockopt, s, level, optname, (char*)optval, optlen);
 }
@@ -306,10 +399,9 @@ int setsockopt(
     int s,
     int level,
     int optname,
-    const void* optval,
+    const char* optval,
     socklen_t optlen) {
-  return wrapSocketFunction<int>(
-      ::setsockopt, s, level, optname, (char*)optval, optlen);
+  return setsockopt(s, level, optname, (const void*)optval, optlen);
 }
 
 int shutdown(int s, int how) {
index 83e1b92e5c24a712dc9d1ebcf22bf0b174b8c339..1763006cd3f54458710ef90a2d0c5a4eca6df124 100755 (executable)
@@ -37,7 +37,7 @@ using sa_family_t = ADDRESS_FAMILY;
 
 // We don't actually support either of these flags
 // currently.
-#define MSG_DONTWAIT 0
+#define MSG_DONTWAIT 0x1000
 #define MSG_EOR 0
 struct msghdr {
   void* msg_name;
@@ -61,7 +61,10 @@ struct sockaddr_un {
 // These are the same, but PF_LOCAL
 // isn't defined by WinSock.
 #define PF_LOCAL PF_UNIX
-#define SO_REUSEPORT SO_REUSEADDR
+
+// This isn't defined by Windows, and we need to
+// distinguish it from SO_REUSEADDR
+#define SO_REUSEPORT 0x7001
 
 // Someone thought it would be a good idea
 // to define a field via a macro...
@@ -94,6 +97,7 @@ using ::socket;
 bool is_fh_socket(int fh);
 SOCKET fd_to_socket(int fd);
 int socket_to_fd(SOCKET s);
+int translate_wsa_error(int wsaErr);
 
 // These aren't additional overloads, but rather other functions that
 // are referenced that we need to wrap, or, in the case of inet_aton,