summary refs log tree commit diff stats
path: root/target/arm/tcg/sme_helper.c
diff options
context:
space:
mode:
Diffstat (limited to 'target/arm/tcg/sme_helper.c')
-rw-r--r--target/arm/tcg/sme_helper.c141
1 files changed, 118 insertions, 23 deletions
diff --git a/target/arm/tcg/sme_helper.c b/target/arm/tcg/sme_helper.c
index 4772c97deb..eff0ce7480 100644
--- a/target/arm/tcg/sme_helper.c
+++ b/target/arm/tcg/sme_helper.c
@@ -1002,19 +1002,18 @@ void HELPER(sme_addva_d)(void *vzda, void *vzn, void *vpn,
     }
 }
 
-void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn,
-                         void *vpm, float_status *fpst, uint32_t desc)
+static void do_fmopa_s(void *vza, void *vzn, void *vzm, uint16_t *pn,
+                       uint16_t *pm, float_status *fpst, uint32_t desc,
+                       uint32_t negx, int negf)
 {
     intptr_t row, col, oprsz = simd_maxsz(desc);
-    uint32_t neg = simd_data(desc) << 31;
-    uint16_t *pn = vpn, *pm = vpm;
 
     for (row = 0; row < oprsz; ) {
         uint16_t pa = pn[H2(row >> 4)];
         do {
             if (pa & 1) {
                 void *vza_row = vza + tile_vslice_offset(row);
-                uint32_t n = *(uint32_t *)(vzn + H1_4(row)) ^ neg;
+                uint32_t n = *(uint32_t *)(vzn + H1_4(row)) ^ negx;
 
                 for (col = 0; col < oprsz; ) {
                     uint16_t pb = pm[H2(col >> 4)];
@@ -1022,7 +1021,7 @@ void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn,
                         if (pb & 1) {
                             uint32_t *a = vza_row + H1_4(col);
                             uint32_t *m = vzm + H1_4(col);
-                            *a = float32_muladd(n, *m, *a, 0, fpst);
+                            *a = float32_muladd(n, *m, *a, negf, fpst);
                         }
                         col += 4;
                         pb >>= 4;
@@ -1035,29 +1034,65 @@ void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn,
     }
 }
 
-void HELPER(sme_fmopa_d)(void *vza, void *vzn, void *vzm, void *vpn,
+void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn,
+                         void *vpm, float_status *fpst, uint32_t desc)
+{
+    do_fmopa_s(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 0);
+}
+
+void HELPER(sme_fmops_s)(void *vza, void *vzn, void *vzm, void *vpn,
                          void *vpm, float_status *fpst, uint32_t desc)
 {
+    do_fmopa_s(vza, vzn, vzm, vpn, vpm, fpst, desc, 1u << 31, 0);
+}
+
+void HELPER(sme_ah_fmops_s)(void *vza, void *vzn, void *vzm, void *vpn,
+                            void *vpm, float_status *fpst, uint32_t desc)
+{
+    do_fmopa_s(vza, vzn, vzm, vpn, vpm, fpst, desc, 0,
+               float_muladd_negate_product);
+}
+
+static void do_fmopa_d(uint64_t *za, uint64_t *zn, uint64_t *zm, uint8_t *pn,
+                       uint8_t *pm, float_status *fpst, uint32_t desc,
+                       uint64_t negx, int negf)
+{
     intptr_t row, col, oprsz = simd_oprsz(desc) / 8;
-    uint64_t neg = (uint64_t)simd_data(desc) << 63;
-    uint64_t *za = vza, *zn = vzn, *zm = vzm;
-    uint8_t *pn = vpn, *pm = vpm;
 
     for (row = 0; row < oprsz; ++row) {
         if (pn[H1(row)] & 1) {
             uint64_t *za_row = &za[tile_vslice_index(row)];
-            uint64_t n = zn[row] ^ neg;
+            uint64_t n = zn[row] ^ negx;
 
             for (col = 0; col < oprsz; ++col) {
                 if (pm[H1(col)] & 1) {
                     uint64_t *a = &za_row[col];
-                    *a = float64_muladd(n, zm[col], *a, 0, fpst);
+                    *a = float64_muladd(n, zm[col], *a, negf, fpst);
                 }
             }
         }
     }
 }
 
+void HELPER(sme_fmopa_d)(void *vza, void *vzn, void *vzm, void *vpn,
+                         void *vpm, float_status *fpst, uint32_t desc)
+{
+    do_fmopa_d(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 0);
+}
+
+void HELPER(sme_fmops_d)(void *vza, void *vzn, void *vzm, void *vpn,
+                         void *vpm, float_status *fpst, uint32_t desc)
+{
+    do_fmopa_d(vza, vzn, vzm, vpn, vpm, fpst, desc, 1ull << 63, 0);
+}
+
+void HELPER(sme_ah_fmops_d)(void *vza, void *vzn, void *vzm, void *vpn,
+                            void *vpm, float_status *fpst, uint32_t desc)
+{
+    do_fmopa_d(vza, vzn, vzm, vpn, vpm, fpst, desc, 0,
+               float_muladd_negate_product);
+}
+
 /*
  * Alter PAIR as needed for controlling predicates being false,
  * and for NEG on an enabled row element.
@@ -1078,6 +1113,20 @@ static inline uint32_t f16mop_adj_pair(uint32_t pair, uint32_t pg, uint32_t neg)
     return pair;
 }
 
+static inline uint32_t f16mop_ah_neg_adj_pair(uint32_t pair, uint32_t pg)
+{
+    uint32_t l = pg & 1 ? float16_ah_chs(pair) : 0;
+    uint32_t h = pg & 4 ? float16_ah_chs(pair >> 16) : 0;
+    return l | (h << 16);
+}
+
+static inline uint32_t bf16mop_ah_neg_adj_pair(uint32_t pair, uint32_t pg)
+{
+    uint32_t l = pg & 1 ? bfloat16_ah_chs(pair) : 0;
+    uint32_t h = pg & 4 ? bfloat16_ah_chs(pair >> 16) : 0;
+    return l | (h << 16);
+}
+
 static float32 f16_dotadd(float32 sum, uint32_t e1, uint32_t e2,
                           float_status *s_f16, float_status *s_std,
                           float_status *s_odd)
@@ -1146,12 +1195,11 @@ static float32 f16_dotadd(float32 sum, uint32_t e1, uint32_t e2,
     return float32_add(sum, t32, s_std);
 }
 
-void HELPER(sme_fmopa_w_h)(void *vza, void *vzn, void *vzm, void *vpn,
-                           void *vpm, CPUARMState *env, uint32_t desc)
+static void do_fmopa_w_h(void *vza, void *vzn, void *vzm, uint16_t *pn,
+                         uint16_t *pm, CPUARMState *env, uint32_t desc,
+                         uint32_t negx, bool ah_neg)
 {
     intptr_t row, col, oprsz = simd_maxsz(desc);
-    uint32_t neg = simd_data(desc) * 0x80008000u;
-    uint16_t *pn = vpn, *pm = vpm;
     float_status fpst_odd = env->vfp.fp_status[FPST_ZA];
 
     set_float_rounding_mode(float_round_to_odd, &fpst_odd);
@@ -1162,7 +1210,11 @@ void HELPER(sme_fmopa_w_h)(void *vza, void *vzn, void *vzm, void *vpn,
             void *vza_row = vza + tile_vslice_offset(row);
             uint32_t n = *(uint32_t *)(vzn + H1_4(row));
 
-            n = f16mop_adj_pair(n, prow, neg);
+            if (ah_neg) {
+                n = f16mop_ah_neg_adj_pair(n, prow);
+            } else {
+                n = f16mop_adj_pair(n, prow, negx);
+            }
 
             for (col = 0; col < oprsz; ) {
                 uint16_t pcol = pm[H2(col >> 4)];
@@ -1187,6 +1239,24 @@ void HELPER(sme_fmopa_w_h)(void *vza, void *vzn, void *vzm, void *vpn,
     }
 }
 
+void HELPER(sme_fmopa_w_h)(void *vza, void *vzn, void *vzm, void *vpn,
+                           void *vpm, CPUARMState *env, uint32_t desc)
+{
+    do_fmopa_w_h(vza, vzn, vzm, vpn, vpm, env, desc, 0, false);
+}
+
+void HELPER(sme_fmops_w_h)(void *vza, void *vzn, void *vzm, void *vpn,
+                           void *vpm, CPUARMState *env, uint32_t desc)
+{
+    do_fmopa_w_h(vza, vzn, vzm, vpn, vpm, env, desc, 0x80008000u, false);
+}
+
+void HELPER(sme_ah_fmops_w_h)(void *vza, void *vzn, void *vzm, void *vpn,
+                              void *vpm, CPUARMState *env, uint32_t desc)
+{
+    do_fmopa_w_h(vza, vzn, vzm, vpn, vpm, env, desc, 0, true);
+}
+
 void HELPER(sme2_fdot_h)(void *vd, void *vn, void *vm, void *va,
                          CPUARMState *env, uint32_t desc)
 {
@@ -1261,12 +1331,11 @@ void HELPER(sme2_fvdot_idx_h)(void *vd, void *vn, void *vm, void *va,
     }
 }
 
-void HELPER(sme_bfmopa_w)(void *vza, void *vzn, void *vzm,
-                          void *vpn, void *vpm, CPUARMState *env, uint32_t desc)
+static void do_bfmopa_w(void *vza, void *vzn, void *vzm,
+                        uint16_t *pn, uint16_t *pm, CPUARMState *env,
+                        uint32_t desc, uint32_t negx, bool ah_neg)
 {
     intptr_t row, col, oprsz = simd_maxsz(desc);
-    uint32_t neg = simd_data(desc) * 0x80008000u;
-    uint16_t *pn = vpn, *pm = vpm;
     float_status fpst, fpst_odd;
 
     if (is_ebf(env, &fpst, &fpst_odd)) {
@@ -1276,7 +1345,11 @@ void HELPER(sme_bfmopa_w)(void *vza, void *vzn, void *vzm,
                 void *vza_row = vza + tile_vslice_offset(row);
                 uint32_t n = *(uint32_t *)(vzn + H1_4(row));
 
-                n = f16mop_adj_pair(n, prow, neg);
+                if (ah_neg) {
+                    n = bf16mop_ah_neg_adj_pair(n, prow);
+                } else {
+                    n = f16mop_adj_pair(n, prow, negx);
+                }
 
                 for (col = 0; col < oprsz; ) {
                     uint16_t pcol = pm[H2(col >> 4)];
@@ -1303,7 +1376,11 @@ void HELPER(sme_bfmopa_w)(void *vza, void *vzn, void *vzm,
                 void *vza_row = vza + tile_vslice_offset(row);
                 uint32_t n = *(uint32_t *)(vzn + H1_4(row));
 
-                n = f16mop_adj_pair(n, prow, neg);
+                if (ah_neg) {
+                    n = bf16mop_ah_neg_adj_pair(n, prow);
+                } else {
+                    n = f16mop_adj_pair(n, prow, negx);
+                }
 
                 for (col = 0; col < oprsz; ) {
                     uint16_t pcol = pm[H2(col >> 4)];
@@ -1326,6 +1403,24 @@ void HELPER(sme_bfmopa_w)(void *vza, void *vzn, void *vzm,
     }
 }
 
+void HELPER(sme_bfmopa_w)(void *vza, void *vzn, void *vzm, void *vpn,
+                          void *vpm, CPUARMState *env, uint32_t desc)
+{
+    do_bfmopa_w(vza, vzn, vzm, vpn, vpm, env, desc, 0, false);
+}
+
+void HELPER(sme_bfmops_w)(void *vza, void *vzn, void *vzm, void *vpn,
+                          void *vpm, CPUARMState *env, uint32_t desc)
+{
+    do_bfmopa_w(vza, vzn, vzm, vpn, vpm, env, desc, 0x80008000u, false);
+}
+
+void HELPER(sme_ah_bfmops_w)(void *vza, void *vzn, void *vzm, void *vpn,
+                             void *vpm, CPUARMState *env, uint32_t desc)
+{
+    do_bfmopa_w(vza, vzn, vzm, vpn, vpm, env, desc, 0, true);
+}
+
 typedef uint32_t IMOPFn32(uint32_t, uint32_t, uint32_t, uint8_t, bool);
 static inline void do_imopa_s(uint32_t *za, uint32_t *zn, uint32_t *zm,
                               uint8_t *pn, uint8_t *pm,