about summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorAjax <commial@gmail.com>2017-01-03 14:16:16 +0100
committerAjax <commial@gmail.com>2017-01-04 17:14:55 +0100
commitc6dc6cb345393a95135254420cdd5452c7ff779a (patch)
treeafa353c31c581cc091caadb0cb75c113a1298b1d
parent7f7bab09b3fdf47ace9236bf5bdeacab6bd4c907 (diff)
downloadmiasm-c6dc6cb345393a95135254420cdd5452c7ff779a.tar.gz
miasm-c6dc6cb345393a95135254420cdd5452c7ff779a.zip
Clean-up LLVM ModuleRef manipulation
Diffstat (limited to '')
-rw-r--r--miasm2/jitter/jitcore_llvm.py10
-rw-r--r--miasm2/jitter/llvmconvert.py78
2 files changed, 44 insertions, 44 deletions
diff --git a/miasm2/jitter/jitcore_llvm.py b/miasm2/jitter/jitcore_llvm.py
index 6f3eca88..0f265073 100644
--- a/miasm2/jitter/jitcore_llvm.py
+++ b/miasm2/jitter/jitcore_llvm.py
@@ -21,8 +21,8 @@ class JitCore_LLVM(jitcore.JitCore):
     def __init__(self, ir_arch, bs=None):
         super(JitCore_LLVM, self).__init__(ir_arch, bs)
 
-        self.options.update({"safe_mode": False,   # Verify each function
-                             "optimise": False,     # Optimise functions
+        self.options.update({"safe_mode": True,   # Verify each function
+                             "optimise": True,     # Optimise functions
                              "log_func": False,    # Print LLVM functions
                              "log_assembly": False,  # Print assembly executed
                              })
@@ -60,12 +60,6 @@ class JitCore_LLVM(jitcore.JitCore):
         mod = importlib.import_module(mod_name)
         self.context.set_vmcpu(mod.get_gpreg_offset_all())
 
-        # Save module base
-        self.mod_base_str = str(self.context.mod)
-
-        # Set IRs transformation to apply
-        self.context.set_IR_transformation(self.ir_arch.expr_fix_regs_for_mode)
-
     def add_bloc(self, block):
         """Add a block to JiT and JiT it.
         @block: the block to add
diff --git a/miasm2/jitter/llvmconvert.py b/miasm2/jitter/llvmconvert.py
index dbc0d1b9..187562b8 100644
--- a/miasm2/jitter/llvmconvert.py
+++ b/miasm2/jitter/llvmconvert.py
@@ -58,7 +58,15 @@ class LLVMContext():
 
     def __init__(self, name="mod"):
         "Initialize a context with a module named 'name'"
-        self.new_module(name)
+        # Initialize llvm
+        llvm.initialize()
+        llvm.initialize_native_target()
+        llvm.initialize_native_asmprinter()
+
+        # Initilize target for compilation
+        target = llvm.Target.from_default_triple()
+        self.target_machine = target.create_target_machine()
+        self.init_exec_engine()
 
     def optimise_level(self, level=2):
         """Set the optimisation level to @level from 0 to 2
@@ -73,16 +81,15 @@ class LLVMContext():
         pmb.populate(pm)
         self.pass_manager = pm
 
+    def init_exec_engine(self):
+        mod = llvm.parse_assembly("")
+        engine = llvm.create_mcjit_compiler(mod,
+                                            self.target_machine)
+        self.exec_engine = engine
+
     def new_module(self, name="mod"):
+        """Create a module, with needed functions"""
         self.mod = llvm_ir.Module(name=name)
-        llvm.initialize()
-        llvm.initialize_native_target()
-        llvm.initialize_native_asmprinter()
-        target = llvm.Target.from_default_triple()
-        target_machine = target.create_target_machine()
-        backing_mod = llvm.parse_assembly("")
-        self.exec_engine = llvm.create_mcjit_compiler(backing_mod,
-                                                      target_machine)
         self.add_fc(self.known_fc)
 
     def get_execengine(self):
@@ -129,7 +136,6 @@ class LLVMContext_JIT(LLVMContext):
         self.arch_specific()
         LLVMContext.__init__(self, name)
         self.vmcpu = {}
-        self.engines = []
 
     def new_module(self, name="mod"):
         LLVMContext.new_module(self, name)
@@ -313,8 +319,11 @@ class LLVMContext_IRCompilation(LLVMContext):
         return builder.store(value, ptr_casted)
 
 class LLVMFunction():
+    """Represent a LLVM function
 
-    "Represent a llvm function"
+    Implementation note:
+    A new module is created each time to avoid cumulative lag (if @new_module)
+    """
 
     # Default logging values
     log_mn = False
@@ -337,10 +346,11 @@ class LLVMFunction():
                                      'bcdadd_cf': 'bcdadd_cf',
     }
 
-    def __init__(self, llvm_context, name="fc"):
-        "Create a new function with name fc"
+    def __init__(self, llvm_context, name="fc", new_module=True):
+        "Create a new function with name @name"
         self.llvm_context = llvm_context
-        self.llvm_context.new_module()
+        if new_module:
+            self.llvm_context.new_module()
         self.mod = self.llvm_context.get_module()
 
         self.my_args = []  # (Expr, LLVMType, Name)
@@ -350,12 +360,12 @@ class LLVMFunction():
 
         self.branch_counter = 0
         self.name = name
+        self._llvm_mod = None
 
     def new_branch_name(self):
         "Return a new branch name"
-
         self.branch_counter += 1
-        return "%s" % self.branch_counter
+        return str(self.branch_counter)
 
     def viewCFG(self):
         "Show the CFG of the current function"
@@ -1477,23 +1487,23 @@ class LLVMFunction():
 
     def __str__(self):
         "Print the llvm IR corresponding to the current module"
-
-        return str(self.fc)
+        return str(self.mod)
 
     def verify(self):
         "Verify the module syntax"
+        return self.as_llvm_mod().verify()
 
-        return self.mod.verify()
+    def get_bytecode(self):
+        "Return LLVM bitcode corresponding to the current module"
+        return self.as_llvm_mod().as_bitcode()
 
     def get_assembly(self):
         "Return native assembly corresponding to the current module"
-
-        return self.mod.to_native_assembly()
+        return self.llvm_context.target_machine.emit_assembly(self.as_llvm_mod())
 
     def optimise(self):
         "Optimise the function in place"
-        while self.llvm_context.pass_manager.run(self.fc):
-            continue
+        return self.llvm_context.pass_manager.run(self.as_llvm_mod())
 
     def __call__(self, *args):
         "Eval the function with arguments args"
@@ -1505,24 +1515,20 @@ class LLVMFunction():
 
         return ret.as_int()
 
+    def as_llvm_mod(self):
+        """Return a ModuleRef standing for the current function"""
+        if self._llvm_mod is None:
+            self._llvm_mod = llvm.parse_assembly(str(self.mod))
+        return self._llvm_mod
+
     def get_function_pointer(self):
         "Return a pointer on the Jitted function"
-        # Parse our generated module
-        mod = llvm.parse_assembly( str( self.mod ) )
-        mod.verify()
-
-        # Apply optimisation
-        self.llvm_context.get_passmanager().run(mod)
+        engine = self.llvm_context.get_execengine()
 
-        # Now add the module and make sure it is ready for execution
-        target = llvm.Target.from_default_triple()
-        target_machine = target.create_target_machine()
-        engine = llvm.create_mcjit_compiler(mod,
-                                            target_machine)
+        # Add the module and make sure it is ready for execution
+        engine.add_module(self.as_llvm_mod())
         engine.finalize_object()
 
-        # For debug: obj_bin = target_machine.emit_object(mod)
-        self.llvm_context.engines.append(engine)
         return engine.get_function_address(self.fc.name)
 
 # TODO: