netfilter: nf_conncount: move all list iterations under spinlock
[linux] / net / netfilter / nf_conncount.c
1 /*
2  * count the number of connections matching an arbitrary key.
3  *
4  * (C) 2017 Red Hat GmbH
5  * Author: Florian Westphal <fw@strlen.de>
6  *
7  * split from xt_connlimit.c:
8  *   (c) 2000 Gerd Knorr <kraxel@bytesex.org>
9  *   Nov 2002: Martin Bene <martin.bene@icomedias.com>:
10  *              only ignore TIME_WAIT or gone connections
11  *   (C) CC Computer Consultants GmbH, 2007
12  */
13 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
14 #include <linux/in.h>
15 #include <linux/in6.h>
16 #include <linux/ip.h>
17 #include <linux/ipv6.h>
18 #include <linux/jhash.h>
19 #include <linux/slab.h>
20 #include <linux/list.h>
21 #include <linux/rbtree.h>
22 #include <linux/module.h>
23 #include <linux/random.h>
24 #include <linux/skbuff.h>
25 #include <linux/spinlock.h>
26 #include <linux/netfilter/nf_conntrack_tcp.h>
27 #include <linux/netfilter/x_tables.h>
28 #include <net/netfilter/nf_conntrack.h>
29 #include <net/netfilter/nf_conntrack_count.h>
30 #include <net/netfilter/nf_conntrack_core.h>
31 #include <net/netfilter/nf_conntrack_tuple.h>
32 #include <net/netfilter/nf_conntrack_zones.h>
33
34 #define CONNCOUNT_SLOTS         256U
35
36 #define CONNCOUNT_GC_MAX_NODES  8
37 #define MAX_KEYLEN              5
38
39 /* we will save the tuples of all connections we care about */
40 struct nf_conncount_tuple {
41         struct list_head                node;
42         struct nf_conntrack_tuple       tuple;
43         struct nf_conntrack_zone        zone;
44         int                             cpu;
45         u32                             jiffies32;
46 };
47
48 struct nf_conncount_rb {
49         struct rb_node node;
50         struct nf_conncount_list list;
51         u32 key[MAX_KEYLEN];
52         struct rcu_head rcu_head;
53 };
54
55 static spinlock_t nf_conncount_locks[CONNCOUNT_SLOTS] __cacheline_aligned_in_smp;
56
57 struct nf_conncount_data {
58         unsigned int keylen;
59         struct rb_root root[CONNCOUNT_SLOTS];
60         struct net *net;
61         struct work_struct gc_work;
62         unsigned long pending_trees[BITS_TO_LONGS(CONNCOUNT_SLOTS)];
63         unsigned int gc_tree;
64 };
65
66 static u_int32_t conncount_rnd __read_mostly;
67 static struct kmem_cache *conncount_rb_cachep __read_mostly;
68 static struct kmem_cache *conncount_conn_cachep __read_mostly;
69
70 static inline bool already_closed(const struct nf_conn *conn)
71 {
72         if (nf_ct_protonum(conn) == IPPROTO_TCP)
73                 return conn->proto.tcp.state == TCP_CONNTRACK_TIME_WAIT ||
74                        conn->proto.tcp.state == TCP_CONNTRACK_CLOSE;
75         else
76                 return false;
77 }
78
79 static int key_diff(const u32 *a, const u32 *b, unsigned int klen)
80 {
81         return memcmp(a, b, klen * sizeof(u32));
82 }
83
84 static bool conn_free(struct nf_conncount_list *list,
85                       struct nf_conncount_tuple *conn)
86 {
87         bool free_entry = false;
88
89         lockdep_assert_held(&list->list_lock);
90
91         list->count--;
92         list_del(&conn->node);
93         if (list->count == 0) {
94                 list->dead = true;
95                 free_entry = true;
96         }
97
98         kmem_cache_free(conncount_conn_cachep, conn);
99         return free_entry;
100 }
101
102 static const struct nf_conntrack_tuple_hash *
103 find_or_evict(struct net *net, struct nf_conncount_list *list,
104               struct nf_conncount_tuple *conn, bool *free_entry)
105 {
106         const struct nf_conntrack_tuple_hash *found;
107         unsigned long a, b;
108         int cpu = raw_smp_processor_id();
109         u32 age;
110
111         found = nf_conntrack_find_get(net, &conn->zone, &conn->tuple);
112         if (found)
113                 return found;
114         b = conn->jiffies32;
115         a = (u32)jiffies;
116
117         /* conn might have been added just before by another cpu and
118          * might still be unconfirmed.  In this case, nf_conntrack_find()
119          * returns no result.  Thus only evict if this cpu added the
120          * stale entry or if the entry is older than two jiffies.
121          */
122         age = a - b;
123         if (conn->cpu == cpu || age >= 2) {
124                 *free_entry = conn_free(list, conn);
125                 return ERR_PTR(-ENOENT);
126         }
127
128         return ERR_PTR(-EAGAIN);
129 }
130
131 static int __nf_conncount_add(struct net *net,
132                               struct nf_conncount_list *list,
133                               const struct nf_conntrack_tuple *tuple,
134                               const struct nf_conntrack_zone *zone)
135 {
136         const struct nf_conntrack_tuple_hash *found;
137         struct nf_conncount_tuple *conn, *conn_n;
138         struct nf_conn *found_ct;
139         unsigned int collect = 0;
140         bool free_entry = false;
141
142         /* check the saved connections */
143         list_for_each_entry_safe(conn, conn_n, &list->head, node) {
144                 if (collect > CONNCOUNT_GC_MAX_NODES)
145                         break;
146
147                 found = find_or_evict(net, list, conn, &free_entry);
148                 if (IS_ERR(found)) {
149                         /* Not found, but might be about to be confirmed */
150                         if (PTR_ERR(found) == -EAGAIN) {
151                                 if (nf_ct_tuple_equal(&conn->tuple, tuple) &&
152                                     nf_ct_zone_id(&conn->zone, conn->zone.dir) ==
153                                     nf_ct_zone_id(zone, zone->dir))
154                                         return 0; /* already exists */
155                         } else {
156                                 collect++;
157                         }
158                         continue;
159                 }
160
161                 found_ct = nf_ct_tuplehash_to_ctrack(found);
162
163                 if (nf_ct_tuple_equal(&conn->tuple, tuple) &&
164                     nf_ct_zone_equal(found_ct, zone, zone->dir)) {
165                         /*
166                          * We should not see tuples twice unless someone hooks
167                          * this into a table without "-p tcp --syn".
168                          *
169                          * Attempt to avoid a re-add in this case.
170                          */
171                         nf_ct_put(found_ct);
172                         return 0;
173                 } else if (already_closed(found_ct)) {
174                         /*
175                          * we do not care about connections which are
176                          * closed already -> ditch it
177                          */
178                         nf_ct_put(found_ct);
179                         conn_free(list, conn);
180                         collect++;
181                         continue;
182                 }
183
184                 nf_ct_put(found_ct);
185         }
186
187         if (WARN_ON_ONCE(list->count > INT_MAX))
188                 return -EOVERFLOW;
189
190         conn = kmem_cache_alloc(conncount_conn_cachep, GFP_ATOMIC);
191         if (conn == NULL)
192                 return -ENOMEM;
193
194         conn->tuple = *tuple;
195         conn->zone = *zone;
196         conn->cpu = raw_smp_processor_id();
197         conn->jiffies32 = (u32)jiffies;
198         list_add_tail(&conn->node, &list->head);
199         list->count++;
200         return 0;
201 }
202
203 int nf_conncount_add(struct net *net,
204                      struct nf_conncount_list *list,
205                      const struct nf_conntrack_tuple *tuple,
206                      const struct nf_conntrack_zone *zone)
207 {
208         int ret;
209
210         /* check the saved connections */
211         spin_lock_bh(&list->list_lock);
212         ret = __nf_conncount_add(net, list, tuple, zone);
213         spin_unlock_bh(&list->list_lock);
214
215         return ret;
216 }
217 EXPORT_SYMBOL_GPL(nf_conncount_add);
218
219 void nf_conncount_list_init(struct nf_conncount_list *list)
220 {
221         spin_lock_init(&list->list_lock);
222         INIT_LIST_HEAD(&list->head);
223         list->count = 0;
224         list->dead = false;
225 }
226 EXPORT_SYMBOL_GPL(nf_conncount_list_init);
227
228 /* Return true if the list is empty. Must be called with BH disabled. */
229 bool nf_conncount_gc_list(struct net *net,
230                           struct nf_conncount_list *list)
231 {
232         const struct nf_conntrack_tuple_hash *found;
233         struct nf_conncount_tuple *conn, *conn_n;
234         struct nf_conn *found_ct;
235         unsigned int collected = 0;
236         bool free_entry = false;
237         bool ret = false;
238
239         /* don't bother if other cpu is already doing GC */
240         if (!spin_trylock(&list->list_lock))
241                 return false;
242
243         list_for_each_entry_safe(conn, conn_n, &list->head, node) {
244                 found = find_or_evict(net, list, conn, &free_entry);
245                 if (IS_ERR(found)) {
246                         if (PTR_ERR(found) == -ENOENT)  {
247                                 if (free_entry) {
248                                         spin_unlock(&list->list_lock);
249                                         return true;
250                                 }
251                                 collected++;
252                         }
253                         continue;
254                 }
255
256                 found_ct = nf_ct_tuplehash_to_ctrack(found);
257                 if (already_closed(found_ct)) {
258                         /*
259                          * we do not care about connections which are
260                          * closed already -> ditch it
261                          */
262                         nf_ct_put(found_ct);
263                         if (conn_free(list, conn)) {
264                                 spin_unlock(&list->list_lock);
265                                 return true;
266                         }
267                         collected++;
268                         continue;
269                 }
270
271                 nf_ct_put(found_ct);
272                 if (collected > CONNCOUNT_GC_MAX_NODES)
273                         break;
274         }
275
276         if (!list->count) {
277                 list->dead = true;
278                 ret = true;
279         }
280         spin_unlock(&list->list_lock);
281
282         return ret;
283 }
284 EXPORT_SYMBOL_GPL(nf_conncount_gc_list);
285
286 static void __tree_nodes_free(struct rcu_head *h)
287 {
288         struct nf_conncount_rb *rbconn;
289
290         rbconn = container_of(h, struct nf_conncount_rb, rcu_head);
291         kmem_cache_free(conncount_rb_cachep, rbconn);
292 }
293
294 static void tree_nodes_free(struct rb_root *root,
295                             struct nf_conncount_rb *gc_nodes[],
296                             unsigned int gc_count)
297 {
298         struct nf_conncount_rb *rbconn;
299
300         while (gc_count) {
301                 rbconn = gc_nodes[--gc_count];
302                 spin_lock(&rbconn->list.list_lock);
303                 rb_erase(&rbconn->node, root);
304                 call_rcu(&rbconn->rcu_head, __tree_nodes_free);
305                 spin_unlock(&rbconn->list.list_lock);
306         }
307 }
308
309 static void schedule_gc_worker(struct nf_conncount_data *data, int tree)
310 {
311         set_bit(tree, data->pending_trees);
312         schedule_work(&data->gc_work);
313 }
314
315 static unsigned int
316 insert_tree(struct net *net,
317             struct nf_conncount_data *data,
318             struct rb_root *root,
319             unsigned int hash,
320             const u32 *key,
321             u8 keylen,
322             const struct nf_conntrack_tuple *tuple,
323             const struct nf_conntrack_zone *zone)
324 {
325         struct nf_conncount_rb *gc_nodes[CONNCOUNT_GC_MAX_NODES];
326         struct rb_node **rbnode, *parent;
327         struct nf_conncount_rb *rbconn;
328         struct nf_conncount_tuple *conn;
329         unsigned int count = 0, gc_count = 0;
330         bool do_gc = true;
331
332         spin_lock_bh(&nf_conncount_locks[hash]);
333 restart:
334         parent = NULL;
335         rbnode = &(root->rb_node);
336         while (*rbnode) {
337                 int diff;
338                 rbconn = rb_entry(*rbnode, struct nf_conncount_rb, node);
339
340                 parent = *rbnode;
341                 diff = key_diff(key, rbconn->key, keylen);
342                 if (diff < 0) {
343                         rbnode = &((*rbnode)->rb_left);
344                 } else if (diff > 0) {
345                         rbnode = &((*rbnode)->rb_right);
346                 } else {
347                         int ret;
348
349                         ret = nf_conncount_add(net, &rbconn->list, tuple, zone);
350                         if (ret)
351                                 count = 0; /* hotdrop */
352                         else
353                                 count = rbconn->list.count;
354                         tree_nodes_free(root, gc_nodes, gc_count);
355                         goto out_unlock;
356                 }
357
358                 if (gc_count >= ARRAY_SIZE(gc_nodes))
359                         continue;
360
361                 if (do_gc && nf_conncount_gc_list(net, &rbconn->list))
362                         gc_nodes[gc_count++] = rbconn;
363         }
364
365         if (gc_count) {
366                 tree_nodes_free(root, gc_nodes, gc_count);
367                 schedule_gc_worker(data, hash);
368                 gc_count = 0;
369                 do_gc = false;
370                 goto restart;
371         }
372
373         /* expected case: match, insert new node */
374         rbconn = kmem_cache_alloc(conncount_rb_cachep, GFP_ATOMIC);
375         if (rbconn == NULL)
376                 goto out_unlock;
377
378         conn = kmem_cache_alloc(conncount_conn_cachep, GFP_ATOMIC);
379         if (conn == NULL) {
380                 kmem_cache_free(conncount_rb_cachep, rbconn);
381                 goto out_unlock;
382         }
383
384         conn->tuple = *tuple;
385         conn->zone = *zone;
386         memcpy(rbconn->key, key, sizeof(u32) * keylen);
387
388         nf_conncount_list_init(&rbconn->list);
389         list_add(&conn->node, &rbconn->list.head);
390         count = 1;
391         rbconn->list.count = count;
392
393         rb_link_node_rcu(&rbconn->node, parent, rbnode);
394         rb_insert_color(&rbconn->node, root);
395 out_unlock:
396         spin_unlock_bh(&nf_conncount_locks[hash]);
397         return count;
398 }
399
400 static unsigned int
401 count_tree(struct net *net,
402            struct nf_conncount_data *data,
403            const u32 *key,
404            const struct nf_conntrack_tuple *tuple,
405            const struct nf_conntrack_zone *zone)
406 {
407         struct rb_root *root;
408         struct rb_node *parent;
409         struct nf_conncount_rb *rbconn;
410         unsigned int hash;
411         u8 keylen = data->keylen;
412
413         hash = jhash2(key, data->keylen, conncount_rnd) % CONNCOUNT_SLOTS;
414         root = &data->root[hash];
415
416         parent = rcu_dereference_raw(root->rb_node);
417         while (parent) {
418                 int diff;
419
420                 rbconn = rb_entry(parent, struct nf_conncount_rb, node);
421
422                 diff = key_diff(key, rbconn->key, keylen);
423                 if (diff < 0) {
424                         parent = rcu_dereference_raw(parent->rb_left);
425                 } else if (diff > 0) {
426                         parent = rcu_dereference_raw(parent->rb_right);
427                 } else {
428                         int ret;
429
430                         if (!tuple) {
431                                 nf_conncount_gc_list(net, &rbconn->list);
432                                 return rbconn->list.count;
433                         }
434
435                         spin_lock_bh(&rbconn->list.list_lock);
436                         /* Node might be about to be free'd.
437                          * We need to defer to insert_tree() in this case.
438                          */
439                         if (rbconn->list.count == 0) {
440                                 spin_unlock_bh(&rbconn->list.list_lock);
441                                 break;
442                         }
443
444                         /* same source network -> be counted! */
445                         ret = __nf_conncount_add(net, &rbconn->list, tuple, zone);
446                         spin_unlock_bh(&rbconn->list.list_lock);
447                         if (ret)
448                                 return 0; /* hotdrop */
449                         else
450                                 return rbconn->list.count;
451                 }
452         }
453
454         if (!tuple)
455                 return 0;
456
457         return insert_tree(net, data, root, hash, key, keylen, tuple, zone);
458 }
459
460 static void tree_gc_worker(struct work_struct *work)
461 {
462         struct nf_conncount_data *data = container_of(work, struct nf_conncount_data, gc_work);
463         struct nf_conncount_rb *gc_nodes[CONNCOUNT_GC_MAX_NODES], *rbconn;
464         struct rb_root *root;
465         struct rb_node *node;
466         unsigned int tree, next_tree, gc_count = 0;
467
468         tree = data->gc_tree % CONNCOUNT_SLOTS;
469         root = &data->root[tree];
470
471         local_bh_disable();
472         rcu_read_lock();
473         for (node = rb_first(root); node != NULL; node = rb_next(node)) {
474                 rbconn = rb_entry(node, struct nf_conncount_rb, node);
475                 if (nf_conncount_gc_list(data->net, &rbconn->list))
476                         gc_count++;
477         }
478         rcu_read_unlock();
479         local_bh_enable();
480
481         cond_resched();
482
483         spin_lock_bh(&nf_conncount_locks[tree]);
484         if (gc_count < ARRAY_SIZE(gc_nodes))
485                 goto next; /* do not bother */
486
487         gc_count = 0;
488         node = rb_first(root);
489         while (node != NULL) {
490                 rbconn = rb_entry(node, struct nf_conncount_rb, node);
491                 node = rb_next(node);
492
493                 if (rbconn->list.count > 0)
494                         continue;
495
496                 gc_nodes[gc_count++] = rbconn;
497                 if (gc_count >= ARRAY_SIZE(gc_nodes)) {
498                         tree_nodes_free(root, gc_nodes, gc_count);
499                         gc_count = 0;
500                 }
501         }
502
503         tree_nodes_free(root, gc_nodes, gc_count);
504 next:
505         clear_bit(tree, data->pending_trees);
506
507         next_tree = (tree + 1) % CONNCOUNT_SLOTS;
508         next_tree = find_next_bit(data->pending_trees, next_tree, CONNCOUNT_SLOTS);
509
510         if (next_tree < CONNCOUNT_SLOTS) {
511                 data->gc_tree = next_tree;
512                 schedule_work(work);
513         }
514
515         spin_unlock_bh(&nf_conncount_locks[tree]);
516 }
517
518 /* Count and return number of conntrack entries in 'net' with particular 'key'.
519  * If 'tuple' is not null, insert it into the accounting data structure.
520  * Call with RCU read lock.
521  */
522 unsigned int nf_conncount_count(struct net *net,
523                                 struct nf_conncount_data *data,
524                                 const u32 *key,
525                                 const struct nf_conntrack_tuple *tuple,
526                                 const struct nf_conntrack_zone *zone)
527 {
528         return count_tree(net, data, key, tuple, zone);
529 }
530 EXPORT_SYMBOL_GPL(nf_conncount_count);
531
532 struct nf_conncount_data *nf_conncount_init(struct net *net, unsigned int family,
533                                             unsigned int keylen)
534 {
535         struct nf_conncount_data *data;
536         int ret, i;
537
538         if (keylen % sizeof(u32) ||
539             keylen / sizeof(u32) > MAX_KEYLEN ||
540             keylen == 0)
541                 return ERR_PTR(-EINVAL);
542
543         net_get_random_once(&conncount_rnd, sizeof(conncount_rnd));
544
545         data = kmalloc(sizeof(*data), GFP_KERNEL);
546         if (!data)
547                 return ERR_PTR(-ENOMEM);
548
549         ret = nf_ct_netns_get(net, family);
550         if (ret < 0) {
551                 kfree(data);
552                 return ERR_PTR(ret);
553         }
554
555         for (i = 0; i < ARRAY_SIZE(data->root); ++i)
556                 data->root[i] = RB_ROOT;
557
558         data->keylen = keylen / sizeof(u32);
559         data->net = net;
560         INIT_WORK(&data->gc_work, tree_gc_worker);
561
562         return data;
563 }
564 EXPORT_SYMBOL_GPL(nf_conncount_init);
565
566 void nf_conncount_cache_free(struct nf_conncount_list *list)
567 {
568         struct nf_conncount_tuple *conn, *conn_n;
569
570         list_for_each_entry_safe(conn, conn_n, &list->head, node)
571                 kmem_cache_free(conncount_conn_cachep, conn);
572 }
573 EXPORT_SYMBOL_GPL(nf_conncount_cache_free);
574
575 static void destroy_tree(struct rb_root *r)
576 {
577         struct nf_conncount_rb *rbconn;
578         struct rb_node *node;
579
580         while ((node = rb_first(r)) != NULL) {
581                 rbconn = rb_entry(node, struct nf_conncount_rb, node);
582
583                 rb_erase(node, r);
584
585                 nf_conncount_cache_free(&rbconn->list);
586
587                 kmem_cache_free(conncount_rb_cachep, rbconn);
588         }
589 }
590
591 void nf_conncount_destroy(struct net *net, unsigned int family,
592                           struct nf_conncount_data *data)
593 {
594         unsigned int i;
595
596         cancel_work_sync(&data->gc_work);
597         nf_ct_netns_put(net, family);
598
599         for (i = 0; i < ARRAY_SIZE(data->root); ++i)
600                 destroy_tree(&data->root[i]);
601
602         kfree(data);
603 }
604 EXPORT_SYMBOL_GPL(nf_conncount_destroy);
605
606 static int __init nf_conncount_modinit(void)
607 {
608         int i;
609
610         for (i = 0; i < CONNCOUNT_SLOTS; ++i)
611                 spin_lock_init(&nf_conncount_locks[i]);
612
613         conncount_conn_cachep = kmem_cache_create("nf_conncount_tuple",
614                                            sizeof(struct nf_conncount_tuple),
615                                            0, 0, NULL);
616         if (!conncount_conn_cachep)
617                 return -ENOMEM;
618
619         conncount_rb_cachep = kmem_cache_create("nf_conncount_rb",
620                                            sizeof(struct nf_conncount_rb),
621                                            0, 0, NULL);
622         if (!conncount_rb_cachep) {
623                 kmem_cache_destroy(conncount_conn_cachep);
624                 return -ENOMEM;
625         }
626
627         return 0;
628 }
629
630 static void __exit nf_conncount_modexit(void)
631 {
632         kmem_cache_destroy(conncount_conn_cachep);
633         kmem_cache_destroy(conncount_rb_cachep);
634 }
635
636 module_init(nf_conncount_modinit);
637 module_exit(nf_conncount_modexit);
638 MODULE_AUTHOR("Jan Engelhardt <jengelh@medozas.de>");
639 MODULE_AUTHOR("Florian Westphal <fw@strlen.de>");
640 MODULE_DESCRIPTION("netfilter: count number of connections matching a key");
641 MODULE_LICENSE("GPL");