about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--miasm2/jitter/jitcore_llvm.py81
-rw-r--r--miasm2/jitter/llvmconvert.py52
2 files changed, 111 insertions, 22 deletions
diff --git a/miasm2/jitter/jitcore_llvm.py b/miasm2/jitter/jitcore_llvm.py
index 9f9a63e4..6a5c2036 100644
--- a/miasm2/jitter/jitcore_llvm.py
+++ b/miasm2/jitter/jitcore_llvm.py
@@ -1,6 +1,7 @@
 import os
 import importlib
-import hashlib
+import tempfile
+from hashlib import md5
 from miasm2.jitter.llvmconvert import *
 import miasm2.jitter.jitcore as jitcore
 import Jitllvm
@@ -28,9 +29,18 @@ class JitCore_LLVM(jitcore.JitCore):
                              })
 
         self.exec_wrapper = Jitllvm.llvm_exec_bloc
-        self.exec_engines = []
         self.ir_arch = ir_arch
 
+        # Cache temporary dir
+        self.tempdir = os.path.join(tempfile.gettempdir(), "miasm_cache")
+        try:
+            os.mkdir(self.tempdir, 0755)
+        except OSError:
+            pass
+        if not os.access(self.tempdir, os.R_OK | os.W_OK):
+            raise RuntimeError(
+                'Cannot access cache directory %s ' % self.tempdir)
+
     def load(self):
 
         # Library to load within Jit context
@@ -60,38 +70,65 @@ class JitCore_LLVM(jitcore.JitCore):
         mod = importlib.import_module(mod_name)
         self.context.set_vmcpu(mod.get_gpreg_offset_all())
 
+        # Enable caching
+        self.context.enable_cache()
+
     def add_bloc(self, block):
         """Add a block to JiT and JiT it.
         @block: the block to add
         """
-        # TODO: caching using hash
+        block_hash = self.hash_block(block)
+        fname_out = os.path.join(self.tempdir, "%s.bc" % block_hash)
+
+        if not os.access(fname_out, os.R_OK):
+            # Build a function in the context
+            func = LLVMFunction(self.context, block.label.name)
 
-        # Build a function in the context
-        func = LLVMFunction(self.context, block.label.name)
+            # Set log level
+            func.log_regs = self.log_regs
+            func.log_mn = self.log_mn
 
-        # Set log level
-        func.log_regs = self.log_regs
-        func.log_mn = self.log_mn
+            # Import asm block
+            func.from_asmblock(block)
 
-        # Import asm block
-        func.from_asmblock(block)
+            # Verify
+            if self.options["safe_mode"] is True:
+                func.verify()
 
-        # Verify
-        if self.options["safe_mode"] is True:
-            func.verify()
+            # Optimise
+            if self.options["optimise"] is True:
+                func.optimise()
 
-        # Optimise
-        if self.options["optimise"] is True:
-            func.optimise()
+            # Log
+            if self.options["log_func"] is True:
+                print func
+            if self.options["log_assembly"] is True:
+                print func.get_assembly()
 
-        # Log
-        if self.options["log_func"] is True:
-            print func
-        if self.options["log_assembly"] is True:
-            print func.get_assembly()
+            # Use propagate the cache filename
+            self.context.set_cache_filename(func, fname_out)
+
+            # Get a pointer on the function for JiT
+            ptr = func.get_function_pointer()
+
+        else:
+            # The cache file exists: function can be loaded from cache
+            ptr = self.context.get_ptr_from_cache(fname_out, block.label.name)
 
         # Store a pointer on the function jitted code
-        self.lbl2jitbloc[block.label.offset] = func.get_function_pointer()
+        self.lbl2jitbloc[block.label.offset] = ptr
+
+    def hash_block(self, block):
+        """
+        Build a hash of the block @block
+        @block: asmbloc
+        """
+        block_raw = "".join(line.b for line in block.lines)
+        block_hash = md5("%X_%s_%s_%s" % (block.label.offset,
+                                          self.log_mn,
+                                          self.log_regs,
+                                          block_raw)).hexdigest()
+        return block_hash
 
     def jit_call(self, label, cpu, _vmmngr, breakpoints):
         """Call the function label with cpu and vmmngr states
diff --git a/miasm2/jitter/llvmconvert.py b/miasm2/jitter/llvmconvert.py
index cfe89059..4031d8f2 100644
--- a/miasm2/jitter/llvmconvert.py
+++ b/miasm2/jitter/llvmconvert.py
@@ -11,6 +11,7 @@
 #
 #
 
+import os
 from llvmlite import binding as llvm
 from llvmlite import ir as llvm_ir
 import miasm2.expression.expression as m2_expr
@@ -294,6 +295,57 @@ class LLVMContext_JIT(LLVMContext):
                               value])
 
 
+    @staticmethod
+    def cache_notify(module, buffer):
+        """Called when @module has been compiled to @buffer"""
+        if not hasattr(module, "fname_out"):
+            return
+        fname_out = module.fname_out
+
+        if os.access(fname_out, os.R_OK):
+            # No need to overwrite
+            return
+
+        open(fname_out, "w").write(buffer)
+
+    @staticmethod
+    def cache_getbuffer(module):
+        """Return a compiled buffer for @module if available"""
+        if not hasattr(module, "fname_out"):
+            return None
+
+        fname_out = module.fname_out
+        if os.access(fname_out, os.R_OK):
+            return open(fname_out).read()
+        return None
+
+    def enable_cache(self):
+        "Enable cache of compiled object"
+        # Load shared libraries
+        for lib_fname in self.library_filenames:
+            self.add_shared_library(lib_fname)
+
+        # Activate cache
+        self.exec_engine.set_object_cache(self.cache_notify,
+                                          self.cache_getbuffer)
+
+    def set_cache_filename(self, func, fname_out):
+        "Set the filename @fname_out to use for cache for @func"
+        # Use a custom attribute to propagate the cache filename
+        func.as_llvm_mod().fname_out = fname_out
+
+    def get_ptr_from_cache(self, file_name, func_name):
+        "Load @file_name and return a pointer on the jitter @func_name"
+        # We use an empty module to avoid loosing time on function building
+        empty_module = llvm.parse_assembly("")
+        empty_module.fname_out = file_name
+
+        engine = self.exec_engine
+        engine.add_module(empty_module)
+        engine.finalize_object()
+        return engine.get_function_address(func_name)
+
+
 class LLVMContext_IRCompilation(LLVMContext):
 
     """Extend LLVMContext in order to handle memory management and custom