mpls: Add a sysctl to control the size of the mpls label table
[firefly-linux-kernel-4.4.55.git] / net / mpls / af_mpls.c
1 #include <linux/types.h>
2 #include <linux/skbuff.h>
3 #include <linux/socket.h>
4 #include <linux/sysctl.h>
5 #include <linux/net.h>
6 #include <linux/module.h>
7 #include <linux/if_arp.h>
8 #include <linux/ipv6.h>
9 #include <linux/mpls.h>
10 #include <net/ip.h>
11 #include <net/dst.h>
12 #include <net/sock.h>
13 #include <net/arp.h>
14 #include <net/ip_fib.h>
15 #include <net/netevent.h>
16 #include <net/netns/generic.h>
17 #include "internal.h"
18
19 #define MAX_NEW_LABELS 2
20
21 /* This maximum ha length copied from the definition of struct neighbour */
22 #define MAX_VIA_ALEN (ALIGN(MAX_ADDR_LEN, sizeof(unsigned long)))
23
24 struct mpls_route { /* next hop label forwarding entry */
25         struct net_device       *rt_dev;
26         struct rcu_head         rt_rcu;
27         u32                     rt_label[MAX_NEW_LABELS];
28         u8                      rt_protocol; /* routing protocol that set this entry */
29         u8                      rt_labels:2,
30                                 rt_via_alen:6;
31         unsigned short          rt_via_family;
32         u8                      rt_via[0];
33 };
34
35 static int zero = 0;
36 static int label_limit = (1 << 20) - 1;
37
38 static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned index)
39 {
40         struct mpls_route *rt = NULL;
41
42         if (index < net->mpls.platform_labels) {
43                 struct mpls_route __rcu **platform_label =
44                         rcu_dereference(net->mpls.platform_label);
45                 rt = rcu_dereference(platform_label[index]);
46         }
47         return rt;
48 }
49
50 static bool mpls_output_possible(const struct net_device *dev)
51 {
52         return dev && (dev->flags & IFF_UP) && netif_carrier_ok(dev);
53 }
54
55 static unsigned int mpls_rt_header_size(const struct mpls_route *rt)
56 {
57         /* The size of the layer 2.5 labels to be added for this route */
58         return rt->rt_labels * sizeof(struct mpls_shim_hdr);
59 }
60
61 static unsigned int mpls_dev_mtu(const struct net_device *dev)
62 {
63         /* The amount of data the layer 2 frame can hold */
64         return dev->mtu;
65 }
66
67 static bool mpls_pkt_too_big(const struct sk_buff *skb, unsigned int mtu)
68 {
69         if (skb->len <= mtu)
70                 return false;
71
72         if (skb_is_gso(skb) && skb_gso_network_seglen(skb) <= mtu)
73                 return false;
74
75         return true;
76 }
77
78 static bool mpls_egress(struct mpls_route *rt, struct sk_buff *skb,
79                         struct mpls_entry_decoded dec)
80 {
81         /* RFC4385 and RFC5586 encode other packets in mpls such that
82          * they don't conflict with the ip version number, making
83          * decoding by examining the ip version correct in everything
84          * except for the strangest cases.
85          *
86          * The strange cases if we choose to support them will require
87          * manual configuration.
88          */
89         struct iphdr *hdr4 = ip_hdr(skb);
90         bool success = true;
91
92         if (hdr4->version == 4) {
93                 skb->protocol = htons(ETH_P_IP);
94                 csum_replace2(&hdr4->check,
95                               htons(hdr4->ttl << 8),
96                               htons(dec.ttl << 8));
97                 hdr4->ttl = dec.ttl;
98         }
99         else if (hdr4->version == 6) {
100                 struct ipv6hdr *hdr6 = ipv6_hdr(skb);
101                 skb->protocol = htons(ETH_P_IPV6);
102                 hdr6->hop_limit = dec.ttl;
103         }
104         else
105                 /* version 0 and version 1 are used by pseudo wires */
106                 success = false;
107         return success;
108 }
109
110 static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
111                         struct packet_type *pt, struct net_device *orig_dev)
112 {
113         struct net *net = dev_net(dev);
114         struct mpls_shim_hdr *hdr;
115         struct mpls_route *rt;
116         struct mpls_entry_decoded dec;
117         struct net_device *out_dev;
118         unsigned int hh_len;
119         unsigned int new_header_size;
120         unsigned int mtu;
121         int err;
122
123         /* Careful this entire function runs inside of an rcu critical section */
124
125         if (skb->pkt_type != PACKET_HOST)
126                 goto drop;
127
128         if ((skb = skb_share_check(skb, GFP_ATOMIC)) == NULL)
129                 goto drop;
130
131         if (!pskb_may_pull(skb, sizeof(*hdr)))
132                 goto drop;
133
134         /* Read and decode the label */
135         hdr = mpls_hdr(skb);
136         dec = mpls_entry_decode(hdr);
137
138         /* Pop the label */
139         skb_pull(skb, sizeof(*hdr));
140         skb_reset_network_header(skb);
141
142         skb_orphan(skb);
143
144         rt = mpls_route_input_rcu(net, dec.label);
145         if (!rt)
146                 goto drop;
147
148         /* Find the output device */
149         out_dev = rt->rt_dev;
150         if (!mpls_output_possible(out_dev))
151                 goto drop;
152
153         if (skb_warn_if_lro(skb))
154                 goto drop;
155
156         skb_forward_csum(skb);
157
158         /* Verify ttl is valid */
159         if (dec.ttl <= 2)
160                 goto drop;
161         dec.ttl -= 1;
162
163         /* Verify the destination can hold the packet */
164         new_header_size = mpls_rt_header_size(rt);
165         mtu = mpls_dev_mtu(out_dev);
166         if (mpls_pkt_too_big(skb, mtu - new_header_size))
167                 goto drop;
168
169         hh_len = LL_RESERVED_SPACE(out_dev);
170         if (!out_dev->header_ops)
171                 hh_len = 0;
172
173         /* Ensure there is enough space for the headers in the skb */
174         if (skb_cow(skb, hh_len + new_header_size))
175                 goto drop;
176
177         skb->dev = out_dev;
178         skb->protocol = htons(ETH_P_MPLS_UC);
179
180         if (unlikely(!new_header_size && dec.bos)) {
181                 /* Penultimate hop popping */
182                 if (!mpls_egress(rt, skb, dec))
183                         goto drop;
184         } else {
185                 bool bos;
186                 int i;
187                 skb_push(skb, new_header_size);
188                 skb_reset_network_header(skb);
189                 /* Push the new labels */
190                 hdr = mpls_hdr(skb);
191                 bos = dec.bos;
192                 for (i = rt->rt_labels - 1; i >= 0; i--) {
193                         hdr[i] = mpls_entry_encode(rt->rt_label[i], dec.ttl, 0, bos);
194                         bos = false;
195                 }
196         }
197
198         err = neigh_xmit(rt->rt_via_family, out_dev, rt->rt_via, skb);
199         if (err)
200                 net_dbg_ratelimited("%s: packet transmission failed: %d\n",
201                                     __func__, err);
202         return 0;
203
204 drop:
205         kfree_skb(skb);
206         return NET_RX_DROP;
207 }
208
209 static struct packet_type mpls_packet_type __read_mostly = {
210         .type = cpu_to_be16(ETH_P_MPLS_UC),
211         .func = mpls_forward,
212 };
213
214 static struct mpls_route *mpls_rt_alloc(size_t alen)
215 {
216         struct mpls_route *rt;
217
218         rt = kzalloc(GFP_KERNEL, sizeof(*rt) + alen);
219         if (rt)
220                 rt->rt_via_alen = alen;
221         return rt;
222 }
223
224 static void mpls_rt_free(struct mpls_route *rt)
225 {
226         if (rt)
227                 kfree_rcu(rt, rt_rcu);
228 }
229
230 static void mpls_route_update(struct net *net, unsigned index,
231                               struct net_device *dev, struct mpls_route *new,
232                               const struct nl_info *info)
233 {
234         struct mpls_route *rt, *old = NULL;
235
236         ASSERT_RTNL();
237
238         rt = net->mpls.platform_label[index];
239         if (!dev || (rt && (rt->rt_dev == dev))) {
240                 rcu_assign_pointer(net->mpls.platform_label[index], new);
241                 old = rt;
242         }
243
244         /* If we removed a route free it now */
245         mpls_rt_free(old);
246 }
247
248 static void mpls_ifdown(struct net_device *dev)
249 {
250         struct net *net = dev_net(dev);
251         unsigned index;
252
253         for (index = 0; index < net->mpls.platform_labels; index++) {
254                 struct mpls_route *rt = net->mpls.platform_label[index];
255                 if (!rt)
256                         continue;
257                 if (rt->rt_dev != dev)
258                         continue;
259                 rt->rt_dev = NULL;
260         }
261 }
262
263 static int mpls_dev_notify(struct notifier_block *this, unsigned long event,
264                            void *ptr)
265 {
266         struct net_device *dev = netdev_notifier_info_to_dev(ptr);
267
268         switch(event) {
269         case NETDEV_UNREGISTER:
270                 mpls_ifdown(dev);
271                 break;
272         }
273         return NOTIFY_OK;
274 }
275
276 static struct notifier_block mpls_dev_notifier = {
277         .notifier_call = mpls_dev_notify,
278 };
279
280 static int resize_platform_label_table(struct net *net, size_t limit)
281 {
282         size_t size = sizeof(struct mpls_route *) * limit;
283         size_t old_limit;
284         size_t cp_size;
285         struct mpls_route __rcu **labels = NULL, **old;
286         struct mpls_route *rt0 = NULL, *rt2 = NULL;
287         unsigned index;
288
289         if (size) {
290                 labels = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
291                 if (!labels)
292                         labels = vzalloc(size);
293
294                 if (!labels)
295                         goto nolabels;
296         }
297
298         /* In case the predefined labels need to be populated */
299         if (limit > LABEL_IPV4_EXPLICIT_NULL) {
300                 struct net_device *lo = net->loopback_dev;
301                 rt0 = mpls_rt_alloc(lo->addr_len);
302                 if (!rt0)
303                         goto nort0;
304                 rt0->rt_dev = lo;
305                 rt0->rt_protocol = RTPROT_KERNEL;
306                 rt0->rt_via_family = AF_PACKET;
307                 memcpy(rt0->rt_via, lo->dev_addr, lo->addr_len);
308         }
309         if (limit > LABEL_IPV6_EXPLICIT_NULL) {
310                 struct net_device *lo = net->loopback_dev;
311                 rt2 = mpls_rt_alloc(lo->addr_len);
312                 if (!rt2)
313                         goto nort2;
314                 rt2->rt_dev = lo;
315                 rt2->rt_protocol = RTPROT_KERNEL;
316                 rt2->rt_via_family = AF_PACKET;
317                 memcpy(rt2->rt_via, lo->dev_addr, lo->addr_len);
318         }
319
320         rtnl_lock();
321         /* Remember the original table */
322         old = net->mpls.platform_label;
323         old_limit = net->mpls.platform_labels;
324
325         /* Free any labels beyond the new table */
326         for (index = limit; index < old_limit; index++)
327                 mpls_route_update(net, index, NULL, NULL, NULL);
328
329         /* Copy over the old labels */
330         cp_size = size;
331         if (old_limit < limit)
332                 cp_size = old_limit * sizeof(struct mpls_route *);
333
334         memcpy(labels, old, cp_size);
335
336         /* If needed set the predefined labels */
337         if ((old_limit <= LABEL_IPV6_EXPLICIT_NULL) &&
338             (limit > LABEL_IPV6_EXPLICIT_NULL)) {
339                 labels[LABEL_IPV6_EXPLICIT_NULL] = rt2;
340                 rt2 = NULL;
341         }
342
343         if ((old_limit <= LABEL_IPV4_EXPLICIT_NULL) &&
344             (limit > LABEL_IPV4_EXPLICIT_NULL)) {
345                 labels[LABEL_IPV4_EXPLICIT_NULL] = rt0;
346                 rt0 = NULL;
347         }
348
349         /* Update the global pointers */
350         net->mpls.platform_labels = limit;
351         net->mpls.platform_label = labels;
352
353         rtnl_unlock();
354
355         mpls_rt_free(rt2);
356         mpls_rt_free(rt0);
357
358         if (old) {
359                 synchronize_rcu();
360                 kvfree(old);
361         }
362         return 0;
363
364 nort2:
365         mpls_rt_free(rt0);
366 nort0:
367         kvfree(labels);
368 nolabels:
369         return -ENOMEM;
370 }
371
372 static int mpls_platform_labels(struct ctl_table *table, int write,
373                                 void __user *buffer, size_t *lenp, loff_t *ppos)
374 {
375         struct net *net = table->data;
376         int platform_labels = net->mpls.platform_labels;
377         int ret;
378         struct ctl_table tmp = {
379                 .procname       = table->procname,
380                 .data           = &platform_labels,
381                 .maxlen         = sizeof(int),
382                 .mode           = table->mode,
383                 .extra1         = &zero,
384                 .extra2         = &label_limit,
385         };
386
387         ret = proc_dointvec_minmax(&tmp, write, buffer, lenp, ppos);
388
389         if (write && ret == 0)
390                 ret = resize_platform_label_table(net, platform_labels);
391
392         return ret;
393 }
394
395 static struct ctl_table mpls_table[] = {
396         {
397                 .procname       = "platform_labels",
398                 .data           = NULL,
399                 .maxlen         = sizeof(int),
400                 .mode           = 0644,
401                 .proc_handler   = mpls_platform_labels,
402         },
403         { }
404 };
405
406 static int mpls_net_init(struct net *net)
407 {
408         struct ctl_table *table;
409
410         net->mpls.platform_labels = 0;
411         net->mpls.platform_label = NULL;
412
413         table = kmemdup(mpls_table, sizeof(mpls_table), GFP_KERNEL);
414         if (table == NULL)
415                 return -ENOMEM;
416
417         table[0].data = net;
418         net->mpls.ctl = register_net_sysctl(net, "net/mpls", table);
419         if (net->mpls.ctl == NULL)
420                 return -ENOMEM;
421
422         return 0;
423 }
424
425 static void mpls_net_exit(struct net *net)
426 {
427         struct ctl_table *table;
428         unsigned int index;
429
430         table = net->mpls.ctl->ctl_table_arg;
431         unregister_net_sysctl_table(net->mpls.ctl);
432         kfree(table);
433
434         /* An rcu grace period haselapsed since there was a device in
435          * the network namespace (and thus the last in fqlight packet)
436          * left this network namespace.  This is because
437          * unregister_netdevice_many and netdev_run_todo has completed
438          * for each network device that was in this network namespace.
439          *
440          * As such no additional rcu synchronization is necessary when
441          * freeing the platform_label table.
442          */
443         rtnl_lock();
444         for (index = 0; index < net->mpls.platform_labels; index++) {
445                 struct mpls_route *rt = net->mpls.platform_label[index];
446                 rcu_assign_pointer(net->mpls.platform_label[index], NULL);
447                 mpls_rt_free(rt);
448         }
449         rtnl_unlock();
450
451         kvfree(net->mpls.platform_label);
452 }
453
454 static struct pernet_operations mpls_net_ops = {
455         .init = mpls_net_init,
456         .exit = mpls_net_exit,
457 };
458
459 static int __init mpls_init(void)
460 {
461         int err;
462
463         BUILD_BUG_ON(sizeof(struct mpls_shim_hdr) != 4);
464
465         err = register_pernet_subsys(&mpls_net_ops);
466         if (err)
467                 goto out;
468
469         err = register_netdevice_notifier(&mpls_dev_notifier);
470         if (err)
471                 goto out_unregister_pernet;
472
473         dev_add_pack(&mpls_packet_type);
474
475         err = 0;
476 out:
477         return err;
478
479 out_unregister_pernet:
480         unregister_pernet_subsys(&mpls_net_ops);
481         goto out;
482 }
483 module_init(mpls_init);
484
485 static void __exit mpls_exit(void)
486 {
487         dev_remove_pack(&mpls_packet_type);
488         unregister_netdevice_notifier(&mpls_dev_notifier);
489         unregister_pernet_subsys(&mpls_net_ops);
490 }
491 module_exit(mpls_exit);
492
493 MODULE_DESCRIPTION("MultiProtocol Label Switching");
494 MODULE_LICENSE("GPL v2");
495 MODULE_ALIAS_NETPROTO(PF_MPLS);