Merge tag 'xfs-for-linus-4.3' of git://git.kernel.org/pub/scm/linux/kernel/git/dgc...
[firefly-linux-kernel-4.4.55.git] / net / mpls / af_mpls.c
index 6e669114f829d696dd77a79f0e81877b8593fc13..bb185a28de9890d2f4b3c57d1ca7af7600f9b2aa 100644 (file)
 #include <net/ip_fib.h>
 #include <net/netevent.h>
 #include <net/netns/generic.h>
+#if IS_ENABLED(CONFIG_IPV6)
+#include <net/ipv6.h>
+#include <net/addrconf.h>
+#endif
 #include "internal.h"
 
 #define LABEL_NOT_SPECIFIED (1<<20)
 /* This maximum ha length copied from the definition of struct neighbour */
 #define MAX_VIA_ALEN (ALIGN(MAX_ADDR_LEN, sizeof(unsigned long)))
 
+enum mpls_payload_type {
+       MPT_UNSPEC, /* IPv4 or IPv6 */
+       MPT_IPV4 = 4,
+       MPT_IPV6 = 6,
+
+       /* Other types not implemented:
+        *  - Pseudo-wire with or without control word (RFC4385)
+        *  - GAL (RFC5586)
+        */
+};
+
 struct mpls_route { /* next hop label forwarding entry */
        struct net_device __rcu *rt_dev;
        struct rcu_head         rt_rcu;
        u32                     rt_label[MAX_NEW_LABELS];
        u8                      rt_protocol; /* routing protocol that set this entry */
+       u8                      rt_payload_type;
        u8                      rt_labels;
        u8                      rt_via_alen;
        u8                      rt_via_table;
@@ -92,16 +108,8 @@ EXPORT_SYMBOL_GPL(mpls_pkt_too_big);
 static bool mpls_egress(struct mpls_route *rt, struct sk_buff *skb,
                        struct mpls_entry_decoded dec)
 {
-       /* RFC4385 and RFC5586 encode other packets in mpls such that
-        * they don't conflict with the ip version number, making
-        * decoding by examining the ip version correct in everything
-        * except for the strangest cases.
-        *
-        * The strange cases if we choose to support them will require
-        * manual configuration.
-        */
-       struct iphdr *hdr4;
-       bool success = true;
+       enum mpls_payload_type payload_type;
+       bool success = false;
 
        /* The IPv4 code below accesses through the IPv4 header
         * checksum, which is 12 bytes into the packet.
@@ -116,23 +124,32 @@ static bool mpls_egress(struct mpls_route *rt, struct sk_buff *skb,
        if (!pskb_may_pull(skb, 12))
                return false;
 
-       /* Use ip_hdr to find the ip protocol version */
-       hdr4 = ip_hdr(skb);
-       if (hdr4->version == 4) {
+       payload_type = rt->rt_payload_type;
+       if (payload_type == MPT_UNSPEC)
+               payload_type = ip_hdr(skb)->version;
+
+       switch (payload_type) {
+       case MPT_IPV4: {
+               struct iphdr *hdr4 = ip_hdr(skb);
                skb->protocol = htons(ETH_P_IP);
                csum_replace2(&hdr4->check,
                              htons(hdr4->ttl << 8),
                              htons(dec.ttl << 8));
                hdr4->ttl = dec.ttl;
+               success = true;
+               break;
        }
-       else if (hdr4->version == 6) {
+       case MPT_IPV6: {
                struct ipv6hdr *hdr6 = ipv6_hdr(skb);
                skb->protocol = htons(ETH_P_IPV6);
                hdr6->hop_limit = dec.ttl;
+               success = true;
+               break;
+       }
+       case MPT_UNSPEC:
+               break;
        }
-       else
-               /* version 0 and version 1 are used by pseudo wires */
-               success = false;
+
        return success;
 }
 
@@ -251,16 +268,17 @@ static const struct nla_policy rtm_mpls_policy[RTA_MAX+1] = {
 };
 
 struct mpls_route_config {
-       u32             rc_protocol;
-       u32             rc_ifindex;
-       u16             rc_via_table;
-       u16             rc_via_alen;
-       u8              rc_via[MAX_VIA_ALEN];
-       u32             rc_label;
-       u32             rc_output_labels;
-       u32             rc_output_label[MAX_NEW_LABELS];
-       u32             rc_nlflags;
-       struct nl_info  rc_nlinfo;
+       u32                     rc_protocol;
+       u32                     rc_ifindex;
+       u16                     rc_via_table;
+       u16                     rc_via_alen;
+       u8                      rc_via[MAX_VIA_ALEN];
+       u32                     rc_label;
+       u32                     rc_output_labels;
+       u32                     rc_output_label[MAX_NEW_LABELS];
+       u32                     rc_nlflags;
+       enum mpls_payload_type  rc_payload_type;
+       struct nl_info          rc_nlinfo;
 };
 
 static struct mpls_route *mpls_rt_alloc(size_t alen)
@@ -289,7 +307,7 @@ static void mpls_notify_route(struct net *net, unsigned index,
        struct mpls_route *rt = new ? new : old;
        unsigned nlm_flags = (old && new) ? NLM_F_REPLACE : 0;
        /* Ignore reserved labels for now */
-       if (rt && (index >= 16))
+       if (rt && (index >= MPLS_LABEL_FIRST_UNRESERVED))
                rtmsg_lfib(event, index, rt, nlh, net, portid, nlm_flags);
 }
 
@@ -323,13 +341,96 @@ static unsigned find_free_label(struct net *net)
 
        platform_label = rtnl_dereference(net->mpls.platform_label);
        platform_labels = net->mpls.platform_labels;
-       for (index = 16; index < platform_labels; index++) {
+       for (index = MPLS_LABEL_FIRST_UNRESERVED; index < platform_labels;
+            index++) {
                if (!rtnl_dereference(platform_label[index]))
                        return index;
        }
        return LABEL_NOT_SPECIFIED;
 }
 
+#if IS_ENABLED(CONFIG_INET)
+static struct net_device *inet_fib_lookup_dev(struct net *net, void *addr)
+{
+       struct net_device *dev;
+       struct rtable *rt;
+       struct in_addr daddr;
+
+       memcpy(&daddr, addr, sizeof(struct in_addr));
+       rt = ip_route_output(net, daddr.s_addr, 0, 0, 0);
+       if (IS_ERR(rt))
+               return ERR_CAST(rt);
+
+       dev = rt->dst.dev;
+       dev_hold(dev);
+
+       ip_rt_put(rt);
+
+       return dev;
+}
+#else
+static struct net_device *inet_fib_lookup_dev(struct net *net, void *addr)
+{
+       return ERR_PTR(-EAFNOSUPPORT);
+}
+#endif
+
+#if IS_ENABLED(CONFIG_IPV6)
+static struct net_device *inet6_fib_lookup_dev(struct net *net, void *addr)
+{
+       struct net_device *dev;
+       struct dst_entry *dst;
+       struct flowi6 fl6;
+       int err;
+
+       if (!ipv6_stub)
+               return ERR_PTR(-EAFNOSUPPORT);
+
+       memset(&fl6, 0, sizeof(fl6));
+       memcpy(&fl6.daddr, addr, sizeof(struct in6_addr));
+       err = ipv6_stub->ipv6_dst_lookup(net, NULL, &dst, &fl6);
+       if (err)
+               return ERR_PTR(err);
+
+       dev = dst->dev;
+       dev_hold(dev);
+       dst_release(dst);
+
+       return dev;
+}
+#else
+static struct net_device *inet6_fib_lookup_dev(struct net *net, void *addr)
+{
+       return ERR_PTR(-EAFNOSUPPORT);
+}
+#endif
+
+static struct net_device *find_outdev(struct net *net,
+                                     struct mpls_route_config *cfg)
+{
+       struct net_device *dev = NULL;
+
+       if (!cfg->rc_ifindex) {
+               switch (cfg->rc_via_table) {
+               case NEIGH_ARP_TABLE:
+                       dev = inet_fib_lookup_dev(net, cfg->rc_via);
+                       break;
+               case NEIGH_ND_TABLE:
+                       dev = inet6_fib_lookup_dev(net, cfg->rc_via);
+                       break;
+               case NEIGH_LINK_TABLE:
+                       break;
+               }
+       } else {
+               dev = dev_get_by_index(net, cfg->rc_ifindex);
+       }
+
+       if (!dev)
+               return ERR_PTR(-ENODEV);
+
+       return dev;
+}
+
 static int mpls_route_add(struct mpls_route_config *cfg)
 {
        struct mpls_route __rcu **platform_label;
@@ -348,8 +449,8 @@ static int mpls_route_add(struct mpls_route_config *cfg)
                index = find_free_label(net);
        }
 
-       /* The first 16 labels are reserved, and may not be set */
-       if (index < 16)
+       /* Reserved labels may not be set */
+       if (index < MPLS_LABEL_FIRST_UNRESERVED)
                goto errout;
 
        /* The full 20 bit range may not be supported. */
@@ -360,10 +461,12 @@ static int mpls_route_add(struct mpls_route_config *cfg)
        if (cfg->rc_output_labels > MAX_NEW_LABELS)
                goto errout;
 
-       err = -ENODEV;
-       dev = dev_get_by_index(net, cfg->rc_ifindex);
-       if (!dev)
+       dev = find_outdev(net, cfg);
+       if (IS_ERR(dev)) {
+               err = PTR_ERR(dev);
+               dev = NULL;
                goto errout;
+       }
 
        /* Ensure this is a supported device */
        err = -EINVAL;
@@ -404,6 +507,7 @@ static int mpls_route_add(struct mpls_route_config *cfg)
                rt->rt_label[i] = cfg->rc_output_label[i];
        rt->rt_protocol = cfg->rc_protocol;
        RCU_INIT_POINTER(rt->rt_dev, dev);
+       rt->rt_payload_type = cfg->rc_payload_type;
        rt->rt_via_table = cfg->rc_via_table;
        memcpy(rt->rt_via, cfg->rc_via, cfg->rc_via_alen);
 
@@ -426,8 +530,8 @@ static int mpls_route_del(struct mpls_route_config *cfg)
 
        index = cfg->rc_label;
 
-       /* The first 16 labels are reserved, and may not be removed */
-       if (index < 16)
+       /* Reserved labels may not be removed */
+       if (index < MPLS_LABEL_FIRST_UNRESERVED)
                goto errout;
 
        /* The full 20 bit range may not be supported */
@@ -745,8 +849,8 @@ static int rtm_to_route_config(struct sk_buff *skb,  struct nlmsghdr *nlh,
                                           &cfg->rc_label))
                                goto errout;
 
-                       /* The first 16 labels are reserved, and may not be set */
-                       if (cfg->rc_label < 16)
+                       /* Reserved labels may not be set */
+                       if (cfg->rc_label < MPLS_LABEL_FIRST_UNRESERVED)
                                goto errout;
 
                        break;
@@ -871,8 +975,8 @@ static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb)
        ASSERT_RTNL();
 
        index = cb->args[0];
-       if (index < 16)
-               index = 16;
+       if (index < MPLS_LABEL_FIRST_UNRESERVED)
+               index = MPLS_LABEL_FIRST_UNRESERVED;
 
        platform_label = rtnl_dereference(net->mpls.platform_label);
        platform_labels = net->mpls.platform_labels;
@@ -958,6 +1062,7 @@ static int resize_platform_label_table(struct net *net, size_t limit)
                        goto nort0;
                RCU_INIT_POINTER(rt0->rt_dev, lo);
                rt0->rt_protocol = RTPROT_KERNEL;
+               rt0->rt_payload_type = MPT_IPV4;
                rt0->rt_via_table = NEIGH_LINK_TABLE;
                memcpy(rt0->rt_via, lo->dev_addr, lo->addr_len);
        }
@@ -968,6 +1073,7 @@ static int resize_platform_label_table(struct net *net, size_t limit)
                        goto nort2;
                RCU_INIT_POINTER(rt2->rt_dev, lo);
                rt2->rt_protocol = RTPROT_KERNEL;
+               rt2->rt_payload_type = MPT_IPV6;
                rt2->rt_via_table = NEIGH_LINK_TABLE;
                memcpy(rt2->rt_via, lo->dev_addr, lo->addr_len);
        }
@@ -1071,8 +1177,10 @@ static int mpls_net_init(struct net *net)
 
        table[0].data = net;
        net->mpls.ctl = register_net_sysctl(net, "net/mpls", table);
-       if (net->mpls.ctl == NULL)
+       if (net->mpls.ctl == NULL) {
+               kfree(table);
                return -ENOMEM;
+       }
 
        return 0;
 }