summary refs log tree commit diff stats
path: root/target/arm/crypto_helper.c
diff options
context:
space:
mode:
Diffstat (limited to 'target/arm/crypto_helper.c')
-rw-r--r--target/arm/crypto_helper.c271
1 files changed, 194 insertions, 77 deletions
diff --git a/target/arm/crypto_helper.c b/target/arm/crypto_helper.c
index f800266727..c76806dc8d 100644
--- a/target/arm/crypto_helper.c
+++ b/target/arm/crypto_helper.c
@@ -13,7 +13,9 @@
 
 #include "cpu.h"
 #include "exec/helper-proto.h"
+#include "tcg/tcg-gvec-desc.h"
 #include "crypto/aes.h"
+#include "vec_internal.h"
 
 union CRYPTO_STATE {
     uint8_t    bytes[16];
@@ -22,25 +24,35 @@ union CRYPTO_STATE {
 };
 
 #ifdef HOST_WORDS_BIGENDIAN
-#define CR_ST_BYTE(state, i)   (state.bytes[(15 - (i)) ^ 8])
-#define CR_ST_WORD(state, i)   (state.words[(3 - (i)) ^ 2])
+#define CR_ST_BYTE(state, i)   ((state).bytes[(15 - (i)) ^ 8])
+#define CR_ST_WORD(state, i)   ((state).words[(3 - (i)) ^ 2])
 #else
-#define CR_ST_BYTE(state, i)   (state.bytes[i])
-#define CR_ST_WORD(state, i)   (state.words[i])
+#define CR_ST_BYTE(state, i)   ((state).bytes[i])
+#define CR_ST_WORD(state, i)   ((state).words[i])
 #endif
 
-void HELPER(crypto_aese)(void *vd, void *vm, uint32_t decrypt)
+/*
+ * The caller has not been converted to full gvec, and so only
+ * modifies the low 16 bytes of the vector register.
+ */
+static void clear_tail_16(void *vd, uint32_t desc)
+{
+    int opr_sz = simd_oprsz(desc);
+    int max_sz = simd_maxsz(desc);
+
+    assert(opr_sz == 16);
+    clear_tail(vd, opr_sz, max_sz);
+}
+
+static void do_crypto_aese(uint64_t *rd, uint64_t *rn,
+                           uint64_t *rm, bool decrypt)
 {
     static uint8_t const * const sbox[2] = { AES_sbox, AES_isbox };
     static uint8_t const * const shift[2] = { AES_shifts, AES_ishifts };
-    uint64_t *rd = vd;
-    uint64_t *rm = vm;
     union CRYPTO_STATE rk = { .l = { rm[0], rm[1] } };
-    union CRYPTO_STATE st = { .l = { rd[0], rd[1] } };
+    union CRYPTO_STATE st = { .l = { rn[0], rn[1] } };
     int i;
 
-    assert(decrypt < 2);
-
     /* xor state vector with round key */
     rk.l[0] ^= st.l[0];
     rk.l[1] ^= st.l[1];
@@ -54,7 +66,18 @@ void HELPER(crypto_aese)(void *vd, void *vm, uint32_t decrypt)
     rd[1] = st.l[1];
 }
 
-void HELPER(crypto_aesmc)(void *vd, void *vm, uint32_t decrypt)
+void HELPER(crypto_aese)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+    intptr_t i, opr_sz = simd_oprsz(desc);
+    bool decrypt = simd_data(desc);
+
+    for (i = 0; i < opr_sz; i += 16) {
+        do_crypto_aese(vd + i, vn + i, vm + i, decrypt);
+    }
+    clear_tail(vd, opr_sz, simd_maxsz(desc));
+}
+
+static void do_crypto_aesmc(uint64_t *rd, uint64_t *rm, bool decrypt)
 {
     static uint32_t const mc[][256] = { {
         /* MixColumns lookup table */
@@ -190,13 +213,9 @@ void HELPER(crypto_aesmc)(void *vd, void *vm, uint32_t decrypt)
         0xbe805d9f, 0xb58d5491, 0xa89a4f83, 0xa397468d,
     } };
 
-    uint64_t *rd = vd;
-    uint64_t *rm = vm;
     union CRYPTO_STATE st = { .l = { rm[0], rm[1] } };
     int i;
 
-    assert(decrypt < 2);
-
     for (i = 0; i < 16; i += 4) {
         CR_ST_WORD(st, i >> 2) =
             mc[decrypt][CR_ST_BYTE(st, i)] ^
@@ -209,6 +228,17 @@ void HELPER(crypto_aesmc)(void *vd, void *vm, uint32_t decrypt)
     rd[1] = st.l[1];
 }
 
+void HELPER(crypto_aesmc)(void *vd, void *vm, uint32_t desc)
+{
+    intptr_t i, opr_sz = simd_oprsz(desc);
+    bool decrypt = simd_data(desc);
+
+    for (i = 0; i < opr_sz; i += 16) {
+        do_crypto_aesmc(vd + i, vm + i, decrypt);
+    }
+    clear_tail(vd, opr_sz, simd_maxsz(desc));
+}
+
 /*
  * SHA-1 logical functions
  */
@@ -228,52 +258,77 @@ static uint32_t maj(uint32_t x, uint32_t y, uint32_t z)
     return (x & y) | ((x | y) & z);
 }
 
-void HELPER(crypto_sha1_3reg)(void *vd, void *vn, void *vm, uint32_t op)
+void HELPER(crypto_sha1su0)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+    uint64_t *d = vd, *n = vn, *m = vm;
+    uint64_t d0, d1;
+
+    d0 = d[1] ^ d[0] ^ m[0];
+    d1 = n[0] ^ d[1] ^ m[1];
+    d[0] = d0;
+    d[1] = d1;
+
+    clear_tail_16(vd, desc);
+}
+
+static inline void crypto_sha1_3reg(uint64_t *rd, uint64_t *rn,
+                                    uint64_t *rm, uint32_t desc,
+                                    uint32_t (*fn)(union CRYPTO_STATE *d))
 {
-    uint64_t *rd = vd;
-    uint64_t *rn = vn;
-    uint64_t *rm = vm;
     union CRYPTO_STATE d = { .l = { rd[0], rd[1] } };
     union CRYPTO_STATE n = { .l = { rn[0], rn[1] } };
     union CRYPTO_STATE m = { .l = { rm[0], rm[1] } };
+    int i;
 
-    if (op == 3) { /* sha1su0 */
-        d.l[0] ^= d.l[1] ^ m.l[0];
-        d.l[1] ^= n.l[0] ^ m.l[1];
-    } else {
-        int i;
-
-        for (i = 0; i < 4; i++) {
-            uint32_t t;
-
-            switch (op) {
-            case 0: /* sha1c */
-                t = cho(CR_ST_WORD(d, 1), CR_ST_WORD(d, 2), CR_ST_WORD(d, 3));
-                break;
-            case 1: /* sha1p */
-                t = par(CR_ST_WORD(d, 1), CR_ST_WORD(d, 2), CR_ST_WORD(d, 3));
-                break;
-            case 2: /* sha1m */
-                t = maj(CR_ST_WORD(d, 1), CR_ST_WORD(d, 2), CR_ST_WORD(d, 3));
-                break;
-            default:
-                g_assert_not_reached();
-            }
-            t += rol32(CR_ST_WORD(d, 0), 5) + CR_ST_WORD(n, 0)
-                 + CR_ST_WORD(m, i);
-
-            CR_ST_WORD(n, 0) = CR_ST_WORD(d, 3);
-            CR_ST_WORD(d, 3) = CR_ST_WORD(d, 2);
-            CR_ST_WORD(d, 2) = ror32(CR_ST_WORD(d, 1), 2);
-            CR_ST_WORD(d, 1) = CR_ST_WORD(d, 0);
-            CR_ST_WORD(d, 0) = t;
-        }
+    for (i = 0; i < 4; i++) {
+        uint32_t t = fn(&d);
+
+        t += rol32(CR_ST_WORD(d, 0), 5) + CR_ST_WORD(n, 0)
+             + CR_ST_WORD(m, i);
+
+        CR_ST_WORD(n, 0) = CR_ST_WORD(d, 3);
+        CR_ST_WORD(d, 3) = CR_ST_WORD(d, 2);
+        CR_ST_WORD(d, 2) = ror32(CR_ST_WORD(d, 1), 2);
+        CR_ST_WORD(d, 1) = CR_ST_WORD(d, 0);
+        CR_ST_WORD(d, 0) = t;
     }
     rd[0] = d.l[0];
     rd[1] = d.l[1];
+
+    clear_tail_16(rd, desc);
+}
+
+static uint32_t do_sha1c(union CRYPTO_STATE *d)
+{
+    return cho(CR_ST_WORD(*d, 1), CR_ST_WORD(*d, 2), CR_ST_WORD(*d, 3));
+}
+
+void HELPER(crypto_sha1c)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+    crypto_sha1_3reg(vd, vn, vm, desc, do_sha1c);
+}
+
+static uint32_t do_sha1p(union CRYPTO_STATE *d)
+{
+    return par(CR_ST_WORD(*d, 1), CR_ST_WORD(*d, 2), CR_ST_WORD(*d, 3));
+}
+
+void HELPER(crypto_sha1p)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+    crypto_sha1_3reg(vd, vn, vm, desc, do_sha1p);
 }
 
-void HELPER(crypto_sha1h)(void *vd, void *vm)
+static uint32_t do_sha1m(union CRYPTO_STATE *d)
+{
+    return maj(CR_ST_WORD(*d, 1), CR_ST_WORD(*d, 2), CR_ST_WORD(*d, 3));
+}
+
+void HELPER(crypto_sha1m)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+    crypto_sha1_3reg(vd, vn, vm, desc, do_sha1m);
+}
+
+void HELPER(crypto_sha1h)(void *vd, void *vm, uint32_t desc)
 {
     uint64_t *rd = vd;
     uint64_t *rm = vm;
@@ -284,9 +339,11 @@ void HELPER(crypto_sha1h)(void *vd, void *vm)
 
     rd[0] = m.l[0];
     rd[1] = m.l[1];
+
+    clear_tail_16(vd, desc);
 }
 
-void HELPER(crypto_sha1su1)(void *vd, void *vm)
+void HELPER(crypto_sha1su1)(void *vd, void *vm, uint32_t desc)
 {
     uint64_t *rd = vd;
     uint64_t *rm = vm;
@@ -300,6 +357,8 @@ void HELPER(crypto_sha1su1)(void *vd, void *vm)
 
     rd[0] = d.l[0];
     rd[1] = d.l[1];
+
+    clear_tail_16(vd, desc);
 }
 
 /*
@@ -327,7 +386,7 @@ static uint32_t s1(uint32_t x)
     return ror32(x, 17) ^ ror32(x, 19) ^ (x >> 10);
 }
 
-void HELPER(crypto_sha256h)(void *vd, void *vn, void *vm)
+void HELPER(crypto_sha256h)(void *vd, void *vn, void *vm, uint32_t desc)
 {
     uint64_t *rd = vd;
     uint64_t *rn = vn;
@@ -358,9 +417,11 @@ void HELPER(crypto_sha256h)(void *vd, void *vn, void *vm)
 
     rd[0] = d.l[0];
     rd[1] = d.l[1];
+
+    clear_tail_16(vd, desc);
 }
 
-void HELPER(crypto_sha256h2)(void *vd, void *vn, void *vm)
+void HELPER(crypto_sha256h2)(void *vd, void *vn, void *vm, uint32_t desc)
 {
     uint64_t *rd = vd;
     uint64_t *rn = vn;
@@ -383,9 +444,11 @@ void HELPER(crypto_sha256h2)(void *vd, void *vn, void *vm)
 
     rd[0] = d.l[0];
     rd[1] = d.l[1];
+
+    clear_tail_16(vd, desc);
 }
 
-void HELPER(crypto_sha256su0)(void *vd, void *vm)
+void HELPER(crypto_sha256su0)(void *vd, void *vm, uint32_t desc)
 {
     uint64_t *rd = vd;
     uint64_t *rm = vm;
@@ -399,9 +462,11 @@ void HELPER(crypto_sha256su0)(void *vd, void *vm)
 
     rd[0] = d.l[0];
     rd[1] = d.l[1];
+
+    clear_tail_16(vd, desc);
 }
 
-void HELPER(crypto_sha256su1)(void *vd, void *vn, void *vm)
+void HELPER(crypto_sha256su1)(void *vd, void *vn, void *vm, uint32_t desc)
 {
     uint64_t *rd = vd;
     uint64_t *rn = vn;
@@ -417,6 +482,8 @@ void HELPER(crypto_sha256su1)(void *vd, void *vn, void *vm)
 
     rd[0] = d.l[0];
     rd[1] = d.l[1];
+
+    clear_tail_16(vd, desc);
 }
 
 /*
@@ -453,7 +520,7 @@ static uint64_t s1_512(uint64_t x)
     return ror64(x, 19) ^ ror64(x, 61) ^ (x >> 6);
 }
 
-void HELPER(crypto_sha512h)(void *vd, void *vn, void *vm)
+void HELPER(crypto_sha512h)(void *vd, void *vn, void *vm, uint32_t desc)
 {
     uint64_t *rd = vd;
     uint64_t *rn = vn;
@@ -466,9 +533,11 @@ void HELPER(crypto_sha512h)(void *vd, void *vn, void *vm)
 
     rd[0] = d0;
     rd[1] = d1;
+
+    clear_tail_16(vd, desc);
 }
 
-void HELPER(crypto_sha512h2)(void *vd, void *vn, void *vm)
+void HELPER(crypto_sha512h2)(void *vd, void *vn, void *vm, uint32_t desc)
 {
     uint64_t *rd = vd;
     uint64_t *rn = vn;
@@ -481,9 +550,11 @@ void HELPER(crypto_sha512h2)(void *vd, void *vn, void *vm)
 
     rd[0] = d0;
     rd[1] = d1;
+
+    clear_tail_16(vd, desc);
 }
 
-void HELPER(crypto_sha512su0)(void *vd, void *vn)
+void HELPER(crypto_sha512su0)(void *vd, void *vn, uint32_t desc)
 {
     uint64_t *rd = vd;
     uint64_t *rn = vn;
@@ -495,9 +566,11 @@ void HELPER(crypto_sha512su0)(void *vd, void *vn)
 
     rd[0] = d0;
     rd[1] = d1;
+
+    clear_tail_16(vd, desc);
 }
 
-void HELPER(crypto_sha512su1)(void *vd, void *vn, void *vm)
+void HELPER(crypto_sha512su1)(void *vd, void *vn, void *vm, uint32_t desc)
 {
     uint64_t *rd = vd;
     uint64_t *rn = vn;
@@ -505,9 +578,11 @@ void HELPER(crypto_sha512su1)(void *vd, void *vn, void *vm)
 
     rd[0] += s1_512(rn[0]) + rm[0];
     rd[1] += s1_512(rn[1]) + rm[1];
+
+    clear_tail_16(vd, desc);
 }
 
-void HELPER(crypto_sm3partw1)(void *vd, void *vn, void *vm)
+void HELPER(crypto_sm3partw1)(void *vd, void *vn, void *vm, uint32_t desc)
 {
     uint64_t *rd = vd;
     uint64_t *rn = vn;
@@ -531,9 +606,11 @@ void HELPER(crypto_sm3partw1)(void *vd, void *vn, void *vm)
 
     rd[0] = d.l[0];
     rd[1] = d.l[1];
+
+    clear_tail_16(vd, desc);
 }
 
-void HELPER(crypto_sm3partw2)(void *vd, void *vn, void *vm)
+void HELPER(crypto_sm3partw2)(void *vd, void *vn, void *vm, uint32_t desc)
 {
     uint64_t *rd = vd;
     uint64_t *rn = vn;
@@ -551,17 +628,18 @@ void HELPER(crypto_sm3partw2)(void *vd, void *vn, void *vm)
 
     rd[0] = d.l[0];
     rd[1] = d.l[1];
+
+    clear_tail_16(vd, desc);
 }
 
-void HELPER(crypto_sm3tt)(void *vd, void *vn, void *vm, uint32_t imm2,
-                          uint32_t opcode)
+static inline void QEMU_ALWAYS_INLINE
+crypto_sm3tt(uint64_t *rd, uint64_t *rn, uint64_t *rm,
+             uint32_t desc, uint32_t opcode)
 {
-    uint64_t *rd = vd;
-    uint64_t *rn = vn;
-    uint64_t *rm = vm;
     union CRYPTO_STATE d = { .l = { rd[0], rd[1] } };
     union CRYPTO_STATE n = { .l = { rn[0], rn[1] } };
     union CRYPTO_STATE m = { .l = { rm[0], rm[1] } };
+    uint32_t imm2 = simd_data(desc);
     uint32_t t;
 
     assert(imm2 < 4);
@@ -576,7 +654,7 @@ void HELPER(crypto_sm3tt)(void *vd, void *vn, void *vm, uint32_t imm2,
         /* SM3TT2B */
         t = cho(CR_ST_WORD(d, 3), CR_ST_WORD(d, 2), CR_ST_WORD(d, 1));
     } else {
-        g_assert_not_reached();
+        qemu_build_not_reached();
     }
 
     t += CR_ST_WORD(d, 0) + CR_ST_WORD(m, imm2);
@@ -601,8 +679,21 @@ void HELPER(crypto_sm3tt)(void *vd, void *vn, void *vm, uint32_t imm2,
 
     rd[0] = d.l[0];
     rd[1] = d.l[1];
+
+    clear_tail_16(rd, desc);
 }
 
+#define DO_SM3TT(NAME, OPCODE) \
+    void HELPER(NAME)(void *vd, void *vn, void *vm, uint32_t desc) \
+    { crypto_sm3tt(vd, vn, vm, desc, OPCODE); }
+
+DO_SM3TT(crypto_sm3tt1a, 0)
+DO_SM3TT(crypto_sm3tt1b, 1)
+DO_SM3TT(crypto_sm3tt2a, 2)
+DO_SM3TT(crypto_sm3tt2b, 3)
+
+#undef DO_SM3TT
+
 static uint8_t const sm4_sbox[] = {
     0xd6, 0x90, 0xe9, 0xfe, 0xcc, 0xe1, 0x3d, 0xb7,
     0x16, 0xb6, 0x14, 0xc2, 0x28, 0xfb, 0x2c, 0x05,
@@ -638,12 +729,10 @@ static uint8_t const sm4_sbox[] = {
     0x79, 0xee, 0x5f, 0x3e, 0xd7, 0xcb, 0x39, 0x48,
 };
 
-void HELPER(crypto_sm4e)(void *vd, void *vn)
+static void do_crypto_sm4e(uint64_t *rd, uint64_t *rn, uint64_t *rm)
 {
-    uint64_t *rd = vd;
-    uint64_t *rn = vn;
-    union CRYPTO_STATE d = { .l = { rd[0], rd[1] } };
-    union CRYPTO_STATE n = { .l = { rn[0], rn[1] } };
+    union CRYPTO_STATE d = { .l = { rn[0], rn[1] } };
+    union CRYPTO_STATE n = { .l = { rm[0], rm[1] } };
     uint32_t t, i;
 
     for (i = 0; i < 4; i++) {
@@ -665,11 +754,18 @@ void HELPER(crypto_sm4e)(void *vd, void *vn)
     rd[1] = d.l[1];
 }
 
-void HELPER(crypto_sm4ekey)(void *vd, void *vn, void* vm)
+void HELPER(crypto_sm4e)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+    intptr_t i, opr_sz = simd_oprsz(desc);
+
+    for (i = 0; i < opr_sz; i += 16) {
+        do_crypto_sm4e(vd + i, vn + i, vm + i);
+    }
+    clear_tail(vd, opr_sz, simd_maxsz(desc));
+}
+
+static void do_crypto_sm4ekey(uint64_t *rd, uint64_t *rn, uint64_t *rm)
 {
-    uint64_t *rd = vd;
-    uint64_t *rn = vn;
-    uint64_t *rm = vm;
     union CRYPTO_STATE d;
     union CRYPTO_STATE n = { .l = { rn[0], rn[1] } };
     union CRYPTO_STATE m = { .l = { rm[0], rm[1] } };
@@ -693,3 +789,24 @@ void HELPER(crypto_sm4ekey)(void *vd, void *vn, void* vm)
     rd[0] = d.l[0];
     rd[1] = d.l[1];
 }
+
+void HELPER(crypto_sm4ekey)(void *vd, void *vn, void* vm, uint32_t desc)
+{
+    intptr_t i, opr_sz = simd_oprsz(desc);
+
+    for (i = 0; i < opr_sz; i += 16) {
+        do_crypto_sm4ekey(vd + i, vn + i, vm + i);
+    }
+    clear_tail(vd, opr_sz, simd_maxsz(desc));
+}
+
+void HELPER(crypto_rax1)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+    intptr_t i, opr_sz = simd_oprsz(desc);
+    uint64_t *d = vd, *n = vn, *m = vm;
+
+    for (i = 0; i < opr_sz / 8; ++i) {
+        d[i] = n[i] ^ rol64(m[i], 1);
+    }
+    clear_tail(vd, opr_sz, simd_maxsz(desc));
+}