about summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorAjax <commial@gmail.com>2016-12-22 15:56:57 +0100
committerAjax <commial@gmail.com>2017-01-04 17:14:55 +0100
commit723a3d884aad1f8373bbc05365b0ebd5bc1b7e06 (patch)
treed71f6abbaf83cb7e1c6c0a0c5be52398ce7b8cd9
parent476cd9ac8a28ef106fbae13e83a810272d797f10 (diff)
downloadmiasm-723a3d884aad1f8373bbc05365b0ebd5bc1b7e06.tar.gz
miasm-723a3d884aad1f8373bbc05365b0ebd5bc1b7e06.zip
Refactor external function call for custom OP
-rw-r--r--miasm2/jitter/llvmconvert.py118
1 files changed, 67 insertions, 51 deletions
diff --git a/miasm2/jitter/llvmconvert.py b/miasm2/jitter/llvmconvert.py
index 2165fbc5..4f581db9 100644
--- a/miasm2/jitter/llvmconvert.py
+++ b/miasm2/jitter/llvmconvert.py
@@ -202,21 +202,31 @@ class LLVMContext_JIT(LLVMContext):
         "Add operations functions"
 
         p8 = llvm_ir.PointerType(LLVMType.IntType(8))
-        self.add_fc({"parity": {"ret": LLVMType.IntType(),
-                                "args": [LLVMType.IntType()]}})
-        self.add_fc({"rot_left": {"ret": LLVMType.IntType(),
-                                  "args": [LLVMType.IntType(),
-                                           LLVMType.IntType(),
-                                           LLVMType.IntType()]}})
-        self.add_fc({"rot_right": {"ret": LLVMType.IntType(),
-                                   "args": [LLVMType.IntType(),
-                                            LLVMType.IntType(),
-                                            LLVMType.IntType()]}})
-
-        self.add_fc({"segm2addr": {"ret": LLVMType.IntType(64),
+        itype = LLVMType.IntType(64)
+        self.add_fc({"parity": {"ret": LLVMType.IntType(1),
+                                "args": [itype]}})
+        self.add_fc({"rot_left": {"ret": itype,
+                                  "args": [itype,
+                                           itype,
+                                           itype]}})
+        self.add_fc({"rot_right": {"ret": itype,
+                                   "args": [itype,
+                                            itype,
+                                            itype]}})
+        self.add_fc({"rcr_rez_op": {"ret": itype,
+                                    "args": [itype,
+                                             itype,
+                                             itype,
+                                             itype]}})
+        self.add_fc({"rcl_rez_op": {"ret": itype,
+                                    "args": [itype,
+                                             itype,
+                                             itype,
+                                             itype]}})
+        self.add_fc({"segm2addr": {"ret": itype,
                                    "args": [p8,
-                                            LLVMType.IntType(64),
-                                            LLVMType.IntType(64)]}})
+                                            itype,
+                                            itype]}})
 
         for k in [8, 16]:
             self.add_fc({"bcdadd_%s" % k: {"ret": LLVMType.IntType(k),
@@ -307,6 +317,21 @@ class LLVMFunction():
     log_mn = False
     log_regs = True
 
+    # Operation translation
+    ## Basics
+    op_translate = {'parity': 'parity',
+    }
+    ## Add the size as first argument
+    op_translate_with_size = {'<<<': 'rot_left',
+                              '>>>': 'rot_right',
+                              '<<<c_rez': 'rcl_rez_op',
+                              '>>>c_rez': 'rcr_rez_op',
+    }
+    ## Add the size as suffix
+    op_translate_with_suffix_size = {'bcdadd': 'bcdadd',
+                                     'bcdadd_cf': 'bcdadd_cf',
+    }
+
     def __init__(self, llvm_context, name="fc"):
         "Create a new function with name fc"
         self.llvm_context = llvm_context
@@ -484,47 +509,38 @@ class LLVMFunction():
         if isinstance(expr, m2_expr.ExprOp):
             op = expr.op
 
-            if op == "parity":
-                fc_ptr = self.mod.get_global("parity")
-                arg = builder.zext(self.add_ir(expr.args[0]),
-                                   LLVMType.IntType())
-                ret = builder.call(fc_ptr, [arg])
-                ret = builder.trunc(ret, LLVMType.IntType(expr.size))
-                self.update_cache(expr, ret)
-                return ret
-
-            if op in ["<<<", ">>>"]:
-                fc_name = "rot_left" if op == "<<<" else "rot_right"
-                fc_ptr = self.mod.get_global(fc_name)
+            if (op in self.op_translate or
+                op in self.op_translate_with_size or
+                op in self.op_translate_with_suffix_size):
                 args = [self.add_ir(arg) for arg in expr.args]
                 arg_size = expr.args[0].size
-                if arg_size < 32:
-                    # Cast args
-                    args = [builder.zext(arg, LLVMType.IntType(32))
-                            for arg in args]
-                arg_size_cst = llvm_ir.Constant(LLVMType.IntType(),
-                                                   arg_size)
-                ret = builder.call(fc_ptr, [arg_size_cst] + args)
-                if arg_size < 32:
-                    # Cast ret
-                    ret = builder.trunc(ret, LLVMType.IntType(arg_size))
-                self.update_cache(expr, ret)
-                return ret
 
-            if op == "bcdadd":
-                size = expr.args[0].size
-                fc_ptr = self.mod.get_global("bcdadd_%s" % size)
-                args = [self.add_ir(arg) for arg in expr.args]
-                ret = builder.call(fc_ptr, args)
-                self.update_cache(expr, ret)
-                return ret
+                if op in self.op_translate_with_size:
+                    fc_name = self.op_translate_with_size[op]
+                    arg_size_cst = llvm_ir.Constant(LLVMType.IntType(64),
+                                                    arg_size)
+                    args = [arg_size_cst] + args
+                elif op in self.op_translate:
+                    fc_name = self.op_translate[op]
+                elif op in self.op_translate_with_suffix_size:
+                    fc_name = "%s_%s" % (self.op_translate[op], arg_size)
+
+                fc_ptr = self.mod.get_global(fc_name)
+
+                # Cast args if needed
+                casted_args = []
+                for i, arg in enumerate(args):
+                    if arg.type.width < fc_ptr.args[i].type.width:
+                        casted_args.append(builder.zext(arg, fc_ptr.args[i].type))
+                    else:
+                        casted_args.append(arg)
+                ret = builder.call(fc_ptr, casted_args)
+
+                # Cast ret if needed
+                ret_size = fc_ptr.return_value.type.width
+                if ret_size > expr.size:
+                    ret = builder.trunc(ret, LLVMType.IntType(expr.size))
 
-            if op == "bcdadd_cf":
-                size = expr.args[0].size
-                fc_ptr = self.mod.get_global("bcdadd_cf_%s" % size)
-                args = [self.add_ir(arg) for arg in expr.args]
-                ret = builder.call(fc_ptr, args)
-                ret = builder.trunc(ret, LLVMType.IntType(expr.size))
                 self.update_cache(expr, ret)
                 return ret