[IPSEC]: Store IPv6 nh pointer in mac_header on output
[powerpc.git] / net / ipv6 / esp6.c
index 7fdf84d..9fc1940 100644 (file)
@@ -51,6 +51,7 @@ static int esp6_output(struct xfrm_state *x, struct sk_buff *skb)
        int clen;
        int alen;
        int nfrags;
+       u8 *tail;
        struct esp_data *esp = x->data;
        int hdr_len = (skb_transport_offset(skb) +
                       sizeof(*esph) + esp->conf.ivlen);
@@ -78,23 +79,24 @@ static int esp6_output(struct xfrm_state *x, struct sk_buff *skb)
        }
 
        /* Fill padding... */
+       tail = skb_tail_pointer(trailer);
        do {
                int i;
                for (i=0; i<clen-skb->len - 2; i++)
-                       *(u8*)(trailer->tail + i) = i+1;
+                       tail[i] = i + 1;
        } while (0);
-       *(u8*)(trailer->tail + clen-skb->len - 2) = (clen - skb->len)-2;
+       tail[clen-skb->len - 2] = (clen - skb->len) - 2;
        pskb_put(skb, trailer, clen - skb->len);
 
-       top_iph = (struct ipv6hdr *)__skb_push(skb, hdr_len);
+       __skb_push(skb, -skb_network_offset(skb));
+       top_iph = ipv6_hdr(skb);
        esph = (struct ipv6_esp_hdr *)skb_transport_header(skb);
        top_iph->payload_len = htons(skb->len + alen - sizeof(*top_iph));
-       *(u8 *)(trailer->tail - 1) = *skb_network_header(skb);
-       *skb_network_header(skb) = IPPROTO_ESP;
+       *(skb_tail_pointer(trailer) - 1) = *skb_mac_header(skb);
+       *skb_mac_header(skb) = IPPROTO_ESP;
 
        esph->spi = x->id.spi;
-       esph->seq_no = htonl(++x->replay.oseq);
-       xfrm_aevent_doreplay(x);
+       esph->seq_no = htonl(XFRM_SKB_CB(skb)->seq);
 
        if (esp->conf.ivlen) {
                if (unlikely(!esp->conf.ivinitted)) {
@@ -233,22 +235,24 @@ out:
        return ret;
 }
 
-static u32 esp6_get_max_size(struct xfrm_state *x, int mtu)
+static u32 esp6_get_mtu(struct xfrm_state *x, int mtu)
 {
        struct esp_data *esp = x->data;
        u32 blksize = ALIGN(crypto_blkcipher_blocksize(esp->conf.tfm), 4);
+       u32 align = max_t(u32, blksize, esp->conf.padlen);
+       u32 rem;
 
-       if (x->props.mode == XFRM_MODE_TUNNEL) {
-               mtu = ALIGN(mtu + 2, blksize);
-       } else {
-               /* The worst case. */
+       mtu -= x->props.header_len + esp->auth.icv_trunc_len;
+       rem = mtu & (align - 1);
+       mtu &= ~(align - 1);
+
+       if (x->props.mode != XFRM_MODE_TUNNEL) {
                u32 padsize = ((blksize - 1) & 7) + 1;
-               mtu = ALIGN(mtu + 2, padsize) + blksize - padsize;
+               mtu -= blksize - padsize;
+               mtu += min_t(u32, blksize - padsize, rem);
        }
-       if (esp->conf.padlen)
-               mtu = ALIGN(mtu, esp->conf.padlen);
 
-       return mtu + x->props.header_len + esp->auth.icv_trunc_len;
+       return mtu - 2;
 }
 
 static void esp6_err(struct sk_buff *skb, struct inet6_skb_parm *opt,
@@ -293,11 +297,6 @@ static int esp6_init_state(struct xfrm_state *x)
        struct esp_data *esp = NULL;
        struct crypto_blkcipher *tfm;
 
-       /* null auth and encryption can have zero length keys */
-       if (x->aalg) {
-               if (x->aalg->alg_key_len > 512)
-                       goto error;
-       }
        if (x->ealg == NULL)
                goto error;
 
@@ -312,15 +311,14 @@ static int esp6_init_state(struct xfrm_state *x)
                struct xfrm_algo_desc *aalg_desc;
                struct crypto_hash *hash;
 
-               esp->auth.key = x->aalg->alg_key;
-               esp->auth.key_len = (x->aalg->alg_key_len+7)/8;
                hash = crypto_alloc_hash(x->aalg->alg_name, 0,
                                         CRYPTO_ALG_ASYNC);
                if (IS_ERR(hash))
                        goto error;
 
                esp->auth.tfm = hash;
-               if (crypto_hash_setkey(hash, esp->auth.key, esp->auth.key_len))
+               if (crypto_hash_setkey(hash, x->aalg->alg_key,
+                                      (x->aalg->alg_key_len + 7) / 8))
                        goto error;
 
                aalg_desc = xfrm_aalg_get_byname(x->aalg->alg_name, 0);
@@ -342,8 +340,6 @@ static int esp6_init_state(struct xfrm_state *x)
                if (!esp->auth.work_icv)
                        goto error;
        }
-       esp->conf.key = x->ealg->alg_key;
-       esp->conf.key_len = (x->ealg->alg_key_len+7)/8;
        tfm = crypto_alloc_blkcipher(x->ealg->alg_name, 0, CRYPTO_ALG_ASYNC);
        if (IS_ERR(tfm))
                goto error;
@@ -356,7 +352,8 @@ static int esp6_init_state(struct xfrm_state *x)
                        goto error;
                esp->conf.ivinitted = 0;
        }
-       if (crypto_blkcipher_setkey(tfm, esp->conf.key, esp->conf.key_len))
+       if (crypto_blkcipher_setkey(tfm, x->ealg->alg_key,
+                                   (x->ealg->alg_key_len + 7) / 8))
                goto error;
        x->props.header_len = sizeof(struct ipv6_esp_hdr) + esp->conf.ivlen;
        if (x->props.mode == XFRM_MODE_TUNNEL)
@@ -376,9 +373,10 @@ static struct xfrm_type esp6_type =
        .description    = "ESP6",
        .owner          = THIS_MODULE,
        .proto          = IPPROTO_ESP,
+       .flags          = XFRM_TYPE_REPLAY_PROT,
        .init_state     = esp6_init_state,
        .destructor     = esp6_destroy,
-       .get_max_size   = esp6_get_max_size,
+       .get_mtu        = esp6_get_mtu,
        .input          = esp6_input,
        .output         = esp6_output,
        .hdr_offset     = xfrm6_find_1stfragopt,
@@ -417,3 +415,4 @@ module_init(esp6_init);
 module_exit(esp6_fini);
 
 MODULE_LICENSE("GPL");
+MODULE_ALIAS_XFRM_TYPE(AF_INET6, XFRM_PROTO_ESP);