diff options
Diffstat (limited to 'miasm2/jitter/llvmconvert.py')
| -rw-r--r-- | miasm2/jitter/llvmconvert.py | 323 |
1 files changed, 218 insertions, 105 deletions
diff --git a/miasm2/jitter/llvmconvert.py b/miasm2/jitter/llvmconvert.py index fd32001c..bea8cd36 100644 --- a/miasm2/jitter/llvmconvert.py +++ b/miasm2/jitter/llvmconvert.py @@ -11,9 +11,15 @@ # # +from builtins import zip +from builtins import range import os from llvmlite import binding as llvm from llvmlite import ir as llvm_ir +from builtins import int as int_types + +from future.utils import viewitems, viewvalues + from miasm2.expression.expression import ExprId, ExprInt, ExprMem, ExprSlice, \ ExprCond, ExprLoc, ExprOp, ExprCompose, LocKey, Expr, \ TOK_EQUAL, \ @@ -67,7 +73,7 @@ class LLVMType(llvm_ir.Type): return precision -class LLVMContext(): +class LLVMContext(object): "Context for llvm binding. Stand for a LLVM Module" @@ -139,7 +145,7 @@ class LLVMContext(): def add_fc(self, fc, readonly=False): "Add function into known_fc" - for name, detail in fc.iteritems(): + for name, detail in viewitems(fc): fnty = llvm_ir.FunctionType(detail["ret"], detail["args"]) fn = llvm_ir.Function(self.mod, fnty, name=name) if readonly: @@ -444,8 +450,10 @@ class LLVMContext_JIT(LLVMContext): self.add_shared_library(lib_fname) # Activate cache - self.exec_engine.set_object_cache(self.cache_notify, - self.cache_getbuffer) + self.exec_engine.set_object_cache( + self.cache_notify, + self.cache_getbuffer + ) def set_cache_filename(self, func, fname_out): "Set the filename @fname_out to use for cache for @func" @@ -473,16 +481,20 @@ class LLVMContext_IRCompilation(LLVMContext): """Perform a memory lookup at @addr of size @size (in bit)""" builder = func.builder int_size = LLVMType.IntType(size) - ptr_casted = builder.inttoptr(addr, - llvm_ir.PointerType(int_size)) + ptr_casted = builder.inttoptr( + addr, + llvm_ir.PointerType(int_size) + ) return builder.load(ptr_casted) def memory_write(self, func, addr, size, value): """Perform a memory write at @addr of size @size (in bit) with LLVM IR @value""" builder = func.builder int_size = LLVMType.IntType(size) - ptr_casted = builder.inttoptr(addr, - llvm_ir.PointerType(int_size)) + ptr_casted = builder.inttoptr( + addr, + llvm_ir.PointerType(int_size) + ) return builder.store(value, ptr_casted) @@ -504,8 +516,9 @@ class LLVMFunction(object): ## Add the size as first argument op_translate_with_size = {} ## Add the size as suffix - op_translate_with_suffix_size = {'bcdadd': 'bcdadd', - 'bcdadd_cf': 'bcdadd_cf', + op_translate_with_suffix_size = { + 'bcdadd': 'bcdadd', + 'bcdadd_cf': 'bcdadd_cf', } def __init__(self, llvm_context, name="fc", new_module=True): @@ -582,12 +595,20 @@ class LLVMFunction(object): offset = self.llvm_context.vmcpu[name] # Pointer cast - ptr = builder.gep(self.local_vars["vmcpu"], - [llvm_ir.Constant(LLVMType.IntType(), - offset)]) + ptr = builder.gep( + self.local_vars["vmcpu"], + [ + llvm_ir.Constant( + LLVMType.IntType(), + offset + ) + ] + ) pointee_type = LLVMType.IntType(expr.size) - ptr_casted = builder.bitcast(ptr, - llvm_ir.PointerType(pointee_type)) + ptr_casted = builder.bitcast( + ptr, + llvm_ir.PointerType(pointee_type) + ) # Store in cache self.local_vars_pointers[name] = ptr_casted @@ -612,7 +633,10 @@ class LLVMFunction(object): def get_basic_block_by_loc_key(self, loc_key): "Return the bbl corresponding to label, None otherwise" - return self.name2bbl.get(self.llvm_context.canonize_label_name(loc_key), None) + return self.name2bbl.get( + self.llvm_context.canonize_label_name(loc_key), + None + ) def global_constant(self, name, value): """ @@ -658,10 +682,15 @@ class LLVMFunction(object): count = 0 while "%s_%d" % (base_name, count) in self.mod.globals: count += 1 - global_fmt = self.global_constant("%s_%d" % (base_name, count), - fmt_bytes) - fnty = llvm_ir.FunctionType(llvm_ir.IntType(32), [cstring], - var_arg=True) + global_fmt = self.global_constant( + "%s_%d" % (base_name, count), + fmt_bytes + ) + fnty = llvm_ir.FunctionType( + llvm_ir.IntType(32), + [cstring], + var_arg=True + ) # Insert printf() fn = mod.globals.get('printf', None) if fn is None: @@ -692,7 +721,10 @@ class LLVMFunction(object): "Init the function" # Build type for fc signature - fc_type = llvm_ir.FunctionType(self.ret_type, [k[1] for k in self.my_args]) + fc_type = llvm_ir.FunctionType( + self.ret_type, + [k[1] for k in self.my_args] + ) # Add fc in module try: @@ -741,7 +773,9 @@ class LLVMFunction(object): return ret if expr.is_loc(): - offset = self.llvm_context.ir_arch.loc_db.get_location_offset(expr.loc_key) + offset = self.llvm_context.ir_arch.loc_db.get_location_offset( + expr.loc_key + ) ret = llvm_ir.Constant(LLVMType.IntType(expr.size), offset) self.update_cache(expr, ret) return ret @@ -785,7 +819,12 @@ class LLVMFunction(object): casted_args = [] for i, arg in enumerate(args): if arg.type.width < fc_ptr.args[i].type.width: - casted_args.append(builder.zext(arg, fc_ptr.args[i].type)) + casted_args.append( + builder.zext( + arg, + fc_ptr.args[i].type + ) + ) else: casted_args.append(arg) ret = builder.call(fc_ptr, casted_args) @@ -810,8 +849,10 @@ class LLVMFunction(object): assert len(expr.args) == 1 arg = self.add_ir(expr.args[0]) truncated = builder.trunc(arg, LLVMType.IntType(8)) - bitcount = builder.call(self.mod.get_global("llvm.ctpop.i8"), - [truncated]) + bitcount = builder.call( + self.mod.get_global("llvm.ctpop.i8"), + [truncated] + ) ret = builder.not_(builder.trunc(bitcount, LLVMType.IntType(1))) self.update_cache(expr, ret) return ret @@ -824,16 +865,20 @@ class LLVMFunction(object): "cnttrailzeros": "cttz", }[op] func_llvm_name = "llvm.%s.i%d" % (func_name, expr.size) - func_sig = {func_llvm_name: { - "ret": LLVMType.IntType(expr.size), - "args": [LLVMType.IntType(expr.args[0].size)] - }} + func_sig = { + func_llvm_name: { + "ret": LLVMType.IntType(expr.size), + "args": [LLVMType.IntType(expr.args[0].size)] + } + } try: self.mod.get_global(func_llvm_name) except KeyError: self.llvm_context.add_fc(func_sig, readonly=True) - ret = builder.call(self.mod.get_global(func_llvm_name), - [arg]) + ret = builder.call( + self.mod.get_global(func_llvm_name), + [arg] + ) self.update_cache(expr, ret) return ret @@ -867,12 +912,19 @@ class LLVMFunction(object): casted_args = [] for i, arg in enumerate(args, 1): if arg.type.width < fc_ptr.args[i].type.width: - casted_args.append(builder.zext(arg, fc_ptr.args[i].type)) + casted_args.append( + builder.zext( + arg, + fc_ptr.args[i].type + ) + ) else: casted_args.append(arg) - ret = builder.call(fc_ptr, - [self.local_vars["jitcpu"]] + casted_args) + ret = builder.call( + fc_ptr, + [self.local_vars["jitcpu"]] + casted_args + ) if ret.type.width > expr.size: ret = builder.trunc(ret, LLVMType.IntType(expr.size)) self.update_cache(expr, ret) @@ -905,11 +957,14 @@ class LLVMFunction(object): if op in unsigned_cmps: op = unsigned_cmps[op] args = [self.add_ir(arg) for arg in expr.args] - ret = builder.select(builder.icmp_unsigned(op, - args[0], - args[1]), - llvm_ir.IntType(expr.size)(1), - llvm_ir.IntType(expr.size)(0)) + ret = builder.select( + builder.icmp_unsigned(op, + args[0], + args[1] + ), + llvm_ir.IntType(expr.size)(1), + llvm_ir.IntType(expr.size)(0) + ) self.update_cache(expr, ret) return ret @@ -919,8 +974,11 @@ class LLVMFunction(object): count = self.add_ir(expr.args[1]) value = self.add_ir(expr.args[0]) itype = LLVMType.IntType(expr.size) - cond_ok = self.builder.icmp_unsigned("<", count, - itype(expr.size)) + cond_ok = self.builder.icmp_unsigned( + "<", + count, + itype(expr.size) + ) zero = itype(0) if op == ">>": callback = builder.lshr @@ -932,8 +990,11 @@ class LLVMFunction(object): cond_neg = self.builder.icmp_signed("<", value, zero) zero = self.builder.select(cond_neg, itype(-1), zero) - ret = self.builder.select(cond_ok, callback(value, count), - zero) + ret = self.builder.select( + cond_ok, + callback(value, count), + zero + ) self.update_cache(expr, ret) return ret @@ -948,8 +1009,10 @@ class LLVMFunction(object): # As shift of expr_size is undefined, we urem the shifters shift = builder.urem(count, expr_size) - shift_inv = builder.urem(builder.sub(expr_size, shift), - expr_size) + shift_inv = builder.urem( + builder.sub(expr_size, shift), + expr_size + ) if op == '<<<': part_a = builder.shl(value, shift) @@ -1045,12 +1108,16 @@ class LLVMFunction(object): # Apply the correct func if expr.size == 32: arg = builder.bitcast(arg, llvm_ir.FloatType()) - ret = builder.call(self.mod.get_global("llvm.%s.f32" % op), - [arg]) + ret = builder.call( + self.mod.get_global("llvm.%s.f32" % op), + [arg] + ) elif expr.size == 64: arg = builder.bitcast(arg, llvm_ir.DoubleType()) - ret = builder.call(self.mod.get_global("llvm.%s.f64" % op), - [arg]) + ret = builder.call( + self.mod.get_global("llvm.%s.f64" % op), + [arg] + ) else: raise RuntimeError("Unsupported precision: %x", expr.size) @@ -1164,22 +1231,27 @@ class LLVMFunction(object): # Remove trailing bits if expr.start != 0: - to_shr = llvm_ir.Constant(LLVMType.IntType(expr.arg.size), - expr.start) - shred = builder.lshr(src, - to_shr) + to_shr = llvm_ir.Constant( + LLVMType.IntType(expr.arg.size), + expr.start + ) + shred = builder.lshr(src, to_shr) else: shred = src # Remove leading bits - to_and = llvm_ir.Constant(LLVMType.IntType(expr.arg.size), - (1 << (expr.stop - expr.start)) - 1) + to_and = llvm_ir.Constant( + LLVMType.IntType(expr.arg.size), + (1 << (expr.stop - expr.start)) - 1 + ) anded = builder.and_(shred, to_and) # Cast into e.size - ret = builder.trunc(anded, - LLVMType.IntType(expr.size)) + ret = builder.trunc( + anded, + LLVMType.IntType(expr.size) + ) self.update_cache(expr, ret) return ret @@ -1192,17 +1264,23 @@ class LLVMFunction(object): for start, src in expr.iter_args(): # src & size src = self.add_ir(src) - src_casted = builder.zext(src, - LLVMType.IntType(expr.size)) - to_and = llvm_ir.Constant(LLVMType.IntType(expr.size), - (1 << src.type.width) - 1) + src_casted = builder.zext( + src, + LLVMType.IntType(expr.size) + ) + to_and = llvm_ir.Constant( + LLVMType.IntType(expr.size), + (1 << src.type.width) - 1 + ) anded = builder.and_(src_casted, to_and) if (start != 0): # result << start - to_shl = llvm_ir.Constant(LLVMType.IntType(expr.size), - start) + to_shl = llvm_ir.Constant( + LLVMType.IntType(expr.size), + start + ) shled = builder.shl(anded, to_shl) final = shled else: @@ -1213,7 +1291,7 @@ class LLVMFunction(object): # result = part1 | part2 | ... last = args[0] - for i in xrange(1, len(expr.args)): + for i in range(1, len(expr.args)): last = builder.or_(last, args[i]) self.update_cache(expr, last) @@ -1246,9 +1324,11 @@ class LLVMFunction(object): # Compute cond zero_casted = llvm_ir.Constant(t_size, 0) - condition_bool = builder.icmp_unsigned("!=", - exceptionflag, - zero_casted) + condition_bool = builder.icmp_unsigned( + "!=", + exceptionflag, + zero_casted + ) # Create bbls branch_id = self.new_branch_name() @@ -1264,7 +1344,7 @@ class LLVMFunction(object): # Then Bloc builder.position_at_end(then_block) PC = self.llvm_context.PC - if isinstance(offset, (int, long)): + if isinstance(offset, int_types): offset = self.add_ir(ExprInt(offset, PC.size)) self.assign(offset, PC) self.assign(self.add_ir(ExprInt(1, 8)), ExprId("status", 32)) @@ -1289,13 +1369,18 @@ class LLVMFunction(object): # Compute cond if restricted_exception is True: flag = m2_csts.EXCEPT_NUM_UPDT_EIP - condition_bool = builder.icmp_unsigned(">", exceptionflag, - llvm_ir.Constant(t_size, flag)) + condition_bool = builder.icmp_unsigned( + ">", + exceptionflag, + llvm_ir.Constant(t_size, flag) + ) else: zero_casted = llvm_ir.Constant(t_size, 0) - condition_bool = builder.icmp_unsigned("!=", - exceptionflag, - zero_casted) + condition_bool = builder.icmp_unsigned( + "!=", + exceptionflag, + zero_casted + ) # Create bbls branch_id = self.new_branch_name() @@ -1311,7 +1396,7 @@ class LLVMFunction(object): # Then Bloc builder.position_at_end(then_block) PC = self.llvm_context.PC - if isinstance(offset, (int, long)): + if isinstance(offset, int_types): offset = self.add_ir(ExprInt(offset, PC.size)) self.assign(offset, PC) self.assign(self.add_ir(ExprInt(1, 8)), ExprId("status", 32)) @@ -1324,8 +1409,12 @@ class LLVMFunction(object): def gen_pre_code(self, instr_attrib): if instr_attrib.log_mn: loc_db = self.llvm_context.ir_arch.loc_db - self.printf("%.8X %s\n" % (instr_attrib.instr.offset, - instr_attrib.instr.to_string(loc_db))) + self.printf( + "%.8X %s\n" % ( + instr_attrib.instr.offset, + instr_attrib.instr.to_string(loc_db) + ) + ) def gen_post_code(self, attributes, pc_value): if attributes.log_regs: @@ -1464,7 +1553,7 @@ class LLVMFunction(object): # Evaluate expressions values = {} - for dst, src in assignblk.iteritems(): + for dst, src in viewitems(assignblk): if dst == self.llvm_context.ir_arch.IRDst: case2dst, case_value = self.expr2cases(src) else: @@ -1472,40 +1561,51 @@ class LLVMFunction(object): # Check memory access exception if attributes[index].mem_read: - self.check_memory_exception(instr.offset, - restricted_exception=True) + self.check_memory_exception( + instr.offset, + restricted_exception=True + ) # Update the memory - for dst, src in values.iteritems(): + for dst, src in viewitems(values): if isinstance(dst, ExprMem): self.assign(src, dst) # Check memory write exception if attributes[index].mem_write: - self.check_memory_exception(instr.offset, - restricted_exception=True) + self.check_memory_exception( + instr.offset, + restricted_exception=True + ) # Update registers values - for dst, src in values.iteritems(): + for dst, src in viewitems(values): if not isinstance(dst, ExprMem): self.assign(src, dst) # Check post assignblk exception flags if attributes[index].set_exception: - self.check_cpu_exception(instr.offset, restricted_exception=True) + self.check_cpu_exception( + instr.offset, + restricted_exception=True + ) # Destination assert case2dst is not None if len(case2dst) == 1: # Avoid switch in this common case - self.gen_jump2dst(instr_attrib, instr_offsets, case2dst.values()[0]) + self.gen_jump2dst( + instr_attrib, + instr_offsets, + next(iter(viewvalues(case2dst))) + ) else: current_bbl = self.builder.basic_block # Gen the out cases branch_id = self.new_branch_name() case2bbl = {} - for case, dst in case2dst.iteritems(): + for case, dst in list(viewitems(case2dst)): name = "switch_%s_%d" % (branch_id, case) bbl = self.append_basic_block(name) case2bbl[case] = bbl @@ -1515,7 +1615,7 @@ class LLVMFunction(object): # Jump on the correct output self.builder.position_at_end(current_bbl) switch = self.builder.switch(case_value, case2bbl[0]) - for i, bbl in case2bbl.iteritems(): + for i, bbl in viewitems(case2bbl): if i == 0: # Default case is case 0, arbitrary continue @@ -1528,17 +1628,23 @@ class LLVMFunction(object): builder = self.builder m2_exception_flag = self.llvm_context.ir_arch.arch.regs.exception_flags t_size = LLVMType.IntType(m2_exception_flag.size) - self.assign(self.add_ir(ExprInt(1, 8)), - ExprId("status", 32)) - self.assign(t_size(m2_csts.EXCEPT_UNK_MNEMO), - m2_exception_flag) - offset = self.llvm_context.ir_arch.loc_db.get_location_offset(asmblock.loc_key) + self.assign( + self.add_ir(ExprInt(1, 8)), + ExprId("status", 32) + ) + self.assign( + t_size(m2_csts.EXCEPT_UNK_MNEMO), + m2_exception_flag + ) + offset = self.llvm_context.ir_arch.loc_db.get_location_offset( + asmblock.loc_key + ) self.set_ret(LLVMType.IntType(64)(offset)) def gen_finalize(self, asmblock, codegen): """ - In case of delayslot, generate a dummy BBL which return on the computed IRDst - or on next_label + In case of delayslot, generate a dummy BBL which return on the computed + IRDst or on next_label """ if self.llvm_context.has_delayslot: next_label = codegen.get_block_post_label(asmblock) @@ -1552,9 +1658,11 @@ class LLVMFunction(object): # Check if IRDst has been set zero_casted = LLVMType.IntType(codegen.delay_slot_set.size)(0) - condition_bool = builder.icmp_unsigned("!=", - self.add_ir(codegen.delay_slot_set), - zero_casted) + condition_bool = builder.icmp_unsigned( + "!=", + self.add_ir(codegen.delay_slot_set), + zero_casted + ) # Create bbls branch_id = self.new_branch_name() @@ -1627,8 +1735,10 @@ class LLVMFunction(object): if self.llvm_context.has_delayslot: for element in (codegen.delay_slot_dst, codegen.delay_slot_set): eltype = LLVMType.IntType(element.size) - ptr = self.CreateEntryBlockAlloca(eltype, - default_value=eltype(0)) + ptr = self.CreateEntryBlockAlloca( + eltype, + default_value=eltype(0) + ) self.local_vars_pointers[element.name] = ptr loc_key = codegen.get_block_post_label(asmblock) offset = self.llvm_context.ir_arch.loc_db.get_location_offset(loc_key) @@ -1640,9 +1750,12 @@ class LLVMFunction(object): for instr, irblocks in zip(asmblock.lines, irblocks_list): - instr_attrib, irblocks_attributes = codegen.get_attributes(instr, irblocks, - self.log_mn, - self.log_regs) + instr_attrib, irblocks_attributes = codegen.get_attributes( + instr, + irblocks, + self.log_mn, + self.log_regs + ) # Pre-create basic blocks for irblock in irblocks: @@ -1783,7 +1896,7 @@ class LLVMFunction_IRCompilation(LLVMFunction): def gen_irblock(self, irblock): instr_attrib = Attributes() - attributes = [Attributes() for _ in xrange(len(irblock.assignblks))] + attributes = [Attributes() for _ in range(len(irblock.assignblks))] instr_offsets = None return super(LLVMFunction_IRCompilation, self).gen_irblock( instr_attrib, attributes, instr_offsets, irblock @@ -1791,11 +1904,11 @@ class LLVMFunction_IRCompilation(LLVMFunction): def from_ircfg(self, ircfg, append_ret=True): # Create basic blocks - for loc_key, irblock in ircfg.blocks.iteritems(): + for loc_key, irblock in viewitems(ircfg.blocks): self.append_basic_block(loc_key) # Add IRBlocks - for label, irblock in ircfg.blocks.iteritems(): + for label, irblock in viewitems(ircfg.blocks): self.builder.position_at_end(self.get_basic_block_by_loc_key(label)) self.gen_irblock(irblock) |