Merge git://git.kernel.org/pub/scm/linux/kernel/git/davem/net
[firefly-linux-kernel-4.4.55.git] / net / netfilter / ipset / ip_set_core.c
1 /* Copyright (C) 2000-2002 Joakim Axelsson <gozem@linux.nu>
2  *                         Patrick Schaaf <bof@bof.de>
3  * Copyright (C) 2003-2011 Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License version 2 as
7  * published by the Free Software Foundation.
8  */
9
10 /* Kernel module for IP set management */
11
12 #include <linux/init.h>
13 #include <linux/module.h>
14 #include <linux/moduleparam.h>
15 #include <linux/ip.h>
16 #include <linux/skbuff.h>
17 #include <linux/spinlock.h>
18 #include <linux/rculist.h>
19 #include <net/netlink.h>
20
21 #include <linux/netfilter.h>
22 #include <linux/netfilter/x_tables.h>
23 #include <linux/netfilter/nfnetlink.h>
24 #include <linux/netfilter/ipset/ip_set.h>
25
26 static LIST_HEAD(ip_set_type_list);             /* all registered set types */
27 static DEFINE_MUTEX(ip_set_type_mutex);         /* protects ip_set_type_list */
28 static DEFINE_RWLOCK(ip_set_ref_lock);          /* protects the set refs */
29
30 static struct ip_set * __rcu *ip_set_list;      /* all individual sets */
31 static ip_set_id_t ip_set_max = CONFIG_IP_SET_MAX; /* max number of sets */
32
33 #define IP_SET_INC      64
34 #define STREQ(a, b)     (strncmp(a, b, IPSET_MAXNAMELEN) == 0)
35
36 static unsigned int max_sets;
37
38 module_param(max_sets, int, 0600);
39 MODULE_PARM_DESC(max_sets, "maximal number of sets");
40 MODULE_LICENSE("GPL");
41 MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>");
42 MODULE_DESCRIPTION("core IP set support");
43 MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_IPSET);
44
45 /* When the nfnl mutex is held: */
46 #define nfnl_dereference(p)             \
47         rcu_dereference_protected(p, 1)
48 #define nfnl_set(id)                    \
49         nfnl_dereference(ip_set_list)[id]
50
51 /*
52  * The set types are implemented in modules and registered set types
53  * can be found in ip_set_type_list. Adding/deleting types is
54  * serialized by ip_set_type_mutex.
55  */
56
57 static inline void
58 ip_set_type_lock(void)
59 {
60         mutex_lock(&ip_set_type_mutex);
61 }
62
63 static inline void
64 ip_set_type_unlock(void)
65 {
66         mutex_unlock(&ip_set_type_mutex);
67 }
68
69 /* Register and deregister settype */
70
71 static struct ip_set_type *
72 find_set_type(const char *name, u8 family, u8 revision)
73 {
74         struct ip_set_type *type;
75
76         list_for_each_entry_rcu(type, &ip_set_type_list, list)
77                 if (STREQ(type->name, name) &&
78                     (type->family == family ||
79                      type->family == NFPROTO_UNSPEC) &&
80                     revision >= type->revision_min &&
81                     revision <= type->revision_max)
82                         return type;
83         return NULL;
84 }
85
86 /* Unlock, try to load a set type module and lock again */
87 static bool
88 load_settype(const char *name)
89 {
90         nfnl_unlock(NFNL_SUBSYS_IPSET);
91         pr_debug("try to load ip_set_%s\n", name);
92         if (request_module("ip_set_%s", name) < 0) {
93                 pr_warning("Can't find ip_set type %s\n", name);
94                 nfnl_lock(NFNL_SUBSYS_IPSET);
95                 return false;
96         }
97         nfnl_lock(NFNL_SUBSYS_IPSET);
98         return true;
99 }
100
101 /* Find a set type and reference it */
102 #define find_set_type_get(name, family, revision, found)        \
103         __find_set_type_get(name, family, revision, found, false)
104
105 static int
106 __find_set_type_get(const char *name, u8 family, u8 revision,
107                     struct ip_set_type **found, bool retry)
108 {
109         struct ip_set_type *type;
110         int err;
111
112         if (retry && !load_settype(name))
113                 return -IPSET_ERR_FIND_TYPE;
114
115         rcu_read_lock();
116         *found = find_set_type(name, family, revision);
117         if (*found) {
118                 err = !try_module_get((*found)->me) ? -EFAULT : 0;
119                 goto unlock;
120         }
121         /* Make sure the type is already loaded
122          * but we don't support the revision */
123         list_for_each_entry_rcu(type, &ip_set_type_list, list)
124                 if (STREQ(type->name, name)) {
125                         err = -IPSET_ERR_FIND_TYPE;
126                         goto unlock;
127                 }
128         rcu_read_unlock();
129
130         return retry ? -IPSET_ERR_FIND_TYPE :
131                 __find_set_type_get(name, family, revision, found, true);
132
133 unlock:
134         rcu_read_unlock();
135         return err;
136 }
137
138 /* Find a given set type by name and family.
139  * If we succeeded, the supported minimal and maximum revisions are
140  * filled out.
141  */
142 #define find_set_type_minmax(name, family, min, max) \
143         __find_set_type_minmax(name, family, min, max, false)
144
145 static int
146 __find_set_type_minmax(const char *name, u8 family, u8 *min, u8 *max,
147                        bool retry)
148 {
149         struct ip_set_type *type;
150         bool found = false;
151
152         if (retry && !load_settype(name))
153                 return -IPSET_ERR_FIND_TYPE;
154
155         *min = 255; *max = 0;
156         rcu_read_lock();
157         list_for_each_entry_rcu(type, &ip_set_type_list, list)
158                 if (STREQ(type->name, name) &&
159                     (type->family == family ||
160                      type->family == NFPROTO_UNSPEC)) {
161                         found = true;
162                         if (type->revision_min < *min)
163                                 *min = type->revision_min;
164                         if (type->revision_max > *max)
165                                 *max = type->revision_max;
166                 }
167         rcu_read_unlock();
168         if (found)
169                 return 0;
170
171         return retry ? -IPSET_ERR_FIND_TYPE :
172                 __find_set_type_minmax(name, family, min, max, true);
173 }
174
175 #define family_name(f)  ((f) == NFPROTO_IPV4 ? "inet" : \
176                          (f) == NFPROTO_IPV6 ? "inet6" : "any")
177
178 /* Register a set type structure. The type is identified by
179  * the unique triple of name, family and revision.
180  */
181 int
182 ip_set_type_register(struct ip_set_type *type)
183 {
184         int ret = 0;
185
186         if (type->protocol != IPSET_PROTOCOL) {
187                 pr_warning("ip_set type %s, family %s, revision %u:%u uses "
188                            "wrong protocol version %u (want %u)\n",
189                            type->name, family_name(type->family),
190                            type->revision_min, type->revision_max,
191                            type->protocol, IPSET_PROTOCOL);
192                 return -EINVAL;
193         }
194
195         ip_set_type_lock();
196         if (find_set_type(type->name, type->family, type->revision_min)) {
197                 /* Duplicate! */
198                 pr_warning("ip_set type %s, family %s with revision min %u "
199                            "already registered!\n", type->name,
200                            family_name(type->family), type->revision_min);
201                 ret = -EINVAL;
202                 goto unlock;
203         }
204         list_add_rcu(&type->list, &ip_set_type_list);
205         pr_debug("type %s, family %s, revision %u:%u registered.\n",
206                  type->name, family_name(type->family),
207                  type->revision_min, type->revision_max);
208 unlock:
209         ip_set_type_unlock();
210         return ret;
211 }
212 EXPORT_SYMBOL_GPL(ip_set_type_register);
213
214 /* Unregister a set type. There's a small race with ip_set_create */
215 void
216 ip_set_type_unregister(struct ip_set_type *type)
217 {
218         ip_set_type_lock();
219         if (!find_set_type(type->name, type->family, type->revision_min)) {
220                 pr_warning("ip_set type %s, family %s with revision min %u "
221                            "not registered\n", type->name,
222                            family_name(type->family), type->revision_min);
223                 goto unlock;
224         }
225         list_del_rcu(&type->list);
226         pr_debug("type %s, family %s with revision min %u unregistered.\n",
227                  type->name, family_name(type->family), type->revision_min);
228 unlock:
229         ip_set_type_unlock();
230
231         synchronize_rcu();
232 }
233 EXPORT_SYMBOL_GPL(ip_set_type_unregister);
234
235 /* Utility functions */
236 void *
237 ip_set_alloc(size_t size)
238 {
239         void *members = NULL;
240
241         if (size < KMALLOC_MAX_SIZE)
242                 members = kzalloc(size, GFP_KERNEL | __GFP_NOWARN);
243
244         if (members) {
245                 pr_debug("%p: allocated with kmalloc\n", members);
246                 return members;
247         }
248
249         members = vzalloc(size);
250         if (!members)
251                 return NULL;
252         pr_debug("%p: allocated with vmalloc\n", members);
253
254         return members;
255 }
256 EXPORT_SYMBOL_GPL(ip_set_alloc);
257
258 void
259 ip_set_free(void *members)
260 {
261         pr_debug("%p: free with %s\n", members,
262                  is_vmalloc_addr(members) ? "vfree" : "kfree");
263         if (is_vmalloc_addr(members))
264                 vfree(members);
265         else
266                 kfree(members);
267 }
268 EXPORT_SYMBOL_GPL(ip_set_free);
269
270 static inline bool
271 flag_nested(const struct nlattr *nla)
272 {
273         return nla->nla_type & NLA_F_NESTED;
274 }
275
276 static const struct nla_policy ipaddr_policy[IPSET_ATTR_IPADDR_MAX + 1] = {
277         [IPSET_ATTR_IPADDR_IPV4]        = { .type = NLA_U32 },
278         [IPSET_ATTR_IPADDR_IPV6]        = { .type = NLA_BINARY,
279                                             .len = sizeof(struct in6_addr) },
280 };
281
282 int
283 ip_set_get_ipaddr4(struct nlattr *nla,  __be32 *ipaddr)
284 {
285         struct nlattr *tb[IPSET_ATTR_IPADDR_MAX+1];
286
287         if (unlikely(!flag_nested(nla)))
288                 return -IPSET_ERR_PROTOCOL;
289         if (nla_parse_nested(tb, IPSET_ATTR_IPADDR_MAX, nla, ipaddr_policy))
290                 return -IPSET_ERR_PROTOCOL;
291         if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV4)))
292                 return -IPSET_ERR_PROTOCOL;
293
294         *ipaddr = nla_get_be32(tb[IPSET_ATTR_IPADDR_IPV4]);
295         return 0;
296 }
297 EXPORT_SYMBOL_GPL(ip_set_get_ipaddr4);
298
299 int
300 ip_set_get_ipaddr6(struct nlattr *nla, union nf_inet_addr *ipaddr)
301 {
302         struct nlattr *tb[IPSET_ATTR_IPADDR_MAX+1];
303
304         if (unlikely(!flag_nested(nla)))
305                 return -IPSET_ERR_PROTOCOL;
306
307         if (nla_parse_nested(tb, IPSET_ATTR_IPADDR_MAX, nla, ipaddr_policy))
308                 return -IPSET_ERR_PROTOCOL;
309         if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV6)))
310                 return -IPSET_ERR_PROTOCOL;
311
312         memcpy(ipaddr, nla_data(tb[IPSET_ATTR_IPADDR_IPV6]),
313                 sizeof(struct in6_addr));
314         return 0;
315 }
316 EXPORT_SYMBOL_GPL(ip_set_get_ipaddr6);
317
318 /*
319  * Creating/destroying/renaming/swapping affect the existence and
320  * the properties of a set. All of these can be executed from userspace
321  * only and serialized by the nfnl mutex indirectly from nfnetlink.
322  *
323  * Sets are identified by their index in ip_set_list and the index
324  * is used by the external references (set/SET netfilter modules).
325  *
326  * The set behind an index may change by swapping only, from userspace.
327  */
328
329 static inline void
330 __ip_set_get(struct ip_set *set)
331 {
332         write_lock_bh(&ip_set_ref_lock);
333         set->ref++;
334         write_unlock_bh(&ip_set_ref_lock);
335 }
336
337 static inline void
338 __ip_set_put(struct ip_set *set)
339 {
340         write_lock_bh(&ip_set_ref_lock);
341         BUG_ON(set->ref == 0);
342         set->ref--;
343         write_unlock_bh(&ip_set_ref_lock);
344 }
345
346 /*
347  * Add, del and test set entries from kernel.
348  *
349  * The set behind the index must exist and must be referenced
350  * so it can't be destroyed (or changed) under our foot.
351  */
352
353 static inline struct ip_set *
354 ip_set_rcu_get(ip_set_id_t index)
355 {
356         struct ip_set *set;
357
358         rcu_read_lock();
359         /* ip_set_list itself needs to be protected */
360         set = rcu_dereference(ip_set_list)[index];
361         rcu_read_unlock();
362
363         return set;
364 }
365
366 int
367 ip_set_test(ip_set_id_t index, const struct sk_buff *skb,
368             const struct xt_action_param *par,
369             const struct ip_set_adt_opt *opt)
370 {
371         struct ip_set *set = ip_set_rcu_get(index);
372         int ret = 0;
373
374         BUG_ON(set == NULL);
375         pr_debug("set %s, index %u\n", set->name, index);
376
377         if (opt->dim < set->type->dimension ||
378             !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
379                 return 0;
380
381         read_lock_bh(&set->lock);
382         ret = set->variant->kadt(set, skb, par, IPSET_TEST, opt);
383         read_unlock_bh(&set->lock);
384
385         if (ret == -EAGAIN) {
386                 /* Type requests element to be completed */
387                 pr_debug("element must be competed, ADD is triggered\n");
388                 write_lock_bh(&set->lock);
389                 set->variant->kadt(set, skb, par, IPSET_ADD, opt);
390                 write_unlock_bh(&set->lock);
391                 ret = 1;
392         } else {
393                 /* --return-nomatch: invert matched element */
394                 if ((opt->flags & IPSET_RETURN_NOMATCH) &&
395                     (set->type->features & IPSET_TYPE_NOMATCH) &&
396                     (ret > 0 || ret == -ENOTEMPTY))
397                         ret = -ret;
398         }
399
400         /* Convert error codes to nomatch */
401         return (ret < 0 ? 0 : ret);
402 }
403 EXPORT_SYMBOL_GPL(ip_set_test);
404
405 int
406 ip_set_add(ip_set_id_t index, const struct sk_buff *skb,
407            const struct xt_action_param *par,
408            const struct ip_set_adt_opt *opt)
409 {
410         struct ip_set *set = ip_set_rcu_get(index);
411         int ret;
412
413         BUG_ON(set == NULL);
414         pr_debug("set %s, index %u\n", set->name, index);
415
416         if (opt->dim < set->type->dimension ||
417             !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
418                 return 0;
419
420         write_lock_bh(&set->lock);
421         ret = set->variant->kadt(set, skb, par, IPSET_ADD, opt);
422         write_unlock_bh(&set->lock);
423
424         return ret;
425 }
426 EXPORT_SYMBOL_GPL(ip_set_add);
427
428 int
429 ip_set_del(ip_set_id_t index, const struct sk_buff *skb,
430            const struct xt_action_param *par,
431            const struct ip_set_adt_opt *opt)
432 {
433         struct ip_set *set = ip_set_rcu_get(index);
434         int ret = 0;
435
436         BUG_ON(set == NULL);
437         pr_debug("set %s, index %u\n", set->name, index);
438
439         if (opt->dim < set->type->dimension ||
440             !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
441                 return 0;
442
443         write_lock_bh(&set->lock);
444         ret = set->variant->kadt(set, skb, par, IPSET_DEL, opt);
445         write_unlock_bh(&set->lock);
446
447         return ret;
448 }
449 EXPORT_SYMBOL_GPL(ip_set_del);
450
451 /*
452  * Find set by name, reference it once. The reference makes sure the
453  * thing pointed to, does not go away under our feet.
454  *
455  */
456 ip_set_id_t
457 ip_set_get_byname(const char *name, struct ip_set **set)
458 {
459         ip_set_id_t i, index = IPSET_INVALID_ID;
460         struct ip_set *s;
461
462         rcu_read_lock();
463         for (i = 0; i < ip_set_max; i++) {
464                 s = rcu_dereference(ip_set_list)[i];
465                 if (s != NULL && STREQ(s->name, name)) {
466                         __ip_set_get(s);
467                         index = i;
468                         *set = s;
469                         break;
470                 }
471         }
472         rcu_read_unlock();
473
474         return index;
475 }
476 EXPORT_SYMBOL_GPL(ip_set_get_byname);
477
478 /*
479  * If the given set pointer points to a valid set, decrement
480  * reference count by 1. The caller shall not assume the index
481  * to be valid, after calling this function.
482  *
483  */
484 void
485 ip_set_put_byindex(ip_set_id_t index)
486 {
487         struct ip_set *set;
488
489         rcu_read_lock();
490         set = rcu_dereference(ip_set_list)[index];
491         if (set != NULL)
492                 __ip_set_put(set);
493         rcu_read_unlock();
494 }
495 EXPORT_SYMBOL_GPL(ip_set_put_byindex);
496
497 /*
498  * Get the name of a set behind a set index.
499  * We assume the set is referenced, so it does exist and
500  * can't be destroyed. The set cannot be renamed due to
501  * the referencing either.
502  *
503  */
504 const char *
505 ip_set_name_byindex(ip_set_id_t index)
506 {
507         const struct ip_set *set = ip_set_rcu_get(index);
508
509         BUG_ON(set == NULL);
510         BUG_ON(set->ref == 0);
511
512         /* Referenced, so it's safe */
513         return set->name;
514 }
515 EXPORT_SYMBOL_GPL(ip_set_name_byindex);
516
517 /*
518  * Routines to call by external subsystems, which do not
519  * call nfnl_lock for us.
520  */
521
522 /*
523  * Find set by name, reference it once. The reference makes sure the
524  * thing pointed to, does not go away under our feet.
525  *
526  * The nfnl mutex is used in the function.
527  */
528 ip_set_id_t
529 ip_set_nfnl_get(const char *name)
530 {
531         ip_set_id_t i, index = IPSET_INVALID_ID;
532         struct ip_set *s;
533
534         nfnl_lock(NFNL_SUBSYS_IPSET);
535         for (i = 0; i < ip_set_max; i++) {
536                 s = nfnl_set(i);
537                 if (s != NULL && STREQ(s->name, name)) {
538                         __ip_set_get(s);
539                         index = i;
540                         break;
541                 }
542         }
543         nfnl_unlock(NFNL_SUBSYS_IPSET);
544
545         return index;
546 }
547 EXPORT_SYMBOL_GPL(ip_set_nfnl_get);
548
549 /*
550  * Find set by index, reference it once. The reference makes sure the
551  * thing pointed to, does not go away under our feet.
552  *
553  * The nfnl mutex is used in the function.
554  */
555 ip_set_id_t
556 ip_set_nfnl_get_byindex(ip_set_id_t index)
557 {
558         struct ip_set *set;
559
560         if (index > ip_set_max)
561                 return IPSET_INVALID_ID;
562
563         nfnl_lock(NFNL_SUBSYS_IPSET);
564         set = nfnl_set(index);
565         if (set)
566                 __ip_set_get(set);
567         else
568                 index = IPSET_INVALID_ID;
569         nfnl_unlock(NFNL_SUBSYS_IPSET);
570
571         return index;
572 }
573 EXPORT_SYMBOL_GPL(ip_set_nfnl_get_byindex);
574
575 /*
576  * If the given set pointer points to a valid set, decrement
577  * reference count by 1. The caller shall not assume the index
578  * to be valid, after calling this function.
579  *
580  * The nfnl mutex is used in the function.
581  */
582 void
583 ip_set_nfnl_put(ip_set_id_t index)
584 {
585         struct ip_set *set;
586         nfnl_lock(NFNL_SUBSYS_IPSET);
587         set = nfnl_set(index);
588         if (set != NULL)
589                 __ip_set_put(set);
590         nfnl_unlock(NFNL_SUBSYS_IPSET);
591 }
592 EXPORT_SYMBOL_GPL(ip_set_nfnl_put);
593
594 /*
595  * Communication protocol with userspace over netlink.
596  *
597  * The commands are serialized by the nfnl mutex.
598  */
599
600 static inline bool
601 protocol_failed(const struct nlattr * const tb[])
602 {
603         return !tb[IPSET_ATTR_PROTOCOL] ||
604                nla_get_u8(tb[IPSET_ATTR_PROTOCOL]) != IPSET_PROTOCOL;
605 }
606
607 static inline u32
608 flag_exist(const struct nlmsghdr *nlh)
609 {
610         return nlh->nlmsg_flags & NLM_F_EXCL ? 0 : IPSET_FLAG_EXIST;
611 }
612
613 static struct nlmsghdr *
614 start_msg(struct sk_buff *skb, u32 portid, u32 seq, unsigned int flags,
615           enum ipset_cmd cmd)
616 {
617         struct nlmsghdr *nlh;
618         struct nfgenmsg *nfmsg;
619
620         nlh = nlmsg_put(skb, portid, seq, cmd | (NFNL_SUBSYS_IPSET << 8),
621                         sizeof(*nfmsg), flags);
622         if (nlh == NULL)
623                 return NULL;
624
625         nfmsg = nlmsg_data(nlh);
626         nfmsg->nfgen_family = NFPROTO_IPV4;
627         nfmsg->version = NFNETLINK_V0;
628         nfmsg->res_id = 0;
629
630         return nlh;
631 }
632
633 /* Create a set */
634
635 static const struct nla_policy ip_set_create_policy[IPSET_ATTR_CMD_MAX + 1] = {
636         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
637         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
638                                     .len = IPSET_MAXNAMELEN - 1 },
639         [IPSET_ATTR_TYPENAME]   = { .type = NLA_NUL_STRING,
640                                     .len = IPSET_MAXNAMELEN - 1},
641         [IPSET_ATTR_REVISION]   = { .type = NLA_U8 },
642         [IPSET_ATTR_FAMILY]     = { .type = NLA_U8 },
643         [IPSET_ATTR_DATA]       = { .type = NLA_NESTED },
644 };
645
646 static struct ip_set *
647 find_set_and_id(const char *name, ip_set_id_t *id)
648 {
649         struct ip_set *set = NULL;
650         ip_set_id_t i;
651
652         *id = IPSET_INVALID_ID;
653         for (i = 0; i < ip_set_max; i++) {
654                 set = nfnl_set(i);
655                 if (set != NULL && STREQ(set->name, name)) {
656                         *id = i;
657                         break;
658                 }
659         }
660         return (*id == IPSET_INVALID_ID ? NULL : set);
661 }
662
663 static inline struct ip_set *
664 find_set(const char *name)
665 {
666         ip_set_id_t id;
667
668         return find_set_and_id(name, &id);
669 }
670
671 static int
672 find_free_id(const char *name, ip_set_id_t *index, struct ip_set **set)
673 {
674         struct ip_set *s;
675         ip_set_id_t i;
676
677         *index = IPSET_INVALID_ID;
678         for (i = 0;  i < ip_set_max; i++) {
679                 s = nfnl_set(i);
680                 if (s == NULL) {
681                         if (*index == IPSET_INVALID_ID)
682                                 *index = i;
683                 } else if (STREQ(name, s->name)) {
684                         /* Name clash */
685                         *set = s;
686                         return -EEXIST;
687                 }
688         }
689         if (*index == IPSET_INVALID_ID)
690                 /* No free slot remained */
691                 return -IPSET_ERR_MAX_SETS;
692         return 0;
693 }
694
695 static int
696 ip_set_none(struct sock *ctnl, struct sk_buff *skb,
697             const struct nlmsghdr *nlh,
698             const struct nlattr * const attr[])
699 {
700         return -EOPNOTSUPP;
701 }
702
703 static int
704 ip_set_create(struct sock *ctnl, struct sk_buff *skb,
705               const struct nlmsghdr *nlh,
706               const struct nlattr * const attr[])
707 {
708         struct ip_set *set, *clash = NULL;
709         ip_set_id_t index = IPSET_INVALID_ID;
710         struct nlattr *tb[IPSET_ATTR_CREATE_MAX+1] = {};
711         const char *name, *typename;
712         u8 family, revision;
713         u32 flags = flag_exist(nlh);
714         int ret = 0;
715
716         if (unlikely(protocol_failed(attr) ||
717                      attr[IPSET_ATTR_SETNAME] == NULL ||
718                      attr[IPSET_ATTR_TYPENAME] == NULL ||
719                      attr[IPSET_ATTR_REVISION] == NULL ||
720                      attr[IPSET_ATTR_FAMILY] == NULL ||
721                      (attr[IPSET_ATTR_DATA] != NULL &&
722                       !flag_nested(attr[IPSET_ATTR_DATA]))))
723                 return -IPSET_ERR_PROTOCOL;
724
725         name = nla_data(attr[IPSET_ATTR_SETNAME]);
726         typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
727         family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
728         revision = nla_get_u8(attr[IPSET_ATTR_REVISION]);
729         pr_debug("setname: %s, typename: %s, family: %s, revision: %u\n",
730                  name, typename, family_name(family), revision);
731
732         /*
733          * First, and without any locks, allocate and initialize
734          * a normal base set structure.
735          */
736         set = kzalloc(sizeof(struct ip_set), GFP_KERNEL);
737         if (!set)
738                 return -ENOMEM;
739         rwlock_init(&set->lock);
740         strlcpy(set->name, name, IPSET_MAXNAMELEN);
741         set->family = family;
742         set->revision = revision;
743
744         /*
745          * Next, check that we know the type, and take
746          * a reference on the type, to make sure it stays available
747          * while constructing our new set.
748          *
749          * After referencing the type, we try to create the type
750          * specific part of the set without holding any locks.
751          */
752         ret = find_set_type_get(typename, family, revision, &(set->type));
753         if (ret)
754                 goto out;
755
756         /*
757          * Without holding any locks, create private part.
758          */
759         if (attr[IPSET_ATTR_DATA] &&
760             nla_parse_nested(tb, IPSET_ATTR_CREATE_MAX, attr[IPSET_ATTR_DATA],
761                              set->type->create_policy)) {
762                 ret = -IPSET_ERR_PROTOCOL;
763                 goto put_out;
764         }
765
766         ret = set->type->create(set, tb, flags);
767         if (ret != 0)
768                 goto put_out;
769
770         /* BTW, ret==0 here. */
771
772         /*
773          * Here, we have a valid, constructed set and we are protected
774          * by the nfnl mutex. Find the first free index in ip_set_list
775          * and check clashing.
776          */
777         ret = find_free_id(set->name, &index, &clash);
778         if (ret == -EEXIST) {
779                 /* If this is the same set and requested, ignore error */
780                 if ((flags & IPSET_FLAG_EXIST) &&
781                     STREQ(set->type->name, clash->type->name) &&
782                     set->type->family == clash->type->family &&
783                     set->type->revision_min == clash->type->revision_min &&
784                     set->type->revision_max == clash->type->revision_max &&
785                     set->variant->same_set(set, clash))
786                         ret = 0;
787                 goto cleanup;
788         } else if (ret == -IPSET_ERR_MAX_SETS) {
789                 struct ip_set **list, **tmp;
790                 ip_set_id_t i = ip_set_max + IP_SET_INC;
791
792                 if (i < ip_set_max || i == IPSET_INVALID_ID)
793                         /* Wraparound */
794                         goto cleanup;
795
796                 list = kzalloc(sizeof(struct ip_set *) * i, GFP_KERNEL);
797                 if (!list)
798                         goto cleanup;
799                 /* nfnl mutex is held, both lists are valid */
800                 tmp = nfnl_dereference(ip_set_list);
801                 memcpy(list, tmp, sizeof(struct ip_set *) * ip_set_max);
802                 rcu_assign_pointer(ip_set_list, list);
803                 /* Make sure all current packets have passed through */
804                 synchronize_net();
805                 /* Use new list */
806                 index = ip_set_max;
807                 ip_set_max = i;
808                 kfree(tmp);
809                 ret = 0;
810         } else if (ret)
811                 goto cleanup;
812
813         /*
814          * Finally! Add our shiny new set to the list, and be done.
815          */
816         pr_debug("create: '%s' created with index %u!\n", set->name, index);
817         nfnl_set(index) = set;
818
819         return ret;
820
821 cleanup:
822         set->variant->destroy(set);
823 put_out:
824         module_put(set->type->me);
825 out:
826         kfree(set);
827         return ret;
828 }
829
830 /* Destroy sets */
831
832 static const struct nla_policy
833 ip_set_setname_policy[IPSET_ATTR_CMD_MAX + 1] = {
834         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
835         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
836                                     .len = IPSET_MAXNAMELEN - 1 },
837 };
838
839 static void
840 ip_set_destroy_set(ip_set_id_t index)
841 {
842         struct ip_set *set = nfnl_set(index);
843
844         pr_debug("set: %s\n",  set->name);
845         nfnl_set(index) = NULL;
846
847         /* Must call it without holding any lock */
848         set->variant->destroy(set);
849         module_put(set->type->me);
850         kfree(set);
851 }
852
853 static int
854 ip_set_destroy(struct sock *ctnl, struct sk_buff *skb,
855                const struct nlmsghdr *nlh,
856                const struct nlattr * const attr[])
857 {
858         struct ip_set *s;
859         ip_set_id_t i;
860         int ret = 0;
861
862         if (unlikely(protocol_failed(attr)))
863                 return -IPSET_ERR_PROTOCOL;
864
865         /* Commands are serialized and references are
866          * protected by the ip_set_ref_lock.
867          * External systems (i.e. xt_set) must call
868          * ip_set_put|get_nfnl_* functions, that way we
869          * can safely check references here.
870          *
871          * list:set timer can only decrement the reference
872          * counter, so if it's already zero, we can proceed
873          * without holding the lock.
874          */
875         read_lock_bh(&ip_set_ref_lock);
876         if (!attr[IPSET_ATTR_SETNAME]) {
877                 for (i = 0; i < ip_set_max; i++) {
878                         s = nfnl_set(i);
879                         if (s != NULL && s->ref) {
880                                 ret = -IPSET_ERR_BUSY;
881                                 goto out;
882                         }
883                 }
884                 read_unlock_bh(&ip_set_ref_lock);
885                 for (i = 0; i < ip_set_max; i++) {
886                         s = nfnl_set(i);
887                         if (s != NULL)
888                                 ip_set_destroy_set(i);
889                 }
890         } else {
891                 s = find_set_and_id(nla_data(attr[IPSET_ATTR_SETNAME]), &i);
892                 if (s == NULL) {
893                         ret = -ENOENT;
894                         goto out;
895                 } else if (s->ref) {
896                         ret = -IPSET_ERR_BUSY;
897                         goto out;
898                 }
899                 read_unlock_bh(&ip_set_ref_lock);
900
901                 ip_set_destroy_set(i);
902         }
903         return 0;
904 out:
905         read_unlock_bh(&ip_set_ref_lock);
906         return ret;
907 }
908
909 /* Flush sets */
910
911 static void
912 ip_set_flush_set(struct ip_set *set)
913 {
914         pr_debug("set: %s\n",  set->name);
915
916         write_lock_bh(&set->lock);
917         set->variant->flush(set);
918         write_unlock_bh(&set->lock);
919 }
920
921 static int
922 ip_set_flush(struct sock *ctnl, struct sk_buff *skb,
923              const struct nlmsghdr *nlh,
924              const struct nlattr * const attr[])
925 {
926         struct ip_set *s;
927         ip_set_id_t i;
928
929         if (unlikely(protocol_failed(attr)))
930                 return -IPSET_ERR_PROTOCOL;
931
932         if (!attr[IPSET_ATTR_SETNAME]) {
933                 for (i = 0; i < ip_set_max; i++) {
934                         s = nfnl_set(i);
935                         if (s != NULL)
936                                 ip_set_flush_set(s);
937                 }
938         } else {
939                 s = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
940                 if (s == NULL)
941                         return -ENOENT;
942
943                 ip_set_flush_set(s);
944         }
945
946         return 0;
947 }
948
949 /* Rename a set */
950
951 static const struct nla_policy
952 ip_set_setname2_policy[IPSET_ATTR_CMD_MAX + 1] = {
953         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
954         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
955                                     .len = IPSET_MAXNAMELEN - 1 },
956         [IPSET_ATTR_SETNAME2]   = { .type = NLA_NUL_STRING,
957                                     .len = IPSET_MAXNAMELEN - 1 },
958 };
959
960 static int
961 ip_set_rename(struct sock *ctnl, struct sk_buff *skb,
962               const struct nlmsghdr *nlh,
963               const struct nlattr * const attr[])
964 {
965         struct ip_set *set, *s;
966         const char *name2;
967         ip_set_id_t i;
968         int ret = 0;
969
970         if (unlikely(protocol_failed(attr) ||
971                      attr[IPSET_ATTR_SETNAME] == NULL ||
972                      attr[IPSET_ATTR_SETNAME2] == NULL))
973                 return -IPSET_ERR_PROTOCOL;
974
975         set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
976         if (set == NULL)
977                 return -ENOENT;
978
979         read_lock_bh(&ip_set_ref_lock);
980         if (set->ref != 0) {
981                 ret = -IPSET_ERR_REFERENCED;
982                 goto out;
983         }
984
985         name2 = nla_data(attr[IPSET_ATTR_SETNAME2]);
986         for (i = 0; i < ip_set_max; i++) {
987                 s = nfnl_set(i);
988                 if (s != NULL && STREQ(s->name, name2)) {
989                         ret = -IPSET_ERR_EXIST_SETNAME2;
990                         goto out;
991                 }
992         }
993         strncpy(set->name, name2, IPSET_MAXNAMELEN);
994
995 out:
996         read_unlock_bh(&ip_set_ref_lock);
997         return ret;
998 }
999
1000 /* Swap two sets so that name/index points to the other.
1001  * References and set names are also swapped.
1002  *
1003  * The commands are serialized by the nfnl mutex and references are
1004  * protected by the ip_set_ref_lock. The kernel interfaces
1005  * do not hold the mutex but the pointer settings are atomic
1006  * so the ip_set_list always contains valid pointers to the sets.
1007  */
1008
1009 static int
1010 ip_set_swap(struct sock *ctnl, struct sk_buff *skb,
1011             const struct nlmsghdr *nlh,
1012             const struct nlattr * const attr[])
1013 {
1014         struct ip_set *from, *to;
1015         ip_set_id_t from_id, to_id;
1016         char from_name[IPSET_MAXNAMELEN];
1017
1018         if (unlikely(protocol_failed(attr) ||
1019                      attr[IPSET_ATTR_SETNAME] == NULL ||
1020                      attr[IPSET_ATTR_SETNAME2] == NULL))
1021                 return -IPSET_ERR_PROTOCOL;
1022
1023         from = find_set_and_id(nla_data(attr[IPSET_ATTR_SETNAME]), &from_id);
1024         if (from == NULL)
1025                 return -ENOENT;
1026
1027         to = find_set_and_id(nla_data(attr[IPSET_ATTR_SETNAME2]), &to_id);
1028         if (to == NULL)
1029                 return -IPSET_ERR_EXIST_SETNAME2;
1030
1031         /* Features must not change.
1032          * Not an artificial restriction anymore, as we must prevent
1033          * possible loops created by swapping in setlist type of sets. */
1034         if (!(from->type->features == to->type->features &&
1035               from->type->family == to->type->family))
1036                 return -IPSET_ERR_TYPE_MISMATCH;
1037
1038         strncpy(from_name, from->name, IPSET_MAXNAMELEN);
1039         strncpy(from->name, to->name, IPSET_MAXNAMELEN);
1040         strncpy(to->name, from_name, IPSET_MAXNAMELEN);
1041
1042         write_lock_bh(&ip_set_ref_lock);
1043         swap(from->ref, to->ref);
1044         nfnl_set(from_id) = to;
1045         nfnl_set(to_id) = from;
1046         write_unlock_bh(&ip_set_ref_lock);
1047
1048         return 0;
1049 }
1050
1051 /* List/save set data */
1052
1053 #define DUMP_INIT       0
1054 #define DUMP_ALL        1
1055 #define DUMP_ONE        2
1056 #define DUMP_LAST       3
1057
1058 #define DUMP_TYPE(arg)          (((u32)(arg)) & 0x0000FFFF)
1059 #define DUMP_FLAGS(arg)         (((u32)(arg)) >> 16)
1060
1061 static int
1062 ip_set_dump_done(struct netlink_callback *cb)
1063 {
1064         if (cb->args[2]) {
1065                 pr_debug("release set %s\n", nfnl_set(cb->args[1])->name);
1066                 ip_set_put_byindex((ip_set_id_t) cb->args[1]);
1067         }
1068         return 0;
1069 }
1070
1071 static inline void
1072 dump_attrs(struct nlmsghdr *nlh)
1073 {
1074         const struct nlattr *attr;
1075         int rem;
1076
1077         pr_debug("dump nlmsg\n");
1078         nlmsg_for_each_attr(attr, nlh, sizeof(struct nfgenmsg), rem) {
1079                 pr_debug("type: %u, len %u\n", nla_type(attr), attr->nla_len);
1080         }
1081 }
1082
1083 static int
1084 dump_init(struct netlink_callback *cb)
1085 {
1086         struct nlmsghdr *nlh = nlmsg_hdr(cb->skb);
1087         int min_len = nlmsg_total_size(sizeof(struct nfgenmsg));
1088         struct nlattr *cda[IPSET_ATTR_CMD_MAX+1];
1089         struct nlattr *attr = (void *)nlh + min_len;
1090         u32 dump_type;
1091         ip_set_id_t index;
1092
1093         /* Second pass, so parser can't fail */
1094         nla_parse(cda, IPSET_ATTR_CMD_MAX,
1095                   attr, nlh->nlmsg_len - min_len, ip_set_setname_policy);
1096
1097         /* cb->args[0] : dump single set/all sets
1098          *         [1] : set index
1099          *         [..]: type specific
1100          */
1101
1102         if (cda[IPSET_ATTR_SETNAME]) {
1103                 struct ip_set *set;
1104
1105                 set = find_set_and_id(nla_data(cda[IPSET_ATTR_SETNAME]),
1106                                       &index);
1107                 if (set == NULL)
1108                         return -ENOENT;
1109
1110                 dump_type = DUMP_ONE;
1111                 cb->args[1] = index;
1112         } else
1113                 dump_type = DUMP_ALL;
1114
1115         if (cda[IPSET_ATTR_FLAGS]) {
1116                 u32 f = ip_set_get_h32(cda[IPSET_ATTR_FLAGS]);
1117                 dump_type |= (f << 16);
1118         }
1119         cb->args[0] = dump_type;
1120
1121         return 0;
1122 }
1123
1124 static int
1125 ip_set_dump_start(struct sk_buff *skb, struct netlink_callback *cb)
1126 {
1127         ip_set_id_t index = IPSET_INVALID_ID, max;
1128         struct ip_set *set = NULL;
1129         struct nlmsghdr *nlh = NULL;
1130         unsigned int flags = NETLINK_CB(cb->skb).portid ? NLM_F_MULTI : 0;
1131         u32 dump_type, dump_flags;
1132         int ret = 0;
1133
1134         if (!cb->args[0]) {
1135                 ret = dump_init(cb);
1136                 if (ret < 0) {
1137                         nlh = nlmsg_hdr(cb->skb);
1138                         /* We have to create and send the error message
1139                          * manually :-( */
1140                         if (nlh->nlmsg_flags & NLM_F_ACK)
1141                                 netlink_ack(cb->skb, nlh, ret);
1142                         return ret;
1143                 }
1144         }
1145
1146         if (cb->args[1] >= ip_set_max)
1147                 goto out;
1148
1149         dump_type = DUMP_TYPE(cb->args[0]);
1150         dump_flags = DUMP_FLAGS(cb->args[0]);
1151         max = dump_type == DUMP_ONE ? cb->args[1] + 1 : ip_set_max;
1152 dump_last:
1153         pr_debug("args[0]: %u %u args[1]: %ld\n",
1154                  dump_type, dump_flags, cb->args[1]);
1155         for (; cb->args[1] < max; cb->args[1]++) {
1156                 index = (ip_set_id_t) cb->args[1];
1157                 set = nfnl_set(index);
1158                 if (set == NULL) {
1159                         if (dump_type == DUMP_ONE) {
1160                                 ret = -ENOENT;
1161                                 goto out;
1162                         }
1163                         continue;
1164                 }
1165                 /* When dumping all sets, we must dump "sorted"
1166                  * so that lists (unions of sets) are dumped last.
1167                  */
1168                 if (dump_type != DUMP_ONE &&
1169                     ((dump_type == DUMP_ALL) ==
1170                      !!(set->type->features & IPSET_DUMP_LAST)))
1171                         continue;
1172                 pr_debug("List set: %s\n", set->name);
1173                 if (!cb->args[2]) {
1174                         /* Start listing: make sure set won't be destroyed */
1175                         pr_debug("reference set\n");
1176                         __ip_set_get(set);
1177                 }
1178                 nlh = start_msg(skb, NETLINK_CB(cb->skb).portid,
1179                                 cb->nlh->nlmsg_seq, flags,
1180                                 IPSET_CMD_LIST);
1181                 if (!nlh) {
1182                         ret = -EMSGSIZE;
1183                         goto release_refcount;
1184                 }
1185                 if (nla_put_u8(skb, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL) ||
1186                     nla_put_string(skb, IPSET_ATTR_SETNAME, set->name))
1187                         goto nla_put_failure;
1188                 if (dump_flags & IPSET_FLAG_LIST_SETNAME)
1189                         goto next_set;
1190                 switch (cb->args[2]) {
1191                 case 0:
1192                         /* Core header data */
1193                         if (nla_put_string(skb, IPSET_ATTR_TYPENAME,
1194                                            set->type->name) ||
1195                             nla_put_u8(skb, IPSET_ATTR_FAMILY,
1196                                        set->family) ||
1197                             nla_put_u8(skb, IPSET_ATTR_REVISION,
1198                                        set->revision))
1199                                 goto nla_put_failure;
1200                         ret = set->variant->head(set, skb);
1201                         if (ret < 0)
1202                                 goto release_refcount;
1203                         if (dump_flags & IPSET_FLAG_LIST_HEADER)
1204                                 goto next_set;
1205                         /* Fall through and add elements */
1206                 default:
1207                         read_lock_bh(&set->lock);
1208                         ret = set->variant->list(set, skb, cb);
1209                         read_unlock_bh(&set->lock);
1210                         if (!cb->args[2])
1211                                 /* Set is done, proceed with next one */
1212                                 goto next_set;
1213                         goto release_refcount;
1214                 }
1215         }
1216         /* If we dump all sets, continue with dumping last ones */
1217         if (dump_type == DUMP_ALL) {
1218                 dump_type = DUMP_LAST;
1219                 cb->args[0] = dump_type | (dump_flags << 16);
1220                 cb->args[1] = 0;
1221                 goto dump_last;
1222         }
1223         goto out;
1224
1225 nla_put_failure:
1226         ret = -EFAULT;
1227 next_set:
1228         if (dump_type == DUMP_ONE)
1229                 cb->args[1] = IPSET_INVALID_ID;
1230         else
1231                 cb->args[1]++;
1232 release_refcount:
1233         /* If there was an error or set is done, release set */
1234         if (ret || !cb->args[2]) {
1235                 pr_debug("release set %s\n", nfnl_set(index)->name);
1236                 ip_set_put_byindex(index);
1237                 cb->args[2] = 0;
1238         }
1239 out:
1240         if (nlh) {
1241                 nlmsg_end(skb, nlh);
1242                 pr_debug("nlmsg_len: %u\n", nlh->nlmsg_len);
1243                 dump_attrs(nlh);
1244         }
1245
1246         return ret < 0 ? ret : skb->len;
1247 }
1248
1249 static int
1250 ip_set_dump(struct sock *ctnl, struct sk_buff *skb,
1251             const struct nlmsghdr *nlh,
1252             const struct nlattr * const attr[])
1253 {
1254         if (unlikely(protocol_failed(attr)))
1255                 return -IPSET_ERR_PROTOCOL;
1256
1257         {
1258                 struct netlink_dump_control c = {
1259                         .dump = ip_set_dump_start,
1260                         .done = ip_set_dump_done,
1261                 };
1262                 return netlink_dump_start(ctnl, skb, nlh, &c);
1263         }
1264 }
1265
1266 /* Add, del and test */
1267
1268 static const struct nla_policy ip_set_adt_policy[IPSET_ATTR_CMD_MAX + 1] = {
1269         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1270         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
1271                                     .len = IPSET_MAXNAMELEN - 1 },
1272         [IPSET_ATTR_LINENO]     = { .type = NLA_U32 },
1273         [IPSET_ATTR_DATA]       = { .type = NLA_NESTED },
1274         [IPSET_ATTR_ADT]        = { .type = NLA_NESTED },
1275 };
1276
1277 static int
1278 call_ad(struct sock *ctnl, struct sk_buff *skb, struct ip_set *set,
1279         struct nlattr *tb[], enum ipset_adt adt,
1280         u32 flags, bool use_lineno)
1281 {
1282         int ret;
1283         u32 lineno = 0;
1284         bool eexist = flags & IPSET_FLAG_EXIST, retried = false;
1285
1286         do {
1287                 write_lock_bh(&set->lock);
1288                 ret = set->variant->uadt(set, tb, adt, &lineno, flags, retried);
1289                 write_unlock_bh(&set->lock);
1290                 retried = true;
1291         } while (ret == -EAGAIN &&
1292                  set->variant->resize &&
1293                  (ret = set->variant->resize(set, retried)) == 0);
1294
1295         if (!ret || (ret == -IPSET_ERR_EXIST && eexist))
1296                 return 0;
1297         if (lineno && use_lineno) {
1298                 /* Error in restore/batch mode: send back lineno */
1299                 struct nlmsghdr *rep, *nlh = nlmsg_hdr(skb);
1300                 struct sk_buff *skb2;
1301                 struct nlmsgerr *errmsg;
1302                 size_t payload = sizeof(*errmsg) + nlmsg_len(nlh);
1303                 int min_len = nlmsg_total_size(sizeof(struct nfgenmsg));
1304                 struct nlattr *cda[IPSET_ATTR_CMD_MAX+1];
1305                 struct nlattr *cmdattr;
1306                 u32 *errline;
1307
1308                 skb2 = nlmsg_new(payload, GFP_KERNEL);
1309                 if (skb2 == NULL)
1310                         return -ENOMEM;
1311                 rep = __nlmsg_put(skb2, NETLINK_CB(skb).portid,
1312                                   nlh->nlmsg_seq, NLMSG_ERROR, payload, 0);
1313                 errmsg = nlmsg_data(rep);
1314                 errmsg->error = ret;
1315                 memcpy(&errmsg->msg, nlh, nlh->nlmsg_len);
1316                 cmdattr = (void *)&errmsg->msg + min_len;
1317
1318                 nla_parse(cda, IPSET_ATTR_CMD_MAX,
1319                           cmdattr, nlh->nlmsg_len - min_len,
1320                           ip_set_adt_policy);
1321
1322                 errline = nla_data(cda[IPSET_ATTR_LINENO]);
1323
1324                 *errline = lineno;
1325
1326                 netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1327                 /* Signal netlink not to send its ACK/errmsg.  */
1328                 return -EINTR;
1329         }
1330
1331         return ret;
1332 }
1333
1334 static int
1335 ip_set_uadd(struct sock *ctnl, struct sk_buff *skb,
1336             const struct nlmsghdr *nlh,
1337             const struct nlattr * const attr[])
1338 {
1339         struct ip_set *set;
1340         struct nlattr *tb[IPSET_ATTR_ADT_MAX+1] = {};
1341         const struct nlattr *nla;
1342         u32 flags = flag_exist(nlh);
1343         bool use_lineno;
1344         int ret = 0;
1345
1346         if (unlikely(protocol_failed(attr) ||
1347                      attr[IPSET_ATTR_SETNAME] == NULL ||
1348                      !((attr[IPSET_ATTR_DATA] != NULL) ^
1349                        (attr[IPSET_ATTR_ADT] != NULL)) ||
1350                      (attr[IPSET_ATTR_DATA] != NULL &&
1351                       !flag_nested(attr[IPSET_ATTR_DATA])) ||
1352                      (attr[IPSET_ATTR_ADT] != NULL &&
1353                       (!flag_nested(attr[IPSET_ATTR_ADT]) ||
1354                        attr[IPSET_ATTR_LINENO] == NULL))))
1355                 return -IPSET_ERR_PROTOCOL;
1356
1357         set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
1358         if (set == NULL)
1359                 return -ENOENT;
1360
1361         use_lineno = !!attr[IPSET_ATTR_LINENO];
1362         if (attr[IPSET_ATTR_DATA]) {
1363                 if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX,
1364                                      attr[IPSET_ATTR_DATA],
1365                                      set->type->adt_policy))
1366                         return -IPSET_ERR_PROTOCOL;
1367                 ret = call_ad(ctnl, skb, set, tb, IPSET_ADD, flags,
1368                               use_lineno);
1369         } else {
1370                 int nla_rem;
1371
1372                 nla_for_each_nested(nla, attr[IPSET_ATTR_ADT], nla_rem) {
1373                         memset(tb, 0, sizeof(tb));
1374                         if (nla_type(nla) != IPSET_ATTR_DATA ||
1375                             !flag_nested(nla) ||
1376                             nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, nla,
1377                                              set->type->adt_policy))
1378                                 return -IPSET_ERR_PROTOCOL;
1379                         ret = call_ad(ctnl, skb, set, tb, IPSET_ADD,
1380                                       flags, use_lineno);
1381                         if (ret < 0)
1382                                 return ret;
1383                 }
1384         }
1385         return ret;
1386 }
1387
1388 static int
1389 ip_set_udel(struct sock *ctnl, struct sk_buff *skb,
1390             const struct nlmsghdr *nlh,
1391             const struct nlattr * const attr[])
1392 {
1393         struct ip_set *set;
1394         struct nlattr *tb[IPSET_ATTR_ADT_MAX+1] = {};
1395         const struct nlattr *nla;
1396         u32 flags = flag_exist(nlh);
1397         bool use_lineno;
1398         int ret = 0;
1399
1400         if (unlikely(protocol_failed(attr) ||
1401                      attr[IPSET_ATTR_SETNAME] == NULL ||
1402                      !((attr[IPSET_ATTR_DATA] != NULL) ^
1403                        (attr[IPSET_ATTR_ADT] != NULL)) ||
1404                      (attr[IPSET_ATTR_DATA] != NULL &&
1405                       !flag_nested(attr[IPSET_ATTR_DATA])) ||
1406                      (attr[IPSET_ATTR_ADT] != NULL &&
1407                       (!flag_nested(attr[IPSET_ATTR_ADT]) ||
1408                        attr[IPSET_ATTR_LINENO] == NULL))))
1409                 return -IPSET_ERR_PROTOCOL;
1410
1411         set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
1412         if (set == NULL)
1413                 return -ENOENT;
1414
1415         use_lineno = !!attr[IPSET_ATTR_LINENO];
1416         if (attr[IPSET_ATTR_DATA]) {
1417                 if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX,
1418                                      attr[IPSET_ATTR_DATA],
1419                                      set->type->adt_policy))
1420                         return -IPSET_ERR_PROTOCOL;
1421                 ret = call_ad(ctnl, skb, set, tb, IPSET_DEL, flags,
1422                               use_lineno);
1423         } else {
1424                 int nla_rem;
1425
1426                 nla_for_each_nested(nla, attr[IPSET_ATTR_ADT], nla_rem) {
1427                         memset(tb, 0, sizeof(*tb));
1428                         if (nla_type(nla) != IPSET_ATTR_DATA ||
1429                             !flag_nested(nla) ||
1430                             nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, nla,
1431                                              set->type->adt_policy))
1432                                 return -IPSET_ERR_PROTOCOL;
1433                         ret = call_ad(ctnl, skb, set, tb, IPSET_DEL,
1434                                       flags, use_lineno);
1435                         if (ret < 0)
1436                                 return ret;
1437                 }
1438         }
1439         return ret;
1440 }
1441
1442 static int
1443 ip_set_utest(struct sock *ctnl, struct sk_buff *skb,
1444              const struct nlmsghdr *nlh,
1445              const struct nlattr * const attr[])
1446 {
1447         struct ip_set *set;
1448         struct nlattr *tb[IPSET_ATTR_ADT_MAX+1] = {};
1449         int ret = 0;
1450
1451         if (unlikely(protocol_failed(attr) ||
1452                      attr[IPSET_ATTR_SETNAME] == NULL ||
1453                      attr[IPSET_ATTR_DATA] == NULL ||
1454                      !flag_nested(attr[IPSET_ATTR_DATA])))
1455                 return -IPSET_ERR_PROTOCOL;
1456
1457         set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
1458         if (set == NULL)
1459                 return -ENOENT;
1460
1461         if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, attr[IPSET_ATTR_DATA],
1462                              set->type->adt_policy))
1463                 return -IPSET_ERR_PROTOCOL;
1464
1465         read_lock_bh(&set->lock);
1466         ret = set->variant->uadt(set, tb, IPSET_TEST, NULL, 0, 0);
1467         read_unlock_bh(&set->lock);
1468         /* Userspace can't trigger element to be re-added */
1469         if (ret == -EAGAIN)
1470                 ret = 1;
1471
1472         return (ret < 0 && ret != -ENOTEMPTY) ? ret :
1473                 ret > 0 ? 0 : -IPSET_ERR_EXIST;
1474 }
1475
1476 /* Get headed data of a set */
1477
1478 static int
1479 ip_set_header(struct sock *ctnl, struct sk_buff *skb,
1480               const struct nlmsghdr *nlh,
1481               const struct nlattr * const attr[])
1482 {
1483         const struct ip_set *set;
1484         struct sk_buff *skb2;
1485         struct nlmsghdr *nlh2;
1486         int ret = 0;
1487
1488         if (unlikely(protocol_failed(attr) ||
1489                      attr[IPSET_ATTR_SETNAME] == NULL))
1490                 return -IPSET_ERR_PROTOCOL;
1491
1492         set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
1493         if (set == NULL)
1494                 return -ENOENT;
1495
1496         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1497         if (skb2 == NULL)
1498                 return -ENOMEM;
1499
1500         nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1501                          IPSET_CMD_HEADER);
1502         if (!nlh2)
1503                 goto nlmsg_failure;
1504         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL) ||
1505             nla_put_string(skb2, IPSET_ATTR_SETNAME, set->name) ||
1506             nla_put_string(skb2, IPSET_ATTR_TYPENAME, set->type->name) ||
1507             nla_put_u8(skb2, IPSET_ATTR_FAMILY, set->family) ||
1508             nla_put_u8(skb2, IPSET_ATTR_REVISION, set->revision))
1509                 goto nla_put_failure;
1510         nlmsg_end(skb2, nlh2);
1511
1512         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1513         if (ret < 0)
1514                 return ret;
1515
1516         return 0;
1517
1518 nla_put_failure:
1519         nlmsg_cancel(skb2, nlh2);
1520 nlmsg_failure:
1521         kfree_skb(skb2);
1522         return -EMSGSIZE;
1523 }
1524
1525 /* Get type data */
1526
1527 static const struct nla_policy ip_set_type_policy[IPSET_ATTR_CMD_MAX + 1] = {
1528         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1529         [IPSET_ATTR_TYPENAME]   = { .type = NLA_NUL_STRING,
1530                                     .len = IPSET_MAXNAMELEN - 1 },
1531         [IPSET_ATTR_FAMILY]     = { .type = NLA_U8 },
1532 };
1533
1534 static int
1535 ip_set_type(struct sock *ctnl, struct sk_buff *skb,
1536             const struct nlmsghdr *nlh,
1537             const struct nlattr * const attr[])
1538 {
1539         struct sk_buff *skb2;
1540         struct nlmsghdr *nlh2;
1541         u8 family, min, max;
1542         const char *typename;
1543         int ret = 0;
1544
1545         if (unlikely(protocol_failed(attr) ||
1546                      attr[IPSET_ATTR_TYPENAME] == NULL ||
1547                      attr[IPSET_ATTR_FAMILY] == NULL))
1548                 return -IPSET_ERR_PROTOCOL;
1549
1550         family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
1551         typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
1552         ret = find_set_type_minmax(typename, family, &min, &max);
1553         if (ret)
1554                 return ret;
1555
1556         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1557         if (skb2 == NULL)
1558                 return -ENOMEM;
1559
1560         nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1561                          IPSET_CMD_TYPE);
1562         if (!nlh2)
1563                 goto nlmsg_failure;
1564         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL) ||
1565             nla_put_string(skb2, IPSET_ATTR_TYPENAME, typename) ||
1566             nla_put_u8(skb2, IPSET_ATTR_FAMILY, family) ||
1567             nla_put_u8(skb2, IPSET_ATTR_REVISION, max) ||
1568             nla_put_u8(skb2, IPSET_ATTR_REVISION_MIN, min))
1569                 goto nla_put_failure;
1570         nlmsg_end(skb2, nlh2);
1571
1572         pr_debug("Send TYPE, nlmsg_len: %u\n", nlh2->nlmsg_len);
1573         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1574         if (ret < 0)
1575                 return ret;
1576
1577         return 0;
1578
1579 nla_put_failure:
1580         nlmsg_cancel(skb2, nlh2);
1581 nlmsg_failure:
1582         kfree_skb(skb2);
1583         return -EMSGSIZE;
1584 }
1585
1586 /* Get protocol version */
1587
1588 static const struct nla_policy
1589 ip_set_protocol_policy[IPSET_ATTR_CMD_MAX + 1] = {
1590         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1591 };
1592
1593 static int
1594 ip_set_protocol(struct sock *ctnl, struct sk_buff *skb,
1595                 const struct nlmsghdr *nlh,
1596                 const struct nlattr * const attr[])
1597 {
1598         struct sk_buff *skb2;
1599         struct nlmsghdr *nlh2;
1600         int ret = 0;
1601
1602         if (unlikely(attr[IPSET_ATTR_PROTOCOL] == NULL))
1603                 return -IPSET_ERR_PROTOCOL;
1604
1605         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1606         if (skb2 == NULL)
1607                 return -ENOMEM;
1608
1609         nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1610                          IPSET_CMD_PROTOCOL);
1611         if (!nlh2)
1612                 goto nlmsg_failure;
1613         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL))
1614                 goto nla_put_failure;
1615         nlmsg_end(skb2, nlh2);
1616
1617         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1618         if (ret < 0)
1619                 return ret;
1620
1621         return 0;
1622
1623 nla_put_failure:
1624         nlmsg_cancel(skb2, nlh2);
1625 nlmsg_failure:
1626         kfree_skb(skb2);
1627         return -EMSGSIZE;
1628 }
1629
1630 static const struct nfnl_callback ip_set_netlink_subsys_cb[IPSET_MSG_MAX] = {
1631         [IPSET_CMD_NONE]        = {
1632                 .call           = ip_set_none,
1633                 .attr_count     = IPSET_ATTR_CMD_MAX,
1634         },
1635         [IPSET_CMD_CREATE]      = {
1636                 .call           = ip_set_create,
1637                 .attr_count     = IPSET_ATTR_CMD_MAX,
1638                 .policy         = ip_set_create_policy,
1639         },
1640         [IPSET_CMD_DESTROY]     = {
1641                 .call           = ip_set_destroy,
1642                 .attr_count     = IPSET_ATTR_CMD_MAX,
1643                 .policy         = ip_set_setname_policy,
1644         },
1645         [IPSET_CMD_FLUSH]       = {
1646                 .call           = ip_set_flush,
1647                 .attr_count     = IPSET_ATTR_CMD_MAX,
1648                 .policy         = ip_set_setname_policy,
1649         },
1650         [IPSET_CMD_RENAME]      = {
1651                 .call           = ip_set_rename,
1652                 .attr_count     = IPSET_ATTR_CMD_MAX,
1653                 .policy         = ip_set_setname2_policy,
1654         },
1655         [IPSET_CMD_SWAP]        = {
1656                 .call           = ip_set_swap,
1657                 .attr_count     = IPSET_ATTR_CMD_MAX,
1658                 .policy         = ip_set_setname2_policy,
1659         },
1660         [IPSET_CMD_LIST]        = {
1661                 .call           = ip_set_dump,
1662                 .attr_count     = IPSET_ATTR_CMD_MAX,
1663                 .policy         = ip_set_setname_policy,
1664         },
1665         [IPSET_CMD_SAVE]        = {
1666                 .call           = ip_set_dump,
1667                 .attr_count     = IPSET_ATTR_CMD_MAX,
1668                 .policy         = ip_set_setname_policy,
1669         },
1670         [IPSET_CMD_ADD] = {
1671                 .call           = ip_set_uadd,
1672                 .attr_count     = IPSET_ATTR_CMD_MAX,
1673                 .policy         = ip_set_adt_policy,
1674         },
1675         [IPSET_CMD_DEL] = {
1676                 .call           = ip_set_udel,
1677                 .attr_count     = IPSET_ATTR_CMD_MAX,
1678                 .policy         = ip_set_adt_policy,
1679         },
1680         [IPSET_CMD_TEST]        = {
1681                 .call           = ip_set_utest,
1682                 .attr_count     = IPSET_ATTR_CMD_MAX,
1683                 .policy         = ip_set_adt_policy,
1684         },
1685         [IPSET_CMD_HEADER]      = {
1686                 .call           = ip_set_header,
1687                 .attr_count     = IPSET_ATTR_CMD_MAX,
1688                 .policy         = ip_set_setname_policy,
1689         },
1690         [IPSET_CMD_TYPE]        = {
1691                 .call           = ip_set_type,
1692                 .attr_count     = IPSET_ATTR_CMD_MAX,
1693                 .policy         = ip_set_type_policy,
1694         },
1695         [IPSET_CMD_PROTOCOL]    = {
1696                 .call           = ip_set_protocol,
1697                 .attr_count     = IPSET_ATTR_CMD_MAX,
1698                 .policy         = ip_set_protocol_policy,
1699         },
1700 };
1701
1702 static struct nfnetlink_subsystem ip_set_netlink_subsys __read_mostly = {
1703         .name           = "ip_set",
1704         .subsys_id      = NFNL_SUBSYS_IPSET,
1705         .cb_count       = IPSET_MSG_MAX,
1706         .cb             = ip_set_netlink_subsys_cb,
1707 };
1708
1709 /* Interface to iptables/ip6tables */
1710
1711 static int
1712 ip_set_sockfn_get(struct sock *sk, int optval, void __user *user, int *len)
1713 {
1714         unsigned int *op;
1715         void *data;
1716         int copylen = *len, ret = 0;
1717
1718         if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN))
1719                 return -EPERM;
1720         if (optval != SO_IP_SET)
1721                 return -EBADF;
1722         if (*len < sizeof(unsigned int))
1723                 return -EINVAL;
1724
1725         data = vmalloc(*len);
1726         if (!data)
1727                 return -ENOMEM;
1728         if (copy_from_user(data, user, *len) != 0) {
1729                 ret = -EFAULT;
1730                 goto done;
1731         }
1732         op = (unsigned int *) data;
1733
1734         if (*op < IP_SET_OP_VERSION) {
1735                 /* Check the version at the beginning of operations */
1736                 struct ip_set_req_version *req_version = data;
1737                 if (req_version->version != IPSET_PROTOCOL) {
1738                         ret = -EPROTO;
1739                         goto done;
1740                 }
1741         }
1742
1743         switch (*op) {
1744         case IP_SET_OP_VERSION: {
1745                 struct ip_set_req_version *req_version = data;
1746
1747                 if (*len != sizeof(struct ip_set_req_version)) {
1748                         ret = -EINVAL;
1749                         goto done;
1750                 }
1751
1752                 req_version->version = IPSET_PROTOCOL;
1753                 ret = copy_to_user(user, req_version,
1754                                    sizeof(struct ip_set_req_version));
1755                 goto done;
1756         }
1757         case IP_SET_OP_GET_BYNAME: {
1758                 struct ip_set_req_get_set *req_get = data;
1759                 ip_set_id_t id;
1760
1761                 if (*len != sizeof(struct ip_set_req_get_set)) {
1762                         ret = -EINVAL;
1763                         goto done;
1764                 }
1765                 req_get->set.name[IPSET_MAXNAMELEN - 1] = '\0';
1766                 nfnl_lock(NFNL_SUBSYS_IPSET);
1767                 find_set_and_id(req_get->set.name, &id);
1768                 req_get->set.index = id;
1769                 nfnl_unlock(NFNL_SUBSYS_IPSET);
1770                 goto copy;
1771         }
1772         case IP_SET_OP_GET_BYINDEX: {
1773                 struct ip_set_req_get_set *req_get = data;
1774                 struct ip_set *set;
1775
1776                 if (*len != sizeof(struct ip_set_req_get_set) ||
1777                     req_get->set.index >= ip_set_max) {
1778                         ret = -EINVAL;
1779                         goto done;
1780                 }
1781                 nfnl_lock(NFNL_SUBSYS_IPSET);
1782                 set = nfnl_set(req_get->set.index);
1783                 strncpy(req_get->set.name, set ? set->name : "",
1784                         IPSET_MAXNAMELEN);
1785                 nfnl_unlock(NFNL_SUBSYS_IPSET);
1786                 goto copy;
1787         }
1788         default:
1789                 ret = -EBADMSG;
1790                 goto done;
1791         }       /* end of switch(op) */
1792
1793 copy:
1794         ret = copy_to_user(user, data, copylen);
1795
1796 done:
1797         vfree(data);
1798         if (ret > 0)
1799                 ret = 0;
1800         return ret;
1801 }
1802
1803 static struct nf_sockopt_ops so_set __read_mostly = {
1804         .pf             = PF_INET,
1805         .get_optmin     = SO_IP_SET,
1806         .get_optmax     = SO_IP_SET + 1,
1807         .get            = &ip_set_sockfn_get,
1808         .owner          = THIS_MODULE,
1809 };
1810
1811 static int __init
1812 ip_set_init(void)
1813 {
1814         struct ip_set **list;
1815         int ret;
1816
1817         if (max_sets)
1818                 ip_set_max = max_sets;
1819         if (ip_set_max >= IPSET_INVALID_ID)
1820                 ip_set_max = IPSET_INVALID_ID - 1;
1821
1822         list = kzalloc(sizeof(struct ip_set *) * ip_set_max, GFP_KERNEL);
1823         if (!list)
1824                 return -ENOMEM;
1825
1826         rcu_assign_pointer(ip_set_list, list);
1827         ret = nfnetlink_subsys_register(&ip_set_netlink_subsys);
1828         if (ret != 0) {
1829                 pr_err("ip_set: cannot register with nfnetlink.\n");
1830                 kfree(list);
1831                 return ret;
1832         }
1833         ret = nf_register_sockopt(&so_set);
1834         if (ret != 0) {
1835                 pr_err("SO_SET registry failed: %d\n", ret);
1836                 nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
1837                 kfree(list);
1838                 return ret;
1839         }
1840
1841         pr_notice("ip_set: protocol %u\n", IPSET_PROTOCOL);
1842         return 0;
1843 }
1844
1845 static void __exit
1846 ip_set_fini(void)
1847 {
1848         struct ip_set **list = rcu_dereference_protected(ip_set_list, 1);
1849
1850         /* There can't be any existing set */
1851         nf_unregister_sockopt(&so_set);
1852         nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
1853         kfree(list);
1854         pr_debug("these are the famous last words\n");
1855 }
1856
1857 module_init(ip_set_init);
1858 module_exit(ip_set_fini);