ce7f7d1212a63f93df46d0faf6e500438018b947
[linux] / net / netfilter / nf_conncount.c
1 /*
2  * count the number of connections matching an arbitrary key.
3  *
4  * (C) 2017 Red Hat GmbH
5  * Author: Florian Westphal <fw@strlen.de>
6  *
7  * split from xt_connlimit.c:
8  *   (c) 2000 Gerd Knorr <kraxel@bytesex.org>
9  *   Nov 2002: Martin Bene <martin.bene@icomedias.com>:
10  *              only ignore TIME_WAIT or gone connections
11  *   (C) CC Computer Consultants GmbH, 2007
12  */
13 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
14 #include <linux/in.h>
15 #include <linux/in6.h>
16 #include <linux/ip.h>
17 #include <linux/ipv6.h>
18 #include <linux/jhash.h>
19 #include <linux/slab.h>
20 #include <linux/list.h>
21 #include <linux/rbtree.h>
22 #include <linux/module.h>
23 #include <linux/random.h>
24 #include <linux/skbuff.h>
25 #include <linux/spinlock.h>
26 #include <linux/netfilter/nf_conntrack_tcp.h>
27 #include <linux/netfilter/x_tables.h>
28 #include <net/netfilter/nf_conntrack.h>
29 #include <net/netfilter/nf_conntrack_count.h>
30 #include <net/netfilter/nf_conntrack_core.h>
31 #include <net/netfilter/nf_conntrack_tuple.h>
32 #include <net/netfilter/nf_conntrack_zones.h>
33
34 #define CONNCOUNT_SLOTS         256U
35
36 #define CONNCOUNT_GC_MAX_NODES  8
37 #define MAX_KEYLEN              5
38
39 /* we will save the tuples of all connections we care about */
40 struct nf_conncount_tuple {
41         struct list_head                node;
42         struct nf_conntrack_tuple       tuple;
43         struct nf_conntrack_zone        zone;
44         int                             cpu;
45         u32                             jiffies32;
46         bool                            dead;
47         struct rcu_head                 rcu_head;
48 };
49
50 struct nf_conncount_rb {
51         struct rb_node node;
52         struct nf_conncount_list list;
53         u32 key[MAX_KEYLEN];
54         struct rcu_head rcu_head;
55 };
56
57 static spinlock_t nf_conncount_locks[CONNCOUNT_SLOTS] __cacheline_aligned_in_smp;
58
59 struct nf_conncount_data {
60         unsigned int keylen;
61         struct rb_root root[CONNCOUNT_SLOTS];
62         struct net *net;
63         struct work_struct gc_work;
64         unsigned long pending_trees[BITS_TO_LONGS(CONNCOUNT_SLOTS)];
65         unsigned int gc_tree;
66 };
67
68 static u_int32_t conncount_rnd __read_mostly;
69 static struct kmem_cache *conncount_rb_cachep __read_mostly;
70 static struct kmem_cache *conncount_conn_cachep __read_mostly;
71
72 static inline bool already_closed(const struct nf_conn *conn)
73 {
74         if (nf_ct_protonum(conn) == IPPROTO_TCP)
75                 return conn->proto.tcp.state == TCP_CONNTRACK_TIME_WAIT ||
76                        conn->proto.tcp.state == TCP_CONNTRACK_CLOSE;
77         else
78                 return false;
79 }
80
81 static int key_diff(const u32 *a, const u32 *b, unsigned int klen)
82 {
83         return memcmp(a, b, klen * sizeof(u32));
84 }
85
86 static void __conn_free(struct rcu_head *h)
87 {
88         struct nf_conncount_tuple *conn;
89
90         conn = container_of(h, struct nf_conncount_tuple, rcu_head);
91         kmem_cache_free(conncount_conn_cachep, conn);
92 }
93
94 static bool conn_free(struct nf_conncount_list *list,
95                       struct nf_conncount_tuple *conn)
96 {
97         bool free_entry = false;
98
99         spin_lock_bh(&list->list_lock);
100
101         if (conn->dead) {
102                 spin_unlock_bh(&list->list_lock);
103                 return free_entry;
104         }
105
106         list->count--;
107         conn->dead = true;
108         list_del_rcu(&conn->node);
109         if (list->count == 0) {
110                 list->dead = true;
111                 free_entry = true;
112         }
113
114         spin_unlock_bh(&list->list_lock);
115         call_rcu(&conn->rcu_head, __conn_free);
116         return free_entry;
117 }
118
119 static const struct nf_conntrack_tuple_hash *
120 find_or_evict(struct net *net, struct nf_conncount_list *list,
121               struct nf_conncount_tuple *conn, bool *free_entry)
122 {
123         const struct nf_conntrack_tuple_hash *found;
124         unsigned long a, b;
125         int cpu = raw_smp_processor_id();
126         u32 age;
127
128         found = nf_conntrack_find_get(net, &conn->zone, &conn->tuple);
129         if (found)
130                 return found;
131         b = conn->jiffies32;
132         a = (u32)jiffies;
133
134         /* conn might have been added just before by another cpu and
135          * might still be unconfirmed.  In this case, nf_conntrack_find()
136          * returns no result.  Thus only evict if this cpu added the
137          * stale entry or if the entry is older than two jiffies.
138          */
139         age = a - b;
140         if (conn->cpu == cpu || age >= 2) {
141                 *free_entry = conn_free(list, conn);
142                 return ERR_PTR(-ENOENT);
143         }
144
145         return ERR_PTR(-EAGAIN);
146 }
147
148 static int __nf_conncount_add(struct net *net,
149                               struct nf_conncount_list *list,
150                               const struct nf_conntrack_tuple *tuple,
151                               const struct nf_conntrack_zone *zone)
152 {
153         const struct nf_conntrack_tuple_hash *found;
154         struct nf_conncount_tuple *conn, *conn_n;
155         struct nf_conn *found_ct;
156         unsigned int collect = 0;
157         bool free_entry = false;
158
159         /* check the saved connections */
160         list_for_each_entry_safe(conn, conn_n, &list->head, node) {
161                 if (collect > CONNCOUNT_GC_MAX_NODES)
162                         break;
163
164                 found = find_or_evict(net, list, conn, &free_entry);
165                 if (IS_ERR(found)) {
166                         /* Not found, but might be about to be confirmed */
167                         if (PTR_ERR(found) == -EAGAIN) {
168                                 if (nf_ct_tuple_equal(&conn->tuple, tuple) &&
169                                     nf_ct_zone_id(&conn->zone, conn->zone.dir) ==
170                                     nf_ct_zone_id(zone, zone->dir))
171                                         return 0; /* already exists */
172                         } else {
173                                 collect++;
174                         }
175                         continue;
176                 }
177
178                 found_ct = nf_ct_tuplehash_to_ctrack(found);
179
180                 if (nf_ct_tuple_equal(&conn->tuple, tuple) &&
181                     nf_ct_zone_equal(found_ct, zone, zone->dir)) {
182                         /*
183                          * We should not see tuples twice unless someone hooks
184                          * this into a table without "-p tcp --syn".
185                          *
186                          * Attempt to avoid a re-add in this case.
187                          */
188                         nf_ct_put(found_ct);
189                         return 0;
190                 } else if (already_closed(found_ct)) {
191                         /*
192                          * we do not care about connections which are
193                          * closed already -> ditch it
194                          */
195                         nf_ct_put(found_ct);
196                         conn_free(list, conn);
197                         collect++;
198                         continue;
199                 }
200
201                 nf_ct_put(found_ct);
202         }
203
204         if (WARN_ON_ONCE(list->count > INT_MAX))
205                 return -EOVERFLOW;
206
207         conn = kmem_cache_alloc(conncount_conn_cachep, GFP_ATOMIC);
208         if (conn == NULL)
209                 return -ENOMEM;
210
211         conn->tuple = *tuple;
212         conn->zone = *zone;
213         conn->cpu = raw_smp_processor_id();
214         conn->jiffies32 = (u32)jiffies;
215         list_add_tail(&conn->node, &list->head);
216         list->count++;
217         return 0;
218 }
219
220 int nf_conncount_add(struct net *net,
221                      struct nf_conncount_list *list,
222                      const struct nf_conntrack_tuple *tuple,
223                      const struct nf_conntrack_zone *zone)
224 {
225         int ret;
226
227         /* check the saved connections */
228         spin_lock_bh(&list->list_lock);
229         ret = __nf_conncount_add(net, list, tuple, zone);
230         spin_unlock_bh(&list->list_lock);
231
232         return ret;
233 }
234 EXPORT_SYMBOL_GPL(nf_conncount_add);
235
236 void nf_conncount_list_init(struct nf_conncount_list *list)
237 {
238         spin_lock_init(&list->list_lock);
239         INIT_LIST_HEAD(&list->head);
240         list->count = 0;
241         list->dead = false;
242 }
243 EXPORT_SYMBOL_GPL(nf_conncount_list_init);
244
245 /* Return true if the list is empty */
246 bool nf_conncount_gc_list(struct net *net,
247                           struct nf_conncount_list *list)
248 {
249         const struct nf_conntrack_tuple_hash *found;
250         struct nf_conncount_tuple *conn, *conn_n;
251         struct nf_conn *found_ct;
252         unsigned int collected = 0;
253         bool free_entry = false;
254         bool ret = false;
255
256         list_for_each_entry_safe(conn, conn_n, &list->head, node) {
257                 found = find_or_evict(net, list, conn, &free_entry);
258                 if (IS_ERR(found)) {
259                         if (PTR_ERR(found) == -ENOENT)  {
260                                 if (free_entry)
261                                         return true;
262                                 collected++;
263                         }
264                         continue;
265                 }
266
267                 found_ct = nf_ct_tuplehash_to_ctrack(found);
268                 if (already_closed(found_ct)) {
269                         /*
270                          * we do not care about connections which are
271                          * closed already -> ditch it
272                          */
273                         nf_ct_put(found_ct);
274                         if (conn_free(list, conn))
275                                 return true;
276                         collected++;
277                         continue;
278                 }
279
280                 nf_ct_put(found_ct);
281                 if (collected > CONNCOUNT_GC_MAX_NODES)
282                         return false;
283         }
284
285         spin_lock_bh(&list->list_lock);
286         if (!list->count) {
287                 list->dead = true;
288                 ret = true;
289         }
290         spin_unlock_bh(&list->list_lock);
291
292         return ret;
293 }
294 EXPORT_SYMBOL_GPL(nf_conncount_gc_list);
295
296 static void __tree_nodes_free(struct rcu_head *h)
297 {
298         struct nf_conncount_rb *rbconn;
299
300         rbconn = container_of(h, struct nf_conncount_rb, rcu_head);
301         kmem_cache_free(conncount_rb_cachep, rbconn);
302 }
303
304 static void tree_nodes_free(struct rb_root *root,
305                             struct nf_conncount_rb *gc_nodes[],
306                             unsigned int gc_count)
307 {
308         struct nf_conncount_rb *rbconn;
309
310         while (gc_count) {
311                 rbconn = gc_nodes[--gc_count];
312                 spin_lock(&rbconn->list.list_lock);
313                 rb_erase(&rbconn->node, root);
314                 call_rcu(&rbconn->rcu_head, __tree_nodes_free);
315                 spin_unlock(&rbconn->list.list_lock);
316         }
317 }
318
319 static void schedule_gc_worker(struct nf_conncount_data *data, int tree)
320 {
321         set_bit(tree, data->pending_trees);
322         schedule_work(&data->gc_work);
323 }
324
325 static unsigned int
326 insert_tree(struct net *net,
327             struct nf_conncount_data *data,
328             struct rb_root *root,
329             unsigned int hash,
330             const u32 *key,
331             u8 keylen,
332             const struct nf_conntrack_tuple *tuple,
333             const struct nf_conntrack_zone *zone)
334 {
335         struct nf_conncount_rb *gc_nodes[CONNCOUNT_GC_MAX_NODES];
336         struct rb_node **rbnode, *parent;
337         struct nf_conncount_rb *rbconn;
338         struct nf_conncount_tuple *conn;
339         unsigned int count = 0, gc_count = 0;
340         bool do_gc = true;
341
342         spin_lock_bh(&nf_conncount_locks[hash]);
343 restart:
344         parent = NULL;
345         rbnode = &(root->rb_node);
346         while (*rbnode) {
347                 int diff;
348                 rbconn = rb_entry(*rbnode, struct nf_conncount_rb, node);
349
350                 parent = *rbnode;
351                 diff = key_diff(key, rbconn->key, keylen);
352                 if (diff < 0) {
353                         rbnode = &((*rbnode)->rb_left);
354                 } else if (diff > 0) {
355                         rbnode = &((*rbnode)->rb_right);
356                 } else {
357                         int ret;
358
359                         ret = nf_conncount_add(net, &rbconn->list, tuple, zone);
360                         if (ret)
361                                 count = 0; /* hotdrop */
362                         else
363                                 count = rbconn->list.count;
364                         tree_nodes_free(root, gc_nodes, gc_count);
365                         goto out_unlock;
366                 }
367
368                 if (gc_count >= ARRAY_SIZE(gc_nodes))
369                         continue;
370
371                 if (do_gc && nf_conncount_gc_list(net, &rbconn->list))
372                         gc_nodes[gc_count++] = rbconn;
373         }
374
375         if (gc_count) {
376                 tree_nodes_free(root, gc_nodes, gc_count);
377                 schedule_gc_worker(data, hash);
378                 gc_count = 0;
379                 do_gc = false;
380                 goto restart;
381         }
382
383         /* expected case: match, insert new node */
384         rbconn = kmem_cache_alloc(conncount_rb_cachep, GFP_ATOMIC);
385         if (rbconn == NULL)
386                 goto out_unlock;
387
388         conn = kmem_cache_alloc(conncount_conn_cachep, GFP_ATOMIC);
389         if (conn == NULL) {
390                 kmem_cache_free(conncount_rb_cachep, rbconn);
391                 goto out_unlock;
392         }
393
394         conn->tuple = *tuple;
395         conn->zone = *zone;
396         memcpy(rbconn->key, key, sizeof(u32) * keylen);
397
398         nf_conncount_list_init(&rbconn->list);
399         list_add(&conn->node, &rbconn->list.head);
400         count = 1;
401         rbconn->list.count = count;
402
403         rb_link_node_rcu(&rbconn->node, parent, rbnode);
404         rb_insert_color(&rbconn->node, root);
405 out_unlock:
406         spin_unlock_bh(&nf_conncount_locks[hash]);
407         return count;
408 }
409
410 static unsigned int
411 count_tree(struct net *net,
412            struct nf_conncount_data *data,
413            const u32 *key,
414            const struct nf_conntrack_tuple *tuple,
415            const struct nf_conntrack_zone *zone)
416 {
417         struct rb_root *root;
418         struct rb_node *parent;
419         struct nf_conncount_rb *rbconn;
420         unsigned int hash;
421         u8 keylen = data->keylen;
422
423         hash = jhash2(key, data->keylen, conncount_rnd) % CONNCOUNT_SLOTS;
424         root = &data->root[hash];
425
426         parent = rcu_dereference_raw(root->rb_node);
427         while (parent) {
428                 int diff;
429
430                 rbconn = rb_entry(parent, struct nf_conncount_rb, node);
431
432                 diff = key_diff(key, rbconn->key, keylen);
433                 if (diff < 0) {
434                         parent = rcu_dereference_raw(parent->rb_left);
435                 } else if (diff > 0) {
436                         parent = rcu_dereference_raw(parent->rb_right);
437                 } else {
438                         int ret;
439
440                         if (!tuple) {
441                                 nf_conncount_gc_list(net, &rbconn->list);
442                                 return rbconn->list.count;
443                         }
444
445                         spin_lock_bh(&rbconn->list.list_lock);
446                         /* Node might be about to be free'd.
447                          * We need to defer to insert_tree() in this case.
448                          */
449                         if (rbconn->list.count == 0) {
450                                 spin_unlock_bh(&rbconn->list.list_lock);
451                                 break;
452                         }
453
454                         /* same source network -> be counted! */
455                         ret = __nf_conncount_add(net, &rbconn->list, tuple, zone);
456                         spin_unlock_bh(&rbconn->list.list_lock);
457                         if (ret)
458                                 return 0; /* hotdrop */
459                         else
460                                 return rbconn->list.count;
461                 }
462         }
463
464         if (!tuple)
465                 return 0;
466
467         return insert_tree(net, data, root, hash, key, keylen, tuple, zone);
468 }
469
470 static void tree_gc_worker(struct work_struct *work)
471 {
472         struct nf_conncount_data *data = container_of(work, struct nf_conncount_data, gc_work);
473         struct nf_conncount_rb *gc_nodes[CONNCOUNT_GC_MAX_NODES], *rbconn;
474         struct rb_root *root;
475         struct rb_node *node;
476         unsigned int tree, next_tree, gc_count = 0;
477
478         tree = data->gc_tree % CONNCOUNT_SLOTS;
479         root = &data->root[tree];
480
481         rcu_read_lock();
482         for (node = rb_first(root); node != NULL; node = rb_next(node)) {
483                 rbconn = rb_entry(node, struct nf_conncount_rb, node);
484                 if (nf_conncount_gc_list(data->net, &rbconn->list))
485                         gc_count++;
486         }
487         rcu_read_unlock();
488
489         spin_lock_bh(&nf_conncount_locks[tree]);
490         if (gc_count < ARRAY_SIZE(gc_nodes))
491                 goto next; /* do not bother */
492
493         gc_count = 0;
494         node = rb_first(root);
495         while (node != NULL) {
496                 rbconn = rb_entry(node, struct nf_conncount_rb, node);
497                 node = rb_next(node);
498
499                 if (rbconn->list.count > 0)
500                         continue;
501
502                 gc_nodes[gc_count++] = rbconn;
503                 if (gc_count >= ARRAY_SIZE(gc_nodes)) {
504                         tree_nodes_free(root, gc_nodes, gc_count);
505                         gc_count = 0;
506                 }
507         }
508
509         tree_nodes_free(root, gc_nodes, gc_count);
510 next:
511         clear_bit(tree, data->pending_trees);
512
513         next_tree = (tree + 1) % CONNCOUNT_SLOTS;
514         next_tree = find_next_bit(data->pending_trees, next_tree, CONNCOUNT_SLOTS);
515
516         if (next_tree < CONNCOUNT_SLOTS) {
517                 data->gc_tree = next_tree;
518                 schedule_work(work);
519         }
520
521         spin_unlock_bh(&nf_conncount_locks[tree]);
522 }
523
524 /* Count and return number of conntrack entries in 'net' with particular 'key'.
525  * If 'tuple' is not null, insert it into the accounting data structure.
526  * Call with RCU read lock.
527  */
528 unsigned int nf_conncount_count(struct net *net,
529                                 struct nf_conncount_data *data,
530                                 const u32 *key,
531                                 const struct nf_conntrack_tuple *tuple,
532                                 const struct nf_conntrack_zone *zone)
533 {
534         return count_tree(net, data, key, tuple, zone);
535 }
536 EXPORT_SYMBOL_GPL(nf_conncount_count);
537
538 struct nf_conncount_data *nf_conncount_init(struct net *net, unsigned int family,
539                                             unsigned int keylen)
540 {
541         struct nf_conncount_data *data;
542         int ret, i;
543
544         if (keylen % sizeof(u32) ||
545             keylen / sizeof(u32) > MAX_KEYLEN ||
546             keylen == 0)
547                 return ERR_PTR(-EINVAL);
548
549         net_get_random_once(&conncount_rnd, sizeof(conncount_rnd));
550
551         data = kmalloc(sizeof(*data), GFP_KERNEL);
552         if (!data)
553                 return ERR_PTR(-ENOMEM);
554
555         ret = nf_ct_netns_get(net, family);
556         if (ret < 0) {
557                 kfree(data);
558                 return ERR_PTR(ret);
559         }
560
561         for (i = 0; i < ARRAY_SIZE(data->root); ++i)
562                 data->root[i] = RB_ROOT;
563
564         data->keylen = keylen / sizeof(u32);
565         data->net = net;
566         INIT_WORK(&data->gc_work, tree_gc_worker);
567
568         return data;
569 }
570 EXPORT_SYMBOL_GPL(nf_conncount_init);
571
572 void nf_conncount_cache_free(struct nf_conncount_list *list)
573 {
574         struct nf_conncount_tuple *conn, *conn_n;
575
576         list_for_each_entry_safe(conn, conn_n, &list->head, node)
577                 kmem_cache_free(conncount_conn_cachep, conn);
578 }
579 EXPORT_SYMBOL_GPL(nf_conncount_cache_free);
580
581 static void destroy_tree(struct rb_root *r)
582 {
583         struct nf_conncount_rb *rbconn;
584         struct rb_node *node;
585
586         while ((node = rb_first(r)) != NULL) {
587                 rbconn = rb_entry(node, struct nf_conncount_rb, node);
588
589                 rb_erase(node, r);
590
591                 nf_conncount_cache_free(&rbconn->list);
592
593                 kmem_cache_free(conncount_rb_cachep, rbconn);
594         }
595 }
596
597 void nf_conncount_destroy(struct net *net, unsigned int family,
598                           struct nf_conncount_data *data)
599 {
600         unsigned int i;
601
602         cancel_work_sync(&data->gc_work);
603         nf_ct_netns_put(net, family);
604
605         for (i = 0; i < ARRAY_SIZE(data->root); ++i)
606                 destroy_tree(&data->root[i]);
607
608         kfree(data);
609 }
610 EXPORT_SYMBOL_GPL(nf_conncount_destroy);
611
612 static int __init nf_conncount_modinit(void)
613 {
614         int i;
615
616         for (i = 0; i < CONNCOUNT_SLOTS; ++i)
617                 spin_lock_init(&nf_conncount_locks[i]);
618
619         conncount_conn_cachep = kmem_cache_create("nf_conncount_tuple",
620                                            sizeof(struct nf_conncount_tuple),
621                                            0, 0, NULL);
622         if (!conncount_conn_cachep)
623                 return -ENOMEM;
624
625         conncount_rb_cachep = kmem_cache_create("nf_conncount_rb",
626                                            sizeof(struct nf_conncount_rb),
627                                            0, 0, NULL);
628         if (!conncount_rb_cachep) {
629                 kmem_cache_destroy(conncount_conn_cachep);
630                 return -ENOMEM;
631         }
632
633         return 0;
634 }
635
636 static void __exit nf_conncount_modexit(void)
637 {
638         kmem_cache_destroy(conncount_conn_cachep);
639         kmem_cache_destroy(conncount_rb_cachep);
640 }
641
642 module_init(nf_conncount_modinit);
643 module_exit(nf_conncount_modexit);
644 MODULE_AUTHOR("Jan Engelhardt <jengelh@medozas.de>");
645 MODULE_AUTHOR("Florian Westphal <fw@strlen.de>");
646 MODULE_DESCRIPTION("netfilter: count number of connections matching a key");
647 MODULE_LICENSE("GPL");