1 #include <linux/module.h>
2 #include <linux/errno.h>
3 #include <linux/socket.h>
4 #include <linux/skbuff.h>
7 #include <linux/types.h>
8 #include <linux/kernel.h>
9 #include <net/genetlink.h>
12 #include <net/protocol.h>
14 #include <net/udp_tunnel.h>
16 #include <uapi/linux/fou.h>
17 #include <uapi/linux/genetlink.h>
19 static DEFINE_SPINLOCK(fou_lock);
20 static LIST_HEAD(fou_list);
26 struct udp_offload udp_offloads;
27 struct list_head list;
33 struct udp_port_cfg udp_config;
36 static inline struct fou *fou_from_sock(struct sock *sk)
38 return sk->sk_user_data;
41 static int fou_udp_encap_recv_deliver(struct sk_buff *skb,
42 u8 protocol, size_t len)
44 struct iphdr *iph = ip_hdr(skb);
46 /* Remove 'len' bytes from the packet (UDP header and
47 * FOU header if present), modify the protocol to the one
48 * we found, and then call rcv_encap.
50 iph->tot_len = htons(ntohs(iph->tot_len) - len);
52 skb_postpull_rcsum(skb, udp_hdr(skb), len);
53 skb_reset_transport_header(skb);
58 static int fou_udp_recv(struct sock *sk, struct sk_buff *skb)
60 struct fou *fou = fou_from_sock(sk);
65 return fou_udp_encap_recv_deliver(skb, fou->protocol,
66 sizeof(struct udphdr));
69 static int gue_udp_recv(struct sock *sk, struct sk_buff *skb)
71 struct fou *fou = fou_from_sock(sk);
73 struct guehdr *guehdr;
79 len = sizeof(struct udphdr) + sizeof(struct guehdr);
80 if (!pskb_may_pull(skb, len))
84 guehdr = (struct guehdr *)&uh[1];
86 len += guehdr->hlen << 2;
87 if (!pskb_may_pull(skb, len))
90 if (guehdr->version != 0)
98 return fou_udp_encap_recv_deliver(skb, guehdr->next_hdr, len);
104 static struct sk_buff **fou_gro_receive(struct sk_buff **head,
107 const struct net_offload *ops;
108 struct sk_buff **pp = NULL;
109 u8 proto = NAPI_GRO_CB(skb)->proto;
110 const struct net_offload **offloads;
113 offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
114 ops = rcu_dereference(offloads[proto]);
115 if (!ops || !ops->callbacks.gro_receive)
118 pp = ops->callbacks.gro_receive(head, skb);
126 static int fou_gro_complete(struct sk_buff *skb, int nhoff)
128 const struct net_offload *ops;
129 u8 proto = NAPI_GRO_CB(skb)->proto;
131 const struct net_offload **offloads;
134 offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
135 ops = rcu_dereference(offloads[proto]);
136 if (WARN_ON(!ops || !ops->callbacks.gro_complete))
139 err = ops->callbacks.gro_complete(skb, nhoff);
147 static struct sk_buff **gue_gro_receive(struct sk_buff **head,
150 const struct net_offload **offloads;
151 const struct net_offload *ops;
152 struct sk_buff **pp = NULL;
155 struct guehdr *guehdr;
156 unsigned int hlen, guehlen;
160 off = skb_gro_offset(skb);
161 hlen = off + sizeof(*guehdr);
162 guehdr = skb_gro_header_fast(skb, off);
163 if (skb_gro_header_hard(skb, hlen)) {
164 guehdr = skb_gro_header_slow(skb, hlen, off);
165 if (unlikely(!guehdr))
169 proto = guehdr->next_hdr;
172 offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
173 ops = rcu_dereference(offloads[proto]);
174 if (WARN_ON(!ops || !ops->callbacks.gro_receive))
177 guehlen = sizeof(*guehdr) + (guehdr->hlen << 2);
179 hlen = off + guehlen;
180 if (skb_gro_header_hard(skb, hlen)) {
181 guehdr = skb_gro_header_slow(skb, hlen, off);
182 if (unlikely(!guehdr))
188 for (p = *head; p; p = p->next) {
189 const struct guehdr *guehdr2;
191 if (!NAPI_GRO_CB(p)->same_flow)
194 guehdr2 = (struct guehdr *)(p->data + off);
196 /* Compare base GUE header to be equal (covers
197 * hlen, version, next_hdr, and flags.
199 if (guehdr->word != guehdr2->word) {
200 NAPI_GRO_CB(p)->same_flow = 0;
204 /* Compare optional fields are the same. */
205 if (guehdr->hlen && memcmp(&guehdr[1], &guehdr2[1],
206 guehdr->hlen << 2)) {
207 NAPI_GRO_CB(p)->same_flow = 0;
212 skb_gro_pull(skb, guehlen);
214 /* Adjusted NAPI_GRO_CB(skb)->csum after skb_gro_pull()*/
215 skb_gro_postpull_rcsum(skb, guehdr, guehlen);
217 pp = ops->callbacks.gro_receive(head, skb);
222 NAPI_GRO_CB(skb)->flush |= flush;
227 static int gue_gro_complete(struct sk_buff *skb, int nhoff)
229 const struct net_offload **offloads;
230 struct guehdr *guehdr = (struct guehdr *)(skb->data + nhoff);
231 const struct net_offload *ops;
232 unsigned int guehlen;
236 proto = guehdr->next_hdr;
238 guehlen = sizeof(*guehdr) + (guehdr->hlen << 2);
241 offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
242 ops = rcu_dereference(offloads[proto]);
243 if (WARN_ON(!ops || !ops->callbacks.gro_complete))
246 err = ops->callbacks.gro_complete(skb, nhoff + guehlen);
253 static int fou_add_to_port_list(struct fou *fou)
257 spin_lock(&fou_lock);
258 list_for_each_entry(fout, &fou_list, list) {
259 if (fou->port == fout->port) {
260 spin_unlock(&fou_lock);
265 list_add(&fou->list, &fou_list);
266 spin_unlock(&fou_lock);
271 static void fou_release(struct fou *fou)
273 struct socket *sock = fou->sock;
274 struct sock *sk = sock->sk;
276 udp_del_offload(&fou->udp_offloads);
278 list_del(&fou->list);
280 /* Remove hooks into tunnel socket */
281 sk->sk_user_data = NULL;
288 static int fou_encap_init(struct sock *sk, struct fou *fou, struct fou_cfg *cfg)
290 udp_sk(sk)->encap_rcv = fou_udp_recv;
291 fou->protocol = cfg->protocol;
292 fou->udp_offloads.callbacks.gro_receive = fou_gro_receive;
293 fou->udp_offloads.callbacks.gro_complete = fou_gro_complete;
294 fou->udp_offloads.port = cfg->udp_config.local_udp_port;
295 fou->udp_offloads.ipproto = cfg->protocol;
300 static int gue_encap_init(struct sock *sk, struct fou *fou, struct fou_cfg *cfg)
302 udp_sk(sk)->encap_rcv = gue_udp_recv;
303 fou->udp_offloads.callbacks.gro_receive = gue_gro_receive;
304 fou->udp_offloads.callbacks.gro_complete = gue_gro_complete;
305 fou->udp_offloads.port = cfg->udp_config.local_udp_port;
310 static int fou_create(struct net *net, struct fou_cfg *cfg,
311 struct socket **sockp)
313 struct fou *fou = NULL;
315 struct socket *sock = NULL;
318 /* Open UDP socket */
319 err = udp_sock_create(net, &cfg->udp_config, &sock);
323 /* Allocate FOU port structure */
324 fou = kzalloc(sizeof(*fou), GFP_KERNEL);
332 fou->port = cfg->udp_config.local_udp_port;
334 /* Initial for fou type */
336 case FOU_ENCAP_DIRECT:
337 err = fou_encap_init(sk, fou, cfg);
342 err = gue_encap_init(sk, fou, cfg);
351 udp_sk(sk)->encap_type = 1;
354 sk->sk_user_data = fou;
357 udp_set_convert_csum(sk, true);
359 sk->sk_allocation = GFP_ATOMIC;
361 if (cfg->udp_config.family == AF_INET) {
362 err = udp_add_offload(&fou->udp_offloads);
367 err = fou_add_to_port_list(fou);
384 static int fou_destroy(struct net *net, struct fou_cfg *cfg)
387 u16 port = cfg->udp_config.local_udp_port;
390 spin_lock(&fou_lock);
391 list_for_each_entry(fou, &fou_list, list) {
392 if (fou->port == port) {
393 udp_del_offload(&fou->udp_offloads);
399 spin_unlock(&fou_lock);
404 static struct genl_family fou_nl_family = {
405 .id = GENL_ID_GENERATE,
407 .name = FOU_GENL_NAME,
408 .version = FOU_GENL_VERSION,
409 .maxattr = FOU_ATTR_MAX,
413 static struct nla_policy fou_nl_policy[FOU_ATTR_MAX + 1] = {
414 [FOU_ATTR_PORT] = { .type = NLA_U16, },
415 [FOU_ATTR_AF] = { .type = NLA_U8, },
416 [FOU_ATTR_IPPROTO] = { .type = NLA_U8, },
417 [FOU_ATTR_TYPE] = { .type = NLA_U8, },
420 static int parse_nl_config(struct genl_info *info,
423 memset(cfg, 0, sizeof(*cfg));
425 cfg->udp_config.family = AF_INET;
427 if (info->attrs[FOU_ATTR_AF]) {
428 u8 family = nla_get_u8(info->attrs[FOU_ATTR_AF]);
430 if (family != AF_INET && family != AF_INET6)
433 cfg->udp_config.family = family;
436 if (info->attrs[FOU_ATTR_PORT]) {
437 u16 port = nla_get_u16(info->attrs[FOU_ATTR_PORT]);
439 cfg->udp_config.local_udp_port = port;
442 if (info->attrs[FOU_ATTR_IPPROTO])
443 cfg->protocol = nla_get_u8(info->attrs[FOU_ATTR_IPPROTO]);
445 if (info->attrs[FOU_ATTR_TYPE])
446 cfg->type = nla_get_u8(info->attrs[FOU_ATTR_TYPE]);
451 static int fou_nl_cmd_add_port(struct sk_buff *skb, struct genl_info *info)
456 err = parse_nl_config(info, &cfg);
460 return fou_create(&init_net, &cfg, NULL);
463 static int fou_nl_cmd_rm_port(struct sk_buff *skb, struct genl_info *info)
467 parse_nl_config(info, &cfg);
469 return fou_destroy(&init_net, &cfg);
472 static const struct genl_ops fou_nl_ops[] = {
475 .doit = fou_nl_cmd_add_port,
476 .policy = fou_nl_policy,
477 .flags = GENL_ADMIN_PERM,
481 .doit = fou_nl_cmd_rm_port,
482 .policy = fou_nl_policy,
483 .flags = GENL_ADMIN_PERM,
487 static int __init fou_init(void)
491 ret = genl_register_family_with_ops(&fou_nl_family,
497 static void __exit fou_fini(void)
499 struct fou *fou, *next;
501 genl_unregister_family(&fou_nl_family);
503 /* Close all the FOU sockets */
505 spin_lock(&fou_lock);
506 list_for_each_entry_safe(fou, next, &fou_list, list)
508 spin_unlock(&fou_lock);
511 module_init(fou_init);
512 module_exit(fou_fini);
513 MODULE_AUTHOR("Tom Herbert <therbert@google.com>");
514 MODULE_LICENSE("GPL");