Merge branch 'for-linus' of git://git.kernel.org/pub/scm/linux/kernel/git/sage/ceph...
[firefly-linux-kernel-4.4.55.git] / net / ipv4 / fou.c
1 #include <linux/module.h>
2 #include <linux/errno.h>
3 #include <linux/socket.h>
4 #include <linux/skbuff.h>
5 #include <linux/ip.h>
6 #include <linux/udp.h>
7 #include <linux/types.h>
8 #include <linux/kernel.h>
9 #include <net/genetlink.h>
10 #include <net/gue.h>
11 #include <net/ip.h>
12 #include <net/protocol.h>
13 #include <net/udp.h>
14 #include <net/udp_tunnel.h>
15 #include <net/xfrm.h>
16 #include <uapi/linux/fou.h>
17 #include <uapi/linux/genetlink.h>
18
19 static DEFINE_SPINLOCK(fou_lock);
20 static LIST_HEAD(fou_list);
21
22 struct fou {
23         struct socket *sock;
24         u8 protocol;
25         u16 port;
26         struct udp_offload udp_offloads;
27         struct list_head list;
28 };
29
30 struct fou_cfg {
31         u16 type;
32         u8 protocol;
33         struct udp_port_cfg udp_config;
34 };
35
36 static inline struct fou *fou_from_sock(struct sock *sk)
37 {
38         return sk->sk_user_data;
39 }
40
41 static int fou_udp_encap_recv_deliver(struct sk_buff *skb,
42                                       u8 protocol, size_t len)
43 {
44         struct iphdr *iph = ip_hdr(skb);
45
46         /* Remove 'len' bytes from the packet (UDP header and
47          * FOU header if present), modify the protocol to the one
48          * we found, and then call rcv_encap.
49          */
50         iph->tot_len = htons(ntohs(iph->tot_len) - len);
51         __skb_pull(skb, len);
52         skb_postpull_rcsum(skb, udp_hdr(skb), len);
53         skb_reset_transport_header(skb);
54
55         return -protocol;
56 }
57
58 static int fou_udp_recv(struct sock *sk, struct sk_buff *skb)
59 {
60         struct fou *fou = fou_from_sock(sk);
61
62         if (!fou)
63                 return 1;
64
65         return fou_udp_encap_recv_deliver(skb, fou->protocol,
66                                           sizeof(struct udphdr));
67 }
68
69 static int gue_udp_recv(struct sock *sk, struct sk_buff *skb)
70 {
71         struct fou *fou = fou_from_sock(sk);
72         size_t len;
73         struct guehdr *guehdr;
74         struct udphdr *uh;
75
76         if (!fou)
77                 return 1;
78
79         len = sizeof(struct udphdr) + sizeof(struct guehdr);
80         if (!pskb_may_pull(skb, len))
81                 goto drop;
82
83         uh = udp_hdr(skb);
84         guehdr = (struct guehdr *)&uh[1];
85
86         len += guehdr->hlen << 2;
87         if (!pskb_may_pull(skb, len))
88                 goto drop;
89
90         if (guehdr->version != 0)
91                 goto drop;
92
93         if (guehdr->flags) {
94                 /* No support yet */
95                 goto drop;
96         }
97
98         return fou_udp_encap_recv_deliver(skb, guehdr->next_hdr, len);
99 drop:
100         kfree_skb(skb);
101         return 0;
102 }
103
104 static struct sk_buff **fou_gro_receive(struct sk_buff **head,
105                                         struct sk_buff *skb)
106 {
107         const struct net_offload *ops;
108         struct sk_buff **pp = NULL;
109         u8 proto = NAPI_GRO_CB(skb)->proto;
110         const struct net_offload **offloads;
111
112         rcu_read_lock();
113         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
114         ops = rcu_dereference(offloads[proto]);
115         if (!ops || !ops->callbacks.gro_receive)
116                 goto out_unlock;
117
118         pp = ops->callbacks.gro_receive(head, skb);
119
120 out_unlock:
121         rcu_read_unlock();
122
123         return pp;
124 }
125
126 static int fou_gro_complete(struct sk_buff *skb, int nhoff)
127 {
128         const struct net_offload *ops;
129         u8 proto = NAPI_GRO_CB(skb)->proto;
130         int err = -ENOSYS;
131         const struct net_offload **offloads;
132
133         rcu_read_lock();
134         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
135         ops = rcu_dereference(offloads[proto]);
136         if (WARN_ON(!ops || !ops->callbacks.gro_complete))
137                 goto out_unlock;
138
139         err = ops->callbacks.gro_complete(skb, nhoff);
140
141 out_unlock:
142         rcu_read_unlock();
143
144         return err;
145 }
146
147 static struct sk_buff **gue_gro_receive(struct sk_buff **head,
148                                         struct sk_buff *skb)
149 {
150         const struct net_offload **offloads;
151         const struct net_offload *ops;
152         struct sk_buff **pp = NULL;
153         struct sk_buff *p;
154         u8 proto;
155         struct guehdr *guehdr;
156         unsigned int hlen, guehlen;
157         unsigned int off;
158         int flush = 1;
159
160         off = skb_gro_offset(skb);
161         hlen = off + sizeof(*guehdr);
162         guehdr = skb_gro_header_fast(skb, off);
163         if (skb_gro_header_hard(skb, hlen)) {
164                 guehdr = skb_gro_header_slow(skb, hlen, off);
165                 if (unlikely(!guehdr))
166                         goto out;
167         }
168
169         proto = guehdr->next_hdr;
170
171         rcu_read_lock();
172         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
173         ops = rcu_dereference(offloads[proto]);
174         if (WARN_ON(!ops || !ops->callbacks.gro_receive))
175                 goto out_unlock;
176
177         guehlen = sizeof(*guehdr) + (guehdr->hlen << 2);
178
179         hlen = off + guehlen;
180         if (skb_gro_header_hard(skb, hlen)) {
181                 guehdr = skb_gro_header_slow(skb, hlen, off);
182                 if (unlikely(!guehdr))
183                         goto out_unlock;
184         }
185
186         flush = 0;
187
188         for (p = *head; p; p = p->next) {
189                 const struct guehdr *guehdr2;
190
191                 if (!NAPI_GRO_CB(p)->same_flow)
192                         continue;
193
194                 guehdr2 = (struct guehdr *)(p->data + off);
195
196                 /* Compare base GUE header to be equal (covers
197                  * hlen, version, next_hdr, and flags.
198                  */
199                 if (guehdr->word != guehdr2->word) {
200                         NAPI_GRO_CB(p)->same_flow = 0;
201                         continue;
202                 }
203
204                 /* Compare optional fields are the same. */
205                 if (guehdr->hlen && memcmp(&guehdr[1], &guehdr2[1],
206                                            guehdr->hlen << 2)) {
207                         NAPI_GRO_CB(p)->same_flow = 0;
208                         continue;
209                 }
210         }
211
212         skb_gro_pull(skb, guehlen);
213
214         /* Adjusted NAPI_GRO_CB(skb)->csum after skb_gro_pull()*/
215         skb_gro_postpull_rcsum(skb, guehdr, guehlen);
216
217         pp = ops->callbacks.gro_receive(head, skb);
218
219 out_unlock:
220         rcu_read_unlock();
221 out:
222         NAPI_GRO_CB(skb)->flush |= flush;
223
224         return pp;
225 }
226
227 static int gue_gro_complete(struct sk_buff *skb, int nhoff)
228 {
229         const struct net_offload **offloads;
230         struct guehdr *guehdr = (struct guehdr *)(skb->data + nhoff);
231         const struct net_offload *ops;
232         unsigned int guehlen;
233         u8 proto;
234         int err = -ENOENT;
235
236         proto = guehdr->next_hdr;
237
238         guehlen = sizeof(*guehdr) + (guehdr->hlen << 2);
239
240         rcu_read_lock();
241         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
242         ops = rcu_dereference(offloads[proto]);
243         if (WARN_ON(!ops || !ops->callbacks.gro_complete))
244                 goto out_unlock;
245
246         err = ops->callbacks.gro_complete(skb, nhoff + guehlen);
247
248 out_unlock:
249         rcu_read_unlock();
250         return err;
251 }
252
253 static int fou_add_to_port_list(struct fou *fou)
254 {
255         struct fou *fout;
256
257         spin_lock(&fou_lock);
258         list_for_each_entry(fout, &fou_list, list) {
259                 if (fou->port == fout->port) {
260                         spin_unlock(&fou_lock);
261                         return -EALREADY;
262                 }
263         }
264
265         list_add(&fou->list, &fou_list);
266         spin_unlock(&fou_lock);
267
268         return 0;
269 }
270
271 static void fou_release(struct fou *fou)
272 {
273         struct socket *sock = fou->sock;
274         struct sock *sk = sock->sk;
275
276         udp_del_offload(&fou->udp_offloads);
277
278         list_del(&fou->list);
279
280         /* Remove hooks into tunnel socket */
281         sk->sk_user_data = NULL;
282
283         sock_release(sock);
284
285         kfree(fou);
286 }
287
288 static int fou_encap_init(struct sock *sk, struct fou *fou, struct fou_cfg *cfg)
289 {
290         udp_sk(sk)->encap_rcv = fou_udp_recv;
291         fou->protocol = cfg->protocol;
292         fou->udp_offloads.callbacks.gro_receive = fou_gro_receive;
293         fou->udp_offloads.callbacks.gro_complete = fou_gro_complete;
294         fou->udp_offloads.port = cfg->udp_config.local_udp_port;
295         fou->udp_offloads.ipproto = cfg->protocol;
296
297         return 0;
298 }
299
300 static int gue_encap_init(struct sock *sk, struct fou *fou, struct fou_cfg *cfg)
301 {
302         udp_sk(sk)->encap_rcv = gue_udp_recv;
303         fou->udp_offloads.callbacks.gro_receive = gue_gro_receive;
304         fou->udp_offloads.callbacks.gro_complete = gue_gro_complete;
305         fou->udp_offloads.port = cfg->udp_config.local_udp_port;
306
307         return 0;
308 }
309
310 static int fou_create(struct net *net, struct fou_cfg *cfg,
311                       struct socket **sockp)
312 {
313         struct fou *fou = NULL;
314         int err;
315         struct socket *sock = NULL;
316         struct sock *sk;
317
318         /* Open UDP socket */
319         err = udp_sock_create(net, &cfg->udp_config, &sock);
320         if (err < 0)
321                 goto error;
322
323         /* Allocate FOU port structure */
324         fou = kzalloc(sizeof(*fou), GFP_KERNEL);
325         if (!fou) {
326                 err = -ENOMEM;
327                 goto error;
328         }
329
330         sk = sock->sk;
331
332         fou->port = cfg->udp_config.local_udp_port;
333
334         /* Initial for fou type */
335         switch (cfg->type) {
336         case FOU_ENCAP_DIRECT:
337                 err = fou_encap_init(sk, fou, cfg);
338                 if (err)
339                         goto error;
340                 break;
341         case FOU_ENCAP_GUE:
342                 err = gue_encap_init(sk, fou, cfg);
343                 if (err)
344                         goto error;
345                 break;
346         default:
347                 err = -EINVAL;
348                 goto error;
349         }
350
351         udp_sk(sk)->encap_type = 1;
352         udp_encap_enable();
353
354         sk->sk_user_data = fou;
355         fou->sock = sock;
356
357         udp_set_convert_csum(sk, true);
358
359         sk->sk_allocation = GFP_ATOMIC;
360
361         if (cfg->udp_config.family == AF_INET) {
362                 err = udp_add_offload(&fou->udp_offloads);
363                 if (err)
364                         goto error;
365         }
366
367         err = fou_add_to_port_list(fou);
368         if (err)
369                 goto error;
370
371         if (sockp)
372                 *sockp = sock;
373
374         return 0;
375
376 error:
377         kfree(fou);
378         if (sock)
379                 sock_release(sock);
380
381         return err;
382 }
383
384 static int fou_destroy(struct net *net, struct fou_cfg *cfg)
385 {
386         struct fou *fou;
387         u16 port = cfg->udp_config.local_udp_port;
388         int err = -EINVAL;
389
390         spin_lock(&fou_lock);
391         list_for_each_entry(fou, &fou_list, list) {
392                 if (fou->port == port) {
393                         udp_del_offload(&fou->udp_offloads);
394                         fou_release(fou);
395                         err = 0;
396                         break;
397                 }
398         }
399         spin_unlock(&fou_lock);
400
401         return err;
402 }
403
404 static struct genl_family fou_nl_family = {
405         .id             = GENL_ID_GENERATE,
406         .hdrsize        = 0,
407         .name           = FOU_GENL_NAME,
408         .version        = FOU_GENL_VERSION,
409         .maxattr        = FOU_ATTR_MAX,
410         .netnsok        = true,
411 };
412
413 static struct nla_policy fou_nl_policy[FOU_ATTR_MAX + 1] = {
414         [FOU_ATTR_PORT] = { .type = NLA_U16, },
415         [FOU_ATTR_AF] = { .type = NLA_U8, },
416         [FOU_ATTR_IPPROTO] = { .type = NLA_U8, },
417         [FOU_ATTR_TYPE] = { .type = NLA_U8, },
418 };
419
420 static int parse_nl_config(struct genl_info *info,
421                            struct fou_cfg *cfg)
422 {
423         memset(cfg, 0, sizeof(*cfg));
424
425         cfg->udp_config.family = AF_INET;
426
427         if (info->attrs[FOU_ATTR_AF]) {
428                 u8 family = nla_get_u8(info->attrs[FOU_ATTR_AF]);
429
430                 if (family != AF_INET && family != AF_INET6)
431                         return -EINVAL;
432
433                 cfg->udp_config.family = family;
434         }
435
436         if (info->attrs[FOU_ATTR_PORT]) {
437                 u16 port = nla_get_u16(info->attrs[FOU_ATTR_PORT]);
438
439                 cfg->udp_config.local_udp_port = port;
440         }
441
442         if (info->attrs[FOU_ATTR_IPPROTO])
443                 cfg->protocol = nla_get_u8(info->attrs[FOU_ATTR_IPPROTO]);
444
445         if (info->attrs[FOU_ATTR_TYPE])
446                 cfg->type = nla_get_u8(info->attrs[FOU_ATTR_TYPE]);
447
448         return 0;
449 }
450
451 static int fou_nl_cmd_add_port(struct sk_buff *skb, struct genl_info *info)
452 {
453         struct fou_cfg cfg;
454         int err;
455
456         err = parse_nl_config(info, &cfg);
457         if (err)
458                 return err;
459
460         return fou_create(&init_net, &cfg, NULL);
461 }
462
463 static int fou_nl_cmd_rm_port(struct sk_buff *skb, struct genl_info *info)
464 {
465         struct fou_cfg cfg;
466
467         parse_nl_config(info, &cfg);
468
469         return fou_destroy(&init_net, &cfg);
470 }
471
472 static const struct genl_ops fou_nl_ops[] = {
473         {
474                 .cmd = FOU_CMD_ADD,
475                 .doit = fou_nl_cmd_add_port,
476                 .policy = fou_nl_policy,
477                 .flags = GENL_ADMIN_PERM,
478         },
479         {
480                 .cmd = FOU_CMD_DEL,
481                 .doit = fou_nl_cmd_rm_port,
482                 .policy = fou_nl_policy,
483                 .flags = GENL_ADMIN_PERM,
484         },
485 };
486
487 static int __init fou_init(void)
488 {
489         int ret;
490
491         ret = genl_register_family_with_ops(&fou_nl_family,
492                                             fou_nl_ops);
493
494         return ret;
495 }
496
497 static void __exit fou_fini(void)
498 {
499         struct fou *fou, *next;
500
501         genl_unregister_family(&fou_nl_family);
502
503         /* Close all the FOU sockets */
504
505         spin_lock(&fou_lock);
506         list_for_each_entry_safe(fou, next, &fou_list, list)
507                 fou_release(fou);
508         spin_unlock(&fou_lock);
509 }
510
511 module_init(fou_init);
512 module_exit(fou_fini);
513 MODULE_AUTHOR("Tom Herbert <therbert@google.com>");
514 MODULE_LICENSE("GPL");