SUNRPC: Fix a lock recursion in the auth_gss downcall
authorTrond Myklebust <Trond.Myklebust@netapp.com>
Wed, 1 Feb 2006 17:18:36 +0000 (12:18 -0500)
committerTrond Myklebust <Trond.Myklebust@netapp.com>
Wed, 1 Feb 2006 17:52:23 +0000 (12:52 -0500)
 When we look up a new cred in the auth_gss downcall so that we can stuff
 the credcache, we do not want that lookup to queue up an upcall in order
 to initialise it. To do an upcall here not only redundant, but since we
 are already holding the inode->i_mutex, it will trigger a lock recursion.

 This patch allows rpcauth cache searches to indicate that they can cope
 with uninitialised credentials.

Signed-off-by: Trond Myklebust <Trond.Myklebust@netapp.com>
include/linux/sunrpc/auth.h
net/sunrpc/auth.c
net/sunrpc/auth_gss/auth_gss.c
net/sunrpc/auth_unix.c

index b68c11a2d6dd912274b146df30d17bc4045800a9..bfc5fb27953949edc244f0b8228701b8f047917c 100644 (file)
@@ -50,6 +50,7 @@ struct rpc_cred {
 };
 #define RPCAUTH_CRED_LOCKED    0x0001
 #define RPCAUTH_CRED_UPTODATE  0x0002
+#define RPCAUTH_CRED_NEW       0x0004
 
 #define RPCAUTH_CRED_MAGIC     0x0f4aa4f0
 
@@ -87,6 +88,10 @@ struct rpc_auth {
                                                 * uid/gid, fs[ug]id, gids)
                                                 */
 
+/* Flags for rpcauth_lookupcred() */
+#define RPCAUTH_LOOKUP_NEW             0x01    /* Accept an uninitialised cred */
+#define RPCAUTH_LOOKUP_ROOTCREDS       0x02    /* This really ought to go! */
+
 /*
  * Client authentication ops
  */
index 9ac1b8c26c01184595f34de62db8ef23b15c6890..1ca89c36da7abaa629e5485b0b04ca25afe42a69 100644 (file)
@@ -184,7 +184,7 @@ rpcauth_gc_credcache(struct rpc_auth *auth, struct hlist_head *free)
  */
 struct rpc_cred *
 rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred,
-               int taskflags)
+               int flags)
 {
        struct rpc_cred_cache *cache = auth->au_credcache;
        HLIST_HEAD(free);
@@ -193,7 +193,7 @@ rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred,
                        *cred = NULL;
        int             nr = 0;
 
-       if (!(taskflags & RPC_TASK_ROOTCREDS))
+       if (!(flags & RPCAUTH_LOOKUP_ROOTCREDS))
                nr = acred->uid & RPC_CREDCACHE_MASK;
 retry:
        spin_lock(&rpc_credcache_lock);
@@ -202,7 +202,7 @@ retry:
        hlist_for_each_safe(pos, next, &cache->hashtable[nr]) {
                struct rpc_cred *entry;
                entry = hlist_entry(pos, struct rpc_cred, cr_hash);
-               if (entry->cr_ops->crmatch(acred, entry, taskflags)) {
+               if (entry->cr_ops->crmatch(acred, entry, flags)) {
                        hlist_del(&entry->cr_hash);
                        cred = entry;
                        break;
@@ -224,7 +224,7 @@ retry:
        rpcauth_destroy_credlist(&free);
 
        if (!cred) {
-               new = auth->au_ops->crcreate(auth, acred, taskflags);
+               new = auth->au_ops->crcreate(auth, acred, flags);
                if (!IS_ERR(new)) {
 #ifdef RPC_DEBUG
                        new->cr_magic = RPCAUTH_CRED_MAGIC;
@@ -238,7 +238,7 @@ retry:
 }
 
 struct rpc_cred *
-rpcauth_lookupcred(struct rpc_auth *auth, int taskflags)
+rpcauth_lookupcred(struct rpc_auth *auth, int flags)
 {
        struct auth_cred acred = {
                .uid = current->fsuid,
@@ -250,7 +250,7 @@ rpcauth_lookupcred(struct rpc_auth *auth, int taskflags)
        dprintk("RPC:     looking up %s cred\n",
                auth->au_ops->au_name);
        get_group_info(acred.group_info);
-       ret = auth->au_ops->lookup_cred(auth, &acred, taskflags);
+       ret = auth->au_ops->lookup_cred(auth, &acred, flags);
        put_group_info(acred.group_info);
        return ret;
 }
@@ -265,11 +265,14 @@ rpcauth_bindcred(struct rpc_task *task)
                .group_info = current->group_info,
        };
        struct rpc_cred *ret;
+       int flags = 0;
 
        dprintk("RPC: %4d looking up %s cred\n",
                task->tk_pid, task->tk_auth->au_ops->au_name);
        get_group_info(acred.group_info);
-       ret = auth->au_ops->lookup_cred(auth, &acred, task->tk_flags);
+       if (task->tk_flags & RPC_TASK_ROOTCREDS)
+               flags |= RPCAUTH_LOOKUP_ROOTCREDS;
+       ret = auth->au_ops->lookup_cred(auth, &acred, flags);
        if (!IS_ERR(ret))
                task->tk_msg.rpc_cred = ret;
        else
index 8d782282ec194c6f682fd71b2a100c25141c45c0..03affcbf6292b69dfe821cbea0998de2841a9a12 100644 (file)
@@ -158,6 +158,7 @@ gss_cred_set_ctx(struct rpc_cred *cred, struct gss_cl_ctx *ctx)
        old = gss_cred->gc_ctx;
        gss_cred->gc_ctx = ctx;
        cred->cr_flags |= RPCAUTH_CRED_UPTODATE;
+       cred->cr_flags &= ~RPCAUTH_CRED_NEW;
        write_unlock(&gss_ctx_lock);
        if (old)
                gss_put_ctx(old);
@@ -580,7 +581,7 @@ gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen)
        } else {
                struct auth_cred acred = { .uid = uid };
                spin_unlock(&gss_auth->lock);
-               cred = rpcauth_lookup_credcache(clnt->cl_auth, &acred, 0);
+               cred = rpcauth_lookup_credcache(clnt->cl_auth, &acred, RPCAUTH_LOOKUP_NEW);
                if (IS_ERR(cred)) {
                        err = PTR_ERR(cred);
                        goto err_put_ctx;
@@ -758,13 +759,13 @@ gss_destroy_cred(struct rpc_cred *rc)
  * Lookup RPCSEC_GSS cred for the current process
  */
 static struct rpc_cred *
-gss_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int taskflags)
+gss_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags)
 {
-       return rpcauth_lookup_credcache(auth, acred, taskflags);
+       return rpcauth_lookup_credcache(auth, acred, flags);
 }
 
 static struct rpc_cred *
-gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int taskflags)
+gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags)
 {
        struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth);
        struct gss_cred *cred = NULL;
@@ -785,13 +786,17 @@ gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int taskflags)
         */
        cred->gc_flags = 0;
        cred->gc_base.cr_ops = &gss_credops;
+       cred->gc_base.cr_flags = RPCAUTH_CRED_NEW;
        cred->gc_service = gss_auth->service;
+       /* Is the caller prepared to initialise the credential? */
+       if (flags & RPCAUTH_LOOKUP_NEW)
+               goto out;
        do {
                err = gss_create_upcall(gss_auth, cred);
        } while (err == -EAGAIN);
        if (err < 0)
                goto out_err;
-
+out:
        return &cred->gc_base;
 
 out_err:
@@ -801,13 +806,21 @@ out_err:
 }
 
 static int
-gss_match(struct auth_cred *acred, struct rpc_cred *rc, int taskflags)
+gss_match(struct auth_cred *acred, struct rpc_cred *rc, int flags)
 {
        struct gss_cred *gss_cred = container_of(rc, struct gss_cred, gc_base);
 
+       /*
+        * If the searchflags have set RPCAUTH_LOOKUP_NEW, then
+        * we don't really care if the credential has expired or not,
+        * since the caller should be prepared to reinitialise it.
+        */
+       if ((flags & RPCAUTH_LOOKUP_NEW) && (rc->cr_flags & RPCAUTH_CRED_NEW))
+               goto out;
        /* Don't match with creds that have expired. */
        if (gss_cred->gc_ctx && time_after(jiffies, gss_cred->gc_ctx->gc_expiry))
                return 0;
+out:
        return (rc->cr_uid == acred->uid);
 }
 
index 1b3ed4fd198735e332b3f8f46492956f91457860..df14b6bfbf10f9a0bc8874b0a332b3b170a06199 100644 (file)
@@ -75,7 +75,7 @@ unx_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags)
 
        atomic_set(&cred->uc_count, 1);
        cred->uc_flags = RPCAUTH_CRED_UPTODATE;
-       if (flags & RPC_TASK_ROOTCREDS) {
+       if (flags & RPCAUTH_LOOKUP_ROOTCREDS) {
                cred->uc_uid = 0;
                cred->uc_gid = 0;
                cred->uc_gids[0] = NOGROUP;
@@ -108,12 +108,12 @@ unx_destroy_cred(struct rpc_cred *cred)
  * request root creds (e.g. for NFS swapping).
  */
 static int
-unx_match(struct auth_cred *acred, struct rpc_cred *rcred, int taskflags)
+unx_match(struct auth_cred *acred, struct rpc_cred *rcred, int flags)
 {
        struct unx_cred *cred = (struct unx_cred *) rcred;
        int             i;
 
-       if (!(taskflags & RPC_TASK_ROOTCREDS)) {
+       if (!(flags & RPCAUTH_LOOKUP_ROOTCREDS)) {
                int groups;
 
                if (cred->uc_uid != acred->uid