9p connect fixes
authorAl Viro <viro@zeniv.linux.org.uk>
Wed, 5 Aug 2009 22:02:43 +0000 (02:02 +0400)
committerAl Viro <viro@zeniv.linux.org.uk>
Wed, 16 Dec 2009 17:16:41 +0000 (12:16 -0500)
* if we fail in p9_conn_create(), we shouldn't leak references to struct file.
  Logics in ->close() doesn't help - ->trans is already gone by the time it's
  called.
* sock_create_kern() can fail.
* use of sock_map_fd() is all fscked up; I'd fixed most of that, but the
  rest will have to wait for a bit more work in net/socket.c (we still are
  violating the basic rule of working with descriptor table: "once the reference
  is installed there, don't rely on finding it there again").

Signed-off-by: Al Viro <viro@zeniv.linux.org.uk>
net/9p/trans_fd.c

index 4dd873e3a1bb185554fb2bd702b53b49758ec0b2..be1cb909d8c00e5e9cac2a910fa2d8f793249779 100644 (file)
@@ -42,6 +42,8 @@
 #include <net/9p/client.h>
 #include <net/9p/transport.h>
 
+#include <linux/syscalls.h> /* killme */
+
 #define P9_PORT 564
 #define MAX_SOCK_BUF (64*1024)
 #define MAXPOLLWADDR   2
@@ -788,24 +790,41 @@ static int p9_fd_open(struct p9_client *client, int rfd, int wfd)
 
 static int p9_socket_open(struct p9_client *client, struct socket *csocket)
 {
-       int fd, ret;
+       struct p9_trans_fd *p;
+       int ret, fd;
+
+       p = kmalloc(sizeof(struct p9_trans_fd), GFP_KERNEL);
+       if (!p)
+               return -ENOMEM;
 
        csocket->sk->sk_allocation = GFP_NOIO;
        fd = sock_map_fd(csocket, 0);
        if (fd < 0) {
                P9_EPRINTK(KERN_ERR, "p9_socket_open: failed to map fd\n");
+               sock_release(csocket);
+               kfree(p);
                return fd;
        }
 
-       ret = p9_fd_open(client, fd, fd);
-       if (ret < 0) {
-               P9_EPRINTK(KERN_ERR, "p9_socket_open: failed to open fd\n");
+       get_file(csocket->file);
+       get_file(csocket->file);
+       p->wr = p->rd = csocket->file;
+       client->trans = p;
+       client->status = Connected;
+
+       sys_close(fd);  /* still racy */
+
+       p->rd->f_flags |= O_NONBLOCK;
+
+       p->conn = p9_conn_create(client);
+       if (IS_ERR(p->conn)) {
+               ret = PTR_ERR(p->conn);
+               p->conn = NULL;
+               kfree(p);
+               sockfd_put(csocket);
                sockfd_put(csocket);
                return ret;
        }
-
-       ((struct p9_trans_fd *)client->trans)->rd->f_flags |= O_NONBLOCK;
-
        return 0;
 }
 
@@ -883,7 +902,6 @@ p9_fd_create_tcp(struct p9_client *client, const char *addr, char *args)
        struct socket *csocket;
        struct sockaddr_in sin_server;
        struct p9_fd_opts opts;
-       struct p9_trans_fd *p = NULL; /* this gets allocated in p9_fd_open */
 
        err = parse_opts(args, &opts);
        if (err < 0)
@@ -897,12 +915,11 @@ p9_fd_create_tcp(struct p9_client *client, const char *addr, char *args)
        sin_server.sin_family = AF_INET;
        sin_server.sin_addr.s_addr = in_aton(addr);
        sin_server.sin_port = htons(opts.port);
-       sock_create_kern(PF_INET, SOCK_STREAM, IPPROTO_TCP, &csocket);
+       err = sock_create_kern(PF_INET, SOCK_STREAM, IPPROTO_TCP, &csocket);
 
-       if (!csocket) {
+       if (err) {
                P9_EPRINTK(KERN_ERR, "p9_trans_tcp: problem creating socket\n");
-               err = -EIO;
-               goto error;
+               return err;
        }
 
        err = csocket->ops->connect(csocket,
@@ -912,30 +929,11 @@ p9_fd_create_tcp(struct p9_client *client, const char *addr, char *args)
                P9_EPRINTK(KERN_ERR,
                        "p9_trans_tcp: problem connecting socket to %s\n",
                        addr);
-               goto error;
-       }
-
-       err = p9_socket_open(client, csocket);
-       if (err < 0)
-               goto error;
-
-       p = (struct p9_trans_fd *) client->trans;
-       p->conn = p9_conn_create(client);
-       if (IS_ERR(p->conn)) {
-               err = PTR_ERR(p->conn);
-               p->conn = NULL;
-               goto error;
-       }
-
-       return 0;
-
-error:
-       if (csocket)
                sock_release(csocket);
+               return err;
+       }
 
-       kfree(p);
-
-       return err;
+       return p9_socket_open(client, csocket);
 }
 
 static int
@@ -944,49 +942,33 @@ p9_fd_create_unix(struct p9_client *client, const char *addr, char *args)
        int err;
        struct socket *csocket;
        struct sockaddr_un sun_server;
-       struct p9_trans_fd *p = NULL; /* this gets allocated in p9_fd_open */
 
        csocket = NULL;
 
        if (strlen(addr) > UNIX_PATH_MAX) {
                P9_EPRINTK(KERN_ERR, "p9_trans_unix: address too long: %s\n",
                        addr);
-               err = -ENAMETOOLONG;
-               goto error;
+               return -ENAMETOOLONG;
        }
 
        sun_server.sun_family = PF_UNIX;
        strcpy(sun_server.sun_path, addr);
-       sock_create_kern(PF_UNIX, SOCK_STREAM, 0, &csocket);
+       err = sock_create_kern(PF_UNIX, SOCK_STREAM, 0, &csocket);
+       if (err < 0) {
+               P9_EPRINTK(KERN_ERR, "p9_trans_unix: problem creating socket\n");
+               return err;
+       }
        err = csocket->ops->connect(csocket, (struct sockaddr *)&sun_server,
                        sizeof(struct sockaddr_un) - 1, 0);
        if (err < 0) {
                P9_EPRINTK(KERN_ERR,
                        "p9_trans_unix: problem connecting socket: %s: %d\n",
                        addr, err);
-               goto error;
-       }
-
-       err = p9_socket_open(client, csocket);
-       if (err < 0)
-               goto error;
-
-       p = (struct p9_trans_fd *) client->trans;
-       p->conn = p9_conn_create(client);
-       if (IS_ERR(p->conn)) {
-               err = PTR_ERR(p->conn);
-               p->conn = NULL;
-               goto error;
-       }
-
-       return 0;
-
-error:
-       if (csocket)
                sock_release(csocket);
+               return err;
+       }
 
-       kfree(p);
-       return err;
+       return p9_socket_open(client, csocket);
 }
 
 static int
@@ -994,7 +976,7 @@ p9_fd_create(struct p9_client *client, const char *addr, char *args)
 {
        int err;
        struct p9_fd_opts opts;
-       struct p9_trans_fd *p = NULL; /* this get allocated in p9_fd_open */
+       struct p9_trans_fd *p;
 
        parse_opts(args, &opts);
 
@@ -1005,21 +987,19 @@ p9_fd_create(struct p9_client *client, const char *addr, char *args)
 
        err = p9_fd_open(client, opts.rfd, opts.wfd);
        if (err < 0)
-               goto error;
+               return err;
 
        p = (struct p9_trans_fd *) client->trans;
        p->conn = p9_conn_create(client);
        if (IS_ERR(p->conn)) {
                err = PTR_ERR(p->conn);
                p->conn = NULL;
-               goto error;
+               fput(p->rd);
+               fput(p->wr);
+               return err;
        }
 
        return 0;
-
-error:
-       kfree(p);
-       return err;
 }
 
 static struct p9_trans_module p9_tcp_trans = {