diff options
Diffstat (limited to '')
| -rw-r--r-- | miasm2/jitter/Jitllvm.c | 63 | ||||
| -rw-r--r-- | miasm2/jitter/jitcore.py | 21 | ||||
| -rw-r--r-- | miasm2/jitter/jitcore_llvm.py | 85 | ||||
| -rw-r--r-- | miasm2/jitter/jitcore_python.py | 13 | ||||
| -rw-r--r-- | miasm2/jitter/jitload.py | 2 | ||||
| -rw-r--r-- | miasm2/jitter/llvmconvert.py | 52 | ||||
| -rw-r--r-- | test/jitter/jit_options.py | 97 | ||||
| -rw-r--r-- | test/jitter/jitload.py | 3 | ||||
| -rw-r--r-- | test/jitter/vm_mngr.py | 3 | ||||
| -rwxr-xr-x | test/test_all.py | 5 |
10 files changed, 282 insertions, 62 deletions
diff --git a/miasm2/jitter/Jitllvm.c b/miasm2/jitter/Jitllvm.c index c176a4b2..b46f88e3 100644 --- a/miasm2/jitter/Jitllvm.c +++ b/miasm2/jitter/Jitllvm.c @@ -12,19 +12,66 @@ PyObject* llvm_exec_bloc(PyObject* self, PyObject* args) { - uint64_t func_addr; uint64_t (*func)(void*, void*, void*, uint8_t*); - uint64_t vm; + vm_cpu_t* cpu; + vm_mngr_t* vm; uint64_t ret; JitCpu* jitcpu; uint8_t status; - - if (!PyArg_ParseTuple(args, "KOK", &func_addr, &jitcpu, &vm)) + PyObject* func_py; + PyObject* lbl2ptr; + PyObject* breakpoints; + PyObject* retaddr = NULL; + uint64_t max_exec_per_call = 0; + uint64_t cpt; + int do_cpt; + + if (!PyArg_ParseTuple(args, "OOOO|K", + &retaddr, &jitcpu, &lbl2ptr, &breakpoints, + &max_exec_per_call)) return NULL; - vm_cpu_t* cpu = jitcpu->cpu; - func = (void *) (intptr_t) func_addr; - ret = func((void*) jitcpu, (void*)(intptr_t) cpu, (void*)(intptr_t) vm, &status); - return PyLong_FromUnsignedLongLong(ret); + + cpu = jitcpu->cpu; + vm = &(jitcpu->pyvm->vm_mngr); + /* The loop will decref retaddr always once */ + Py_INCREF(retaddr); + + if (max_exec_per_call == 0) { + do_cpt = 0; + cpt = 1; + } else { + do_cpt = 1; + cpt = max_exec_per_call; + } + + for (;;) { + // Handle cpt + if (cpt == 0) + return retaddr; + if (do_cpt) + cpt --; + + // Get the expected jitted function address + func_py = PyDict_GetItem(lbl2ptr, retaddr); + if (func_py) + func = PyLong_AsVoidPtr((PyObject*) func_py); + else + // retaddr is not jitted yet + return retaddr; + + // Execute it + ret = func((void*) jitcpu, (void*)(intptr_t) cpu, (void*)(intptr_t) vm, &status); + Py_DECREF(retaddr); + retaddr = PyLong_FromUnsignedLongLong(ret); + + // Check exception + if (status) + return retaddr; + + // Check breakpoint + if (PyDict_Contains(breakpoints, retaddr)) + return retaddr; + } } diff --git a/miasm2/jitter/jitcore.py b/miasm2/jitter/jitcore.py index f3a79bee..7e831280 100644 --- a/miasm2/jitter/jitcore.py +++ b/miasm2/jitter/jitcore.py @@ -165,33 +165,22 @@ class JitCore(object): # Update jitcode mem range self.add_bloc_to_mem_interval(vm, cur_bloc) - def jit_call(self, label, cpu, _vmmngr, breakpoints): - """Call the function label with cpu and vmmngr states - @label: function's label - @cpu: JitCpu instance - @breakpoints: Dict instance of used breakpoints - """ - return self.exec_wrapper(label, cpu, self.lbl2jitbloc.data, breakpoints, - self.options["max_exec_per_call"]) - - def runbloc(self, cpu, vm, lbl, breakpoints): + def runbloc(self, cpu, lbl, breakpoints): """Run the bloc starting at lbl. @cpu: JitCpu instance - @vm: VmMngr instance @lbl: target label """ if lbl is None: - lbl = cpu.get_gpreg()[self.ir_arch.pc.name] + lbl = getattr(cpu, self.ir_arch.pc.name) if not lbl in self.lbl2jitbloc: # Need to JiT the bloc - self.disbloc(lbl, vm) + self.disbloc(lbl, cpu.vmmngr) # Run the bloc and update cpu/vmmngr state - ret = self.jit_call(lbl, cpu, vm, breakpoints) - - return ret + return self.exec_wrapper(lbl, cpu, self.lbl2jitbloc.data, breakpoints, + self.options["max_exec_per_call"]) def blocs2memrange(self, blocs): """Return an interval instance standing for blocs addresses diff --git a/miasm2/jitter/jitcore_llvm.py b/miasm2/jitter/jitcore_llvm.py index 0f265073..8f58f1da 100644 --- a/miasm2/jitter/jitcore_llvm.py +++ b/miasm2/jitter/jitcore_llvm.py @@ -1,6 +1,7 @@ import os import importlib -import hashlib +import tempfile +from hashlib import md5 from miasm2.jitter.llvmconvert import * import miasm2.jitter.jitcore as jitcore import Jitllvm @@ -28,9 +29,18 @@ class JitCore_LLVM(jitcore.JitCore): }) self.exec_wrapper = Jitllvm.llvm_exec_bloc - self.exec_engines = [] self.ir_arch = ir_arch + # Cache temporary dir + self.tempdir = os.path.join(tempfile.gettempdir(), "miasm_cache") + try: + os.mkdir(self.tempdir, 0755) + except OSError: + pass + if not os.access(self.tempdir, os.R_OK | os.W_OK): + raise RuntimeError( + 'Cannot access cache directory %s ' % self.tempdir) + def load(self): # Library to load within Jit context @@ -60,43 +70,62 @@ class JitCore_LLVM(jitcore.JitCore): mod = importlib.import_module(mod_name) self.context.set_vmcpu(mod.get_gpreg_offset_all()) + # Enable caching + self.context.enable_cache() + def add_bloc(self, block): """Add a block to JiT and JiT it. @block: the block to add """ - # TODO: caching using hash + block_hash = self.hash_block(block) + fname_out = os.path.join(self.tempdir, "%s.bc" % block_hash) - # Build a function in the context - func = LLVMFunction(self.context, block.label.name) + if not os.access(fname_out, os.R_OK): + # Build a function in the context + func = LLVMFunction(self.context, block.label.name) - # Set log level - func.log_regs = self.log_regs - func.log_mn = self.log_mn + # Set log level + func.log_regs = self.log_regs + func.log_mn = self.log_mn - # Import asm block - func.from_asmblock(block) + # Import asm block + func.from_asmblock(block) - # Verify - if self.options["safe_mode"] is True: - func.verify() + # Verify + if self.options["safe_mode"] is True: + func.verify() - # Optimise - if self.options["optimise"] is True: - func.optimise() + # Optimise + if self.options["optimise"] is True: + func.optimise() - # Log - if self.options["log_func"] is True: - print func - if self.options["log_assembly"] is True: - print func.get_assembly() + # Log + if self.options["log_func"] is True: + print func + if self.options["log_assembly"] is True: + print func.get_assembly() + + # Use propagate the cache filename + self.context.set_cache_filename(func, fname_out) + + # Get a pointer on the function for JiT + ptr = func.get_function_pointer() + + else: + # The cache file exists: function can be loaded from cache + ptr = self.context.get_ptr_from_cache(fname_out, block.label.name) # Store a pointer on the function jitted code - self.lbl2jitbloc[block.label.offset] = func.get_function_pointer() + self.lbl2jitbloc[block.label.offset] = ptr - def jit_call(self, label, cpu, _vmmngr, breakpoints): - """Call the function label with cpu and vmmngr states - @label: function's label - @cpu: JitCpu instance - @breakpoints: Dict instance of used breakpoints + def hash_block(self, block): + """ + Build a hash of the block @block + @block: asmbloc """ - return self.exec_wrapper(self.lbl2jitbloc[label], cpu, cpu.vmmngr.vmmngr) + block_raw = "".join(line.b for line in block.lines) + block_hash = md5("%X_%s_%s_%s" % (block.label.offset, + self.log_mn, + self.log_regs, + block_raw)).hexdigest() + return block_hash diff --git a/miasm2/jitter/jitcore_python.py b/miasm2/jitter/jitcore_python.py index 87259f71..27666ab4 100644 --- a/miasm2/jitter/jitcore_python.py +++ b/miasm2/jitter/jitcore_python.py @@ -38,11 +38,12 @@ class JitCore_Python(jitcore.JitCore): @irblocs: a gorup of irblocs """ - def myfunc(cpu, vmmngr): + def myfunc(cpu): """Execute the function according to cpu and vmmngr states @cpu: JitCpu instance - @vm: VmMngr instance """ + # Get virtual memory handler + vmmngr = cpu.vmmngr # Keep current location in irblocs cur_label = label @@ -125,15 +126,15 @@ class JitCore_Python(jitcore.JitCore): # Associate myfunc with current label self.lbl2jitbloc[label.offset] = myfunc - def jit_call(self, label, cpu, vmmngr, _breakpoints): - """Call the function label with cpu and vmmngr states + def exec_wrapper(self, label, cpu, _lbl2jitbloc, _breakpoints, + _max_exec_per_call): + """Call the function @label with @cpu @label: function's label @cpu: JitCpu instance - @vm: VmMngr instance """ # Get Python function corresponding to @label fc_ptr = self.lbl2jitbloc[label] # Execute the function - return fc_ptr(cpu, vmmngr) + return fc_ptr(cpu) diff --git a/miasm2/jitter/jitload.py b/miasm2/jitter/jitload.py index f23c78c0..bc09e1f2 100644 --- a/miasm2/jitter/jitload.py +++ b/miasm2/jitter/jitload.py @@ -296,7 +296,7 @@ class jitter: """Wrapper on JiT backend. Run the code at PC and return the next PC. @pc: address of code to run""" - return self.jit.runbloc(self.cpu, self.vm, pc, self.breakpoints_handler.callbacks) + return self.jit.runbloc(self.cpu, pc, self.breakpoints_handler.callbacks) def runiter_once(self, pc): """Iterator on callbacks results on code running from PC. diff --git a/miasm2/jitter/llvmconvert.py b/miasm2/jitter/llvmconvert.py index cfe89059..4031d8f2 100644 --- a/miasm2/jitter/llvmconvert.py +++ b/miasm2/jitter/llvmconvert.py @@ -11,6 +11,7 @@ # # +import os from llvmlite import binding as llvm from llvmlite import ir as llvm_ir import miasm2.expression.expression as m2_expr @@ -294,6 +295,57 @@ class LLVMContext_JIT(LLVMContext): value]) + @staticmethod + def cache_notify(module, buffer): + """Called when @module has been compiled to @buffer""" + if not hasattr(module, "fname_out"): + return + fname_out = module.fname_out + + if os.access(fname_out, os.R_OK): + # No need to overwrite + return + + open(fname_out, "w").write(buffer) + + @staticmethod + def cache_getbuffer(module): + """Return a compiled buffer for @module if available""" + if not hasattr(module, "fname_out"): + return None + + fname_out = module.fname_out + if os.access(fname_out, os.R_OK): + return open(fname_out).read() + return None + + def enable_cache(self): + "Enable cache of compiled object" + # Load shared libraries + for lib_fname in self.library_filenames: + self.add_shared_library(lib_fname) + + # Activate cache + self.exec_engine.set_object_cache(self.cache_notify, + self.cache_getbuffer) + + def set_cache_filename(self, func, fname_out): + "Set the filename @fname_out to use for cache for @func" + # Use a custom attribute to propagate the cache filename + func.as_llvm_mod().fname_out = fname_out + + def get_ptr_from_cache(self, file_name, func_name): + "Load @file_name and return a pointer on the jitter @func_name" + # We use an empty module to avoid loosing time on function building + empty_module = llvm.parse_assembly("") + empty_module.fname_out = file_name + + engine = self.exec_engine + engine.add_module(empty_module) + engine.finalize_object() + return engine.get_function_address(func_name) + + class LLVMContext_IRCompilation(LLVMContext): """Extend LLVMContext in order to handle memory management and custom diff --git a/test/jitter/jit_options.py b/test/jitter/jit_options.py new file mode 100644 index 00000000..cc955c64 --- /dev/null +++ b/test/jitter/jit_options.py @@ -0,0 +1,97 @@ +import os +import sys +from miasm2.jitter.csts import PAGE_READ, PAGE_WRITE +from miasm2.analysis.machine import Machine +from pdb import pm + +# Shellcode + +# main: +# MOV EAX, 0x1 +# loop_main: +# CMP EAX, 0x10 +# JZ loop_end +# loop_inc: +# INC EAX +# JMP loop_main +# loop_end: +# RET +data = "b80100000083f810740340ebf8c3".decode("hex") +run_addr = 0x40000000 + +def code_sentinelle(jitter): + jitter.run = False + jitter.pc = 0 + return True + +def init_jitter(): + global data, run_addr + # Create jitter + myjit = Machine("x86_32").jitter(sys.argv[1]) + + myjit.vm.add_memory_page(run_addr, PAGE_READ | PAGE_WRITE, data) + + # Init jitter + myjit.init_stack() + myjit.jit.log_regs = True + myjit.jit.log_mn = True + myjit.push_uint32_t(0x1337beef) + + myjit.add_breakpoint(0x1337beef, code_sentinelle) + return myjit + +# Test 'max_exec_per_call' +print "[+] First run, to jit blocks" +myjit = init_jitter() +myjit.init_run(run_addr) +myjit.continue_run() + +assert myjit.run is False +assert myjit.cpu.EAX == 0x10 + +## Let's specify a max_exec_per_call +## 5: main, loop_main, loop_inc, loop_main, loop_inc +myjit.jit.options["max_exec_per_call"] = 5 + +first_call = True +def cb(jitter): + global first_call + if first_call: + # Avoid breaking on the first pass (before any execution) + first_call = False + return True + return False + +## Second run +print "[+] Second run" +myjit.push_uint32_t(0x1337beef) +myjit.cpu.EAX = 0 +myjit.init_run(run_addr) +myjit.exec_cb = cb +myjit.continue_run() + +assert myjit.run is True +# Use a '<=' because it's a 'max_...' +assert myjit.cpu.EAX <= 3 + +# Test 'jit_maxline' +print "[+] Run instr one by one" +myjit = init_jitter() +myjit.jit.options["jit_maxline"] = 1 +myjit.jit.options["max_exec_per_call"] = 1 + +counter = 0 +def cb(jitter): + global counter + counter += 1 + return True + +myjit.init_run(run_addr) +myjit.exec_cb = cb +myjit.continue_run() + +assert myjit.run is False +assert myjit.cpu.EAX == 0x10 +## dry(1) + main(1) + (loop_main(2) + loop_inc(2))*(0x10 - 1) + loop_main(2) + +## loop_end(1) = 65 +assert counter == 65 diff --git a/test/jitter/jitload.py b/test/jitter/jitload.py index 283298db..544e9d18 100644 --- a/test/jitter/jitload.py +++ b/test/jitter/jitload.py @@ -1,3 +1,4 @@ +import sys from pdb import pm from miasm2.jitter.csts import PAGE_READ, PAGE_WRITE @@ -9,7 +10,7 @@ from miasm2.expression.expression import ExprId, ExprInt32, ExprInt64, ExprAff, data = "8d49048d5b0180f90174058d5bffeb038d5b0189d8c3".decode("hex") # Init jitter -myjit = Machine("x86_32").jitter() +myjit = Machine("x86_32").jitter(sys.argv[1]) myjit.init_stack() run_addr = 0x40000000 diff --git a/test/jitter/vm_mngr.py b/test/jitter/vm_mngr.py index b2b7336b..87bc6f8f 100644 --- a/test/jitter/vm_mngr.py +++ b/test/jitter/vm_mngr.py @@ -1,7 +1,8 @@ +import sys from miasm2.jitter.csts import PAGE_READ, PAGE_WRITE from miasm2.analysis.machine import Machine -myjit = Machine("x86_32").jitter() +myjit = Machine("x86_32").jitter(sys.argv[1]) base_addr = 0x13371337 page_size = 0x1000 diff --git a/test/test_all.py b/test/test_all.py index bec0c78d..59624832 100755 --- a/test/test_all.py +++ b/test/test_all.py @@ -325,8 +325,11 @@ for i, test_args in enumerate(test_args): ## Jitter for script in ["jitload.py", "vm_mngr.py", + "jit_options.py", ]: - testset += RegressionTest([script], base_dir="jitter", tags=[TAGS["tcc"]]) + for engine in ArchUnitTest.jitter_engines: + testset += RegressionTest([script, engine], base_dir="jitter", + tags=[TAGS.get(engine,None)]) # Examples |