l2tp: fix races in tunnel creation
[linux] / net / l2tp / l2tp_core.c
index 14b67df..afb42d1 100644 (file)
@@ -1436,74 +1436,11 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32
 {
        struct l2tp_tunnel *tunnel = NULL;
        int err;
-       struct socket *sock = NULL;
-       struct sock *sk = NULL;
-       struct l2tp_net *pn;
        enum l2tp_encap_type encap = L2TP_ENCAPTYPE_UDP;
 
-       /* Get the tunnel socket from the fd, which was opened by
-        * the userspace L2TP daemon. If not specified, create a
-        * kernel socket.
-        */
-       if (fd < 0) {
-               err = l2tp_tunnel_sock_create(net, tunnel_id, peer_tunnel_id,
-                               cfg, &sock);
-               if (err < 0)
-                       goto err;
-       } else {
-               sock = sockfd_lookup(fd, &err);
-               if (!sock) {
-                       pr_err("tunl %u: sockfd_lookup(fd=%d) returned %d\n",
-                              tunnel_id, fd, err);
-                       err = -EBADF;
-                       goto err;
-               }
-
-               /* Reject namespace mismatches */
-               if (!net_eq(sock_net(sock->sk), net)) {
-                       pr_err("tunl %u: netns mismatch\n", tunnel_id);
-                       err = -EINVAL;
-                       goto err;
-               }
-       }
-
-       sk = sock->sk;
-
        if (cfg != NULL)
                encap = cfg->encap;
 
-       /* Quick sanity checks */
-       err = -EPROTONOSUPPORT;
-       if (sk->sk_type != SOCK_DGRAM) {
-               pr_debug("tunl %hu: fd %d wrong socket type\n",
-                        tunnel_id, fd);
-               goto err;
-       }
-       switch (encap) {
-       case L2TP_ENCAPTYPE_UDP:
-               if (sk->sk_protocol != IPPROTO_UDP) {
-                       pr_err("tunl %hu: fd %d wrong protocol, got %d, expected %d\n",
-                              tunnel_id, fd, sk->sk_protocol, IPPROTO_UDP);
-                       goto err;
-               }
-               break;
-       case L2TP_ENCAPTYPE_IP:
-               if (sk->sk_protocol != IPPROTO_L2TP) {
-                       pr_err("tunl %hu: fd %d wrong protocol, got %d, expected %d\n",
-                              tunnel_id, fd, sk->sk_protocol, IPPROTO_L2TP);
-                       goto err;
-               }
-               break;
-       }
-
-       /* Check if this socket has already been prepped */
-       tunnel = l2tp_tunnel(sk);
-       if (tunnel != NULL) {
-               /* This socket has already been prepped */
-               err = -EBUSY;
-               goto err;
-       }
-
        tunnel = kzalloc(sizeof(struct l2tp_tunnel), GFP_KERNEL);
        if (tunnel == NULL) {
                err = -ENOMEM;
@@ -1520,72 +1457,113 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32
        rwlock_init(&tunnel->hlist_lock);
        tunnel->acpt_newsess = true;
 
-       /* The net we belong to */
-       tunnel->l2tp_net = net;
-       pn = l2tp_pernet(net);
-
        if (cfg != NULL)
                tunnel->debug = cfg->debug;
 
-       /* Mark socket as an encapsulation socket. See net/ipv4/udp.c */
        tunnel->encap = encap;
-       if (encap == L2TP_ENCAPTYPE_UDP) {
-               struct udp_tunnel_sock_cfg udp_cfg = { };
-
-               udp_cfg.sk_user_data = tunnel;
-               udp_cfg.encap_type = UDP_ENCAP_L2TPINUDP;
-               udp_cfg.encap_rcv = l2tp_udp_encap_recv;
-               udp_cfg.encap_destroy = l2tp_udp_encap_destroy;
-
-               setup_udp_tunnel_sock(net, sock, &udp_cfg);
-       } else {
-               sk->sk_user_data = tunnel;
-       }
 
-       /* Bump the reference count. The tunnel context is deleted
-        * only when this drops to zero. A reference is also held on
-        * the tunnel socket to ensure that it is not released while
-        * the tunnel is extant. Must be done before sk_destruct is
-        * set.
-        */
        refcount_set(&tunnel->ref_count, 1);
-       sock_hold(sk);
-       tunnel->sock = sk;
        tunnel->fd = fd;
 
-       /* Hook on the tunnel socket destructor so that we can cleanup
-        * if the tunnel socket goes away.
-        */
-       tunnel->old_sk_destruct = sk->sk_destruct;
-       sk->sk_destruct = &l2tp_tunnel_destruct;
-       lockdep_set_class_and_name(&sk->sk_lock.slock, &l2tp_socket_class, "l2tp_sock");
-
-       sk->sk_allocation = GFP_ATOMIC;
-
        /* Init delete workqueue struct */
        INIT_WORK(&tunnel->del_work, l2tp_tunnel_del_work);
 
-       /* Add tunnel to our list */
        INIT_LIST_HEAD(&tunnel->list);
-       spin_lock_bh(&pn->l2tp_tunnel_list_lock);
-       list_add_rcu(&tunnel->list, &pn->l2tp_tunnel_list);
-       spin_unlock_bh(&pn->l2tp_tunnel_list_lock);
 
        err = 0;
 err:
        if (tunnelp)
                *tunnelp = tunnel;
 
-       /* If tunnel's socket was created by the kernel, it doesn't
-        *  have a file.
-        */
-       if (sock && sock->file)
-               sockfd_put(sock);
-
        return err;
 }
 EXPORT_SYMBOL_GPL(l2tp_tunnel_create);
 
+static int l2tp_validate_socket(const struct sock *sk, const struct net *net,
+                               enum l2tp_encap_type encap)
+{
+       if (!net_eq(sock_net(sk), net))
+               return -EINVAL;
+
+       if (sk->sk_type != SOCK_DGRAM)
+               return -EPROTONOSUPPORT;
+
+       if ((encap == L2TP_ENCAPTYPE_UDP && sk->sk_protocol != IPPROTO_UDP) ||
+           (encap == L2TP_ENCAPTYPE_IP && sk->sk_protocol != IPPROTO_L2TP))
+               return -EPROTONOSUPPORT;
+
+       if (sk->sk_user_data)
+               return -EBUSY;
+
+       return 0;
+}
+
+int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net,
+                        struct l2tp_tunnel_cfg *cfg)
+{
+       struct l2tp_net *pn;
+       struct socket *sock;
+       struct sock *sk;
+       int ret;
+
+       if (tunnel->fd < 0) {
+               ret = l2tp_tunnel_sock_create(net, tunnel->tunnel_id,
+                                             tunnel->peer_tunnel_id, cfg,
+                                             &sock);
+               if (ret < 0)
+                       goto err;
+       } else {
+               sock = sockfd_lookup(tunnel->fd, &ret);
+               if (!sock)
+                       goto err;
+
+               ret = l2tp_validate_socket(sock->sk, net, tunnel->encap);
+               if (ret < 0)
+                       goto err_sock;
+       }
+
+       sk = sock->sk;
+
+       sock_hold(sk);
+       tunnel->sock = sk;
+       tunnel->l2tp_net = net;
+
+       pn = l2tp_pernet(net);
+       spin_lock_bh(&pn->l2tp_tunnel_list_lock);
+       list_add_rcu(&tunnel->list, &pn->l2tp_tunnel_list);
+       spin_unlock_bh(&pn->l2tp_tunnel_list_lock);
+
+       if (tunnel->encap == L2TP_ENCAPTYPE_UDP) {
+               struct udp_tunnel_sock_cfg udp_cfg = {
+                       .sk_user_data = tunnel,
+                       .encap_type = UDP_ENCAP_L2TPINUDP,
+                       .encap_rcv = l2tp_udp_encap_recv,
+                       .encap_destroy = l2tp_udp_encap_destroy,
+               };
+
+               setup_udp_tunnel_sock(net, sock, &udp_cfg);
+       } else {
+               sk->sk_user_data = tunnel;
+       }
+
+       tunnel->old_sk_destruct = sk->sk_destruct;
+       sk->sk_destruct = &l2tp_tunnel_destruct;
+       lockdep_set_class_and_name(&sk->sk_lock.slock, &l2tp_socket_class,
+                                  "l2tp_sock");
+       sk->sk_allocation = GFP_ATOMIC;
+
+       if (tunnel->fd >= 0)
+               sockfd_put(sock);
+
+       return 0;
+
+err_sock:
+       sockfd_put(sock);
+err:
+       return ret;
+}
+EXPORT_SYMBOL_GPL(l2tp_tunnel_register);
+
 /* This function is used by the netlink TUNNEL_DELETE command.
  */
 void l2tp_tunnel_delete(struct l2tp_tunnel *tunnel)