about summary refs log tree commit diff stats
path: root/miasm2/jitter/llvmconvert.py
diff options
context:
space:
mode:
Diffstat (limited to 'miasm2/jitter/llvmconvert.py')
-rw-r--r--miasm2/jitter/llvmconvert.py161
1 files changed, 146 insertions, 15 deletions
diff --git a/miasm2/jitter/llvmconvert.py b/miasm2/jitter/llvmconvert.py
index d63351cc..c4e6709d 100644
--- a/miasm2/jitter/llvmconvert.py
+++ b/miasm2/jitter/llvmconvert.py
@@ -51,6 +51,17 @@ class LLVMType(llvm_ir.Type):
         else:
             raise ValueError()
 
+    @classmethod
+    def fptype(cls, size):
+        """Return the floating type corresponding to precision @size"""
+        if size == 32:
+            precision = llvm_ir.FloatType()
+        elif size == 64:
+            precision = llvm_ir.DoubleType()
+        else:
+            raise RuntimeError("Unsupported precision: %x", size)
+        return precision
+
 
 class LLVMContext():
 
@@ -236,8 +247,16 @@ class LLVMContext_JIT(LLVMContext):
         i8 = LLVMType.IntType(8)
         p8 = llvm_ir.PointerType(i8)
         itype = LLVMType.IntType(64)
+        ftype = llvm_ir.FloatType()
+        dtype = llvm_ir.DoubleType()
         fc = {"llvm.ctpop.i8": {"ret": i8,
                                 "args": [i8]},
+              "llvm.nearbyint.f32": {"ret": ftype,
+                                     "args": [ftype]},
+              "llvm.nearbyint.f64": {"ret": dtype,
+                                     "args": [dtype]},
+              "llvm.trunc.f32": {"ret": ftype,
+                                 "args": [ftype]},
               "segm2addr": {"ret": itype,
                             "args": [p8,
                                      itype,
@@ -245,6 +264,22 @@ class LLVMContext_JIT(LLVMContext):
               "x86_cpuid": {"ret": itype,
                         "args": [itype,
                                  itype]},
+              "fcom_c0": {"ret": itype,
+                          "args": [dtype,
+                                   dtype]},
+              "fcom_c1": {"ret": itype,
+                          "args": [dtype,
+                                   dtype]},
+              "fcom_c2": {"ret": itype,
+                          "args": [dtype,
+                                   dtype]},
+              "fcom_c3": {"ret": itype,
+                          "args": [dtype,
+                                   dtype]},
+              "llvm.sqrt.f32": {"ret": ftype,
+                                "args": [ftype]},
+              "llvm.sqrt.f64": {"ret": dtype,
+                                "args": [dtype]},
         }
 
         for k in [8, 16]:
@@ -466,10 +501,7 @@ class LLVMFunction():
                           [llvm_ir.Constant(LLVMType.IntType(),
                                             offset)])
         regs = self.llvm_context.ir_arch.arch.regs
-        if hasattr(regs, "float_list") and expr in regs.float_list:
-            pointee_type = llvm_ir.DoubleType()
-        else:
-            pointee_type = LLVMType.IntType(expr.size)
+        pointee_type = LLVMType.IntType(expr.size)
         ptr_casted = builder.bitcast(ptr,
                                      llvm_ir.PointerType(pointee_type))
         # Store in cache
@@ -764,15 +796,19 @@ class LLVMFunction():
                 itype = LLVMType.IntType(expr.size)
                 cond_ok = self.builder.icmp_unsigned("<", count,
                                                      itype(expr.size))
+                zero = itype(0)
                 if op == ">>":
                     callback = builder.lshr
                 elif op == "<<":
                     callback = builder.shl
                 elif op == "a>>":
                     callback = builder.ashr
+                    # x a>> size is 0 or -1, depending on x sign
+                    cond_neg = self.builder.icmp_signed("<", value, zero)
+                    zero = self.builder.select(cond_neg, itype(-1), zero)
 
                 ret = self.builder.select(cond_ok, callback(value, count),
-                                          itype(0))
+                                          zero)
                 self.update_cache(expr, ret)
                 return ret
 
@@ -800,19 +836,118 @@ class LLVMFunction():
                 self.update_cache(expr, ret)
                 return ret
 
+            if op.startswith("sint_to_fp"):
+                fptype = LLVMType.fptype(expr.size)
+                arg = self.add_ir(expr.args[0])
+                ret = builder.sitofp(arg, fptype)
+                ret = builder.bitcast(ret, llvm_ir.IntType(expr.size))
+                self.update_cache(expr, ret)
+                return ret
 
+            if op == "fp_to_sint32":
+                size_arg = expr.args[0].size
+                fptype_orig = LLVMType.fptype(size_arg)
+                arg = self.add_ir(expr.args[0])
+                arg = builder.bitcast(arg, fptype_orig)
+                # Enforce IEEE-754 behavior. This could be enhanced with
+                # 'llvm.experimental.constrained.nearbyint'
+                if size_arg == 32:
+                    func = self.mod.get_global("llvm.nearbyint.f32")
+                elif size_arg == 64:
+                    func = self.mod.get_global("llvm.nearbyint.f64")
+                else:
+                    raise RuntimeError("Unsupported size")
+                rounded = builder.call(func, [arg])
+                ret = builder.fptoui(rounded, llvm_ir.IntType(expr.size))
+                self.update_cache(expr, ret)
+                return ret
 
-            if op in ["int_16_to_double", "int_32_to_double", "int_64_to_double",
-                      "mem_16_to_double", "mem_32_to_double", "mem_64_to_double"]:
+            if op.startswith("fpconvert_fp"):
+                assert len(expr.args) == 1
+                size_arg = expr.args[0].size
+                fptype = LLVMType.fptype(expr.size)
+                fptype_orig = LLVMType.fptype(size_arg)
                 arg = self.add_ir(expr.args[0])
-                ret = builder.uitofp(arg, llvm_ir.DoubleType())
+                arg = builder.bitcast(arg, fptype_orig)
+                if expr.size > size_arg:
+                    fc = builder.fpext
+                elif expr.size < size_arg:
+                    fc = builder.fptrunc
+                else:
+                    raise RuntimeError("Not supported, same size")
+                ret = fc(arg, fptype)
+                ret = builder.bitcast(ret, llvm_ir.IntType(expr.size))
+                self.update_cache(expr, ret)
+                return ret
+
+            if op.startswith("fpround_"):
+                assert len(expr.args) == 1
+                fptype = LLVMType.fptype(expr.size)
+                arg = self.add_ir(expr.args[0])
+                arg = builder.bitcast(arg, fptype)
+                if op == "fpround_towardszero" and expr.size == 32:
+                    fc = self.mod.get_global("llvm.trunc.f32")
+                else:
+                    raise RuntimeError("Not supported, same size")
+                rounded = builder.call(fc, [arg])
+                ret = builder.bitcast(rounded, llvm_ir.IntType(expr.size))
                 self.update_cache(expr, ret)
                 return ret
 
-            if op in ["double_to_int_16", "double_to_int_32", "double_to_int_64",
-                      "double_to_mem_16", "double_to_mem_32", "double_to_mem_64"]:
+            if op in ["fcom_c0", "fcom_c1", "fcom_c2", "fcom_c3"]:
+                arg1 = self.add_ir(expr.args[0])
+                arg2 = self.add_ir(expr.args[0])
+                fc_name = op
+                fc_ptr = self.mod.get_global(fc_name)
+                casted_args = [
+                    builder.bitcast(arg1, llvm_ir.DoubleType()),
+                    builder.bitcast(arg2, llvm_ir.DoubleType()),
+                ]
+                ret = builder.call(fc_ptr, casted_args)
+
+                # Cast ret if needed
+                ret_size = fc_ptr.return_value.type.width
+                if ret_size > expr.size:
+                    ret = builder.trunc(ret, LLVMType.IntType(expr.size))
+                self.update_cache(expr, ret)
+                return ret
+
+            if op in ["fsqrt"]:
                 arg = self.add_ir(expr.args[0])
-                ret = builder.fptoui(arg, llvm_ir.IntType(expr.size))
+
+                # Apply the correct sqrt func
+                if expr.size == 32:
+                    arg = builder.bitcast(arg, llvm_ir.FloatType())
+                    ret = builder.call(self.mod.get_global("llvm.sqrt.f32"),
+                                       [arg])
+                elif expr.size == 64:
+                    arg = builder.bitcast(arg, llvm_ir.DoubleType())
+                    ret = builder.call(self.mod.get_global("llvm.sqrt.f64"),
+                                       [arg])
+                else:
+                    raise RuntimeError("Unsupported precision: %x", expr.size)
+
+                ret = builder.bitcast(ret, llvm_ir.IntType(expr.size))
+                self.update_cache(expr, ret)
+                return ret
+
+            if op in ["fadd", "fmul", "fsub", "fdiv"]:
+                # More than 2 args not yet supported
+                assert len(expr.args) == 2
+                arg1 = self.add_ir(expr.args[0])
+                arg2 = self.add_ir(expr.args[1])
+                precision = LLVMType.fptype(expr.size)
+                arg1 = builder.bitcast(arg1, precision)
+                arg2 = builder.bitcast(arg2, precision)
+                if op == "fadd":
+                    ret = builder.fadd(arg1, arg2)
+                elif op == "fmul":
+                    ret = builder.fmul(arg1, arg2)
+                elif op == "fsub":
+                    ret = builder.fsub(arg1, arg2)
+                elif op == "fdiv":
+                    ret = builder.fdiv(arg1, arg2)
+                ret = builder.bitcast(ret, llvm_ir.IntType(expr.size))
                 self.update_cache(expr, ret)
                 return ret
 
@@ -832,10 +967,6 @@ class LLVMFunction():
                     callback = builder.urem
                 elif op == "/":
                     callback = builder.udiv
-                elif op == "fadd":
-                    callback = builder.fadd
-                elif op == "fdiv":
-                    callback = builder.fdiv
                 else:
                     raise NotImplementedError('Unknown op: %s' % op)