Merge tag 'arm64-fixes' of git://git.kernel.org/pub/scm/linux/kernel/git/arm64/linux
[firefly-linux-kernel-4.4.55.git] / net / rxrpc / ar-call.c
index a3bbb360a3f96e0ee4aa6aac2e7eea830305b564..a9e05db0f5d5900e93f87a8567e7533a1745c82a 100644 (file)
 #include <linux/slab.h>
 #include <linux/module.h>
 #include <linux/circ_buf.h>
+#include <linux/hashtable.h>
+#include <linux/spinlock_types.h>
 #include <net/sock.h>
 #include <net/af_rxrpc.h>
 #include "ar-internal.h"
 
+/*
+ * Maximum lifetime of a call (in jiffies).
+ */
+unsigned rxrpc_max_call_lifetime = 60 * HZ;
+
+/*
+ * Time till dead call expires after last use (in jiffies).
+ */
+unsigned rxrpc_dead_call_expiry = 2 * HZ;
+
 const char *const rxrpc_call_states[] = {
        [RXRPC_CALL_CLIENT_SEND_REQUEST]        = "ClSndReq",
        [RXRPC_CALL_CLIENT_AWAIT_REPLY]         = "ClAwtRpl",
@@ -38,8 +50,6 @@ const char *const rxrpc_call_states[] = {
 struct kmem_cache *rxrpc_call_jar;
 LIST_HEAD(rxrpc_calls);
 DEFINE_RWLOCK(rxrpc_call_lock);
-static unsigned int rxrpc_call_max_lifetime = 60;
-static unsigned int rxrpc_dead_call_timeout = 2;
 
 static void rxrpc_destroy_call(struct work_struct *work);
 static void rxrpc_call_life_expired(unsigned long _call);
@@ -47,6 +57,145 @@ static void rxrpc_dead_call_expired(unsigned long _call);
 static void rxrpc_ack_time_expired(unsigned long _call);
 static void rxrpc_resend_time_expired(unsigned long _call);
 
+static DEFINE_SPINLOCK(rxrpc_call_hash_lock);
+static DEFINE_HASHTABLE(rxrpc_call_hash, 10);
+
+/*
+ * Hash function for rxrpc_call_hash
+ */
+static unsigned long rxrpc_call_hashfunc(
+       u8              clientflag,
+       __be32          cid,
+       __be32          call_id,
+       __be32          epoch,
+       __be16          service_id,
+       sa_family_t     proto,
+       void            *localptr,
+       unsigned int    addr_size,
+       const u8        *peer_addr)
+{
+       const u16 *p;
+       unsigned int i;
+       unsigned long key;
+       u32 hcid = ntohl(cid);
+
+       _enter("");
+
+       key = (unsigned long)localptr;
+       /* We just want to add up the __be32 values, so forcing the
+        * cast should be okay.
+        */
+       key += (__force u32)epoch;
+       key += (__force u16)service_id;
+       key += (__force u32)call_id;
+       key += (hcid & RXRPC_CIDMASK) >> RXRPC_CIDSHIFT;
+       key += hcid & RXRPC_CHANNELMASK;
+       key += clientflag;
+       key += proto;
+       /* Step through the peer address in 16-bit portions for speed */
+       for (i = 0, p = (const u16 *)peer_addr; i < addr_size >> 1; i++, p++)
+               key += *p;
+       _leave(" key = 0x%lx", key);
+       return key;
+}
+
+/*
+ * Add a call to the hashtable
+ */
+static void rxrpc_call_hash_add(struct rxrpc_call *call)
+{
+       unsigned long key;
+       unsigned int addr_size = 0;
+
+       _enter("");
+       switch (call->proto) {
+       case AF_INET:
+               addr_size = sizeof(call->peer_ip.ipv4_addr);
+               break;
+       case AF_INET6:
+               addr_size = sizeof(call->peer_ip.ipv6_addr);
+               break;
+       default:
+               break;
+       }
+       key = rxrpc_call_hashfunc(call->in_clientflag, call->cid,
+                                 call->call_id, call->epoch,
+                                 call->service_id, call->proto,
+                                 call->conn->trans->local, addr_size,
+                                 call->peer_ip.ipv6_addr);
+       /* Store the full key in the call */
+       call->hash_key = key;
+       spin_lock(&rxrpc_call_hash_lock);
+       hash_add_rcu(rxrpc_call_hash, &call->hash_node, key);
+       spin_unlock(&rxrpc_call_hash_lock);
+       _leave("");
+}
+
+/*
+ * Remove a call from the hashtable
+ */
+static void rxrpc_call_hash_del(struct rxrpc_call *call)
+{
+       _enter("");
+       spin_lock(&rxrpc_call_hash_lock);
+       hash_del_rcu(&call->hash_node);
+       spin_unlock(&rxrpc_call_hash_lock);
+       _leave("");
+}
+
+/*
+ * Find a call in the hashtable and return it, or NULL if it
+ * isn't there.
+ */
+struct rxrpc_call *rxrpc_find_call_hash(
+       u8              clientflag,
+       __be32          cid,
+       __be32          call_id,
+       __be32          epoch,
+       __be16          service_id,
+       void            *localptr,
+       sa_family_t     proto,
+       const u8        *peer_addr)
+{
+       unsigned long key;
+       unsigned int addr_size = 0;
+       struct rxrpc_call *call = NULL;
+       struct rxrpc_call *ret = NULL;
+
+       _enter("");
+       switch (proto) {
+       case AF_INET:
+               addr_size = sizeof(call->peer_ip.ipv4_addr);
+               break;
+       case AF_INET6:
+               addr_size = sizeof(call->peer_ip.ipv6_addr);
+               break;
+       default:
+               break;
+       }
+
+       key = rxrpc_call_hashfunc(clientflag, cid, call_id, epoch,
+                                 service_id, proto, localptr, addr_size,
+                                 peer_addr);
+       hash_for_each_possible_rcu(rxrpc_call_hash, call, hash_node, key) {
+               if (call->hash_key == key &&
+                   call->call_id == call_id &&
+                   call->cid == cid &&
+                   call->in_clientflag == clientflag &&
+                   call->service_id == service_id &&
+                   call->proto == proto &&
+                   call->local == localptr &&
+                   memcmp(call->peer_ip.ipv6_addr, peer_addr,
+                             addr_size) == 0 &&
+                   call->epoch == epoch) {
+                       ret = call;
+                       break;
+               }
+       }
+       _leave(" = %p", ret);
+       return ret;
+}
+
 /*
  * allocate a new call
  */
@@ -91,7 +240,7 @@ static struct rxrpc_call *rxrpc_alloc_call(gfp_t gfp)
        call->rx_data_expect = 1;
        call->rx_data_eaten = 0;
        call->rx_first_oos = 0;
-       call->ackr_win_top = call->rx_data_eaten + 1 + RXRPC_MAXACKS;
+       call->ackr_win_top = call->rx_data_eaten + 1 + rxrpc_rx_window_size;
        call->creation_jif = jiffies;
        return call;
 }
@@ -128,11 +277,31 @@ static struct rxrpc_call *rxrpc_alloc_client_call(
                return ERR_PTR(ret);
        }
 
+       /* Record copies of information for hashtable lookup */
+       call->proto = rx->proto;
+       call->local = trans->local;
+       switch (call->proto) {
+       case AF_INET:
+               call->peer_ip.ipv4_addr =
+                       trans->peer->srx.transport.sin.sin_addr.s_addr;
+               break;
+       case AF_INET6:
+               memcpy(call->peer_ip.ipv6_addr,
+                      trans->peer->srx.transport.sin6.sin6_addr.in6_u.u6_addr8,
+                      sizeof(call->peer_ip.ipv6_addr));
+               break;
+       }
+       call->epoch = call->conn->epoch;
+       call->service_id = call->conn->service_id;
+       call->in_clientflag = call->conn->in_clientflag;
+       /* Add the new call to the hashtable */
+       rxrpc_call_hash_add(call);
+
        spin_lock(&call->conn->trans->peer->lock);
        list_add(&call->error_link, &call->conn->trans->peer->error_targets);
        spin_unlock(&call->conn->trans->peer->lock);
 
-       call->lifetimer.expires = jiffies + rxrpc_call_max_lifetime * HZ;
+       call->lifetimer.expires = jiffies + rxrpc_max_call_lifetime;
        add_timer(&call->lifetimer);
 
        _leave(" = %p", call);
@@ -320,9 +489,12 @@ struct rxrpc_call *rxrpc_incoming_call(struct rxrpc_sock *rx,
                parent = *p;
                call = rb_entry(parent, struct rxrpc_call, conn_node);
 
-               if (call_id < call->call_id)
+               /* The tree is sorted in order of the __be32 value without
+                * turning it into host order.
+                */
+               if ((__force u32)call_id < (__force u32)call->call_id)
                        p = &(*p)->rb_left;
-               else if (call_id > call->call_id)
+               else if ((__force u32)call_id > (__force u32)call->call_id)
                        p = &(*p)->rb_right;
                else
                        goto old_call;
@@ -347,9 +519,31 @@ struct rxrpc_call *rxrpc_incoming_call(struct rxrpc_sock *rx,
        list_add_tail(&call->link, &rxrpc_calls);
        write_unlock_bh(&rxrpc_call_lock);
 
+       /* Record copies of information for hashtable lookup */
+       call->proto = rx->proto;
+       call->local = conn->trans->local;
+       switch (call->proto) {
+       case AF_INET:
+               call->peer_ip.ipv4_addr =
+                       conn->trans->peer->srx.transport.sin.sin_addr.s_addr;
+               break;
+       case AF_INET6:
+               memcpy(call->peer_ip.ipv6_addr,
+                      conn->trans->peer->srx.transport.sin6.sin6_addr.in6_u.u6_addr8,
+                      sizeof(call->peer_ip.ipv6_addr));
+               break;
+       default:
+               break;
+       }
+       call->epoch = conn->epoch;
+       call->service_id = conn->service_id;
+       call->in_clientflag = conn->in_clientflag;
+       /* Add the new call to the hashtable */
+       rxrpc_call_hash_add(call);
+
        _net("CALL incoming %d on CONN %d", call->debug_id, call->conn->debug_id);
 
-       call->lifetimer.expires = jiffies + rxrpc_call_max_lifetime * HZ;
+       call->lifetimer.expires = jiffies + rxrpc_max_call_lifetime;
        add_timer(&call->lifetimer);
        _leave(" = %p {%d} [new]", call, call->debug_id);
        return call;
@@ -533,7 +727,7 @@ void rxrpc_release_call(struct rxrpc_call *call)
        del_timer_sync(&call->resend_timer);
        del_timer_sync(&call->ack_timer);
        del_timer_sync(&call->lifetimer);
-       call->deadspan.expires = jiffies + rxrpc_dead_call_timeout * HZ;
+       call->deadspan.expires = jiffies + rxrpc_dead_call_expiry;
        add_timer(&call->deadspan);
 
        _leave("");
@@ -665,6 +859,9 @@ static void rxrpc_cleanup_call(struct rxrpc_call *call)
                rxrpc_put_connection(call->conn);
        }
 
+       /* Remove the call from the hash */
+       rxrpc_call_hash_del(call);
+
        if (call->acks_window) {
                _debug("kill Tx window %d",
                       CIRC_CNT(call->acks_head, call->acks_tail,