about summary refs log tree commit diff stats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--miasm2/jitter/Jitllvm.c63
-rw-r--r--miasm2/jitter/jitcore.py21
-rw-r--r--miasm2/jitter/jitcore_llvm.py85
-rw-r--r--miasm2/jitter/jitcore_python.py13
-rw-r--r--miasm2/jitter/jitload.py2
-rw-r--r--miasm2/jitter/llvmconvert.py52
-rw-r--r--test/jitter/jit_options.py97
-rw-r--r--test/jitter/jitload.py3
-rw-r--r--test/jitter/vm_mngr.py3
-rwxr-xr-xtest/test_all.py5
10 files changed, 282 insertions, 62 deletions
diff --git a/miasm2/jitter/Jitllvm.c b/miasm2/jitter/Jitllvm.c
index c176a4b2..b46f88e3 100644
--- a/miasm2/jitter/Jitllvm.c
+++ b/miasm2/jitter/Jitllvm.c
@@ -12,19 +12,66 @@
 
 PyObject* llvm_exec_bloc(PyObject* self, PyObject* args)
 {
-	uint64_t func_addr;
 	uint64_t (*func)(void*, void*, void*, uint8_t*);
-	uint64_t vm;
+	vm_cpu_t* cpu;
+	vm_mngr_t* vm;
 	uint64_t ret;
 	JitCpu* jitcpu;
 	uint8_t status;
-	
-	if (!PyArg_ParseTuple(args, "KOK", &func_addr, &jitcpu, &vm))
+	PyObject* func_py;
+	PyObject* lbl2ptr;
+	PyObject* breakpoints;
+	PyObject* retaddr = NULL;
+	uint64_t max_exec_per_call = 0;
+	uint64_t cpt;
+	int do_cpt;
+
+	if (!PyArg_ParseTuple(args, "OOOO|K",
+			      &retaddr, &jitcpu, &lbl2ptr, &breakpoints,
+			      &max_exec_per_call))
 		return NULL;
-	vm_cpu_t* cpu = jitcpu->cpu;
-	func = (void *) (intptr_t) func_addr;
-	ret = func((void*) jitcpu, (void*)(intptr_t) cpu, (void*)(intptr_t) vm, &status);
-	return PyLong_FromUnsignedLongLong(ret);
+
+	cpu = jitcpu->cpu;
+	vm = &(jitcpu->pyvm->vm_mngr);
+	/* The loop will decref retaddr always once */
+	Py_INCREF(retaddr);
+
+	if (max_exec_per_call == 0) {
+		do_cpt = 0;
+		cpt = 1;
+	} else {
+		do_cpt = 1;
+		cpt = max_exec_per_call;
+	}
+
+	for (;;) {
+		// Handle cpt
+		if (cpt == 0)
+			return retaddr;
+		if (do_cpt)
+			cpt --;
+
+		// Get the expected jitted function address
+		func_py = PyDict_GetItem(lbl2ptr, retaddr);
+		if (func_py)
+			func = PyLong_AsVoidPtr((PyObject*) func_py);
+		else
+			// retaddr is not jitted yet
+			return retaddr;
+
+		// Execute it
+		ret = func((void*) jitcpu, (void*)(intptr_t) cpu, (void*)(intptr_t) vm, &status);
+		Py_DECREF(retaddr);
+		retaddr = PyLong_FromUnsignedLongLong(ret);
+
+		// Check exception
+		if (status)
+			return retaddr;
+
+		// Check breakpoint
+		if (PyDict_Contains(breakpoints, retaddr))
+			return retaddr;
+	}
 }
 
 
diff --git a/miasm2/jitter/jitcore.py b/miasm2/jitter/jitcore.py
index f3a79bee..7e831280 100644
--- a/miasm2/jitter/jitcore.py
+++ b/miasm2/jitter/jitcore.py
@@ -165,33 +165,22 @@ class JitCore(object):
         # Update jitcode mem range
         self.add_bloc_to_mem_interval(vm, cur_bloc)
 
-    def jit_call(self, label, cpu, _vmmngr, breakpoints):
-        """Call the function label with cpu and vmmngr states
-        @label: function's label
-        @cpu: JitCpu instance
-        @breakpoints: Dict instance of used breakpoints
-        """
-        return self.exec_wrapper(label, cpu, self.lbl2jitbloc.data, breakpoints,
-                                 self.options["max_exec_per_call"])
-
-    def runbloc(self, cpu, vm, lbl, breakpoints):
+    def runbloc(self, cpu, lbl, breakpoints):
         """Run the bloc starting at lbl.
         @cpu: JitCpu instance
-        @vm: VmMngr instance
         @lbl: target label
         """
 
         if lbl is None:
-            lbl = cpu.get_gpreg()[self.ir_arch.pc.name]
+            lbl = getattr(cpu, self.ir_arch.pc.name)
 
         if not lbl in self.lbl2jitbloc:
             # Need to JiT the bloc
-            self.disbloc(lbl, vm)
+            self.disbloc(lbl, cpu.vmmngr)
 
         # Run the bloc and update cpu/vmmngr state
-        ret = self.jit_call(lbl, cpu, vm, breakpoints)
-
-        return ret
+        return self.exec_wrapper(lbl, cpu, self.lbl2jitbloc.data, breakpoints,
+                                 self.options["max_exec_per_call"])
 
     def blocs2memrange(self, blocs):
         """Return an interval instance standing for blocs addresses
diff --git a/miasm2/jitter/jitcore_llvm.py b/miasm2/jitter/jitcore_llvm.py
index 0f265073..8f58f1da 100644
--- a/miasm2/jitter/jitcore_llvm.py
+++ b/miasm2/jitter/jitcore_llvm.py
@@ -1,6 +1,7 @@
 import os
 import importlib
-import hashlib
+import tempfile
+from hashlib import md5
 from miasm2.jitter.llvmconvert import *
 import miasm2.jitter.jitcore as jitcore
 import Jitllvm
@@ -28,9 +29,18 @@ class JitCore_LLVM(jitcore.JitCore):
                              })
 
         self.exec_wrapper = Jitllvm.llvm_exec_bloc
-        self.exec_engines = []
         self.ir_arch = ir_arch
 
+        # Cache temporary dir
+        self.tempdir = os.path.join(tempfile.gettempdir(), "miasm_cache")
+        try:
+            os.mkdir(self.tempdir, 0755)
+        except OSError:
+            pass
+        if not os.access(self.tempdir, os.R_OK | os.W_OK):
+            raise RuntimeError(
+                'Cannot access cache directory %s ' % self.tempdir)
+
     def load(self):
 
         # Library to load within Jit context
@@ -60,43 +70,62 @@ class JitCore_LLVM(jitcore.JitCore):
         mod = importlib.import_module(mod_name)
         self.context.set_vmcpu(mod.get_gpreg_offset_all())
 
+        # Enable caching
+        self.context.enable_cache()
+
     def add_bloc(self, block):
         """Add a block to JiT and JiT it.
         @block: the block to add
         """
-        # TODO: caching using hash
+        block_hash = self.hash_block(block)
+        fname_out = os.path.join(self.tempdir, "%s.bc" % block_hash)
 
-        # Build a function in the context
-        func = LLVMFunction(self.context, block.label.name)
+        if not os.access(fname_out, os.R_OK):
+            # Build a function in the context
+            func = LLVMFunction(self.context, block.label.name)
 
-        # Set log level
-        func.log_regs = self.log_regs
-        func.log_mn = self.log_mn
+            # Set log level
+            func.log_regs = self.log_regs
+            func.log_mn = self.log_mn
 
-        # Import asm block
-        func.from_asmblock(block)
+            # Import asm block
+            func.from_asmblock(block)
 
-        # Verify
-        if self.options["safe_mode"] is True:
-            func.verify()
+            # Verify
+            if self.options["safe_mode"] is True:
+                func.verify()
 
-        # Optimise
-        if self.options["optimise"] is True:
-            func.optimise()
+            # Optimise
+            if self.options["optimise"] is True:
+                func.optimise()
 
-        # Log
-        if self.options["log_func"] is True:
-            print func
-        if self.options["log_assembly"] is True:
-            print func.get_assembly()
+            # Log
+            if self.options["log_func"] is True:
+                print func
+            if self.options["log_assembly"] is True:
+                print func.get_assembly()
+
+            # Use propagate the cache filename
+            self.context.set_cache_filename(func, fname_out)
+
+            # Get a pointer on the function for JiT
+            ptr = func.get_function_pointer()
+
+        else:
+            # The cache file exists: function can be loaded from cache
+            ptr = self.context.get_ptr_from_cache(fname_out, block.label.name)
 
         # Store a pointer on the function jitted code
-        self.lbl2jitbloc[block.label.offset] = func.get_function_pointer()
+        self.lbl2jitbloc[block.label.offset] = ptr
 
-    def jit_call(self, label, cpu, _vmmngr, breakpoints):
-        """Call the function label with cpu and vmmngr states
-        @label: function's label
-        @cpu: JitCpu instance
-        @breakpoints: Dict instance of used breakpoints
+    def hash_block(self, block):
+        """
+        Build a hash of the block @block
+        @block: asmbloc
         """
-        return self.exec_wrapper(self.lbl2jitbloc[label], cpu, cpu.vmmngr.vmmngr)
+        block_raw = "".join(line.b for line in block.lines)
+        block_hash = md5("%X_%s_%s_%s" % (block.label.offset,
+                                          self.log_mn,
+                                          self.log_regs,
+                                          block_raw)).hexdigest()
+        return block_hash
diff --git a/miasm2/jitter/jitcore_python.py b/miasm2/jitter/jitcore_python.py
index 87259f71..27666ab4 100644
--- a/miasm2/jitter/jitcore_python.py
+++ b/miasm2/jitter/jitcore_python.py
@@ -38,11 +38,12 @@ class JitCore_Python(jitcore.JitCore):
         @irblocs: a gorup of irblocs
         """
 
-        def myfunc(cpu, vmmngr):
+        def myfunc(cpu):
             """Execute the function according to cpu and vmmngr states
             @cpu: JitCpu instance
-            @vm: VmMngr instance
             """
+            # Get virtual memory handler
+            vmmngr = cpu.vmmngr
 
             # Keep current location in irblocs
             cur_label = label
@@ -125,15 +126,15 @@ class JitCore_Python(jitcore.JitCore):
         # Associate myfunc with current label
         self.lbl2jitbloc[label.offset] = myfunc
 
-    def jit_call(self, label, cpu, vmmngr, _breakpoints):
-        """Call the function label with cpu and vmmngr states
+    def exec_wrapper(self, label, cpu, _lbl2jitbloc, _breakpoints,
+                     _max_exec_per_call):
+        """Call the function @label with @cpu
         @label: function's label
         @cpu: JitCpu instance
-        @vm: VmMngr instance
         """
 
         # Get Python function corresponding to @label
         fc_ptr = self.lbl2jitbloc[label]
 
         # Execute the function
-        return fc_ptr(cpu, vmmngr)
+        return fc_ptr(cpu)
diff --git a/miasm2/jitter/jitload.py b/miasm2/jitter/jitload.py
index f23c78c0..bc09e1f2 100644
--- a/miasm2/jitter/jitload.py
+++ b/miasm2/jitter/jitload.py
@@ -296,7 +296,7 @@ class jitter:
         """Wrapper on JiT backend. Run the code at PC and return the next PC.
         @pc: address of code to run"""
 
-        return self.jit.runbloc(self.cpu, self.vm, pc, self.breakpoints_handler.callbacks)
+        return self.jit.runbloc(self.cpu, pc, self.breakpoints_handler.callbacks)
 
     def runiter_once(self, pc):
         """Iterator on callbacks results on code running from PC.
diff --git a/miasm2/jitter/llvmconvert.py b/miasm2/jitter/llvmconvert.py
index cfe89059..4031d8f2 100644
--- a/miasm2/jitter/llvmconvert.py
+++ b/miasm2/jitter/llvmconvert.py
@@ -11,6 +11,7 @@
 #
 #
 
+import os
 from llvmlite import binding as llvm
 from llvmlite import ir as llvm_ir
 import miasm2.expression.expression as m2_expr
@@ -294,6 +295,57 @@ class LLVMContext_JIT(LLVMContext):
                               value])
 
 
+    @staticmethod
+    def cache_notify(module, buffer):
+        """Called when @module has been compiled to @buffer"""
+        if not hasattr(module, "fname_out"):
+            return
+        fname_out = module.fname_out
+
+        if os.access(fname_out, os.R_OK):
+            # No need to overwrite
+            return
+
+        open(fname_out, "w").write(buffer)
+
+    @staticmethod
+    def cache_getbuffer(module):
+        """Return a compiled buffer for @module if available"""
+        if not hasattr(module, "fname_out"):
+            return None
+
+        fname_out = module.fname_out
+        if os.access(fname_out, os.R_OK):
+            return open(fname_out).read()
+        return None
+
+    def enable_cache(self):
+        "Enable cache of compiled object"
+        # Load shared libraries
+        for lib_fname in self.library_filenames:
+            self.add_shared_library(lib_fname)
+
+        # Activate cache
+        self.exec_engine.set_object_cache(self.cache_notify,
+                                          self.cache_getbuffer)
+
+    def set_cache_filename(self, func, fname_out):
+        "Set the filename @fname_out to use for cache for @func"
+        # Use a custom attribute to propagate the cache filename
+        func.as_llvm_mod().fname_out = fname_out
+
+    def get_ptr_from_cache(self, file_name, func_name):
+        "Load @file_name and return a pointer on the jitter @func_name"
+        # We use an empty module to avoid loosing time on function building
+        empty_module = llvm.parse_assembly("")
+        empty_module.fname_out = file_name
+
+        engine = self.exec_engine
+        engine.add_module(empty_module)
+        engine.finalize_object()
+        return engine.get_function_address(func_name)
+
+
 class LLVMContext_IRCompilation(LLVMContext):
 
     """Extend LLVMContext in order to handle memory management and custom
diff --git a/test/jitter/jit_options.py b/test/jitter/jit_options.py
new file mode 100644
index 00000000..cc955c64
--- /dev/null
+++ b/test/jitter/jit_options.py
@@ -0,0 +1,97 @@
+import os
+import sys
+from miasm2.jitter.csts import PAGE_READ, PAGE_WRITE
+from miasm2.analysis.machine import Machine
+from pdb import pm
+
+# Shellcode
+
+# main:
+#       MOV EAX, 0x1
+# loop_main:
+#       CMP EAX, 0x10
+#       JZ loop_end
+# loop_inc:
+#       INC EAX
+#       JMP loop_main
+# loop_end:
+#       RET
+data = "b80100000083f810740340ebf8c3".decode("hex")
+run_addr = 0x40000000
+
+def code_sentinelle(jitter):
+    jitter.run = False
+    jitter.pc = 0
+    return True
+
+def init_jitter():
+    global data, run_addr
+    # Create jitter
+    myjit = Machine("x86_32").jitter(sys.argv[1])
+
+    myjit.vm.add_memory_page(run_addr, PAGE_READ | PAGE_WRITE, data)
+
+    # Init jitter
+    myjit.init_stack()
+    myjit.jit.log_regs = True
+    myjit.jit.log_mn = True
+    myjit.push_uint32_t(0x1337beef)
+
+    myjit.add_breakpoint(0x1337beef, code_sentinelle)
+    return myjit
+
+# Test 'max_exec_per_call'
+print "[+] First run, to jit blocks"
+myjit = init_jitter()
+myjit.init_run(run_addr)
+myjit.continue_run()
+
+assert myjit.run is False
+assert myjit.cpu.EAX  == 0x10
+
+## Let's specify a max_exec_per_call
+## 5: main, loop_main, loop_inc, loop_main, loop_inc
+myjit.jit.options["max_exec_per_call"] = 5
+
+first_call = True
+def cb(jitter):
+    global first_call
+    if first_call:
+        # Avoid breaking on the first pass (before any execution)
+        first_call = False
+        return True
+    return False
+
+## Second run
+print "[+] Second run"
+myjit.push_uint32_t(0x1337beef)
+myjit.cpu.EAX = 0
+myjit.init_run(run_addr)
+myjit.exec_cb = cb
+myjit.continue_run()
+
+assert myjit.run is True
+# Use a '<=' because it's a 'max_...'
+assert myjit.cpu.EAX <= 3
+
+# Test 'jit_maxline'
+print "[+] Run instr one by one"
+myjit = init_jitter()
+myjit.jit.options["jit_maxline"] = 1
+myjit.jit.options["max_exec_per_call"] = 1
+
+counter = 0
+def cb(jitter):
+    global counter
+    counter += 1
+    return True
+
+myjit.init_run(run_addr)
+myjit.exec_cb = cb
+myjit.continue_run()
+
+assert myjit.run is False
+assert myjit.cpu.EAX  == 0x10
+## dry(1) + main(1) + (loop_main(2) + loop_inc(2))*(0x10 - 1) + loop_main(2) +
+## loop_end(1) = 65
+assert counter == 65
diff --git a/test/jitter/jitload.py b/test/jitter/jitload.py
index 283298db..544e9d18 100644
--- a/test/jitter/jitload.py
+++ b/test/jitter/jitload.py
@@ -1,3 +1,4 @@
+import sys
 from pdb import pm
 
 from miasm2.jitter.csts import PAGE_READ, PAGE_WRITE
@@ -9,7 +10,7 @@ from miasm2.expression.expression import ExprId, ExprInt32, ExprInt64, ExprAff,
 data = "8d49048d5b0180f90174058d5bffeb038d5b0189d8c3".decode("hex")
 
 # Init jitter
-myjit = Machine("x86_32").jitter()
+myjit = Machine("x86_32").jitter(sys.argv[1])
 myjit.init_stack()
 
 run_addr = 0x40000000
diff --git a/test/jitter/vm_mngr.py b/test/jitter/vm_mngr.py
index b2b7336b..87bc6f8f 100644
--- a/test/jitter/vm_mngr.py
+++ b/test/jitter/vm_mngr.py
@@ -1,7 +1,8 @@
+import sys
 from miasm2.jitter.csts import PAGE_READ, PAGE_WRITE
 from miasm2.analysis.machine import Machine
 
-myjit = Machine("x86_32").jitter()
+myjit = Machine("x86_32").jitter(sys.argv[1])
 
 base_addr = 0x13371337
 page_size = 0x1000
diff --git a/test/test_all.py b/test/test_all.py
index bec0c78d..59624832 100755
--- a/test/test_all.py
+++ b/test/test_all.py
@@ -325,8 +325,11 @@ for i, test_args in enumerate(test_args):
 ## Jitter
 for script in ["jitload.py",
                "vm_mngr.py",
+               "jit_options.py",
                ]:
-    testset += RegressionTest([script], base_dir="jitter", tags=[TAGS["tcc"]])
+    for engine in ArchUnitTest.jitter_engines:
+        testset += RegressionTest([script, engine], base_dir="jitter",
+                                  tags=[TAGS.get(engine,None)])
 
 
 # Examples