Add a couple more things to the socket portability layer
[folly.git] / folly / portability / Sockets.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/portability/Sockets.h>
18
19 #ifdef _MSC_VER
20
21 #include <errno.h>
22 #include <fcntl.h>
23
24 #include <MSWSock.h>
25
26 #include <folly/ScopeGuard.h>
27
28 namespace folly {
29 namespace portability {
30 namespace sockets {
31
32 // We have to startup WSA.
33 static struct FSPInit {
34   FSPInit() {
35     WSADATA dat;
36     WSAStartup(MAKEWORD(2, 2), &dat);
37   }
38   ~FSPInit() { WSACleanup(); }
39 } fspInit;
40
41 bool is_fh_socket(int fh) {
42   SOCKET h = fd_to_socket(fh);
43   constexpr long kDummyEvents = 0xABCDEF12;
44   WSANETWORKEVENTS e;
45   e.lNetworkEvents = kDummyEvents;
46   WSAEnumNetworkEvents(h, nullptr, &e);
47   return e.lNetworkEvents != kDummyEvents;
48 }
49
50 SOCKET fd_to_socket(int fd) {
51   // We do this in a roundabout way to allow us to compile even if
52   // we're doing a bit of trickery to ensure that things aren't
53   // being implicitly converted to a SOCKET by temporarily
54   // adjusting the windows headers to define SOCKET as a
55   // structure.
56   static_assert(sizeof(HANDLE) == sizeof(SOCKET), "Handle size mismatch.");
57   HANDLE tmp = (HANDLE)_get_osfhandle(fd);
58   return *(SOCKET*)&tmp;
59 }
60
61 int socket_to_fd(SOCKET s) {
62   return _open_osfhandle((intptr_t)s, O_RDWR | O_BINARY);
63 }
64
65 template <class R, class F, class... Args>
66 static R wrapSocketFunction(F f, int s, Args... args) {
67   SOCKET h = fd_to_socket(s);
68   R ret = f(h, args...);
69   errno = WSAGetLastError();
70   return ret;
71 }
72
73 int accept(int s, struct sockaddr* addr, socklen_t* addrlen) {
74   return socket_to_fd(wrapSocketFunction<SOCKET>(::accept, s, addr, addrlen));
75 }
76
77 int bind(int s, const struct sockaddr* name, socklen_t namelen) {
78   return wrapSocketFunction<int>(::bind, s, name, namelen);
79 }
80
81 int connect(int s, const struct sockaddr* name, socklen_t namelen) {
82   return wrapSocketFunction<int>(::connect, s, name, namelen);
83 }
84
85 int getpeername(int s, struct sockaddr* name, socklen_t* namelen) {
86   return wrapSocketFunction<int>(::getpeername, s, name, namelen);
87 }
88
89 int getsockname(int s, struct sockaddr* name, socklen_t* namelen) {
90   return wrapSocketFunction<int>(::getsockname, s, name, namelen);
91 }
92
93 int getsockopt(int s, int level, int optname, char* optval, socklen_t* optlen) {
94   return wrapSocketFunction<int>(
95       ::getsockopt, s, level, optname, (char*)optval, optlen);
96 }
97
98 int getsockopt(int s, int level, int optname, void* optval, socklen_t* optlen) {
99   return wrapSocketFunction<int>(
100       ::getsockopt, s, level, optname, (char*)optval, optlen);
101 }
102
103 int inet_aton(const char* cp, struct in_addr* inp) {
104   inp->s_addr = inet_addr(cp);
105   return inp->s_addr == INADDR_NONE ? 0 : 1;
106 }
107
108 const char* inet_ntop(int af, const void* src, char* dst, socklen_t size) {
109   return ::inet_ntop(af, (char*)src, dst, size);
110 }
111
112 int listen(int s, int backlog) {
113   return wrapSocketFunction<int>(::listen, s, backlog);
114 }
115
116 int poll(struct pollfd fds[], nfds_t nfds, int timeout) {
117   // TODO: Allow both file descriptors and SOCKETs in this.
118   for (int i = 0; i < nfds; i++) {
119     fds[i].fd = fd_to_socket(fds[i].fd);
120   }
121   return ::WSAPoll(fds, (ULONG)nfds, timeout);
122 }
123
124 ssize_t recv(int s, void* buf, size_t len, int flags) {
125   return wrapSocketFunction<ssize_t>(::recv, s, (char*)buf, (int)len, flags);
126 }
127
128 ssize_t recv(int s, char* buf, int len, int flags) {
129   return wrapSocketFunction<ssize_t>(::recv, s, (char*)buf, len, flags);
130 }
131
132 ssize_t recv(int s, void* buf, int len, int flags) {
133   return wrapSocketFunction<ssize_t>(::recv, s, (char*)buf, len, flags);
134 }
135
136 ssize_t recvfrom(
137     int s,
138     void* buf,
139     size_t len,
140     int flags,
141     struct sockaddr* from,
142     socklen_t* fromlen) {
143   return wrapSocketFunction<ssize_t>(
144       ::recvfrom, s, (char*)buf, (int)len, flags, from, fromlen);
145 }
146
147 ssize_t recvfrom(
148     int s,
149     char* buf,
150     int len,
151     int flags,
152     struct sockaddr* from,
153     socklen_t* fromlen) {
154   return wrapSocketFunction<ssize_t>(
155       ::recvfrom, s, (char*)buf, len, flags, from, fromlen);
156 }
157
158 ssize_t recvfrom(
159     int s,
160     void* buf,
161     int len,
162     int flags,
163     struct sockaddr* from,
164     socklen_t* fromlen) {
165   return wrapSocketFunction<ssize_t>(
166       ::recvfrom, s, (char*)buf, len, flags, from, fromlen);
167 }
168
169 ssize_t recvmsg(int s, struct msghdr* message, int fl) {
170   SOCKET h = fd_to_socket(s);
171
172   // Don't currently support the name translation.
173   if (message->msg_name != nullptr || message->msg_namelen != 0) {
174     return (ssize_t)-1;
175   }
176   WSAMSG msg;
177   msg.name = nullptr;
178   msg.namelen = 0;
179   msg.Control.buf = (CHAR*)message->msg_control;
180   msg.Control.len = (ULONG)message->msg_controllen;
181   msg.dwFlags = 0;
182   msg.dwBufferCount = (DWORD)message->msg_iovlen;
183   msg.lpBuffers = new WSABUF[message->msg_iovlen];
184   SCOPE_EXIT { delete[] msg.lpBuffers; };
185   for (size_t i = 0; i < message->msg_iovlen; i++) {
186     msg.lpBuffers[i].buf = (CHAR*)message->msg_iov[i].iov_base;
187     msg.lpBuffers[i].len = (ULONG)message->msg_iov[i].iov_len;
188   }
189
190   // WSARecvMsg is an extension, so we don't get
191   // the convenience of being able to call it directly, even though
192   // WSASendMsg is part of the normal API -_-...
193   LPFN_WSARECVMSG WSARecvMsg;
194   GUID WSARecgMsg_GUID = WSAID_WSARECVMSG;
195   DWORD recMsgBytes;
196   WSAIoctl(
197       h,
198       SIO_GET_EXTENSION_FUNCTION_POINTER,
199       &WSARecgMsg_GUID,
200       sizeof(WSARecgMsg_GUID),
201       &WSARecvMsg,
202       sizeof(WSARecvMsg),
203       &recMsgBytes,
204       nullptr,
205       nullptr);
206
207   DWORD bytesReceived;
208   int res = WSARecvMsg(h, &msg, &bytesReceived, nullptr, nullptr);
209   return res == 0 ? (ssize_t)bytesReceived : -1;
210 }
211
212 ssize_t send(int s, const void* buf, size_t len, int flags) {
213   return wrapSocketFunction<ssize_t>(::send, s, (char*)buf, (int)len, flags);
214 }
215
216 ssize_t send(int s, const char* buf, int len, int flags) {
217   return wrapSocketFunction<ssize_t>(::send, s, (char*)buf, len, flags);
218 }
219
220 ssize_t send(int s, const void* buf, int len, int flags) {
221   return wrapSocketFunction<ssize_t>(::send, s, (char*)buf, len, flags);
222 }
223
224 ssize_t sendmsg(int s, const struct msghdr* message, int fl) {
225   SOCKET h = fd_to_socket(s);
226
227   // Don't currently support the name translation.
228   if (message->msg_name != nullptr || message->msg_namelen != 0) {
229     return (ssize_t)-1;
230   }
231   WSAMSG msg;
232   msg.name = nullptr;
233   msg.namelen = 0;
234   msg.Control.buf = (CHAR*)message->msg_control;
235   msg.Control.len = (ULONG)message->msg_controllen;
236   msg.dwFlags = 0;
237   msg.dwBufferCount = (DWORD)message->msg_iovlen;
238   msg.lpBuffers = new WSABUF[message->msg_iovlen];
239   SCOPE_EXIT { delete[] msg.lpBuffers; };
240   for (size_t i = 0; i < message->msg_iovlen; i++) {
241     msg.lpBuffers[i].buf = (CHAR*)message->msg_iov[i].iov_base;
242     msg.lpBuffers[i].len = (ULONG)message->msg_iov[i].iov_len;
243   }
244
245   DWORD bytesSent;
246   int res = WSASendMsg(h, &msg, 0, &bytesSent, nullptr, nullptr);
247   return res == 0 ? (ssize_t)bytesSent : -1;
248 }
249
250 ssize_t sendto(
251     int s,
252     const void* buf,
253     size_t len,
254     int flags,
255     const sockaddr* to,
256     socklen_t tolen) {
257   return wrapSocketFunction<ssize_t>(
258       ::sendto, s, (char*)buf, (int)len, flags, to, tolen);
259 }
260
261 ssize_t sendto(
262     int s,
263     const char* buf,
264     int len,
265     int flags,
266     const sockaddr* to,
267     socklen_t tolen) {
268   return wrapSocketFunction<ssize_t>(
269       ::sendto, s, (char*)buf, len, flags, to, tolen);
270 }
271
272 ssize_t sendto(
273     int s,
274     const void* buf,
275     int len,
276     int flags,
277     const sockaddr* to,
278     socklen_t tolen) {
279   return wrapSocketFunction<ssize_t>(
280       ::sendto, s, (char*)buf, len, flags, to, tolen);
281 }
282
283 int setsockopt(
284     int s,
285     int level,
286     int optname,
287     const char* optval,
288     socklen_t optlen) {
289   return wrapSocketFunction<int>(
290       ::setsockopt, s, level, optname, (char*)optval, optlen);
291 }
292
293 int setsockopt(
294     int s,
295     int level,
296     int optname,
297     const void* optval,
298     socklen_t optlen) {
299   return wrapSocketFunction<int>(
300       ::setsockopt, s, level, optname, (char*)optval, optlen);
301 }
302
303 int shutdown(int s, int how) {
304   return wrapSocketFunction<int>(::shutdown, s, how);
305 }
306
307 int socket(int af, int type, int protocol) {
308   return socket_to_fd(::socket(af, type, protocol));
309 }
310
311 int socketpair(int domain, int type, int protocol, int sv[2]) {
312   // Stub this out for now, to get things compiling.
313   return -1;
314 }
315 }
316 }
317 }
318 #endif