Merge tag 'gpio-v4.21-1' of git://git.kernel.org/pub/scm/linux/kernel/git/linusw...
[linux] / net / netfilter / nf_conncount.c
1 /*
2  * count the number of connections matching an arbitrary key.
3  *
4  * (C) 2017 Red Hat GmbH
5  * Author: Florian Westphal <fw@strlen.de>
6  *
7  * split from xt_connlimit.c:
8  *   (c) 2000 Gerd Knorr <kraxel@bytesex.org>
9  *   Nov 2002: Martin Bene <martin.bene@icomedias.com>:
10  *              only ignore TIME_WAIT or gone connections
11  *   (C) CC Computer Consultants GmbH, 2007
12  */
13 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
14 #include <linux/in.h>
15 #include <linux/in6.h>
16 #include <linux/ip.h>
17 #include <linux/ipv6.h>
18 #include <linux/jhash.h>
19 #include <linux/slab.h>
20 #include <linux/list.h>
21 #include <linux/rbtree.h>
22 #include <linux/module.h>
23 #include <linux/random.h>
24 #include <linux/skbuff.h>
25 #include <linux/spinlock.h>
26 #include <linux/netfilter/nf_conntrack_tcp.h>
27 #include <linux/netfilter/x_tables.h>
28 #include <net/netfilter/nf_conntrack.h>
29 #include <net/netfilter/nf_conntrack_count.h>
30 #include <net/netfilter/nf_conntrack_core.h>
31 #include <net/netfilter/nf_conntrack_tuple.h>
32 #include <net/netfilter/nf_conntrack_zones.h>
33
34 #define CONNCOUNT_SLOTS         256U
35
36 #ifdef CONFIG_LOCKDEP
37 #define CONNCOUNT_LOCK_SLOTS    8U
38 #else
39 #define CONNCOUNT_LOCK_SLOTS    256U
40 #endif
41
42 #define CONNCOUNT_GC_MAX_NODES  8
43 #define MAX_KEYLEN              5
44
45 /* we will save the tuples of all connections we care about */
46 struct nf_conncount_tuple {
47         struct list_head                node;
48         struct nf_conntrack_tuple       tuple;
49         struct nf_conntrack_zone        zone;
50         int                             cpu;
51         u32                             jiffies32;
52         bool                            dead;
53         struct rcu_head                 rcu_head;
54 };
55
56 struct nf_conncount_rb {
57         struct rb_node node;
58         struct nf_conncount_list list;
59         u32 key[MAX_KEYLEN];
60         struct rcu_head rcu_head;
61 };
62
63 static spinlock_t nf_conncount_locks[CONNCOUNT_LOCK_SLOTS] __cacheline_aligned_in_smp;
64
65 struct nf_conncount_data {
66         unsigned int keylen;
67         struct rb_root root[CONNCOUNT_SLOTS];
68         struct net *net;
69         struct work_struct gc_work;
70         unsigned long pending_trees[BITS_TO_LONGS(CONNCOUNT_SLOTS)];
71         unsigned int gc_tree;
72 };
73
74 static u_int32_t conncount_rnd __read_mostly;
75 static struct kmem_cache *conncount_rb_cachep __read_mostly;
76 static struct kmem_cache *conncount_conn_cachep __read_mostly;
77
78 static inline bool already_closed(const struct nf_conn *conn)
79 {
80         if (nf_ct_protonum(conn) == IPPROTO_TCP)
81                 return conn->proto.tcp.state == TCP_CONNTRACK_TIME_WAIT ||
82                        conn->proto.tcp.state == TCP_CONNTRACK_CLOSE;
83         else
84                 return false;
85 }
86
87 static int key_diff(const u32 *a, const u32 *b, unsigned int klen)
88 {
89         return memcmp(a, b, klen * sizeof(u32));
90 }
91
92 enum nf_conncount_list_add
93 nf_conncount_add(struct nf_conncount_list *list,
94                  const struct nf_conntrack_tuple *tuple,
95                  const struct nf_conntrack_zone *zone)
96 {
97         struct nf_conncount_tuple *conn;
98
99         if (WARN_ON_ONCE(list->count > INT_MAX))
100                 return NF_CONNCOUNT_ERR;
101
102         conn = kmem_cache_alloc(conncount_conn_cachep, GFP_ATOMIC);
103         if (conn == NULL)
104                 return NF_CONNCOUNT_ERR;
105
106         conn->tuple = *tuple;
107         conn->zone = *zone;
108         conn->cpu = raw_smp_processor_id();
109         conn->jiffies32 = (u32)jiffies;
110         conn->dead = false;
111         spin_lock_bh(&list->list_lock);
112         if (list->dead == true) {
113                 kmem_cache_free(conncount_conn_cachep, conn);
114                 spin_unlock_bh(&list->list_lock);
115                 return NF_CONNCOUNT_SKIP;
116         }
117         list_add_tail(&conn->node, &list->head);
118         list->count++;
119         spin_unlock_bh(&list->list_lock);
120         return NF_CONNCOUNT_ADDED;
121 }
122 EXPORT_SYMBOL_GPL(nf_conncount_add);
123
124 static void __conn_free(struct rcu_head *h)
125 {
126         struct nf_conncount_tuple *conn;
127
128         conn = container_of(h, struct nf_conncount_tuple, rcu_head);
129         kmem_cache_free(conncount_conn_cachep, conn);
130 }
131
132 static bool conn_free(struct nf_conncount_list *list,
133                       struct nf_conncount_tuple *conn)
134 {
135         bool free_entry = false;
136
137         spin_lock_bh(&list->list_lock);
138
139         if (conn->dead) {
140                 spin_unlock_bh(&list->list_lock);
141                 return free_entry;
142         }
143
144         list->count--;
145         conn->dead = true;
146         list_del_rcu(&conn->node);
147         if (list->count == 0) {
148                 list->dead = true;
149                 free_entry = true;
150         }
151
152         spin_unlock_bh(&list->list_lock);
153         call_rcu(&conn->rcu_head, __conn_free);
154         return free_entry;
155 }
156
157 static const struct nf_conntrack_tuple_hash *
158 find_or_evict(struct net *net, struct nf_conncount_list *list,
159               struct nf_conncount_tuple *conn, bool *free_entry)
160 {
161         const struct nf_conntrack_tuple_hash *found;
162         unsigned long a, b;
163         int cpu = raw_smp_processor_id();
164         __s32 age;
165
166         found = nf_conntrack_find_get(net, &conn->zone, &conn->tuple);
167         if (found)
168                 return found;
169         b = conn->jiffies32;
170         a = (u32)jiffies;
171
172         /* conn might have been added just before by another cpu and
173          * might still be unconfirmed.  In this case, nf_conntrack_find()
174          * returns no result.  Thus only evict if this cpu added the
175          * stale entry or if the entry is older than two jiffies.
176          */
177         age = a - b;
178         if (conn->cpu == cpu || age >= 2) {
179                 *free_entry = conn_free(list, conn);
180                 return ERR_PTR(-ENOENT);
181         }
182
183         return ERR_PTR(-EAGAIN);
184 }
185
186 void nf_conncount_lookup(struct net *net,
187                          struct nf_conncount_list *list,
188                          const struct nf_conntrack_tuple *tuple,
189                          const struct nf_conntrack_zone *zone,
190                          bool *addit)
191 {
192         const struct nf_conntrack_tuple_hash *found;
193         struct nf_conncount_tuple *conn, *conn_n;
194         struct nf_conn *found_ct;
195         unsigned int collect = 0;
196         bool free_entry = false;
197
198         /* best effort only */
199         *addit = tuple ? true : false;
200
201         /* check the saved connections */
202         list_for_each_entry_safe(conn, conn_n, &list->head, node) {
203                 if (collect > CONNCOUNT_GC_MAX_NODES)
204                         break;
205
206                 found = find_or_evict(net, list, conn, &free_entry);
207                 if (IS_ERR(found)) {
208                         /* Not found, but might be about to be confirmed */
209                         if (PTR_ERR(found) == -EAGAIN) {
210                                 if (!tuple)
211                                         continue;
212
213                                 if (nf_ct_tuple_equal(&conn->tuple, tuple) &&
214                                     nf_ct_zone_id(&conn->zone, conn->zone.dir) ==
215                                     nf_ct_zone_id(zone, zone->dir))
216                                         *addit = false;
217                         } else if (PTR_ERR(found) == -ENOENT)
218                                 collect++;
219                         continue;
220                 }
221
222                 found_ct = nf_ct_tuplehash_to_ctrack(found);
223
224                 if (tuple && nf_ct_tuple_equal(&conn->tuple, tuple) &&
225                     nf_ct_zone_equal(found_ct, zone, zone->dir)) {
226                         /*
227                          * We should not see tuples twice unless someone hooks
228                          * this into a table without "-p tcp --syn".
229                          *
230                          * Attempt to avoid a re-add in this case.
231                          */
232                         *addit = false;
233                 } else if (already_closed(found_ct)) {
234                         /*
235                          * we do not care about connections which are
236                          * closed already -> ditch it
237                          */
238                         nf_ct_put(found_ct);
239                         conn_free(list, conn);
240                         collect++;
241                         continue;
242                 }
243
244                 nf_ct_put(found_ct);
245         }
246 }
247 EXPORT_SYMBOL_GPL(nf_conncount_lookup);
248
249 void nf_conncount_list_init(struct nf_conncount_list *list)
250 {
251         spin_lock_init(&list->list_lock);
252         INIT_LIST_HEAD(&list->head);
253         list->count = 0;
254         list->dead = false;
255 }
256 EXPORT_SYMBOL_GPL(nf_conncount_list_init);
257
258 /* Return true if the list is empty */
259 bool nf_conncount_gc_list(struct net *net,
260                           struct nf_conncount_list *list)
261 {
262         const struct nf_conntrack_tuple_hash *found;
263         struct nf_conncount_tuple *conn, *conn_n;
264         struct nf_conn *found_ct;
265         unsigned int collected = 0;
266         bool free_entry = false;
267         bool ret = false;
268
269         list_for_each_entry_safe(conn, conn_n, &list->head, node) {
270                 found = find_or_evict(net, list, conn, &free_entry);
271                 if (IS_ERR(found)) {
272                         if (PTR_ERR(found) == -ENOENT)  {
273                                 if (free_entry)
274                                         return true;
275                                 collected++;
276                         }
277                         continue;
278                 }
279
280                 found_ct = nf_ct_tuplehash_to_ctrack(found);
281                 if (already_closed(found_ct)) {
282                         /*
283                          * we do not care about connections which are
284                          * closed already -> ditch it
285                          */
286                         nf_ct_put(found_ct);
287                         if (conn_free(list, conn))
288                                 return true;
289                         collected++;
290                         continue;
291                 }
292
293                 nf_ct_put(found_ct);
294                 if (collected > CONNCOUNT_GC_MAX_NODES)
295                         return false;
296         }
297
298         spin_lock_bh(&list->list_lock);
299         if (!list->count) {
300                 list->dead = true;
301                 ret = true;
302         }
303         spin_unlock_bh(&list->list_lock);
304
305         return ret;
306 }
307 EXPORT_SYMBOL_GPL(nf_conncount_gc_list);
308
309 static void __tree_nodes_free(struct rcu_head *h)
310 {
311         struct nf_conncount_rb *rbconn;
312
313         rbconn = container_of(h, struct nf_conncount_rb, rcu_head);
314         kmem_cache_free(conncount_rb_cachep, rbconn);
315 }
316
317 static void tree_nodes_free(struct rb_root *root,
318                             struct nf_conncount_rb *gc_nodes[],
319                             unsigned int gc_count)
320 {
321         struct nf_conncount_rb *rbconn;
322
323         while (gc_count) {
324                 rbconn = gc_nodes[--gc_count];
325                 spin_lock(&rbconn->list.list_lock);
326                 rb_erase(&rbconn->node, root);
327                 call_rcu(&rbconn->rcu_head, __tree_nodes_free);
328                 spin_unlock(&rbconn->list.list_lock);
329         }
330 }
331
332 static void schedule_gc_worker(struct nf_conncount_data *data, int tree)
333 {
334         set_bit(tree, data->pending_trees);
335         schedule_work(&data->gc_work);
336 }
337
338 static unsigned int
339 insert_tree(struct net *net,
340             struct nf_conncount_data *data,
341             struct rb_root *root,
342             unsigned int hash,
343             const u32 *key,
344             u8 keylen,
345             const struct nf_conntrack_tuple *tuple,
346             const struct nf_conntrack_zone *zone)
347 {
348         enum nf_conncount_list_add ret;
349         struct nf_conncount_rb *gc_nodes[CONNCOUNT_GC_MAX_NODES];
350         struct rb_node **rbnode, *parent;
351         struct nf_conncount_rb *rbconn;
352         struct nf_conncount_tuple *conn;
353         unsigned int count = 0, gc_count = 0;
354         bool node_found = false;
355
356         spin_lock_bh(&nf_conncount_locks[hash % CONNCOUNT_LOCK_SLOTS]);
357
358         parent = NULL;
359         rbnode = &(root->rb_node);
360         while (*rbnode) {
361                 int diff;
362                 rbconn = rb_entry(*rbnode, struct nf_conncount_rb, node);
363
364                 parent = *rbnode;
365                 diff = key_diff(key, rbconn->key, keylen);
366                 if (diff < 0) {
367                         rbnode = &((*rbnode)->rb_left);
368                 } else if (diff > 0) {
369                         rbnode = &((*rbnode)->rb_right);
370                 } else {
371                         /* unlikely: other cpu added node already */
372                         node_found = true;
373                         ret = nf_conncount_add(&rbconn->list, tuple, zone);
374                         if (ret == NF_CONNCOUNT_ERR) {
375                                 count = 0; /* hotdrop */
376                         } else if (ret == NF_CONNCOUNT_ADDED) {
377                                 count = rbconn->list.count;
378                         } else {
379                                 /* NF_CONNCOUNT_SKIP, rbconn is already
380                                  * reclaimed by gc, insert a new tree node
381                                  */
382                                 node_found = false;
383                         }
384                         break;
385                 }
386
387                 if (gc_count >= ARRAY_SIZE(gc_nodes))
388                         continue;
389
390                 if (nf_conncount_gc_list(net, &rbconn->list))
391                         gc_nodes[gc_count++] = rbconn;
392         }
393
394         if (gc_count) {
395                 tree_nodes_free(root, gc_nodes, gc_count);
396                 /* tree_node_free before new allocation permits
397                  * allocator to re-use newly free'd object.
398                  *
399                  * This is a rare event; in most cases we will find
400                  * existing node to re-use. (or gc_count is 0).
401                  */
402
403                 if (gc_count >= ARRAY_SIZE(gc_nodes))
404                         schedule_gc_worker(data, hash);
405         }
406
407         if (node_found)
408                 goto out_unlock;
409
410         /* expected case: match, insert new node */
411         rbconn = kmem_cache_alloc(conncount_rb_cachep, GFP_ATOMIC);
412         if (rbconn == NULL)
413                 goto out_unlock;
414
415         conn = kmem_cache_alloc(conncount_conn_cachep, GFP_ATOMIC);
416         if (conn == NULL) {
417                 kmem_cache_free(conncount_rb_cachep, rbconn);
418                 goto out_unlock;
419         }
420
421         conn->tuple = *tuple;
422         conn->zone = *zone;
423         memcpy(rbconn->key, key, sizeof(u32) * keylen);
424
425         nf_conncount_list_init(&rbconn->list);
426         list_add(&conn->node, &rbconn->list.head);
427         count = 1;
428         rbconn->list.count = count;
429
430         rb_link_node_rcu(&rbconn->node, parent, rbnode);
431         rb_insert_color(&rbconn->node, root);
432 out_unlock:
433         spin_unlock_bh(&nf_conncount_locks[hash % CONNCOUNT_LOCK_SLOTS]);
434         return count;
435 }
436
437 static unsigned int
438 count_tree(struct net *net,
439            struct nf_conncount_data *data,
440            const u32 *key,
441            const struct nf_conntrack_tuple *tuple,
442            const struct nf_conntrack_zone *zone)
443 {
444         enum nf_conncount_list_add ret;
445         struct rb_root *root;
446         struct rb_node *parent;
447         struct nf_conncount_rb *rbconn;
448         unsigned int hash;
449         u8 keylen = data->keylen;
450
451         hash = jhash2(key, data->keylen, conncount_rnd) % CONNCOUNT_SLOTS;
452         root = &data->root[hash];
453
454         parent = rcu_dereference_raw(root->rb_node);
455         while (parent) {
456                 int diff;
457                 bool addit;
458
459                 rbconn = rb_entry(parent, struct nf_conncount_rb, node);
460
461                 diff = key_diff(key, rbconn->key, keylen);
462                 if (diff < 0) {
463                         parent = rcu_dereference_raw(parent->rb_left);
464                 } else if (diff > 0) {
465                         parent = rcu_dereference_raw(parent->rb_right);
466                 } else {
467                         /* same source network -> be counted! */
468                         nf_conncount_lookup(net, &rbconn->list, tuple, zone,
469                                             &addit);
470
471                         if (!addit)
472                                 return rbconn->list.count;
473
474                         ret = nf_conncount_add(&rbconn->list, tuple, zone);
475                         if (ret == NF_CONNCOUNT_ERR) {
476                                 return 0; /* hotdrop */
477                         } else if (ret == NF_CONNCOUNT_ADDED) {
478                                 return rbconn->list.count;
479                         } else {
480                                 /* NF_CONNCOUNT_SKIP, rbconn is already
481                                  * reclaimed by gc, insert a new tree node
482                                  */
483                                 break;
484                         }
485                 }
486         }
487
488         if (!tuple)
489                 return 0;
490
491         return insert_tree(net, data, root, hash, key, keylen, tuple, zone);
492 }
493
494 static void tree_gc_worker(struct work_struct *work)
495 {
496         struct nf_conncount_data *data = container_of(work, struct nf_conncount_data, gc_work);
497         struct nf_conncount_rb *gc_nodes[CONNCOUNT_GC_MAX_NODES], *rbconn;
498         struct rb_root *root;
499         struct rb_node *node;
500         unsigned int tree, next_tree, gc_count = 0;
501
502         tree = data->gc_tree % CONNCOUNT_LOCK_SLOTS;
503         root = &data->root[tree];
504
505         rcu_read_lock();
506         for (node = rb_first(root); node != NULL; node = rb_next(node)) {
507                 rbconn = rb_entry(node, struct nf_conncount_rb, node);
508                 if (nf_conncount_gc_list(data->net, &rbconn->list))
509                         gc_nodes[gc_count++] = rbconn;
510         }
511         rcu_read_unlock();
512
513         spin_lock_bh(&nf_conncount_locks[tree]);
514
515         if (gc_count) {
516                 tree_nodes_free(root, gc_nodes, gc_count);
517         }
518
519         clear_bit(tree, data->pending_trees);
520
521         next_tree = (tree + 1) % CONNCOUNT_SLOTS;
522         next_tree = find_next_bit(data->pending_trees, next_tree, CONNCOUNT_SLOTS);
523
524         if (next_tree < CONNCOUNT_SLOTS) {
525                 data->gc_tree = next_tree;
526                 schedule_work(work);
527         }
528
529         spin_unlock_bh(&nf_conncount_locks[tree]);
530 }
531
532 /* Count and return number of conntrack entries in 'net' with particular 'key'.
533  * If 'tuple' is not null, insert it into the accounting data structure.
534  * Call with RCU read lock.
535  */
536 unsigned int nf_conncount_count(struct net *net,
537                                 struct nf_conncount_data *data,
538                                 const u32 *key,
539                                 const struct nf_conntrack_tuple *tuple,
540                                 const struct nf_conntrack_zone *zone)
541 {
542         return count_tree(net, data, key, tuple, zone);
543 }
544 EXPORT_SYMBOL_GPL(nf_conncount_count);
545
546 struct nf_conncount_data *nf_conncount_init(struct net *net, unsigned int family,
547                                             unsigned int keylen)
548 {
549         struct nf_conncount_data *data;
550         int ret, i;
551
552         if (keylen % sizeof(u32) ||
553             keylen / sizeof(u32) > MAX_KEYLEN ||
554             keylen == 0)
555                 return ERR_PTR(-EINVAL);
556
557         net_get_random_once(&conncount_rnd, sizeof(conncount_rnd));
558
559         data = kmalloc(sizeof(*data), GFP_KERNEL);
560         if (!data)
561                 return ERR_PTR(-ENOMEM);
562
563         ret = nf_ct_netns_get(net, family);
564         if (ret < 0) {
565                 kfree(data);
566                 return ERR_PTR(ret);
567         }
568
569         for (i = 0; i < ARRAY_SIZE(data->root); ++i)
570                 data->root[i] = RB_ROOT;
571
572         data->keylen = keylen / sizeof(u32);
573         data->net = net;
574         INIT_WORK(&data->gc_work, tree_gc_worker);
575
576         return data;
577 }
578 EXPORT_SYMBOL_GPL(nf_conncount_init);
579
580 void nf_conncount_cache_free(struct nf_conncount_list *list)
581 {
582         struct nf_conncount_tuple *conn, *conn_n;
583
584         list_for_each_entry_safe(conn, conn_n, &list->head, node)
585                 kmem_cache_free(conncount_conn_cachep, conn);
586 }
587 EXPORT_SYMBOL_GPL(nf_conncount_cache_free);
588
589 static void destroy_tree(struct rb_root *r)
590 {
591         struct nf_conncount_rb *rbconn;
592         struct rb_node *node;
593
594         while ((node = rb_first(r)) != NULL) {
595                 rbconn = rb_entry(node, struct nf_conncount_rb, node);
596
597                 rb_erase(node, r);
598
599                 nf_conncount_cache_free(&rbconn->list);
600
601                 kmem_cache_free(conncount_rb_cachep, rbconn);
602         }
603 }
604
605 void nf_conncount_destroy(struct net *net, unsigned int family,
606                           struct nf_conncount_data *data)
607 {
608         unsigned int i;
609
610         cancel_work_sync(&data->gc_work);
611         nf_ct_netns_put(net, family);
612
613         for (i = 0; i < ARRAY_SIZE(data->root); ++i)
614                 destroy_tree(&data->root[i]);
615
616         kfree(data);
617 }
618 EXPORT_SYMBOL_GPL(nf_conncount_destroy);
619
620 static int __init nf_conncount_modinit(void)
621 {
622         int i;
623
624         BUILD_BUG_ON(CONNCOUNT_LOCK_SLOTS > CONNCOUNT_SLOTS);
625         BUILD_BUG_ON((CONNCOUNT_SLOTS % CONNCOUNT_LOCK_SLOTS) != 0);
626
627         for (i = 0; i < CONNCOUNT_LOCK_SLOTS; ++i)
628                 spin_lock_init(&nf_conncount_locks[i]);
629
630         conncount_conn_cachep = kmem_cache_create("nf_conncount_tuple",
631                                            sizeof(struct nf_conncount_tuple),
632                                            0, 0, NULL);
633         if (!conncount_conn_cachep)
634                 return -ENOMEM;
635
636         conncount_rb_cachep = kmem_cache_create("nf_conncount_rb",
637                                            sizeof(struct nf_conncount_rb),
638                                            0, 0, NULL);
639         if (!conncount_rb_cachep) {
640                 kmem_cache_destroy(conncount_conn_cachep);
641                 return -ENOMEM;
642         }
643
644         return 0;
645 }
646
647 static void __exit nf_conncount_modexit(void)
648 {
649         kmem_cache_destroy(conncount_conn_cachep);
650         kmem_cache_destroy(conncount_rb_cachep);
651 }
652
653 module_init(nf_conncount_modinit);
654 module_exit(nf_conncount_modexit);
655 MODULE_AUTHOR("Jan Engelhardt <jengelh@medozas.de>");
656 MODULE_AUTHOR("Florian Westphal <fw@strlen.de>");
657 MODULE_DESCRIPTION("netfilter: count number of connections matching a key");
658 MODULE_LICENSE("GPL");