Merge remote-tracking branches 'asoc/topic/ad1836', 'asoc/topic/ad193x', 'asoc/topic...
[firefly-linux-kernel-4.4.55.git] / net / ipv4 / udp.c
index 44f6a20fa29df830c1208e825816eaa11785f1ab..f140048334ce21f38d86ac4a2fdf770df5330d78 100644 (file)
@@ -560,15 +560,11 @@ static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb,
                                                 __be16 sport, __be16 dport,
                                                 struct udp_table *udptable)
 {
-       struct sock *sk;
        const struct iphdr *iph = ip_hdr(skb);
 
-       if (unlikely(sk = skb_steal_sock(skb)))
-               return sk;
-       else
-               return __udp4_lib_lookup(dev_net(skb_dst(skb)->dev), iph->saddr, sport,
-                                        iph->daddr, dport, inet_iif(skb),
-                                        udptable);
+       return __udp4_lib_lookup(dev_net(skb_dst(skb)->dev), iph->saddr, sport,
+                                iph->daddr, dport, inet_iif(skb),
+                                udptable);
 }
 
 struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,
@@ -1603,12 +1599,16 @@ static void flush_stack(struct sock **stack, unsigned int count,
                kfree_skb(skb1);
 }
 
-static void udp_sk_rx_dst_set(struct sock *sk, const struct sk_buff *skb)
+/* For TCP sockets, sk_rx_dst is protected by socket lock
+ * For UDP, we use xchg() to guard against concurrent changes.
+ */
+static void udp_sk_rx_dst_set(struct sock *sk, struct dst_entry *dst)
 {
-       struct dst_entry *dst = skb_dst(skb);
+       struct dst_entry *old;
 
        dst_hold(dst);
-       sk->sk_rx_dst = dst;
+       old = xchg(&sk->sk_rx_dst, dst);
+       dst_release(old);
 }
 
 /*
@@ -1739,15 +1739,16 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
        if (udp4_csum_init(skb, uh, proto))
                goto csum_error;
 
-       if (skb->sk) {
+       sk = skb_steal_sock(skb);
+       if (sk) {
+               struct dst_entry *dst = skb_dst(skb);
                int ret;
-               sk = skb->sk;
 
-               if (unlikely(sk->sk_rx_dst == NULL))
-                       udp_sk_rx_dst_set(sk, skb);
+               if (unlikely(sk->sk_rx_dst != dst))
+                       udp_sk_rx_dst_set(sk, dst);
 
                ret = udp_queue_rcv_skb(sk, skb);
-
+               sock_put(sk);
                /* a return value > 0 means to resubmit the input, but
                 * it wants the return to be -protocol, or 0
                 */
@@ -1913,17 +1914,20 @@ static struct sock *__udp4_lib_demux_lookup(struct net *net,
 
 void udp_v4_early_demux(struct sk_buff *skb)
 {
-       const struct iphdr *iph = ip_hdr(skb);
-       const struct udphdr *uh = udp_hdr(skb);
+       struct net *net = dev_net(skb->dev);
+       const struct iphdr *iph;
+       const struct udphdr *uh;
        struct sock *sk;
        struct dst_entry *dst;
-       struct net *net = dev_net(skb->dev);
        int dif = skb->dev->ifindex;
 
        /* validate the packet */
        if (!pskb_may_pull(skb, skb_transport_offset(skb) + sizeof(struct udphdr)))
                return;
 
+       iph = ip_hdr(skb);
+       uh = udp_hdr(skb);
+
        if (skb->pkt_type == PACKET_BROADCAST ||
            skb->pkt_type == PACKET_MULTICAST)
                sk = __udp4_lib_mcast_demux_lookup(net, uh->dest, iph->daddr,