diff options
| -rw-r--r-- | miasm2/jitter/jitcore_llvm.py | 81 | ||||
| -rw-r--r-- | miasm2/jitter/llvmconvert.py | 52 |
2 files changed, 111 insertions, 22 deletions
diff --git a/miasm2/jitter/jitcore_llvm.py b/miasm2/jitter/jitcore_llvm.py index 9f9a63e4..6a5c2036 100644 --- a/miasm2/jitter/jitcore_llvm.py +++ b/miasm2/jitter/jitcore_llvm.py @@ -1,6 +1,7 @@ import os import importlib -import hashlib +import tempfile +from hashlib import md5 from miasm2.jitter.llvmconvert import * import miasm2.jitter.jitcore as jitcore import Jitllvm @@ -28,9 +29,18 @@ class JitCore_LLVM(jitcore.JitCore): }) self.exec_wrapper = Jitllvm.llvm_exec_bloc - self.exec_engines = [] self.ir_arch = ir_arch + # Cache temporary dir + self.tempdir = os.path.join(tempfile.gettempdir(), "miasm_cache") + try: + os.mkdir(self.tempdir, 0755) + except OSError: + pass + if not os.access(self.tempdir, os.R_OK | os.W_OK): + raise RuntimeError( + 'Cannot access cache directory %s ' % self.tempdir) + def load(self): # Library to load within Jit context @@ -60,38 +70,65 @@ class JitCore_LLVM(jitcore.JitCore): mod = importlib.import_module(mod_name) self.context.set_vmcpu(mod.get_gpreg_offset_all()) + # Enable caching + self.context.enable_cache() + def add_bloc(self, block): """Add a block to JiT and JiT it. @block: the block to add """ - # TODO: caching using hash + block_hash = self.hash_block(block) + fname_out = os.path.join(self.tempdir, "%s.bc" % block_hash) + + if not os.access(fname_out, os.R_OK): + # Build a function in the context + func = LLVMFunction(self.context, block.label.name) - # Build a function in the context - func = LLVMFunction(self.context, block.label.name) + # Set log level + func.log_regs = self.log_regs + func.log_mn = self.log_mn - # Set log level - func.log_regs = self.log_regs - func.log_mn = self.log_mn + # Import asm block + func.from_asmblock(block) - # Import asm block - func.from_asmblock(block) + # Verify + if self.options["safe_mode"] is True: + func.verify() - # Verify - if self.options["safe_mode"] is True: - func.verify() + # Optimise + if self.options["optimise"] is True: + func.optimise() - # Optimise - if self.options["optimise"] is True: - func.optimise() + # Log + if self.options["log_func"] is True: + print func + if self.options["log_assembly"] is True: + print func.get_assembly() - # Log - if self.options["log_func"] is True: - print func - if self.options["log_assembly"] is True: - print func.get_assembly() + # Use propagate the cache filename + self.context.set_cache_filename(func, fname_out) + + # Get a pointer on the function for JiT + ptr = func.get_function_pointer() + + else: + # The cache file exists: function can be loaded from cache + ptr = self.context.get_ptr_from_cache(fname_out, block.label.name) # Store a pointer on the function jitted code - self.lbl2jitbloc[block.label.offset] = func.get_function_pointer() + self.lbl2jitbloc[block.label.offset] = ptr + + def hash_block(self, block): + """ + Build a hash of the block @block + @block: asmbloc + """ + block_raw = "".join(line.b for line in block.lines) + block_hash = md5("%X_%s_%s_%s" % (block.label.offset, + self.log_mn, + self.log_regs, + block_raw)).hexdigest() + return block_hash def jit_call(self, label, cpu, _vmmngr, breakpoints): """Call the function label with cpu and vmmngr states diff --git a/miasm2/jitter/llvmconvert.py b/miasm2/jitter/llvmconvert.py index cfe89059..4031d8f2 100644 --- a/miasm2/jitter/llvmconvert.py +++ b/miasm2/jitter/llvmconvert.py @@ -11,6 +11,7 @@ # # +import os from llvmlite import binding as llvm from llvmlite import ir as llvm_ir import miasm2.expression.expression as m2_expr @@ -294,6 +295,57 @@ class LLVMContext_JIT(LLVMContext): value]) + @staticmethod + def cache_notify(module, buffer): + """Called when @module has been compiled to @buffer""" + if not hasattr(module, "fname_out"): + return + fname_out = module.fname_out + + if os.access(fname_out, os.R_OK): + # No need to overwrite + return + + open(fname_out, "w").write(buffer) + + @staticmethod + def cache_getbuffer(module): + """Return a compiled buffer for @module if available""" + if not hasattr(module, "fname_out"): + return None + + fname_out = module.fname_out + if os.access(fname_out, os.R_OK): + return open(fname_out).read() + return None + + def enable_cache(self): + "Enable cache of compiled object" + # Load shared libraries + for lib_fname in self.library_filenames: + self.add_shared_library(lib_fname) + + # Activate cache + self.exec_engine.set_object_cache(self.cache_notify, + self.cache_getbuffer) + + def set_cache_filename(self, func, fname_out): + "Set the filename @fname_out to use for cache for @func" + # Use a custom attribute to propagate the cache filename + func.as_llvm_mod().fname_out = fname_out + + def get_ptr_from_cache(self, file_name, func_name): + "Load @file_name and return a pointer on the jitter @func_name" + # We use an empty module to avoid loosing time on function building + empty_module = llvm.parse_assembly("") + empty_module.fname_out = file_name + + engine = self.exec_engine + engine.add_module(empty_module) + engine.finalize_object() + return engine.get_function_address(func_name) + + class LLVMContext_IRCompilation(LLVMContext): """Extend LLVMContext in order to handle memory management and custom |