diff options
| author | serpilliere <serpilliere@users.noreply.github.com> | 2017-01-05 10:41:32 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2017-01-05 10:41:32 +0100 |
| commit | c551e1e761b14b962ae327e6752599b7934ba9cc (patch) | |
| tree | 4ed186bb5cef2793569d125f42b681dbac4b2fb4 | |
| parent | 2b93bc6682f3a08a5eccccefa135535708434f9e (diff) | |
| parent | 66d61fa8b799214f42be7ba42fe8138e302978bb (diff) | |
| download | miasm-c551e1e761b14b962ae327e6752599b7934ba9cc.tar.gz miasm-c551e1e761b14b962ae327e6752599b7934ba9cc.zip | |
Merge pull request #458 from commial/feature-llvm-jitter
Feature llvm jitter
| -rw-r--r-- | .travis.yml | 11 | ||||
| -rw-r--r-- | README.md | 8 | ||||
| -rw-r--r-- | miasm2/ir/translators/C.py | 6 | ||||
| -rw-r--r-- | miasm2/jitter/Jitllvm.c | 18 | ||||
| -rw-r--r-- | miasm2/jitter/arch/JitCore_x86.c | 5 | ||||
| -rw-r--r-- | miasm2/jitter/jitcore_llvm.py | 99 | ||||
| -rw-r--r-- | miasm2/jitter/llvmconvert.py | 1352 | ||||
| -rw-r--r-- | miasm2/jitter/vm_mngr.c | 9 | ||||
| -rw-r--r-- | miasm2/jitter/vm_mngr.h | 2 | ||||
| -rw-r--r-- | test/test_all.py | 9 |
10 files changed, 897 insertions, 622 deletions
diff --git a/.travis.yml b/.travis.yml index a0bf7d06..ca4ca8ce 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,13 +4,19 @@ python: - "2.7" addons: apt: + sources: ['llvm-toolchain-precise-3.8', 'ubuntu-toolchain-r-test'] packages: - make - gcc - python-virtualenv - unzip + - llvm-3.8 + - llvm-3.8-dev + - g++-5 before_script: - "cd .." +- "export LLVM_CONFIG=$(which llvm-config-3.8)" +- "export CXX=$(which g++-5)" # make virtual env - "python /usr/lib/python2.7/dist-packages/virtualenv.py virtualenv;" - "cd virtualenv;" @@ -21,6 +27,11 @@ before_script: - "make && export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$(pwd);cd ..;" - "cp tinycc/libtcc.h include" - "cp tinycc/libtcc.so.1.0 tinycc/libtcc.so" +# install llvmlite, using the system libdtc++ instead of statically linking it +- "pip install enum34" +- "git clone https://github.com/numba/llvmlite llvmlite && cd llvmlite" +- "sed -i 's/-static-libstdc++ //' ffi/Makefile.linux" +- "python setup.py install && cd .." # install elfesteem - "git clone https://github.com/serpilliere/elfesteem elfesteem && cd elfesteem && python setup.py install && cd ..;" # install pyparsing diff --git a/README.md b/README.md index 3d0dda88..a57ef87f 100644 --- a/README.md +++ b/README.md @@ -454,7 +454,7 @@ Miasm uses: To enable code JIT, one of the following module is mandatory: * GCC * Clang -* LLVM v3.2 with python-llvm, see below +* LLVM with Numba llvmlite, see below * LibTCC [tinycc (ONLY version 0.9.26)](http://repo.or.cz/w/tinycc.git) 'optional' Miasm can also use: @@ -483,9 +483,9 @@ To use the jitter, GCC, TCC or LLVM is recommended * `sudo make install` * There may be an error on documentation generation * LLVM - * Debian (testing/unstable): install python-llvm - * Debian stable/Ubuntu/Kali/whatever: install from [llvmpy](http://www.llvmpy.org/) - * Windows: python-llvm is not supported :/ + * Debian (testing/unstable): Not tested + * Debian stable/Ubuntu/Kali/whatever: `pip install llvmlite` or install from [llvmlite](https://github.com/numba/llvmlite) + * Windows: Not tested * Build and install Miasm: ``` $ cd miasm_directory diff --git a/miasm2/ir/translators/C.py b/miasm2/ir/translators/C.py index 57859f9c..c7913ea8 100644 --- a/miasm2/ir/translators/C.py +++ b/miasm2/ir/translators/C.py @@ -47,9 +47,9 @@ class TranslatorC(Translator): return "parity(%s&0x%x)" % (self.from_expr(expr.args[0]), size2mask(expr.args[0].size)) elif expr.op in ['bsr', 'bsf']: - return "x86_%s(%s, 0x%x)" % (expr.op, - self.from_expr(expr.args[0]), - expr.args[0].size) + return "x86_%s(0x%x, %s)" % (expr.op, + expr.args[0].size, + self.from_expr(expr.args[0])) elif expr.op in ['clz']: return "%s(%s)" % (expr.op, self.from_expr(expr.args[0])) diff --git a/miasm2/jitter/Jitllvm.c b/miasm2/jitter/Jitllvm.c index 6622e615..c176a4b2 100644 --- a/miasm2/jitter/Jitllvm.c +++ b/miasm2/jitter/Jitllvm.c @@ -3,19 +3,27 @@ #include <inttypes.h> #include <stdint.h> +#include "queue.h" +#include "vm_mngr.h" +#include "vm_mngr_py.h" +#include "JitCore.h" +// Needed to get the JitCpu.cpu offset, arch independent +#include "arch/JitCore_x86.h" PyObject* llvm_exec_bloc(PyObject* self, PyObject* args) { uint64_t func_addr; - uint64_t (*func)(void*, void*); + uint64_t (*func)(void*, void*, void*, uint8_t*); uint64_t vm; - uint64_t cpu; uint64_t ret; - - if (!PyArg_ParseTuple(args, "KKK", &func_addr, &cpu, &vm)) + JitCpu* jitcpu; + uint8_t status; + + if (!PyArg_ParseTuple(args, "KOK", &func_addr, &jitcpu, &vm)) return NULL; + vm_cpu_t* cpu = jitcpu->cpu; func = (void *) (intptr_t) func_addr; - ret = func((void*)(intptr_t) cpu, (void*)(intptr_t) vm); + ret = func((void*) jitcpu, (void*)(intptr_t) cpu, (void*)(intptr_t) vm, &status); return PyLong_FromUnsignedLongLong(ret); } diff --git a/miasm2/jitter/arch/JitCore_x86.c b/miasm2/jitter/arch/JitCore_x86.c index 94729b90..66c3fb56 100644 --- a/miasm2/jitter/arch/JitCore_x86.c +++ b/miasm2/jitter/arch/JitCore_x86.c @@ -599,6 +599,11 @@ PyObject* get_gpreg_offset_all(void) get_reg_off(tsc1); get_reg_off(tsc2); + get_reg_off(interrupt_num); + get_reg_off(exception_flags); + + get_reg_off(float_stack_ptr); + return dict; } diff --git a/miasm2/jitter/jitcore_llvm.py b/miasm2/jitter/jitcore_llvm.py index acf91d15..0f265073 100644 --- a/miasm2/jitter/jitcore_llvm.py +++ b/miasm2/jitter/jitcore_llvm.py @@ -14,16 +14,17 @@ class JitCore_LLVM(jitcore.JitCore): arch_dependent_libs = {"x86": "JitCore_x86.so", "arm": "JitCore_arm.so", "msp430": "JitCore_msp430.so", - "mips32": "JitCore_mips32.so"} + "mips32": "JitCore_mips32.so", + "aarch64": "JitCore_aarch64.so", + } def __init__(self, ir_arch, bs=None): super(JitCore_LLVM, self).__init__(ir_arch, bs) - self.options.update({"safe_mode": False, # Verify each function - "optimise": False, # Optimise functions + self.options.update({"safe_mode": True, # Verify each function + "optimise": True, # Optimise functions "log_func": False, # Print LLVM functions "log_assembly": False, # Print assembly executed - "cache_ir": None # SaveDir for cached .ll }) self.exec_wrapper = Jitllvm.llvm_exec_bloc @@ -46,7 +47,7 @@ class JitCore_LLVM(jitcore.JitCore): pass # Create a context - self.context = LLVMContext_JIT(libs_to_load) + self.context = LLVMContext_JIT(libs_to_load, self.ir_arch) # Set the optimisation level self.context.optimise_level() @@ -59,83 +60,21 @@ class JitCore_LLVM(jitcore.JitCore): mod = importlib.import_module(mod_name) self.context.set_vmcpu(mod.get_gpreg_offset_all()) - # Save module base - self.mod_base_str = str(self.context.mod) - - # Set IRs transformation to apply - self.context.set_IR_transformation(self.ir_arch.expr_fix_regs_for_mode) - - def add_bloc(self, bloc): - - # Search in IR cache - if self.options["cache_ir"] is not None: - - # /!\ This part is under development - # Use it at your own risk - - # Compute Hash : label + bloc binary - func_name = bloc.label.name - to_hash = func_name - - # Get binary from bloc - for line in bloc.lines: - b = line.b - to_hash += b - - # Compute Hash - md5 = hashlib.md5(to_hash).hexdigest() - - # Try to load the function from cache - filename = self.options["cache_ir"] + md5 + ".ll" - - try: - fcontent = open(filename) - content = fcontent.read() - fcontent.close() - - except IOError: - content = None - - if content is None: - # Compute the IR - super(JitCore_LLVM, self).add_bloc(bloc) - - # Save it - fdest = open(filename, "w") - dump = str(self.context.mod.get_function_named(func_name)) - my = "declare i16 @llvm.bswap.i16(i16) nounwind readnone\n" - - fdest.write(self.mod_base_str + my + dump) - fdest.close() - - else: - import llvm.core as llvm_c - import llvm.ee as llvm_e - my_mod = llvm_c.Module.from_assembly(content) - func = my_mod.get_function_named(func_name) - exec_en = llvm_e.ExecutionEngine.new(my_mod) - self.exec_engines.append(exec_en) - - # We can use the same exec_engine - ptr = self.exec_engines[0].get_pointer_to_function(func) - - # Store a pointer on the function jitted code - self.lbl2jitbloc[bloc.label.offset] = ptr - - else: - super(JitCore_LLVM, self).add_bloc(bloc) - - def jitirblocs(self, label, irblocs): + def add_bloc(self, block): + """Add a block to JiT and JiT it. + @block: the block to add + """ + # TODO: caching using hash # Build a function in the context - func = LLVMFunction(self.context, label.name) + func = LLVMFunction(self.context, block.label.name) # Set log level func.log_regs = self.log_regs func.log_mn = self.log_mn - # Import irblocs - func.from_blocs(irblocs) + # Import asm block + func.from_asmblock(block) # Verify if self.options["safe_mode"] is True: @@ -152,4 +91,12 @@ class JitCore_LLVM(jitcore.JitCore): print func.get_assembly() # Store a pointer on the function jitted code - self.lbl2jitbloc[label.offset] = func.get_function_pointer() + self.lbl2jitbloc[block.label.offset] = func.get_function_pointer() + + def jit_call(self, label, cpu, _vmmngr, breakpoints): + """Call the function label with cpu and vmmngr states + @label: function's label + @cpu: JitCpu instance + @breakpoints: Dict instance of used breakpoints + """ + return self.exec_wrapper(self.lbl2jitbloc[label], cpu, cpu.vmmngr.vmmngr) diff --git a/miasm2/jitter/llvmconvert.py b/miasm2/jitter/llvmconvert.py index 3ac75cd7..08b1986b 100644 --- a/miasm2/jitter/llvmconvert.py +++ b/miasm2/jitter/llvmconvert.py @@ -5,33 +5,33 @@ # - JiT # # # Requires: # -# - llvmpy (tested on v0.11.2) # +# - llvmlite (tested on v0.15) # # # Authors : Fabrice DESCLAUX (CEA/DAM), Camille MOUGEY (CEA/DAM) # # # -import llvm -import llvm.core as llvm_c -import llvm.ee as llvm_e -import llvm.passes as llvm_p +from llvmlite import binding as llvm +from llvmlite import ir as llvm_ir import miasm2.expression.expression as m2_expr import miasm2.jitter.csts as m2_csts import miasm2.core.asmbloc as m2_asmbloc +from miasm2.jitter.codegen import CGen +from miasm2.expression.expression_helper import possible_values -class LLVMType(llvm_c.Type): +class LLVMType(llvm_ir.Type): "Handle LLVM Type" int_cache = {} @classmethod - def int(cls, size=32): + def IntType(cls, size=32): try: return cls.int_cache[size] except KeyError: - cls.int_cache[size] = llvm_c.Type.int(size) + cls.int_cache[size] = llvm_ir.IntType(size) return cls.int_cache[size] @classmethod @@ -43,7 +43,7 @@ class LLVMType(llvm_c.Type): def generic(cls, e): "Generic value for execution" if isinstance(e, m2_expr.ExprInt): - return llvm_e.GenericValue.int(LLVMType.int(e.size), int(e)) + return llvm_e.GenericValue.int(LLVMType.IntType(e.size), int(e.arg)) elif isinstance(e, llvm_e.GenericValue): return e else: @@ -58,39 +58,39 @@ class LLVMContext(): def __init__(self, name="mod"): "Initialize a context with a module named 'name'" - self.mod = llvm_c.Module.new(name) - self.pass_manager = llvm_p.FunctionPassManager.new(self.mod) - self.exec_engine = llvm_e.ExecutionEngine.new(self.mod) - self.add_fc(self.known_fc) - - def optimise_level(self, classic_passes=True, dead_passes=True): - """Set the optimisation level : - classic_passes : - - combine instruction - - reassociate - - global value numbering - - simplify cfg - - dead_passes : - - dead code - - dead store - - dead instructions + # Initialize llvm + llvm.initialize() + llvm.initialize_native_target() + llvm.initialize_native_asmprinter() + + # Initilize target for compilation + target = llvm.Target.from_default_triple() + self.target_machine = target.create_target_machine() + self.init_exec_engine() + + def optimise_level(self, level=2): + """Set the optimisation level to @level from 0 to 2 + 0: non-optimized + 2: optimized """ # Set up the optimiser pipeline - - if classic_passes is True: - # self.pass_manager.add(llvm_p.PASS_INSTCOMBINE) - self.pass_manager.add(llvm_p.PASS_REASSOCIATE) - self.pass_manager.add(llvm_p.PASS_GVN) - self.pass_manager.add(llvm_p.PASS_SIMPLIFYCFG) - - if dead_passes is True: - self.pass_manager.add(llvm_p.PASS_DCE) - self.pass_manager.add(llvm_p.PASS_DSE) - self.pass_manager.add(llvm_p.PASS_DIE) - - self.pass_manager.initialize() + pmb = llvm.create_pass_manager_builder() + pmb.opt_level = level + pm = llvm.create_module_pass_manager() + pmb.populate(pm) + self.pass_manager = pm + + def init_exec_engine(self): + mod = llvm.parse_assembly("") + engine = llvm.create_mcjit_compiler(mod, + self.target_machine) + self.exec_engine = engine + + def new_module(self, name="mod"): + """Create a module, with needed functions""" + self.mod = llvm_ir.Module(name=name) + self.add_fc(self.known_fc) def get_execengine(self): "Return the Execution Engine associated with this context" @@ -98,7 +98,7 @@ class LLVMContext(): def get_passmanager(self): "Return the Pass Manager associated with this context" - return self.exec_engine + return self.pass_manager def get_module(self): "Return the module associated with this context" @@ -106,123 +106,246 @@ class LLVMContext(): def add_shared_library(self, filename): "Load the shared library 'filename'" - return llvm_c.load_library_permanently(filename) + return llvm.load_library_permanently(filename) def add_fc(self, fc): "Add function into known_fc" - for name, detail in fc.items(): - self.mod.add_function(LLVMType.function(detail["ret"], - detail["args"]), - name) + for name, detail in fc.iteritems(): + fnty = llvm_ir.FunctionType(detail["ret"], detail["args"]) + llvm_ir.Function(self.mod, fnty, name=name) + + def memory_lookup(self, func, addr, size): + """Perform a memory lookup at @addr of size @size (in bit)""" + raise NotImplementedError("Abstract method") + + def memory_write(self, func, addr, size, value): + """Perform a memory write at @addr of size @size (in bit) with LLVM IR @value""" + raise NotImplementedError("Abstract method") class LLVMContext_JIT(LLVMContext): - "Extend LLVMContext_JIT in order to handle memory management" + """Extend LLVMContext_JIT in order to handle memory management and custom + operations""" - def __init__(self, library_filenames, name="mod"): + def __init__(self, library_filenames, ir_arch, name="mod"): "Init a LLVMContext object, and load the mem management shared library" + self.library_filenames = library_filenames + self.ir_arch = ir_arch + self.arch_specific() LLVMContext.__init__(self, name) - for lib_fname in library_filenames: + self.vmcpu = {} + + def new_module(self, name="mod"): + LLVMContext.new_module(self, name) + for lib_fname in self.library_filenames: self.add_shared_library(lib_fname) self.add_memlookups() self.add_get_exceptionflag() self.add_op() self.add_log_functions() - self.vmcpu = {} + + def arch_specific(self): + arch = self.ir_arch.arch + if arch.name == "x86": + self.PC = arch.regs.RIP + self.logging_func = "dump_gpregs_%d" % self.ir_arch.attrib + else: + self.PC = self.ir_arch.pc + self.logging_func = "dump_gpregs" + if arch.name == "mips32": + from miasm2.arch.mips32.jit import mipsCGen + self.cgen_class = mipsCGen + self.has_delayslot = True + else: + self.cgen_class = CGen + self.has_delayslot = False def add_memlookups(self): "Add MEM_LOOKUP functions" fc = {} - p8 = llvm_c.PointerType.pointer(LLVMType.int(8)) + p8 = llvm_ir.PointerType(LLVMType.IntType(8)) for i in [8, 16, 32, 64]: - fc["MEM_LOOKUP_%02d" % i] = {"ret": LLVMType.int(i), + fc["MEM_LOOKUP_%02d" % i] = {"ret": LLVMType.IntType(i), "args": [p8, - LLVMType.int(64)]} + LLVMType.IntType(64)]} - fc["MEM_WRITE_%02d" % i] = {"ret": LLVMType.void(), + fc["MEM_WRITE_%02d" % i] = {"ret": llvm_ir.VoidType(), "args": [p8, - LLVMType.int(64), - LLVMType.int(i)]} - + LLVMType.IntType(64), + LLVMType.IntType(i)]} + fc["reset_memory_access"] = {"ret": llvm_ir.VoidType(), + "args": [p8, + ]} + fc["check_memory_breakpoint"] = {"ret": llvm_ir.VoidType(), + "args": [p8, + ]} + fc["check_invalid_code_blocs"] = {"ret": llvm_ir.VoidType(), + "args": [p8, + ]} self.add_fc(fc) def add_get_exceptionflag(self): "Add 'get_exception_flag' function" - p8 = llvm_c.PointerType.pointer(LLVMType.int(8)) - self.add_fc({"get_exception_flag": {"ret": LLVMType.int(64), + p8 = llvm_ir.PointerType(LLVMType.IntType(8)) + self.add_fc({"get_exception_flag": {"ret": LLVMType.IntType(64), "args": [p8]}}) def add_op(self): "Add operations functions" - p8 = llvm_c.PointerType.pointer(LLVMType.int(8)) - self.add_fc({"parity": {"ret": LLVMType.int(), - "args": [LLVMType.int()]}}) - self.add_fc({"rot_left": {"ret": LLVMType.int(), - "args": [LLVMType.int(), - LLVMType.int(), - LLVMType.int()]}}) - self.add_fc({"rot_right": {"ret": LLVMType.int(), - "args": [LLVMType.int(), - LLVMType.int(), - LLVMType.int()]}}) - - self.add_fc({"segm2addr": {"ret": LLVMType.int(64), + p8 = llvm_ir.PointerType(LLVMType.IntType(8)) + itype = LLVMType.IntType(64) + self.add_fc({"parity": {"ret": LLVMType.IntType(1), + "args": [itype]}}) + self.add_fc({"rot_left": {"ret": itype, + "args": [itype, + itype, + itype]}}) + self.add_fc({"rot_right": {"ret": itype, + "args": [itype, + itype, + itype]}}) + self.add_fc({"rcr_rez_op": {"ret": itype, + "args": [itype, + itype, + itype, + itype]}}) + self.add_fc({"rcl_rez_op": {"ret": itype, + "args": [itype, + itype, + itype, + itype]}}) + self.add_fc({"x86_bsr": {"ret": itype, + "args": [itype, + itype]}}) + self.add_fc({"x86_bsf": {"ret": itype, + "args": [itype, + itype]}}) + self.add_fc({"segm2addr": {"ret": itype, "args": [p8, - LLVMType.int(64), - LLVMType.int(64)]}}) + itype, + itype]}}) for k in [8, 16]: - self.add_fc({"bcdadd_%s" % k: {"ret": LLVMType.int(k), - "args": [LLVMType.int(k), - LLVMType.int(k)]}}) - self.add_fc({"bcdadd_cf_%s" % k: {"ret": LLVMType.int(k), - "args": [LLVMType.int(k), - LLVMType.int(k)]}}) + self.add_fc({"bcdadd_%s" % k: {"ret": LLVMType.IntType(k), + "args": [LLVMType.IntType(k), + LLVMType.IntType(k)]}}) + self.add_fc({"bcdadd_cf_%s" % k: {"ret": LLVMType.IntType(k), + "args": [LLVMType.IntType(k), + LLVMType.IntType(k)]}}) for k in [16, 32, 64]: - self.add_fc({"imod%s" % k: {"ret": LLVMType.int(k), + self.add_fc({"imod%s" % k: {"ret": LLVMType.IntType(k), + "args": [p8, + LLVMType.IntType(k), + LLVMType.IntType(k)]}}) + self.add_fc({"idiv%s" % k: {"ret": LLVMType.IntType(k), "args": [p8, - LLVMType.int(k), - LLVMType.int(k)]}}) - self.add_fc({"idiv%s" % k: {"ret": LLVMType.int(k), + LLVMType.IntType(k), + LLVMType.IntType(k)]}}) + self.add_fc({"umod%s" % k: {"ret": LLVMType.IntType(k), "args": [p8, - LLVMType.int(k), - LLVMType.int(k)]}}) + LLVMType.IntType(k), + LLVMType.IntType(k)]}}) + self.add_fc({"udiv%s" % k: {"ret": LLVMType.IntType(k), + "args": [p8, + LLVMType.IntType(k), + LLVMType.IntType(k)]}}) def add_log_functions(self): "Add functions for state logging" - p8 = llvm_c.PointerType.pointer(LLVMType.int(8)) - self.add_fc({"dump_gpregs": {"ret": LLVMType.void(), - "args": [p8]}}) + p8 = llvm_ir.PointerType(LLVMType.IntType(8)) + self.add_fc({self.logging_func: {"ret": llvm_ir.VoidType(), + "args": [p8]}}) def set_vmcpu(self, lookup_table): "Set the correspondance between register name and vmcpu offset" self.vmcpu = lookup_table - def set_IR_transformation(self, *args): - """Set a list of transformation to apply on expression before their - treatments. - args: function Expr(Expr)""" - self.IR_transformation_functions = args + def memory_lookup(self, func, addr, size): + """Perform a memory lookup at @addr of size @size (in bit)""" + builder = func.builder + fc_name = "MEM_LOOKUP_%02d" % size + fc_ptr = self.mod.get_global(fc_name) + addr_casted = builder.zext(addr, + LLVMType.IntType(64)) + + ret = builder.call(fc_ptr, [func.local_vars["jitcpu"], + addr_casted]) + return ret + + def memory_write(self, func, addr, size, value): + """Perform a memory write at @addr of size @size (in bit) with LLVM IR @value""" + # Function call + builder = func.builder + fc_name = "MEM_WRITE_%02d" % size + fc_ptr = self.mod.get_global(fc_name) + dst_casted = builder.zext(addr, LLVMType.IntType(64)) + builder.call(fc_ptr, [func.local_vars["jitcpu"], + dst_casted, + value]) + + +class LLVMContext_IRCompilation(LLVMContext): + + """Extend LLVMContext in order to handle memory management and custom + operations for Miasm IR compilation""" + + def memory_lookup(self, func, addr, size): + """Perform a memory lookup at @addr of size @size (in bit)""" + builder = func.builder + int_size = LLVMType.IntType(size) + ptr_casted = builder.inttoptr(addr, + llvm_ir.PointerType(int_size)) + return builder.load(ptr_casted) + + def memory_write(self, func, addr, size, value): + """Perform a memory write at @addr of size @size (in bit) with LLVM IR @value""" + builder = func.builder + int_size = LLVMType.IntType(size) + ptr_casted = builder.inttoptr(addr, + llvm_ir.PointerType(int_size)) + return builder.store(value, ptr_casted) class LLVMFunction(): + """Represent a LLVM function - "Represent a llvm function" + Implementation note: + A new module is created each time to avoid cumulative lag (if @new_module) + """ # Default logging values log_mn = False - log_regs = False - - def __init__(self, llvm_context, name="fc"): - "Create a new function with name fc" + log_regs = True + + # Operation translation + ## Basics + op_translate = {'parity': 'parity', + } + ## Add the size as first argument + op_translate_with_size = {'<<<': 'rot_left', + '>>>': 'rot_right', + '<<<c_rez': 'rcl_rez_op', + '>>>c_rez': 'rcr_rez_op', + 'bsr': 'x86_bsr', + 'bsf': 'x86_bsf', + } + ## Add the size as suffix + op_translate_with_suffix_size = {'bcdadd': 'bcdadd', + 'bcdadd_cf': 'bcdadd_cf', + } + + def __init__(self, llvm_context, name="fc", new_module=True): + "Create a new function with name @name" self.llvm_context = llvm_context + if new_module: + self.llvm_context.new_module() self.mod = self.llvm_context.get_module() self.my_args = [] # (Expr, LLVMType, Name) @@ -232,40 +355,195 @@ class LLVMFunction(): self.branch_counter = 0 self.name = name + self._llvm_mod = None + + # Constructor utils def new_branch_name(self): "Return a new branch name" - self.branch_counter += 1 - return "%s" % self.branch_counter + return str(self.branch_counter) - def viewCFG(self): - "Show the CFG of the current function" - self.fc.viewCFG() - - def append_basic_block(self, label): + def append_basic_block(self, label, overwrite=True): """Add a new basic block to the current function. @label: str or asmlabel + @overwrite: if False, do nothing if a bbl with the same name already exists Return the corresponding LLVM Basic Block""" name = self.canonize_label_name(label) + bbl = self.name2bbl.get(name, None) + if not overwrite and bbl is not None: + return bbl bbl = self.fc.append_basic_block(name) - self.name2bbl[label] = bbl + self.name2bbl[name] = bbl return bbl + def CreateEntryBlockAlloca(self, var_type, default_value=None): + """Create an alloca instruction at the beginning of the current fc + @default_value: if set, store the default_value just after the allocation + """ + builder = self.builder + current_bbl = builder.basic_block + builder.position_at_start(self.entry_bbl) + + ret = builder.alloca(var_type) + if default_value is not None: + builder.store(default_value, ret) + builder.position_at_end(current_bbl) + return ret + + def get_ptr_by_expr(self, expr): + """"Return a pointer casted corresponding to ExprId expr. If it is not + already computed, compute it at the end of entry_bloc""" + + name = expr.name + + ptr_casted = self.local_vars_pointers.get(name, None) + if ptr_casted is not None: + # If the pointer has already been computed + return ptr_casted + + # Get current objects + builder = self.builder + current_bbl = builder.basic_block + + # Go at the right position + entry_bloc_bbl = self.entry_bbl + builder.position_at_end(entry_bloc_bbl) + + # Compute the pointer address + offset = self.llvm_context.vmcpu[name] + + # Pointer cast + ptr = builder.gep(self.local_vars["vmcpu"], + [llvm_ir.Constant(LLVMType.IntType(), + offset)]) + int_size = LLVMType.IntType(expr.size) + ptr_casted = builder.bitcast(ptr, + llvm_ir.PointerType(int_size)) + # Store in cache + self.local_vars_pointers[name] = ptr_casted + + # Reset builder + builder.position_at_end(current_bbl) + + return ptr_casted + + def update_cache(self, name, value): + "Add 'name' = 'value' to the cache iff main_stream = True" + + if self.main_stream is True: + self.expr_cache[name] = value + + def set_ret(self, var): + "Cast @var and return it at the end of current bbl" + if var.type.width < 64: + var_casted = self.builder.zext(var, LLVMType.IntType(64)) + else: + var_casted = var + self.builder.ret(var_casted) + + def canonize_label_name(self, label): + """Canonize @label names to a common form. + @label: str or asmlabel instance""" + if isinstance(label, str): + return label + elif isinstance(label, m2_asmbloc.asm_label): + return "label_%s" % label.name + elif m2_asmbloc.expr_is_label(label): + return "label_%s" % label.name.name + else: + raise ValueError("label must either be str or asmlabel") + + def get_basic_bloc_by_label(self, label): + "Return the bbl corresponding to label, None otherwise" + return self.name2bbl.get(self.canonize_label_name(label), None) + + def global_constant(self, name, value): + """ + Inspired from numba/cgutils.py + + Get or create a (LLVM module-)global constant with *name* or *value*. + """ + module = self.mod + data = llvm_ir.GlobalVariable(self.mod, value.type, name=name) + data.global_constant = True + data.initializer = value + return data + + def make_bytearray(self, buf): + """ + Inspired from numba/cgutils.py + + Make a byte array constant from *buf*. + """ + b = bytearray(buf) + n = len(b) + return llvm_ir.Constant(llvm_ir.ArrayType(llvm_ir.IntType(8), n), b) + + def printf(self, format, *args): + """ + Inspired from numba/cgutils.py + + Calls printf(). + Argument `format` is expected to be a Python string. + Values to be printed are listed in `args`. + + Note: There is no checking to ensure there is correct number of values + in `args` and there type matches the declaration in the format string. + """ + assert isinstance(format, str) + mod = self.mod + # Make global constant for format string + cstring = llvm_ir.IntType(8).as_pointer() + fmt_bytes = self.make_bytearray((format + '\00').encode('ascii')) + + base_name = "printf_format" + count = 0 + while self.mod.get_global("%s_%d" % (base_name, count)): + count += 1 + global_fmt = self.global_constant("%s_%d" % (base_name, count), + fmt_bytes) + fnty = llvm_ir.FunctionType(llvm_ir.IntType(32), [cstring], + var_arg=True) + # Insert printf() + fn = mod.get_global('printf') + if fn is None: + fn = llvm_ir.Function(mod, fnty, name="printf") + # Call + ptr_fmt = self.builder.bitcast(global_fmt, cstring) + return self.builder.call(fn, [ptr_fmt] + list(args)) + + # Effective constructors + + def affect(self, src, dst): + "Affect from LLVM src to M2 dst" + + # Destination + builder = self.builder + + if isinstance(dst, m2_expr.ExprId): + ptr_casted = self.get_ptr_by_expr(dst) + builder.store(src, ptr_casted) + + elif isinstance(dst, m2_expr.ExprMem): + addr = self.add_ir(dst.arg) + self.llvm_context.memory_write(self, addr, dst.size, src) + else: + raise Exception("UnknownAffectationType") + def init_fc(self): "Init the function" # Build type for fc signature - fc_type = LLVMType.function( - self.ret_type, [k[1] for k in self.my_args]) + fc_type = llvm_ir.FunctionType(self.ret_type, [k[1] for k in self.my_args]) # Add fc in module try: - fc = self.mod.add_function(fc_type, self.name) + fc = llvm_ir.Function(self.mod, fc_type, name=self.name) except llvm.LLVMException: # Overwrite the previous function - previous_fc = self.mod.get_function_named(self.name) + previous_fc = self.mod.get_global(self.name) previous_fc.delete() fc = self.mod.add_function(fc_type, self.name) @@ -283,7 +561,6 @@ class LLVMFunction(): self.expr_cache = {} self.main_stream = True self.name2bbl = {} - self.offsets_jitted = set() # Function link self.fc = fc @@ -292,69 +569,7 @@ class LLVMFunction(): self.entry_bbl = self.append_basic_block("entry") # Instruction builder - self.builder = llvm_c.Builder.new(self.entry_bbl) - - def CreateEntryBlockAlloca(self, var_type): - "Create an alloca instruction at the beginning of the current fc" - builder = self.builder - current_bbl = builder.basic_block - builder.position_at_end(self.entry_bbl) - - ret = builder.alloca(var_type) - builder.position_at_end(current_bbl) - return ret - - def get_ptr_by_expr(self, expr): - """"Return a pointer casted corresponding to ExprId expr. If it is not - already computed, compute it at the end of entry_bloc""" - - name = expr.name - - try: - # If the pointer has already been computed - ptr_casted = self.local_vars_pointers[name] - - except KeyError: - # Get current objects - builder = self.builder - current_bbl = builder.basic_block - - # Go at the right position - entry_bloc_bbl = self.entry_bbl - builder.position_at_end(entry_bloc_bbl) - - # Compute the pointer address - offset = self.llvm_context.vmcpu[name] - - # Pointer cast - ptr = builder.gep(self.local_vars["vmcpu"], - [llvm_c.Constant.int(LLVMType.int(), - offset)]) - int_size = LLVMType.int(expr.size) - ptr_casted = builder.bitcast(ptr, - llvm_c.PointerType.pointer(int_size)) - # Store in cache - self.local_vars_pointers[name] = ptr_casted - - # Reset builder - builder.position_at_end(current_bbl) - - return ptr_casted - - def clear_cache(self, regs_updated): - "Remove from the cache values which depends on regs_updated" - - regs_updated_set = set(regs_updated) - - for expr in self.expr_cache.keys(): - if expr.get_r(True).isdisjoint(regs_updated_set) is not True: - self.expr_cache.pop(expr) - - def update_cache(self, name, value): - "Add 'name' = 'value' to the cache iff main_stream = True" - - if self.main_stream is True: - self.expr_cache[name] = value + self.builder = llvm_ir.IRBuilder(self.entry_bbl) def add_ir(self, expr): "Add a Miasm2 IR to the last bbl. Return the var created" @@ -365,7 +580,7 @@ class LLVMFunction(): builder = self.builder if isinstance(expr, m2_expr.ExprInt): - ret = llvm_c.Constant.int(LLVMType.int(expr.size), int(expr)) + ret = llvm_ir.Constant(LLVMType.IntType(expr.size), int(expr.arg)) self.update_cache(expr, ret) return ret @@ -374,7 +589,7 @@ class LLVMFunction(): if not isinstance(name, str): # Resolve label offset = name.offset - ret = llvm_c.Constant.int(LLVMType.int(expr.size), offset) + ret = llvm_ir.Constant(LLVMType.IntType(expr.size), offset) self.update_cache(expr, ret) return ret @@ -393,69 +608,73 @@ class LLVMFunction(): if isinstance(expr, m2_expr.ExprOp): op = expr.op - if op == "parity": - fc_ptr = self.mod.get_function_named("parity") - arg = builder.zext(self.add_ir(expr.args[0]), - LLVMType.int()) - ret = builder.call(fc_ptr, [arg]) - ret = builder.trunc(ret, LLVMType.int(expr.size)) - self.update_cache(expr, ret) - return ret - - if op in ["<<<", ">>>"]: - fc_name = "rot_left" if op == "<<<" else "rot_right" - fc_ptr = self.mod.get_function_named(fc_name) + if (op in self.op_translate or + op in self.op_translate_with_size or + op in self.op_translate_with_suffix_size): args = [self.add_ir(arg) for arg in expr.args] arg_size = expr.args[0].size - if arg_size < 32: - # Cast args - args = [builder.zext(arg, LLVMType.int(32)) - for arg in args] - arg_size_cst = llvm_c.Constant.int(LLVMType.int(), - arg_size) - ret = builder.call(fc_ptr, [arg_size_cst] + args) - if arg_size < 32: - # Cast ret - ret = builder.trunc(ret, LLVMType.int(arg_size)) - self.update_cache(expr, ret) - return ret - if op == "bcdadd": - size = expr.args[0].size - fc_ptr = self.mod.get_function_named("bcdadd_%s" % size) - args = [self.add_ir(arg) for arg in expr.args] - ret = builder.call(fc_ptr, args) - self.update_cache(expr, ret) - return ret + if op in self.op_translate_with_size: + fc_name = self.op_translate_with_size[op] + arg_size_cst = llvm_ir.Constant(LLVMType.IntType(64), + arg_size) + args = [arg_size_cst] + args + elif op in self.op_translate: + fc_name = self.op_translate[op] + elif op in self.op_translate_with_suffix_size: + fc_name = "%s_%s" % (self.op_translate[op], arg_size) + + fc_ptr = self.mod.get_global(fc_name) + + # Cast args if needed + casted_args = [] + for i, arg in enumerate(args): + if arg.type.width < fc_ptr.args[i].type.width: + casted_args.append(builder.zext(arg, fc_ptr.args[i].type)) + else: + casted_args.append(arg) + ret = builder.call(fc_ptr, casted_args) + + # Cast ret if needed + ret_size = fc_ptr.return_value.type.width + if ret_size > expr.size: + ret = builder.trunc(ret, LLVMType.IntType(expr.size)) - if op == "bcdadd_cf": - size = expr.args[0].size - fc_ptr = self.mod.get_function_named("bcdadd_cf_%s" % size) - args = [self.add_ir(arg) for arg in expr.args] - ret = builder.call(fc_ptr, args) - ret = builder.trunc(ret, LLVMType.int(expr.size)) self.update_cache(expr, ret) return ret if op == "-": - zero = llvm_c.Constant.int(LLVMType.int(expr.size), - 0) + # Unsupported op '-' with more than 1 arg + assert len(expr.args) == 1 + zero = LLVMType.IntType(expr.size)(0) ret = builder.sub(zero, self.add_ir(expr.args[0])) self.update_cache(expr, ret) return ret if op == "segm": - fc_ptr = self.mod.get_function_named("segm2addr") - args_casted = [builder.zext(self.add_ir(arg), LLVMType.int(64)) - for arg in expr.args] - args = [self.local_vars["vmcpu"]] + args_casted - ret = builder.call(fc_ptr, args) - ret = builder.trunc(ret, LLVMType.int(expr.size)) + fc_ptr = self.mod.get_global("segm2addr") + + # Cast args if needed + args = [self.add_ir(arg) for arg in expr.args] + casted_args = [] + for i, arg in enumerate(args, 1): + if arg.type.width < fc_ptr.args[i].type.width: + casted_args.append(builder.zext(arg, fc_ptr.args[i].type)) + else: + casted_args.append(arg) + + ret = builder.call(fc_ptr, + [self.local_vars["jitcpu"]] + casted_args) + # Ret size is not expr.size on segm2addr (which is the size of + # the segment, for instance 16 bits), but the size of an addr + ret_size = self.llvm_context.PC.size + if ret.type.width > ret_size: + ret = builder.trunc(ret, LLVMType.IntType(ret_size)) self.update_cache(expr, ret) return ret - if op in ["imod", "idiv"]: - fc_ptr = self.mod.get_function_named( + if op in ["imod", "idiv", "umod", "udiv"]: + fc_ptr = self.mod.get_global( "%s%s" % (op, expr.args[0].size)) args_casted = [self.add_ir(arg) for arg in expr.args] args = [self.local_vars["vmcpu"]] + args_casted @@ -463,6 +682,26 @@ class LLVMFunction(): self.update_cache(expr, ret) return ret + if op in [">>", "<<", "a>>"]: + assert len(expr.args) == 2 + # Undefined behavior must be enforced to 0 + count = self.add_ir(expr.args[1]) + value = self.add_ir(expr.args[0]) + itype = LLVMType.IntType(expr.size) + cond_ok = self.builder.icmp_unsigned("<", count, + itype(expr.size)) + if op == ">>": + callback = builder.lshr + elif op == "<<": + callback = builder.shl + elif op == "a>>": + callback = builder.ashr + + ret = self.builder.select(cond_ok, callback(value, count), + itype(0)) + self.update_cache(expr, ret) + return ret + if len(expr.args) > 1: if op == "*": @@ -475,16 +714,10 @@ class LLVMFunction(): callback = builder.xor elif op == "|": callback = builder.or_ - elif op == ">>": - callback = builder.lshr - elif op == "<<": - callback = builder.shl - elif op == "a>>": - callback = builder.ashr - elif op == "udiv": - callback = builder.udiv - elif op == "umod": + elif op == "%": callback = builder.urem + elif op == "/": + callback = builder.udiv else: raise NotImplementedError('Unknown op: %s' % op) @@ -502,58 +735,18 @@ class LLVMFunction(): if isinstance(expr, m2_expr.ExprMem): - fc_name = "MEM_LOOKUP_%02d" % expr.size - fc_ptr = self.mod.get_function_named(fc_name) - addr_casted = builder.zext(self.add_ir(expr.arg), - LLVMType.int(64)) - - ret = builder.call(fc_ptr, [self.local_vars["vmmngr"], - addr_casted]) - - # Do not update memory cache to avoid pointer collision - return ret + addr = self.add_ir(expr.arg) + return self.llvm_context.memory_lookup(self, addr, expr.size) if isinstance(expr, m2_expr.ExprCond): # Compute cond cond = self.add_ir(expr.cond) - zero_casted = llvm_c.Constant.int(LLVMType.int(expr.cond.size), - 0) - condition_bool = builder.icmp(llvm_c.ICMP_NE, cond, - zero_casted) - - # Alloc return var - alloca = self.CreateEntryBlockAlloca(LLVMType.int(expr.size)) - - # Create bbls - branch_id = self.new_branch_name() - then_block = self.append_basic_block('then%s' % branch_id) - else_block = self.append_basic_block('else%s' % branch_id) - merge_block = self.append_basic_block('ifcond%s' % branch_id) - - builder.cbranch(condition_bool, then_block, else_block) - - # Deactivate object caching - current_main_stream = self.main_stream - self.main_stream = False - - # Then Bloc - builder.position_at_end(then_block) + zero_casted = LLVMType.IntType(expr.cond.size)(0) + condition_bool = builder.icmp_unsigned("!=", cond, + zero_casted) then_value = self.add_ir(expr.src1) - builder.store(then_value, alloca) - builder.branch(merge_block) - - # Else Bloc - builder.position_at_end(else_block) else_value = self.add_ir(expr.src2) - builder.store(else_value, alloca) - builder.branch(merge_block) - - # Merge bloc - builder.position_at_end(merge_block) - ret = builder.load(alloca) - - # Reactivate object caching - self.main_stream = current_main_stream + ret = builder.select(condition_bool, then_value, else_value) self.update_cache(expr, ret) return ret @@ -564,22 +757,22 @@ class LLVMFunction(): # Remove trailing bits if expr.start != 0: - to_shr = llvm_c.Constant.int(LLVMType.int(expr.arg.size), - expr.start) + to_shr = llvm_ir.Constant(LLVMType.IntType(expr.arg.size), + expr.start) shred = builder.lshr(src, to_shr) else: shred = src # Remove leading bits - to_and = llvm_c.Constant.int(LLVMType.int(expr.arg.size), - (1 << (expr.stop - expr.start)) - 1) + to_and = llvm_ir.Constant(LLVMType.IntType(expr.arg.size), + (1 << (expr.stop - expr.start)) - 1) anded = builder.and_(shred, to_and) # Cast into e.size ret = builder.trunc(anded, - LLVMType.int(expr.size)) + LLVMType.IntType(expr.size)) self.update_cache(expr, ret) return ret @@ -589,22 +782,20 @@ class LLVMFunction(): args = [] # Build each part - for arg in expr.args: - src, start, stop = arg - - # src & (stop - start) + for start, src in expr.iter_args(): + # src & size src = self.add_ir(src) src_casted = builder.zext(src, - LLVMType.int(expr.size)) - to_and = llvm_c.Constant.int(LLVMType.int(expr.size), - (1 << (stop - start)) - 1) + LLVMType.IntType(expr.size)) + to_and = llvm_ir.Constant(LLVMType.IntType(expr.size), + (1 << src.type.width) - 1) anded = builder.and_(src_casted, to_and) if (start != 0): # result << start - to_shl = llvm_c.Constant.int(LLVMType.int(expr.size), - start) + to_shl = llvm_ir.Constant(LLVMType.IntType(expr.size), + start) shled = builder.shl(anded, to_shl) final = shled else: @@ -623,92 +814,34 @@ class LLVMFunction(): raise Exception("UnkownExpression", expr.__class__.__name__) - def set_ret(self, var): - "Cast @var and return it at the end of current bbl" - if var.type.width < 64: - var_casted = self.builder.zext(var, LLVMType.int(64)) - else: - var_casted = var - self.builder.ret(var_casted) - - def from_expr(self, expr): - "Build the function from an expression" - - # Build function signature - args = expr.get_r(True) - for a in args: - if not isinstance(a, m2_expr.ExprMem): - self.my_args.append((a, LLVMType.int(a.size), a.name)) - - self.ret_type = LLVMType.int(expr.size) - - # Initialise the function - self.init_fc() - - ret = self.add_ir(expr) - - self.set_ret(ret) - - def affect(self, src, dst, add_new=True): - "Affect from M2 src to M2 dst. If add_new, add a suffix '_new' to dest" - - # Source - src = self.add_ir(src) - - # Destination - builder = self.builder - self.add_ir(m2_expr.ExprId("vmcpu")) - - if isinstance(dst, m2_expr.ExprId): - dst_name = dst.name + "_new" if add_new else dst.name + # JiT specifics - ptr_casted = self.get_ptr_by_expr( - m2_expr.ExprId(dst_name, dst.size)) - builder.store(src, ptr_casted) - - elif isinstance(dst, m2_expr.ExprMem): - self.add_ir(dst.arg) - - # Function call - fc_name = "MEM_WRITE_%02d" % dst.size - fc_ptr = self.mod.get_function_named(fc_name) - dst = self.add_ir(dst.arg) - dst_casted = builder.zext(dst, LLVMType.int(64)) - builder.call(fc_ptr, [self.local_vars["vmmngr"], - dst_casted, - src]) - - else: - raise Exception("UnknownAffectationType") - - def check_error(self, line, except_do_not_update_pc=False): + def check_memory_exception(self, offset, restricted_exception=False): """Add a check for memory errors. - @line: Irbloc line corresponding to the current instruction - If except_do_not_update_pc, check only for exception which do not - require a pc update""" + @offset: offset of the current exception (int or Instruction) + If restricted_exception, check only for exception which do not + require a pc update, and do not consider automod exception""" # VmMngr "get_exception_flag" return's size size = 64 - t_size = LLVMType.int(size) - - # Current address - pc_to_return = line.offset + t_size = LLVMType.IntType(size) # Get exception flag value + # TODO: avoid costly call using a structure deref builder = self.builder - fc_ptr = self.mod.get_function_named("get_exception_flag") + fc_ptr = self.mod.get_global("get_exception_flag") exceptionflag = builder.call(fc_ptr, [self.local_vars["vmmngr"]]) - if except_do_not_update_pc is True: - auto_mod_flag = m2_csts.EXCEPT_DO_NOT_UPDATE_PC - m2_flag = llvm_c.Constant.int(t_size, auto_mod_flag) + if restricted_exception is True: + flag = ~m2_csts.EXCEPT_CODE_AUTOMOD & m2_csts.EXCEPT_DO_NOT_UPDATE_PC + m2_flag = llvm_ir.Constant(t_size, flag) exceptionflag = builder.and_(exceptionflag, m2_flag) # Compute cond - zero_casted = llvm_c.Constant.int(t_size, 0) - condition_bool = builder.icmp(llvm_c.ICMP_NE, - exceptionflag, - zero_casted) + zero_casted = llvm_ir.Constant(t_size, 0) + condition_bool = builder.icmp_unsigned("!=", + exceptionflag, + zero_casted) # Create bbls branch_id = self.new_branch_name() @@ -723,168 +856,276 @@ class LLVMFunction(): # Then Bloc builder.position_at_end(then_block) - self.set_ret(llvm_c.Constant.int(self.ret_type, pc_to_return)) + PC = self.llvm_context.PC + if isinstance(offset, (int, long)): + offset = self.add_ir(m2_expr.ExprInt(offset, PC.size)) + self.affect(offset, PC) + self.affect(self.add_ir(m2_expr.ExprInt8(1)), m2_expr.ExprId("status")) + self.set_ret(offset) builder.position_at_end(merge_block) - # Reactivate object caching self.main_stream = current_main_stream - def log_instruction(self, instruction, line): - "Print current instruction and registers if options are set" + def check_cpu_exception(self, offset, restricted_exception=False): + """Add a check for CPU errors. + @offset: offset of the current exception (int or Instruction) + If restricted_exception, check only for exception which do not + require a pc update""" - # Get builder + # Get exception flag value builder = self.builder + m2_exception_flag = self.llvm_context.ir_arch.arch.regs.exception_flags + t_size = LLVMType.IntType(m2_exception_flag.size) + exceptionflag = self.add_ir(m2_exception_flag) - if self.log_mn is True: - print instruction # TODO + # Compute cond + if restricted_exception is True: + flag = m2_csts.EXCEPT_NUM_UPDT_EIP + condition_bool = builder.icmp_unsigned(">", exceptionflag, + llvm_ir.Constant(t_size, flag)) + else: + zero_casted = llvm_ir.Constant(t_size, 0) + condition_bool = builder.icmp_unsigned("!=", + exceptionflag, + zero_casted) - if self.log_regs is True: - # Call dump general purpose registers - fc_ptr = self.mod.get_function_named("dump_gpregs") - builder.call(fc_ptr, [self.local_vars["vmcpu"]]) + # Create bbls + branch_id = self.new_branch_name() + then_block = self.append_basic_block('then%s' % branch_id) + merge_block = self.append_basic_block('ifcond%s' % branch_id) - def add_bloc(self, bloc, lines): - "Add a bloc of instruction in the current function" + builder.cbranch(condition_bool, then_block, merge_block) - for instruction, line in zip(bloc, lines): - new_reg = set() + # Deactivate object caching + current_main_stream = self.main_stream + self.main_stream = False - # Check general errors only at the beggining of instruction - if line.offset not in self.offsets_jitted: - self.offsets_jitted.add(line.offset) - self.check_error(line) + # Then Bloc + builder.position_at_end(then_block) + PC = self.llvm_context.PC + if isinstance(offset, (int, long)): + offset = self.add_ir(m2_expr.ExprInt(offset, PC.size)) + self.affect(offset, PC) + self.affect(self.add_ir(m2_expr.ExprInt8(1)), m2_expr.ExprId("status")) + self.set_ret(offset) - # Log mn and registers if options is set - self.log_instruction(instruction, line) + builder.position_at_end(merge_block) + # Reactivate object caching + self.main_stream = current_main_stream + def gen_pre_code(self, attributes): + if attributes.log_mn: + self.printf("%.8X %s\n" % (attributes.instr.offset, + attributes.instr)) + + def gen_post_code(self, attributes): + if attributes.log_regs: + fc_ptr = self.mod.get_global(self.llvm_context.logging_func) + self.builder.call(fc_ptr, [self.local_vars["vmcpu"]]) + + def gen_post_instr_checks(self, attrib, next_instr): + if attrib.mem_read | attrib.mem_write: + fc_ptr = self.mod.get_global("check_memory_breakpoint") + self.builder.call(fc_ptr, [self.local_vars["vmmngr"]]) + fc_ptr = self.mod.get_global("check_invalid_code_blocs") + self.builder.call(fc_ptr, [self.local_vars["vmmngr"]]) + self.check_memory_exception(next_instr, restricted_exception=False) + if attrib.set_exception or attrib.op_set_exception: + self.check_cpu_exception(next_instr, restricted_exception=False) + + if attrib.mem_read | attrib.mem_write: + fc_ptr = self.mod.get_global("reset_memory_access") + self.builder.call(fc_ptr, [self.local_vars["vmmngr"]]) + + def expr2cases(self, expr): + """ + Evaluate @expr and return: + - switch value -> dst + - evaluation of the switch value (if any) + """ - # Pass on empty instruction - if len(instruction) == 0: - continue + to_eval = expr + dst2case = {} + case2dst = {} + for i, solution in enumerate(possible_values(expr)): + value = solution.value + index = dst2case.get(value, i) + to_eval = to_eval.replace_expr({value: m2_expr.ExprInt(index, value.size)}) + dst2case[value] = index + if m2_asmbloc.expr_is_int_or_label(value): + case2dst[i] = value + else: + case2dst[i] = self.add_ir(value) - for expression in instruction: - # Apply preinit transformation - for func in self.llvm_context.IR_transformation_functions: - expression = func(expression) - # Treat current expression - self.affect(expression.src, expression.dst) + evaluated = self.add_ir(to_eval) + return case2dst, evaluated - # Save registers updated - new_reg.update(expression.dst.get_w()) + def gen_jump2dst(self, attrib, dst): + """Generate the code for a jump to @dst with final check for error - # Check for errors (without updating PC) - self.check_error(line, except_do_not_update_pc=True) + Several cases have to be considered: + - jump to an offset out of the current ASM BBL (JMP 0x11223344) + - jump to an offset inside the current ASM BBL (Go to next instruction) + - jump to a generated IR label, which must be jitted in this same + function (REP MOVSB) + - jump to a computed offset (CALL @32[0x11223344]) + """ + PC = self.llvm_context.PC + # We are no longer in the main stream, deactivate cache + self.main_stream = False - # new -> normal - reg_written = [] - for r in new_reg: - if isinstance(r, m2_expr.ExprId): - r_new = m2_expr.ExprId(r.name + "_new", r.size) - reg_written += [r, r_new] - self.affect(r_new, r, add_new=False) + if isinstance(dst, m2_expr.ExprInt): + dst = m2_expr.ExprId(self.llvm_context.ir_arch.symbol_pool.getby_offset_create(int(dst)), + dst.size) + + if m2_asmbloc.expr_is_label(dst): + bbl = self.get_basic_bloc_by_label(dst) + if bbl is not None: + # "local" jump, inside this function + if dst.name.offset is not None: + # Avoid checks on generated label + self.gen_post_code(attrib) + self.gen_post_instr_checks(attrib, dst.name.offset) + self.builder.branch(bbl) + return + else: + # "extern" jump on a defined offset, return to the caller + offset = dst.name.offset + dst = self.add_ir(m2_expr.ExprInt(offset, PC.size)) - # Clear cache - self.clear_cache(reg_written) - self.main_stream = True + # "extern" jump with a computed value, return to the caller + assert isinstance(dst, (llvm_ir.Instruction, llvm_ir.Value)) + # Cast @dst, if needed + # for instance, x86_32: IRDst is 32 bits, so is @dst; PC is 64 bits + if dst.type.width != PC.size: + dst = self.builder.zext(dst, LLVMType.IntType(PC.size)) - def from_bloc(self, bloc, final_expr): - """Build the function from a bloc, with the dst equation. - Prototype : f(i8* vmcpu, i8* vmmngr)""" + self.gen_post_code(attrib) + self.affect(dst, PC) + self.gen_post_instr_checks(attrib, dst) + self.affect(self.add_ir(m2_expr.ExprInt8(0)), m2_expr.ExprId("status")) + self.set_ret(dst) - # Build function signature - self.my_args.append((m2_expr.ExprId("vmcpu"), - llvm_c.PointerType.pointer(LLVMType.int(8)), - "vmcpu")) - self.my_args.append((m2_expr.ExprId("vmmngr"), - llvm_c.PointerType.pointer(LLVMType.int(8)), - "vmmngr")) - self.ret_type = LLVMType.int(final_expr.size) - # Initialise the function - self.init_fc() + def gen_irblock(self, attrib, instr, irblock): + """ + Generate the code for an @irblock + @instr: the current instruction to translate + @irblock: an irbloc instance + @attrib: an Attributs instance + """ - # Add content - self.add_bloc(bloc, []) + case2dst = None + case_value = None - # Finalise the function - self.set_ret(self.add_ir(final_expr)) + for assignblk in irblock.irs: + # Enable cache + self.main_stream = True + self.expr_cache = {} + + # Prefetch memory + for element in assignblk.get_r(mem_read=True): + if isinstance(element, m2_expr.ExprMem): + self.add_ir(element) + + # Evaluate expressions + values = {} + for dst, src in assignblk.iteritems(): + if dst == self.llvm_context.ir_arch.IRDst: + case2dst, case_value = self.expr2cases(src) + else: + values[dst] = self.add_ir(src) - raise NotImplementedError("Not tested") + # Check memory access exception + if assignblk.mem_read: + self.check_memory_exception(instr.offset, + restricted_exception=True) - def canonize_label_name(self, label): - """Canonize @label names to a common form. - @label: str or asmlabel instance""" - if isinstance(label, str): - return label - elif isinstance(label, m2_asmbloc.asm_label): - return "label_%s" % label.name - else: - raise ValueError("label must either be str or asmlabel") + # Check operation exception + if assignblk.op_set_exception: + self.check_cpu_exception(instr.offset, restricted_exception=True) - def get_basic_bloc_by_label(self, label): - "Return the bbl corresponding to label, None otherwise" - return self.name2bbl.get(self.canonize_label_name(label), None) + # Update the memory + for dst, src in values.iteritems(): + if isinstance(dst, m2_expr.ExprMem): + self.affect(src, dst) - def gen_ret_or_branch(self, dest): - """Manage the dest ExprId. If label, branch on it if it is known. - Otherwise, return the ExprId or the offset value""" + # Check memory write exception + if assignblk.mem_write: + self.check_memory_exception(instr.offset, + restricted_exception=True) - builder = self.builder + # Update registers values + for dst, src in values.iteritems(): + if not isinstance(dst, m2_expr.ExprMem): + self.affect(src, dst) - if isinstance(dest, m2_expr.ExprId): - dest_name = dest.name - elif isinstance(dest, m2_expr.ExprSlice) and \ - isinstance(dest.arg, m2_expr.ExprId): - # Manage ExprId mask case - dest_name = dest.arg.name - else: - raise ValueError() + # Check post assignblk exception flags + if assignblk.set_exception: + self.check_cpu_exception(instr.offset, restricted_exception=True) - if not isinstance(dest_name, str): - label = dest_name - target_bbl = self.get_basic_bloc_by_label(label) - if target_bbl is None: - self.set_ret(self.add_ir(dest)) - else: - builder.branch(target_bbl) + # Destination + assert case2dst is not None + if len(case2dst) == 1: + # Avoid switch in this common case + self.gen_jump2dst(attrib, case2dst.values()[0]) else: - self.set_ret(self.add_ir(dest)) + current_bbl = self.builder.basic_block - def add_irbloc(self, irbloc): - "Add the content of irbloc at the corresponding labeled block" + # Gen the out cases + branch_id = self.new_branch_name() + case2bbl = {} + for case, dst in case2dst.iteritems(): + name = "switch_%s_%d" % (branch_id, case) + bbl = self.append_basic_block(name) + case2bbl[case] = bbl + self.builder.position_at_start(bbl) + self.gen_jump2dst(attrib, dst) + + # Jump on the correct output + self.builder.position_at_end(current_bbl) + switch = self.builder.switch(case_value, case2bbl[0]) + for i, bbl in case2bbl.iteritems(): + if i == 0: + # Default case is case 0, arbitrary + continue + switch.add_case(i, bbl) + + def gen_bad_block(self, asmblock): + """ + Translate an asm_bad_block into a CPU exception + """ builder = self.builder + m2_exception_flag = self.llvm_context.ir_arch.arch.regs.exception_flags + t_size = LLVMType.IntType(m2_exception_flag.size) + self.affect(self.add_ir(m2_expr.ExprInt8(1)), + m2_expr.ExprId("status")) + self.affect(t_size(m2_csts.EXCEPT_UNK_MNEMO), + m2_exception_flag) + self.set_ret(LLVMType.IntType(64)(asmblock.label.offset)) + + def gen_finalize(self, asmblock, codegen): + """ + In case of delayslot, generate a dummy BBL which return on the computed IRDst + or on next_label + """ + if self.llvm_context.has_delayslot: + next_label = codegen.get_block_post_label(asmblock) + builder = self.builder - bloc = irbloc.irs - dest = irbloc.dst - label = irbloc.label - lines = irbloc.lines - - # Get labeled basic bloc - label_block = self.get_basic_bloc_by_label(label) - builder.position_at_end(label_block) - - # Erase cache - self.expr_cache = {} - - # Add the content of the bloc with corresponding lines - self.add_bloc(bloc, lines) - - # Erase cache - self.expr_cache = {} + builder.position_at_end(self.get_basic_bloc_by_label(next_label)) - # Manage ret - for func in self.llvm_context.IR_transformation_functions: - dest = func(dest) + # Common code + self.affect(self.add_ir(m2_expr.ExprInt8(0)), + m2_expr.ExprId("status")) - if isinstance(dest, m2_expr.ExprCond): - # Compute cond - cond = self.add_ir(dest.cond) - zero_casted = llvm_c.Constant.int(LLVMType.int(dest.cond.size), - 0) - condition_bool = builder.icmp(llvm_c.ICMP_NE, cond, - zero_casted) + # Check if IRDst has been set + zero_casted = LLVMType.IntType(codegen.delay_slot_set.size)(0) + condition_bool = builder.icmp_unsigned("!=", + self.add_ir(codegen.delay_slot_set), + zero_casted) # Create bbls branch_id = self.new_branch_name() @@ -893,78 +1134,138 @@ class LLVMFunction(): builder.cbranch(condition_bool, then_block, else_block) - # Then Bloc - builder.position_at_end(then_block) - self.gen_ret_or_branch(dest.src1) + # Deactivate object caching + self.main_stream = False - # Else Bloc + # Then Block + builder.position_at_end(then_block) + PC = self.llvm_context.PC + to_ret = self.add_ir(codegen.delay_slot_dst) + self.affect(to_ret, PC) + self.affect(self.add_ir(m2_expr.ExprInt8(0)), + m2_expr.ExprId("status")) + self.set_ret(to_ret) + + # Else Block builder.position_at_end(else_block) - self.gen_ret_or_branch(dest.src2) - - elif isinstance(dest, m2_expr.ExprId): - self.gen_ret_or_branch(dest) - - elif isinstance(dest, m2_expr.ExprSlice): - self.gen_ret_or_branch(dest) + PC = self.llvm_context.PC + to_ret = LLVMType.IntType(PC.size)(next_label.offset) + self.affect(to_ret, PC) + self.set_ret(to_ret) - else: - raise Exception("Bloc dst has to be an ExprId or an ExprCond") - - def from_blocs(self, blocs): - """Build the function from a list of bloc (irbloc instances). - Prototype : f(i8* vmcpu, i8* vmmngr)""" + def from_asmblock(self, asmblock): + """Build the function from an asmblock (asm_block instance). + Prototype : f(i8* jitcpu, i8* vmcpu, i8* vmmngr, i8* status)""" # Build function signature + self.my_args.append((m2_expr.ExprId("jitcpu"), + llvm_ir.PointerType(LLVMType.IntType(8)), + "jitcpu")) self.my_args.append((m2_expr.ExprId("vmcpu"), - llvm_c.PointerType.pointer(LLVMType.int(8)), + llvm_ir.PointerType(LLVMType.IntType(8)), "vmcpu")) self.my_args.append((m2_expr.ExprId("vmmngr"), - llvm_c.PointerType.pointer(LLVMType.int(8)), + llvm_ir.PointerType(LLVMType.IntType(8)), "vmmngr")) + self.my_args.append((m2_expr.ExprId("status"), + llvm_ir.PointerType(LLVMType.IntType(8)), + "status")) ret_size = 64 - self.ret_type = LLVMType.int(ret_size) + self.ret_type = LLVMType.IntType(ret_size) # Initialise the function self.init_fc() + self.local_vars_pointers["status"] = self.local_vars["status"] + + if isinstance(asmblock, m2_asmbloc.asm_block_bad): + self.gen_bad_block(asmblock) + return # Create basic blocks (for label branchs) entry_bbl, builder = self.entry_bbl, self.builder - - for irbloc in blocs: - name = self.canonize_label_name(irbloc.label) - self.append_basic_block(name) + for instr in asmblock.lines: + lbl = self.llvm_context.ir_arch.symbol_pool.getby_offset_create(instr.offset) + self.append_basic_block(lbl) + + # TODO: merge duplicate code with CGen + codegen = self.llvm_context.cgen_class(self.llvm_context.ir_arch) + irblocks_list = codegen.block2assignblks(asmblock) + + # Prepare for delayslot + if self.llvm_context.has_delayslot: + for element in (codegen.delay_slot_dst, codegen.delay_slot_set): + eltype = LLVMType.IntType(element.size) + ptr = self.CreateEntryBlockAlloca(eltype, + default_value=eltype(0)) + self.local_vars_pointers[element.name] = ptr + lbl = codegen.get_block_post_label(asmblock) + self.append_basic_block(lbl) # Add content builder.position_at_end(entry_bbl) - for irbloc in blocs: - self.add_irbloc(irbloc) + for instr, irblocks in zip(asmblock.lines, irblocks_list): + attrib = codegen.get_attributes(instr, irblocks, self.log_mn, + self.log_regs) + + # Pre-create basic blocks + for irblock in irblocks: + self.append_basic_block(irblock.label, overwrite=False) + + # Generate the corresponding code + for index, irblock in enumerate(irblocks): + self.llvm_context.ir_arch.irbloc_fix_regs_for_mode( + irblock, self.llvm_context.ir_arch.attrib) + + # Set the builder at the begining of the correct bbl + name = self.canonize_label_name(irblock.label) + self.builder.position_at_end(self.get_basic_bloc_by_label(name)) + + if index == 0: + self.gen_pre_code(attrib) + self.gen_irblock(attrib, instr, irblock) + + # Gen finalize (see codegen::CGen) is unrecheable, except with delayslot + self.gen_finalize(asmblock, codegen) # Branch entry_bbl on first label builder.position_at_end(entry_bbl) - first_label_bbl = self.get_basic_bloc_by_label(blocs[0].label) + first_label_bbl = self.get_basic_bloc_by_label(asmblock.label) builder.branch(first_label_bbl) + + # LLVMFunction manipulation + def __str__(self): "Print the llvm IR corresponding to the current module" + return str(self.mod) + + def dot(self): + "Return the CFG of the current function" + return llvm.get_function_cfg(self.fc) - return str(self.fc) + def as_llvm_mod(self): + """Return a ModuleRef standing for the current function""" + if self._llvm_mod is None: + self._llvm_mod = llvm.parse_assembly(str(self.mod)) + return self._llvm_mod def verify(self): "Verify the module syntax" + return self.as_llvm_mod().verify() - return self.mod.verify() + def get_bytecode(self): + "Return LLVM bitcode corresponding to the current module" + return self.as_llvm_mod().as_bitcode() def get_assembly(self): "Return native assembly corresponding to the current module" - - return self.mod.to_native_assembly() + return self.llvm_context.target_machine.emit_assembly(self.as_llvm_mod()) def optimise(self): "Optimise the function in place" - while self.llvm_context.pass_manager.run(self.fc): - continue + return self.llvm_context.pass_manager.run(self.as_llvm_mod()) def __call__(self, *args): "Eval the function with arguments args" @@ -978,9 +1279,10 @@ class LLVMFunction(): def get_function_pointer(self): "Return a pointer on the Jitted function" - e = self.llvm_context.get_execengine() + engine = self.llvm_context.get_execengine() - return e.get_pointer_to_function(self.fc) + # Add the module and make sure it is ready for execution + engine.add_module(self.as_llvm_mod()) + engine.finalize_object() -# TODO: -# - Add more expressions + return engine.get_function_address(self.fc.name) diff --git a/miasm2/jitter/vm_mngr.c b/miasm2/jitter/vm_mngr.c index a8cc7639..42f91f72 100644 --- a/miasm2/jitter/vm_mngr.c +++ b/miasm2/jitter/vm_mngr.c @@ -76,6 +76,11 @@ const uint8_t parity_table[256] = { 0, CC_P, CC_P, 0, CC_P, 0, 0, CC_P, }; +uint8_t parity(uint64_t a) { + return parity_table[(a) & 0xFF]; +} + + // #define DEBUG_MIASM_AUTOMOD_CODE void memory_access_list_init(struct memory_access_list * access) @@ -916,7 +921,7 @@ unsigned int rcr_rez_op(unsigned int size, unsigned int a, unsigned int b, unsig return tmp; } -unsigned int x86_bsr(uint64_t src, unsigned int size) +unsigned int x86_bsr(unsigned int size, uint64_t src) { int i; @@ -928,7 +933,7 @@ unsigned int x86_bsr(uint64_t src, unsigned int size) exit(0); } -unsigned int x86_bsf(uint64_t src, unsigned int size) +unsigned int x86_bsf(unsigned int size, uint64_t src) { int i; diff --git a/miasm2/jitter/vm_mngr.h b/miasm2/jitter/vm_mngr.h index d3583b52..88ecf34d 100644 --- a/miasm2/jitter/vm_mngr.h +++ b/miasm2/jitter/vm_mngr.h @@ -194,7 +194,7 @@ int vm_write_mem(vm_mngr_t* vm_mngr, uint64_t addr, char *buffer, uint64_t size) extern const uint8_t parity_table[256]; -#define parity(a) (parity_table[(a) & 0xFF]) +uint8_t parity(uint64_t a); unsigned int my_imul08(unsigned int a, unsigned int b); diff --git a/test/test_all.py b/test/test_all.py index 62f1cd4b..7cc8f6eb 100644 --- a/test/test_all.py +++ b/test/test_all.py @@ -55,7 +55,7 @@ class ArchUnitTest(RegressionTest): # script -> blacklisted jitter blacklist = { - "x86/unit/mn_float.py": ["python"], + "x86/unit/mn_float.py": ["python", "llvm"], } for script in ["x86/sem.py", "x86/unit/mn_strings.py", @@ -684,7 +684,7 @@ By default, no tag is omitted." % ", ".join(TAGS.keys()), default="") # Handle llvm modularity llvm = True try: - import llvm + import llvmlite except ImportError: llvm = False @@ -695,12 +695,9 @@ By default, no tag is omitted." % ", ".join(TAGS.keys()), default="") except ImportError: tcc = False - # TODO XXX: fix llvm jitter (deactivated for the moment) - llvm = False - if llvm is False: print "%(red)s[LLVM]%(end)s Python" % cosmetics.colors + \ - "'py-llvm 3.2' module is required for llvm tests" + "'llvmlite' module is required for llvm tests" # Remove llvm tests if TAGS["llvm"] not in exclude_tags: |