gue: TX support for using remote checksum offload option
[firefly-linux-kernel-4.4.55.git] / net / ipv4 / fou.c
1 #include <linux/module.h>
2 #include <linux/errno.h>
3 #include <linux/socket.h>
4 #include <linux/skbuff.h>
5 #include <linux/ip.h>
6 #include <linux/udp.h>
7 #include <linux/types.h>
8 #include <linux/kernel.h>
9 #include <net/genetlink.h>
10 #include <net/gue.h>
11 #include <net/ip.h>
12 #include <net/protocol.h>
13 #include <net/udp.h>
14 #include <net/udp_tunnel.h>
15 #include <net/xfrm.h>
16 #include <uapi/linux/fou.h>
17 #include <uapi/linux/genetlink.h>
18
19 static DEFINE_SPINLOCK(fou_lock);
20 static LIST_HEAD(fou_list);
21
22 struct fou {
23         struct socket *sock;
24         u8 protocol;
25         u16 port;
26         struct udp_offload udp_offloads;
27         struct list_head list;
28 };
29
30 struct fou_cfg {
31         u16 type;
32         u8 protocol;
33         struct udp_port_cfg udp_config;
34 };
35
36 static inline struct fou *fou_from_sock(struct sock *sk)
37 {
38         return sk->sk_user_data;
39 }
40
41 static void fou_recv_pull(struct sk_buff *skb, size_t len)
42 {
43         struct iphdr *iph = ip_hdr(skb);
44
45         /* Remove 'len' bytes from the packet (UDP header and
46          * FOU header if present).
47          */
48         iph->tot_len = htons(ntohs(iph->tot_len) - len);
49         __skb_pull(skb, len);
50         skb_postpull_rcsum(skb, udp_hdr(skb), len);
51         skb_reset_transport_header(skb);
52 }
53
54 static int fou_udp_recv(struct sock *sk, struct sk_buff *skb)
55 {
56         struct fou *fou = fou_from_sock(sk);
57
58         if (!fou)
59                 return 1;
60
61         fou_recv_pull(skb, sizeof(struct udphdr));
62
63         return -fou->protocol;
64 }
65
66 static int gue_control_message(struct sk_buff *skb, struct guehdr *guehdr)
67 {
68         /* No support yet */
69         kfree_skb(skb);
70         return 0;
71 }
72
73 static int gue_udp_recv(struct sock *sk, struct sk_buff *skb)
74 {
75         struct fou *fou = fou_from_sock(sk);
76         size_t len, optlen, hdrlen;
77         struct guehdr *guehdr;
78         void *data;
79
80         if (!fou)
81                 return 1;
82
83         len = sizeof(struct udphdr) + sizeof(struct guehdr);
84         if (!pskb_may_pull(skb, len))
85                 goto drop;
86
87         guehdr = (struct guehdr *)&udp_hdr(skb)[1];
88
89         optlen = guehdr->hlen << 2;
90         len += optlen;
91
92         if (!pskb_may_pull(skb, len))
93                 goto drop;
94
95         /* guehdr may change after pull */
96         guehdr = (struct guehdr *)&udp_hdr(skb)[1];
97
98         hdrlen = sizeof(struct guehdr) + optlen;
99
100         if (guehdr->version != 0 || validate_gue_flags(guehdr, optlen))
101                 goto drop;
102
103         /* Pull UDP and GUE headers */
104         fou_recv_pull(skb, len);
105
106         data = &guehdr[1];
107
108         if (guehdr->flags & GUE_FLAG_PRIV) {
109                 data += GUE_LEN_PRIV;
110
111                 /* Process private flags */
112         }
113
114         if (unlikely(guehdr->control))
115                 return gue_control_message(skb, guehdr);
116
117         return -guehdr->proto_ctype;
118
119 drop:
120         kfree_skb(skb);
121         return 0;
122 }
123
124 static struct sk_buff **fou_gro_receive(struct sk_buff **head,
125                                         struct sk_buff *skb)
126 {
127         const struct net_offload *ops;
128         struct sk_buff **pp = NULL;
129         u8 proto = NAPI_GRO_CB(skb)->proto;
130         const struct net_offload **offloads;
131
132         rcu_read_lock();
133         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
134         ops = rcu_dereference(offloads[proto]);
135         if (!ops || !ops->callbacks.gro_receive)
136                 goto out_unlock;
137
138         pp = ops->callbacks.gro_receive(head, skb);
139
140 out_unlock:
141         rcu_read_unlock();
142
143         return pp;
144 }
145
146 static int fou_gro_complete(struct sk_buff *skb, int nhoff)
147 {
148         const struct net_offload *ops;
149         u8 proto = NAPI_GRO_CB(skb)->proto;
150         int err = -ENOSYS;
151         const struct net_offload **offloads;
152
153         rcu_read_lock();
154         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
155         ops = rcu_dereference(offloads[proto]);
156         if (WARN_ON(!ops || !ops->callbacks.gro_complete))
157                 goto out_unlock;
158
159         err = ops->callbacks.gro_complete(skb, nhoff);
160
161 out_unlock:
162         rcu_read_unlock();
163
164         return err;
165 }
166
167 static struct sk_buff **gue_gro_receive(struct sk_buff **head,
168                                         struct sk_buff *skb)
169 {
170         const struct net_offload **offloads;
171         const struct net_offload *ops;
172         struct sk_buff **pp = NULL;
173         struct sk_buff *p;
174         struct guehdr *guehdr;
175         size_t len, optlen, hdrlen, off;
176         void *data;
177         int flush = 1;
178
179         off = skb_gro_offset(skb);
180         len = off + sizeof(*guehdr);
181
182         guehdr = skb_gro_header_fast(skb, off);
183         if (skb_gro_header_hard(skb, len)) {
184                 guehdr = skb_gro_header_slow(skb, len, off);
185                 if (unlikely(!guehdr))
186                         goto out;
187         }
188
189         optlen = guehdr->hlen << 2;
190         len += optlen;
191
192         if (skb_gro_header_hard(skb, len)) {
193                 guehdr = skb_gro_header_slow(skb, len, off);
194                 if (unlikely(!guehdr))
195                         goto out;
196         }
197
198         if (unlikely(guehdr->control) || guehdr->version != 0 ||
199             validate_gue_flags(guehdr, optlen))
200                 goto out;
201
202         hdrlen = sizeof(*guehdr) + optlen;
203
204         skb_gro_pull(skb, hdrlen);
205
206         /* Adjusted NAPI_GRO_CB(skb)->csum after skb_gro_pull()*/
207         skb_gro_postpull_rcsum(skb, guehdr, hdrlen);
208
209         data = &guehdr[1];
210
211         if (guehdr->flags & GUE_FLAG_PRIV) {
212                 data += GUE_LEN_PRIV;
213
214                 /* Process private flags */
215         }
216
217         flush = 0;
218
219         for (p = *head; p; p = p->next) {
220                 const struct guehdr *guehdr2;
221
222                 if (!NAPI_GRO_CB(p)->same_flow)
223                         continue;
224
225                 guehdr2 = (struct guehdr *)(p->data + off);
226
227                 /* Compare base GUE header to be equal (covers
228                  * hlen, version, proto_ctype, and flags.
229                  */
230                 if (guehdr->word != guehdr2->word) {
231                         NAPI_GRO_CB(p)->same_flow = 0;
232                         continue;
233                 }
234
235                 /* Compare optional fields are the same. */
236                 if (guehdr->hlen && memcmp(&guehdr[1], &guehdr2[1],
237                                            guehdr->hlen << 2)) {
238                         NAPI_GRO_CB(p)->same_flow = 0;
239                         continue;
240                 }
241         }
242
243         rcu_read_lock();
244         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
245         ops = rcu_dereference(offloads[guehdr->proto_ctype]);
246         if (WARN_ON(!ops || !ops->callbacks.gro_receive))
247                 goto out_unlock;
248
249         pp = ops->callbacks.gro_receive(head, skb);
250
251 out_unlock:
252         rcu_read_unlock();
253 out:
254         NAPI_GRO_CB(skb)->flush |= flush;
255
256         return pp;
257 }
258
259 static int gue_gro_complete(struct sk_buff *skb, int nhoff)
260 {
261         const struct net_offload **offloads;
262         struct guehdr *guehdr = (struct guehdr *)(skb->data + nhoff);
263         const struct net_offload *ops;
264         unsigned int guehlen;
265         u8 proto;
266         int err = -ENOENT;
267
268         proto = guehdr->proto_ctype;
269
270         guehlen = sizeof(*guehdr) + (guehdr->hlen << 2);
271
272         rcu_read_lock();
273         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
274         ops = rcu_dereference(offloads[proto]);
275         if (WARN_ON(!ops || !ops->callbacks.gro_complete))
276                 goto out_unlock;
277
278         err = ops->callbacks.gro_complete(skb, nhoff + guehlen);
279
280 out_unlock:
281         rcu_read_unlock();
282         return err;
283 }
284
285 static int fou_add_to_port_list(struct fou *fou)
286 {
287         struct fou *fout;
288
289         spin_lock(&fou_lock);
290         list_for_each_entry(fout, &fou_list, list) {
291                 if (fou->port == fout->port) {
292                         spin_unlock(&fou_lock);
293                         return -EALREADY;
294                 }
295         }
296
297         list_add(&fou->list, &fou_list);
298         spin_unlock(&fou_lock);
299
300         return 0;
301 }
302
303 static void fou_release(struct fou *fou)
304 {
305         struct socket *sock = fou->sock;
306         struct sock *sk = sock->sk;
307
308         udp_del_offload(&fou->udp_offloads);
309
310         list_del(&fou->list);
311
312         /* Remove hooks into tunnel socket */
313         sk->sk_user_data = NULL;
314
315         sock_release(sock);
316
317         kfree(fou);
318 }
319
320 static int fou_encap_init(struct sock *sk, struct fou *fou, struct fou_cfg *cfg)
321 {
322         udp_sk(sk)->encap_rcv = fou_udp_recv;
323         fou->protocol = cfg->protocol;
324         fou->udp_offloads.callbacks.gro_receive = fou_gro_receive;
325         fou->udp_offloads.callbacks.gro_complete = fou_gro_complete;
326         fou->udp_offloads.port = cfg->udp_config.local_udp_port;
327         fou->udp_offloads.ipproto = cfg->protocol;
328
329         return 0;
330 }
331
332 static int gue_encap_init(struct sock *sk, struct fou *fou, struct fou_cfg *cfg)
333 {
334         udp_sk(sk)->encap_rcv = gue_udp_recv;
335         fou->udp_offloads.callbacks.gro_receive = gue_gro_receive;
336         fou->udp_offloads.callbacks.gro_complete = gue_gro_complete;
337         fou->udp_offloads.port = cfg->udp_config.local_udp_port;
338
339         return 0;
340 }
341
342 static int fou_create(struct net *net, struct fou_cfg *cfg,
343                       struct socket **sockp)
344 {
345         struct fou *fou = NULL;
346         int err;
347         struct socket *sock = NULL;
348         struct sock *sk;
349
350         /* Open UDP socket */
351         err = udp_sock_create(net, &cfg->udp_config, &sock);
352         if (err < 0)
353                 goto error;
354
355         /* Allocate FOU port structure */
356         fou = kzalloc(sizeof(*fou), GFP_KERNEL);
357         if (!fou) {
358                 err = -ENOMEM;
359                 goto error;
360         }
361
362         sk = sock->sk;
363
364         fou->port = cfg->udp_config.local_udp_port;
365
366         /* Initial for fou type */
367         switch (cfg->type) {
368         case FOU_ENCAP_DIRECT:
369                 err = fou_encap_init(sk, fou, cfg);
370                 if (err)
371                         goto error;
372                 break;
373         case FOU_ENCAP_GUE:
374                 err = gue_encap_init(sk, fou, cfg);
375                 if (err)
376                         goto error;
377                 break;
378         default:
379                 err = -EINVAL;
380                 goto error;
381         }
382
383         udp_sk(sk)->encap_type = 1;
384         udp_encap_enable();
385
386         sk->sk_user_data = fou;
387         fou->sock = sock;
388
389         udp_set_convert_csum(sk, true);
390
391         sk->sk_allocation = GFP_ATOMIC;
392
393         if (cfg->udp_config.family == AF_INET) {
394                 err = udp_add_offload(&fou->udp_offloads);
395                 if (err)
396                         goto error;
397         }
398
399         err = fou_add_to_port_list(fou);
400         if (err)
401                 goto error;
402
403         if (sockp)
404                 *sockp = sock;
405
406         return 0;
407
408 error:
409         kfree(fou);
410         if (sock)
411                 sock_release(sock);
412
413         return err;
414 }
415
416 static int fou_destroy(struct net *net, struct fou_cfg *cfg)
417 {
418         struct fou *fou;
419         u16 port = cfg->udp_config.local_udp_port;
420         int err = -EINVAL;
421
422         spin_lock(&fou_lock);
423         list_for_each_entry(fou, &fou_list, list) {
424                 if (fou->port == port) {
425                         udp_del_offload(&fou->udp_offloads);
426                         fou_release(fou);
427                         err = 0;
428                         break;
429                 }
430         }
431         spin_unlock(&fou_lock);
432
433         return err;
434 }
435
436 static struct genl_family fou_nl_family = {
437         .id             = GENL_ID_GENERATE,
438         .hdrsize        = 0,
439         .name           = FOU_GENL_NAME,
440         .version        = FOU_GENL_VERSION,
441         .maxattr        = FOU_ATTR_MAX,
442         .netnsok        = true,
443 };
444
445 static struct nla_policy fou_nl_policy[FOU_ATTR_MAX + 1] = {
446         [FOU_ATTR_PORT] = { .type = NLA_U16, },
447         [FOU_ATTR_AF] = { .type = NLA_U8, },
448         [FOU_ATTR_IPPROTO] = { .type = NLA_U8, },
449         [FOU_ATTR_TYPE] = { .type = NLA_U8, },
450 };
451
452 static int parse_nl_config(struct genl_info *info,
453                            struct fou_cfg *cfg)
454 {
455         memset(cfg, 0, sizeof(*cfg));
456
457         cfg->udp_config.family = AF_INET;
458
459         if (info->attrs[FOU_ATTR_AF]) {
460                 u8 family = nla_get_u8(info->attrs[FOU_ATTR_AF]);
461
462                 if (family != AF_INET && family != AF_INET6)
463                         return -EINVAL;
464
465                 cfg->udp_config.family = family;
466         }
467
468         if (info->attrs[FOU_ATTR_PORT]) {
469                 u16 port = nla_get_u16(info->attrs[FOU_ATTR_PORT]);
470
471                 cfg->udp_config.local_udp_port = port;
472         }
473
474         if (info->attrs[FOU_ATTR_IPPROTO])
475                 cfg->protocol = nla_get_u8(info->attrs[FOU_ATTR_IPPROTO]);
476
477         if (info->attrs[FOU_ATTR_TYPE])
478                 cfg->type = nla_get_u8(info->attrs[FOU_ATTR_TYPE]);
479
480         return 0;
481 }
482
483 static int fou_nl_cmd_add_port(struct sk_buff *skb, struct genl_info *info)
484 {
485         struct fou_cfg cfg;
486         int err;
487
488         err = parse_nl_config(info, &cfg);
489         if (err)
490                 return err;
491
492         return fou_create(&init_net, &cfg, NULL);
493 }
494
495 static int fou_nl_cmd_rm_port(struct sk_buff *skb, struct genl_info *info)
496 {
497         struct fou_cfg cfg;
498
499         parse_nl_config(info, &cfg);
500
501         return fou_destroy(&init_net, &cfg);
502 }
503
504 static const struct genl_ops fou_nl_ops[] = {
505         {
506                 .cmd = FOU_CMD_ADD,
507                 .doit = fou_nl_cmd_add_port,
508                 .policy = fou_nl_policy,
509                 .flags = GENL_ADMIN_PERM,
510         },
511         {
512                 .cmd = FOU_CMD_DEL,
513                 .doit = fou_nl_cmd_rm_port,
514                 .policy = fou_nl_policy,
515                 .flags = GENL_ADMIN_PERM,
516         },
517 };
518
519 static void fou_build_udp(struct sk_buff *skb, struct ip_tunnel_encap *e,
520                           struct flowi4 *fl4, u8 *protocol, __be16 sport)
521 {
522         struct udphdr *uh;
523
524         skb_push(skb, sizeof(struct udphdr));
525         skb_reset_transport_header(skb);
526
527         uh = udp_hdr(skb);
528
529         uh->dest = e->dport;
530         uh->source = sport;
531         uh->len = htons(skb->len);
532         uh->check = 0;
533         udp_set_csum(!(e->flags & TUNNEL_ENCAP_FLAG_CSUM), skb,
534                      fl4->saddr, fl4->daddr, skb->len);
535
536         *protocol = IPPROTO_UDP;
537 }
538
539 int fou_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
540                      u8 *protocol, struct flowi4 *fl4)
541 {
542         bool csum = !!(e->flags & TUNNEL_ENCAP_FLAG_CSUM);
543         int type = csum ? SKB_GSO_UDP_TUNNEL_CSUM : SKB_GSO_UDP_TUNNEL;
544         __be16 sport;
545
546         skb = iptunnel_handle_offloads(skb, csum, type);
547
548         if (IS_ERR(skb))
549                 return PTR_ERR(skb);
550
551         sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev),
552                                                skb, 0, 0, false);
553         fou_build_udp(skb, e, fl4, protocol, sport);
554
555         return 0;
556 }
557 EXPORT_SYMBOL(fou_build_header);
558
559 int gue_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
560                      u8 *protocol, struct flowi4 *fl4)
561 {
562         bool csum = !!(e->flags & TUNNEL_ENCAP_FLAG_CSUM);
563         int type = csum ? SKB_GSO_UDP_TUNNEL_CSUM : SKB_GSO_UDP_TUNNEL;
564         struct guehdr *guehdr;
565         size_t hdrlen, optlen = 0;
566         __be16 sport;
567         void *data;
568         bool need_priv = false;
569
570         if ((e->flags & TUNNEL_ENCAP_FLAG_REMCSUM) &&
571             skb->ip_summed == CHECKSUM_PARTIAL) {
572                 csum = false;
573                 optlen += GUE_PLEN_REMCSUM;
574                 type |= SKB_GSO_TUNNEL_REMCSUM;
575                 need_priv = true;
576         }
577
578         optlen += need_priv ? GUE_LEN_PRIV : 0;
579
580         skb = iptunnel_handle_offloads(skb, csum, type);
581
582         if (IS_ERR(skb))
583                 return PTR_ERR(skb);
584
585         /* Get source port (based on flow hash) before skb_push */
586         sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev),
587                                                skb, 0, 0, false);
588
589         hdrlen = sizeof(struct guehdr) + optlen;
590
591         skb_push(skb, hdrlen);
592
593         guehdr = (struct guehdr *)skb->data;
594
595         guehdr->control = 0;
596         guehdr->version = 0;
597         guehdr->hlen = optlen >> 2;
598         guehdr->flags = 0;
599         guehdr->proto_ctype = *protocol;
600
601         data = &guehdr[1];
602
603         if (need_priv) {
604                 __be32 *flags = data;
605
606                 guehdr->flags |= GUE_FLAG_PRIV;
607                 *flags = 0;
608                 data += GUE_LEN_PRIV;
609
610                 if (type & SKB_GSO_TUNNEL_REMCSUM) {
611                         u16 csum_start = skb_checksum_start_offset(skb);
612                         __be16 *pd = data;
613
614                         if (csum_start < hdrlen)
615                                 return -EINVAL;
616
617                         csum_start -= hdrlen;
618                         pd[0] = htons(csum_start);
619                         pd[1] = htons(csum_start + skb->csum_offset);
620
621                         if (!skb_is_gso(skb)) {
622                                 skb->ip_summed = CHECKSUM_NONE;
623                                 skb->encapsulation = 0;
624                         }
625
626                         *flags |= GUE_PFLAG_REMCSUM;
627                         data += GUE_PLEN_REMCSUM;
628                 }
629
630         }
631
632         fou_build_udp(skb, e, fl4, protocol, sport);
633
634         return 0;
635 }
636 EXPORT_SYMBOL(gue_build_header);
637
638 static int __init fou_init(void)
639 {
640         int ret;
641
642         ret = genl_register_family_with_ops(&fou_nl_family,
643                                             fou_nl_ops);
644
645         return ret;
646 }
647
648 static void __exit fou_fini(void)
649 {
650         struct fou *fou, *next;
651
652         genl_unregister_family(&fou_nl_family);
653
654         /* Close all the FOU sockets */
655
656         spin_lock(&fou_lock);
657         list_for_each_entry_safe(fou, next, &fou_list, list)
658                 fou_release(fou);
659         spin_unlock(&fou_lock);
660 }
661
662 module_init(fou_init);
663 module_exit(fou_fini);
664 MODULE_AUTHOR("Tom Herbert <therbert@google.com>");
665 MODULE_LICENSE("GPL");