summary refs log tree commit diff stats
path: root/target/arm/tcg/vec_helper.c
diff options
context:
space:
mode:
Diffstat (limited to 'target/arm/tcg/vec_helper.c')
-rw-r--r--target/arm/tcg/vec_helper.c208
1 files changed, 163 insertions, 45 deletions
diff --git a/target/arm/tcg/vec_helper.c b/target/arm/tcg/vec_helper.c
index 98604d170f..22ddb96881 100644
--- a/target/arm/tcg/vec_helper.c
+++ b/target/arm/tcg/vec_helper.c
@@ -2790,44 +2790,115 @@ DO_MMLA_B(gvec_usmmla_b, do_usmmla_b)
  * BFloat16 Dot Product
  */
 
-float32 bfdotadd(float32 sum, uint32_t e1, uint32_t e2)
+bool is_ebf(CPUARMState *env, float_status *statusp, float_status *oddstatusp)
 {
-    /* FPCR is ignored for BFDOT and BFMMLA. */
-    float_status bf_status = {
+    /*
+     * For BFDOT, BFMMLA, etc, the behaviour depends on FPCR.EBF.
+     * For EBF = 0, we ignore the FPCR bits which determine rounding
+     * mode and denormal-flushing, and we do unfused multiplies and
+     * additions with intermediate rounding of all products and sums.
+     * For EBF = 1, we honour FPCR rounding mode and denormal-flushing bits,
+     * and we perform a fused two-way sum-of-products without intermediate
+     * rounding of the products.
+     * In either case, we don't set fp exception flags.
+     *
+     * EBF is AArch64 only, so even if it's set in the FPCR it has
+     * no effect on AArch32 instructions.
+     */
+    bool ebf = is_a64(env) && env->vfp.fpcr & FPCR_EBF;
+    *statusp = (float_status){
         .tininess_before_rounding = float_tininess_before_rounding,
         .float_rounding_mode = float_round_to_odd_inf,
         .flush_to_zero = true,
         .flush_inputs_to_zero = true,
         .default_nan_mode = true,
     };
+
+    if (ebf) {
+        float_status *fpst = &env->vfp.fp_status;
+        set_flush_to_zero(get_flush_to_zero(fpst), statusp);
+        set_flush_inputs_to_zero(get_flush_inputs_to_zero(fpst), statusp);
+        set_float_rounding_mode(get_float_rounding_mode(fpst), statusp);
+
+        /* EBF=1 needs to do a step with round-to-odd semantics */
+        *oddstatusp = *statusp;
+        set_float_rounding_mode(float_round_to_odd, oddstatusp);
+    }
+
+    return ebf;
+}
+
+float32 bfdotadd(float32 sum, uint32_t e1, uint32_t e2, float_status *fpst)
+{
     float32 t1, t2;
 
     /*
      * Extract each BFloat16 from the element pair, and shift
      * them such that they become float32.
      */
-    t1 = float32_mul(e1 << 16, e2 << 16, &bf_status);
-    t2 = float32_mul(e1 & 0xffff0000u, e2 & 0xffff0000u, &bf_status);
-    t1 = float32_add(t1, t2, &bf_status);
-    t1 = float32_add(sum, t1, &bf_status);
+    t1 = float32_mul(e1 << 16, e2 << 16, fpst);
+    t2 = float32_mul(e1 & 0xffff0000u, e2 & 0xffff0000u, fpst);
+    t1 = float32_add(t1, t2, fpst);
+    t1 = float32_add(sum, t1, fpst);
 
     return t1;
 }
 
-void HELPER(gvec_bfdot)(void *vd, void *vn, void *vm, void *va, uint32_t desc)
+float32 bfdotadd_ebf(float32 sum, uint32_t e1, uint32_t e2,
+                     float_status *fpst, float_status *fpst_odd)
+{
+    /*
+     * Compare f16_dotadd() in sme_helper.c, but here we have
+     * bfloat16 inputs. In particular that means that we do not
+     * want the FPCR.FZ16 flush semantics, so we use the normal
+     * float_status for the input handling here.
+     */
+    float64 e1r = float32_to_float64(e1 << 16, fpst);
+    float64 e1c = float32_to_float64(e1 & 0xffff0000u, fpst);
+    float64 e2r = float32_to_float64(e2 << 16, fpst);
+    float64 e2c = float32_to_float64(e2 & 0xffff0000u, fpst);
+    float64 t64;
+    float32 t32;
+
+    /*
+     * The ARM pseudocode function FPDot performs both multiplies
+     * and the add with a single rounding operation.  Emulate this
+     * by performing the first multiply in round-to-odd, then doing
+     * the second multiply as fused multiply-add, and rounding to
+     * float32 all in one step.
+     */
+    t64 = float64_mul(e1r, e2r, fpst_odd);
+    t64 = float64r32_muladd(e1c, e2c, t64, 0, fpst);
+
+    /* This conversion is exact, because we've already rounded. */
+    t32 = float64_to_float32(t64, fpst);
+
+    /* The final accumulation step is not fused. */
+    return float32_add(sum, t32, fpst);
+}
+
+void HELPER(gvec_bfdot)(void *vd, void *vn, void *vm, void *va,
+                        CPUARMState *env, uint32_t desc)
 {
     intptr_t i, opr_sz = simd_oprsz(desc);
     float32 *d = vd, *a = va;
     uint32_t *n = vn, *m = vm;
+    float_status fpst, fpst_odd;
 
-    for (i = 0; i < opr_sz / 4; ++i) {
-        d[i] = bfdotadd(a[i], n[i], m[i]);
+    if (is_ebf(env, &fpst, &fpst_odd)) {
+        for (i = 0; i < opr_sz / 4; ++i) {
+            d[i] = bfdotadd_ebf(a[i], n[i], m[i], &fpst, &fpst_odd);
+        }
+    } else {
+        for (i = 0; i < opr_sz / 4; ++i) {
+            d[i] = bfdotadd(a[i], n[i], m[i], &fpst);
+        }
     }
     clear_tail(d, opr_sz, simd_maxsz(desc));
 }
 
 void HELPER(gvec_bfdot_idx)(void *vd, void *vn, void *vm,
-                            void *va, uint32_t desc)
+                            void *va, CPUARMState *env, uint32_t desc)
 {
     intptr_t i, j, opr_sz = simd_oprsz(desc);
     intptr_t index = simd_data(desc);
@@ -2835,53 +2906,100 @@ void HELPER(gvec_bfdot_idx)(void *vd, void *vn, void *vm,
     intptr_t eltspersegment = MIN(16 / 4, elements);
     float32 *d = vd, *a = va;
     uint32_t *n = vn, *m = vm;
+    float_status fpst, fpst_odd;
 
-    for (i = 0; i < elements; i += eltspersegment) {
-        uint32_t m_idx = m[i + H4(index)];
+    if (is_ebf(env, &fpst, &fpst_odd)) {
+        for (i = 0; i < elements; i += eltspersegment) {
+            uint32_t m_idx = m[i + H4(index)];
 
-        for (j = i; j < i + eltspersegment; j++) {
-            d[j] = bfdotadd(a[j], n[j], m_idx);
+            for (j = i; j < i + eltspersegment; j++) {
+                d[j] = bfdotadd_ebf(a[j], n[j], m_idx, &fpst, &fpst_odd);
+            }
+        }
+    } else {
+        for (i = 0; i < elements; i += eltspersegment) {
+            uint32_t m_idx = m[i + H4(index)];
+
+            for (j = i; j < i + eltspersegment; j++) {
+                d[j] = bfdotadd(a[j], n[j], m_idx, &fpst);
+            }
         }
     }
     clear_tail(d, opr_sz, simd_maxsz(desc));
 }
 
-void HELPER(gvec_bfmmla)(void *vd, void *vn, void *vm, void *va, uint32_t desc)
+void HELPER(gvec_bfmmla)(void *vd, void *vn, void *vm, void *va,
+                         CPUARMState *env, uint32_t desc)
 {
     intptr_t s, opr_sz = simd_oprsz(desc);
     float32 *d = vd, *a = va;
     uint32_t *n = vn, *m = vm;
+    float_status fpst, fpst_odd;
 
-    for (s = 0; s < opr_sz / 4; s += 4) {
-        float32 sum00, sum01, sum10, sum11;
+    if (is_ebf(env, &fpst, &fpst_odd)) {
+        for (s = 0; s < opr_sz / 4; s += 4) {
+            float32 sum00, sum01, sum10, sum11;
 
-        /*
-         * Process the entire segment at once, writing back the
-         * results only after we've consumed all of the inputs.
-         *
-         * Key to indices by column:
-         *               i   j           i   k             j   k
-         */
-        sum00 = a[s + H4(0 + 0)];
-        sum00 = bfdotadd(sum00, n[s + H4(0 + 0)], m[s + H4(0 + 0)]);
-        sum00 = bfdotadd(sum00, n[s + H4(0 + 1)], m[s + H4(0 + 1)]);
-
-        sum01 = a[s + H4(0 + 1)];
-        sum01 = bfdotadd(sum01, n[s + H4(0 + 0)], m[s + H4(2 + 0)]);
-        sum01 = bfdotadd(sum01, n[s + H4(0 + 1)], m[s + H4(2 + 1)]);
-
-        sum10 = a[s + H4(2 + 0)];
-        sum10 = bfdotadd(sum10, n[s + H4(2 + 0)], m[s + H4(0 + 0)]);
-        sum10 = bfdotadd(sum10, n[s + H4(2 + 1)], m[s + H4(0 + 1)]);
-
-        sum11 = a[s + H4(2 + 1)];
-        sum11 = bfdotadd(sum11, n[s + H4(2 + 0)], m[s + H4(2 + 0)]);
-        sum11 = bfdotadd(sum11, n[s + H4(2 + 1)], m[s + H4(2 + 1)]);
-
-        d[s + H4(0 + 0)] = sum00;
-        d[s + H4(0 + 1)] = sum01;
-        d[s + H4(2 + 0)] = sum10;
-        d[s + H4(2 + 1)] = sum11;
+            /*
+             * Process the entire segment at once, writing back the
+             * results only after we've consumed all of the inputs.
+             *
+             * Key to indices by column:
+             *               i   j               i   k             j   k
+             */
+            sum00 = a[s + H4(0 + 0)];
+            sum00 = bfdotadd_ebf(sum00, n[s + H4(0 + 0)], m[s + H4(0 + 0)], &fpst, &fpst_odd);
+            sum00 = bfdotadd_ebf(sum00, n[s + H4(0 + 1)], m[s + H4(0 + 1)], &fpst, &fpst_odd);
+
+            sum01 = a[s + H4(0 + 1)];
+            sum01 = bfdotadd_ebf(sum01, n[s + H4(0 + 0)], m[s + H4(2 + 0)], &fpst, &fpst_odd);
+            sum01 = bfdotadd_ebf(sum01, n[s + H4(0 + 1)], m[s + H4(2 + 1)], &fpst, &fpst_odd);
+
+            sum10 = a[s + H4(2 + 0)];
+            sum10 = bfdotadd_ebf(sum10, n[s + H4(2 + 0)], m[s + H4(0 + 0)], &fpst, &fpst_odd);
+            sum10 = bfdotadd_ebf(sum10, n[s + H4(2 + 1)], m[s + H4(0 + 1)], &fpst, &fpst_odd);
+
+            sum11 = a[s + H4(2 + 1)];
+            sum11 = bfdotadd_ebf(sum11, n[s + H4(2 + 0)], m[s + H4(2 + 0)], &fpst, &fpst_odd);
+            sum11 = bfdotadd_ebf(sum11, n[s + H4(2 + 1)], m[s + H4(2 + 1)], &fpst, &fpst_odd);
+
+            d[s + H4(0 + 0)] = sum00;
+            d[s + H4(0 + 1)] = sum01;
+            d[s + H4(2 + 0)] = sum10;
+            d[s + H4(2 + 1)] = sum11;
+        }
+    } else {
+        for (s = 0; s < opr_sz / 4; s += 4) {
+            float32 sum00, sum01, sum10, sum11;
+
+            /*
+             * Process the entire segment at once, writing back the
+             * results only after we've consumed all of the inputs.
+             *
+             * Key to indices by column:
+             *               i   j           i   k             j   k
+             */
+            sum00 = a[s + H4(0 + 0)];
+            sum00 = bfdotadd(sum00, n[s + H4(0 + 0)], m[s + H4(0 + 0)], &fpst);
+            sum00 = bfdotadd(sum00, n[s + H4(0 + 1)], m[s + H4(0 + 1)], &fpst);
+
+            sum01 = a[s + H4(0 + 1)];
+            sum01 = bfdotadd(sum01, n[s + H4(0 + 0)], m[s + H4(2 + 0)], &fpst);
+            sum01 = bfdotadd(sum01, n[s + H4(0 + 1)], m[s + H4(2 + 1)], &fpst);
+
+            sum10 = a[s + H4(2 + 0)];
+            sum10 = bfdotadd(sum10, n[s + H4(2 + 0)], m[s + H4(0 + 0)], &fpst);
+            sum10 = bfdotadd(sum10, n[s + H4(2 + 1)], m[s + H4(0 + 1)], &fpst);
+
+            sum11 = a[s + H4(2 + 1)];
+            sum11 = bfdotadd(sum11, n[s + H4(2 + 0)], m[s + H4(2 + 0)], &fpst);
+            sum11 = bfdotadd(sum11, n[s + H4(2 + 1)], m[s + H4(2 + 1)], &fpst);
+
+            d[s + H4(0 + 0)] = sum00;
+            d[s + H4(0 + 1)] = sum01;
+            d[s + H4(2 + 0)] = sum10;
+            d[s + H4(2 + 1)] = sum11;
+        }
     }
     clear_tail(d, opr_sz, simd_maxsz(desc));
 }