Merge tag 'arm64-fixes' of git://git.kernel.org/pub/scm/linux/kernel/git/arm64/linux
[linux] / net / netfilter / nft_ct.c
1 /*
2  * Copyright (c) 2008-2009 Patrick McHardy <kaber@trash.net>
3  * Copyright (c) 2016 Pablo Neira Ayuso <pablo@netfilter.org>
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License version 2 as
7  * published by the Free Software Foundation.
8  *
9  * Development of this code funded by Astaro AG (http://www.astaro.com/)
10  */
11
12 #include <linux/kernel.h>
13 #include <linux/init.h>
14 #include <linux/module.h>
15 #include <linux/netlink.h>
16 #include <linux/netfilter.h>
17 #include <linux/netfilter/nf_tables.h>
18 #include <net/netfilter/nf_tables.h>
19 #include <net/netfilter/nf_conntrack.h>
20 #include <net/netfilter/nf_conntrack_acct.h>
21 #include <net/netfilter/nf_conntrack_tuple.h>
22 #include <net/netfilter/nf_conntrack_helper.h>
23 #include <net/netfilter/nf_conntrack_ecache.h>
24 #include <net/netfilter/nf_conntrack_labels.h>
25 #include <net/netfilter/nf_conntrack_timeout.h>
26 #include <net/netfilter/nf_conntrack_l4proto.h>
27
28 struct nft_ct {
29         enum nft_ct_keys        key:8;
30         enum ip_conntrack_dir   dir:8;
31         union {
32                 enum nft_registers      dreg:8;
33                 enum nft_registers      sreg:8;
34         };
35 };
36
37 struct nft_ct_helper_obj  {
38         struct nf_conntrack_helper *helper4;
39         struct nf_conntrack_helper *helper6;
40         u8 l4proto;
41 };
42
43 #ifdef CONFIG_NF_CONNTRACK_ZONES
44 static DEFINE_PER_CPU(struct nf_conn *, nft_ct_pcpu_template);
45 static unsigned int nft_ct_pcpu_template_refcnt __read_mostly;
46 #endif
47
48 static u64 nft_ct_get_eval_counter(const struct nf_conn_counter *c,
49                                    enum nft_ct_keys k,
50                                    enum ip_conntrack_dir d)
51 {
52         if (d < IP_CT_DIR_MAX)
53                 return k == NFT_CT_BYTES ? atomic64_read(&c[d].bytes) :
54                                            atomic64_read(&c[d].packets);
55
56         return nft_ct_get_eval_counter(c, k, IP_CT_DIR_ORIGINAL) +
57                nft_ct_get_eval_counter(c, k, IP_CT_DIR_REPLY);
58 }
59
60 static void nft_ct_get_eval(const struct nft_expr *expr,
61                             struct nft_regs *regs,
62                             const struct nft_pktinfo *pkt)
63 {
64         const struct nft_ct *priv = nft_expr_priv(expr);
65         u32 *dest = &regs->data[priv->dreg];
66         enum ip_conntrack_info ctinfo;
67         const struct nf_conn *ct;
68         const struct nf_conn_help *help;
69         const struct nf_conntrack_tuple *tuple;
70         const struct nf_conntrack_helper *helper;
71         unsigned int state;
72
73         ct = nf_ct_get(pkt->skb, &ctinfo);
74
75         switch (priv->key) {
76         case NFT_CT_STATE:
77                 if (ct)
78                         state = NF_CT_STATE_BIT(ctinfo);
79                 else if (ctinfo == IP_CT_UNTRACKED)
80                         state = NF_CT_STATE_UNTRACKED_BIT;
81                 else
82                         state = NF_CT_STATE_INVALID_BIT;
83                 *dest = state;
84                 return;
85         default:
86                 break;
87         }
88
89         if (ct == NULL)
90                 goto err;
91
92         switch (priv->key) {
93         case NFT_CT_DIRECTION:
94                 nft_reg_store8(dest, CTINFO2DIR(ctinfo));
95                 return;
96         case NFT_CT_STATUS:
97                 *dest = ct->status;
98                 return;
99 #ifdef CONFIG_NF_CONNTRACK_MARK
100         case NFT_CT_MARK:
101                 *dest = ct->mark;
102                 return;
103 #endif
104 #ifdef CONFIG_NF_CONNTRACK_SECMARK
105         case NFT_CT_SECMARK:
106                 *dest = ct->secmark;
107                 return;
108 #endif
109         case NFT_CT_EXPIRATION:
110                 *dest = jiffies_to_msecs(nf_ct_expires(ct));
111                 return;
112         case NFT_CT_HELPER:
113                 if (ct->master == NULL)
114                         goto err;
115                 help = nfct_help(ct->master);
116                 if (help == NULL)
117                         goto err;
118                 helper = rcu_dereference(help->helper);
119                 if (helper == NULL)
120                         goto err;
121                 strncpy((char *)dest, helper->name, NF_CT_HELPER_NAME_LEN);
122                 return;
123 #ifdef CONFIG_NF_CONNTRACK_LABELS
124         case NFT_CT_LABELS: {
125                 struct nf_conn_labels *labels = nf_ct_labels_find(ct);
126
127                 if (labels)
128                         memcpy(dest, labels->bits, NF_CT_LABELS_MAX_SIZE);
129                 else
130                         memset(dest, 0, NF_CT_LABELS_MAX_SIZE);
131                 return;
132         }
133 #endif
134         case NFT_CT_BYTES: /* fallthrough */
135         case NFT_CT_PKTS: {
136                 const struct nf_conn_acct *acct = nf_conn_acct_find(ct);
137                 u64 count = 0;
138
139                 if (acct)
140                         count = nft_ct_get_eval_counter(acct->counter,
141                                                         priv->key, priv->dir);
142                 memcpy(dest, &count, sizeof(count));
143                 return;
144         }
145         case NFT_CT_AVGPKT: {
146                 const struct nf_conn_acct *acct = nf_conn_acct_find(ct);
147                 u64 avgcnt = 0, bcnt = 0, pcnt = 0;
148
149                 if (acct) {
150                         pcnt = nft_ct_get_eval_counter(acct->counter,
151                                                        NFT_CT_PKTS, priv->dir);
152                         bcnt = nft_ct_get_eval_counter(acct->counter,
153                                                        NFT_CT_BYTES, priv->dir);
154                         if (pcnt != 0)
155                                 avgcnt = div64_u64(bcnt, pcnt);
156                 }
157
158                 memcpy(dest, &avgcnt, sizeof(avgcnt));
159                 return;
160         }
161         case NFT_CT_L3PROTOCOL:
162                 nft_reg_store8(dest, nf_ct_l3num(ct));
163                 return;
164         case NFT_CT_PROTOCOL:
165                 nft_reg_store8(dest, nf_ct_protonum(ct));
166                 return;
167 #ifdef CONFIG_NF_CONNTRACK_ZONES
168         case NFT_CT_ZONE: {
169                 const struct nf_conntrack_zone *zone = nf_ct_zone(ct);
170                 u16 zoneid;
171
172                 if (priv->dir < IP_CT_DIR_MAX)
173                         zoneid = nf_ct_zone_id(zone, priv->dir);
174                 else
175                         zoneid = zone->id;
176
177                 nft_reg_store16(dest, zoneid);
178                 return;
179         }
180 #endif
181         default:
182                 break;
183         }
184
185         tuple = &ct->tuplehash[priv->dir].tuple;
186         switch (priv->key) {
187         case NFT_CT_SRC:
188                 memcpy(dest, tuple->src.u3.all,
189                        nf_ct_l3num(ct) == NFPROTO_IPV4 ? 4 : 16);
190                 return;
191         case NFT_CT_DST:
192                 memcpy(dest, tuple->dst.u3.all,
193                        nf_ct_l3num(ct) == NFPROTO_IPV4 ? 4 : 16);
194                 return;
195         case NFT_CT_PROTO_SRC:
196                 nft_reg_store16(dest, (__force u16)tuple->src.u.all);
197                 return;
198         case NFT_CT_PROTO_DST:
199                 nft_reg_store16(dest, (__force u16)tuple->dst.u.all);
200                 return;
201         case NFT_CT_SRC_IP:
202                 if (nf_ct_l3num(ct) != NFPROTO_IPV4)
203                         goto err;
204                 *dest = tuple->src.u3.ip;
205                 return;
206         case NFT_CT_DST_IP:
207                 if (nf_ct_l3num(ct) != NFPROTO_IPV4)
208                         goto err;
209                 *dest = tuple->dst.u3.ip;
210                 return;
211         case NFT_CT_SRC_IP6:
212                 if (nf_ct_l3num(ct) != NFPROTO_IPV6)
213                         goto err;
214                 memcpy(dest, tuple->src.u3.ip6, sizeof(struct in6_addr));
215                 return;
216         case NFT_CT_DST_IP6:
217                 if (nf_ct_l3num(ct) != NFPROTO_IPV6)
218                         goto err;
219                 memcpy(dest, tuple->dst.u3.ip6, sizeof(struct in6_addr));
220                 return;
221         default:
222                 break;
223         }
224         return;
225 err:
226         regs->verdict.code = NFT_BREAK;
227 }
228
229 #ifdef CONFIG_NF_CONNTRACK_ZONES
230 static void nft_ct_set_zone_eval(const struct nft_expr *expr,
231                                  struct nft_regs *regs,
232                                  const struct nft_pktinfo *pkt)
233 {
234         struct nf_conntrack_zone zone = { .dir = NF_CT_DEFAULT_ZONE_DIR };
235         const struct nft_ct *priv = nft_expr_priv(expr);
236         struct sk_buff *skb = pkt->skb;
237         enum ip_conntrack_info ctinfo;
238         u16 value = nft_reg_load16(&regs->data[priv->sreg]);
239         struct nf_conn *ct;
240
241         ct = nf_ct_get(skb, &ctinfo);
242         if (ct) /* already tracked */
243                 return;
244
245         zone.id = value;
246
247         switch (priv->dir) {
248         case IP_CT_DIR_ORIGINAL:
249                 zone.dir = NF_CT_ZONE_DIR_ORIG;
250                 break;
251         case IP_CT_DIR_REPLY:
252                 zone.dir = NF_CT_ZONE_DIR_REPL;
253                 break;
254         default:
255                 break;
256         }
257
258         ct = this_cpu_read(nft_ct_pcpu_template);
259
260         if (likely(atomic_read(&ct->ct_general.use) == 1)) {
261                 nf_ct_zone_add(ct, &zone);
262         } else {
263                 /* previous skb got queued to userspace */
264                 ct = nf_ct_tmpl_alloc(nft_net(pkt), &zone, GFP_ATOMIC);
265                 if (!ct) {
266                         regs->verdict.code = NF_DROP;
267                         return;
268                 }
269         }
270
271         atomic_inc(&ct->ct_general.use);
272         nf_ct_set(skb, ct, IP_CT_NEW);
273 }
274 #endif
275
276 static void nft_ct_set_eval(const struct nft_expr *expr,
277                             struct nft_regs *regs,
278                             const struct nft_pktinfo *pkt)
279 {
280         const struct nft_ct *priv = nft_expr_priv(expr);
281         struct sk_buff *skb = pkt->skb;
282 #if defined(CONFIG_NF_CONNTRACK_MARK) || defined(CONFIG_NF_CONNTRACK_SECMARK)
283         u32 value = regs->data[priv->sreg];
284 #endif
285         enum ip_conntrack_info ctinfo;
286         struct nf_conn *ct;
287
288         ct = nf_ct_get(skb, &ctinfo);
289         if (ct == NULL || nf_ct_is_template(ct))
290                 return;
291
292         switch (priv->key) {
293 #ifdef CONFIG_NF_CONNTRACK_MARK
294         case NFT_CT_MARK:
295                 if (ct->mark != value) {
296                         ct->mark = value;
297                         nf_conntrack_event_cache(IPCT_MARK, ct);
298                 }
299                 break;
300 #endif
301 #ifdef CONFIG_NF_CONNTRACK_SECMARK
302         case NFT_CT_SECMARK:
303                 if (ct->secmark != value) {
304                         ct->secmark = value;
305                         nf_conntrack_event_cache(IPCT_SECMARK, ct);
306                 }
307                 break;
308 #endif
309 #ifdef CONFIG_NF_CONNTRACK_LABELS
310         case NFT_CT_LABELS:
311                 nf_connlabels_replace(ct,
312                                       &regs->data[priv->sreg],
313                                       &regs->data[priv->sreg],
314                                       NF_CT_LABELS_MAX_SIZE / sizeof(u32));
315                 break;
316 #endif
317 #ifdef CONFIG_NF_CONNTRACK_EVENTS
318         case NFT_CT_EVENTMASK: {
319                 struct nf_conntrack_ecache *e = nf_ct_ecache_find(ct);
320                 u32 ctmask = regs->data[priv->sreg];
321
322                 if (e) {
323                         if (e->ctmask != ctmask)
324                                 e->ctmask = ctmask;
325                         break;
326                 }
327
328                 if (ctmask && !nf_ct_is_confirmed(ct))
329                         nf_ct_ecache_ext_add(ct, ctmask, 0, GFP_ATOMIC);
330                 break;
331         }
332 #endif
333         default:
334                 break;
335         }
336 }
337
338 static const struct nla_policy nft_ct_policy[NFTA_CT_MAX + 1] = {
339         [NFTA_CT_DREG]          = { .type = NLA_U32 },
340         [NFTA_CT_KEY]           = { .type = NLA_U32 },
341         [NFTA_CT_DIRECTION]     = { .type = NLA_U8 },
342         [NFTA_CT_SREG]          = { .type = NLA_U32 },
343 };
344
345 #ifdef CONFIG_NF_CONNTRACK_ZONES
346 static void nft_ct_tmpl_put_pcpu(void)
347 {
348         struct nf_conn *ct;
349         int cpu;
350
351         for_each_possible_cpu(cpu) {
352                 ct = per_cpu(nft_ct_pcpu_template, cpu);
353                 if (!ct)
354                         break;
355                 nf_ct_put(ct);
356                 per_cpu(nft_ct_pcpu_template, cpu) = NULL;
357         }
358 }
359
360 static bool nft_ct_tmpl_alloc_pcpu(void)
361 {
362         struct nf_conntrack_zone zone = { .id = 0 };
363         struct nf_conn *tmp;
364         int cpu;
365
366         if (nft_ct_pcpu_template_refcnt)
367                 return true;
368
369         for_each_possible_cpu(cpu) {
370                 tmp = nf_ct_tmpl_alloc(&init_net, &zone, GFP_KERNEL);
371                 if (!tmp) {
372                         nft_ct_tmpl_put_pcpu();
373                         return false;
374                 }
375
376                 atomic_set(&tmp->ct_general.use, 1);
377                 per_cpu(nft_ct_pcpu_template, cpu) = tmp;
378         }
379
380         return true;
381 }
382 #endif
383
384 static int nft_ct_get_init(const struct nft_ctx *ctx,
385                            const struct nft_expr *expr,
386                            const struct nlattr * const tb[])
387 {
388         struct nft_ct *priv = nft_expr_priv(expr);
389         unsigned int len;
390         int err;
391
392         priv->key = ntohl(nla_get_be32(tb[NFTA_CT_KEY]));
393         priv->dir = IP_CT_DIR_MAX;
394         switch (priv->key) {
395         case NFT_CT_DIRECTION:
396                 if (tb[NFTA_CT_DIRECTION] != NULL)
397                         return -EINVAL;
398                 len = sizeof(u8);
399                 break;
400         case NFT_CT_STATE:
401         case NFT_CT_STATUS:
402 #ifdef CONFIG_NF_CONNTRACK_MARK
403         case NFT_CT_MARK:
404 #endif
405 #ifdef CONFIG_NF_CONNTRACK_SECMARK
406         case NFT_CT_SECMARK:
407 #endif
408         case NFT_CT_EXPIRATION:
409                 if (tb[NFTA_CT_DIRECTION] != NULL)
410                         return -EINVAL;
411                 len = sizeof(u32);
412                 break;
413 #ifdef CONFIG_NF_CONNTRACK_LABELS
414         case NFT_CT_LABELS:
415                 if (tb[NFTA_CT_DIRECTION] != NULL)
416                         return -EINVAL;
417                 len = NF_CT_LABELS_MAX_SIZE;
418                 break;
419 #endif
420         case NFT_CT_HELPER:
421                 if (tb[NFTA_CT_DIRECTION] != NULL)
422                         return -EINVAL;
423                 len = NF_CT_HELPER_NAME_LEN;
424                 break;
425
426         case NFT_CT_L3PROTOCOL:
427         case NFT_CT_PROTOCOL:
428                 /* For compatibility, do not report error if NFTA_CT_DIRECTION
429                  * attribute is specified.
430                  */
431                 len = sizeof(u8);
432                 break;
433         case NFT_CT_SRC:
434         case NFT_CT_DST:
435                 if (tb[NFTA_CT_DIRECTION] == NULL)
436                         return -EINVAL;
437
438                 switch (ctx->family) {
439                 case NFPROTO_IPV4:
440                         len = FIELD_SIZEOF(struct nf_conntrack_tuple,
441                                            src.u3.ip);
442                         break;
443                 case NFPROTO_IPV6:
444                 case NFPROTO_INET:
445                         len = FIELD_SIZEOF(struct nf_conntrack_tuple,
446                                            src.u3.ip6);
447                         break;
448                 default:
449                         return -EAFNOSUPPORT;
450                 }
451                 break;
452         case NFT_CT_SRC_IP:
453         case NFT_CT_DST_IP:
454                 if (tb[NFTA_CT_DIRECTION] == NULL)
455                         return -EINVAL;
456
457                 len = FIELD_SIZEOF(struct nf_conntrack_tuple, src.u3.ip);
458                 break;
459         case NFT_CT_SRC_IP6:
460         case NFT_CT_DST_IP6:
461                 if (tb[NFTA_CT_DIRECTION] == NULL)
462                         return -EINVAL;
463
464                 len = FIELD_SIZEOF(struct nf_conntrack_tuple, src.u3.ip6);
465                 break;
466         case NFT_CT_PROTO_SRC:
467         case NFT_CT_PROTO_DST:
468                 if (tb[NFTA_CT_DIRECTION] == NULL)
469                         return -EINVAL;
470                 len = FIELD_SIZEOF(struct nf_conntrack_tuple, src.u.all);
471                 break;
472         case NFT_CT_BYTES:
473         case NFT_CT_PKTS:
474         case NFT_CT_AVGPKT:
475                 len = sizeof(u64);
476                 break;
477 #ifdef CONFIG_NF_CONNTRACK_ZONES
478         case NFT_CT_ZONE:
479                 len = sizeof(u16);
480                 break;
481 #endif
482         default:
483                 return -EOPNOTSUPP;
484         }
485
486         if (tb[NFTA_CT_DIRECTION] != NULL) {
487                 priv->dir = nla_get_u8(tb[NFTA_CT_DIRECTION]);
488                 switch (priv->dir) {
489                 case IP_CT_DIR_ORIGINAL:
490                 case IP_CT_DIR_REPLY:
491                         break;
492                 default:
493                         return -EINVAL;
494                 }
495         }
496
497         priv->dreg = nft_parse_register(tb[NFTA_CT_DREG]);
498         err = nft_validate_register_store(ctx, priv->dreg, NULL,
499                                           NFT_DATA_VALUE, len);
500         if (err < 0)
501                 return err;
502
503         err = nf_ct_netns_get(ctx->net, ctx->family);
504         if (err < 0)
505                 return err;
506
507         if (priv->key == NFT_CT_BYTES ||
508             priv->key == NFT_CT_PKTS  ||
509             priv->key == NFT_CT_AVGPKT)
510                 nf_ct_set_acct(ctx->net, true);
511
512         return 0;
513 }
514
515 static void __nft_ct_set_destroy(const struct nft_ctx *ctx, struct nft_ct *priv)
516 {
517         switch (priv->key) {
518 #ifdef CONFIG_NF_CONNTRACK_LABELS
519         case NFT_CT_LABELS:
520                 nf_connlabels_put(ctx->net);
521                 break;
522 #endif
523 #ifdef CONFIG_NF_CONNTRACK_ZONES
524         case NFT_CT_ZONE:
525                 if (--nft_ct_pcpu_template_refcnt == 0)
526                         nft_ct_tmpl_put_pcpu();
527 #endif
528         default:
529                 break;
530         }
531 }
532
533 static int nft_ct_set_init(const struct nft_ctx *ctx,
534                            const struct nft_expr *expr,
535                            const struct nlattr * const tb[])
536 {
537         struct nft_ct *priv = nft_expr_priv(expr);
538         unsigned int len;
539         int err;
540
541         priv->dir = IP_CT_DIR_MAX;
542         priv->key = ntohl(nla_get_be32(tb[NFTA_CT_KEY]));
543         switch (priv->key) {
544 #ifdef CONFIG_NF_CONNTRACK_MARK
545         case NFT_CT_MARK:
546                 if (tb[NFTA_CT_DIRECTION])
547                         return -EINVAL;
548                 len = FIELD_SIZEOF(struct nf_conn, mark);
549                 break;
550 #endif
551 #ifdef CONFIG_NF_CONNTRACK_LABELS
552         case NFT_CT_LABELS:
553                 if (tb[NFTA_CT_DIRECTION])
554                         return -EINVAL;
555                 len = NF_CT_LABELS_MAX_SIZE;
556                 err = nf_connlabels_get(ctx->net, (len * BITS_PER_BYTE) - 1);
557                 if (err)
558                         return err;
559                 break;
560 #endif
561 #ifdef CONFIG_NF_CONNTRACK_ZONES
562         case NFT_CT_ZONE:
563                 if (!nft_ct_tmpl_alloc_pcpu())
564                         return -ENOMEM;
565                 nft_ct_pcpu_template_refcnt++;
566                 len = sizeof(u16);
567                 break;
568 #endif
569 #ifdef CONFIG_NF_CONNTRACK_EVENTS
570         case NFT_CT_EVENTMASK:
571                 if (tb[NFTA_CT_DIRECTION])
572                         return -EINVAL;
573                 len = sizeof(u32);
574                 break;
575 #endif
576 #ifdef CONFIG_NF_CONNTRACK_SECMARK
577         case NFT_CT_SECMARK:
578                 if (tb[NFTA_CT_DIRECTION])
579                         return -EINVAL;
580                 len = sizeof(u32);
581                 break;
582 #endif
583         default:
584                 return -EOPNOTSUPP;
585         }
586
587         if (tb[NFTA_CT_DIRECTION]) {
588                 priv->dir = nla_get_u8(tb[NFTA_CT_DIRECTION]);
589                 switch (priv->dir) {
590                 case IP_CT_DIR_ORIGINAL:
591                 case IP_CT_DIR_REPLY:
592                         break;
593                 default:
594                         err = -EINVAL;
595                         goto err1;
596                 }
597         }
598
599         priv->sreg = nft_parse_register(tb[NFTA_CT_SREG]);
600         err = nft_validate_register_load(priv->sreg, len);
601         if (err < 0)
602                 goto err1;
603
604         err = nf_ct_netns_get(ctx->net, ctx->family);
605         if (err < 0)
606                 goto err1;
607
608         return 0;
609
610 err1:
611         __nft_ct_set_destroy(ctx, priv);
612         return err;
613 }
614
615 static void nft_ct_get_destroy(const struct nft_ctx *ctx,
616                                const struct nft_expr *expr)
617 {
618         nf_ct_netns_put(ctx->net, ctx->family);
619 }
620
621 static void nft_ct_set_destroy(const struct nft_ctx *ctx,
622                                const struct nft_expr *expr)
623 {
624         struct nft_ct *priv = nft_expr_priv(expr);
625
626         __nft_ct_set_destroy(ctx, priv);
627         nf_ct_netns_put(ctx->net, ctx->family);
628 }
629
630 static int nft_ct_get_dump(struct sk_buff *skb, const struct nft_expr *expr)
631 {
632         const struct nft_ct *priv = nft_expr_priv(expr);
633
634         if (nft_dump_register(skb, NFTA_CT_DREG, priv->dreg))
635                 goto nla_put_failure;
636         if (nla_put_be32(skb, NFTA_CT_KEY, htonl(priv->key)))
637                 goto nla_put_failure;
638
639         switch (priv->key) {
640         case NFT_CT_SRC:
641         case NFT_CT_DST:
642         case NFT_CT_SRC_IP:
643         case NFT_CT_DST_IP:
644         case NFT_CT_SRC_IP6:
645         case NFT_CT_DST_IP6:
646         case NFT_CT_PROTO_SRC:
647         case NFT_CT_PROTO_DST:
648                 if (nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir))
649                         goto nla_put_failure;
650                 break;
651         case NFT_CT_BYTES:
652         case NFT_CT_PKTS:
653         case NFT_CT_AVGPKT:
654         case NFT_CT_ZONE:
655                 if (priv->dir < IP_CT_DIR_MAX &&
656                     nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir))
657                         goto nla_put_failure;
658                 break;
659         default:
660                 break;
661         }
662
663         return 0;
664
665 nla_put_failure:
666         return -1;
667 }
668
669 static int nft_ct_set_dump(struct sk_buff *skb, const struct nft_expr *expr)
670 {
671         const struct nft_ct *priv = nft_expr_priv(expr);
672
673         if (nft_dump_register(skb, NFTA_CT_SREG, priv->sreg))
674                 goto nla_put_failure;
675         if (nla_put_be32(skb, NFTA_CT_KEY, htonl(priv->key)))
676                 goto nla_put_failure;
677
678         switch (priv->key) {
679         case NFT_CT_ZONE:
680                 if (priv->dir < IP_CT_DIR_MAX &&
681                     nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir))
682                         goto nla_put_failure;
683                 break;
684         default:
685                 break;
686         }
687
688         return 0;
689
690 nla_put_failure:
691         return -1;
692 }
693
694 static struct nft_expr_type nft_ct_type;
695 static const struct nft_expr_ops nft_ct_get_ops = {
696         .type           = &nft_ct_type,
697         .size           = NFT_EXPR_SIZE(sizeof(struct nft_ct)),
698         .eval           = nft_ct_get_eval,
699         .init           = nft_ct_get_init,
700         .destroy        = nft_ct_get_destroy,
701         .dump           = nft_ct_get_dump,
702 };
703
704 static const struct nft_expr_ops nft_ct_set_ops = {
705         .type           = &nft_ct_type,
706         .size           = NFT_EXPR_SIZE(sizeof(struct nft_ct)),
707         .eval           = nft_ct_set_eval,
708         .init           = nft_ct_set_init,
709         .destroy        = nft_ct_set_destroy,
710         .dump           = nft_ct_set_dump,
711 };
712
713 #ifdef CONFIG_NF_CONNTRACK_ZONES
714 static const struct nft_expr_ops nft_ct_set_zone_ops = {
715         .type           = &nft_ct_type,
716         .size           = NFT_EXPR_SIZE(sizeof(struct nft_ct)),
717         .eval           = nft_ct_set_zone_eval,
718         .init           = nft_ct_set_init,
719         .destroy        = nft_ct_set_destroy,
720         .dump           = nft_ct_set_dump,
721 };
722 #endif
723
724 static const struct nft_expr_ops *
725 nft_ct_select_ops(const struct nft_ctx *ctx,
726                     const struct nlattr * const tb[])
727 {
728         if (tb[NFTA_CT_KEY] == NULL)
729                 return ERR_PTR(-EINVAL);
730
731         if (tb[NFTA_CT_DREG] && tb[NFTA_CT_SREG])
732                 return ERR_PTR(-EINVAL);
733
734         if (tb[NFTA_CT_DREG])
735                 return &nft_ct_get_ops;
736
737         if (tb[NFTA_CT_SREG]) {
738 #ifdef CONFIG_NF_CONNTRACK_ZONES
739                 if (nla_get_be32(tb[NFTA_CT_KEY]) == htonl(NFT_CT_ZONE))
740                         return &nft_ct_set_zone_ops;
741 #endif
742                 return &nft_ct_set_ops;
743         }
744
745         return ERR_PTR(-EINVAL);
746 }
747
748 static struct nft_expr_type nft_ct_type __read_mostly = {
749         .name           = "ct",
750         .select_ops     = nft_ct_select_ops,
751         .policy         = nft_ct_policy,
752         .maxattr        = NFTA_CT_MAX,
753         .owner          = THIS_MODULE,
754 };
755
756 static void nft_notrack_eval(const struct nft_expr *expr,
757                              struct nft_regs *regs,
758                              const struct nft_pktinfo *pkt)
759 {
760         struct sk_buff *skb = pkt->skb;
761         enum ip_conntrack_info ctinfo;
762         struct nf_conn *ct;
763
764         ct = nf_ct_get(pkt->skb, &ctinfo);
765         /* Previously seen (loopback or untracked)?  Ignore. */
766         if (ct || ctinfo == IP_CT_UNTRACKED)
767                 return;
768
769         nf_ct_set(skb, ct, IP_CT_UNTRACKED);
770 }
771
772 static struct nft_expr_type nft_notrack_type;
773 static const struct nft_expr_ops nft_notrack_ops = {
774         .type           = &nft_notrack_type,
775         .size           = NFT_EXPR_SIZE(0),
776         .eval           = nft_notrack_eval,
777 };
778
779 static struct nft_expr_type nft_notrack_type __read_mostly = {
780         .name           = "notrack",
781         .ops            = &nft_notrack_ops,
782         .owner          = THIS_MODULE,
783 };
784
785 #ifdef CONFIG_NF_CONNTRACK_TIMEOUT
786 static int
787 nft_ct_timeout_parse_policy(void *timeouts,
788                             const struct nf_conntrack_l4proto *l4proto,
789                             struct net *net, const struct nlattr *attr)
790 {
791         struct nlattr **tb;
792         int ret = 0;
793
794         tb = kcalloc(l4proto->ctnl_timeout.nlattr_max + 1, sizeof(*tb),
795                      GFP_KERNEL);
796
797         if (!tb)
798                 return -ENOMEM;
799
800         ret = nla_parse_nested(tb, l4proto->ctnl_timeout.nlattr_max,
801                                attr, l4proto->ctnl_timeout.nla_policy,
802                                NULL);
803         if (ret < 0)
804                 goto err;
805
806         ret = l4proto->ctnl_timeout.nlattr_to_obj(tb, net, timeouts);
807
808 err:
809         kfree(tb);
810         return ret;
811 }
812
813 struct nft_ct_timeout_obj {
814         struct nf_ct_timeout    *timeout;
815         u8                      l4proto;
816 };
817
818 static void nft_ct_timeout_obj_eval(struct nft_object *obj,
819                                     struct nft_regs *regs,
820                                     const struct nft_pktinfo *pkt)
821 {
822         const struct nft_ct_timeout_obj *priv = nft_obj_data(obj);
823         struct nf_conn *ct = (struct nf_conn *)skb_nfct(pkt->skb);
824         struct nf_conn_timeout *timeout;
825         const unsigned int *values;
826
827         if (priv->l4proto != pkt->tprot)
828                 return;
829
830         if (!ct || nf_ct_is_template(ct) || nf_ct_is_confirmed(ct))
831                 return;
832
833         timeout = nf_ct_timeout_find(ct);
834         if (!timeout) {
835                 timeout = nf_ct_timeout_ext_add(ct, priv->timeout, GFP_ATOMIC);
836                 if (!timeout) {
837                         regs->verdict.code = NF_DROP;
838                         return;
839                 }
840         }
841
842         rcu_assign_pointer(timeout->timeout, priv->timeout);
843
844         /* adjust the timeout as per 'new' state. ct is unconfirmed,
845          * so the current timestamp must not be added.
846          */
847         values = nf_ct_timeout_data(timeout);
848         if (values)
849                 nf_ct_refresh(ct, pkt->skb, values[0]);
850 }
851
852 static int nft_ct_timeout_obj_init(const struct nft_ctx *ctx,
853                                    const struct nlattr * const tb[],
854                                    struct nft_object *obj)
855 {
856         struct nft_ct_timeout_obj *priv = nft_obj_data(obj);
857         const struct nf_conntrack_l4proto *l4proto;
858         struct nf_ct_timeout *timeout;
859         int l3num = ctx->family;
860         __u8 l4num;
861         int ret;
862
863         if (!tb[NFTA_CT_TIMEOUT_L4PROTO] ||
864             !tb[NFTA_CT_TIMEOUT_DATA])
865                 return -EINVAL;
866
867         if (tb[NFTA_CT_TIMEOUT_L3PROTO])
868                 l3num = ntohs(nla_get_be16(tb[NFTA_CT_TIMEOUT_L3PROTO]));
869
870         l4num = nla_get_u8(tb[NFTA_CT_TIMEOUT_L4PROTO]);
871         priv->l4proto = l4num;
872
873         l4proto = nf_ct_l4proto_find_get(l4num);
874
875         if (l4proto->l4proto != l4num) {
876                 ret = -EOPNOTSUPP;
877                 goto err_proto_put;
878         }
879
880         timeout = kzalloc(sizeof(struct nf_ct_timeout) +
881                           l4proto->ctnl_timeout.obj_size, GFP_KERNEL);
882         if (timeout == NULL) {
883                 ret = -ENOMEM;
884                 goto err_proto_put;
885         }
886
887         ret = nft_ct_timeout_parse_policy(&timeout->data, l4proto, ctx->net,
888                                           tb[NFTA_CT_TIMEOUT_DATA]);
889         if (ret < 0)
890                 goto err_free_timeout;
891
892         timeout->l3num = l3num;
893         timeout->l4proto = l4proto;
894
895         ret = nf_ct_netns_get(ctx->net, ctx->family);
896         if (ret < 0)
897                 goto err_free_timeout;
898
899         priv->timeout = timeout;
900         return 0;
901
902 err_free_timeout:
903         kfree(timeout);
904 err_proto_put:
905         nf_ct_l4proto_put(l4proto);
906         return ret;
907 }
908
909 static void nft_ct_timeout_obj_destroy(const struct nft_ctx *ctx,
910                                        struct nft_object *obj)
911 {
912         struct nft_ct_timeout_obj *priv = nft_obj_data(obj);
913         struct nf_ct_timeout *timeout = priv->timeout;
914
915         nf_ct_untimeout(ctx->net, timeout);
916         nf_ct_l4proto_put(timeout->l4proto);
917         nf_ct_netns_put(ctx->net, ctx->family);
918         kfree(priv->timeout);
919 }
920
921 static int nft_ct_timeout_obj_dump(struct sk_buff *skb,
922                                    struct nft_object *obj, bool reset)
923 {
924         const struct nft_ct_timeout_obj *priv = nft_obj_data(obj);
925         const struct nf_ct_timeout *timeout = priv->timeout;
926         struct nlattr *nest_params;
927         int ret;
928
929         if (nla_put_u8(skb, NFTA_CT_TIMEOUT_L4PROTO, timeout->l4proto->l4proto) ||
930             nla_put_be16(skb, NFTA_CT_TIMEOUT_L3PROTO, htons(timeout->l3num)))
931                 return -1;
932
933         nest_params = nla_nest_start(skb, NFTA_CT_TIMEOUT_DATA | NLA_F_NESTED);
934         if (!nest_params)
935                 return -1;
936
937         ret = timeout->l4proto->ctnl_timeout.obj_to_nlattr(skb, &timeout->data);
938         if (ret < 0)
939                 return -1;
940         nla_nest_end(skb, nest_params);
941         return 0;
942 }
943
944 static const struct nla_policy nft_ct_timeout_policy[NFTA_CT_TIMEOUT_MAX + 1] = {
945         [NFTA_CT_TIMEOUT_L3PROTO] = {.type = NLA_U16 },
946         [NFTA_CT_TIMEOUT_L4PROTO] = {.type = NLA_U8 },
947         [NFTA_CT_TIMEOUT_DATA]    = {.type = NLA_NESTED },
948 };
949
950 static struct nft_object_type nft_ct_timeout_obj_type;
951
952 static const struct nft_object_ops nft_ct_timeout_obj_ops = {
953         .type           = &nft_ct_timeout_obj_type,
954         .size           = sizeof(struct nft_ct_timeout_obj),
955         .eval           = nft_ct_timeout_obj_eval,
956         .init           = nft_ct_timeout_obj_init,
957         .destroy        = nft_ct_timeout_obj_destroy,
958         .dump           = nft_ct_timeout_obj_dump,
959 };
960
961 static struct nft_object_type nft_ct_timeout_obj_type __read_mostly = {
962         .type           = NFT_OBJECT_CT_TIMEOUT,
963         .ops            = &nft_ct_timeout_obj_ops,
964         .maxattr        = NFTA_CT_TIMEOUT_MAX,
965         .policy         = nft_ct_timeout_policy,
966         .owner          = THIS_MODULE,
967 };
968 #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */
969
970 static int nft_ct_helper_obj_init(const struct nft_ctx *ctx,
971                                   const struct nlattr * const tb[],
972                                   struct nft_object *obj)
973 {
974         struct nft_ct_helper_obj *priv = nft_obj_data(obj);
975         struct nf_conntrack_helper *help4, *help6;
976         char name[NF_CT_HELPER_NAME_LEN];
977         int family = ctx->family;
978         int err;
979
980         if (!tb[NFTA_CT_HELPER_NAME] || !tb[NFTA_CT_HELPER_L4PROTO])
981                 return -EINVAL;
982
983         priv->l4proto = nla_get_u8(tb[NFTA_CT_HELPER_L4PROTO]);
984         if (!priv->l4proto)
985                 return -ENOENT;
986
987         nla_strlcpy(name, tb[NFTA_CT_HELPER_NAME], sizeof(name));
988
989         if (tb[NFTA_CT_HELPER_L3PROTO])
990                 family = ntohs(nla_get_be16(tb[NFTA_CT_HELPER_L3PROTO]));
991
992         help4 = NULL;
993         help6 = NULL;
994
995         switch (family) {
996         case NFPROTO_IPV4:
997                 if (ctx->family == NFPROTO_IPV6)
998                         return -EINVAL;
999
1000                 help4 = nf_conntrack_helper_try_module_get(name, family,
1001                                                            priv->l4proto);
1002                 break;
1003         case NFPROTO_IPV6:
1004                 if (ctx->family == NFPROTO_IPV4)
1005                         return -EINVAL;
1006
1007                 help6 = nf_conntrack_helper_try_module_get(name, family,
1008                                                            priv->l4proto);
1009                 break;
1010         case NFPROTO_NETDEV: /* fallthrough */
1011         case NFPROTO_BRIDGE: /* same */
1012         case NFPROTO_INET:
1013                 help4 = nf_conntrack_helper_try_module_get(name, NFPROTO_IPV4,
1014                                                            priv->l4proto);
1015                 help6 = nf_conntrack_helper_try_module_get(name, NFPROTO_IPV6,
1016                                                            priv->l4proto);
1017                 break;
1018         default:
1019                 return -EAFNOSUPPORT;
1020         }
1021
1022         /* && is intentional; only error if INET found neither ipv4 or ipv6 */
1023         if (!help4 && !help6)
1024                 return -ENOENT;
1025
1026         priv->helper4 = help4;
1027         priv->helper6 = help6;
1028
1029         err = nf_ct_netns_get(ctx->net, ctx->family);
1030         if (err < 0)
1031                 goto err_put_helper;
1032
1033         return 0;
1034
1035 err_put_helper:
1036         if (priv->helper4)
1037                 nf_conntrack_helper_put(priv->helper4);
1038         if (priv->helper6)
1039                 nf_conntrack_helper_put(priv->helper6);
1040         return err;
1041 }
1042
1043 static void nft_ct_helper_obj_destroy(const struct nft_ctx *ctx,
1044                                       struct nft_object *obj)
1045 {
1046         struct nft_ct_helper_obj *priv = nft_obj_data(obj);
1047
1048         if (priv->helper4)
1049                 nf_conntrack_helper_put(priv->helper4);
1050         if (priv->helper6)
1051                 nf_conntrack_helper_put(priv->helper6);
1052
1053         nf_ct_netns_put(ctx->net, ctx->family);
1054 }
1055
1056 static void nft_ct_helper_obj_eval(struct nft_object *obj,
1057                                    struct nft_regs *regs,
1058                                    const struct nft_pktinfo *pkt)
1059 {
1060         const struct nft_ct_helper_obj *priv = nft_obj_data(obj);
1061         struct nf_conn *ct = (struct nf_conn *)skb_nfct(pkt->skb);
1062         struct nf_conntrack_helper *to_assign = NULL;
1063         struct nf_conn_help *help;
1064
1065         if (!ct ||
1066             nf_ct_is_confirmed(ct) ||
1067             nf_ct_is_template(ct) ||
1068             priv->l4proto != nf_ct_protonum(ct))
1069                 return;
1070
1071         switch (nf_ct_l3num(ct)) {
1072         case NFPROTO_IPV4:
1073                 to_assign = priv->helper4;
1074                 break;
1075         case NFPROTO_IPV6:
1076                 to_assign = priv->helper6;
1077                 break;
1078         default:
1079                 WARN_ON_ONCE(1);
1080                 return;
1081         }
1082
1083         if (!to_assign)
1084                 return;
1085
1086         if (test_bit(IPS_HELPER_BIT, &ct->status))
1087                 return;
1088
1089         help = nf_ct_helper_ext_add(ct, GFP_ATOMIC);
1090         if (help) {
1091                 rcu_assign_pointer(help->helper, to_assign);
1092                 set_bit(IPS_HELPER_BIT, &ct->status);
1093         }
1094 }
1095
1096 static int nft_ct_helper_obj_dump(struct sk_buff *skb,
1097                                   struct nft_object *obj, bool reset)
1098 {
1099         const struct nft_ct_helper_obj *priv = nft_obj_data(obj);
1100         const struct nf_conntrack_helper *helper;
1101         u16 family;
1102
1103         if (priv->helper4 && priv->helper6) {
1104                 family = NFPROTO_INET;
1105                 helper = priv->helper4;
1106         } else if (priv->helper6) {
1107                 family = NFPROTO_IPV6;
1108                 helper = priv->helper6;
1109         } else {
1110                 family = NFPROTO_IPV4;
1111                 helper = priv->helper4;
1112         }
1113
1114         if (nla_put_string(skb, NFTA_CT_HELPER_NAME, helper->name))
1115                 return -1;
1116
1117         if (nla_put_u8(skb, NFTA_CT_HELPER_L4PROTO, priv->l4proto))
1118                 return -1;
1119
1120         if (nla_put_be16(skb, NFTA_CT_HELPER_L3PROTO, htons(family)))
1121                 return -1;
1122
1123         return 0;
1124 }
1125
1126 static const struct nla_policy nft_ct_helper_policy[NFTA_CT_HELPER_MAX + 1] = {
1127         [NFTA_CT_HELPER_NAME] = { .type = NLA_STRING,
1128                                   .len = NF_CT_HELPER_NAME_LEN - 1 },
1129         [NFTA_CT_HELPER_L3PROTO] = { .type = NLA_U16 },
1130         [NFTA_CT_HELPER_L4PROTO] = { .type = NLA_U8 },
1131 };
1132
1133 static struct nft_object_type nft_ct_helper_obj_type;
1134 static const struct nft_object_ops nft_ct_helper_obj_ops = {
1135         .type           = &nft_ct_helper_obj_type,
1136         .size           = sizeof(struct nft_ct_helper_obj),
1137         .eval           = nft_ct_helper_obj_eval,
1138         .init           = nft_ct_helper_obj_init,
1139         .destroy        = nft_ct_helper_obj_destroy,
1140         .dump           = nft_ct_helper_obj_dump,
1141 };
1142
1143 static struct nft_object_type nft_ct_helper_obj_type __read_mostly = {
1144         .type           = NFT_OBJECT_CT_HELPER,
1145         .ops            = &nft_ct_helper_obj_ops,
1146         .maxattr        = NFTA_CT_HELPER_MAX,
1147         .policy         = nft_ct_helper_policy,
1148         .owner          = THIS_MODULE,
1149 };
1150
1151 static int __init nft_ct_module_init(void)
1152 {
1153         int err;
1154
1155         BUILD_BUG_ON(NF_CT_LABELS_MAX_SIZE > NFT_REG_SIZE);
1156
1157         err = nft_register_expr(&nft_ct_type);
1158         if (err < 0)
1159                 return err;
1160
1161         err = nft_register_expr(&nft_notrack_type);
1162         if (err < 0)
1163                 goto err1;
1164
1165         err = nft_register_obj(&nft_ct_helper_obj_type);
1166         if (err < 0)
1167                 goto err2;
1168 #ifdef CONFIG_NF_CONNTRACK_TIMEOUT
1169         err = nft_register_obj(&nft_ct_timeout_obj_type);
1170         if (err < 0)
1171                 goto err3;
1172 #endif
1173         return 0;
1174
1175 #ifdef CONFIG_NF_CONNTRACK_TIMEOUT
1176 err3:
1177         nft_unregister_obj(&nft_ct_helper_obj_type);
1178 #endif
1179 err2:
1180         nft_unregister_expr(&nft_notrack_type);
1181 err1:
1182         nft_unregister_expr(&nft_ct_type);
1183         return err;
1184 }
1185
1186 static void __exit nft_ct_module_exit(void)
1187 {
1188 #ifdef CONFIG_NF_CONNTRACK_TIMEOUT
1189         nft_unregister_obj(&nft_ct_timeout_obj_type);
1190 #endif
1191         nft_unregister_obj(&nft_ct_helper_obj_type);
1192         nft_unregister_expr(&nft_notrack_type);
1193         nft_unregister_expr(&nft_ct_type);
1194 }
1195
1196 module_init(nft_ct_module_init);
1197 module_exit(nft_ct_module_exit);
1198
1199 MODULE_LICENSE("GPL");
1200 MODULE_AUTHOR("Patrick McHardy <kaber@trash.net>");
1201 MODULE_ALIAS_NFT_EXPR("ct");
1202 MODULE_ALIAS_NFT_EXPR("notrack");
1203 MODULE_ALIAS_NFT_OBJ(NFT_OBJECT_CT_HELPER);
1204 MODULE_ALIAS_NFT_OBJ(NFT_OBJECT_CT_TIMEOUT);