Merge branch 'for-davem' of git://git.kernel.org/pub/scm/linux/kernel/git/bwh/sfc...
[firefly-linux-kernel-4.4.55.git] / net / netlink / af_netlink.c
index faa48f70b7c9b132bfaf3c8ce561982e82c4c7b6..5463969da45b9a30dbbb1802624b0e9e8bf92a29 100644 (file)
@@ -80,6 +80,7 @@ struct netlink_sock {
        struct mutex            *cb_mutex;
        struct mutex            cb_def_mutex;
        void                    (*netlink_rcv)(struct sk_buff *skb);
+       void                    (*netlink_bind)(int group);
        struct module           *module;
 };
 
@@ -104,27 +105,28 @@ static inline int netlink_is_kernel(struct sock *sk)
 }
 
 struct nl_pid_hash {
-       struct hlist_head *table;
-       unsigned long rehash_time;
+       struct hlist_head       *table;
+       unsigned long           rehash_time;
 
-       unsigned int mask;
-       unsigned int shift;
+       unsigned int            mask;
+       unsigned int            shift;
 
-       unsigned int entries;
-       unsigned int max_shift;
+       unsigned int            entries;
+       unsigned int            max_shift;
 
-       u32 rnd;
+       u32                     rnd;
 };
 
 struct netlink_table {
-       struct nl_pid_hash hash;
-       struct hlist_head mc_list;
-       struct listeners __rcu *listeners;
-       unsigned int nl_nonroot;
-       unsigned int groups;
-       struct mutex *cb_mutex;
-       struct module *module;
-       int registered;
+       struct nl_pid_hash      hash;
+       struct hlist_head       mc_list;
+       struct listeners __rcu  *listeners;
+       unsigned int            nl_nonroot;
+       unsigned int            groups;
+       struct mutex            *cb_mutex;
+       struct module           *module;
+       void                    (*bind)(int group);
+       int                     registered;
 };
 
 static struct netlink_table *nl_table;
@@ -132,7 +134,6 @@ static struct netlink_table *nl_table;
 static DECLARE_WAIT_QUEUE_HEAD(nl_table_wait);
 
 static int netlink_dump(struct sock *sk);
-static void netlink_destroy_callback(struct netlink_callback *cb);
 
 static DEFINE_RWLOCK(nl_table_lock);
 static atomic_t nl_table_users = ATOMIC_INIT(0);
@@ -149,6 +150,18 @@ static inline struct hlist_head *nl_pid_hashfn(struct nl_pid_hash *hash, u32 pid
        return &hash->table[jhash_1word(pid, hash->rnd) & hash->mask];
 }
 
+static void netlink_destroy_callback(struct netlink_callback *cb)
+{
+       kfree_skb(cb->skb);
+       kfree(cb);
+}
+
+static void netlink_consume_callback(struct netlink_callback *cb)
+{
+       consume_skb(cb->skb);
+       kfree(cb);
+}
+
 static void netlink_sock_destruct(struct sock *sk)
 {
        struct netlink_sock *nlk = nlk_sk(sk);
@@ -414,9 +427,9 @@ static int __netlink_create(struct net *net, struct socket *sock,
        sock_init_data(sock, sk);
 
        nlk = nlk_sk(sk);
-       if (cb_mutex)
+       if (cb_mutex) {
                nlk->cb_mutex = cb_mutex;
-       else {
+       else {
                nlk->cb_mutex = &nlk->cb_def_mutex;
                mutex_init(nlk->cb_mutex);
        }
@@ -433,6 +446,7 @@ static int netlink_create(struct net *net, struct socket *sock, int protocol,
        struct module *module = NULL;
        struct mutex *cb_mutex;
        struct netlink_sock *nlk;
+       void (*bind)(int group);
        int err = 0;
 
        sock->state = SS_UNCONNECTED;
@@ -457,6 +471,7 @@ static int netlink_create(struct net *net, struct socket *sock, int protocol,
        else
                err = -EPROTONOSUPPORT;
        cb_mutex = nl_table[protocol].cb_mutex;
+       bind = nl_table[protocol].bind;
        netlink_unlock_table();
 
        if (err < 0)
@@ -472,6 +487,7 @@ static int netlink_create(struct net *net, struct socket *sock, int protocol,
 
        nlk = nlk_sk(sock->sk);
        nlk->module = module;
+       nlk->netlink_bind = bind;
 out:
        return err;
 
@@ -522,8 +538,9 @@ static int netlink_release(struct socket *sock)
                        nl_table[sk->sk_protocol].module = NULL;
                        nl_table[sk->sk_protocol].registered = 0;
                }
-       } else if (nlk->subscriptions)
+       } else if (nlk->subscriptions) {
                netlink_update_listeners(sk);
+       }
        netlink_table_ungrab();
 
        kfree(nlk->groups);
@@ -671,6 +688,15 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr,
        netlink_update_listeners(sk);
        netlink_table_ungrab();
 
+       if (nlk->netlink_bind && nlk->groups[0]) {
+               int i;
+
+               for (i=0; i<nlk->ngroups; i++) {
+                       if (test_bit(i, nlk->groups))
+                               nlk->netlink_bind(i);
+               }
+       }
+
        return 0;
 }
 
@@ -866,7 +892,7 @@ static struct sk_buff *netlink_trim(struct sk_buff *skb, gfp_t allocation)
                struct sk_buff *nskb = skb_clone(skb, allocation);
                if (!nskb)
                        return skb;
-               kfree_skb(skb);
+               consume_skb(skb);
                skb = nskb;
        }
 
@@ -896,8 +922,10 @@ static int netlink_unicast_kernel(struct sock *sk, struct sk_buff *skb)
                ret = skb->len;
                skb_set_owner_r(skb, sk);
                nlk->netlink_rcv(skb);
+               consume_skb(skb);
+       } else {
+               kfree_skb(skb);
        }
-       kfree_skb(skb);
        sock_put(sk);
        return ret;
 }
@@ -1086,8 +1114,8 @@ int netlink_broadcast_filtered(struct sock *ssk, struct sk_buff *skb, u32 pid,
        if (info.delivery_failure) {
                kfree_skb(info.skb2);
                return -ENOBUFS;
-       } else
-               consume_skb(info.skb2);
+       }
+       consume_skb(info.skb2);
 
        if (info.delivered) {
                if (info.congested && (allocation & __GFP_WAIT))
@@ -1225,6 +1253,10 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname,
                netlink_update_socket_mc(nlk, val,
                                         optname == NETLINK_ADD_MEMBERSHIP);
                netlink_table_ungrab();
+
+               if (nlk->netlink_bind)
+                       nlk->netlink_bind(val);
+
                err = 0;
                break;
        }
@@ -1240,8 +1272,9 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname,
                        nlk->flags |= NETLINK_RECV_NO_ENOBUFS;
                        clear_bit(0, &nlk->state);
                        wake_up_interruptible(&nlk->wait);
-               } else
+               } else {
                        nlk->flags &= ~NETLINK_RECV_NO_ENOBUFS;
+               }
                err = 0;
                break;
        default:
@@ -1488,14 +1521,16 @@ static void netlink_data_ready(struct sock *sk, int len)
  */
 
 struct sock *
-netlink_kernel_create(struct net *net, int unit, unsigned int groups,
-                     void (*input)(struct sk_buff *skb),
-                     struct mutex *cb_mutex, struct module *module)
+netlink_kernel_create(struct net *net, int unit,
+                     struct module *module,
+                     struct netlink_kernel_cfg *cfg)
 {
        struct socket *sock;
        struct sock *sk;
        struct netlink_sock *nlk;
        struct listeners *listeners = NULL;
+       struct mutex *cb_mutex = cfg ? cfg->cb_mutex : NULL;
+       unsigned int groups;
 
        BUG_ON(!nl_table);
 
@@ -1517,16 +1552,18 @@ netlink_kernel_create(struct net *net, int unit, unsigned int groups,
        sk = sock->sk;
        sk_change_net(sk, net);
 
-       if (groups < 32)
+       if (!cfg || cfg->groups < 32)
                groups = 32;
+       else
+               groups = cfg->groups;
 
        listeners = kzalloc(sizeof(*listeners) + NLGRPSZ(groups), GFP_KERNEL);
        if (!listeners)
                goto out_sock_release;
 
        sk->sk_data_ready = netlink_data_ready;
-       if (input)
-               nlk_sk(sk)->netlink_rcv = input;
+       if (cfg && cfg->input)
+               nlk_sk(sk)->netlink_rcv = cfg->input;
 
        if (netlink_insert(sk, net, 0))
                goto out_sock_release;
@@ -1540,6 +1577,7 @@ netlink_kernel_create(struct net *net, int unit, unsigned int groups,
                rcu_assign_pointer(nl_table[unit].listeners, listeners);
                nl_table[unit].cb_mutex = cb_mutex;
                nl_table[unit].module = module;
+               nl_table[unit].bind = cfg ? cfg->bind : NULL;
                nl_table[unit].registered = 1;
        } else {
                kfree(listeners);
@@ -1645,12 +1683,6 @@ void netlink_set_nonroot(int protocol, unsigned int flags)
 }
 EXPORT_SYMBOL(netlink_set_nonroot);
 
-static void netlink_destroy_callback(struct netlink_callback *cb)
-{
-       kfree_skb(cb->skb);
-       kfree(cb);
-}
-
 struct nlmsghdr *
 __nlmsg_put(struct sk_buff *skb, u32 pid, u32 seq, int type, int len, int flags)
 {
@@ -1727,7 +1759,7 @@ static int netlink_dump(struct sock *sk)
        nlk->cb = NULL;
        mutex_unlock(nlk->cb_mutex);
 
-       netlink_destroy_callback(cb);
+       netlink_consume_callback(cb);
        return 0;
 
 errout_skb:
@@ -1996,11 +2028,11 @@ static void netlink_seq_stop(struct seq_file *seq, void *v)
 
 static int netlink_seq_show(struct seq_file *seq, void *v)
 {
-       if (v == SEQ_START_TOKEN)
+       if (v == SEQ_START_TOKEN) {
                seq_puts(seq,
                         "sk       Eth Pid    Groups   "
                         "Rmem     Wmem     Dump     Locks     Drops     Inode\n");
-       else {
+       else {
                struct sock *s = v;
                struct netlink_sock *nlk = nlk_sk(s);