mpls: Functions for reading and wrinting mpls labels over netlink
[firefly-linux-kernel-4.4.55.git] / net / mpls / af_mpls.c
1 #include <linux/types.h>
2 #include <linux/skbuff.h>
3 #include <linux/socket.h>
4 #include <linux/sysctl.h>
5 #include <linux/net.h>
6 #include <linux/module.h>
7 #include <linux/if_arp.h>
8 #include <linux/ipv6.h>
9 #include <linux/mpls.h>
10 #include <net/ip.h>
11 #include <net/dst.h>
12 #include <net/sock.h>
13 #include <net/arp.h>
14 #include <net/ip_fib.h>
15 #include <net/netevent.h>
16 #include <net/netns/generic.h>
17 #include "internal.h"
18
19 #define LABEL_NOT_SPECIFIED (1<<20)
20 #define MAX_NEW_LABELS 2
21
22 /* This maximum ha length copied from the definition of struct neighbour */
23 #define MAX_VIA_ALEN (ALIGN(MAX_ADDR_LEN, sizeof(unsigned long)))
24
25 struct mpls_route { /* next hop label forwarding entry */
26         struct net_device       *rt_dev;
27         struct rcu_head         rt_rcu;
28         u32                     rt_label[MAX_NEW_LABELS];
29         u8                      rt_protocol; /* routing protocol that set this entry */
30         u8                      rt_labels:2,
31                                 rt_via_alen:6;
32         unsigned short          rt_via_family;
33         u8                      rt_via[0];
34 };
35
36 static int zero = 0;
37 static int label_limit = (1 << 20) - 1;
38
39 static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned index)
40 {
41         struct mpls_route *rt = NULL;
42
43         if (index < net->mpls.platform_labels) {
44                 struct mpls_route __rcu **platform_label =
45                         rcu_dereference(net->mpls.platform_label);
46                 rt = rcu_dereference(platform_label[index]);
47         }
48         return rt;
49 }
50
51 static bool mpls_output_possible(const struct net_device *dev)
52 {
53         return dev && (dev->flags & IFF_UP) && netif_carrier_ok(dev);
54 }
55
56 static unsigned int mpls_rt_header_size(const struct mpls_route *rt)
57 {
58         /* The size of the layer 2.5 labels to be added for this route */
59         return rt->rt_labels * sizeof(struct mpls_shim_hdr);
60 }
61
62 static unsigned int mpls_dev_mtu(const struct net_device *dev)
63 {
64         /* The amount of data the layer 2 frame can hold */
65         return dev->mtu;
66 }
67
68 static bool mpls_pkt_too_big(const struct sk_buff *skb, unsigned int mtu)
69 {
70         if (skb->len <= mtu)
71                 return false;
72
73         if (skb_is_gso(skb) && skb_gso_network_seglen(skb) <= mtu)
74                 return false;
75
76         return true;
77 }
78
79 static bool mpls_egress(struct mpls_route *rt, struct sk_buff *skb,
80                         struct mpls_entry_decoded dec)
81 {
82         /* RFC4385 and RFC5586 encode other packets in mpls such that
83          * they don't conflict with the ip version number, making
84          * decoding by examining the ip version correct in everything
85          * except for the strangest cases.
86          *
87          * The strange cases if we choose to support them will require
88          * manual configuration.
89          */
90         struct iphdr *hdr4 = ip_hdr(skb);
91         bool success = true;
92
93         if (hdr4->version == 4) {
94                 skb->protocol = htons(ETH_P_IP);
95                 csum_replace2(&hdr4->check,
96                               htons(hdr4->ttl << 8),
97                               htons(dec.ttl << 8));
98                 hdr4->ttl = dec.ttl;
99         }
100         else if (hdr4->version == 6) {
101                 struct ipv6hdr *hdr6 = ipv6_hdr(skb);
102                 skb->protocol = htons(ETH_P_IPV6);
103                 hdr6->hop_limit = dec.ttl;
104         }
105         else
106                 /* version 0 and version 1 are used by pseudo wires */
107                 success = false;
108         return success;
109 }
110
111 static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
112                         struct packet_type *pt, struct net_device *orig_dev)
113 {
114         struct net *net = dev_net(dev);
115         struct mpls_shim_hdr *hdr;
116         struct mpls_route *rt;
117         struct mpls_entry_decoded dec;
118         struct net_device *out_dev;
119         unsigned int hh_len;
120         unsigned int new_header_size;
121         unsigned int mtu;
122         int err;
123
124         /* Careful this entire function runs inside of an rcu critical section */
125
126         if (skb->pkt_type != PACKET_HOST)
127                 goto drop;
128
129         if ((skb = skb_share_check(skb, GFP_ATOMIC)) == NULL)
130                 goto drop;
131
132         if (!pskb_may_pull(skb, sizeof(*hdr)))
133                 goto drop;
134
135         /* Read and decode the label */
136         hdr = mpls_hdr(skb);
137         dec = mpls_entry_decode(hdr);
138
139         /* Pop the label */
140         skb_pull(skb, sizeof(*hdr));
141         skb_reset_network_header(skb);
142
143         skb_orphan(skb);
144
145         rt = mpls_route_input_rcu(net, dec.label);
146         if (!rt)
147                 goto drop;
148
149         /* Find the output device */
150         out_dev = rt->rt_dev;
151         if (!mpls_output_possible(out_dev))
152                 goto drop;
153
154         if (skb_warn_if_lro(skb))
155                 goto drop;
156
157         skb_forward_csum(skb);
158
159         /* Verify ttl is valid */
160         if (dec.ttl <= 2)
161                 goto drop;
162         dec.ttl -= 1;
163
164         /* Verify the destination can hold the packet */
165         new_header_size = mpls_rt_header_size(rt);
166         mtu = mpls_dev_mtu(out_dev);
167         if (mpls_pkt_too_big(skb, mtu - new_header_size))
168                 goto drop;
169
170         hh_len = LL_RESERVED_SPACE(out_dev);
171         if (!out_dev->header_ops)
172                 hh_len = 0;
173
174         /* Ensure there is enough space for the headers in the skb */
175         if (skb_cow(skb, hh_len + new_header_size))
176                 goto drop;
177
178         skb->dev = out_dev;
179         skb->protocol = htons(ETH_P_MPLS_UC);
180
181         if (unlikely(!new_header_size && dec.bos)) {
182                 /* Penultimate hop popping */
183                 if (!mpls_egress(rt, skb, dec))
184                         goto drop;
185         } else {
186                 bool bos;
187                 int i;
188                 skb_push(skb, new_header_size);
189                 skb_reset_network_header(skb);
190                 /* Push the new labels */
191                 hdr = mpls_hdr(skb);
192                 bos = dec.bos;
193                 for (i = rt->rt_labels - 1; i >= 0; i--) {
194                         hdr[i] = mpls_entry_encode(rt->rt_label[i], dec.ttl, 0, bos);
195                         bos = false;
196                 }
197         }
198
199         err = neigh_xmit(rt->rt_via_family, out_dev, rt->rt_via, skb);
200         if (err)
201                 net_dbg_ratelimited("%s: packet transmission failed: %d\n",
202                                     __func__, err);
203         return 0;
204
205 drop:
206         kfree_skb(skb);
207         return NET_RX_DROP;
208 }
209
210 static struct packet_type mpls_packet_type __read_mostly = {
211         .type = cpu_to_be16(ETH_P_MPLS_UC),
212         .func = mpls_forward,
213 };
214
215 struct mpls_route_config {
216         u32             rc_protocol;
217         u32             rc_ifindex;
218         u16             rc_via_family;
219         u16             rc_via_alen;
220         u8              rc_via[MAX_VIA_ALEN];
221         u32             rc_label;
222         u32             rc_output_labels;
223         u32             rc_output_label[MAX_NEW_LABELS];
224         u32             rc_nlflags;
225         struct nl_info  rc_nlinfo;
226 };
227
228 static struct mpls_route *mpls_rt_alloc(size_t alen)
229 {
230         struct mpls_route *rt;
231
232         rt = kzalloc(GFP_KERNEL, sizeof(*rt) + alen);
233         if (rt)
234                 rt->rt_via_alen = alen;
235         return rt;
236 }
237
238 static void mpls_rt_free(struct mpls_route *rt)
239 {
240         if (rt)
241                 kfree_rcu(rt, rt_rcu);
242 }
243
244 static void mpls_route_update(struct net *net, unsigned index,
245                               struct net_device *dev, struct mpls_route *new,
246                               const struct nl_info *info)
247 {
248         struct mpls_route *rt, *old = NULL;
249
250         ASSERT_RTNL();
251
252         rt = net->mpls.platform_label[index];
253         if (!dev || (rt && (rt->rt_dev == dev))) {
254                 rcu_assign_pointer(net->mpls.platform_label[index], new);
255                 old = rt;
256         }
257
258         /* If we removed a route free it now */
259         mpls_rt_free(old);
260 }
261
262 static unsigned find_free_label(struct net *net)
263 {
264         unsigned index;
265         for (index = 16; index < net->mpls.platform_labels; index++) {
266                 if (!net->mpls.platform_label[index])
267                         return index;
268         }
269         return LABEL_NOT_SPECIFIED;
270 }
271
272 static int mpls_route_add(struct mpls_route_config *cfg)
273 {
274         struct net *net = cfg->rc_nlinfo.nl_net;
275         struct net_device *dev = NULL;
276         struct mpls_route *rt, *old;
277         unsigned index;
278         int i;
279         int err = -EINVAL;
280
281         index = cfg->rc_label;
282
283         /* If a label was not specified during insert pick one */
284         if ((index == LABEL_NOT_SPECIFIED) &&
285             (cfg->rc_nlflags & NLM_F_CREATE)) {
286                 index = find_free_label(net);
287         }
288
289         /* The first 16 labels are reserved, and may not be set */
290         if (index < 16)
291                 goto errout;
292
293         /* The full 20 bit range may not be supported. */
294         if (index >= net->mpls.platform_labels)
295                 goto errout;
296
297         /* Ensure only a supported number of labels are present */
298         if (cfg->rc_output_labels > MAX_NEW_LABELS)
299                 goto errout;
300
301         err = -ENODEV;
302         dev = dev_get_by_index(net, cfg->rc_ifindex);
303         if (!dev)
304                 goto errout;
305
306         /* For now just support ethernet devices */
307         err = -EINVAL;
308         if ((dev->type != ARPHRD_ETHER) && (dev->type != ARPHRD_LOOPBACK))
309                 goto errout;
310
311         err = -EINVAL;
312         if ((cfg->rc_via_family == AF_PACKET) &&
313             (dev->addr_len != cfg->rc_via_alen))
314                 goto errout;
315
316         /* Append makes no sense with mpls */
317         err = -EINVAL;
318         if (cfg->rc_nlflags & NLM_F_APPEND)
319                 goto errout;
320
321         err = -EEXIST;
322         old = net->mpls.platform_label[index];
323         if ((cfg->rc_nlflags & NLM_F_EXCL) && old)
324                 goto errout;
325
326         err = -EEXIST;
327         if (!(cfg->rc_nlflags & NLM_F_REPLACE) && old)
328                 goto errout;
329
330         err = -ENOENT;
331         if (!(cfg->rc_nlflags & NLM_F_CREATE) && !old)
332                 goto errout;
333
334         err = -ENOMEM;
335         rt = mpls_rt_alloc(cfg->rc_via_alen);
336         if (!rt)
337                 goto errout;
338
339         rt->rt_labels = cfg->rc_output_labels;
340         for (i = 0; i < rt->rt_labels; i++)
341                 rt->rt_label[i] = cfg->rc_output_label[i];
342         rt->rt_protocol = cfg->rc_protocol;
343         rt->rt_dev = dev;
344         rt->rt_via_family = cfg->rc_via_family;
345         memcpy(rt->rt_via, cfg->rc_via, cfg->rc_via_alen);
346
347         mpls_route_update(net, index, NULL, rt, &cfg->rc_nlinfo);
348
349         dev_put(dev);
350         return 0;
351
352 errout:
353         if (dev)
354                 dev_put(dev);
355         return err;
356 }
357
358 static int mpls_route_del(struct mpls_route_config *cfg)
359 {
360         struct net *net = cfg->rc_nlinfo.nl_net;
361         unsigned index;
362         int err = -EINVAL;
363
364         index = cfg->rc_label;
365
366         /* The first 16 labels are reserved, and may not be removed */
367         if (index < 16)
368                 goto errout;
369
370         /* The full 20 bit range may not be supported */
371         if (index >= net->mpls.platform_labels)
372                 goto errout;
373
374         mpls_route_update(net, index, NULL, NULL, &cfg->rc_nlinfo);
375
376         err = 0;
377 errout:
378         return err;
379 }
380
381 static void mpls_ifdown(struct net_device *dev)
382 {
383         struct net *net = dev_net(dev);
384         unsigned index;
385
386         for (index = 0; index < net->mpls.platform_labels; index++) {
387                 struct mpls_route *rt = net->mpls.platform_label[index];
388                 if (!rt)
389                         continue;
390                 if (rt->rt_dev != dev)
391                         continue;
392                 rt->rt_dev = NULL;
393         }
394 }
395
396 static int mpls_dev_notify(struct notifier_block *this, unsigned long event,
397                            void *ptr)
398 {
399         struct net_device *dev = netdev_notifier_info_to_dev(ptr);
400
401         switch(event) {
402         case NETDEV_UNREGISTER:
403                 mpls_ifdown(dev);
404                 break;
405         }
406         return NOTIFY_OK;
407 }
408
409 static struct notifier_block mpls_dev_notifier = {
410         .notifier_call = mpls_dev_notify,
411 };
412
413 int nla_put_labels(struct sk_buff *skb, int attrtype,
414                    u8 labels, const u32 label[])
415 {
416         struct nlattr *nla;
417         struct mpls_shim_hdr *nla_label;
418         bool bos;
419         int i;
420         nla = nla_reserve(skb, attrtype, labels*4);
421         if (!nla)
422                 return -EMSGSIZE;
423
424         nla_label = nla_data(nla);
425         bos = true;
426         for (i = labels - 1; i >= 0; i--) {
427                 nla_label[i] = mpls_entry_encode(label[i], 0, 0, bos);
428                 bos = false;
429         }
430
431         return 0;
432 }
433
434 int nla_get_labels(const struct nlattr *nla,
435                    u32 max_labels, u32 *labels, u32 label[])
436 {
437         unsigned len = nla_len(nla);
438         unsigned nla_labels;
439         struct mpls_shim_hdr *nla_label;
440         bool bos;
441         int i;
442
443         /* len needs to be an even multiple of 4 (the label size) */
444         if (len & 3)
445                 return -EINVAL;
446
447         /* Limit the number of new labels allowed */
448         nla_labels = len/4;
449         if (nla_labels > max_labels)
450                 return -EINVAL;
451
452         nla_label = nla_data(nla);
453         bos = true;
454         for (i = nla_labels - 1; i >= 0; i--, bos = false) {
455                 struct mpls_entry_decoded dec;
456                 dec = mpls_entry_decode(nla_label + i);
457
458                 /* Ensure the bottom of stack flag is properly set
459                  * and ttl and tc are both clear.
460                  */
461                 if ((dec.bos != bos) || dec.ttl || dec.tc)
462                         return -EINVAL;
463
464                 label[i] = dec.label;
465         }
466         *labels = nla_labels;
467         return 0;
468 }
469
470 static int resize_platform_label_table(struct net *net, size_t limit)
471 {
472         size_t size = sizeof(struct mpls_route *) * limit;
473         size_t old_limit;
474         size_t cp_size;
475         struct mpls_route __rcu **labels = NULL, **old;
476         struct mpls_route *rt0 = NULL, *rt2 = NULL;
477         unsigned index;
478
479         if (size) {
480                 labels = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
481                 if (!labels)
482                         labels = vzalloc(size);
483
484                 if (!labels)
485                         goto nolabels;
486         }
487
488         /* In case the predefined labels need to be populated */
489         if (limit > LABEL_IPV4_EXPLICIT_NULL) {
490                 struct net_device *lo = net->loopback_dev;
491                 rt0 = mpls_rt_alloc(lo->addr_len);
492                 if (!rt0)
493                         goto nort0;
494                 rt0->rt_dev = lo;
495                 rt0->rt_protocol = RTPROT_KERNEL;
496                 rt0->rt_via_family = AF_PACKET;
497                 memcpy(rt0->rt_via, lo->dev_addr, lo->addr_len);
498         }
499         if (limit > LABEL_IPV6_EXPLICIT_NULL) {
500                 struct net_device *lo = net->loopback_dev;
501                 rt2 = mpls_rt_alloc(lo->addr_len);
502                 if (!rt2)
503                         goto nort2;
504                 rt2->rt_dev = lo;
505                 rt2->rt_protocol = RTPROT_KERNEL;
506                 rt2->rt_via_family = AF_PACKET;
507                 memcpy(rt2->rt_via, lo->dev_addr, lo->addr_len);
508         }
509
510         rtnl_lock();
511         /* Remember the original table */
512         old = net->mpls.platform_label;
513         old_limit = net->mpls.platform_labels;
514
515         /* Free any labels beyond the new table */
516         for (index = limit; index < old_limit; index++)
517                 mpls_route_update(net, index, NULL, NULL, NULL);
518
519         /* Copy over the old labels */
520         cp_size = size;
521         if (old_limit < limit)
522                 cp_size = old_limit * sizeof(struct mpls_route *);
523
524         memcpy(labels, old, cp_size);
525
526         /* If needed set the predefined labels */
527         if ((old_limit <= LABEL_IPV6_EXPLICIT_NULL) &&
528             (limit > LABEL_IPV6_EXPLICIT_NULL)) {
529                 labels[LABEL_IPV6_EXPLICIT_NULL] = rt2;
530                 rt2 = NULL;
531         }
532
533         if ((old_limit <= LABEL_IPV4_EXPLICIT_NULL) &&
534             (limit > LABEL_IPV4_EXPLICIT_NULL)) {
535                 labels[LABEL_IPV4_EXPLICIT_NULL] = rt0;
536                 rt0 = NULL;
537         }
538
539         /* Update the global pointers */
540         net->mpls.platform_labels = limit;
541         net->mpls.platform_label = labels;
542
543         rtnl_unlock();
544
545         mpls_rt_free(rt2);
546         mpls_rt_free(rt0);
547
548         if (old) {
549                 synchronize_rcu();
550                 kvfree(old);
551         }
552         return 0;
553
554 nort2:
555         mpls_rt_free(rt0);
556 nort0:
557         kvfree(labels);
558 nolabels:
559         return -ENOMEM;
560 }
561
562 static int mpls_platform_labels(struct ctl_table *table, int write,
563                                 void __user *buffer, size_t *lenp, loff_t *ppos)
564 {
565         struct net *net = table->data;
566         int platform_labels = net->mpls.platform_labels;
567         int ret;
568         struct ctl_table tmp = {
569                 .procname       = table->procname,
570                 .data           = &platform_labels,
571                 .maxlen         = sizeof(int),
572                 .mode           = table->mode,
573                 .extra1         = &zero,
574                 .extra2         = &label_limit,
575         };
576
577         ret = proc_dointvec_minmax(&tmp, write, buffer, lenp, ppos);
578
579         if (write && ret == 0)
580                 ret = resize_platform_label_table(net, platform_labels);
581
582         return ret;
583 }
584
585 static struct ctl_table mpls_table[] = {
586         {
587                 .procname       = "platform_labels",
588                 .data           = NULL,
589                 .maxlen         = sizeof(int),
590                 .mode           = 0644,
591                 .proc_handler   = mpls_platform_labels,
592         },
593         { }
594 };
595
596 static int mpls_net_init(struct net *net)
597 {
598         struct ctl_table *table;
599
600         net->mpls.platform_labels = 0;
601         net->mpls.platform_label = NULL;
602
603         table = kmemdup(mpls_table, sizeof(mpls_table), GFP_KERNEL);
604         if (table == NULL)
605                 return -ENOMEM;
606
607         table[0].data = net;
608         net->mpls.ctl = register_net_sysctl(net, "net/mpls", table);
609         if (net->mpls.ctl == NULL)
610                 return -ENOMEM;
611
612         return 0;
613 }
614
615 static void mpls_net_exit(struct net *net)
616 {
617         struct ctl_table *table;
618         unsigned int index;
619
620         table = net->mpls.ctl->ctl_table_arg;
621         unregister_net_sysctl_table(net->mpls.ctl);
622         kfree(table);
623
624         /* An rcu grace period haselapsed since there was a device in
625          * the network namespace (and thus the last in fqlight packet)
626          * left this network namespace.  This is because
627          * unregister_netdevice_many and netdev_run_todo has completed
628          * for each network device that was in this network namespace.
629          *
630          * As such no additional rcu synchronization is necessary when
631          * freeing the platform_label table.
632          */
633         rtnl_lock();
634         for (index = 0; index < net->mpls.platform_labels; index++) {
635                 struct mpls_route *rt = net->mpls.platform_label[index];
636                 rcu_assign_pointer(net->mpls.platform_label[index], NULL);
637                 mpls_rt_free(rt);
638         }
639         rtnl_unlock();
640
641         kvfree(net->mpls.platform_label);
642 }
643
644 static struct pernet_operations mpls_net_ops = {
645         .init = mpls_net_init,
646         .exit = mpls_net_exit,
647 };
648
649 static int __init mpls_init(void)
650 {
651         int err;
652
653         BUILD_BUG_ON(sizeof(struct mpls_shim_hdr) != 4);
654
655         err = register_pernet_subsys(&mpls_net_ops);
656         if (err)
657                 goto out;
658
659         err = register_netdevice_notifier(&mpls_dev_notifier);
660         if (err)
661                 goto out_unregister_pernet;
662
663         dev_add_pack(&mpls_packet_type);
664
665         err = 0;
666 out:
667         return err;
668
669 out_unregister_pernet:
670         unregister_pernet_subsys(&mpls_net_ops);
671         goto out;
672 }
673 module_init(mpls_init);
674
675 static void __exit mpls_exit(void)
676 {
677         dev_remove_pack(&mpls_packet_type);
678         unregister_netdevice_notifier(&mpls_dev_notifier);
679         unregister_pernet_subsys(&mpls_net_ops);
680 }
681 module_exit(mpls_exit);
682
683 MODULE_DESCRIPTION("MultiProtocol Label Switching");
684 MODULE_LICENSE("GPL v2");
685 MODULE_ALIAS_NETPROTO(PF_MPLS);