batman-adv: Correct rcu refcounting for neigh_node
[firefly-linux-kernel-4.4.55.git] / net / batman-adv / icmp_socket.c
index 8e0cd8a1bc0292b4121aa5e33d243f51e39187ac..7fa5bb8a940921b42fa91c6e470c456bb3e90c02 100644 (file)
@@ -156,7 +156,8 @@ static ssize_t bat_socket_write(struct file *file, const char __user *buff,
        struct sk_buff *skb;
        struct icmp_packet_rr *icmp_packet;
 
-       struct orig_node *orig_node;
+       struct orig_node *orig_node = NULL;
+       struct neigh_node *neigh_node = NULL;
        struct batman_if *batman_if;
        size_t packet_len = sizeof(struct icmp_packet);
        uint8_t dstaddr[ETH_ALEN];
@@ -224,17 +225,25 @@ static ssize_t bat_socket_write(struct file *file, const char __user *buff,
        orig_node = ((struct orig_node *)hash_find(bat_priv->orig_hash,
                                                   compare_orig, choose_orig,
                                                   icmp_packet->dst));
-       rcu_read_unlock();
 
        if (!orig_node)
                goto unlock;
 
-       if (!orig_node->router)
+       kref_get(&orig_node->refcount);
+       neigh_node = orig_node->router;
+
+       if (!neigh_node)
+               goto unlock;
+
+       if (!atomic_inc_not_zero(&neigh_node->refcount)) {
+               neigh_node = NULL;
                goto unlock;
+       }
+
+       rcu_read_unlock();
 
        batman_if = orig_node->router->if_incoming;
        memcpy(dstaddr, orig_node->router->addr, ETH_ALEN);
-
        spin_unlock_bh(&bat_priv->orig_hash_lock);
 
        if (!batman_if)
@@ -247,14 +256,14 @@ static ssize_t bat_socket_write(struct file *file, const char __user *buff,
               bat_priv->primary_if->net_dev->dev_addr, ETH_ALEN);
 
        if (packet_len == sizeof(struct icmp_packet_rr))
-               memcpy(icmp_packet->rr, batman_if->net_dev->dev_addr, ETH_ALEN);
-
+               memcpy(icmp_packet->rr,
+                      batman_if->net_dev->dev_addr, ETH_ALEN);
 
        send_skb_packet(skb, batman_if, dstaddr);
-
        goto out;
 
 unlock:
+       rcu_read_unlock();
        spin_unlock_bh(&bat_priv->orig_hash_lock);
 dst_unreach:
        icmp_packet->msg_type = DESTINATION_UNREACHABLE;
@@ -262,6 +271,10 @@ dst_unreach:
 free_skb:
        kfree_skb(skb);
 out:
+       if (neigh_node)
+               neigh_node_free_ref(neigh_node);
+       if (orig_node)
+               kref_put(&orig_node->refcount, orig_node_free_ref);
        return len;
 }