about summary refs log tree commit diff stats
path: root/miasm2/jitter
diff options
context:
space:
mode:
authorserpilliere <serpilliere@users.noreply.github.com>2018-07-10 19:04:39 +0200
committerGitHub <noreply@github.com>2018-07-10 19:04:39 +0200
commitc48a8ba7ed9110df962df94ab9db314b2873c6b2 (patch)
tree6e14f8fdaa4471dc1fb8fdcd6bfe9e271500a803 /miasm2/jitter
parenta5221c1b926af7716860fd27039528cfb54d6095 (diff)
parentd65bbbcc4a7d3c0fff9e9c80a04e23bbc4bf5333 (diff)
downloadmiasm-c48a8ba7ed9110df962df94ab9db314b2873c6b2.tar.gz
miasm-c48a8ba7ed9110df962df94ab9db314b2873c6b2.zip
Merge pull request #795 from commial/features/better-float-sse
Better float support & additionnal SSE
Diffstat (limited to '')
-rw-r--r--miasm2/jitter/arch/JitCore_x86.h16
-rw-r--r--miasm2/jitter/llvmconvert.py161
-rw-r--r--miasm2/jitter/op_semantics.c227
-rw-r--r--miasm2/jitter/op_semantics.h38
4 files changed, 288 insertions, 154 deletions
diff --git a/miasm2/jitter/arch/JitCore_x86.h b/miasm2/jitter/arch/JitCore_x86.h
index 221ba5db..a5fc4bd4 100644
--- a/miasm2/jitter/arch/JitCore_x86.h
+++ b/miasm2/jitter/arch/JitCore_x86.h
@@ -49,14 +49,14 @@ typedef struct {
 
 	uint64_t cond;
 
-	double float_st0;
-	double float_st1;
-	double float_st2;
-	double float_st3;
-	double float_st4;
-	double float_st5;
-	double float_st6;
-	double float_st7;
+	uint64_t float_st0;
+	uint64_t float_st1;
+	uint64_t float_st2;
+	uint64_t float_st3;
+	uint64_t float_st4;
+	uint64_t float_st5;
+	uint64_t float_st6;
+	uint64_t float_st7;
 
 	unsigned int float_c0;
 	unsigned int float_c1;
diff --git a/miasm2/jitter/llvmconvert.py b/miasm2/jitter/llvmconvert.py
index d63351cc..c4e6709d 100644
--- a/miasm2/jitter/llvmconvert.py
+++ b/miasm2/jitter/llvmconvert.py
@@ -51,6 +51,17 @@ class LLVMType(llvm_ir.Type):
         else:
             raise ValueError()
 
+    @classmethod
+    def fptype(cls, size):
+        """Return the floating type corresponding to precision @size"""
+        if size == 32:
+            precision = llvm_ir.FloatType()
+        elif size == 64:
+            precision = llvm_ir.DoubleType()
+        else:
+            raise RuntimeError("Unsupported precision: %x", size)
+        return precision
+
 
 class LLVMContext():
 
@@ -236,8 +247,16 @@ class LLVMContext_JIT(LLVMContext):
         i8 = LLVMType.IntType(8)
         p8 = llvm_ir.PointerType(i8)
         itype = LLVMType.IntType(64)
+        ftype = llvm_ir.FloatType()
+        dtype = llvm_ir.DoubleType()
         fc = {"llvm.ctpop.i8": {"ret": i8,
                                 "args": [i8]},
+              "llvm.nearbyint.f32": {"ret": ftype,
+                                     "args": [ftype]},
+              "llvm.nearbyint.f64": {"ret": dtype,
+                                     "args": [dtype]},
+              "llvm.trunc.f32": {"ret": ftype,
+                                 "args": [ftype]},
               "segm2addr": {"ret": itype,
                             "args": [p8,
                                      itype,
@@ -245,6 +264,22 @@ class LLVMContext_JIT(LLVMContext):
               "x86_cpuid": {"ret": itype,
                         "args": [itype,
                                  itype]},
+              "fcom_c0": {"ret": itype,
+                          "args": [dtype,
+                                   dtype]},
+              "fcom_c1": {"ret": itype,
+                          "args": [dtype,
+                                   dtype]},
+              "fcom_c2": {"ret": itype,
+                          "args": [dtype,
+                                   dtype]},
+              "fcom_c3": {"ret": itype,
+                          "args": [dtype,
+                                   dtype]},
+              "llvm.sqrt.f32": {"ret": ftype,
+                                "args": [ftype]},
+              "llvm.sqrt.f64": {"ret": dtype,
+                                "args": [dtype]},
         }
 
         for k in [8, 16]:
@@ -466,10 +501,7 @@ class LLVMFunction():
                           [llvm_ir.Constant(LLVMType.IntType(),
                                             offset)])
         regs = self.llvm_context.ir_arch.arch.regs
-        if hasattr(regs, "float_list") and expr in regs.float_list:
-            pointee_type = llvm_ir.DoubleType()
-        else:
-            pointee_type = LLVMType.IntType(expr.size)
+        pointee_type = LLVMType.IntType(expr.size)
         ptr_casted = builder.bitcast(ptr,
                                      llvm_ir.PointerType(pointee_type))
         # Store in cache
@@ -764,15 +796,19 @@ class LLVMFunction():
                 itype = LLVMType.IntType(expr.size)
                 cond_ok = self.builder.icmp_unsigned("<", count,
                                                      itype(expr.size))
+                zero = itype(0)
                 if op == ">>":
                     callback = builder.lshr
                 elif op == "<<":
                     callback = builder.shl
                 elif op == "a>>":
                     callback = builder.ashr
+                    # x a>> size is 0 or -1, depending on x sign
+                    cond_neg = self.builder.icmp_signed("<", value, zero)
+                    zero = self.builder.select(cond_neg, itype(-1), zero)
 
                 ret = self.builder.select(cond_ok, callback(value, count),
-                                          itype(0))
+                                          zero)
                 self.update_cache(expr, ret)
                 return ret
 
@@ -800,19 +836,118 @@ class LLVMFunction():
                 self.update_cache(expr, ret)
                 return ret
 
+            if op.startswith("sint_to_fp"):
+                fptype = LLVMType.fptype(expr.size)
+                arg = self.add_ir(expr.args[0])
+                ret = builder.sitofp(arg, fptype)
+                ret = builder.bitcast(ret, llvm_ir.IntType(expr.size))
+                self.update_cache(expr, ret)
+                return ret
 
+            if op == "fp_to_sint32":
+                size_arg = expr.args[0].size
+                fptype_orig = LLVMType.fptype(size_arg)
+                arg = self.add_ir(expr.args[0])
+                arg = builder.bitcast(arg, fptype_orig)
+                # Enforce IEEE-754 behavior. This could be enhanced with
+                # 'llvm.experimental.constrained.nearbyint'
+                if size_arg == 32:
+                    func = self.mod.get_global("llvm.nearbyint.f32")
+                elif size_arg == 64:
+                    func = self.mod.get_global("llvm.nearbyint.f64")
+                else:
+                    raise RuntimeError("Unsupported size")
+                rounded = builder.call(func, [arg])
+                ret = builder.fptoui(rounded, llvm_ir.IntType(expr.size))
+                self.update_cache(expr, ret)
+                return ret
 
-            if op in ["int_16_to_double", "int_32_to_double", "int_64_to_double",
-                      "mem_16_to_double", "mem_32_to_double", "mem_64_to_double"]:
+            if op.startswith("fpconvert_fp"):
+                assert len(expr.args) == 1
+                size_arg = expr.args[0].size
+                fptype = LLVMType.fptype(expr.size)
+                fptype_orig = LLVMType.fptype(size_arg)
                 arg = self.add_ir(expr.args[0])
-                ret = builder.uitofp(arg, llvm_ir.DoubleType())
+                arg = builder.bitcast(arg, fptype_orig)
+                if expr.size > size_arg:
+                    fc = builder.fpext
+                elif expr.size < size_arg:
+                    fc = builder.fptrunc
+                else:
+                    raise RuntimeError("Not supported, same size")
+                ret = fc(arg, fptype)
+                ret = builder.bitcast(ret, llvm_ir.IntType(expr.size))
+                self.update_cache(expr, ret)
+                return ret
+
+            if op.startswith("fpround_"):
+                assert len(expr.args) == 1
+                fptype = LLVMType.fptype(expr.size)
+                arg = self.add_ir(expr.args[0])
+                arg = builder.bitcast(arg, fptype)
+                if op == "fpround_towardszero" and expr.size == 32:
+                    fc = self.mod.get_global("llvm.trunc.f32")
+                else:
+                    raise RuntimeError("Not supported, same size")
+                rounded = builder.call(fc, [arg])
+                ret = builder.bitcast(rounded, llvm_ir.IntType(expr.size))
                 self.update_cache(expr, ret)
                 return ret
 
-            if op in ["double_to_int_16", "double_to_int_32", "double_to_int_64",
-                      "double_to_mem_16", "double_to_mem_32", "double_to_mem_64"]:
+            if op in ["fcom_c0", "fcom_c1", "fcom_c2", "fcom_c3"]:
+                arg1 = self.add_ir(expr.args[0])
+                arg2 = self.add_ir(expr.args[0])
+                fc_name = op
+                fc_ptr = self.mod.get_global(fc_name)
+                casted_args = [
+                    builder.bitcast(arg1, llvm_ir.DoubleType()),
+                    builder.bitcast(arg2, llvm_ir.DoubleType()),
+                ]
+                ret = builder.call(fc_ptr, casted_args)
+
+                # Cast ret if needed
+                ret_size = fc_ptr.return_value.type.width
+                if ret_size > expr.size:
+                    ret = builder.trunc(ret, LLVMType.IntType(expr.size))
+                self.update_cache(expr, ret)
+                return ret
+
+            if op in ["fsqrt"]:
                 arg = self.add_ir(expr.args[0])
-                ret = builder.fptoui(arg, llvm_ir.IntType(expr.size))
+
+                # Apply the correct sqrt func
+                if expr.size == 32:
+                    arg = builder.bitcast(arg, llvm_ir.FloatType())
+                    ret = builder.call(self.mod.get_global("llvm.sqrt.f32"),
+                                       [arg])
+                elif expr.size == 64:
+                    arg = builder.bitcast(arg, llvm_ir.DoubleType())
+                    ret = builder.call(self.mod.get_global("llvm.sqrt.f64"),
+                                       [arg])
+                else:
+                    raise RuntimeError("Unsupported precision: %x", expr.size)
+
+                ret = builder.bitcast(ret, llvm_ir.IntType(expr.size))
+                self.update_cache(expr, ret)
+                return ret
+
+            if op in ["fadd", "fmul", "fsub", "fdiv"]:
+                # More than 2 args not yet supported
+                assert len(expr.args) == 2
+                arg1 = self.add_ir(expr.args[0])
+                arg2 = self.add_ir(expr.args[1])
+                precision = LLVMType.fptype(expr.size)
+                arg1 = builder.bitcast(arg1, precision)
+                arg2 = builder.bitcast(arg2, precision)
+                if op == "fadd":
+                    ret = builder.fadd(arg1, arg2)
+                elif op == "fmul":
+                    ret = builder.fmul(arg1, arg2)
+                elif op == "fsub":
+                    ret = builder.fsub(arg1, arg2)
+                elif op == "fdiv":
+                    ret = builder.fdiv(arg1, arg2)
+                ret = builder.bitcast(ret, llvm_ir.IntType(expr.size))
                 self.update_cache(expr, ret)
                 return ret
 
@@ -832,10 +967,6 @@ class LLVMFunction():
                     callback = builder.urem
                 elif op == "/":
                     callback = builder.udiv
-                elif op == "fadd":
-                    callback = builder.fadd
-                elif op == "fdiv":
-                    callback = builder.fdiv
                 else:
                     raise NotImplementedError('Unknown op: %s' % op)
 
diff --git a/miasm2/jitter/op_semantics.c b/miasm2/jitter/op_semantics.c
index 0420532a..0bc3fcc5 100644
--- a/miasm2/jitter/op_semantics.c
+++ b/miasm2/jitter/op_semantics.c
@@ -355,147 +355,92 @@ void dump_float(void)
 	*/
 }
 
-double mem_32_to_double(unsigned int m)
+uint32_t fpu_fadd32(uint32_t a, uint32_t b)
 {
-	float f;
-	double d;
-
-	f = *((float*)&m);
-	d = f;
-#ifdef DEBUG_MIASM_DOUBLE
-	dump_float();
-	printf("%d float %e\n", m, d);
-#endif
-	return d;
-}
-
-
-double mem_64_to_double(uint64_t m)
-{
-	double d;
-	d = *((double*)&m);
+	float c;
+	c = *((float*)&a) + *((float*)&b);
 #ifdef DEBUG_MIASM_DOUBLE
 	dump_float();
-	printf("%"PRId64" double %e\n", m, d);
-#endif
-	return d;
-}
-
-double int_16_to_double(unsigned int m)
-{
-	double d;
-
-	d = (double)(m&0xffff);
-#ifdef DEBUG_MIASM_DOUBLE
-	dump_float();
-	printf("%d double %e\n", m, d);
-#endif
-	return d;
-}
-
-double int_32_to_double(unsigned int m)
-{
-	double d;
-
-	d = (double)m;
-#ifdef DEBUG_MIASM_DOUBLE
-	dump_float();
-	printf("%d double %e\n", m, d);
+	printf("%e + %e -> %e\n", a, b, c);
 #endif
-	return d;
+	return *((uint32_t*)&c);
 }
 
-double int_64_to_double(uint64_t m)
+uint64_t fpu_fadd64(uint64_t a, uint64_t b)
 {
-	double d;
-
-	d = (double)m;
+	double c;
+	c = *((double*)&a) + *((double*)&b);
 #ifdef DEBUG_MIASM_DOUBLE
 	dump_float();
-	printf("%"PRId64" double %e\n", m, d);
+	printf("%e + %e -> %e\n", a, b, c);
 #endif
-	return d;
+	return *((uint64_t*)&c);
 }
 
-int16_t double_to_int_16(double d)
+uint32_t fpu_fsub32(uint32_t a, uint32_t b)
 {
-	int16_t i;
-
-	i = (int16_t)d;
+	float c;
+	c = *((float*)&a) - *((float*)&b);
 #ifdef DEBUG_MIASM_DOUBLE
 	dump_float();
-	printf("%e int %d\n", d, i);
+	printf("%e + %e -> %e\n", a, b, c);
 #endif
-	return i;
+	return *((uint32_t*)&c);
 }
 
-int32_t double_to_int_32(double d)
+uint64_t fpu_fsub64(uint64_t a, uint64_t b)
 {
-	int32_t i;
-
-	i = (int32_t)d;
+	double c;
+	c = *((double*)&a) - *((double*)&b);
 #ifdef DEBUG_MIASM_DOUBLE
 	dump_float();
-	printf("%e int %d\n", d, i);
+	printf("%e + %e -> %e\n", a, b, c);
 #endif
-	return i;
+	return *((uint64_t*)&c);
 }
 
-int64_t double_to_int_64(double d)
+uint32_t fpu_fmul32(uint32_t a, uint32_t b)
 {
-	int64_t i;
-
-	i = (int64_t)d;
+	float c;
+	c = *((float*)&a) * *((float*)&b);
 #ifdef DEBUG_MIASM_DOUBLE
 	dump_float();
-	printf("%e int %"PRId64"\n", d, i);
+	printf("%e * %e -> %e\n", a, b, c);
 #endif
-	return i;
+	return *((uint32_t*)&c);
 }
 
-
-double fpu_fadd(double a, double b)
+uint64_t fpu_fmul64(uint64_t a, uint64_t b)
 {
 	double c;
-	c = a + b;
+	c = *((double*)&a) * *((double*)&b);
 #ifdef DEBUG_MIASM_DOUBLE
 	dump_float();
-	printf("%e + %e -> %e\n", a, b, c);
+	printf("%e * %e -> %e\n", a, b, c);
 #endif
-	return c;
+	return *((uint64_t*)&c);
 }
 
-double fpu_fsub(double a, double b)
+uint32_t fpu_fdiv32(uint32_t a, uint32_t b)
 {
-	double c;
-	c = a - b;
-#ifdef DEBUG_MIASM_DOUBLE
-	dump_float();
-	printf("%e - %e -> %e\n", a, b, c);
-#endif
-	return c;
-}
-
-double fpu_fmul(double a, double b)
-{
-	double c;
-	c = a * b;
+	float c;
+	c = *((float*)&a) / *((float*)&b);
 #ifdef DEBUG_MIASM_DOUBLE
 	dump_float();
 	printf("%e * %e -> %e\n", a, b, c);
 #endif
-	return c;
+	return *((uint32_t*)&c);
 }
 
-double fpu_fdiv(double a, double b)
+uint64_t fpu_fdiv64(uint64_t a, uint64_t b)
 {
 	double c;
-	c = a / b;
+	c = *((double*)&a) / *((double*)&b);
 #ifdef DEBUG_MIASM_DOUBLE
 	dump_float();
-	printf("%e / %e -> %e\n", a, b, c);
+	printf("%e * %e -> %e\n", a, b, c);
 #endif
-	return c;
+	return *((uint64_t*)&c);
 }
 
 double fpu_ftan(double a)
@@ -567,15 +512,26 @@ double fpu_f2xm1(double a)
 	return b;
 }
 
-double fpu_fsqrt(double a)
+uint32_t fpu_fsqrt32(uint32_t a)
+{
+	float b;
+	b = sqrtf(*((float*)&a));
+#ifdef DEBUG_MIASM_DOUBLE
+	dump_float();
+	printf("%e sqrt %e\n", a, b);
+#endif
+	return *((uint32_t*)&b);
+}
+
+uint64_t fpu_fsqrt64(uint64_t a)
 {
 	double b;
-	b = sqrt(a);
+	b = sqrt(*((double*)&a));
 #ifdef DEBUG_MIASM_DOUBLE
 	dump_float();
 	printf("%e sqrt %e\n", a, b);
 #endif
-	return b;
+	return *((uint64_t*)&b);
 }
 
 double fpu_fabs(double a)
@@ -751,30 +707,75 @@ unsigned int fpu_fxam_c3(double a)
 	}
 }
 
-unsigned int double_to_mem_32(double d)
+uint64_t sint64_to_fp64(int64_t a)
 {
-	unsigned int m;
-	float f;
-	f = d;
-	m = *((unsigned int*)&f);
-#ifdef DEBUG_MIASM_DOUBLE
-	dump_float();
-	printf("%d %e\n", m, d);
-#endif
-	return m;
+	double result = (double) a;
+	return *((uint64_t*)&result);
 }
 
-uint64_t double_to_mem_64(double d)
+uint32_t sint32_to_fp32(int32_t a)
 {
-	uint64_t m;
-	m = *((uint64_t*)&d);
-#ifdef DEBUG_MIASM_DOUBLE
-	dump_float();
-	printf("%"PRId64" %e\n", m, d);
-#endif
-	return m;
+	float result = (float) a;
+	return *((uint32_t*)&result);
+}
+
+uint64_t sint32_to_fp64(int32_t a)
+{
+	double result = (double) a;
+	return *((uint64_t*)&result);
 }
 
+int32_t fp32_to_sint32(uint32_t a)
+{
+	// Enforce nearbyint (IEEE-754 behavior)
+	float rounded = *((float*)&a);
+	rounded = nearbyintf(rounded);
+	return (int32_t) rounded;
+}
+
+int64_t fp64_to_sint64(uint64_t a)
+{
+	// Enforce nearbyint (IEEE-754 behavior)
+	double rounded = *((double*)&a);
+	rounded = nearbyint(rounded);
+	return (int64_t) rounded;
+}
+
+int32_t fp64_to_sint32(uint64_t a)
+{
+	// Enforce nearbyint (IEEE-754 behavior)
+	double rounded = *((double*)&a);
+	rounded = nearbyint(rounded);
+	return (int32_t) rounded;
+}
+
+uint32_t fp64_to_fp32(uint64_t a)
+{
+	float result = (float) *((double*)&a);
+	return *((uint32_t*)&result);
+}
+
+uint64_t fp32_to_fp64(uint32_t a)
+{
+	double result = (double) *((float*)&a);
+	return *((uint64_t*)&result);
+}
+
+uint32_t fpround_towardszero_fp32(uint32_t a)
+{
+	float rounded = *((float*)&a);
+	rounded = truncf(rounded);
+	return *((uint32_t*)&rounded);
+}
+
+uint64_t fpround_towardszero_fp64(uint64_t a)
+{
+	double rounded = *((float*)&a);
+	rounded = trunc(rounded);
+	return *((uint64_t*)&rounded);
+}
+
+
 UDIV(16)
 UDIV(32)
 UDIV(64)
diff --git a/miasm2/jitter/op_semantics.h b/miasm2/jitter/op_semantics.h
index 3eb81cff..f8042895 100644
--- a/miasm2/jitter/op_semantics.h
+++ b/miasm2/jitter/op_semantics.h
@@ -96,19 +96,23 @@ int16_t idiv16(int16_t a, int16_t b);
 int16_t imod16(int16_t a, int16_t b);
 
 unsigned int x86_cpuid(unsigned int a, unsigned int reg_num);
-double int2double(unsigned int m);
 
-double fpu_fadd(double a, double b);
-double fpu_fsub(double a, double b);
-double fpu_fmul(double a, double b);
-double fpu_fdiv(double a, double b);
+uint32_t fpu_fadd32(uint32_t a, uint32_t b);
+uint64_t fpu_fadd64(uint64_t a, uint64_t b);
+uint32_t fpu_fsub32(uint32_t a, uint32_t b);
+uint64_t fpu_fsub64(uint64_t a, uint64_t b);
+uint32_t fpu_fmul32(uint32_t a, uint32_t b);
+uint64_t fpu_fmul64(uint64_t a, uint64_t b);
+uint32_t fpu_fdiv32(uint32_t a, uint32_t b);
+uint64_t fpu_fdiv64(uint64_t a, uint64_t b);
 double fpu_ftan(double a);
 double fpu_frndint(double a);
 double fpu_fsin(double a);
 double fpu_fcos(double a);
 double fpu_fscale(double a, double b);
 double fpu_f2xm1(double a);
-double fpu_fsqrt(double a);
+uint32_t fpu_fsqrt32(uint32_t a);
+uint64_t fpu_fsqrt64(uint64_t a);
 double fpu_fabs(double a);
 double fpu_fprem(double a, double b);
 double fpu_fchs(double a);
@@ -124,18 +128,16 @@ unsigned int fpu_fxam_c1(double a);
 unsigned int fpu_fxam_c2(double a);
 unsigned int fpu_fxam_c3(double a);
 
-
-double mem_32_to_double(unsigned int m);
-double mem_64_to_double(uint64_t m);
-double int_16_to_double(unsigned int m);
-double int_32_to_double(unsigned int m);
-double int_64_to_double(uint64_t m);
-int16_t double_to_int_16(double d);
-int32_t double_to_int_32(double d);
-int64_t double_to_int_64(double d);
-unsigned int double_to_mem_32(double d);
-uint64_t double_to_mem_64(double d);
-
+uint64_t sint64_to_fp64(int64_t a);
+uint32_t sint32_to_fp32(int32_t a);
+uint64_t sint32_to_fp64(int32_t a);
+int32_t fp32_to_sint32(uint32_t a);
+int64_t fp64_to_sint64(uint64_t a);
+int32_t fp64_to_sint32(uint64_t a);
+uint32_t fp64_to_fp32(uint64_t a);
+uint64_t fp32_to_fp64(uint32_t a);
+uint32_t fpround_towardszero_fp32(uint32_t a);
+uint64_t fpround_towardszero_fp64(uint64_t a);
 
 #define SHIFT_RIGHT_ARITH(size, value, shift)				\
 	((uint ## size ## _t)((((uint64_t) (shift)) > ((size) - 1))?	\