Merge git://git.kernel.org/pub/scm/linux/kernel/git/davem/net
[firefly-linux-kernel-4.4.55.git] / net / vmw_vsock / vmci_transport.c
index a70ace83a1531232a2f11afc63c8b4ab1c4ea918..daff75200e256705b3e0e9605cf6e6baf6ac17e3 100644 (file)
@@ -123,6 +123,14 @@ static s32 vmci_transport_error_to_vsock_error(s32 vmci_error)
        return err > 0 ? -err : err;
 }
 
+static u32 vmci_transport_peer_rid(u32 peer_cid)
+{
+       if (VMADDR_CID_HYPERVISOR == peer_cid)
+               return VMCI_TRANSPORT_HYPERVISOR_PACKET_RID;
+
+       return VMCI_TRANSPORT_PACKET_RID;
+}
+
 static inline void
 vmci_transport_packet_init(struct vmci_transport_packet *pkt,
                           struct sockaddr_vm *src,
@@ -140,7 +148,7 @@ vmci_transport_packet_init(struct vmci_transport_packet *pkt,
        pkt->dg.src = vmci_make_handle(VMADDR_CID_ANY,
                                       VMCI_TRANSPORT_PACKET_RID);
        pkt->dg.dst = vmci_make_handle(dst->svm_cid,
-                                      VMCI_TRANSPORT_PACKET_RID);
+                                      vmci_transport_peer_rid(dst->svm_cid));
        pkt->dg.payload_size = sizeof(*pkt) - sizeof(pkt->dg);
        pkt->version = VMCI_TRANSPORT_PACKET_VERSION;
        pkt->type = type;
@@ -464,19 +472,16 @@ static struct sock *vmci_transport_get_pending(
        struct vsock_sock *vlistener;
        struct vsock_sock *vpending;
        struct sock *pending;
+       struct sockaddr_vm src;
+
+       vsock_addr_init(&src, pkt->dg.src.context, pkt->src_port);
 
        vlistener = vsock_sk(listener);
 
        list_for_each_entry(vpending, &vlistener->pending_links,
                            pending_links) {
-               struct sockaddr_vm src;
-               struct sockaddr_vm dst;
-
-               vsock_addr_init(&src, pkt->dg.src.context, pkt->src_port);
-               vsock_addr_init(&dst, pkt->dg.dst.context, pkt->dst_port);
-
                if (vsock_addr_equals_addr(&src, &vpending->remote_addr) &&
-                   vsock_addr_equals_addr(&dst, &vpending->local_addr)) {
+                   pkt->dst_port == vpending->local_addr.svm_port) {
                        pending = sk_vsock(vpending);
                        sock_hold(pending);
                        goto found;
@@ -511,6 +516,9 @@ static bool vmci_transport_is_trusted(struct vsock_sock *vsock, u32 peer_cid)
 
 static bool vmci_transport_allow_dgram(struct vsock_sock *vsock, u32 peer_cid)
 {
+       if (VMADDR_CID_HYPERVISOR == peer_cid)
+               return true;
+
        if (vsock->cached_peer != peer_cid) {
                vsock->cached_peer = peer_cid;
                if (!vmci_transport_is_trusted(vsock, peer_cid) &&
@@ -631,7 +639,6 @@ static int vmci_transport_recv_dgram_cb(void *data, struct vmci_datagram *dg)
 static bool vmci_transport_stream_allow(u32 cid, u32 port)
 {
        static const u32 non_socket_contexts[] = {
-               VMADDR_CID_HYPERVISOR,
                VMADDR_CID_RESERVED,
        };
        int i;
@@ -670,7 +677,7 @@ static int vmci_transport_recv_stream_cb(void *data, struct vmci_datagram *dg)
         */
 
        if (!vmci_transport_stream_allow(dg->src.context, -1)
-           || VMCI_TRANSPORT_PACKET_RID != dg->src.resource)
+           || vmci_transport_peer_rid(dg->src.context) != dg->src.resource)
                return VMCI_ERROR_NO_ACCESS;
 
        if (VMCI_DG_SIZE(dg) < sizeof(*pkt))
@@ -739,10 +746,15 @@ static int vmci_transport_recv_stream_cb(void *data, struct vmci_datagram *dg)
         */
        bh_lock_sock(sk);
 
-       if (!sock_owned_by_user(sk) && sk->sk_state == SS_CONNECTED)
-               vmci_trans(vsk)->notify_ops->handle_notify_pkt(
-                               sk, pkt, true, &dst, &src,
-                               &bh_process_pkt);
+       if (!sock_owned_by_user(sk)) {
+               /* The local context ID may be out of date, update it. */
+               vsk->local_addr.svm_cid = dst.svm_cid;
+
+               if (sk->sk_state == SS_CONNECTED)
+                       vmci_trans(vsk)->notify_ops->handle_notify_pkt(
+                                       sk, pkt, true, &dst, &src,
+                                       &bh_process_pkt);
+       }
 
        bh_unlock_sock(sk);
 
@@ -902,6 +914,9 @@ static void vmci_transport_recv_pkt_work(struct work_struct *work)
 
        lock_sock(sk);
 
+       /* The local context ID may be out of date. */
+       vsock_sk(sk)->local_addr.svm_cid = pkt->dg.dst.context;
+
        switch (sk->sk_state) {
        case SS_LISTEN:
                vmci_transport_recv_listen(sk, pkt);
@@ -958,6 +973,10 @@ static int vmci_transport_recv_listen(struct sock *sk,
        pending = vmci_transport_get_pending(sk, pkt);
        if (pending) {
                lock_sock(pending);
+
+               /* The local context ID may be out of date. */
+               vsock_sk(pending)->local_addr.svm_cid = pkt->dg.dst.context;
+
                switch (pending->sk_state) {
                case SS_CONNECTING:
                        err = vmci_transport_recv_connecting_server(sk,
@@ -1727,6 +1746,8 @@ static int vmci_transport_dgram_dequeue(struct kiocb *kiocb,
        if (flags & MSG_OOB || flags & MSG_ERRQUEUE)
                return -EOPNOTSUPP;
 
+       msg->msg_namelen = 0;
+
        /* Retrieve the head sk_buff from the socket's receive queue. */
        err = 0;
        skb = skb_recv_datagram(&vsk->sk, flags, noblock, &err);
@@ -1759,7 +1780,6 @@ static int vmci_transport_dgram_dequeue(struct kiocb *kiocb,
        if (err)
                goto out;
 
-       msg->msg_namelen = 0;
        if (msg->msg_name) {
                struct sockaddr_vm *vm_addr;