1 #include <linux/module.h>
2 #include <linux/errno.h>
3 #include <linux/socket.h>
4 #include <linux/skbuff.h>
7 #include <linux/types.h>
8 #include <linux/kernel.h>
9 #include <net/genetlink.h>
12 #include <net/protocol.h>
14 #include <net/udp_tunnel.h>
16 #include <uapi/linux/fou.h>
17 #include <uapi/linux/genetlink.h>
19 static DEFINE_SPINLOCK(fou_lock);
20 static LIST_HEAD(fou_list);
26 struct udp_offload udp_offloads;
27 struct list_head list;
33 struct udp_port_cfg udp_config;
36 static inline struct fou *fou_from_sock(struct sock *sk)
38 return sk->sk_user_data;
41 static void fou_recv_pull(struct sk_buff *skb, size_t len)
43 struct iphdr *iph = ip_hdr(skb);
45 /* Remove 'len' bytes from the packet (UDP header and
46 * FOU header if present).
48 iph->tot_len = htons(ntohs(iph->tot_len) - len);
50 skb_postpull_rcsum(skb, udp_hdr(skb), len);
51 skb_reset_transport_header(skb);
54 static int fou_udp_recv(struct sock *sk, struct sk_buff *skb)
56 struct fou *fou = fou_from_sock(sk);
61 fou_recv_pull(skb, sizeof(struct udphdr));
63 return -fou->protocol;
66 static struct guehdr *gue_remcsum(struct sk_buff *skb, struct guehdr *guehdr,
67 void *data, size_t hdrlen, u8 ipproto)
70 size_t start = ntohs(pd[0]);
71 size_t offset = ntohs(pd[1]);
72 size_t plen = hdrlen + max_t(size_t, offset + sizeof(u16), start);
74 if (skb->remcsum_offload) {
75 /* Already processed in GRO path */
76 skb->remcsum_offload = 0;
80 if (!pskb_may_pull(skb, plen))
82 guehdr = (struct guehdr *)&udp_hdr(skb)[1];
84 skb_remcsum_process(skb, (void *)guehdr + hdrlen, start, offset);
89 static int gue_control_message(struct sk_buff *skb, struct guehdr *guehdr)
96 static int gue_udp_recv(struct sock *sk, struct sk_buff *skb)
98 struct fou *fou = fou_from_sock(sk);
99 size_t len, optlen, hdrlen;
100 struct guehdr *guehdr;
107 len = sizeof(struct udphdr) + sizeof(struct guehdr);
108 if (!pskb_may_pull(skb, len))
111 guehdr = (struct guehdr *)&udp_hdr(skb)[1];
113 optlen = guehdr->hlen << 2;
116 if (!pskb_may_pull(skb, len))
119 /* guehdr may change after pull */
120 guehdr = (struct guehdr *)&udp_hdr(skb)[1];
122 hdrlen = sizeof(struct guehdr) + optlen;
124 if (guehdr->version != 0 || validate_gue_flags(guehdr, optlen))
127 hdrlen = sizeof(struct guehdr) + optlen;
129 ip_hdr(skb)->tot_len = htons(ntohs(ip_hdr(skb)->tot_len) - len);
131 /* Pull csum through the guehdr now . This can be used if
132 * there is a remote checksum offload.
134 skb_postpull_rcsum(skb, udp_hdr(skb), len);
138 if (guehdr->flags & GUE_FLAG_PRIV) {
139 __be32 flags = *(__be32 *)(data + doffset);
141 doffset += GUE_LEN_PRIV;
143 if (flags & GUE_PFLAG_REMCSUM) {
144 guehdr = gue_remcsum(skb, guehdr, data + doffset,
145 hdrlen, guehdr->proto_ctype);
151 doffset += GUE_PLEN_REMCSUM;
155 if (unlikely(guehdr->control))
156 return gue_control_message(skb, guehdr);
158 __skb_pull(skb, sizeof(struct udphdr) + hdrlen);
159 skb_reset_transport_header(skb);
161 return -guehdr->proto_ctype;
168 static struct sk_buff **fou_gro_receive(struct sk_buff **head,
170 struct udp_offload *uoff)
172 const struct net_offload *ops;
173 struct sk_buff **pp = NULL;
174 u8 proto = NAPI_GRO_CB(skb)->proto;
175 const struct net_offload **offloads;
178 offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
179 ops = rcu_dereference(offloads[proto]);
180 if (!ops || !ops->callbacks.gro_receive)
183 pp = ops->callbacks.gro_receive(head, skb);
191 static int fou_gro_complete(struct sk_buff *skb, int nhoff,
192 struct udp_offload *uoff)
194 const struct net_offload *ops;
195 u8 proto = NAPI_GRO_CB(skb)->proto;
197 const struct net_offload **offloads;
199 udp_tunnel_gro_complete(skb, nhoff);
202 offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
203 ops = rcu_dereference(offloads[proto]);
204 if (WARN_ON(!ops || !ops->callbacks.gro_complete))
207 err = ops->callbacks.gro_complete(skb, nhoff);
215 static struct guehdr *gue_gro_remcsum(struct sk_buff *skb, unsigned int off,
216 struct guehdr *guehdr, void *data,
217 size_t hdrlen, u8 ipproto)
220 size_t start = ntohs(pd[0]);
221 size_t offset = ntohs(pd[1]);
222 size_t plen = hdrlen + max_t(size_t, offset + sizeof(u16), start);
224 if (skb->remcsum_offload)
227 if (!NAPI_GRO_CB(skb)->csum_valid)
230 /* Pull checksum that will be written */
231 if (skb_gro_header_hard(skb, off + plen)) {
232 guehdr = skb_gro_header_slow(skb, off + plen, off);
237 skb_gro_remcsum_process(skb, (void *)guehdr + hdrlen, start, offset);
239 skb->remcsum_offload = 1;
244 static struct sk_buff **gue_gro_receive(struct sk_buff **head,
246 struct udp_offload *uoff)
248 const struct net_offload **offloads;
249 const struct net_offload *ops;
250 struct sk_buff **pp = NULL;
252 struct guehdr *guehdr;
253 size_t len, optlen, hdrlen, off;
258 off = skb_gro_offset(skb);
259 len = off + sizeof(*guehdr);
261 guehdr = skb_gro_header_fast(skb, off);
262 if (skb_gro_header_hard(skb, len)) {
263 guehdr = skb_gro_header_slow(skb, len, off);
264 if (unlikely(!guehdr))
268 optlen = guehdr->hlen << 2;
271 if (skb_gro_header_hard(skb, len)) {
272 guehdr = skb_gro_header_slow(skb, len, off);
273 if (unlikely(!guehdr))
277 if (unlikely(guehdr->control) || guehdr->version != 0 ||
278 validate_gue_flags(guehdr, optlen))
281 hdrlen = sizeof(*guehdr) + optlen;
283 /* Adjust NAPI_GRO_CB(skb)->csum to account for guehdr,
284 * this is needed if there is a remote checkcsum offload.
286 skb_gro_postpull_rcsum(skb, guehdr, hdrlen);
290 if (guehdr->flags & GUE_FLAG_PRIV) {
291 __be32 flags = *(__be32 *)(data + doffset);
293 doffset += GUE_LEN_PRIV;
295 if (flags & GUE_PFLAG_REMCSUM) {
296 guehdr = gue_gro_remcsum(skb, off, guehdr,
297 data + doffset, hdrlen,
298 guehdr->proto_ctype);
304 doffset += GUE_PLEN_REMCSUM;
308 skb_gro_pull(skb, hdrlen);
312 for (p = *head; p; p = p->next) {
313 const struct guehdr *guehdr2;
315 if (!NAPI_GRO_CB(p)->same_flow)
318 guehdr2 = (struct guehdr *)(p->data + off);
320 /* Compare base GUE header to be equal (covers
321 * hlen, version, proto_ctype, and flags.
323 if (guehdr->word != guehdr2->word) {
324 NAPI_GRO_CB(p)->same_flow = 0;
328 /* Compare optional fields are the same. */
329 if (guehdr->hlen && memcmp(&guehdr[1], &guehdr2[1],
330 guehdr->hlen << 2)) {
331 NAPI_GRO_CB(p)->same_flow = 0;
337 offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
338 ops = rcu_dereference(offloads[guehdr->proto_ctype]);
339 if (WARN_ON(!ops || !ops->callbacks.gro_receive))
342 pp = ops->callbacks.gro_receive(head, skb);
347 NAPI_GRO_CB(skb)->flush |= flush;
352 static int gue_gro_complete(struct sk_buff *skb, int nhoff,
353 struct udp_offload *uoff)
355 const struct net_offload **offloads;
356 struct guehdr *guehdr = (struct guehdr *)(skb->data + nhoff);
357 const struct net_offload *ops;
358 unsigned int guehlen;
362 proto = guehdr->proto_ctype;
364 guehlen = sizeof(*guehdr) + (guehdr->hlen << 2);
367 offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
368 ops = rcu_dereference(offloads[proto]);
369 if (WARN_ON(!ops || !ops->callbacks.gro_complete))
372 err = ops->callbacks.gro_complete(skb, nhoff + guehlen);
379 static int fou_add_to_port_list(struct fou *fou)
383 spin_lock(&fou_lock);
384 list_for_each_entry(fout, &fou_list, list) {
385 if (fou->port == fout->port) {
386 spin_unlock(&fou_lock);
391 list_add(&fou->list, &fou_list);
392 spin_unlock(&fou_lock);
397 static void fou_release(struct fou *fou)
399 struct socket *sock = fou->sock;
400 struct sock *sk = sock->sk;
402 udp_del_offload(&fou->udp_offloads);
404 list_del(&fou->list);
406 /* Remove hooks into tunnel socket */
407 sk->sk_user_data = NULL;
414 static int fou_encap_init(struct sock *sk, struct fou *fou, struct fou_cfg *cfg)
416 udp_sk(sk)->encap_rcv = fou_udp_recv;
417 fou->protocol = cfg->protocol;
418 fou->udp_offloads.callbacks.gro_receive = fou_gro_receive;
419 fou->udp_offloads.callbacks.gro_complete = fou_gro_complete;
420 fou->udp_offloads.port = cfg->udp_config.local_udp_port;
421 fou->udp_offloads.ipproto = cfg->protocol;
426 static int gue_encap_init(struct sock *sk, struct fou *fou, struct fou_cfg *cfg)
428 udp_sk(sk)->encap_rcv = gue_udp_recv;
429 fou->udp_offloads.callbacks.gro_receive = gue_gro_receive;
430 fou->udp_offloads.callbacks.gro_complete = gue_gro_complete;
431 fou->udp_offloads.port = cfg->udp_config.local_udp_port;
436 static int fou_create(struct net *net, struct fou_cfg *cfg,
437 struct socket **sockp)
439 struct fou *fou = NULL;
441 struct socket *sock = NULL;
444 /* Open UDP socket */
445 err = udp_sock_create(net, &cfg->udp_config, &sock);
449 /* Allocate FOU port structure */
450 fou = kzalloc(sizeof(*fou), GFP_KERNEL);
458 fou->port = cfg->udp_config.local_udp_port;
460 /* Initial for fou type */
462 case FOU_ENCAP_DIRECT:
463 err = fou_encap_init(sk, fou, cfg);
468 err = gue_encap_init(sk, fou, cfg);
477 udp_sk(sk)->encap_type = 1;
480 sk->sk_user_data = fou;
483 inet_inc_convert_csum(sk);
485 sk->sk_allocation = GFP_ATOMIC;
487 if (cfg->udp_config.family == AF_INET) {
488 err = udp_add_offload(&fou->udp_offloads);
493 err = fou_add_to_port_list(fou);
510 static int fou_destroy(struct net *net, struct fou_cfg *cfg)
513 u16 port = cfg->udp_config.local_udp_port;
516 spin_lock(&fou_lock);
517 list_for_each_entry(fou, &fou_list, list) {
518 if (fou->port == port) {
519 udp_del_offload(&fou->udp_offloads);
525 spin_unlock(&fou_lock);
530 static struct genl_family fou_nl_family = {
531 .id = GENL_ID_GENERATE,
533 .name = FOU_GENL_NAME,
534 .version = FOU_GENL_VERSION,
535 .maxattr = FOU_ATTR_MAX,
539 static struct nla_policy fou_nl_policy[FOU_ATTR_MAX + 1] = {
540 [FOU_ATTR_PORT] = { .type = NLA_U16, },
541 [FOU_ATTR_AF] = { .type = NLA_U8, },
542 [FOU_ATTR_IPPROTO] = { .type = NLA_U8, },
543 [FOU_ATTR_TYPE] = { .type = NLA_U8, },
546 static int parse_nl_config(struct genl_info *info,
549 memset(cfg, 0, sizeof(*cfg));
551 cfg->udp_config.family = AF_INET;
553 if (info->attrs[FOU_ATTR_AF]) {
554 u8 family = nla_get_u8(info->attrs[FOU_ATTR_AF]);
556 if (family != AF_INET && family != AF_INET6)
559 cfg->udp_config.family = family;
562 if (info->attrs[FOU_ATTR_PORT]) {
563 u16 port = nla_get_u16(info->attrs[FOU_ATTR_PORT]);
565 cfg->udp_config.local_udp_port = port;
568 if (info->attrs[FOU_ATTR_IPPROTO])
569 cfg->protocol = nla_get_u8(info->attrs[FOU_ATTR_IPPROTO]);
571 if (info->attrs[FOU_ATTR_TYPE])
572 cfg->type = nla_get_u8(info->attrs[FOU_ATTR_TYPE]);
577 static int fou_nl_cmd_add_port(struct sk_buff *skb, struct genl_info *info)
582 err = parse_nl_config(info, &cfg);
586 return fou_create(&init_net, &cfg, NULL);
589 static int fou_nl_cmd_rm_port(struct sk_buff *skb, struct genl_info *info)
593 parse_nl_config(info, &cfg);
595 return fou_destroy(&init_net, &cfg);
598 static const struct genl_ops fou_nl_ops[] = {
601 .doit = fou_nl_cmd_add_port,
602 .policy = fou_nl_policy,
603 .flags = GENL_ADMIN_PERM,
607 .doit = fou_nl_cmd_rm_port,
608 .policy = fou_nl_policy,
609 .flags = GENL_ADMIN_PERM,
613 size_t fou_encap_hlen(struct ip_tunnel_encap *e)
615 return sizeof(struct udphdr);
617 EXPORT_SYMBOL(fou_encap_hlen);
619 size_t gue_encap_hlen(struct ip_tunnel_encap *e)
622 bool need_priv = false;
624 len = sizeof(struct udphdr) + sizeof(struct guehdr);
626 if (e->flags & TUNNEL_ENCAP_FLAG_REMCSUM) {
627 len += GUE_PLEN_REMCSUM;
631 len += need_priv ? GUE_LEN_PRIV : 0;
635 EXPORT_SYMBOL(gue_encap_hlen);
637 static void fou_build_udp(struct sk_buff *skb, struct ip_tunnel_encap *e,
638 struct flowi4 *fl4, u8 *protocol, __be16 sport)
642 skb_push(skb, sizeof(struct udphdr));
643 skb_reset_transport_header(skb);
649 uh->len = htons(skb->len);
651 udp_set_csum(!(e->flags & TUNNEL_ENCAP_FLAG_CSUM), skb,
652 fl4->saddr, fl4->daddr, skb->len);
654 *protocol = IPPROTO_UDP;
657 int fou_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
658 u8 *protocol, struct flowi4 *fl4)
660 bool csum = !!(e->flags & TUNNEL_ENCAP_FLAG_CSUM);
661 int type = csum ? SKB_GSO_UDP_TUNNEL_CSUM : SKB_GSO_UDP_TUNNEL;
664 skb = iptunnel_handle_offloads(skb, csum, type);
669 sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev),
671 fou_build_udp(skb, e, fl4, protocol, sport);
675 EXPORT_SYMBOL(fou_build_header);
677 int gue_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
678 u8 *protocol, struct flowi4 *fl4)
680 bool csum = !!(e->flags & TUNNEL_ENCAP_FLAG_CSUM);
681 int type = csum ? SKB_GSO_UDP_TUNNEL_CSUM : SKB_GSO_UDP_TUNNEL;
682 struct guehdr *guehdr;
683 size_t hdrlen, optlen = 0;
686 bool need_priv = false;
688 if ((e->flags & TUNNEL_ENCAP_FLAG_REMCSUM) &&
689 skb->ip_summed == CHECKSUM_PARTIAL) {
691 optlen += GUE_PLEN_REMCSUM;
692 type |= SKB_GSO_TUNNEL_REMCSUM;
696 optlen += need_priv ? GUE_LEN_PRIV : 0;
698 skb = iptunnel_handle_offloads(skb, csum, type);
703 /* Get source port (based on flow hash) before skb_push */
704 sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev),
707 hdrlen = sizeof(struct guehdr) + optlen;
709 skb_push(skb, hdrlen);
711 guehdr = (struct guehdr *)skb->data;
715 guehdr->hlen = optlen >> 2;
717 guehdr->proto_ctype = *protocol;
722 __be32 *flags = data;
724 guehdr->flags |= GUE_FLAG_PRIV;
726 data += GUE_LEN_PRIV;
728 if (type & SKB_GSO_TUNNEL_REMCSUM) {
729 u16 csum_start = skb_checksum_start_offset(skb);
732 if (csum_start < hdrlen)
735 csum_start -= hdrlen;
736 pd[0] = htons(csum_start);
737 pd[1] = htons(csum_start + skb->csum_offset);
739 if (!skb_is_gso(skb)) {
740 skb->ip_summed = CHECKSUM_NONE;
741 skb->encapsulation = 0;
744 *flags |= GUE_PFLAG_REMCSUM;
745 data += GUE_PLEN_REMCSUM;
750 fou_build_udp(skb, e, fl4, protocol, sport);
754 EXPORT_SYMBOL(gue_build_header);
756 #ifdef CONFIG_NET_FOU_IP_TUNNELS
758 static const struct ip_tunnel_encap_ops __read_mostly fou_iptun_ops = {
759 .encap_hlen = fou_encap_hlen,
760 .build_header = fou_build_header,
763 static const struct ip_tunnel_encap_ops __read_mostly gue_iptun_ops = {
764 .encap_hlen = gue_encap_hlen,
765 .build_header = gue_build_header,
768 static int ip_tunnel_encap_add_fou_ops(void)
772 ret = ip_tunnel_encap_add_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
774 pr_err("can't add fou ops\n");
778 ret = ip_tunnel_encap_add_ops(&gue_iptun_ops, TUNNEL_ENCAP_GUE);
780 pr_err("can't add gue ops\n");
781 ip_tunnel_encap_del_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
788 static void ip_tunnel_encap_del_fou_ops(void)
790 ip_tunnel_encap_del_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
791 ip_tunnel_encap_del_ops(&gue_iptun_ops, TUNNEL_ENCAP_GUE);
796 static int ip_tunnel_encap_add_fou_ops(void)
801 static void ip_tunnel_encap_del_fou_ops(void)
807 static int __init fou_init(void)
811 ret = genl_register_family_with_ops(&fou_nl_family,
817 ret = ip_tunnel_encap_add_fou_ops();
819 genl_unregister_family(&fou_nl_family);
825 static void __exit fou_fini(void)
827 struct fou *fou, *next;
829 ip_tunnel_encap_del_fou_ops();
831 genl_unregister_family(&fou_nl_family);
833 /* Close all the FOU sockets */
835 spin_lock(&fou_lock);
836 list_for_each_entry_safe(fou, next, &fou_list, list)
838 spin_unlock(&fou_lock);
841 module_init(fou_init);
842 module_exit(fou_fini);
843 MODULE_AUTHOR("Tom Herbert <therbert@google.com>");
844 MODULE_LICENSE("GPL");