Merge branch 'mount.part1' of git://git.kernel.org/pub/scm/linux/kernel/git/viro/vfs
[linux] / net / sunrpc / auth.c
index ad8ead7..1ff9768 100644 (file)
@@ -39,6 +39,20 @@ static const struct rpc_authops __rcu *auth_flavors[RPC_AUTH_MAXFLAVOR] = {
 static LIST_HEAD(cred_unused);
 static unsigned long number_cred_unused;
 
+static struct cred machine_cred = {
+       .usage = ATOMIC_INIT(1),
+};
+
+/*
+ * Return the machine_cred pointer to be used whenever
+ * the a generic machine credential is needed.
+ */
+const struct cred *rpc_machine_cred(void)
+{
+       return &machine_cred;
+}
+EXPORT_SYMBOL_GPL(rpc_machine_cred);
+
 #define MAX_HASHTABLE_BITS (14)
 static int param_set_hashtbl_sz(const char *val, const struct kernel_param *kp)
 {
@@ -346,29 +360,6 @@ out_nocache:
 }
 EXPORT_SYMBOL_GPL(rpcauth_init_credcache);
 
-/*
- * Setup a credential key lifetime timeout notification
- */
-int
-rpcauth_key_timeout_notify(struct rpc_auth *auth, struct rpc_cred *cred)
-{
-       if (!cred->cr_auth->au_ops->key_timeout)
-               return 0;
-       return cred->cr_auth->au_ops->key_timeout(auth, cred);
-}
-EXPORT_SYMBOL_GPL(rpcauth_key_timeout_notify);
-
-bool
-rpcauth_cred_key_to_expire(struct rpc_auth *auth, struct rpc_cred *cred)
-{
-       if (auth->au_flags & RPCAUTH_AUTH_NO_CRKEY_TIMEOUT)
-               return false;
-       if (!cred->cr_ops->crkey_to_expire)
-               return false;
-       return cred->cr_ops->crkey_to_expire(cred);
-}
-EXPORT_SYMBOL_GPL(rpcauth_cred_key_to_expire);
-
 char *
 rpcauth_stringify_acceptor(struct rpc_cred *cred)
 {
@@ -587,13 +578,6 @@ rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred,
        hlist_for_each_entry_rcu(entry, &cache->hashtable[nr], cr_hash) {
                if (!entry->cr_ops->crmatch(acred, entry, flags))
                        continue;
-               if (flags & RPCAUTH_LOOKUP_RCU) {
-                       if (test_bit(RPCAUTH_CRED_NEW, &entry->cr_flags) ||
-                           refcount_read(&entry->cr_count) == 0)
-                               continue;
-                       cred = entry;
-                       break;
-               }
                cred = get_rpccred(entry);
                if (cred)
                        break;
@@ -603,9 +587,6 @@ rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred,
        if (cred != NULL)
                goto found;
 
-       if (flags & RPCAUTH_LOOKUP_RCU)
-               return ERR_PTR(-ECHILD);
-
        new = auth->au_ops->crcreate(auth, acred, flags, gfp);
        if (IS_ERR(new)) {
                cred = new;
@@ -656,9 +637,7 @@ rpcauth_lookupcred(struct rpc_auth *auth, int flags)
                auth->au_ops->au_name);
 
        memset(&acred, 0, sizeof(acred));
-       acred.uid = cred->fsuid;
-       acred.gid = cred->fsgid;
-       acred.group_info = cred->group_info;
+       acred.cred = cred;
        ret = auth->au_ops->lookup_cred(auth, &acred, flags);
        return ret;
 }
@@ -672,31 +651,41 @@ rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred,
        INIT_LIST_HEAD(&cred->cr_lru);
        refcount_set(&cred->cr_count, 1);
        cred->cr_auth = auth;
+       cred->cr_flags = 0;
        cred->cr_ops = ops;
        cred->cr_expire = jiffies;
-       cred->cr_uid = acred->uid;
+       cred->cr_cred = get_cred(acred->cred);
 }
 EXPORT_SYMBOL_GPL(rpcauth_init_cred);
 
-struct rpc_cred *
-rpcauth_generic_bind_cred(struct rpc_task *task, struct rpc_cred *cred, int lookupflags)
+static struct rpc_cred *
+rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags)
 {
-       dprintk("RPC: %5u holding %s cred %p\n", task->tk_pid,
-                       cred->cr_auth->au_ops->au_name, cred);
-       return get_rpccred(cred);
+       struct rpc_auth *auth = task->tk_client->cl_auth;
+       struct auth_cred acred = {
+               .cred = get_task_cred(&init_task),
+       };
+       struct rpc_cred *ret;
+
+       dprintk("RPC: %5u looking up %s cred\n",
+               task->tk_pid, task->tk_client->cl_auth->au_ops->au_name);
+       ret = auth->au_ops->lookup_cred(auth, &acred, lookupflags);
+       put_cred(acred.cred);
+       return ret;
 }
-EXPORT_SYMBOL_GPL(rpcauth_generic_bind_cred);
 
 static struct rpc_cred *
-rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags)
+rpcauth_bind_machine_cred(struct rpc_task *task, int lookupflags)
 {
        struct rpc_auth *auth = task->tk_client->cl_auth;
        struct auth_cred acred = {
-               .uid = GLOBAL_ROOT_UID,
-               .gid = GLOBAL_ROOT_GID,
+               .principal = task->tk_client->cl_principal,
+               .cred = init_task.cred,
        };
 
-       dprintk("RPC: %5u looking up %s cred\n",
+       if (!acred.principal)
+               return NULL;
+       dprintk("RPC: %5u looking up %s machine cred\n",
                task->tk_pid, task->tk_client->cl_auth->au_ops->au_name);
        return auth->au_ops->lookup_cred(auth, &acred, lookupflags);
 }
@@ -712,18 +701,33 @@ rpcauth_bind_new_cred(struct rpc_task *task, int lookupflags)
 }
 
 static int
-rpcauth_bindcred(struct rpc_task *task, struct rpc_cred *cred, int flags)
+rpcauth_bindcred(struct rpc_task *task, const struct cred *cred, int flags)
 {
        struct rpc_rqst *req = task->tk_rqstp;
-       struct rpc_cred *new;
+       struct rpc_cred *new = NULL;
        int lookupflags = 0;
+       struct rpc_auth *auth = task->tk_client->cl_auth;
+       struct auth_cred acred = {
+               .cred = cred,
+       };
 
        if (flags & RPC_TASK_ASYNC)
                lookupflags |= RPCAUTH_LOOKUP_NEW;
-       if (cred != NULL)
-               new = cred->cr_ops->crbind(task, cred, lookupflags);
-       else if (flags & RPC_TASK_ROOTCREDS)
+       if (task->tk_op_cred)
+               /* Task must use exactly this rpc_cred */
+               new = get_rpccred(task->tk_op_cred);
+       else if (cred != NULL && cred != &machine_cred)
+               new = auth->au_ops->lookup_cred(auth, &acred, lookupflags);
+       else if (cred == &machine_cred)
+               new = rpcauth_bind_machine_cred(task, lookupflags);
+
+       /* If machine cred couldn't be bound, try a root cred */
+       if (new)
+               ;
+       else if (cred == &machine_cred || (flags & RPC_TASK_ROOTCREDS))
                new = rpcauth_bind_root_cred(task, lookupflags);
+       else if (flags & RPC_TASK_NULLCREDS)
+               new = authnull_ops.lookup_cred(NULL, NULL, 0);
        else
                new = rpcauth_bind_new_cred(task, lookupflags);
        if (IS_ERR(new))
@@ -901,15 +905,10 @@ int __init rpcauth_init_module(void)
        err = rpc_init_authunix();
        if (err < 0)
                goto out1;
-       err = rpc_init_generic_auth();
-       if (err < 0)
-               goto out2;
        err = register_shrinker(&rpc_cred_shrinker);
        if (err < 0)
-               goto out3;
+               goto out2;
        return 0;
-out3:
-       rpc_destroy_generic_auth();
 out2:
        rpc_destroy_authunix();
 out1:
@@ -919,6 +918,5 @@ out1:
 void rpcauth_remove_module(void)
 {
        rpc_destroy_authunix();
-       rpc_destroy_generic_auth();
        unregister_shrinker(&rpc_cred_shrinker);
 }