basic modification from way back
[powerpc.git] / net / netlink / af_netlink.c
index 383dd4e..1f15821 100644 (file)
@@ -8,7 +8,7 @@
  *             modify it under the terms of the GNU General Public License
  *             as published by the Free Software Foundation; either version
  *             2 of the License, or (at your option) any later version.
- * 
+ *
  * Tue Jun 26 14:36:48 MEST 2001 Herbert "herp" Rosmanith
  *                               added netlink_proto_exit
  * Tue Jan 22 18:32:44 BRST 2002 Arnaldo C. de Melo <acme@conectiva.com.br>
@@ -45,7 +45,6 @@
 #include <linux/rtnetlink.h>
 #include <linux/proc_fs.h>
 #include <linux/seq_file.h>
-#include <linux/smp_lock.h>
 #include <linux/notifier.h>
 #include <linux/security.h>
 #include <linux/jhash.h>
@@ -56,6 +55,7 @@
 #include <linux/types.h>
 #include <linux/audit.h>
 #include <linux/selinux.h>
+#include <linux/mutex.h>
 
 #include <net/sock.h>
 #include <net/scm.h>
@@ -76,7 +76,8 @@ struct netlink_sock {
        unsigned long           state;
        wait_queue_head_t       wait;
        struct netlink_callback *cb;
-       spinlock_t              cb_lock;
+       struct mutex            *cb_mutex;
+       struct mutex            cb_def_mutex;
        void                    (*data_ready)(struct sock *sk, int bytes);
        struct module           *module;
 };
@@ -108,6 +109,7 @@ struct netlink_table {
        unsigned long *listeners;
        unsigned int nl_nonroot;
        unsigned int groups;
+       struct mutex *cb_mutex;
        struct module *module;
        int registered;
 };
@@ -118,6 +120,7 @@ static DECLARE_WAIT_QUEUE_HEAD(nl_table_wait);
 
 static int netlink_dump(struct sock *sk);
 static void netlink_destroy_callback(struct netlink_callback *cb);
+static void netlink_queue_skip(struct nlmsghdr *nlh, struct sk_buff *skb);
 
 static DEFINE_RWLOCK(nl_table_lock);
 static atomic_t nl_table_users = ATOMIC_INIT(0);
@@ -136,6 +139,14 @@ static struct hlist_head *nl_pid_hashfn(struct nl_pid_hash *hash, u32 pid)
 
 static void netlink_sock_destruct(struct sock *sk)
 {
+       struct netlink_sock *nlk = nlk_sk(sk);
+
+       if (nlk->cb) {
+               if (nlk->cb->done)
+                       nlk->cb->done(nlk->cb);
+               netlink_destroy_callback(nlk->cb);
+       }
+
        skb_queue_purge(&sk->sk_receive_queue);
 
        if (!sock_flag(sk, SOCK_DEAD)) {
@@ -144,7 +155,6 @@ static void netlink_sock_destruct(struct sock *sk)
        }
        BUG_TRAP(!atomic_read(&sk->sk_rmem_alloc));
        BUG_TRAP(!atomic_read(&sk->sk_wmem_alloc));
-       BUG_TRAP(!nlk_sk(sk)->cb);
        BUG_TRAP(!nlk_sk(sk)->groups);
 }
 
@@ -370,7 +380,8 @@ static struct proto netlink_proto = {
        .obj_size = sizeof(struct netlink_sock),
 };
 
-static int __netlink_create(struct socket *sock, int protocol)
+static int __netlink_create(struct socket *sock, struct mutex *cb_mutex,
+                           int protocol)
 {
        struct sock *sk;
        struct netlink_sock *nlk;
@@ -384,7 +395,12 @@ static int __netlink_create(struct socket *sock, int protocol)
        sock_init_data(sock, sk);
 
        nlk = nlk_sk(sk);
-       spin_lock_init(&nlk->cb_lock);
+       if (cb_mutex)
+               nlk->cb_mutex = cb_mutex;
+       else {
+               nlk->cb_mutex = &nlk->cb_def_mutex;
+               mutex_init(nlk->cb_mutex);
+       }
        init_waitqueue_head(&nlk->wait);
 
        sk->sk_destruct = netlink_sock_destruct;
@@ -395,8 +411,8 @@ static int __netlink_create(struct socket *sock, int protocol)
 static int netlink_create(struct socket *sock, int protocol)
 {
        struct module *module = NULL;
+       struct mutex *cb_mutex;
        struct netlink_sock *nlk;
-       unsigned int groups;
        int err = 0;
 
        sock->state = SS_UNCONNECTED;
@@ -418,10 +434,10 @@ static int netlink_create(struct socket *sock, int protocol)
        if (nl_table[protocol].registered &&
            try_module_get(nl_table[protocol].module))
                module = nl_table[protocol].module;
-       groups = nl_table[protocol].groups;
+       cb_mutex = nl_table[protocol].cb_mutex;
        netlink_unlock_table();
 
-       if ((err = __netlink_create(sock, protocol)) < 0)
+       if ((err = __netlink_create(sock, cb_mutex, protocol)) < 0)
                goto out_module;
 
        nlk = nlk_sk(sock->sk);
@@ -443,21 +459,14 @@ static int netlink_release(struct socket *sock)
                return 0;
 
        netlink_remove(sk);
+       sock_orphan(sk);
        nlk = nlk_sk(sk);
 
-       spin_lock(&nlk->cb_lock);
-       if (nlk->cb) {
-               if (nlk->cb->done)
-                       nlk->cb->done(nlk->cb);
-               netlink_destroy_callback(nlk->cb);
-               nlk->cb = NULL;
-       }
-       spin_unlock(&nlk->cb_lock);
-
-       /* OK. Socket is unlinked, and, therefore,
-          no new packets will arrive */
+       /*
+        * OK. Socket is unlinked, any packets that arrive now
+        * will be purged.
+        */
 
-       sock_orphan(sk);
        sock->sk = NULL;
        wake_up_interruptible_all(&nlk->wait);
 
@@ -470,7 +479,7 @@ static int netlink_release(struct socket *sock)
                                          };
                atomic_notifier_call_chain(&netlink_chain,
                                NETLINK_URELEASE, &n);
-       }       
+       }
 
        module_put(nlk->module);
 
@@ -528,11 +537,11 @@ retry:
        return err;
 }
 
-static inline int netlink_capable(struct socket *sock, unsigned int flag) 
-{ 
+static inline int netlink_capable(struct socket *sock, unsigned int flag)
+{
        return (nl_table[sock->sk->sk_protocol].nl_nonroot & flag) ||
               capable(CAP_NET_ADMIN);
-} 
+}
 
 static void
 netlink_update_subscriptions(struct sock *sk, unsigned int subscriptions)
@@ -574,7 +583,7 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len
        struct netlink_sock *nlk = nlk_sk(sk);
        struct sockaddr_nl *nladdr = (struct sockaddr_nl *)addr;
        int err;
-       
+
        if (nladdr->nl_family != AF_NETLINK)
                return -EINVAL;
 
@@ -605,9 +614,9 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len
 
        netlink_table_grab();
        netlink_update_subscriptions(sk, nlk->subscriptions +
-                                        hweight32(nladdr->nl_groups) -
-                                        hweight32(nlk->groups[0]));
-       nlk->groups[0] = (nlk->groups[0] & ~0xffffffffUL) | nladdr->nl_groups; 
+                                        hweight32(nladdr->nl_groups) -
+                                        hweight32(nlk->groups[0]));
+       nlk->groups[0] = (nlk->groups[0] & ~0xffffffffUL) | nladdr->nl_groups;
        netlink_update_listeners(sk);
        netlink_table_ungrab();
 
@@ -652,7 +661,7 @@ static int netlink_getname(struct socket *sock, struct sockaddr *addr, int *addr
        struct sock *sk = sock->sk;
        struct netlink_sock *nlk = nlk_sk(sk);
        struct sockaddr_nl *nladdr=(struct sockaddr_nl *)addr;
-       
+
        nladdr->nl_family = AF_NETLINK;
        nladdr->nl_pad = 0;
        *addr_len = sizeof(*nladdr);
@@ -999,7 +1008,7 @@ void netlink_set_err(struct sock *ssk, u32 pid, u32 group, int code)
 }
 
 static int netlink_setsockopt(struct socket *sock, int level, int optname,
-                              char __user *optval, int optlen)
+                             char __user *optval, int optlen)
 {
        struct sock *sk = sock->sk;
        struct netlink_sock *nlk = nlk_sk(sk);
@@ -1054,7 +1063,7 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname,
 }
 
 static int netlink_getsockopt(struct socket *sock, int level, int optname,
-                              char __user *optval, int __user *optlen)
+                             char __user *optval, int __user *optlen)
 {
        struct sock *sk = sock->sk;
        struct netlink_sock *nlk = nlk_sk(sk);
@@ -1215,7 +1224,7 @@ static int netlink_recvmsg(struct kiocb *kiocb, struct socket *sock,
                copied = len;
        }
 
-       skb->h.raw = skb->data;
+       skb_reset_transport_header(skb);
        err = skb_copy_datagram_iovec(skb, 0, msg->msg_iov, copied);
 
        if (msg->msg_name) {
@@ -1235,13 +1244,14 @@ static int netlink_recvmsg(struct kiocb *kiocb, struct socket *sock,
                siocb->scm = &scm;
        }
        siocb->scm->creds = *NETLINK_CREDS(skb);
+       if (flags & MSG_TRUNC)
+               copied = skb->len;
        skb_free_datagram(sk, skb);
 
        if (nlk->cb && atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf / 2)
                netlink_dump(sk);
 
        scm_recv(sock, msg, siocb->scm, flags);
-
 out:
        netlink_rcv_wake(sk);
        return err ? : copied;
@@ -1257,15 +1267,15 @@ static void netlink_data_ready(struct sock *sk, int len)
 }
 
 /*
- *     We export these functions to other modules. They provide a 
+ *     We export these functions to other modules. They provide a
  *     complete set of kernel non-blocking support for message
  *     queueing.
  */
 
 struct sock *
 netlink_kernel_create(int unit, unsigned int groups,
-                      void (*input)(struct sock *sk, int len),
-                      struct module *module)
+                     void (*input)(struct sock *sk, int len),
+                     struct mutex *cb_mutex, struct module *module)
 {
        struct socket *sock;
        struct sock *sk;
@@ -1280,7 +1290,7 @@ netlink_kernel_create(int unit, unsigned int groups,
        if (sock_create_lite(PF_NETLINK, SOCK_DGRAM, unit, &sock))
                return NULL;
 
-       if (__netlink_create(sock, unit) < 0)
+       if (__netlink_create(sock, cb_mutex, unit) < 0)
                goto out_sock_release;
 
        if (groups < 32)
@@ -1304,6 +1314,7 @@ netlink_kernel_create(int unit, unsigned int groups,
        netlink_table_grab();
        nl_table[unit].groups = groups;
        nl_table[unit].listeners = listeners;
+       nl_table[unit].cb_mutex = cb_mutex;
        nl_table[unit].module = module;
        nl_table[unit].registered = 1;
        netlink_table_ungrab();
@@ -1317,10 +1328,10 @@ out_sock_release:
 }
 
 void netlink_set_nonroot(int protocol, unsigned int flags)
-{ 
-       if ((unsigned int)protocol < MAX_LINKS) 
+{
+       if ((unsigned int)protocol < MAX_LINKS)
                nl_table[protocol].nl_nonroot = flags;
-} 
+}
 
 static void netlink_destroy_callback(struct netlink_callback *cb)
 {
@@ -1341,12 +1352,12 @@ static int netlink_dump(struct sock *sk)
        struct sk_buff *skb;
        struct nlmsghdr *nlh;
        int len, err = -ENOBUFS;
-       
+
        skb = sock_rmalloc(sk, NLMSG_GOODSIZE, 0, GFP_KERNEL);
        if (!skb)
                goto errout;
 
-       spin_lock(&nlk->cb_lock);
+       mutex_lock(nlk->cb_mutex);
 
        cb = nlk->cb;
        if (cb == NULL) {
@@ -1357,7 +1368,7 @@ static int netlink_dump(struct sock *sk)
        len = cb->dump(skb, cb);
 
        if (len > 0) {
-               spin_unlock(&nlk->cb_lock);
+               mutex_unlock(nlk->cb_mutex);
                skb_queue_tail(&sk->sk_receive_queue, skb);
                sk->sk_data_ready(sk, len);
                return 0;
@@ -1375,13 +1386,13 @@ static int netlink_dump(struct sock *sk)
        if (cb->done)
                cb->done(cb);
        nlk->cb = NULL;
-       spin_unlock(&nlk->cb_lock);
+       mutex_unlock(nlk->cb_mutex);
 
        netlink_destroy_callback(cb);
        return 0;
 
 errout_skb:
-       spin_unlock(&nlk->cb_lock);
+       mutex_unlock(nlk->cb_mutex);
        kfree_skb(skb);
 errout:
        return err;
@@ -1413,19 +1424,24 @@ int netlink_dump_start(struct sock *ssk, struct sk_buff *skb,
        }
        nlk = nlk_sk(sk);
        /* A dump is in progress... */
-       spin_lock(&nlk->cb_lock);
+       mutex_lock(nlk->cb_mutex);
        if (nlk->cb) {
-               spin_unlock(&nlk->cb_lock);
+               mutex_unlock(nlk->cb_mutex);
                netlink_destroy_callback(cb);
                sock_put(sk);
                return -EBUSY;
        }
        nlk->cb = cb;
-       spin_unlock(&nlk->cb_lock);
+       mutex_unlock(nlk->cb_mutex);
 
        netlink_dump(sk);
        sock_put(sk);
-       return 0;
+
+       /* We successfully started a dump, by returning -EINTR we
+        * signal the queue mangement to interrupt processing of
+        * any netlink messages so userspace gets a chance to read
+        * the results. */
+       return -EINTR;
 }
 
 void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err)
@@ -1462,27 +1478,35 @@ void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err)
 }
 
 static int netlink_rcv_skb(struct sk_buff *skb, int (*cb)(struct sk_buff *,
-                                                    struct nlmsghdr *, int *))
+                                                    struct nlmsghdr *))
 {
        struct nlmsghdr *nlh;
        int err;
 
        while (skb->len >= nlmsg_total_size(0)) {
-               nlh = (struct nlmsghdr *) skb->data;
+               nlh = nlmsg_hdr(skb);
+               err = 0;
 
                if (nlh->nlmsg_len < NLMSG_HDRLEN || skb->len < nlh->nlmsg_len)
                        return 0;
 
-               if (cb(skb, nlh, &err) < 0) {
-                       /* Not an error, but we have to interrupt processing
-                        * here. Note: that in this case we do not pull
-                        * message from skb, it will be processed later.
-                        */
-                       if (err == 0)
-                               return -1;
+               /* Only requests are handled by the kernel */
+               if (!(nlh->nlmsg_flags & NLM_F_REQUEST))
+                       goto skip;
+
+               /* Skip control messages */
+               if (nlh->nlmsg_type < NLMSG_MIN_TYPE)
+                       goto skip;
+
+               err = cb(skb, nlh);
+               if (err == -EINTR) {
+                       /* Not an error, but we interrupt processing */
+                       netlink_queue_skip(nlh, skb);
+                       return err;
+               }
+skip:
+               if (nlh->nlmsg_flags & NLM_F_ACK || err)
                        netlink_ack(skb, nlh, err);
-               } else if (nlh->nlmsg_flags & NLM_F_ACK)
-                       netlink_ack(skb, nlh, 0);
 
                netlink_queue_skip(nlh, skb);
        }
@@ -1504,9 +1528,14 @@ static int netlink_rcv_skb(struct sk_buff *skb, int (*cb)(struct sk_buff *,
  *
  * qlen must be initialized to 0 before the initial entry, afterwards
  * the function may be called repeatedly until qlen reaches 0.
+ *
+ * The callback function may return -EINTR to signal that processing
+ * of netlink messages shall be interrupted. In this case the message
+ * currently being processed will NOT be requeued onto the receive
+ * queue.
  */
 void netlink_run_queue(struct sock *sk, unsigned int *qlen,
-                      int (*cb)(struct sk_buff *, struct nlmsghdr *, int *))
+                      int (*cb)(struct sk_buff *, struct nlmsghdr *))
 {
        struct sk_buff *skb;
 
@@ -1537,7 +1566,7 @@ void netlink_run_queue(struct sock *sk, unsigned int *qlen,
  * Pulls the given netlink message off the socket buffer so the next
  * call to netlink_queue_run() will not reconsider the message.
  */
-void netlink_queue_skip(struct nlmsghdr *nlh, struct sk_buff *skb)
+static void netlink_queue_skip(struct nlmsghdr *nlh, struct sk_buff *skb)
 {
        int msglen = NLMSG_ALIGN(nlh->nlmsg_len);
 
@@ -1626,7 +1655,7 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
 
        if (v == SEQ_START_TOKEN)
                return netlink_seq_socket_idx(seq, 0);
-               
+
        s = sk_next(v);
        if (s)
                return s;
@@ -1713,7 +1742,7 @@ static int netlink_seq_open(struct inode *inode, struct file *file)
        return 0;
 }
 
-static struct file_operations netlink_seq_fops = {
+static const struct file_operations netlink_seq_fops = {
        .owner          = THIS_MODULE,
        .open           = netlink_seq_open,
        .read           = seq_read,
@@ -1732,7 +1761,7 @@ int netlink_unregister_notifier(struct notifier_block *nb)
 {
        return atomic_notifier_chain_unregister(&netlink_chain, nb);
 }
-                
+
 static const struct proto_ops netlink_ops = {
        .family =       PF_NETLINK,
        .owner =        THIS_MODULE,
@@ -1808,7 +1837,7 @@ static int __init netlink_proto_init(void)
 #ifdef CONFIG_PROC_FS
        proc_net_fops_create("netlink", 0, &netlink_seq_fops);
 #endif
-       /* The netlink device handler may be needed early. */ 
+       /* The netlink device handler may be needed early. */
        rtnetlink_init();
 out:
        return err;
@@ -1820,12 +1849,10 @@ core_initcall(netlink_proto_init);
 
 EXPORT_SYMBOL(netlink_ack);
 EXPORT_SYMBOL(netlink_run_queue);
-EXPORT_SYMBOL(netlink_queue_skip);
 EXPORT_SYMBOL(netlink_broadcast);
 EXPORT_SYMBOL(netlink_dump_start);
 EXPORT_SYMBOL(netlink_kernel_create);
 EXPORT_SYMBOL(netlink_register_notifier);
-EXPORT_SYMBOL(netlink_set_err);
 EXPORT_SYMBOL(netlink_set_nonroot);
 EXPORT_SYMBOL(netlink_unicast);
 EXPORT_SYMBOL(netlink_unregister_notifier);