diff options
| author | Ajax <commial@gmail.com> | 2016-12-21 17:17:34 +0100 |
|---|---|---|
| committer | Ajax <commial@gmail.com> | 2017-01-04 17:14:55 +0100 |
| commit | 68e548b9b6c411cea13792a87291d0514fd45520 (patch) | |
| tree | cda7fa91c1d2ed6fdc6589da93c21a1399730f47 | |
| parent | d554129240394be47c9d99655e7d7feef5567795 (diff) | |
| download | miasm-68e548b9b6c411cea13792a87291d0514fd45520.tar.gz miasm-68e548b9b6c411cea13792a87291d0514fd45520.zip | |
Adapt codegen.CGen principles to LLVM
| -rw-r--r-- | miasm2/jitter/Jitllvm.c | 7 | ||||
| -rw-r--r-- | miasm2/jitter/jitcore_llvm.py | 16 | ||||
| -rw-r--r-- | miasm2/jitter/llvmconvert.py | 439 |
3 files changed, 418 insertions, 44 deletions
diff --git a/miasm2/jitter/Jitllvm.c b/miasm2/jitter/Jitllvm.c index 98e047bf..c176a4b2 100644 --- a/miasm2/jitter/Jitllvm.c +++ b/miasm2/jitter/Jitllvm.c @@ -13,16 +13,17 @@ PyObject* llvm_exec_bloc(PyObject* self, PyObject* args) { uint64_t func_addr; - uint64_t (*func)(void*, void*, void*); + uint64_t (*func)(void*, void*, void*, uint8_t*); uint64_t vm; uint64_t ret; JitCpu* jitcpu; - + uint8_t status; + if (!PyArg_ParseTuple(args, "KOK", &func_addr, &jitcpu, &vm)) return NULL; vm_cpu_t* cpu = jitcpu->cpu; func = (void *) (intptr_t) func_addr; - ret = func((void*) jitcpu, (void*)(intptr_t) cpu, (void*)(intptr_t) vm); + ret = func((void*) jitcpu, (void*)(intptr_t) cpu, (void*)(intptr_t) vm, &status); return PyLong_FromUnsignedLongLong(ret); } diff --git a/miasm2/jitter/jitcore_llvm.py b/miasm2/jitter/jitcore_llvm.py index 88db199a..020f2d45 100644 --- a/miasm2/jitter/jitcore_llvm.py +++ b/miasm2/jitter/jitcore_llvm.py @@ -45,7 +45,7 @@ class JitCore_LLVM(jitcore.JitCore): pass # Create a context - self.context = LLVMContext_JIT(libs_to_load) + self.context = LLVMContext_JIT(libs_to_load, self.ir_arch) # Set the optimisation level self.context.optimise_level() @@ -64,17 +64,21 @@ class JitCore_LLVM(jitcore.JitCore): # Set IRs transformation to apply self.context.set_IR_transformation(self.ir_arch.expr_fix_regs_for_mode) - def jitirblocs(self, label, irblocs): + def add_bloc(self, block): + """Add a block to JiT and JiT it. + @block: the block to add + """ + # TODO: caching using hash # Build a function in the context - func = LLVMFunction(self.context, label.name) + func = LLVMFunction(self.context, block.label.name) # Set log level func.log_regs = self.log_regs func.log_mn = self.log_mn - # Import irblocs - func.from_blocs(irblocs) + # Import asm block + func.from_asmblock(block) # Verify if self.options["safe_mode"] is True: @@ -91,7 +95,7 @@ class JitCore_LLVM(jitcore.JitCore): print func.get_assembly() # Store a pointer on the function jitted code - self.lbl2jitbloc[label.offset] = func.get_function_pointer() + self.lbl2jitbloc[block.label.offset] = func.get_function_pointer() def jit_call(self, label, cpu, _vmmngr, breakpoints): """Call the function label with cpu and vmmngr states diff --git a/miasm2/jitter/llvmconvert.py b/miasm2/jitter/llvmconvert.py index 460d4e4f..2165fbc5 100644 --- a/miasm2/jitter/llvmconvert.py +++ b/miasm2/jitter/llvmconvert.py @@ -16,6 +16,8 @@ from llvmlite import ir as llvm_ir import miasm2.expression.expression as m2_expr import miasm2.jitter.csts as m2_csts import miasm2.core.asmbloc as m2_asmbloc +from miasm2.jitter.codegen import CGen +from miasm2.expression.expression_helper import possible_values class LLVMType(llvm_ir.Type): @@ -138,9 +140,11 @@ class LLVMContext_JIT(LLVMContext): """Extend LLVMContext_JIT in order to handle memory management and custom operations""" - def __init__(self, library_filenames, name="mod"): + def __init__(self, library_filenames, ir_arch, name="mod"): "Init a LLVMContext object, and load the mem management shared library" self.library_filenames = library_filenames + self.ir_arch = ir_arch + self.arch_specific() LLVMContext.__init__(self, name) self.vmcpu = {} self.engines = [] @@ -154,6 +158,15 @@ class LLVMContext_JIT(LLVMContext): self.add_op() self.add_log_functions() + def arch_specific(self): + arch = self.ir_arch.arch + if arch.name == "x86": + self.PC = arch.regs.RIP + self.logging_func = "dump_gpregs_%d" % self.ir_arch.attrib + else: + self.PC = self.ir_arch.pc + self.logging_func = "dump_gpregs" + def add_memlookups(self): "Add MEM_LOOKUP functions" @@ -168,7 +181,15 @@ class LLVMContext_JIT(LLVMContext): "args": [p8, LLVMType.IntType(64), LLVMType.IntType(i)]} - + fc["reset_memory_access"] = {"ret": llvm_ir.VoidType(), + "args": [p8, + ]} + fc["check_memory_breakpoint"] = {"ret": llvm_ir.VoidType(), + "args": [p8, + ]} + fc["check_invalid_code_blocs"] = {"ret": llvm_ir.VoidType(), + "args": [p8, + ]} self.add_fc(fc) def add_get_exceptionflag(self): @@ -219,8 +240,8 @@ class LLVMContext_JIT(LLVMContext): "Add functions for state logging" p8 = llvm_ir.PointerType(LLVMType.IntType(8)) - self.add_fc({"dump_gpregs": {"ret": llvm_ir.VoidType(), - "args": [p8]}}) + self.add_fc({self.logging_func: {"ret": llvm_ir.VoidType(), + "args": [p8]}}) def set_vmcpu(self, lookup_table): "Set the correspondance between register name and vmcpu offset" @@ -310,13 +331,17 @@ class LLVMFunction(): "Show the CFG of the current function" self.fc.viewCFG() - def append_basic_block(self, label): + def append_basic_block(self, label, overwrite=True): """Add a new basic block to the current function. @label: str or asmlabel + @overwrite: if False, do nothing if a bbl with the same name already exists Return the corresponding LLVM Basic Block""" name = self.canonize_label_name(label) + bbl = self.name2bbl.get(name, None) + if not overwrite and bbl is not None: + return bbl bbl = self.fc.append_basic_block(name) - self.name2bbl[label] = bbl + self.name2bbl[name] = bbl return bbl @@ -712,13 +737,8 @@ class LLVMFunction(): builder = self.builder if isinstance(dst, m2_expr.ExprId): - dst_name = dst.name - if dst_name == "IRDst": - self.local_vars[dst_name] = src - else: - ptr_casted = self.get_ptr_by_expr( - m2_expr.ExprId(dst_name, dst.size)) - builder.store(src, ptr_casted) + ptr_casted = self.get_ptr_by_expr(dst) + builder.store(src, ptr_casted) elif isinstance(dst, m2_expr.ExprMem): addr = self.add_ir(dst.arg) @@ -726,27 +746,25 @@ class LLVMFunction(): else: raise Exception("UnknownAffectationType") - def check_error(self, line, except_do_not_update_pc=False): + def check_memory_exception(self, offset, restricted_exception=False): """Add a check for memory errors. - @line: Irbloc line corresponding to the current instruction - If except_do_not_update_pc, check only for exception which do not - require a pc update""" + @offset: offset of the current exception (int or Instruction) + If restricted_exception, check only for exception which do not + require a pc update, and do not consider automod exception""" # VmMngr "get_exception_flag" return's size size = 64 t_size = LLVMType.IntType(size) - # Current address - pc_to_return = line.offset - # Get exception flag value + # TODO: avoid costly call using a structure deref builder = self.builder fc_ptr = self.mod.get_global("get_exception_flag") exceptionflag = builder.call(fc_ptr, [self.local_vars["vmmngr"]]) - if except_do_not_update_pc is True: - auto_mod_flag = m2_csts.EXCEPT_DO_NOT_UPDATE_PC - m2_flag = llvm_ir.Constant(t_size, auto_mod_flag) + if restricted_exception is True: + flag = ~m2_csts.EXCEPT_CODE_AUTOMOD & m2_csts.EXCEPT_DO_NOT_UPDATE_PC + m2_flag = llvm_ir.Constant(t_size, flag) exceptionflag = builder.and_(exceptionflag, m2_flag) # Compute cond @@ -768,26 +786,63 @@ class LLVMFunction(): # Then Bloc builder.position_at_end(then_block) - self.set_ret(llvm_ir.Constant(self.ret_type, pc_to_return)) + PC = self.llvm_context.PC + if isinstance(offset, (int, long)): + offset = self.add_ir(m2_expr.ExprInt(offset, PC.size)) + self.affect(offset, PC) + self.affect(self.add_ir(m2_expr.ExprInt8(1)), m2_expr.ExprId("status")) + self.set_ret(offset) builder.position_at_end(merge_block) - # Reactivate object caching self.main_stream = current_main_stream - def log_instruction(self, instruction, line): - "Print current instruction and registers if options are set" + def check_cpu_exception(self, offset, restricted_exception=False): + """Add a check for CPU errors. + @offset: offset of the current exception (int or Instruction) + If restricted_exception, check only for exception which do not + require a pc update""" - # Get builder + # Get exception flag value builder = self.builder + m2_exception_flag = self.llvm_context.ir_arch.arch.regs.exception_flags + t_size = LLVMType.IntType(m2_exception_flag.size) + exceptionflag = self.add_ir(m2_exception_flag) + + # Compute cond + if restricted_exception is True: + flag = m2_csts.EXCEPT_NUM_UPDT_EIP + condition_bool = builder.icmp_unsigned(">", exceptionflag, + llvm_ir.Constant(t_size, flag)) + else: + zero_casted = llvm_ir.Constant(t_size, 0) + condition_bool = builder.icmp_unsigned("!=", + exceptionflag, + zero_casted) + + # Create bbls + branch_id = self.new_branch_name() + then_block = self.append_basic_block('then%s' % branch_id) + merge_block = self.append_basic_block('ifcond%s' % branch_id) - if self.log_mn is True: - print instruction # TODO + builder.cbranch(condition_bool, then_block, merge_block) - if self.log_regs is True: - # Call dump general purpose registers - fc_ptr = self.mod.get_global("dump_gpregs") - builder.call(fc_ptr, [self.local_vars["vmcpu"]]) + # Deactivate object caching + current_main_stream = self.main_stream + self.main_stream = False + + # Then Bloc + builder.position_at_end(then_block) + PC = self.llvm_context.PC + if isinstance(offset, (int, long)): + offset = self.add_ir(m2_expr.ExprInt(offset, PC.size)) + self.affect(offset, PC) + self.affect(self.add_ir(m2_expr.ExprInt8(1)), m2_expr.ExprId("status")) + self.set_ret(offset) + + builder.position_at_end(merge_block) + # Reactivate object caching + self.main_stream = current_main_stream def add_bloc(self, bloc, lines): "Add a bloc of instruction in the current function" @@ -868,6 +923,8 @@ class LLVMFunction(): return label elif isinstance(label, m2_asmbloc.asm_label): return "label_%s" % label.name + elif m2_asmbloc.expr_is_label(label): + return "label_%s" % label.name.name else: raise ValueError("label must either be str or asmlabel") @@ -956,11 +1013,323 @@ class LLVMFunction(): self.gen_ret_or_branch(dest) elif isinstance(dest, m2_expr.ExprMem): - self.set_ret(self.add_ir(m2_expr.ExprId("IRDst"))) + self.set_ret(self.add_ir(self.ir_arch.IRDst)) else: raise Exception("Bloc dst has to be an ExprId or an ExprCond") + def canonize_instr_bbl(self, instr): + if isinstance(instr, (int, long)): + return "instr_%s" % hex(instr) + return "instr_%s" % hex(instr.offset) + + def global_constant(self, name, value): + """ + Inspired from numba/cgutils.py + + Get or create a (LLVM module-)global constant with *name* or *value*. + """ + module = self.mod + data = llvm_ir.GlobalVariable(self.mod, value.type, name=name) + data.global_constant = True + data.initializer = value + return data + + def make_bytearray(self, buf): + """ + Inspired from numba/cgutils.py + + Make a byte array constant from *buf*. + """ + b = bytearray(buf) + n = len(b) + return llvm_ir.Constant(llvm_ir.ArrayType(llvm_ir.IntType(8), n), b) + + def printf(self, format, *args): + """ + Inspired from numba/cgutils.py + + Calls printf(). + Argument `format` is expected to be a Python string. + Values to be printed are listed in `args`. + + Note: There is no checking to ensure there is correct number of values + in `args` and there type matches the declaration in the format string. + """ + assert isinstance(format, str) + mod = self.mod + # Make global constant for format string + cstring = llvm_ir.IntType(8).as_pointer() + fmt_bytes = self.make_bytearray((format + '\00').encode('ascii')) + + base_name = "printf_format" + count = 0 + while self.mod.get_global("%s_%d" % (base_name, count)): + count += 1 + global_fmt = self.global_constant("%s_%d" % (base_name, count), + fmt_bytes) + fnty = llvm_ir.FunctionType(llvm_ir.IntType(32), [cstring], + var_arg=True) + # Insert printf() + fn = mod.get_global('printf') + if fn is None: + fn = llvm_ir.Function(mod, fnty, name="printf") + # Call + ptr_fmt = self.builder.bitcast(global_fmt, cstring) + return self.builder.call(fn, [ptr_fmt] + list(args)) + + def gen_pre_code(self, attributes): + if attributes.log_mn: + self.printf("%.8X %s\n" % (attributes.instr.offset, + attributes.instr)) + + def gen_post_code(self, attributes): + if attributes.log_regs: + fc_ptr = self.mod.get_global(self.llvm_context.logging_func) + self.builder.call(fc_ptr, [self.local_vars["vmcpu"]]) + + def gen_post_instr_checks(self, attrib, next_instr): + if attrib.mem_read | attrib.mem_write: + fc_ptr = self.mod.get_global("check_memory_breakpoint") + self.builder.call(fc_ptr, [self.local_vars["vmmngr"]]) + fc_ptr = self.mod.get_global("check_invalid_code_blocs") + self.builder.call(fc_ptr, [self.local_vars["vmmngr"]]) + self.check_memory_exception(next_instr, restricted_exception=False) + if attrib.set_exception or attrib.op_set_exception: + self.check_cpu_exception(next_instr, restricted_exception=False) + + if attrib.mem_read | attrib.mem_write: + fc_ptr = self.mod.get_global("reset_memory_access") + self.builder.call(fc_ptr, [self.local_vars["vmmngr"]]) + + def expr2cases(self, expr): + """ + Evaluate @expr and return: + - switch value -> dst + - evaluation of the switch value (if any) + """ + + to_eval = expr + dst2case = {} + case2dst = {} + for i, solution in enumerate(possible_values(expr)): + value = solution.value + index = dst2case.get(value, i) + to_eval = to_eval.replace_expr({value: m2_expr.ExprInt(index, value.size)}) + dst2case[value] = index + if m2_asmbloc.expr_is_int_or_label(value): + case2dst[i] = value + else: + case2dst[i] = self.add_ir(value) + + + evaluated = self.add_ir(to_eval) + return case2dst, evaluated + + def gen_jump2dst(self, attrib, dst): + """Generate the code for a jump to @dst with final check for error + + Several cases have to be considered: + - jump to an offset out of the current ASM BBL (JMP 0x11223344) + - jump to an offset inside the current ASM BBL (Go to next instruction) + - jump to a generated IR label, which must be jitted in this same + function (REP MOVSB) + - jump to a computed offset (CALL @32[0x11223344]) + """ + PC = self.llvm_context.PC + # We are no longer in the main stream, deactivate cache + self.main_stream = False + + if isinstance(dst, m2_expr.ExprInt): + dst = m2_expr.ExprId(self.llvm_context.ir_arch.symbol_pool.getby_offset_create(int(dst)), + dst.size) + + if m2_asmbloc.expr_is_label(dst): + bbl = self.get_basic_bloc_by_label(dst) + if bbl is not None: + # "local" jump, inside this function + if dst.name.offset is not None: + # Avoid checks on generated label + self.gen_post_code(attrib) + self.gen_post_instr_checks(attrib, dst.name.offset) + self.builder.branch(bbl) + return + else: + # "extern" jump on a defined offset, return to the caller + offset = dst.name.offset + dst = self.add_ir(m2_expr.ExprInt(offset, PC.size)) + + # "extern" jump with a computed value, return to the caller + assert isinstance(dst, (llvm_ir.Instruction, llvm_ir.Value)) + # Cast @dst, if needed + # for instance, x86_32: IRDst is 32 bits, so is @dst; PC is 64 bits + if dst.type.width != PC.size: + dst = self.builder.zext(dst, LLVMType.IntType(PC.size)) + + self.gen_post_code(attrib) + self.affect(dst, PC) + self.gen_post_instr_checks(attrib, dst) + self.affect(self.add_ir(m2_expr.ExprInt8(0)), m2_expr.ExprId("status")) + self.set_ret(dst) + + + def gen_irblock(self, attrib, instr, irblock): + """ + Generate the code for an @irblock + @instr: the current instruction to translate + @irblock: an irbloc instance + @attrib: an Attributs instance + """ + + case2dst = None + case_value = None + + for assignblk in irblock.irs: + # Enable cache + self.main_stream = True + self.expr_cache = {} + + # Prefetch memory + for element in assignblk.get_r(mem_read=True): + if isinstance(element, m2_expr.ExprMem): + self.add_ir(element) + + # Evaluate expressions + values = {} + for dst, src in assignblk.iteritems(): + if dst == self.llvm_context.ir_arch.IRDst: + case2dst, case_value = self.expr2cases(src) + else: + values[dst] = self.add_ir(src) + + # Check memory access exception + if assignblk.mem_read: + self.check_memory_exception(instr.offset, + restricted_exception=True) + + # Check operation exception + if assignblk.op_set_exception: + self.check_cpu_exception(instr.offset, restricted_exception=True) + + # Update the memory + for dst, src in values.iteritems(): + if isinstance(dst, m2_expr.ExprMem): + self.affect(src, dst) + + # Check memory write exception + if assignblk.mem_write: + self.check_memory_exception(instr.offset, + restricted_exception=True) + + # Update registers values + for dst, src in values.iteritems(): + if not isinstance(dst, m2_expr.ExprMem): + self.affect(src, dst) + + # Check post assignblk exception flags + if assignblk.set_exception: + self.check_cpu_exception(instr.offset, restricted_exception=True) + + # Destination + assert case2dst is not None + if len(case2dst) == 1: + # Avoid switch in this common case + self.gen_jump2dst(attrib, case2dst.values()[0]) + else: + current_bbl = self.builder.basic_block + + # Gen the out cases + branch_id = self.new_branch_name() + case2bbl = {} + for case, dst in case2dst.iteritems(): + name = "switch_%s_%d" % (branch_id, case) + bbl = self.append_basic_block(name) + case2bbl[case] = bbl + self.builder.position_at_start(bbl) + self.gen_jump2dst(attrib, dst) + + # Jump on the correct output + self.builder.position_at_end(current_bbl) + switch = self.builder.switch(case_value, case2bbl[0]) + for i, bbl in case2bbl.iteritems(): + if i == 0: + # Default case is case 0, arbitrary + continue + switch.add_case(i, bbl) + + def from_asmblock(self, asmblock): + """Build the function from an asmblock (asm_block instance). + Prototype : f(i8* jitcpu, i8* vmcpu, i8* vmmngr, i8* status)""" + + if isinstance(asmblock, m2_asmbloc.asm_block_bad): + raise NotImplementedError("TODO") + + # Build function signature + self.my_args.append((m2_expr.ExprId("jitcpu"), + llvm_ir.PointerType(LLVMType.IntType(8)), + "jitcpu")) + self.my_args.append((m2_expr.ExprId("vmcpu"), + llvm_ir.PointerType(LLVMType.IntType(8)), + "vmcpu")) + self.my_args.append((m2_expr.ExprId("vmmngr"), + llvm_ir.PointerType(LLVMType.IntType(8)), + "vmmngr")) + self.my_args.append((m2_expr.ExprId("status"), + llvm_ir.PointerType(LLVMType.IntType(8)), + "status")) + ret_size = 64 + + self.ret_type = LLVMType.IntType(ret_size) + + # Initialise the function + self.init_fc() + self.local_vars_pointers["status"] = self.local_vars["status"] + + # Create basic blocks (for label branchs) + entry_bbl, builder = self.entry_bbl, self.builder + + for instr in asmblock.lines: + lbl = self.llvm_context.ir_arch.symbol_pool.getby_offset_create(instr.offset) + name = self.canonize_label_name(lbl) + self.append_basic_block(name) + + # Add content + builder.position_at_end(entry_bbl) + + # TODO: merge duplicate code with CGen + codegen = CGen(self.llvm_context.ir_arch) + irblocks_list = codegen.block2assignblks(asmblock) + + for instr, irblocks in zip(asmblock.lines, irblocks_list): + attrib = codegen.get_attributes(instr, irblocks, self.log_mn, + self.log_regs) + + # Pre-create basic blocks + for irblock in irblocks: + name = self.canonize_label_name(irblock.label) + self.append_basic_block(name, overwrite=False) + + # Generate the corresponding code + for index, irblock in enumerate(irblocks): + self.llvm_context.ir_arch.irbloc_fix_regs_for_mode( + irblock, self.llvm_context.ir_arch.attrib) + + # Set the builder at the begining of the correct bbl + name = self.canonize_label_name(irblock.label) + self.builder.position_at_end(self.get_basic_bloc_by_label(name)) + + if index == 0: + self.gen_pre_code(attrib) + self.gen_irblock(attrib, instr, irblock) + + # Gen finalize (see codegen::CGen) is unrecheable + # self.gen_finalize(codegen.get_block_post_label(asmblock).offset) + + # Branch entry_bbl on first label + builder.position_at_end(entry_bbl) + first_label_bbl = self.get_basic_bloc_by_label(asmblock.label) + builder.branch(first_label_bbl) + def from_blocs(self, blocs): """Build the function from a list of bloc (irbloc instances). Prototype : f(i8* jitcpu, i8* vmcpu, i8* vmmngr)""" |