Merge tag 'xfs-for-linus-4.3' of git://git.kernel.org/pub/scm/linux/kernel/git/dgc...
[firefly-linux-kernel-4.4.55.git] / net / mpls / af_mpls.c
index 88cfaa241c07b72650344da9097363f7fd85d5f2..bb185a28de9890d2f4b3c57d1ca7af7600f9b2aa 100644 (file)
 /* This maximum ha length copied from the definition of struct neighbour */
 #define MAX_VIA_ALEN (ALIGN(MAX_ADDR_LEN, sizeof(unsigned long)))
 
+enum mpls_payload_type {
+       MPT_UNSPEC, /* IPv4 or IPv6 */
+       MPT_IPV4 = 4,
+       MPT_IPV6 = 6,
+
+       /* Other types not implemented:
+        *  - Pseudo-wire with or without control word (RFC4385)
+        *  - GAL (RFC5586)
+        */
+};
+
 struct mpls_route { /* next hop label forwarding entry */
        struct net_device __rcu *rt_dev;
        struct rcu_head         rt_rcu;
        u32                     rt_label[MAX_NEW_LABELS];
        u8                      rt_protocol; /* routing protocol that set this entry */
+       u8                      rt_payload_type;
        u8                      rt_labels;
        u8                      rt_via_alen;
        u8                      rt_via_table;
@@ -96,16 +108,8 @@ EXPORT_SYMBOL_GPL(mpls_pkt_too_big);
 static bool mpls_egress(struct mpls_route *rt, struct sk_buff *skb,
                        struct mpls_entry_decoded dec)
 {
-       /* RFC4385 and RFC5586 encode other packets in mpls such that
-        * they don't conflict with the ip version number, making
-        * decoding by examining the ip version correct in everything
-        * except for the strangest cases.
-        *
-        * The strange cases if we choose to support them will require
-        * manual configuration.
-        */
-       struct iphdr *hdr4;
-       bool success = true;
+       enum mpls_payload_type payload_type;
+       bool success = false;
 
        /* The IPv4 code below accesses through the IPv4 header
         * checksum, which is 12 bytes into the packet.
@@ -120,23 +124,32 @@ static bool mpls_egress(struct mpls_route *rt, struct sk_buff *skb,
        if (!pskb_may_pull(skb, 12))
                return false;
 
-       /* Use ip_hdr to find the ip protocol version */
-       hdr4 = ip_hdr(skb);
-       if (hdr4->version == 4) {
+       payload_type = rt->rt_payload_type;
+       if (payload_type == MPT_UNSPEC)
+               payload_type = ip_hdr(skb)->version;
+
+       switch (payload_type) {
+       case MPT_IPV4: {
+               struct iphdr *hdr4 = ip_hdr(skb);
                skb->protocol = htons(ETH_P_IP);
                csum_replace2(&hdr4->check,
                              htons(hdr4->ttl << 8),
                              htons(dec.ttl << 8));
                hdr4->ttl = dec.ttl;
+               success = true;
+               break;
        }
-       else if (hdr4->version == 6) {
+       case MPT_IPV6: {
                struct ipv6hdr *hdr6 = ipv6_hdr(skb);
                skb->protocol = htons(ETH_P_IPV6);
                hdr6->hop_limit = dec.ttl;
+               success = true;
+               break;
+       }
+       case MPT_UNSPEC:
+               break;
        }
-       else
-               /* version 0 and version 1 are used by pseudo wires */
-               success = false;
+
        return success;
 }
 
@@ -255,16 +268,17 @@ static const struct nla_policy rtm_mpls_policy[RTA_MAX+1] = {
 };
 
 struct mpls_route_config {
-       u32             rc_protocol;
-       u32             rc_ifindex;
-       u16             rc_via_table;
-       u16             rc_via_alen;
-       u8              rc_via[MAX_VIA_ALEN];
-       u32             rc_label;
-       u32             rc_output_labels;
-       u32             rc_output_label[MAX_NEW_LABELS];
-       u32             rc_nlflags;
-       struct nl_info  rc_nlinfo;
+       u32                     rc_protocol;
+       u32                     rc_ifindex;
+       u16                     rc_via_table;
+       u16                     rc_via_alen;
+       u8                      rc_via[MAX_VIA_ALEN];
+       u32                     rc_label;
+       u32                     rc_output_labels;
+       u32                     rc_output_label[MAX_NEW_LABELS];
+       u32                     rc_nlflags;
+       enum mpls_payload_type  rc_payload_type;
+       struct nl_info          rc_nlinfo;
 };
 
 static struct mpls_route *mpls_rt_alloc(size_t alen)
@@ -293,7 +307,7 @@ static void mpls_notify_route(struct net *net, unsigned index,
        struct mpls_route *rt = new ? new : old;
        unsigned nlm_flags = (old && new) ? NLM_F_REPLACE : 0;
        /* Ignore reserved labels for now */
-       if (rt && (index >= 16))
+       if (rt && (index >= MPLS_LABEL_FIRST_UNRESERVED))
                rtmsg_lfib(event, index, rt, nlh, net, portid, nlm_flags);
 }
 
@@ -327,7 +341,8 @@ static unsigned find_free_label(struct net *net)
 
        platform_label = rtnl_dereference(net->mpls.platform_label);
        platform_labels = net->mpls.platform_labels;
-       for (index = 16; index < platform_labels; index++) {
+       for (index = MPLS_LABEL_FIRST_UNRESERVED; index < platform_labels;
+            index++) {
                if (!rtnl_dereference(platform_label[index]))
                        return index;
        }
@@ -337,14 +352,14 @@ static unsigned find_free_label(struct net *net)
 #if IS_ENABLED(CONFIG_INET)
 static struct net_device *inet_fib_lookup_dev(struct net *net, void *addr)
 {
-       struct net_device *dev = NULL;
+       struct net_device *dev;
        struct rtable *rt;
        struct in_addr daddr;
 
        memcpy(&daddr, addr, sizeof(struct in_addr));
        rt = ip_route_output(net, daddr.s_addr, 0, 0, 0);
        if (IS_ERR(rt))
-               goto errout;
+               return ERR_CAST(rt);
 
        dev = rt->dst.dev;
        dev_hold(dev);
@@ -352,8 +367,6 @@ static struct net_device *inet_fib_lookup_dev(struct net *net, void *addr)
        ip_rt_put(rt);
 
        return dev;
-errout:
-       return ERR_PTR(-ENODEV);
 }
 #else
 static struct net_device *inet_fib_lookup_dev(struct net *net, void *addr)
@@ -365,7 +378,7 @@ static struct net_device *inet_fib_lookup_dev(struct net *net, void *addr)
 #if IS_ENABLED(CONFIG_IPV6)
 static struct net_device *inet6_fib_lookup_dev(struct net *net, void *addr)
 {
-       struct net_device *dev = NULL;
+       struct net_device *dev;
        struct dst_entry *dst;
        struct flowi6 fl6;
        int err;
@@ -377,16 +390,13 @@ static struct net_device *inet6_fib_lookup_dev(struct net *net, void *addr)
        memcpy(&fl6.daddr, addr, sizeof(struct in6_addr));
        err = ipv6_stub->ipv6_dst_lookup(net, NULL, &dst, &fl6);
        if (err)
-               goto errout;
+               return ERR_PTR(err);
 
        dev = dst->dev;
        dev_hold(dev);
        dst_release(dst);
 
        return dev;
-
-errout:
-       return ERR_PTR(err);
 }
 #else
 static struct net_device *inet6_fib_lookup_dev(struct net *net, void *addr)
@@ -415,6 +425,9 @@ static struct net_device *find_outdev(struct net *net,
                dev = dev_get_by_index(net, cfg->rc_ifindex);
        }
 
+       if (!dev)
+               return ERR_PTR(-ENODEV);
+
        return dev;
 }
 
@@ -436,8 +449,8 @@ static int mpls_route_add(struct mpls_route_config *cfg)
                index = find_free_label(net);
        }
 
-       /* The first 16 labels are reserved, and may not be set */
-       if (index < 16)
+       /* Reserved labels may not be set */
+       if (index < MPLS_LABEL_FIRST_UNRESERVED)
                goto errout;
 
        /* The full 20 bit range may not be supported. */
@@ -494,6 +507,7 @@ static int mpls_route_add(struct mpls_route_config *cfg)
                rt->rt_label[i] = cfg->rc_output_label[i];
        rt->rt_protocol = cfg->rc_protocol;
        RCU_INIT_POINTER(rt->rt_dev, dev);
+       rt->rt_payload_type = cfg->rc_payload_type;
        rt->rt_via_table = cfg->rc_via_table;
        memcpy(rt->rt_via, cfg->rc_via, cfg->rc_via_alen);
 
@@ -516,8 +530,8 @@ static int mpls_route_del(struct mpls_route_config *cfg)
 
        index = cfg->rc_label;
 
-       /* The first 16 labels are reserved, and may not be removed */
-       if (index < 16)
+       /* Reserved labels may not be removed */
+       if (index < MPLS_LABEL_FIRST_UNRESERVED)
                goto errout;
 
        /* The full 20 bit range may not be supported */
@@ -835,8 +849,8 @@ static int rtm_to_route_config(struct sk_buff *skb,  struct nlmsghdr *nlh,
                                           &cfg->rc_label))
                                goto errout;
 
-                       /* The first 16 labels are reserved, and may not be set */
-                       if (cfg->rc_label < 16)
+                       /* Reserved labels may not be set */
+                       if (cfg->rc_label < MPLS_LABEL_FIRST_UNRESERVED)
                                goto errout;
 
                        break;
@@ -961,8 +975,8 @@ static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb)
        ASSERT_RTNL();
 
        index = cb->args[0];
-       if (index < 16)
-               index = 16;
+       if (index < MPLS_LABEL_FIRST_UNRESERVED)
+               index = MPLS_LABEL_FIRST_UNRESERVED;
 
        platform_label = rtnl_dereference(net->mpls.platform_label);
        platform_labels = net->mpls.platform_labels;
@@ -1048,6 +1062,7 @@ static int resize_platform_label_table(struct net *net, size_t limit)
                        goto nort0;
                RCU_INIT_POINTER(rt0->rt_dev, lo);
                rt0->rt_protocol = RTPROT_KERNEL;
+               rt0->rt_payload_type = MPT_IPV4;
                rt0->rt_via_table = NEIGH_LINK_TABLE;
                memcpy(rt0->rt_via, lo->dev_addr, lo->addr_len);
        }
@@ -1058,6 +1073,7 @@ static int resize_platform_label_table(struct net *net, size_t limit)
                        goto nort2;
                RCU_INIT_POINTER(rt2->rt_dev, lo);
                rt2->rt_protocol = RTPROT_KERNEL;
+               rt2->rt_payload_type = MPT_IPV6;
                rt2->rt_via_table = NEIGH_LINK_TABLE;
                memcpy(rt2->rt_via, lo->dev_addr, lo->addr_len);
        }
@@ -1161,8 +1177,10 @@ static int mpls_net_init(struct net *net)
 
        table[0].data = net;
        net->mpls.ctl = register_net_sysctl(net, "net/mpls", table);
-       if (net->mpls.ctl == NULL)
+       if (net->mpls.ctl == NULL) {
+               kfree(table);
                return -ENOMEM;
+       }
 
        return 0;
 }