ANDROID: crypto: heh - factor out poly_hash algorithm
authorEric Biggers <ebiggers@google.com>
Wed, 11 Jan 2017 18:36:41 +0000 (10:36 -0800)
committerAmit Pundir <amit.pundir@linaro.org>
Mon, 10 Apr 2017 07:42:16 +0000 (13:12 +0530)
Factor most of poly_hash() out into its own keyed hash algorithm so that
optimized architecture-specific implementations of it will be possible.

For now we call poly_hash through the shash API, since HEH already had
an example of using shash for another algorithm (CMAC), and we will not
be adding any poly_hash implementations that require ahash yet.  We can
however switch to ahash later if it becomes useful.

Bug: 32508661
Signed-off-by: Eric Biggers <ebiggers@google.com>
Change-Id: I8de54ddcecd1d7fa6e9842a09506a08129bae0b6

crypto/heh.c

index 48a284cecaa2e45d3a60c7fe87905d5cad0e5d60..10c00aaf797eb36b648856bf3a40550aa4baf95b 100644 (file)
 
 struct heh_instance_ctx {
        struct crypto_shash_spawn cmac;
+       struct crypto_shash_spawn poly_hash;
        struct crypto_skcipher_spawn ecb;
 };
 
 struct heh_tfm_ctx {
        struct crypto_shash *cmac;
+       struct crypto_shash *poly_hash; /* keyed with tau_key */
        struct crypto_ablkcipher *ecb;
-       struct gf128mul_4k *tau_key;
 };
 
 struct heh_cmac_data {
@@ -63,6 +64,10 @@ struct heh_req_ctx { /* aligned to alignmask */
                        struct shash_desc desc;
                        /* + crypto_shash_descsize(cmac) */
                } cmac;
+               struct {
+                       struct shash_desc desc;
+                       /* + crypto_shash_descsize(poly_hash) */
+               } poly_hash;
                struct {
                        u8 keystream[HEH_BLOCK_SIZE];
                        u8 tmp[HEH_BLOCK_SIZE];
@@ -157,73 +162,138 @@ static int generate_betas(struct ablkcipher_request *req,
        return 0;
 }
 
+/*****************************************************************************/
+
 /*
- * Evaluation of a polynomial over GF(2^128) using Horner's rule.  The
- * polynomial is evaluated at 'point'.  The polynomial's coefficients are taken
- * from 'coeffs_sgl' and are for terms with consecutive descending degree ending
- * at degree 1.  'bytes_of_coeffs' is 16 times the number of terms.
+ * This is the generic version of poly_hash.  It does the GF(2^128)
+ * multiplication by 'tau_key' using a precomputed table, without using any
+ * special CPU instructions.  On some platforms, an accelerated version (with
+ * higher cra_priority) may be used instead.
  */
-static be128 evaluate_polynomial(struct gf128mul_4k *point,
-                                struct scatterlist *coeffs_sgl,
-                                unsigned int bytes_of_coeffs)
+
+struct poly_hash_tfm_ctx {
+       struct gf128mul_4k *tau_key;
+};
+
+struct poly_hash_desc_ctx {
+       be128 digest;
+       unsigned int count;
+};
+
+static int poly_hash_setkey(struct crypto_shash *tfm,
+                           const u8 *key, unsigned int keylen)
 {
-       be128 value = {0};
-       struct sg_mapping_iter miter;
-       unsigned int remaining = bytes_of_coeffs;
-       unsigned int needed = 0;
+       struct poly_hash_tfm_ctx *tctx = crypto_shash_ctx(tfm);
+       be128 key128;
 
-       sg_miter_start(&miter, coeffs_sgl, sg_nents(coeffs_sgl),
-                      SG_MITER_FROM_SG | SG_MITER_ATOMIC);
-       while (remaining) {
-               be128 coeff;
-               const u8 *src;
-               unsigned int srclen;
-               u8 *dst = (u8 *)&value;
-
-               /*
-                * Note: scatterlist elements are not necessarily evenly
-                * divisible into blocks, nor are they necessarily aligned to
-                * __alignof__(be128).
-                */
-               sg_miter_next(&miter);
+       if (keylen != HEH_BLOCK_SIZE) {
+               crypto_shash_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
+               return -EINVAL;
+       }
+
+       if (tctx->tau_key)
+               gf128mul_free_4k(tctx->tau_key);
+       memcpy(&key128, key, HEH_BLOCK_SIZE);
+       tctx->tau_key = gf128mul_init_4k_ble(&key128);
+       if (!tctx->tau_key)
+               return -ENOMEM;
+       return 0;
+}
+
+static int poly_hash_init(struct shash_desc *desc)
+{
+       struct poly_hash_desc_ctx *ctx = shash_desc_ctx(desc);
 
-               src = miter.addr;
-               srclen = min_t(unsigned int, miter.length, remaining);
-               remaining -= srclen;
+       ctx->digest = (be128) { 0 };
+       ctx->count = 0;
+       return 0;
+}
 
-               if (needed) {
-                       unsigned int n = min(srclen, needed);
-                       u8 *pos = dst + (HEH_BLOCK_SIZE - needed);
+static int poly_hash_update(struct shash_desc *desc, const u8 *src,
+                           unsigned int len)
+{
+       struct poly_hash_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
+       struct poly_hash_desc_ctx *ctx = shash_desc_ctx(desc);
+       unsigned int partial = ctx->count % HEH_BLOCK_SIZE;
+       u8 *dst = (u8 *)&ctx->digest + partial;
 
-                       needed -= n;
-                       srclen -= n;
+       ctx->count += len;
 
-                       while (n--)
-                               *pos++ ^= *src++;
+       /* Finishing at least one block? */
+       if (partial + len >= HEH_BLOCK_SIZE) {
 
-                       if (!needed)
-                               gf128mul_4k_ble(&value, point);
+               if (partial) {
+                       /* Finish the pending block. */
+                       unsigned int n = HEH_BLOCK_SIZE - partial;
+
+                       len -= n;
+                       do {
+                               *dst++ ^= *src++;
+                       } while (--n);
+
+                       gf128mul_4k_ble(&ctx->digest, tctx->tau_key);
                }
 
-               while (srclen >= HEH_BLOCK_SIZE) {
+               /* Process zero or more full blocks. */
+               while (len >= HEH_BLOCK_SIZE) {
+                       be128 coeff;
+
                        memcpy(&coeff, src, HEH_BLOCK_SIZE);
-                       be128_xor(&value, &value, &coeff);
-                       gf128mul_4k_ble(&value, point);
+                       be128_xor(&ctx->digest, &ctx->digest, &coeff);
                        src += HEH_BLOCK_SIZE;
-                       srclen -= HEH_BLOCK_SIZE;
+                       len -= HEH_BLOCK_SIZE;
+                       gf128mul_4k_ble(&ctx->digest, tctx->tau_key);
                }
+               dst = (u8 *)&ctx->digest;
+       }
 
-               if (srclen) {
-                       needed = HEH_BLOCK_SIZE - srclen;
-                       do {
-                               *dst++ ^= *src++;
-                       } while (--srclen);
-               }
+       /* Continue adding the next block to 'digest'. */
+       while (len--)
+               *dst++ ^= *src++;
+       return 0;
+}
+
+static int poly_hash_final(struct shash_desc *desc, u8 *out)
+{
+       struct poly_hash_desc_ctx *ctx = shash_desc_ctx(desc);
+
+       /* Finish the last block if needed. */
+       if (ctx->count % HEH_BLOCK_SIZE) {
+               struct poly_hash_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
+
+               gf128mul_4k_ble(&ctx->digest, tctx->tau_key);
        }
-       sg_miter_stop(&miter);
-       return value;
+
+       memcpy(out, &ctx->digest, HEH_BLOCK_SIZE);
+       return 0;
 }
 
+static void poly_hash_exit(struct crypto_tfm *tfm)
+{
+       struct poly_hash_tfm_ctx *tctx = crypto_tfm_ctx(tfm);
+
+       gf128mul_free_4k(tctx->tau_key);
+}
+
+static struct shash_alg poly_hash_alg = {
+       .digestsize     = HEH_BLOCK_SIZE,
+       .init           = poly_hash_init,
+       .update         = poly_hash_update,
+       .final          = poly_hash_final,
+       .setkey         = poly_hash_setkey,
+       .descsize       = sizeof(struct poly_hash_desc_ctx),
+       .base           = {
+               .cra_name               = "poly_hash",
+               .cra_driver_name        = "poly_hash-generic",
+               .cra_priority           = 100,
+               .cra_ctxsize            = sizeof(struct poly_hash_tfm_ctx),
+               .cra_exit               = poly_hash_exit,
+               .cra_module             = THIS_MODULE,
+       },
+};
+
+/*****************************************************************************/
+
 /*
  * Split the message into 16 byte blocks, padding out the last block, and use
  * the blocks as coefficients in the evaluation of a polynomial over GF(2^128)
@@ -242,18 +312,42 @@ static be128 evaluate_polynomial(struct gf128mul_4k *point,
  *     N is the number of full blocks in the message
  *     m_i is the i-th full block in the message for i = 0 to N-1 inclusive
  *     m_N is the partial block of the message zero-padded up to 16 bytes
+ *
+ * Note that most of this is now separated out into its own keyed hash
+ * algorithm, to allow optimized implementations.  However, we still handle the
+ * swapping of the last two coefficients here in the HEH template because this
+ * simplifies the poly_hash algorithms: they don't have to buffer an extra
+ * block, don't have to duplicate as much code, and are more similar to GHASH.
  */
-static be128 poly_hash(struct crypto_ablkcipher *tfm, struct scatterlist *sgl,
-                      unsigned int len)
+static int poly_hash(struct ablkcipher_request *req, struct scatterlist *sgl,
+                    be128 *hash)
 {
+       struct heh_req_ctx *rctx = heh_req_ctx(req);
+       struct crypto_ablkcipher *tfm = crypto_ablkcipher_reqtfm(req);
        struct heh_tfm_ctx *ctx = crypto_ablkcipher_ctx(tfm);
-       unsigned int tail_offset = get_tail_offset(len);
-       unsigned int tail_len = len - tail_offset;
-       be128 hash;
+       struct shash_desc *desc = &rctx->u.poly_hash.desc;
+       unsigned int tail_offset = get_tail_offset(req->nbytes);
+       unsigned int tail_len = req->nbytes - tail_offset;
        be128 tail[2];
+       unsigned int i, n;
+       struct sg_mapping_iter miter;
+       int err;
+
+       desc->tfm = ctx->poly_hash;
+       desc->flags = req->base.flags;
 
        /* Handle all full blocks except the last */
-       hash = evaluate_polynomial(ctx->tau_key, sgl, tail_offset);
+       err = crypto_shash_init(desc);
+       sg_miter_start(&miter, sgl, sg_nents(sgl),
+                      SG_MITER_FROM_SG | SG_MITER_ATOMIC);
+       for (i = 0; i < tail_offset && !err; i += n) {
+               sg_miter_next(&miter);
+               n = min_t(unsigned int, miter.length, tail_offset - i);
+               err = crypto_shash_update(desc, miter.addr, n);
+       }
+       sg_miter_stop(&miter);
+       if (err)
+               return err;
 
        /* Handle the last full block and the partial block */
        scatterwalk_map_and_copy(tail, sgl, tail_offset, tail_len, 0);
@@ -261,11 +355,15 @@ static be128 poly_hash(struct crypto_ablkcipher *tfm, struct scatterlist *sgl,
        if (tail_len != HEH_BLOCK_SIZE) {
                /* handle the partial block */
                memset((u8 *)tail + tail_len, 0, sizeof(tail) - tail_len);
-               be128_xor(&hash, &hash, &tail[1]);
-               gf128mul_4k_ble(&hash, ctx->tau_key);
+               err = crypto_shash_update(desc, (u8 *)&tail[1], HEH_BLOCK_SIZE);
+               if (err)
+                       return err;
        }
-       be128_xor(&hash, &hash, &tail[0]);
-       return hash;
+       err = crypto_shash_final(desc, (u8 *)hash);
+       if (err)
+               return err;
+       be128_xor(hash, hash, &tail[0]);
+       return 0;
 }
 
 /*
@@ -323,13 +421,14 @@ static int heh_tfm_blocks(struct ablkcipher_request *req,
 static int heh_hash(struct ablkcipher_request *req, const be128 *beta_key)
 {
        be128 hash;
-       struct crypto_ablkcipher *tfm = crypto_ablkcipher_reqtfm(req);
        unsigned int tail_offset = get_tail_offset(req->nbytes);
        unsigned int partial_len = req->nbytes % HEH_BLOCK_SIZE;
        int err;
 
        /* poly_hash() the full message including the partial block */
-       hash = poly_hash(tfm, req->src, req->nbytes);
+       err = poly_hash(req, req->src, &hash);
+       if (err)
+               return err;
 
        /* Transform all full blocks except the last */
        err = heh_tfm_blocks(req, req->src, req->dst, tail_offset, &hash,
@@ -361,10 +460,8 @@ static int heh_hash_inv(struct ablkcipher_request *req, const be128 *beta_key)
        be128 tmp;
        struct scatterlist tmp_sgl[2];
        struct scatterlist *tail_sgl;
-       unsigned int len = req->nbytes;
-       unsigned int tail_offset = get_tail_offset(len);
+       unsigned int tail_offset = get_tail_offset(req->nbytes);
        struct scatterlist *sgl = req->dst;
-       struct crypto_ablkcipher *tfm = crypto_ablkcipher_reqtfm(req);
        int err;
 
        /*
@@ -388,7 +485,9 @@ static int heh_hash_inv(struct ablkcipher_request *req, const be128 *beta_key)
         */
        memset(&tmp, 0, sizeof(tmp));
        scatterwalk_map_and_copy(&tmp, tail_sgl, 0, HEH_BLOCK_SIZE, 1);
-       tmp = poly_hash(tfm, sgl, len);
+       err = poly_hash(req, sgl, &tmp);
+       if (err)
+               return err;
        be128_xor(&tmp, &tmp, &hash);
        scatterwalk_map_and_copy(&tmp, tail_sgl, 0, HEH_BLOCK_SIZE, 1);
        return 0;
@@ -522,8 +621,6 @@ static int heh_ecb(struct ablkcipher_request *req, bool decrypt)
 
 static int heh_crypt(struct ablkcipher_request *req, bool decrypt)
 {
-       struct crypto_ablkcipher *tfm = crypto_ablkcipher_reqtfm(req);
-       struct heh_tfm_ctx *ctx = crypto_ablkcipher_ctx(tfm);
        struct heh_req_ctx *rctx = heh_req_ctx(req);
        int err;
 
@@ -531,9 +628,6 @@ static int heh_crypt(struct ablkcipher_request *req, bool decrypt)
        if (req->nbytes < HEH_BLOCK_SIZE)
                return -EINVAL;
 
-       /* Key must have been set */
-       if (!ctx->tau_key)
-               return -ENOKEY;
        err = generate_betas(req, &rctx->beta1_key, &rctx->beta2_key);
        if (err)
                return err;
@@ -602,11 +696,8 @@ static int heh_setkey(struct crypto_ablkcipher *parent, const u8 *key,
                memcpy(derived_keys + i, digest, HEH_BLOCK_SIZE);
        }
 
-       if (ctx->tau_key)
-               gf128mul_free_4k(ctx->tau_key);
-       err = -ENOMEM;
-       ctx->tau_key = gf128mul_init_4k_ble((const be128 *)derived_keys);
-       if (!ctx->tau_key)
+       err = crypto_shash_setkey(ctx->poly_hash, derived_keys, HEH_BLOCK_SIZE);
+       if (err)
                goto out;
 
        crypto_ablkcipher_clear_flags(ecb, CRYPTO_TFM_REQ_MASK);
@@ -627,6 +718,7 @@ static int heh_init_tfm(struct crypto_tfm *tfm)
        struct heh_instance_ctx *ictx = crypto_instance_ctx(inst);
        struct heh_tfm_ctx *ctx = crypto_tfm_ctx(tfm);
        struct crypto_shash *cmac;
+       struct crypto_shash *poly_hash;
        struct crypto_ablkcipher *ecb;
        unsigned int reqsize;
        int err;
@@ -635,25 +727,37 @@ static int heh_init_tfm(struct crypto_tfm *tfm)
        if (IS_ERR(cmac))
                return PTR_ERR(cmac);
 
+       poly_hash = crypto_spawn_shash(&ictx->poly_hash);
+       err = PTR_ERR(poly_hash);
+       if (IS_ERR(poly_hash))
+               goto err_free_cmac;
+
        ecb = crypto_spawn_skcipher(&ictx->ecb);
        err = PTR_ERR(ecb);
        if (IS_ERR(ecb))
-               goto err_free_cmac;
+               goto err_free_poly_hash;
 
        ctx->cmac = cmac;
+       ctx->poly_hash = poly_hash;
        ctx->ecb = ecb;
 
        reqsize = crypto_tfm_alg_alignmask(tfm) &
                  ~(crypto_tfm_ctx_alignment() - 1);
-       reqsize += max(offsetof(struct heh_req_ctx, u.cmac.desc) +
-                       sizeof(struct shash_desc) +
-                       crypto_shash_descsize(cmac),
-                      offsetof(struct heh_req_ctx, u.ecb.req) +
-                       sizeof(struct ablkcipher_request) +
-                       crypto_ablkcipher_reqsize(ecb));
+       reqsize += max3(offsetof(struct heh_req_ctx, u.cmac.desc) +
+                         sizeof(struct shash_desc) +
+                         crypto_shash_descsize(cmac),
+                       offsetof(struct heh_req_ctx, u.poly_hash.desc) +
+                         sizeof(struct shash_desc) +
+                         crypto_shash_descsize(poly_hash),
+                       offsetof(struct heh_req_ctx, u.ecb.req) +
+                         sizeof(struct ablkcipher_request) +
+                         crypto_ablkcipher_reqsize(ecb));
        tfm->crt_ablkcipher.reqsize = reqsize;
+
        return 0;
 
+err_free_poly_hash:
+       crypto_free_shash(poly_hash);
 err_free_cmac:
        crypto_free_shash(cmac);
        return err;
@@ -663,8 +767,8 @@ static void heh_exit_tfm(struct crypto_tfm *tfm)
 {
        struct heh_tfm_ctx *ctx = crypto_tfm_ctx(tfm);
 
-       gf128mul_free_4k(ctx->tau_key);
        crypto_free_shash(ctx->cmac);
+       crypto_free_shash(ctx->poly_hash);
        crypto_free_ablkcipher(ctx->ecb);
 }
 
@@ -673,6 +777,7 @@ static void heh_free_instance(struct crypto_instance *inst)
        struct heh_instance_ctx *ctx = crypto_instance_ctx(inst);
 
        crypto_drop_shash(&ctx->cmac);
+       crypto_drop_shash(&ctx->poly_hash);
        crypto_drop_skcipher(&ctx->ecb);
        kfree(inst);
 }
@@ -689,12 +794,13 @@ static void heh_free_instance(struct crypto_instance *inst)
  */
 static int heh_create_common(struct crypto_template *tmpl, struct rtattr **tb,
                             const char *full_name, const char *cmac_name,
-                            const char *ecb_name)
+                            const char *poly_hash_name, const char *ecb_name)
 {
        struct crypto_attr_type *algt;
        struct crypto_instance *inst;
        struct heh_instance_ctx *ctx;
        struct shash_alg *cmac;
+       struct shash_alg *poly_hash;
        struct crypto_alg *ecb;
        int err;
 
@@ -713,10 +819,9 @@ static int heh_create_common(struct crypto_template *tmpl, struct rtattr **tb,
 
        ctx = crypto_instance_ctx(inst);
 
-       /* Set up the cmac and ecb spawns */
-
+       /* Set up the cmac spawn */
        ctx->cmac.base.inst = inst;
-       err = crypto_grab_shash(&ctx->cmac, cmac_name, 0, CRYPTO_ALG_ASYNC);
+       err = crypto_grab_shash(&ctx->cmac, cmac_name, 0, 0);
        if (err)
                goto err_free_inst;
        cmac = crypto_spawn_shash_alg(&ctx->cmac);
@@ -724,12 +829,23 @@ static int heh_create_common(struct crypto_template *tmpl, struct rtattr **tb,
        if (cmac->digestsize != HEH_BLOCK_SIZE)
                goto err_drop_cmac;
 
+       /* Set up the poly_hash spawn */
+       ctx->poly_hash.base.inst = inst;
+       err = crypto_grab_shash(&ctx->poly_hash, poly_hash_name, 0, 0);
+       if (err)
+               goto err_drop_cmac;
+       poly_hash = crypto_spawn_shash_alg(&ctx->poly_hash);
+       err = -EINVAL;
+       if (poly_hash->digestsize != HEH_BLOCK_SIZE)
+               goto err_drop_poly_hash;
+
+       /* Set up the ecb spawn */
        ctx->ecb.base.inst = inst;
        err = crypto_grab_skcipher(&ctx->ecb, ecb_name, 0,
                                   crypto_requires_sync(algt->type,
                                                        algt->mask));
        if (err)
-               goto err_drop_cmac;
+               goto err_drop_poly_hash;
        ecb = crypto_skcipher_spawn_alg(&ctx->ecb);
 
        /* HEH only supports block ciphers with 16 byte block size */
@@ -750,7 +866,8 @@ static int heh_create_common(struct crypto_template *tmpl, struct rtattr **tb,
        /* Set the instance names */
        err = -ENAMETOOLONG;
        if (snprintf(inst->alg.cra_driver_name, CRYPTO_MAX_ALG_NAME,
-                    "heh_base(%s,%s)", cmac->base.cra_driver_name,
+                    "heh_base(%s,%s,%s)", cmac->base.cra_driver_name,
+                    poly_hash->base.cra_driver_name,
                     ecb->cra_driver_name) >= CRYPTO_MAX_ALG_NAME)
                goto err_drop_ecb;
 
@@ -762,8 +879,7 @@ static int heh_create_common(struct crypto_template *tmpl, struct rtattr **tb,
        /* Finish initializing the instance */
 
        inst->alg.cra_flags = CRYPTO_ALG_TYPE_ABLKCIPHER |
-                               ((cmac->base.cra_flags | ecb->cra_flags) &
-                                CRYPTO_ALG_ASYNC);
+                               (ecb->cra_flags & CRYPTO_ALG_ASYNC);
        inst->alg.cra_blocksize = HEH_BLOCK_SIZE;
        inst->alg.cra_ctxsize = sizeof(struct heh_tfm_ctx);
        inst->alg.cra_alignmask = ecb->cra_alignmask | (__alignof__(be128) - 1);
@@ -792,6 +908,8 @@ static int heh_create_common(struct crypto_template *tmpl, struct rtattr **tb,
 
 err_drop_ecb:
        crypto_drop_skcipher(&ctx->ecb);
+err_drop_poly_hash:
+       crypto_drop_shash(&ctx->poly_hash);
 err_drop_cmac:
        crypto_drop_shash(&ctx->cmac);
 err_free_inst:
@@ -823,7 +941,8 @@ static int heh_create(struct crypto_template *tmpl, struct rtattr **tb)
            CRYPTO_MAX_ALG_NAME)
                return -ENAMETOOLONG;
 
-       return heh_create_common(tmpl, tb, full_name, cmac_name, ecb_name);
+       return heh_create_common(tmpl, tb, full_name, cmac_name, "poly_hash",
+                                ecb_name);
 }
 
 static struct crypto_template heh_tmpl = {
@@ -837,26 +956,34 @@ static int heh_base_create(struct crypto_template *tmpl, struct rtattr **tb)
 {
        char full_name[CRYPTO_MAX_ALG_NAME];
        const char *cmac_name;
+       const char *poly_hash_name;
        const char *ecb_name;
 
        cmac_name = crypto_attr_alg_name(tb[1]);
        if (IS_ERR(cmac_name))
                return PTR_ERR(cmac_name);
 
-       ecb_name = crypto_attr_alg_name(tb[2]);
+       poly_hash_name = crypto_attr_alg_name(tb[2]);
+       if (IS_ERR(poly_hash_name))
+               return PTR_ERR(poly_hash_name);
+
+       ecb_name = crypto_attr_alg_name(tb[3]);
        if (IS_ERR(ecb_name))
                return PTR_ERR(ecb_name);
 
-       if (snprintf(full_name, CRYPTO_MAX_ALG_NAME, "heh_base(%s,%s)",
-                    cmac_name, ecb_name) >= CRYPTO_MAX_ALG_NAME)
+       if (snprintf(full_name, CRYPTO_MAX_ALG_NAME, "heh_base(%s,%s,%s)",
+                    cmac_name, poly_hash_name, ecb_name) >=
+           CRYPTO_MAX_ALG_NAME)
                return -ENAMETOOLONG;
 
-       return heh_create_common(tmpl, tb, full_name, cmac_name, ecb_name);
+       return heh_create_common(tmpl, tb, full_name, cmac_name, poly_hash_name,
+                                ecb_name);
 }
 
 /*
  * If HEH is instantiated as "heh_base" instead of "heh", then specific
- * implementations of cmac and ecb can be specified instead of just the cipher
+ * implementations of cmac, poly_hash, and ecb can be specified instead of just
+ * the cipher.
  */
 static struct crypto_template heh_base_tmpl = {
        .name = "heh_base",
@@ -877,8 +1004,14 @@ static int __init heh_module_init(void)
        if (err)
                goto out_undo_heh;
 
+       err = crypto_register_shash(&poly_hash_alg);
+       if (err)
+               goto out_undo_heh_base;
+
        return 0;
 
+out_undo_heh_base:
+       crypto_unregister_template(&heh_base_tmpl);
 out_undo_heh:
        crypto_unregister_template(&heh_tmpl);
        return err;
@@ -888,6 +1021,7 @@ static void __exit heh_module_exit(void)
 {
        crypto_unregister_template(&heh_tmpl);
        crypto_unregister_template(&heh_base_tmpl);
+       crypto_unregister_shash(&poly_hash_alg);
 }
 
 module_init(heh_module_init);