net: diag: allow socket bytecode filters to match socket marks
[firefly-linux-kernel-4.4.55.git] / net / ipv4 / inet_diag.c
1 /*
2  * inet_diag.c  Module for monitoring INET transport protocols sockets.
3  *
4  * Authors:     Alexey Kuznetsov, <kuznet@ms2.inr.ac.ru>
5  *
6  *      This program is free software; you can redistribute it and/or
7  *      modify it under the terms of the GNU General Public License
8  *      as published by the Free Software Foundation; either version
9  *      2 of the License, or (at your option) any later version.
10  */
11
12 #include <linux/kernel.h>
13 #include <linux/module.h>
14 #include <linux/types.h>
15 #include <linux/fcntl.h>
16 #include <linux/random.h>
17 #include <linux/slab.h>
18 #include <linux/cache.h>
19 #include <linux/init.h>
20 #include <linux/time.h>
21
22 #include <net/icmp.h>
23 #include <net/tcp.h>
24 #include <net/ipv6.h>
25 #include <net/inet_common.h>
26 #include <net/inet_connection_sock.h>
27 #include <net/inet_hashtables.h>
28 #include <net/inet_timewait_sock.h>
29 #include <net/inet6_hashtables.h>
30 #include <net/netlink.h>
31
32 #include <linux/inet.h>
33 #include <linux/stddef.h>
34
35 #include <linux/inet_diag.h>
36 #include <linux/sock_diag.h>
37
38 static const struct inet_diag_handler **inet_diag_table;
39
40 struct inet_diag_entry {
41         const __be32 *saddr;
42         const __be32 *daddr;
43         u16 sport;
44         u16 dport;
45         u16 family;
46         u16 userlocks;
47         u32 ifindex;
48         u32 mark;
49 };
50
51 static DEFINE_MUTEX(inet_diag_table_mutex);
52
53 static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
54 {
55         if (!inet_diag_table[proto])
56                 request_module("net-pf-%d-proto-%d-type-%d-%d", PF_NETLINK,
57                                NETLINK_SOCK_DIAG, AF_INET, proto);
58
59         mutex_lock(&inet_diag_table_mutex);
60         if (!inet_diag_table[proto])
61                 return ERR_PTR(-ENOENT);
62
63         return inet_diag_table[proto];
64 }
65
66 static void inet_diag_unlock_handler(const struct inet_diag_handler *handler)
67 {
68         mutex_unlock(&inet_diag_table_mutex);
69 }
70
71 static void inet_diag_msg_common_fill(struct inet_diag_msg *r, struct sock *sk)
72 {
73         r->idiag_family = sk->sk_family;
74
75         r->id.idiag_sport = htons(sk->sk_num);
76         r->id.idiag_dport = sk->sk_dport;
77         r->id.idiag_if = sk->sk_bound_dev_if;
78         sock_diag_save_cookie(sk, r->id.idiag_cookie);
79
80 #if IS_ENABLED(CONFIG_IPV6)
81         if (sk->sk_family == AF_INET6) {
82                 *(struct in6_addr *)r->id.idiag_src = sk->sk_v6_rcv_saddr;
83                 *(struct in6_addr *)r->id.idiag_dst = sk->sk_v6_daddr;
84         } else
85 #endif
86         {
87         memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src));
88         memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst));
89
90         r->id.idiag_src[0] = sk->sk_rcv_saddr;
91         r->id.idiag_dst[0] = sk->sk_daddr;
92         }
93 }
94
95 static size_t inet_sk_attr_size(void)
96 {
97         return    nla_total_size(sizeof(struct tcp_info))
98                 + nla_total_size(1) /* INET_DIAG_SHUTDOWN */
99                 + nla_total_size(1) /* INET_DIAG_TOS */
100                 + nla_total_size(1) /* INET_DIAG_TCLASS */
101                 + nla_total_size(sizeof(struct inet_diag_meminfo))
102                 + nla_total_size(sizeof(struct inet_diag_msg))
103                 + nla_total_size(SK_MEMINFO_VARS * sizeof(u32))
104                 + nla_total_size(TCP_CA_NAME_MAX)
105                 + nla_total_size(sizeof(struct tcpvegas_info))
106                 + 64;
107 }
108
109 int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
110                       struct sk_buff *skb, const struct inet_diag_req_v2 *req,
111                       struct user_namespace *user_ns,
112                       u32 portid, u32 seq, u16 nlmsg_flags,
113                       const struct nlmsghdr *unlh)
114 {
115         const struct inet_sock *inet = inet_sk(sk);
116         const struct tcp_congestion_ops *ca_ops;
117         const struct inet_diag_handler *handler;
118         int ext = req->idiag_ext;
119         struct inet_diag_msg *r;
120         struct nlmsghdr  *nlh;
121         struct nlattr *attr;
122         void *info = NULL;
123
124         handler = inet_diag_table[req->sdiag_protocol];
125         BUG_ON(!handler);
126
127         nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
128                         nlmsg_flags);
129         if (!nlh)
130                 return -EMSGSIZE;
131
132         r = nlmsg_data(nlh);
133         BUG_ON(!sk_fullsock(sk));
134
135         inet_diag_msg_common_fill(r, sk);
136         r->idiag_state = sk->sk_state;
137         r->idiag_timer = 0;
138         r->idiag_retrans = 0;
139
140         if (nla_put_u8(skb, INET_DIAG_SHUTDOWN, sk->sk_shutdown))
141                 goto errout;
142
143         /* IPv6 dual-stack sockets use inet->tos for IPv4 connections,
144          * hence this needs to be included regardless of socket family.
145          */
146         if (ext & (1 << (INET_DIAG_TOS - 1)))
147                 if (nla_put_u8(skb, INET_DIAG_TOS, inet->tos) < 0)
148                         goto errout;
149
150 #if IS_ENABLED(CONFIG_IPV6)
151         if (r->idiag_family == AF_INET6) {
152                 if (ext & (1 << (INET_DIAG_TCLASS - 1)))
153                         if (nla_put_u8(skb, INET_DIAG_TCLASS,
154                                        inet6_sk(sk)->tclass) < 0)
155                                 goto errout;
156
157                 if (((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) &&
158                     nla_put_u8(skb, INET_DIAG_SKV6ONLY, ipv6_only_sock(sk)))
159                         goto errout;
160         }
161 #endif
162
163         r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk));
164         r->idiag_inode = sock_i_ino(sk);
165
166         if (ext & (1 << (INET_DIAG_MEMINFO - 1))) {
167                 struct inet_diag_meminfo minfo = {
168                         .idiag_rmem = sk_rmem_alloc_get(sk),
169                         .idiag_wmem = sk->sk_wmem_queued,
170                         .idiag_fmem = sk->sk_forward_alloc,
171                         .idiag_tmem = sk_wmem_alloc_get(sk),
172                 };
173
174                 if (nla_put(skb, INET_DIAG_MEMINFO, sizeof(minfo), &minfo) < 0)
175                         goto errout;
176         }
177
178         if (ext & (1 << (INET_DIAG_SKMEMINFO - 1)))
179                 if (sock_diag_put_meminfo(sk, skb, INET_DIAG_SKMEMINFO))
180                         goto errout;
181
182         if (!icsk) {
183                 handler->idiag_get_info(sk, r, NULL);
184                 goto out;
185         }
186
187 #define EXPIRES_IN_MS(tmo)  DIV_ROUND_UP((tmo - jiffies) * 1000, HZ)
188
189         if (icsk->icsk_pending == ICSK_TIME_RETRANS ||
190             icsk->icsk_pending == ICSK_TIME_EARLY_RETRANS ||
191             icsk->icsk_pending == ICSK_TIME_LOSS_PROBE) {
192                 r->idiag_timer = 1;
193                 r->idiag_retrans = icsk->icsk_retransmits;
194                 r->idiag_expires = EXPIRES_IN_MS(icsk->icsk_timeout);
195         } else if (icsk->icsk_pending == ICSK_TIME_PROBE0) {
196                 r->idiag_timer = 4;
197                 r->idiag_retrans = icsk->icsk_probes_out;
198                 r->idiag_expires = EXPIRES_IN_MS(icsk->icsk_timeout);
199         } else if (timer_pending(&sk->sk_timer)) {
200                 r->idiag_timer = 2;
201                 r->idiag_retrans = icsk->icsk_probes_out;
202                 r->idiag_expires = EXPIRES_IN_MS(sk->sk_timer.expires);
203         } else {
204                 r->idiag_timer = 0;
205                 r->idiag_expires = 0;
206         }
207 #undef EXPIRES_IN_MS
208
209         if ((ext & (1 << (INET_DIAG_INFO - 1))) && handler->idiag_info_size) {
210                 attr = nla_reserve(skb, INET_DIAG_INFO,
211                                    handler->idiag_info_size);
212                 if (!attr)
213                         goto errout;
214
215                 info = nla_data(attr);
216         }
217
218         if (ext & (1 << (INET_DIAG_CONG - 1))) {
219                 int err = 0;
220
221                 rcu_read_lock();
222                 ca_ops = READ_ONCE(icsk->icsk_ca_ops);
223                 if (ca_ops)
224                         err = nla_put_string(skb, INET_DIAG_CONG, ca_ops->name);
225                 rcu_read_unlock();
226                 if (err < 0)
227                         goto errout;
228         }
229
230         handler->idiag_get_info(sk, r, info);
231
232         if (sk->sk_state < TCP_TIME_WAIT) {
233                 union tcp_cc_info info;
234                 size_t sz = 0;
235                 int attr;
236
237                 rcu_read_lock();
238                 ca_ops = READ_ONCE(icsk->icsk_ca_ops);
239                 if (ca_ops && ca_ops->get_info)
240                         sz = ca_ops->get_info(sk, ext, &attr, &info);
241                 rcu_read_unlock();
242                 if (sz && nla_put(skb, attr, sz, &info) < 0)
243                         goto errout;
244         }
245
246 out:
247         nlmsg_end(skb, nlh);
248         return 0;
249
250 errout:
251         nlmsg_cancel(skb, nlh);
252         return -EMSGSIZE;
253 }
254 EXPORT_SYMBOL_GPL(inet_sk_diag_fill);
255
256 static int inet_csk_diag_fill(struct sock *sk,
257                               struct sk_buff *skb,
258                               const struct inet_diag_req_v2 *req,
259                               struct user_namespace *user_ns,
260                               u32 portid, u32 seq, u16 nlmsg_flags,
261                               const struct nlmsghdr *unlh)
262 {
263         return inet_sk_diag_fill(sk, inet_csk(sk), skb, req,
264                                  user_ns, portid, seq, nlmsg_flags, unlh);
265 }
266
267 static int inet_twsk_diag_fill(struct sock *sk,
268                                struct sk_buff *skb,
269                                u32 portid, u32 seq, u16 nlmsg_flags,
270                                const struct nlmsghdr *unlh)
271 {
272         struct inet_timewait_sock *tw = inet_twsk(sk);
273         struct inet_diag_msg *r;
274         struct nlmsghdr *nlh;
275         long tmo;
276
277         nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
278                         nlmsg_flags);
279         if (!nlh)
280                 return -EMSGSIZE;
281
282         r = nlmsg_data(nlh);
283         BUG_ON(tw->tw_state != TCP_TIME_WAIT);
284
285         tmo = tw->tw_timer.expires - jiffies;
286         if (tmo < 0)
287                 tmo = 0;
288
289         inet_diag_msg_common_fill(r, sk);
290         r->idiag_retrans      = 0;
291
292         r->idiag_state        = tw->tw_substate;
293         r->idiag_timer        = 3;
294         r->idiag_expires      = jiffies_to_msecs(tmo);
295         r->idiag_rqueue       = 0;
296         r->idiag_wqueue       = 0;
297         r->idiag_uid          = 0;
298         r->idiag_inode        = 0;
299
300         nlmsg_end(skb, nlh);
301         return 0;
302 }
303
304 static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb,
305                               u32 portid, u32 seq, u16 nlmsg_flags,
306                               const struct nlmsghdr *unlh)
307 {
308         struct inet_diag_msg *r;
309         struct nlmsghdr *nlh;
310         long tmo;
311
312         nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
313                         nlmsg_flags);
314         if (!nlh)
315                 return -EMSGSIZE;
316
317         r = nlmsg_data(nlh);
318         inet_diag_msg_common_fill(r, sk);
319         r->idiag_state = TCP_SYN_RECV;
320         r->idiag_timer = 1;
321         r->idiag_retrans = inet_reqsk(sk)->num_retrans;
322
323         BUILD_BUG_ON(offsetof(struct inet_request_sock, ir_cookie) !=
324                      offsetof(struct sock, sk_cookie));
325
326         tmo = inet_reqsk(sk)->rsk_timer.expires - jiffies;
327         r->idiag_expires = (tmo >= 0) ? jiffies_to_msecs(tmo) : 0;
328         r->idiag_rqueue = 0;
329         r->idiag_wqueue = 0;
330         r->idiag_uid    = 0;
331         r->idiag_inode  = 0;
332
333         nlmsg_end(skb, nlh);
334         return 0;
335 }
336
337 static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
338                         const struct inet_diag_req_v2 *r,
339                         struct user_namespace *user_ns,
340                         u32 portid, u32 seq, u16 nlmsg_flags,
341                         const struct nlmsghdr *unlh)
342 {
343         if (sk->sk_state == TCP_TIME_WAIT)
344                 return inet_twsk_diag_fill(sk, skb, portid, seq,
345                                            nlmsg_flags, unlh);
346
347         if (sk->sk_state == TCP_NEW_SYN_RECV)
348                 return inet_req_diag_fill(sk, skb, portid, seq,
349                                           nlmsg_flags, unlh);
350
351         return inet_csk_diag_fill(sk, skb, r, user_ns, portid, seq,
352                                   nlmsg_flags, unlh);
353 }
354
355 struct sock *inet_diag_find_one_icsk(struct net *net,
356                                      struct inet_hashinfo *hashinfo,
357                                      const struct inet_diag_req_v2 *req)
358 {
359         struct sock *sk;
360
361         if (req->sdiag_family == AF_INET)
362                 sk = inet_lookup(net, hashinfo, req->id.idiag_dst[0],
363                                  req->id.idiag_dport, req->id.idiag_src[0],
364                                  req->id.idiag_sport, req->id.idiag_if);
365 #if IS_ENABLED(CONFIG_IPV6)
366         else if (req->sdiag_family == AF_INET6) {
367                 if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) &&
368                     ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_src))
369                         sk = inet_lookup(net, hashinfo, req->id.idiag_dst[3],
370                                          req->id.idiag_dport, req->id.idiag_src[3],
371                                          req->id.idiag_sport, req->id.idiag_if);
372                 else
373                         sk = inet6_lookup(net, hashinfo,
374                                           (struct in6_addr *)req->id.idiag_dst,
375                                           req->id.idiag_dport,
376                                           (struct in6_addr *)req->id.idiag_src,
377                                           req->id.idiag_sport,
378                                           req->id.idiag_if);
379         }
380 #endif
381         else
382                 return ERR_PTR(-EINVAL);
383
384         if (!sk)
385                 return ERR_PTR(-ENOENT);
386
387         if (sock_diag_check_cookie(sk, req->id.idiag_cookie)) {
388                 sock_gen_put(sk);
389                 return ERR_PTR(-ENOENT);
390         }
391
392         return sk;
393 }
394 EXPORT_SYMBOL_GPL(inet_diag_find_one_icsk);
395
396 int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo,
397                             struct sk_buff *in_skb,
398                             const struct nlmsghdr *nlh,
399                             const struct inet_diag_req_v2 *req)
400 {
401         struct net *net = sock_net(in_skb->sk);
402         struct sk_buff *rep;
403         struct sock *sk;
404         int err;
405
406         sk = inet_diag_find_one_icsk(net, hashinfo, req);
407         if (IS_ERR(sk))
408                 return PTR_ERR(sk);
409
410         rep = nlmsg_new(inet_sk_attr_size(), GFP_KERNEL);
411         if (!rep) {
412                 err = -ENOMEM;
413                 goto out;
414         }
415
416         err = sk_diag_fill(sk, rep, req,
417                            sk_user_ns(NETLINK_CB(in_skb).sk),
418                            NETLINK_CB(in_skb).portid,
419                            nlh->nlmsg_seq, 0, nlh);
420         if (err < 0) {
421                 WARN_ON(err == -EMSGSIZE);
422                 nlmsg_free(rep);
423                 goto out;
424         }
425         err = netlink_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid,
426                               MSG_DONTWAIT);
427         if (err > 0)
428                 err = 0;
429
430 out:
431         if (sk)
432                 sock_gen_put(sk);
433
434         return err;
435 }
436 EXPORT_SYMBOL_GPL(inet_diag_dump_one_icsk);
437
438 static int inet_diag_cmd_exact(int cmd, struct sk_buff *in_skb,
439                                const struct nlmsghdr *nlh,
440                                const struct inet_diag_req_v2 *req)
441 {
442         const struct inet_diag_handler *handler;
443         int err;
444
445         handler = inet_diag_lock_handler(req->sdiag_protocol);
446         if (IS_ERR(handler))
447                 err = PTR_ERR(handler);
448         else if (cmd == SOCK_DIAG_BY_FAMILY)
449                 err = handler->dump_one(in_skb, nlh, req);
450         else if (cmd == SOCK_DESTROY_BACKPORT && handler->destroy)
451                 err = handler->destroy(in_skb, req);
452         else
453                 err = -EOPNOTSUPP;
454         inet_diag_unlock_handler(handler);
455
456         return err;
457 }
458
459 static int bitstring_match(const __be32 *a1, const __be32 *a2, int bits)
460 {
461         int words = bits >> 5;
462
463         bits &= 0x1f;
464
465         if (words) {
466                 if (memcmp(a1, a2, words << 2))
467                         return 0;
468         }
469         if (bits) {
470                 __be32 w1, w2;
471                 __be32 mask;
472
473                 w1 = a1[words];
474                 w2 = a2[words];
475
476                 mask = htonl((0xffffffff) << (32 - bits));
477
478                 if ((w1 ^ w2) & mask)
479                         return 0;
480         }
481
482         return 1;
483 }
484
485 static int inet_diag_bc_run(const struct nlattr *_bc,
486                             const struct inet_diag_entry *entry)
487 {
488         const void *bc = nla_data(_bc);
489         int len = nla_len(_bc);
490
491         while (len > 0) {
492                 int yes = 1;
493                 const struct inet_diag_bc_op *op = bc;
494
495                 switch (op->code) {
496                 case INET_DIAG_BC_NOP:
497                         break;
498                 case INET_DIAG_BC_JMP:
499                         yes = 0;
500                         break;
501                 case INET_DIAG_BC_S_GE:
502                         yes = entry->sport >= op[1].no;
503                         break;
504                 case INET_DIAG_BC_S_LE:
505                         yes = entry->sport <= op[1].no;
506                         break;
507                 case INET_DIAG_BC_D_GE:
508                         yes = entry->dport >= op[1].no;
509                         break;
510                 case INET_DIAG_BC_D_LE:
511                         yes = entry->dport <= op[1].no;
512                         break;
513                 case INET_DIAG_BC_AUTO:
514                         yes = !(entry->userlocks & SOCK_BINDPORT_LOCK);
515                         break;
516                 case INET_DIAG_BC_S_COND:
517                 case INET_DIAG_BC_D_COND: {
518                         const struct inet_diag_hostcond *cond;
519                         const __be32 *addr;
520
521                         cond = (const struct inet_diag_hostcond *)(op + 1);
522                         if (cond->port != -1 &&
523                             cond->port != (op->code == INET_DIAG_BC_S_COND ?
524                                              entry->sport : entry->dport)) {
525                                 yes = 0;
526                                 break;
527                         }
528
529                         if (op->code == INET_DIAG_BC_S_COND)
530                                 addr = entry->saddr;
531                         else
532                                 addr = entry->daddr;
533
534                         if (cond->family != AF_UNSPEC &&
535                             cond->family != entry->family) {
536                                 if (entry->family == AF_INET6 &&
537                                     cond->family == AF_INET) {
538                                         if (addr[0] == 0 && addr[1] == 0 &&
539                                             addr[2] == htonl(0xffff) &&
540                                             bitstring_match(addr + 3,
541                                                             cond->addr,
542                                                             cond->prefix_len))
543                                                 break;
544                                 }
545                                 yes = 0;
546                                 break;
547                         }
548
549                         if (cond->prefix_len == 0)
550                                 break;
551                         if (bitstring_match(addr, cond->addr,
552                                             cond->prefix_len))
553                                 break;
554                         yes = 0;
555                         break;
556                 }
557                 case INET_DIAG_BC_DEV_COND: {
558                         u32 ifindex;
559
560                         ifindex = *((const u32 *)(op + 1));
561                         if (ifindex != entry->ifindex)
562                                 yes = 0;
563                         break;
564                 }
565                 case INET_DIAG_BC_MARK_COND: {
566                         struct inet_diag_markcond *cond;
567
568                         cond = (struct inet_diag_markcond *)(op + 1);
569                         if ((entry->mark & cond->mask) != cond->mark)
570                                 yes = 0;
571                         break;
572                 }
573                 }
574
575                 if (yes) {
576                         len -= op->yes;
577                         bc += op->yes;
578                 } else {
579                         len -= op->no;
580                         bc += op->no;
581                 }
582         }
583         return len == 0;
584 }
585
586 /* This helper is available for all sockets (ESTABLISH, TIMEWAIT, SYN_RECV)
587  */
588 static void entry_fill_addrs(struct inet_diag_entry *entry,
589                              const struct sock *sk)
590 {
591 #if IS_ENABLED(CONFIG_IPV6)
592         if (sk->sk_family == AF_INET6) {
593                 entry->saddr = sk->sk_v6_rcv_saddr.s6_addr32;
594                 entry->daddr = sk->sk_v6_daddr.s6_addr32;
595         } else
596 #endif
597         {
598                 entry->saddr = &sk->sk_rcv_saddr;
599                 entry->daddr = &sk->sk_daddr;
600         }
601 }
602
603 int inet_diag_bc_sk(const struct nlattr *bc, struct sock *sk)
604 {
605         struct inet_sock *inet = inet_sk(sk);
606         struct inet_diag_entry entry;
607
608         if (!bc)
609                 return 1;
610
611         entry.family = sk->sk_family;
612         entry_fill_addrs(&entry, sk);
613         entry.sport = inet->inet_num;
614         entry.dport = ntohs(inet->inet_dport);
615         entry.ifindex = sk->sk_bound_dev_if;
616         entry.userlocks = sk_fullsock(sk) ? sk->sk_userlocks : 0;
617         if (sk_fullsock(sk))
618                 entry.mark = sk->sk_mark;
619         else if (sk->sk_state == TCP_NEW_SYN_RECV)
620                 entry.mark = inet_rsk(inet_reqsk(sk))->ir_mark;
621         else
622                 entry.mark = 0;
623
624         return inet_diag_bc_run(bc, &entry);
625 }
626 EXPORT_SYMBOL_GPL(inet_diag_bc_sk);
627
628 static int valid_cc(const void *bc, int len, int cc)
629 {
630         while (len >= 0) {
631                 const struct inet_diag_bc_op *op = bc;
632
633                 if (cc > len)
634                         return 0;
635                 if (cc == len)
636                         return 1;
637                 if (op->yes < 4 || op->yes & 3)
638                         return 0;
639                 len -= op->yes;
640                 bc  += op->yes;
641         }
642         return 0;
643 }
644
645 /* data is u32 ifindex */
646 static bool valid_devcond(const struct inet_diag_bc_op *op, int len,
647                           int *min_len)
648 {
649         /* Check ifindex space. */
650         *min_len += sizeof(u32);
651         if (len < *min_len)
652                 return false;
653
654         return true;
655 }
656 /* Validate an inet_diag_hostcond. */
657 static bool valid_hostcond(const struct inet_diag_bc_op *op, int len,
658                            int *min_len)
659 {
660         struct inet_diag_hostcond *cond;
661         int addr_len;
662
663         /* Check hostcond space. */
664         *min_len += sizeof(struct inet_diag_hostcond);
665         if (len < *min_len)
666                 return false;
667         cond = (struct inet_diag_hostcond *)(op + 1);
668
669         /* Check address family and address length. */
670         switch (cond->family) {
671         case AF_UNSPEC:
672                 addr_len = 0;
673                 break;
674         case AF_INET:
675                 addr_len = sizeof(struct in_addr);
676                 break;
677         case AF_INET6:
678                 addr_len = sizeof(struct in6_addr);
679                 break;
680         default:
681                 return false;
682         }
683         *min_len += addr_len;
684         if (len < *min_len)
685                 return false;
686
687         /* Check prefix length (in bits) vs address length (in bytes). */
688         if (cond->prefix_len > 8 * addr_len)
689                 return false;
690
691         return true;
692 }
693
694 /* Validate a port comparison operator. */
695 static bool valid_port_comparison(const struct inet_diag_bc_op *op,
696                                   int len, int *min_len)
697 {
698         /* Port comparisons put the port in a follow-on inet_diag_bc_op. */
699         *min_len += sizeof(struct inet_diag_bc_op);
700         if (len < *min_len)
701                 return false;
702         return true;
703 }
704
705 static bool valid_markcond(const struct inet_diag_bc_op *op, int len,
706                            int *min_len)
707 {
708         *min_len += sizeof(struct inet_diag_markcond);
709         return len >= *min_len;
710 }
711
712 static int inet_diag_bc_audit(const struct nlattr *attr,
713                               const struct sk_buff *skb)
714 {
715         bool net_admin = netlink_net_capable(skb, CAP_NET_ADMIN);
716         const void *bytecode, *bc;
717         int bytecode_len, len;
718
719         if (!attr || nla_len(attr) < sizeof(struct inet_diag_bc_op))
720                 return -EINVAL;
721
722         bytecode = bc = nla_data(attr);
723         len = bytecode_len = nla_len(attr);
724
725         while (len > 0) {
726                 int min_len = sizeof(struct inet_diag_bc_op);
727                 const struct inet_diag_bc_op *op = bc;
728
729                 switch (op->code) {
730                 case INET_DIAG_BC_S_COND:
731                 case INET_DIAG_BC_D_COND:
732                         if (!valid_hostcond(bc, len, &min_len))
733                                 return -EINVAL;
734                         break;
735                 case INET_DIAG_BC_DEV_COND:
736                         if (!valid_devcond(bc, len, &min_len))
737                                 return -EINVAL;
738                         break;
739                 case INET_DIAG_BC_S_GE:
740                 case INET_DIAG_BC_S_LE:
741                 case INET_DIAG_BC_D_GE:
742                 case INET_DIAG_BC_D_LE:
743                         if (!valid_port_comparison(bc, len, &min_len))
744                                 return -EINVAL;
745                         break;
746                 case INET_DIAG_BC_MARK_COND:
747                         if (!net_admin)
748                                 return -EPERM;
749                         if (!valid_markcond(bc, len, &min_len))
750                                 return -EINVAL;
751                         break;
752                 case INET_DIAG_BC_AUTO:
753                 case INET_DIAG_BC_JMP:
754                 case INET_DIAG_BC_NOP:
755                         break;
756                 default:
757                         return -EINVAL;
758                 }
759
760                 if (op->code != INET_DIAG_BC_NOP) {
761                         if (op->no < min_len || op->no > len + 4 || op->no & 3)
762                                 return -EINVAL;
763                         if (op->no < len &&
764                             !valid_cc(bytecode, bytecode_len, len - op->no))
765                                 return -EINVAL;
766                 }
767
768                 if (op->yes < min_len || op->yes > len + 4 || op->yes & 3)
769                         return -EINVAL;
770                 bc  += op->yes;
771                 len -= op->yes;
772         }
773         return len == 0 ? 0 : -EINVAL;
774 }
775
776 static int inet_csk_diag_dump(struct sock *sk,
777                               struct sk_buff *skb,
778                               struct netlink_callback *cb,
779                               const struct inet_diag_req_v2 *r,
780                               const struct nlattr *bc)
781 {
782         if (!inet_diag_bc_sk(bc, sk))
783                 return 0;
784
785         return inet_csk_diag_fill(sk, skb, r,
786                                   sk_user_ns(NETLINK_CB(cb->skb).sk),
787                                   NETLINK_CB(cb->skb).portid,
788                                   cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh);
789 }
790
791 static void twsk_build_assert(void)
792 {
793         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_family) !=
794                      offsetof(struct sock, sk_family));
795
796         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_num) !=
797                      offsetof(struct inet_sock, inet_num));
798
799         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_dport) !=
800                      offsetof(struct inet_sock, inet_dport));
801
802         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_rcv_saddr) !=
803                      offsetof(struct inet_sock, inet_rcv_saddr));
804
805         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_daddr) !=
806                      offsetof(struct inet_sock, inet_daddr));
807
808 #if IS_ENABLED(CONFIG_IPV6)
809         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_rcv_saddr) !=
810                      offsetof(struct sock, sk_v6_rcv_saddr));
811
812         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_daddr) !=
813                      offsetof(struct sock, sk_v6_daddr));
814 #endif
815 }
816
817 void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
818                          struct netlink_callback *cb,
819                          const struct inet_diag_req_v2 *r, struct nlattr *bc)
820 {
821         struct net *net = sock_net(skb->sk);
822         int i, num, s_i, s_num;
823         u32 idiag_states = r->idiag_states;
824
825         if (idiag_states & TCPF_SYN_RECV)
826                 idiag_states |= TCPF_NEW_SYN_RECV;
827         s_i = cb->args[1];
828         s_num = num = cb->args[2];
829
830         if (cb->args[0] == 0) {
831                 if (!(idiag_states & TCPF_LISTEN))
832                         goto skip_listen_ht;
833
834                 for (i = s_i; i < INET_LHTABLE_SIZE; i++) {
835                         struct inet_listen_hashbucket *ilb;
836                         struct hlist_nulls_node *node;
837                         struct sock *sk;
838
839                         num = 0;
840                         ilb = &hashinfo->listening_hash[i];
841                         spin_lock_bh(&ilb->lock);
842                         sk_nulls_for_each(sk, node, &ilb->head) {
843                                 struct inet_sock *inet = inet_sk(sk);
844
845                                 if (!net_eq(sock_net(sk), net))
846                                         continue;
847
848                                 if (num < s_num) {
849                                         num++;
850                                         continue;
851                                 }
852
853                                 if (r->sdiag_family != AF_UNSPEC &&
854                                     sk->sk_family != r->sdiag_family)
855                                         goto next_listen;
856
857                                 if (r->id.idiag_sport != inet->inet_sport &&
858                                     r->id.idiag_sport)
859                                         goto next_listen;
860
861                                 if (r->id.idiag_dport ||
862                                     cb->args[3] > 0)
863                                         goto next_listen;
864
865                                 if (inet_csk_diag_dump(sk, skb, cb, r, bc) < 0) {
866                                         spin_unlock_bh(&ilb->lock);
867                                         goto done;
868                                 }
869
870 next_listen:
871                                 cb->args[3] = 0;
872                                 cb->args[4] = 0;
873                                 ++num;
874                         }
875                         spin_unlock_bh(&ilb->lock);
876
877                         s_num = 0;
878                         cb->args[3] = 0;
879                         cb->args[4] = 0;
880                 }
881 skip_listen_ht:
882                 cb->args[0] = 1;
883                 s_i = num = s_num = 0;
884         }
885
886         if (!(idiag_states & ~TCPF_LISTEN))
887                 goto out;
888
889         for (i = s_i; i <= hashinfo->ehash_mask; i++) {
890                 struct inet_ehash_bucket *head = &hashinfo->ehash[i];
891                 spinlock_t *lock = inet_ehash_lockp(hashinfo, i);
892                 struct hlist_nulls_node *node;
893                 struct sock *sk;
894
895                 num = 0;
896
897                 if (hlist_nulls_empty(&head->chain))
898                         continue;
899
900                 if (i > s_i)
901                         s_num = 0;
902
903                 spin_lock_bh(lock);
904                 sk_nulls_for_each(sk, node, &head->chain) {
905                         int state, res;
906
907                         if (!net_eq(sock_net(sk), net))
908                                 continue;
909                         if (num < s_num)
910                                 goto next_normal;
911                         state = (sk->sk_state == TCP_TIME_WAIT) ?
912                                 inet_twsk(sk)->tw_substate : sk->sk_state;
913                         if (!(idiag_states & (1 << state)))
914                                 goto next_normal;
915                         if (r->sdiag_family != AF_UNSPEC &&
916                             sk->sk_family != r->sdiag_family)
917                                 goto next_normal;
918                         if (r->id.idiag_sport != htons(sk->sk_num) &&
919                             r->id.idiag_sport)
920                                 goto next_normal;
921                         if (r->id.idiag_dport != sk->sk_dport &&
922                             r->id.idiag_dport)
923                                 goto next_normal;
924                         twsk_build_assert();
925
926                         if (!inet_diag_bc_sk(bc, sk))
927                                 goto next_normal;
928
929                         res = sk_diag_fill(sk, skb, r,
930                                            sk_user_ns(NETLINK_CB(cb->skb).sk),
931                                            NETLINK_CB(cb->skb).portid,
932                                            cb->nlh->nlmsg_seq, NLM_F_MULTI,
933                                            cb->nlh);
934                         if (res < 0) {
935                                 spin_unlock_bh(lock);
936                                 goto done;
937                         }
938 next_normal:
939                         ++num;
940                 }
941
942                 spin_unlock_bh(lock);
943         }
944
945 done:
946         cb->args[1] = i;
947         cb->args[2] = num;
948 out:
949         ;
950 }
951 EXPORT_SYMBOL_GPL(inet_diag_dump_icsk);
952
953 static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
954                             const struct inet_diag_req_v2 *r,
955                             struct nlattr *bc)
956 {
957         const struct inet_diag_handler *handler;
958         int err = 0;
959
960         handler = inet_diag_lock_handler(r->sdiag_protocol);
961         if (!IS_ERR(handler))
962                 handler->dump(skb, cb, r, bc);
963         else
964                 err = PTR_ERR(handler);
965         inet_diag_unlock_handler(handler);
966
967         return err ? : skb->len;
968 }
969
970 static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
971 {
972         int hdrlen = sizeof(struct inet_diag_req_v2);
973         struct nlattr *bc = NULL;
974
975         if (nlmsg_attrlen(cb->nlh, hdrlen))
976                 bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE);
977
978         return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh), bc);
979 }
980
981 static int inet_diag_type2proto(int type)
982 {
983         switch (type) {
984         case TCPDIAG_GETSOCK:
985                 return IPPROTO_TCP;
986         case DCCPDIAG_GETSOCK:
987                 return IPPROTO_DCCP;
988         default:
989                 return 0;
990         }
991 }
992
993 static int inet_diag_dump_compat(struct sk_buff *skb,
994                                  struct netlink_callback *cb)
995 {
996         struct inet_diag_req *rc = nlmsg_data(cb->nlh);
997         int hdrlen = sizeof(struct inet_diag_req);
998         struct inet_diag_req_v2 req;
999         struct nlattr *bc = NULL;
1000
1001         req.sdiag_family = AF_UNSPEC; /* compatibility */
1002         req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type);
1003         req.idiag_ext = rc->idiag_ext;
1004         req.idiag_states = rc->idiag_states;
1005         req.id = rc->id;
1006
1007         if (nlmsg_attrlen(cb->nlh, hdrlen))
1008                 bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE);
1009
1010         return __inet_diag_dump(skb, cb, &req, bc);
1011 }
1012
1013 static int inet_diag_get_exact_compat(struct sk_buff *in_skb,
1014                                       const struct nlmsghdr *nlh)
1015 {
1016         struct inet_diag_req *rc = nlmsg_data(nlh);
1017         struct inet_diag_req_v2 req;
1018
1019         req.sdiag_family = rc->idiag_family;
1020         req.sdiag_protocol = inet_diag_type2proto(nlh->nlmsg_type);
1021         req.idiag_ext = rc->idiag_ext;
1022         req.idiag_states = rc->idiag_states;
1023         req.id = rc->id;
1024
1025         return inet_diag_cmd_exact(SOCK_DIAG_BY_FAMILY, in_skb, nlh, &req);
1026 }
1027
1028 static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
1029 {
1030         int hdrlen = sizeof(struct inet_diag_req);
1031         struct net *net = sock_net(skb->sk);
1032
1033         if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX ||
1034             nlmsg_len(nlh) < hdrlen)
1035                 return -EINVAL;
1036
1037         if (nlh->nlmsg_flags & NLM_F_DUMP) {
1038                 if (nlmsg_attrlen(nlh, hdrlen)) {
1039                         struct nlattr *attr;
1040                         int err;
1041
1042                         attr = nlmsg_find_attr(nlh, hdrlen,
1043                                                INET_DIAG_REQ_BYTECODE);
1044                         err = inet_diag_bc_audit(attr, skb);
1045                         if (err)
1046                                 return err;
1047                 }
1048                 {
1049                         struct netlink_dump_control c = {
1050                                 .dump = inet_diag_dump_compat,
1051                         };
1052                         return netlink_dump_start(net->diag_nlsk, skb, nlh, &c);
1053                 }
1054         }
1055
1056         return inet_diag_get_exact_compat(skb, nlh);
1057 }
1058
1059 static int inet_diag_handler_cmd(struct sk_buff *skb, struct nlmsghdr *h)
1060 {
1061         int hdrlen = sizeof(struct inet_diag_req_v2);
1062         struct net *net = sock_net(skb->sk);
1063
1064         if (nlmsg_len(h) < hdrlen)
1065                 return -EINVAL;
1066
1067         if (h->nlmsg_type == SOCK_DIAG_BY_FAMILY &&
1068             h->nlmsg_flags & NLM_F_DUMP) {
1069                 if (nlmsg_attrlen(h, hdrlen)) {
1070                         struct nlattr *attr;
1071                         int err;
1072
1073                         attr = nlmsg_find_attr(h, hdrlen,
1074                                                INET_DIAG_REQ_BYTECODE);
1075                         err = inet_diag_bc_audit(attr, skb);
1076                         if (err)
1077                                 return err;
1078                 }
1079                 {
1080                         struct netlink_dump_control c = {
1081                                 .dump = inet_diag_dump,
1082                         };
1083                         return netlink_dump_start(net->diag_nlsk, skb, h, &c);
1084                 }
1085         }
1086
1087         return inet_diag_cmd_exact(h->nlmsg_type, skb, h, nlmsg_data(h));
1088 }
1089
1090 static
1091 int inet_diag_handler_get_info(struct sk_buff *skb, struct sock *sk)
1092 {
1093         const struct inet_diag_handler *handler;
1094         struct nlmsghdr *nlh;
1095         struct nlattr *attr;
1096         struct inet_diag_msg *r;
1097         void *info = NULL;
1098         int err = 0;
1099
1100         nlh = nlmsg_put(skb, 0, 0, SOCK_DIAG_BY_FAMILY, sizeof(*r), 0);
1101         if (!nlh)
1102                 return -ENOMEM;
1103
1104         r = nlmsg_data(nlh);
1105         memset(r, 0, sizeof(*r));
1106         inet_diag_msg_common_fill(r, sk);
1107         if (sk->sk_type == SOCK_DGRAM || sk->sk_type == SOCK_STREAM)
1108                 r->id.idiag_sport = inet_sk(sk)->inet_sport;
1109         r->idiag_state = sk->sk_state;
1110
1111         if ((err = nla_put_u8(skb, INET_DIAG_PROTOCOL, sk->sk_protocol))) {
1112                 nlmsg_cancel(skb, nlh);
1113                 return err;
1114         }
1115
1116         handler = inet_diag_lock_handler(sk->sk_protocol);
1117         if (IS_ERR(handler)) {
1118                 inet_diag_unlock_handler(handler);
1119                 nlmsg_cancel(skb, nlh);
1120                 return PTR_ERR(handler);
1121         }
1122
1123         attr = handler->idiag_info_size
1124                 ? nla_reserve(skb, INET_DIAG_INFO, handler->idiag_info_size)
1125                 : NULL;
1126         if (attr)
1127                 info = nla_data(attr);
1128
1129         handler->idiag_get_info(sk, r, info);
1130         inet_diag_unlock_handler(handler);
1131
1132         nlmsg_end(skb, nlh);
1133         return 0;
1134 }
1135
1136 static const struct sock_diag_handler inet_diag_handler = {
1137         .family = AF_INET,
1138         .dump = inet_diag_handler_cmd,
1139         .get_info = inet_diag_handler_get_info,
1140         .destroy = inet_diag_handler_cmd,
1141 };
1142
1143 static const struct sock_diag_handler inet6_diag_handler = {
1144         .family = AF_INET6,
1145         .dump = inet_diag_handler_cmd,
1146         .get_info = inet_diag_handler_get_info,
1147         .destroy = inet_diag_handler_cmd,
1148 };
1149
1150 int inet_diag_register(const struct inet_diag_handler *h)
1151 {
1152         const __u16 type = h->idiag_type;
1153         int err = -EINVAL;
1154
1155         if (type >= IPPROTO_MAX)
1156                 goto out;
1157
1158         mutex_lock(&inet_diag_table_mutex);
1159         err = -EEXIST;
1160         if (!inet_diag_table[type]) {
1161                 inet_diag_table[type] = h;
1162                 err = 0;
1163         }
1164         mutex_unlock(&inet_diag_table_mutex);
1165 out:
1166         return err;
1167 }
1168 EXPORT_SYMBOL_GPL(inet_diag_register);
1169
1170 void inet_diag_unregister(const struct inet_diag_handler *h)
1171 {
1172         const __u16 type = h->idiag_type;
1173
1174         if (type >= IPPROTO_MAX)
1175                 return;
1176
1177         mutex_lock(&inet_diag_table_mutex);
1178         inet_diag_table[type] = NULL;
1179         mutex_unlock(&inet_diag_table_mutex);
1180 }
1181 EXPORT_SYMBOL_GPL(inet_diag_unregister);
1182
1183 static int __init inet_diag_init(void)
1184 {
1185         const int inet_diag_table_size = (IPPROTO_MAX *
1186                                           sizeof(struct inet_diag_handler *));
1187         int err = -ENOMEM;
1188
1189         inet_diag_table = kzalloc(inet_diag_table_size, GFP_KERNEL);
1190         if (!inet_diag_table)
1191                 goto out;
1192
1193         err = sock_diag_register(&inet_diag_handler);
1194         if (err)
1195                 goto out_free_nl;
1196
1197         err = sock_diag_register(&inet6_diag_handler);
1198         if (err)
1199                 goto out_free_inet;
1200
1201         sock_diag_register_inet_compat(inet_diag_rcv_msg_compat);
1202 out:
1203         return err;
1204
1205 out_free_inet:
1206         sock_diag_unregister(&inet_diag_handler);
1207 out_free_nl:
1208         kfree(inet_diag_table);
1209         goto out;
1210 }
1211
1212 static void __exit inet_diag_exit(void)
1213 {
1214         sock_diag_unregister(&inet6_diag_handler);
1215         sock_diag_unregister(&inet_diag_handler);
1216         sock_diag_unregister_inet_compat(inet_diag_rcv_msg_compat);
1217         kfree(inet_diag_table);
1218 }
1219
1220 module_init(inet_diag_init);
1221 module_exit(inet_diag_exit);
1222 MODULE_LICENSE("GPL");
1223 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2 /* AF_INET */);
1224 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 10 /* AF_INET6 */);