diff options
Diffstat (limited to '')
| -rw-r--r-- | miasm2/jitter/llvmconvert.py | 1352 |
1 files changed, 827 insertions, 525 deletions
diff --git a/miasm2/jitter/llvmconvert.py b/miasm2/jitter/llvmconvert.py index 3ac75cd7..08b1986b 100644 --- a/miasm2/jitter/llvmconvert.py +++ b/miasm2/jitter/llvmconvert.py @@ -5,33 +5,33 @@ # - JiT # # # Requires: # -# - llvmpy (tested on v0.11.2) # +# - llvmlite (tested on v0.15) # # # Authors : Fabrice DESCLAUX (CEA/DAM), Camille MOUGEY (CEA/DAM) # # # -import llvm -import llvm.core as llvm_c -import llvm.ee as llvm_e -import llvm.passes as llvm_p +from llvmlite import binding as llvm +from llvmlite import ir as llvm_ir import miasm2.expression.expression as m2_expr import miasm2.jitter.csts as m2_csts import miasm2.core.asmbloc as m2_asmbloc +from miasm2.jitter.codegen import CGen +from miasm2.expression.expression_helper import possible_values -class LLVMType(llvm_c.Type): +class LLVMType(llvm_ir.Type): "Handle LLVM Type" int_cache = {} @classmethod - def int(cls, size=32): + def IntType(cls, size=32): try: return cls.int_cache[size] except KeyError: - cls.int_cache[size] = llvm_c.Type.int(size) + cls.int_cache[size] = llvm_ir.IntType(size) return cls.int_cache[size] @classmethod @@ -43,7 +43,7 @@ class LLVMType(llvm_c.Type): def generic(cls, e): "Generic value for execution" if isinstance(e, m2_expr.ExprInt): - return llvm_e.GenericValue.int(LLVMType.int(e.size), int(e)) + return llvm_e.GenericValue.int(LLVMType.IntType(e.size), int(e.arg)) elif isinstance(e, llvm_e.GenericValue): return e else: @@ -58,39 +58,39 @@ class LLVMContext(): def __init__(self, name="mod"): "Initialize a context with a module named 'name'" - self.mod = llvm_c.Module.new(name) - self.pass_manager = llvm_p.FunctionPassManager.new(self.mod) - self.exec_engine = llvm_e.ExecutionEngine.new(self.mod) - self.add_fc(self.known_fc) - - def optimise_level(self, classic_passes=True, dead_passes=True): - """Set the optimisation level : - classic_passes : - - combine instruction - - reassociate - - global value numbering - - simplify cfg - - dead_passes : - - dead code - - dead store - - dead instructions + # Initialize llvm + llvm.initialize() + llvm.initialize_native_target() + llvm.initialize_native_asmprinter() + + # Initilize target for compilation + target = llvm.Target.from_default_triple() + self.target_machine = target.create_target_machine() + self.init_exec_engine() + + def optimise_level(self, level=2): + """Set the optimisation level to @level from 0 to 2 + 0: non-optimized + 2: optimized """ # Set up the optimiser pipeline - - if classic_passes is True: - # self.pass_manager.add(llvm_p.PASS_INSTCOMBINE) - self.pass_manager.add(llvm_p.PASS_REASSOCIATE) - self.pass_manager.add(llvm_p.PASS_GVN) - self.pass_manager.add(llvm_p.PASS_SIMPLIFYCFG) - - if dead_passes is True: - self.pass_manager.add(llvm_p.PASS_DCE) - self.pass_manager.add(llvm_p.PASS_DSE) - self.pass_manager.add(llvm_p.PASS_DIE) - - self.pass_manager.initialize() + pmb = llvm.create_pass_manager_builder() + pmb.opt_level = level + pm = llvm.create_module_pass_manager() + pmb.populate(pm) + self.pass_manager = pm + + def init_exec_engine(self): + mod = llvm.parse_assembly("") + engine = llvm.create_mcjit_compiler(mod, + self.target_machine) + self.exec_engine = engine + + def new_module(self, name="mod"): + """Create a module, with needed functions""" + self.mod = llvm_ir.Module(name=name) + self.add_fc(self.known_fc) def get_execengine(self): "Return the Execution Engine associated with this context" @@ -98,7 +98,7 @@ class LLVMContext(): def get_passmanager(self): "Return the Pass Manager associated with this context" - return self.exec_engine + return self.pass_manager def get_module(self): "Return the module associated with this context" @@ -106,123 +106,246 @@ class LLVMContext(): def add_shared_library(self, filename): "Load the shared library 'filename'" - return llvm_c.load_library_permanently(filename) + return llvm.load_library_permanently(filename) def add_fc(self, fc): "Add function into known_fc" - for name, detail in fc.items(): - self.mod.add_function(LLVMType.function(detail["ret"], - detail["args"]), - name) + for name, detail in fc.iteritems(): + fnty = llvm_ir.FunctionType(detail["ret"], detail["args"]) + llvm_ir.Function(self.mod, fnty, name=name) + + def memory_lookup(self, func, addr, size): + """Perform a memory lookup at @addr of size @size (in bit)""" + raise NotImplementedError("Abstract method") + + def memory_write(self, func, addr, size, value): + """Perform a memory write at @addr of size @size (in bit) with LLVM IR @value""" + raise NotImplementedError("Abstract method") class LLVMContext_JIT(LLVMContext): - "Extend LLVMContext_JIT in order to handle memory management" + """Extend LLVMContext_JIT in order to handle memory management and custom + operations""" - def __init__(self, library_filenames, name="mod"): + def __init__(self, library_filenames, ir_arch, name="mod"): "Init a LLVMContext object, and load the mem management shared library" + self.library_filenames = library_filenames + self.ir_arch = ir_arch + self.arch_specific() LLVMContext.__init__(self, name) - for lib_fname in library_filenames: + self.vmcpu = {} + + def new_module(self, name="mod"): + LLVMContext.new_module(self, name) + for lib_fname in self.library_filenames: self.add_shared_library(lib_fname) self.add_memlookups() self.add_get_exceptionflag() self.add_op() self.add_log_functions() - self.vmcpu = {} + + def arch_specific(self): + arch = self.ir_arch.arch + if arch.name == "x86": + self.PC = arch.regs.RIP + self.logging_func = "dump_gpregs_%d" % self.ir_arch.attrib + else: + self.PC = self.ir_arch.pc + self.logging_func = "dump_gpregs" + if arch.name == "mips32": + from miasm2.arch.mips32.jit import mipsCGen + self.cgen_class = mipsCGen + self.has_delayslot = True + else: + self.cgen_class = CGen + self.has_delayslot = False def add_memlookups(self): "Add MEM_LOOKUP functions" fc = {} - p8 = llvm_c.PointerType.pointer(LLVMType.int(8)) + p8 = llvm_ir.PointerType(LLVMType.IntType(8)) for i in [8, 16, 32, 64]: - fc["MEM_LOOKUP_%02d" % i] = {"ret": LLVMType.int(i), + fc["MEM_LOOKUP_%02d" % i] = {"ret": LLVMType.IntType(i), "args": [p8, - LLVMType.int(64)]} + LLVMType.IntType(64)]} - fc["MEM_WRITE_%02d" % i] = {"ret": LLVMType.void(), + fc["MEM_WRITE_%02d" % i] = {"ret": llvm_ir.VoidType(), "args": [p8, - LLVMType.int(64), - LLVMType.int(i)]} - + LLVMType.IntType(64), + LLVMType.IntType(i)]} + fc["reset_memory_access"] = {"ret": llvm_ir.VoidType(), + "args": [p8, + ]} + fc["check_memory_breakpoint"] = {"ret": llvm_ir.VoidType(), + "args": [p8, + ]} + fc["check_invalid_code_blocs"] = {"ret": llvm_ir.VoidType(), + "args": [p8, + ]} self.add_fc(fc) def add_get_exceptionflag(self): "Add 'get_exception_flag' function" - p8 = llvm_c.PointerType.pointer(LLVMType.int(8)) - self.add_fc({"get_exception_flag": {"ret": LLVMType.int(64), + p8 = llvm_ir.PointerType(LLVMType.IntType(8)) + self.add_fc({"get_exception_flag": {"ret": LLVMType.IntType(64), "args": [p8]}}) def add_op(self): "Add operations functions" - p8 = llvm_c.PointerType.pointer(LLVMType.int(8)) - self.add_fc({"parity": {"ret": LLVMType.int(), - "args": [LLVMType.int()]}}) - self.add_fc({"rot_left": {"ret": LLVMType.int(), - "args": [LLVMType.int(), - LLVMType.int(), - LLVMType.int()]}}) - self.add_fc({"rot_right": {"ret": LLVMType.int(), - "args": [LLVMType.int(), - LLVMType.int(), - LLVMType.int()]}}) - - self.add_fc({"segm2addr": {"ret": LLVMType.int(64), + p8 = llvm_ir.PointerType(LLVMType.IntType(8)) + itype = LLVMType.IntType(64) + self.add_fc({"parity": {"ret": LLVMType.IntType(1), + "args": [itype]}}) + self.add_fc({"rot_left": {"ret": itype, + "args": [itype, + itype, + itype]}}) + self.add_fc({"rot_right": {"ret": itype, + "args": [itype, + itype, + itype]}}) + self.add_fc({"rcr_rez_op": {"ret": itype, + "args": [itype, + itype, + itype, + itype]}}) + self.add_fc({"rcl_rez_op": {"ret": itype, + "args": [itype, + itype, + itype, + itype]}}) + self.add_fc({"x86_bsr": {"ret": itype, + "args": [itype, + itype]}}) + self.add_fc({"x86_bsf": {"ret": itype, + "args": [itype, + itype]}}) + self.add_fc({"segm2addr": {"ret": itype, "args": [p8, - LLVMType.int(64), - LLVMType.int(64)]}}) + itype, + itype]}}) for k in [8, 16]: - self.add_fc({"bcdadd_%s" % k: {"ret": LLVMType.int(k), - "args": [LLVMType.int(k), - LLVMType.int(k)]}}) - self.add_fc({"bcdadd_cf_%s" % k: {"ret": LLVMType.int(k), - "args": [LLVMType.int(k), - LLVMType.int(k)]}}) + self.add_fc({"bcdadd_%s" % k: {"ret": LLVMType.IntType(k), + "args": [LLVMType.IntType(k), + LLVMType.IntType(k)]}}) + self.add_fc({"bcdadd_cf_%s" % k: {"ret": LLVMType.IntType(k), + "args": [LLVMType.IntType(k), + LLVMType.IntType(k)]}}) for k in [16, 32, 64]: - self.add_fc({"imod%s" % k: {"ret": LLVMType.int(k), + self.add_fc({"imod%s" % k: {"ret": LLVMType.IntType(k), + "args": [p8, + LLVMType.IntType(k), + LLVMType.IntType(k)]}}) + self.add_fc({"idiv%s" % k: {"ret": LLVMType.IntType(k), "args": [p8, - LLVMType.int(k), - LLVMType.int(k)]}}) - self.add_fc({"idiv%s" % k: {"ret": LLVMType.int(k), + LLVMType.IntType(k), + LLVMType.IntType(k)]}}) + self.add_fc({"umod%s" % k: {"ret": LLVMType.IntType(k), "args": [p8, - LLVMType.int(k), - LLVMType.int(k)]}}) + LLVMType.IntType(k), + LLVMType.IntType(k)]}}) + self.add_fc({"udiv%s" % k: {"ret": LLVMType.IntType(k), + "args": [p8, + LLVMType.IntType(k), + LLVMType.IntType(k)]}}) def add_log_functions(self): "Add functions for state logging" - p8 = llvm_c.PointerType.pointer(LLVMType.int(8)) - self.add_fc({"dump_gpregs": {"ret": LLVMType.void(), - "args": [p8]}}) + p8 = llvm_ir.PointerType(LLVMType.IntType(8)) + self.add_fc({self.logging_func: {"ret": llvm_ir.VoidType(), + "args": [p8]}}) def set_vmcpu(self, lookup_table): "Set the correspondance between register name and vmcpu offset" self.vmcpu = lookup_table - def set_IR_transformation(self, *args): - """Set a list of transformation to apply on expression before their - treatments. - args: function Expr(Expr)""" - self.IR_transformation_functions = args + def memory_lookup(self, func, addr, size): + """Perform a memory lookup at @addr of size @size (in bit)""" + builder = func.builder + fc_name = "MEM_LOOKUP_%02d" % size + fc_ptr = self.mod.get_global(fc_name) + addr_casted = builder.zext(addr, + LLVMType.IntType(64)) + + ret = builder.call(fc_ptr, [func.local_vars["jitcpu"], + addr_casted]) + return ret + + def memory_write(self, func, addr, size, value): + """Perform a memory write at @addr of size @size (in bit) with LLVM IR @value""" + # Function call + builder = func.builder + fc_name = "MEM_WRITE_%02d" % size + fc_ptr = self.mod.get_global(fc_name) + dst_casted = builder.zext(addr, LLVMType.IntType(64)) + builder.call(fc_ptr, [func.local_vars["jitcpu"], + dst_casted, + value]) + + +class LLVMContext_IRCompilation(LLVMContext): + + """Extend LLVMContext in order to handle memory management and custom + operations for Miasm IR compilation""" + + def memory_lookup(self, func, addr, size): + """Perform a memory lookup at @addr of size @size (in bit)""" + builder = func.builder + int_size = LLVMType.IntType(size) + ptr_casted = builder.inttoptr(addr, + llvm_ir.PointerType(int_size)) + return builder.load(ptr_casted) + + def memory_write(self, func, addr, size, value): + """Perform a memory write at @addr of size @size (in bit) with LLVM IR @value""" + builder = func.builder + int_size = LLVMType.IntType(size) + ptr_casted = builder.inttoptr(addr, + llvm_ir.PointerType(int_size)) + return builder.store(value, ptr_casted) class LLVMFunction(): + """Represent a LLVM function - "Represent a llvm function" + Implementation note: + A new module is created each time to avoid cumulative lag (if @new_module) + """ # Default logging values log_mn = False - log_regs = False - - def __init__(self, llvm_context, name="fc"): - "Create a new function with name fc" + log_regs = True + + # Operation translation + ## Basics + op_translate = {'parity': 'parity', + } + ## Add the size as first argument + op_translate_with_size = {'<<<': 'rot_left', + '>>>': 'rot_right', + '<<<c_rez': 'rcl_rez_op', + '>>>c_rez': 'rcr_rez_op', + 'bsr': 'x86_bsr', + 'bsf': 'x86_bsf', + } + ## Add the size as suffix + op_translate_with_suffix_size = {'bcdadd': 'bcdadd', + 'bcdadd_cf': 'bcdadd_cf', + } + + def __init__(self, llvm_context, name="fc", new_module=True): + "Create a new function with name @name" self.llvm_context = llvm_context + if new_module: + self.llvm_context.new_module() self.mod = self.llvm_context.get_module() self.my_args = [] # (Expr, LLVMType, Name) @@ -232,40 +355,195 @@ class LLVMFunction(): self.branch_counter = 0 self.name = name + self._llvm_mod = None + + # Constructor utils def new_branch_name(self): "Return a new branch name" - self.branch_counter += 1 - return "%s" % self.branch_counter + return str(self.branch_counter) - def viewCFG(self): - "Show the CFG of the current function" - self.fc.viewCFG() - - def append_basic_block(self, label): + def append_basic_block(self, label, overwrite=True): """Add a new basic block to the current function. @label: str or asmlabel + @overwrite: if False, do nothing if a bbl with the same name already exists Return the corresponding LLVM Basic Block""" name = self.canonize_label_name(label) + bbl = self.name2bbl.get(name, None) + if not overwrite and bbl is not None: + return bbl bbl = self.fc.append_basic_block(name) - self.name2bbl[label] = bbl + self.name2bbl[name] = bbl return bbl + def CreateEntryBlockAlloca(self, var_type, default_value=None): + """Create an alloca instruction at the beginning of the current fc + @default_value: if set, store the default_value just after the allocation + """ + builder = self.builder + current_bbl = builder.basic_block + builder.position_at_start(self.entry_bbl) + + ret = builder.alloca(var_type) + if default_value is not None: + builder.store(default_value, ret) + builder.position_at_end(current_bbl) + return ret + + def get_ptr_by_expr(self, expr): + """"Return a pointer casted corresponding to ExprId expr. If it is not + already computed, compute it at the end of entry_bloc""" + + name = expr.name + + ptr_casted = self.local_vars_pointers.get(name, None) + if ptr_casted is not None: + # If the pointer has already been computed + return ptr_casted + + # Get current objects + builder = self.builder + current_bbl = builder.basic_block + + # Go at the right position + entry_bloc_bbl = self.entry_bbl + builder.position_at_end(entry_bloc_bbl) + + # Compute the pointer address + offset = self.llvm_context.vmcpu[name] + + # Pointer cast + ptr = builder.gep(self.local_vars["vmcpu"], + [llvm_ir.Constant(LLVMType.IntType(), + offset)]) + int_size = LLVMType.IntType(expr.size) + ptr_casted = builder.bitcast(ptr, + llvm_ir.PointerType(int_size)) + # Store in cache + self.local_vars_pointers[name] = ptr_casted + + # Reset builder + builder.position_at_end(current_bbl) + + return ptr_casted + + def update_cache(self, name, value): + "Add 'name' = 'value' to the cache iff main_stream = True" + + if self.main_stream is True: + self.expr_cache[name] = value + + def set_ret(self, var): + "Cast @var and return it at the end of current bbl" + if var.type.width < 64: + var_casted = self.builder.zext(var, LLVMType.IntType(64)) + else: + var_casted = var + self.builder.ret(var_casted) + + def canonize_label_name(self, label): + """Canonize @label names to a common form. + @label: str or asmlabel instance""" + if isinstance(label, str): + return label + elif isinstance(label, m2_asmbloc.asm_label): + return "label_%s" % label.name + elif m2_asmbloc.expr_is_label(label): + return "label_%s" % label.name.name + else: + raise ValueError("label must either be str or asmlabel") + + def get_basic_bloc_by_label(self, label): + "Return the bbl corresponding to label, None otherwise" + return self.name2bbl.get(self.canonize_label_name(label), None) + + def global_constant(self, name, value): + """ + Inspired from numba/cgutils.py + + Get or create a (LLVM module-)global constant with *name* or *value*. + """ + module = self.mod + data = llvm_ir.GlobalVariable(self.mod, value.type, name=name) + data.global_constant = True + data.initializer = value + return data + + def make_bytearray(self, buf): + """ + Inspired from numba/cgutils.py + + Make a byte array constant from *buf*. + """ + b = bytearray(buf) + n = len(b) + return llvm_ir.Constant(llvm_ir.ArrayType(llvm_ir.IntType(8), n), b) + + def printf(self, format, *args): + """ + Inspired from numba/cgutils.py + + Calls printf(). + Argument `format` is expected to be a Python string. + Values to be printed are listed in `args`. + + Note: There is no checking to ensure there is correct number of values + in `args` and there type matches the declaration in the format string. + """ + assert isinstance(format, str) + mod = self.mod + # Make global constant for format string + cstring = llvm_ir.IntType(8).as_pointer() + fmt_bytes = self.make_bytearray((format + '\00').encode('ascii')) + + base_name = "printf_format" + count = 0 + while self.mod.get_global("%s_%d" % (base_name, count)): + count += 1 + global_fmt = self.global_constant("%s_%d" % (base_name, count), + fmt_bytes) + fnty = llvm_ir.FunctionType(llvm_ir.IntType(32), [cstring], + var_arg=True) + # Insert printf() + fn = mod.get_global('printf') + if fn is None: + fn = llvm_ir.Function(mod, fnty, name="printf") + # Call + ptr_fmt = self.builder.bitcast(global_fmt, cstring) + return self.builder.call(fn, [ptr_fmt] + list(args)) + + # Effective constructors + + def affect(self, src, dst): + "Affect from LLVM src to M2 dst" + + # Destination + builder = self.builder + + if isinstance(dst, m2_expr.ExprId): + ptr_casted = self.get_ptr_by_expr(dst) + builder.store(src, ptr_casted) + + elif isinstance(dst, m2_expr.ExprMem): + addr = self.add_ir(dst.arg) + self.llvm_context.memory_write(self, addr, dst.size, src) + else: + raise Exception("UnknownAffectationType") + def init_fc(self): "Init the function" # Build type for fc signature - fc_type = LLVMType.function( - self.ret_type, [k[1] for k in self.my_args]) + fc_type = llvm_ir.FunctionType(self.ret_type, [k[1] for k in self.my_args]) # Add fc in module try: - fc = self.mod.add_function(fc_type, self.name) + fc = llvm_ir.Function(self.mod, fc_type, name=self.name) except llvm.LLVMException: # Overwrite the previous function - previous_fc = self.mod.get_function_named(self.name) + previous_fc = self.mod.get_global(self.name) previous_fc.delete() fc = self.mod.add_function(fc_type, self.name) @@ -283,7 +561,6 @@ class LLVMFunction(): self.expr_cache = {} self.main_stream = True self.name2bbl = {} - self.offsets_jitted = set() # Function link self.fc = fc @@ -292,69 +569,7 @@ class LLVMFunction(): self.entry_bbl = self.append_basic_block("entry") # Instruction builder - self.builder = llvm_c.Builder.new(self.entry_bbl) - - def CreateEntryBlockAlloca(self, var_type): - "Create an alloca instruction at the beginning of the current fc" - builder = self.builder - current_bbl = builder.basic_block - builder.position_at_end(self.entry_bbl) - - ret = builder.alloca(var_type) - builder.position_at_end(current_bbl) - return ret - - def get_ptr_by_expr(self, expr): - """"Return a pointer casted corresponding to ExprId expr. If it is not - already computed, compute it at the end of entry_bloc""" - - name = expr.name - - try: - # If the pointer has already been computed - ptr_casted = self.local_vars_pointers[name] - - except KeyError: - # Get current objects - builder = self.builder - current_bbl = builder.basic_block - - # Go at the right position - entry_bloc_bbl = self.entry_bbl - builder.position_at_end(entry_bloc_bbl) - - # Compute the pointer address - offset = self.llvm_context.vmcpu[name] - - # Pointer cast - ptr = builder.gep(self.local_vars["vmcpu"], - [llvm_c.Constant.int(LLVMType.int(), - offset)]) - int_size = LLVMType.int(expr.size) - ptr_casted = builder.bitcast(ptr, - llvm_c.PointerType.pointer(int_size)) - # Store in cache - self.local_vars_pointers[name] = ptr_casted - - # Reset builder - builder.position_at_end(current_bbl) - - return ptr_casted - - def clear_cache(self, regs_updated): - "Remove from the cache values which depends on regs_updated" - - regs_updated_set = set(regs_updated) - - for expr in self.expr_cache.keys(): - if expr.get_r(True).isdisjoint(regs_updated_set) is not True: - self.expr_cache.pop(expr) - - def update_cache(self, name, value): - "Add 'name' = 'value' to the cache iff main_stream = True" - - if self.main_stream is True: - self.expr_cache[name] = value + self.builder = llvm_ir.IRBuilder(self.entry_bbl) def add_ir(self, expr): "Add a Miasm2 IR to the last bbl. Return the var created" @@ -365,7 +580,7 @@ class LLVMFunction(): builder = self.builder if isinstance(expr, m2_expr.ExprInt): - ret = llvm_c.Constant.int(LLVMType.int(expr.size), int(expr)) + ret = llvm_ir.Constant(LLVMType.IntType(expr.size), int(expr.arg)) self.update_cache(expr, ret) return ret @@ -374,7 +589,7 @@ class LLVMFunction(): if not isinstance(name, str): # Resolve label offset = name.offset - ret = llvm_c.Constant.int(LLVMType.int(expr.size), offset) + ret = llvm_ir.Constant(LLVMType.IntType(expr.size), offset) self.update_cache(expr, ret) return ret @@ -393,69 +608,73 @@ class LLVMFunction(): if isinstance(expr, m2_expr.ExprOp): op = expr.op - if op == "parity": - fc_ptr = self.mod.get_function_named("parity") - arg = builder.zext(self.add_ir(expr.args[0]), - LLVMType.int()) - ret = builder.call(fc_ptr, [arg]) - ret = builder.trunc(ret, LLVMType.int(expr.size)) - self.update_cache(expr, ret) - return ret - - if op in ["<<<", ">>>"]: - fc_name = "rot_left" if op == "<<<" else "rot_right" - fc_ptr = self.mod.get_function_named(fc_name) + if (op in self.op_translate or + op in self.op_translate_with_size or + op in self.op_translate_with_suffix_size): args = [self.add_ir(arg) for arg in expr.args] arg_size = expr.args[0].size - if arg_size < 32: - # Cast args - args = [builder.zext(arg, LLVMType.int(32)) - for arg in args] - arg_size_cst = llvm_c.Constant.int(LLVMType.int(), - arg_size) - ret = builder.call(fc_ptr, [arg_size_cst] + args) - if arg_size < 32: - # Cast ret - ret = builder.trunc(ret, LLVMType.int(arg_size)) - self.update_cache(expr, ret) - return ret - if op == "bcdadd": - size = expr.args[0].size - fc_ptr = self.mod.get_function_named("bcdadd_%s" % size) - args = [self.add_ir(arg) for arg in expr.args] - ret = builder.call(fc_ptr, args) - self.update_cache(expr, ret) - return ret + if op in self.op_translate_with_size: + fc_name = self.op_translate_with_size[op] + arg_size_cst = llvm_ir.Constant(LLVMType.IntType(64), + arg_size) + args = [arg_size_cst] + args + elif op in self.op_translate: + fc_name = self.op_translate[op] + elif op in self.op_translate_with_suffix_size: + fc_name = "%s_%s" % (self.op_translate[op], arg_size) + + fc_ptr = self.mod.get_global(fc_name) + + # Cast args if needed + casted_args = [] + for i, arg in enumerate(args): + if arg.type.width < fc_ptr.args[i].type.width: + casted_args.append(builder.zext(arg, fc_ptr.args[i].type)) + else: + casted_args.append(arg) + ret = builder.call(fc_ptr, casted_args) + + # Cast ret if needed + ret_size = fc_ptr.return_value.type.width + if ret_size > expr.size: + ret = builder.trunc(ret, LLVMType.IntType(expr.size)) - if op == "bcdadd_cf": - size = expr.args[0].size - fc_ptr = self.mod.get_function_named("bcdadd_cf_%s" % size) - args = [self.add_ir(arg) for arg in expr.args] - ret = builder.call(fc_ptr, args) - ret = builder.trunc(ret, LLVMType.int(expr.size)) self.update_cache(expr, ret) return ret if op == "-": - zero = llvm_c.Constant.int(LLVMType.int(expr.size), - 0) + # Unsupported op '-' with more than 1 arg + assert len(expr.args) == 1 + zero = LLVMType.IntType(expr.size)(0) ret = builder.sub(zero, self.add_ir(expr.args[0])) self.update_cache(expr, ret) return ret if op == "segm": - fc_ptr = self.mod.get_function_named("segm2addr") - args_casted = [builder.zext(self.add_ir(arg), LLVMType.int(64)) - for arg in expr.args] - args = [self.local_vars["vmcpu"]] + args_casted - ret = builder.call(fc_ptr, args) - ret = builder.trunc(ret, LLVMType.int(expr.size)) + fc_ptr = self.mod.get_global("segm2addr") + + # Cast args if needed + args = [self.add_ir(arg) for arg in expr.args] + casted_args = [] + for i, arg in enumerate(args, 1): + if arg.type.width < fc_ptr.args[i].type.width: + casted_args.append(builder.zext(arg, fc_ptr.args[i].type)) + else: + casted_args.append(arg) + + ret = builder.call(fc_ptr, + [self.local_vars["jitcpu"]] + casted_args) + # Ret size is not expr.size on segm2addr (which is the size of + # the segment, for instance 16 bits), but the size of an addr + ret_size = self.llvm_context.PC.size + if ret.type.width > ret_size: + ret = builder.trunc(ret, LLVMType.IntType(ret_size)) self.update_cache(expr, ret) return ret - if op in ["imod", "idiv"]: - fc_ptr = self.mod.get_function_named( + if op in ["imod", "idiv", "umod", "udiv"]: + fc_ptr = self.mod.get_global( "%s%s" % (op, expr.args[0].size)) args_casted = [self.add_ir(arg) for arg in expr.args] args = [self.local_vars["vmcpu"]] + args_casted @@ -463,6 +682,26 @@ class LLVMFunction(): self.update_cache(expr, ret) return ret + if op in [">>", "<<", "a>>"]: + assert len(expr.args) == 2 + # Undefined behavior must be enforced to 0 + count = self.add_ir(expr.args[1]) + value = self.add_ir(expr.args[0]) + itype = LLVMType.IntType(expr.size) + cond_ok = self.builder.icmp_unsigned("<", count, + itype(expr.size)) + if op == ">>": + callback = builder.lshr + elif op == "<<": + callback = builder.shl + elif op == "a>>": + callback = builder.ashr + + ret = self.builder.select(cond_ok, callback(value, count), + itype(0)) + self.update_cache(expr, ret) + return ret + if len(expr.args) > 1: if op == "*": @@ -475,16 +714,10 @@ class LLVMFunction(): callback = builder.xor elif op == "|": callback = builder.or_ - elif op == ">>": - callback = builder.lshr - elif op == "<<": - callback = builder.shl - elif op == "a>>": - callback = builder.ashr - elif op == "udiv": - callback = builder.udiv - elif op == "umod": + elif op == "%": callback = builder.urem + elif op == "/": + callback = builder.udiv else: raise NotImplementedError('Unknown op: %s' % op) @@ -502,58 +735,18 @@ class LLVMFunction(): if isinstance(expr, m2_expr.ExprMem): - fc_name = "MEM_LOOKUP_%02d" % expr.size - fc_ptr = self.mod.get_function_named(fc_name) - addr_casted = builder.zext(self.add_ir(expr.arg), - LLVMType.int(64)) - - ret = builder.call(fc_ptr, [self.local_vars["vmmngr"], - addr_casted]) - - # Do not update memory cache to avoid pointer collision - return ret + addr = self.add_ir(expr.arg) + return self.llvm_context.memory_lookup(self, addr, expr.size) if isinstance(expr, m2_expr.ExprCond): # Compute cond cond = self.add_ir(expr.cond) - zero_casted = llvm_c.Constant.int(LLVMType.int(expr.cond.size), - 0) - condition_bool = builder.icmp(llvm_c.ICMP_NE, cond, - zero_casted) - - # Alloc return var - alloca = self.CreateEntryBlockAlloca(LLVMType.int(expr.size)) - - # Create bbls - branch_id = self.new_branch_name() - then_block = self.append_basic_block('then%s' % branch_id) - else_block = self.append_basic_block('else%s' % branch_id) - merge_block = self.append_basic_block('ifcond%s' % branch_id) - - builder.cbranch(condition_bool, then_block, else_block) - - # Deactivate object caching - current_main_stream = self.main_stream - self.main_stream = False - - # Then Bloc - builder.position_at_end(then_block) + zero_casted = LLVMType.IntType(expr.cond.size)(0) + condition_bool = builder.icmp_unsigned("!=", cond, + zero_casted) then_value = self.add_ir(expr.src1) - builder.store(then_value, alloca) - builder.branch(merge_block) - - # Else Bloc - builder.position_at_end(else_block) else_value = self.add_ir(expr.src2) - builder.store(else_value, alloca) - builder.branch(merge_block) - - # Merge bloc - builder.position_at_end(merge_block) - ret = builder.load(alloca) - - # Reactivate object caching - self.main_stream = current_main_stream + ret = builder.select(condition_bool, then_value, else_value) self.update_cache(expr, ret) return ret @@ -564,22 +757,22 @@ class LLVMFunction(): # Remove trailing bits if expr.start != 0: - to_shr = llvm_c.Constant.int(LLVMType.int(expr.arg.size), - expr.start) + to_shr = llvm_ir.Constant(LLVMType.IntType(expr.arg.size), + expr.start) shred = builder.lshr(src, to_shr) else: shred = src # Remove leading bits - to_and = llvm_c.Constant.int(LLVMType.int(expr.arg.size), - (1 << (expr.stop - expr.start)) - 1) + to_and = llvm_ir.Constant(LLVMType.IntType(expr.arg.size), + (1 << (expr.stop - expr.start)) - 1) anded = builder.and_(shred, to_and) # Cast into e.size ret = builder.trunc(anded, - LLVMType.int(expr.size)) + LLVMType.IntType(expr.size)) self.update_cache(expr, ret) return ret @@ -589,22 +782,20 @@ class LLVMFunction(): args = [] # Build each part - for arg in expr.args: - src, start, stop = arg - - # src & (stop - start) + for start, src in expr.iter_args(): + # src & size src = self.add_ir(src) src_casted = builder.zext(src, - LLVMType.int(expr.size)) - to_and = llvm_c.Constant.int(LLVMType.int(expr.size), - (1 << (stop - start)) - 1) + LLVMType.IntType(expr.size)) + to_and = llvm_ir.Constant(LLVMType.IntType(expr.size), + (1 << src.type.width) - 1) anded = builder.and_(src_casted, to_and) if (start != 0): # result << start - to_shl = llvm_c.Constant.int(LLVMType.int(expr.size), - start) + to_shl = llvm_ir.Constant(LLVMType.IntType(expr.size), + start) shled = builder.shl(anded, to_shl) final = shled else: @@ -623,92 +814,34 @@ class LLVMFunction(): raise Exception("UnkownExpression", expr.__class__.__name__) - def set_ret(self, var): - "Cast @var and return it at the end of current bbl" - if var.type.width < 64: - var_casted = self.builder.zext(var, LLVMType.int(64)) - else: - var_casted = var - self.builder.ret(var_casted) - - def from_expr(self, expr): - "Build the function from an expression" - - # Build function signature - args = expr.get_r(True) - for a in args: - if not isinstance(a, m2_expr.ExprMem): - self.my_args.append((a, LLVMType.int(a.size), a.name)) - - self.ret_type = LLVMType.int(expr.size) - - # Initialise the function - self.init_fc() - - ret = self.add_ir(expr) - - self.set_ret(ret) - - def affect(self, src, dst, add_new=True): - "Affect from M2 src to M2 dst. If add_new, add a suffix '_new' to dest" - - # Source - src = self.add_ir(src) - - # Destination - builder = self.builder - self.add_ir(m2_expr.ExprId("vmcpu")) - - if isinstance(dst, m2_expr.ExprId): - dst_name = dst.name + "_new" if add_new else dst.name + # JiT specifics - ptr_casted = self.get_ptr_by_expr( - m2_expr.ExprId(dst_name, dst.size)) - builder.store(src, ptr_casted) - - elif isinstance(dst, m2_expr.ExprMem): - self.add_ir(dst.arg) - - # Function call - fc_name = "MEM_WRITE_%02d" % dst.size - fc_ptr = self.mod.get_function_named(fc_name) - dst = self.add_ir(dst.arg) - dst_casted = builder.zext(dst, LLVMType.int(64)) - builder.call(fc_ptr, [self.local_vars["vmmngr"], - dst_casted, - src]) - - else: - raise Exception("UnknownAffectationType") - - def check_error(self, line, except_do_not_update_pc=False): + def check_memory_exception(self, offset, restricted_exception=False): """Add a check for memory errors. - @line: Irbloc line corresponding to the current instruction - If except_do_not_update_pc, check only for exception which do not - require a pc update""" + @offset: offset of the current exception (int or Instruction) + If restricted_exception, check only for exception which do not + require a pc update, and do not consider automod exception""" # VmMngr "get_exception_flag" return's size size = 64 - t_size = LLVMType.int(size) - - # Current address - pc_to_return = line.offset + t_size = LLVMType.IntType(size) # Get exception flag value + # TODO: avoid costly call using a structure deref builder = self.builder - fc_ptr = self.mod.get_function_named("get_exception_flag") + fc_ptr = self.mod.get_global("get_exception_flag") exceptionflag = builder.call(fc_ptr, [self.local_vars["vmmngr"]]) - if except_do_not_update_pc is True: - auto_mod_flag = m2_csts.EXCEPT_DO_NOT_UPDATE_PC - m2_flag = llvm_c.Constant.int(t_size, auto_mod_flag) + if restricted_exception is True: + flag = ~m2_csts.EXCEPT_CODE_AUTOMOD & m2_csts.EXCEPT_DO_NOT_UPDATE_PC + m2_flag = llvm_ir.Constant(t_size, flag) exceptionflag = builder.and_(exceptionflag, m2_flag) # Compute cond - zero_casted = llvm_c.Constant.int(t_size, 0) - condition_bool = builder.icmp(llvm_c.ICMP_NE, - exceptionflag, - zero_casted) + zero_casted = llvm_ir.Constant(t_size, 0) + condition_bool = builder.icmp_unsigned("!=", + exceptionflag, + zero_casted) # Create bbls branch_id = self.new_branch_name() @@ -723,168 +856,276 @@ class LLVMFunction(): # Then Bloc builder.position_at_end(then_block) - self.set_ret(llvm_c.Constant.int(self.ret_type, pc_to_return)) + PC = self.llvm_context.PC + if isinstance(offset, (int, long)): + offset = self.add_ir(m2_expr.ExprInt(offset, PC.size)) + self.affect(offset, PC) + self.affect(self.add_ir(m2_expr.ExprInt8(1)), m2_expr.ExprId("status")) + self.set_ret(offset) builder.position_at_end(merge_block) - # Reactivate object caching self.main_stream = current_main_stream - def log_instruction(self, instruction, line): - "Print current instruction and registers if options are set" + def check_cpu_exception(self, offset, restricted_exception=False): + """Add a check for CPU errors. + @offset: offset of the current exception (int or Instruction) + If restricted_exception, check only for exception which do not + require a pc update""" - # Get builder + # Get exception flag value builder = self.builder + m2_exception_flag = self.llvm_context.ir_arch.arch.regs.exception_flags + t_size = LLVMType.IntType(m2_exception_flag.size) + exceptionflag = self.add_ir(m2_exception_flag) - if self.log_mn is True: - print instruction # TODO + # Compute cond + if restricted_exception is True: + flag = m2_csts.EXCEPT_NUM_UPDT_EIP + condition_bool = builder.icmp_unsigned(">", exceptionflag, + llvm_ir.Constant(t_size, flag)) + else: + zero_casted = llvm_ir.Constant(t_size, 0) + condition_bool = builder.icmp_unsigned("!=", + exceptionflag, + zero_casted) - if self.log_regs is True: - # Call dump general purpose registers - fc_ptr = self.mod.get_function_named("dump_gpregs") - builder.call(fc_ptr, [self.local_vars["vmcpu"]]) + # Create bbls + branch_id = self.new_branch_name() + then_block = self.append_basic_block('then%s' % branch_id) + merge_block = self.append_basic_block('ifcond%s' % branch_id) - def add_bloc(self, bloc, lines): - "Add a bloc of instruction in the current function" + builder.cbranch(condition_bool, then_block, merge_block) - for instruction, line in zip(bloc, lines): - new_reg = set() + # Deactivate object caching + current_main_stream = self.main_stream + self.main_stream = False - # Check general errors only at the beggining of instruction - if line.offset not in self.offsets_jitted: - self.offsets_jitted.add(line.offset) - self.check_error(line) + # Then Bloc + builder.position_at_end(then_block) + PC = self.llvm_context.PC + if isinstance(offset, (int, long)): + offset = self.add_ir(m2_expr.ExprInt(offset, PC.size)) + self.affect(offset, PC) + self.affect(self.add_ir(m2_expr.ExprInt8(1)), m2_expr.ExprId("status")) + self.set_ret(offset) - # Log mn and registers if options is set - self.log_instruction(instruction, line) + builder.position_at_end(merge_block) + # Reactivate object caching + self.main_stream = current_main_stream + def gen_pre_code(self, attributes): + if attributes.log_mn: + self.printf("%.8X %s\n" % (attributes.instr.offset, + attributes.instr)) + + def gen_post_code(self, attributes): + if attributes.log_regs: + fc_ptr = self.mod.get_global(self.llvm_context.logging_func) + self.builder.call(fc_ptr, [self.local_vars["vmcpu"]]) + + def gen_post_instr_checks(self, attrib, next_instr): + if attrib.mem_read | attrib.mem_write: + fc_ptr = self.mod.get_global("check_memory_breakpoint") + self.builder.call(fc_ptr, [self.local_vars["vmmngr"]]) + fc_ptr = self.mod.get_global("check_invalid_code_blocs") + self.builder.call(fc_ptr, [self.local_vars["vmmngr"]]) + self.check_memory_exception(next_instr, restricted_exception=False) + if attrib.set_exception or attrib.op_set_exception: + self.check_cpu_exception(next_instr, restricted_exception=False) + + if attrib.mem_read | attrib.mem_write: + fc_ptr = self.mod.get_global("reset_memory_access") + self.builder.call(fc_ptr, [self.local_vars["vmmngr"]]) + + def expr2cases(self, expr): + """ + Evaluate @expr and return: + - switch value -> dst + - evaluation of the switch value (if any) + """ - # Pass on empty instruction - if len(instruction) == 0: - continue + to_eval = expr + dst2case = {} + case2dst = {} + for i, solution in enumerate(possible_values(expr)): + value = solution.value + index = dst2case.get(value, i) + to_eval = to_eval.replace_expr({value: m2_expr.ExprInt(index, value.size)}) + dst2case[value] = index + if m2_asmbloc.expr_is_int_or_label(value): + case2dst[i] = value + else: + case2dst[i] = self.add_ir(value) - for expression in instruction: - # Apply preinit transformation - for func in self.llvm_context.IR_transformation_functions: - expression = func(expression) - # Treat current expression - self.affect(expression.src, expression.dst) + evaluated = self.add_ir(to_eval) + return case2dst, evaluated - # Save registers updated - new_reg.update(expression.dst.get_w()) + def gen_jump2dst(self, attrib, dst): + """Generate the code for a jump to @dst with final check for error - # Check for errors (without updating PC) - self.check_error(line, except_do_not_update_pc=True) + Several cases have to be considered: + - jump to an offset out of the current ASM BBL (JMP 0x11223344) + - jump to an offset inside the current ASM BBL (Go to next instruction) + - jump to a generated IR label, which must be jitted in this same + function (REP MOVSB) + - jump to a computed offset (CALL @32[0x11223344]) + """ + PC = self.llvm_context.PC + # We are no longer in the main stream, deactivate cache + self.main_stream = False - # new -> normal - reg_written = [] - for r in new_reg: - if isinstance(r, m2_expr.ExprId): - r_new = m2_expr.ExprId(r.name + "_new", r.size) - reg_written += [r, r_new] - self.affect(r_new, r, add_new=False) + if isinstance(dst, m2_expr.ExprInt): + dst = m2_expr.ExprId(self.llvm_context.ir_arch.symbol_pool.getby_offset_create(int(dst)), + dst.size) + + if m2_asmbloc.expr_is_label(dst): + bbl = self.get_basic_bloc_by_label(dst) + if bbl is not None: + # "local" jump, inside this function + if dst.name.offset is not None: + # Avoid checks on generated label + self.gen_post_code(attrib) + self.gen_post_instr_checks(attrib, dst.name.offset) + self.builder.branch(bbl) + return + else: + # "extern" jump on a defined offset, return to the caller + offset = dst.name.offset + dst = self.add_ir(m2_expr.ExprInt(offset, PC.size)) - # Clear cache - self.clear_cache(reg_written) - self.main_stream = True + # "extern" jump with a computed value, return to the caller + assert isinstance(dst, (llvm_ir.Instruction, llvm_ir.Value)) + # Cast @dst, if needed + # for instance, x86_32: IRDst is 32 bits, so is @dst; PC is 64 bits + if dst.type.width != PC.size: + dst = self.builder.zext(dst, LLVMType.IntType(PC.size)) - def from_bloc(self, bloc, final_expr): - """Build the function from a bloc, with the dst equation. - Prototype : f(i8* vmcpu, i8* vmmngr)""" + self.gen_post_code(attrib) + self.affect(dst, PC) + self.gen_post_instr_checks(attrib, dst) + self.affect(self.add_ir(m2_expr.ExprInt8(0)), m2_expr.ExprId("status")) + self.set_ret(dst) - # Build function signature - self.my_args.append((m2_expr.ExprId("vmcpu"), - llvm_c.PointerType.pointer(LLVMType.int(8)), - "vmcpu")) - self.my_args.append((m2_expr.ExprId("vmmngr"), - llvm_c.PointerType.pointer(LLVMType.int(8)), - "vmmngr")) - self.ret_type = LLVMType.int(final_expr.size) - # Initialise the function - self.init_fc() + def gen_irblock(self, attrib, instr, irblock): + """ + Generate the code for an @irblock + @instr: the current instruction to translate + @irblock: an irbloc instance + @attrib: an Attributs instance + """ - # Add content - self.add_bloc(bloc, []) + case2dst = None + case_value = None - # Finalise the function - self.set_ret(self.add_ir(final_expr)) + for assignblk in irblock.irs: + # Enable cache + self.main_stream = True + self.expr_cache = {} + + # Prefetch memory + for element in assignblk.get_r(mem_read=True): + if isinstance(element, m2_expr.ExprMem): + self.add_ir(element) + + # Evaluate expressions + values = {} + for dst, src in assignblk.iteritems(): + if dst == self.llvm_context.ir_arch.IRDst: + case2dst, case_value = self.expr2cases(src) + else: + values[dst] = self.add_ir(src) - raise NotImplementedError("Not tested") + # Check memory access exception + if assignblk.mem_read: + self.check_memory_exception(instr.offset, + restricted_exception=True) - def canonize_label_name(self, label): - """Canonize @label names to a common form. - @label: str or asmlabel instance""" - if isinstance(label, str): - return label - elif isinstance(label, m2_asmbloc.asm_label): - return "label_%s" % label.name - else: - raise ValueError("label must either be str or asmlabel") + # Check operation exception + if assignblk.op_set_exception: + self.check_cpu_exception(instr.offset, restricted_exception=True) - def get_basic_bloc_by_label(self, label): - "Return the bbl corresponding to label, None otherwise" - return self.name2bbl.get(self.canonize_label_name(label), None) + # Update the memory + for dst, src in values.iteritems(): + if isinstance(dst, m2_expr.ExprMem): + self.affect(src, dst) - def gen_ret_or_branch(self, dest): - """Manage the dest ExprId. If label, branch on it if it is known. - Otherwise, return the ExprId or the offset value""" + # Check memory write exception + if assignblk.mem_write: + self.check_memory_exception(instr.offset, + restricted_exception=True) - builder = self.builder + # Update registers values + for dst, src in values.iteritems(): + if not isinstance(dst, m2_expr.ExprMem): + self.affect(src, dst) - if isinstance(dest, m2_expr.ExprId): - dest_name = dest.name - elif isinstance(dest, m2_expr.ExprSlice) and \ - isinstance(dest.arg, m2_expr.ExprId): - # Manage ExprId mask case - dest_name = dest.arg.name - else: - raise ValueError() + # Check post assignblk exception flags + if assignblk.set_exception: + self.check_cpu_exception(instr.offset, restricted_exception=True) - if not isinstance(dest_name, str): - label = dest_name - target_bbl = self.get_basic_bloc_by_label(label) - if target_bbl is None: - self.set_ret(self.add_ir(dest)) - else: - builder.branch(target_bbl) + # Destination + assert case2dst is not None + if len(case2dst) == 1: + # Avoid switch in this common case + self.gen_jump2dst(attrib, case2dst.values()[0]) else: - self.set_ret(self.add_ir(dest)) + current_bbl = self.builder.basic_block - def add_irbloc(self, irbloc): - "Add the content of irbloc at the corresponding labeled block" + # Gen the out cases + branch_id = self.new_branch_name() + case2bbl = {} + for case, dst in case2dst.iteritems(): + name = "switch_%s_%d" % (branch_id, case) + bbl = self.append_basic_block(name) + case2bbl[case] = bbl + self.builder.position_at_start(bbl) + self.gen_jump2dst(attrib, dst) + + # Jump on the correct output + self.builder.position_at_end(current_bbl) + switch = self.builder.switch(case_value, case2bbl[0]) + for i, bbl in case2bbl.iteritems(): + if i == 0: + # Default case is case 0, arbitrary + continue + switch.add_case(i, bbl) + + def gen_bad_block(self, asmblock): + """ + Translate an asm_bad_block into a CPU exception + """ builder = self.builder + m2_exception_flag = self.llvm_context.ir_arch.arch.regs.exception_flags + t_size = LLVMType.IntType(m2_exception_flag.size) + self.affect(self.add_ir(m2_expr.ExprInt8(1)), + m2_expr.ExprId("status")) + self.affect(t_size(m2_csts.EXCEPT_UNK_MNEMO), + m2_exception_flag) + self.set_ret(LLVMType.IntType(64)(asmblock.label.offset)) + + def gen_finalize(self, asmblock, codegen): + """ + In case of delayslot, generate a dummy BBL which return on the computed IRDst + or on next_label + """ + if self.llvm_context.has_delayslot: + next_label = codegen.get_block_post_label(asmblock) + builder = self.builder - bloc = irbloc.irs - dest = irbloc.dst - label = irbloc.label - lines = irbloc.lines - - # Get labeled basic bloc - label_block = self.get_basic_bloc_by_label(label) - builder.position_at_end(label_block) - - # Erase cache - self.expr_cache = {} - - # Add the content of the bloc with corresponding lines - self.add_bloc(bloc, lines) - - # Erase cache - self.expr_cache = {} + builder.position_at_end(self.get_basic_bloc_by_label(next_label)) - # Manage ret - for func in self.llvm_context.IR_transformation_functions: - dest = func(dest) + # Common code + self.affect(self.add_ir(m2_expr.ExprInt8(0)), + m2_expr.ExprId("status")) - if isinstance(dest, m2_expr.ExprCond): - # Compute cond - cond = self.add_ir(dest.cond) - zero_casted = llvm_c.Constant.int(LLVMType.int(dest.cond.size), - 0) - condition_bool = builder.icmp(llvm_c.ICMP_NE, cond, - zero_casted) + # Check if IRDst has been set + zero_casted = LLVMType.IntType(codegen.delay_slot_set.size)(0) + condition_bool = builder.icmp_unsigned("!=", + self.add_ir(codegen.delay_slot_set), + zero_casted) # Create bbls branch_id = self.new_branch_name() @@ -893,78 +1134,138 @@ class LLVMFunction(): builder.cbranch(condition_bool, then_block, else_block) - # Then Bloc - builder.position_at_end(then_block) - self.gen_ret_or_branch(dest.src1) + # Deactivate object caching + self.main_stream = False - # Else Bloc + # Then Block + builder.position_at_end(then_block) + PC = self.llvm_context.PC + to_ret = self.add_ir(codegen.delay_slot_dst) + self.affect(to_ret, PC) + self.affect(self.add_ir(m2_expr.ExprInt8(0)), + m2_expr.ExprId("status")) + self.set_ret(to_ret) + + # Else Block builder.position_at_end(else_block) - self.gen_ret_or_branch(dest.src2) - - elif isinstance(dest, m2_expr.ExprId): - self.gen_ret_or_branch(dest) - - elif isinstance(dest, m2_expr.ExprSlice): - self.gen_ret_or_branch(dest) + PC = self.llvm_context.PC + to_ret = LLVMType.IntType(PC.size)(next_label.offset) + self.affect(to_ret, PC) + self.set_ret(to_ret) - else: - raise Exception("Bloc dst has to be an ExprId or an ExprCond") - - def from_blocs(self, blocs): - """Build the function from a list of bloc (irbloc instances). - Prototype : f(i8* vmcpu, i8* vmmngr)""" + def from_asmblock(self, asmblock): + """Build the function from an asmblock (asm_block instance). + Prototype : f(i8* jitcpu, i8* vmcpu, i8* vmmngr, i8* status)""" # Build function signature + self.my_args.append((m2_expr.ExprId("jitcpu"), + llvm_ir.PointerType(LLVMType.IntType(8)), + "jitcpu")) self.my_args.append((m2_expr.ExprId("vmcpu"), - llvm_c.PointerType.pointer(LLVMType.int(8)), + llvm_ir.PointerType(LLVMType.IntType(8)), "vmcpu")) self.my_args.append((m2_expr.ExprId("vmmngr"), - llvm_c.PointerType.pointer(LLVMType.int(8)), + llvm_ir.PointerType(LLVMType.IntType(8)), "vmmngr")) + self.my_args.append((m2_expr.ExprId("status"), + llvm_ir.PointerType(LLVMType.IntType(8)), + "status")) ret_size = 64 - self.ret_type = LLVMType.int(ret_size) + self.ret_type = LLVMType.IntType(ret_size) # Initialise the function self.init_fc() + self.local_vars_pointers["status"] = self.local_vars["status"] + + if isinstance(asmblock, m2_asmbloc.asm_block_bad): + self.gen_bad_block(asmblock) + return # Create basic blocks (for label branchs) entry_bbl, builder = self.entry_bbl, self.builder - - for irbloc in blocs: - name = self.canonize_label_name(irbloc.label) - self.append_basic_block(name) + for instr in asmblock.lines: + lbl = self.llvm_context.ir_arch.symbol_pool.getby_offset_create(instr.offset) + self.append_basic_block(lbl) + + # TODO: merge duplicate code with CGen + codegen = self.llvm_context.cgen_class(self.llvm_context.ir_arch) + irblocks_list = codegen.block2assignblks(asmblock) + + # Prepare for delayslot + if self.llvm_context.has_delayslot: + for element in (codegen.delay_slot_dst, codegen.delay_slot_set): + eltype = LLVMType.IntType(element.size) + ptr = self.CreateEntryBlockAlloca(eltype, + default_value=eltype(0)) + self.local_vars_pointers[element.name] = ptr + lbl = codegen.get_block_post_label(asmblock) + self.append_basic_block(lbl) # Add content builder.position_at_end(entry_bbl) - for irbloc in blocs: - self.add_irbloc(irbloc) + for instr, irblocks in zip(asmblock.lines, irblocks_list): + attrib = codegen.get_attributes(instr, irblocks, self.log_mn, + self.log_regs) + + # Pre-create basic blocks + for irblock in irblocks: + self.append_basic_block(irblock.label, overwrite=False) + + # Generate the corresponding code + for index, irblock in enumerate(irblocks): + self.llvm_context.ir_arch.irbloc_fix_regs_for_mode( + irblock, self.llvm_context.ir_arch.attrib) + + # Set the builder at the begining of the correct bbl + name = self.canonize_label_name(irblock.label) + self.builder.position_at_end(self.get_basic_bloc_by_label(name)) + + if index == 0: + self.gen_pre_code(attrib) + self.gen_irblock(attrib, instr, irblock) + + # Gen finalize (see codegen::CGen) is unrecheable, except with delayslot + self.gen_finalize(asmblock, codegen) # Branch entry_bbl on first label builder.position_at_end(entry_bbl) - first_label_bbl = self.get_basic_bloc_by_label(blocs[0].label) + first_label_bbl = self.get_basic_bloc_by_label(asmblock.label) builder.branch(first_label_bbl) + + # LLVMFunction manipulation + def __str__(self): "Print the llvm IR corresponding to the current module" + return str(self.mod) + + def dot(self): + "Return the CFG of the current function" + return llvm.get_function_cfg(self.fc) - return str(self.fc) + def as_llvm_mod(self): + """Return a ModuleRef standing for the current function""" + if self._llvm_mod is None: + self._llvm_mod = llvm.parse_assembly(str(self.mod)) + return self._llvm_mod def verify(self): "Verify the module syntax" + return self.as_llvm_mod().verify() - return self.mod.verify() + def get_bytecode(self): + "Return LLVM bitcode corresponding to the current module" + return self.as_llvm_mod().as_bitcode() def get_assembly(self): "Return native assembly corresponding to the current module" - - return self.mod.to_native_assembly() + return self.llvm_context.target_machine.emit_assembly(self.as_llvm_mod()) def optimise(self): "Optimise the function in place" - while self.llvm_context.pass_manager.run(self.fc): - continue + return self.llvm_context.pass_manager.run(self.as_llvm_mod()) def __call__(self, *args): "Eval the function with arguments args" @@ -978,9 +1279,10 @@ class LLVMFunction(): def get_function_pointer(self): "Return a pointer on the Jitted function" - e = self.llvm_context.get_execengine() + engine = self.llvm_context.get_execengine() - return e.get_pointer_to_function(self.fc) + # Add the module and make sure it is ready for execution + engine.add_module(self.as_llvm_mod()) + engine.finalize_object() -# TODO: -# - Add more expressions + return engine.get_function_address(self.fc.name) |