about summary refs log tree commit diff stats
path: root/miasm2/jitter/llvmconvert.py
diff options
context:
space:
mode:
Diffstat (limited to 'miasm2/jitter/llvmconvert.py')
-rw-r--r--miasm2/jitter/llvmconvert.py323
1 files changed, 218 insertions, 105 deletions
diff --git a/miasm2/jitter/llvmconvert.py b/miasm2/jitter/llvmconvert.py
index fd32001c..bea8cd36 100644
--- a/miasm2/jitter/llvmconvert.py
+++ b/miasm2/jitter/llvmconvert.py
@@ -11,9 +11,15 @@
 #
 #
 
+from builtins import zip
+from builtins import range
 import os
 from llvmlite import binding as llvm
 from llvmlite import ir as llvm_ir
+from builtins import int as int_types
+
+from future.utils import viewitems, viewvalues
+
 from miasm2.expression.expression import ExprId, ExprInt, ExprMem, ExprSlice, \
     ExprCond, ExprLoc, ExprOp, ExprCompose, LocKey, Expr, \
     TOK_EQUAL, \
@@ -67,7 +73,7 @@ class LLVMType(llvm_ir.Type):
         return precision
 
 
-class LLVMContext():
+class LLVMContext(object):
 
     "Context for llvm binding. Stand for a LLVM Module"
 
@@ -139,7 +145,7 @@ class LLVMContext():
     def add_fc(self, fc, readonly=False):
         "Add function into known_fc"
 
-        for name, detail in fc.iteritems():
+        for name, detail in viewitems(fc):
             fnty = llvm_ir.FunctionType(detail["ret"], detail["args"])
             fn = llvm_ir.Function(self.mod, fnty, name=name)
             if readonly:
@@ -444,8 +450,10 @@ class LLVMContext_JIT(LLVMContext):
             self.add_shared_library(lib_fname)
 
         # Activate cache
-        self.exec_engine.set_object_cache(self.cache_notify,
-                                          self.cache_getbuffer)
+        self.exec_engine.set_object_cache(
+            self.cache_notify,
+            self.cache_getbuffer
+        )
 
     def set_cache_filename(self, func, fname_out):
         "Set the filename @fname_out to use for cache for @func"
@@ -473,16 +481,20 @@ class LLVMContext_IRCompilation(LLVMContext):
         """Perform a memory lookup at @addr of size @size (in bit)"""
         builder = func.builder
         int_size = LLVMType.IntType(size)
-        ptr_casted = builder.inttoptr(addr,
-                                      llvm_ir.PointerType(int_size))
+        ptr_casted = builder.inttoptr(
+            addr,
+            llvm_ir.PointerType(int_size)
+        )
         return builder.load(ptr_casted)
 
     def memory_write(self, func, addr, size, value):
         """Perform a memory write at @addr of size @size (in bit) with LLVM IR @value"""
         builder = func.builder
         int_size = LLVMType.IntType(size)
-        ptr_casted = builder.inttoptr(addr,
-                                      llvm_ir.PointerType(int_size))
+        ptr_casted = builder.inttoptr(
+            addr,
+            llvm_ir.PointerType(int_size)
+        )
         return builder.store(value, ptr_casted)
 
 
@@ -504,8 +516,9 @@ class LLVMFunction(object):
     ## Add the size as first argument
     op_translate_with_size = {}
     ## Add the size as suffix
-    op_translate_with_suffix_size = {'bcdadd': 'bcdadd',
-                                     'bcdadd_cf': 'bcdadd_cf',
+    op_translate_with_suffix_size = {
+        'bcdadd': 'bcdadd',
+        'bcdadd_cf': 'bcdadd_cf',
     }
 
     def __init__(self, llvm_context, name="fc", new_module=True):
@@ -582,12 +595,20 @@ class LLVMFunction(object):
         offset = self.llvm_context.vmcpu[name]
 
         # Pointer cast
-        ptr = builder.gep(self.local_vars["vmcpu"],
-                          [llvm_ir.Constant(LLVMType.IntType(),
-                                            offset)])
+        ptr = builder.gep(
+            self.local_vars["vmcpu"],
+            [
+                llvm_ir.Constant(
+                    LLVMType.IntType(),
+                    offset
+                )
+            ]
+        )
         pointee_type = LLVMType.IntType(expr.size)
-        ptr_casted = builder.bitcast(ptr,
-                                     llvm_ir.PointerType(pointee_type))
+        ptr_casted = builder.bitcast(
+            ptr,
+            llvm_ir.PointerType(pointee_type)
+        )
         # Store in cache
         self.local_vars_pointers[name] = ptr_casted
 
@@ -612,7 +633,10 @@ class LLVMFunction(object):
 
     def get_basic_block_by_loc_key(self, loc_key):
         "Return the bbl corresponding to label, None otherwise"
-        return self.name2bbl.get(self.llvm_context.canonize_label_name(loc_key), None)
+        return self.name2bbl.get(
+            self.llvm_context.canonize_label_name(loc_key),
+            None
+        )
 
     def global_constant(self, name, value):
         """
@@ -658,10 +682,15 @@ class LLVMFunction(object):
         count = 0
         while "%s_%d" % (base_name, count) in self.mod.globals:
             count += 1
-        global_fmt = self.global_constant("%s_%d" % (base_name, count),
-                                          fmt_bytes)
-        fnty = llvm_ir.FunctionType(llvm_ir.IntType(32), [cstring],
-                                    var_arg=True)
+        global_fmt = self.global_constant(
+            "%s_%d" % (base_name, count),
+            fmt_bytes
+        )
+        fnty = llvm_ir.FunctionType(
+            llvm_ir.IntType(32),
+            [cstring],
+            var_arg=True
+        )
         # Insert printf()
         fn = mod.globals.get('printf', None)
         if fn is None:
@@ -692,7 +721,10 @@ class LLVMFunction(object):
         "Init the function"
 
         # Build type for fc signature
-        fc_type = llvm_ir.FunctionType(self.ret_type, [k[1] for k in self.my_args])
+        fc_type = llvm_ir.FunctionType(
+            self.ret_type,
+            [k[1] for k in self.my_args]
+        )
 
         # Add fc in module
         try:
@@ -741,7 +773,9 @@ class LLVMFunction(object):
             return ret
 
         if expr.is_loc():
-            offset = self.llvm_context.ir_arch.loc_db.get_location_offset(expr.loc_key)
+            offset = self.llvm_context.ir_arch.loc_db.get_location_offset(
+                expr.loc_key
+            )
             ret = llvm_ir.Constant(LLVMType.IntType(expr.size), offset)
             self.update_cache(expr, ret)
             return ret
@@ -785,7 +819,12 @@ class LLVMFunction(object):
                 casted_args = []
                 for i, arg in enumerate(args):
                     if arg.type.width < fc_ptr.args[i].type.width:
-                        casted_args.append(builder.zext(arg, fc_ptr.args[i].type))
+                        casted_args.append(
+                            builder.zext(
+                                arg,
+                                fc_ptr.args[i].type
+                            )
+                        )
                     else:
                         casted_args.append(arg)
                 ret = builder.call(fc_ptr, casted_args)
@@ -810,8 +849,10 @@ class LLVMFunction(object):
                 assert len(expr.args) == 1
                 arg = self.add_ir(expr.args[0])
                 truncated = builder.trunc(arg, LLVMType.IntType(8))
-                bitcount = builder.call(self.mod.get_global("llvm.ctpop.i8"),
-                                        [truncated])
+                bitcount = builder.call(
+                    self.mod.get_global("llvm.ctpop.i8"),
+                    [truncated]
+                )
                 ret = builder.not_(builder.trunc(bitcount, LLVMType.IntType(1)))
                 self.update_cache(expr, ret)
                 return ret
@@ -824,16 +865,20 @@ class LLVMFunction(object):
                     "cnttrailzeros": "cttz",
                 }[op]
                 func_llvm_name = "llvm.%s.i%d" % (func_name, expr.size)
-                func_sig = {func_llvm_name: {
-                    "ret": LLVMType.IntType(expr.size),
-                    "args": [LLVMType.IntType(expr.args[0].size)]
-                }}
+                func_sig = {
+                    func_llvm_name: {
+                        "ret": LLVMType.IntType(expr.size),
+                        "args": [LLVMType.IntType(expr.args[0].size)]
+                    }
+                }
                 try:
                     self.mod.get_global(func_llvm_name)
                 except KeyError:
                     self.llvm_context.add_fc(func_sig, readonly=True)
-                ret = builder.call(self.mod.get_global(func_llvm_name),
-                                   [arg])
+                ret = builder.call(
+                    self.mod.get_global(func_llvm_name),
+                    [arg]
+                )
                 self.update_cache(expr, ret)
                 return ret
 
@@ -867,12 +912,19 @@ class LLVMFunction(object):
                 casted_args = []
                 for i, arg in enumerate(args, 1):
                     if arg.type.width < fc_ptr.args[i].type.width:
-                        casted_args.append(builder.zext(arg, fc_ptr.args[i].type))
+                        casted_args.append(
+                            builder.zext(
+                                arg,
+                                fc_ptr.args[i].type
+                            )
+                        )
                     else:
                         casted_args.append(arg)
 
-                ret = builder.call(fc_ptr,
-                                   [self.local_vars["jitcpu"]] + casted_args)
+                ret = builder.call(
+                    fc_ptr,
+                    [self.local_vars["jitcpu"]] + casted_args
+                )
                 if ret.type.width > expr.size:
                     ret = builder.trunc(ret, LLVMType.IntType(expr.size))
                 self.update_cache(expr, ret)
@@ -905,11 +957,14 @@ class LLVMFunction(object):
             if op in unsigned_cmps:
                 op = unsigned_cmps[op]
                 args = [self.add_ir(arg) for arg in expr.args]
-                ret = builder.select(builder.icmp_unsigned(op,
-                                                           args[0],
-                                                           args[1]),
-                                     llvm_ir.IntType(expr.size)(1),
-                                     llvm_ir.IntType(expr.size)(0))
+                ret = builder.select(
+                    builder.icmp_unsigned(op,
+                                          args[0],
+                                          args[1]
+                    ),
+                    llvm_ir.IntType(expr.size)(1),
+                    llvm_ir.IntType(expr.size)(0)
+                )
                 self.update_cache(expr, ret)
                 return ret
 
@@ -919,8 +974,11 @@ class LLVMFunction(object):
                 count = self.add_ir(expr.args[1])
                 value = self.add_ir(expr.args[0])
                 itype = LLVMType.IntType(expr.size)
-                cond_ok = self.builder.icmp_unsigned("<", count,
-                                                     itype(expr.size))
+                cond_ok = self.builder.icmp_unsigned(
+                    "<",
+                    count,
+                    itype(expr.size)
+                )
                 zero = itype(0)
                 if op == ">>":
                     callback = builder.lshr
@@ -932,8 +990,11 @@ class LLVMFunction(object):
                     cond_neg = self.builder.icmp_signed("<", value, zero)
                     zero = self.builder.select(cond_neg, itype(-1), zero)
 
-                ret = self.builder.select(cond_ok, callback(value, count),
-                                          zero)
+                ret = self.builder.select(
+                    cond_ok,
+                    callback(value, count),
+                    zero
+                )
                 self.update_cache(expr, ret)
                 return ret
 
@@ -948,8 +1009,10 @@ class LLVMFunction(object):
 
                 # As shift of expr_size is undefined, we urem the shifters
                 shift = builder.urem(count, expr_size)
-                shift_inv = builder.urem(builder.sub(expr_size, shift),
-                                         expr_size)
+                shift_inv = builder.urem(
+                    builder.sub(expr_size, shift),
+                    expr_size
+                )
 
                 if op == '<<<':
                     part_a = builder.shl(value, shift)
@@ -1045,12 +1108,16 @@ class LLVMFunction(object):
                 # Apply the correct func
                 if expr.size == 32:
                     arg = builder.bitcast(arg, llvm_ir.FloatType())
-                    ret = builder.call(self.mod.get_global("llvm.%s.f32" % op),
-                                       [arg])
+                    ret = builder.call(
+                        self.mod.get_global("llvm.%s.f32" % op),
+                        [arg]
+                    )
                 elif expr.size == 64:
                     arg = builder.bitcast(arg, llvm_ir.DoubleType())
-                    ret = builder.call(self.mod.get_global("llvm.%s.f64" % op),
-                                       [arg])
+                    ret = builder.call(
+                        self.mod.get_global("llvm.%s.f64" % op),
+                        [arg]
+                    )
                 else:
                     raise RuntimeError("Unsupported precision: %x", expr.size)
 
@@ -1164,22 +1231,27 @@ class LLVMFunction(object):
 
             # Remove trailing bits
             if expr.start != 0:
-                to_shr = llvm_ir.Constant(LLVMType.IntType(expr.arg.size),
-                                          expr.start)
-                shred = builder.lshr(src,
-                                     to_shr)
+                to_shr = llvm_ir.Constant(
+                    LLVMType.IntType(expr.arg.size),
+                    expr.start
+                )
+                shred = builder.lshr(src, to_shr)
             else:
                 shred = src
 
             # Remove leading bits
-            to_and = llvm_ir.Constant(LLVMType.IntType(expr.arg.size),
-                                      (1 << (expr.stop - expr.start)) - 1)
+            to_and = llvm_ir.Constant(
+                LLVMType.IntType(expr.arg.size),
+                (1 << (expr.stop - expr.start)) - 1
+            )
             anded = builder.and_(shred,
                                  to_and)
 
             # Cast into e.size
-            ret = builder.trunc(anded,
-                                LLVMType.IntType(expr.size))
+            ret = builder.trunc(
+                anded,
+                LLVMType.IntType(expr.size)
+            )
 
             self.update_cache(expr, ret)
             return ret
@@ -1192,17 +1264,23 @@ class LLVMFunction(object):
             for start, src in expr.iter_args():
                 # src & size
                 src = self.add_ir(src)
-                src_casted = builder.zext(src,
-                                          LLVMType.IntType(expr.size))
-                to_and = llvm_ir.Constant(LLVMType.IntType(expr.size),
-                                          (1 << src.type.width) - 1)
+                src_casted = builder.zext(
+                    src,
+                    LLVMType.IntType(expr.size)
+                )
+                to_and = llvm_ir.Constant(
+                    LLVMType.IntType(expr.size),
+                    (1 << src.type.width) - 1
+                )
                 anded = builder.and_(src_casted,
                                      to_and)
 
                 if (start != 0):
                     # result << start
-                    to_shl = llvm_ir.Constant(LLVMType.IntType(expr.size),
-                                              start)
+                    to_shl = llvm_ir.Constant(
+                        LLVMType.IntType(expr.size),
+                        start
+                    )
                     shled = builder.shl(anded, to_shl)
                     final = shled
                 else:
@@ -1213,7 +1291,7 @@ class LLVMFunction(object):
 
             # result = part1 | part2 | ...
             last = args[0]
-            for i in xrange(1, len(expr.args)):
+            for i in range(1, len(expr.args)):
                 last = builder.or_(last, args[i])
 
             self.update_cache(expr, last)
@@ -1246,9 +1324,11 @@ class LLVMFunction(object):
 
         # Compute cond
         zero_casted = llvm_ir.Constant(t_size, 0)
-        condition_bool = builder.icmp_unsigned("!=",
-                                               exceptionflag,
-                                               zero_casted)
+        condition_bool = builder.icmp_unsigned(
+            "!=",
+            exceptionflag,
+            zero_casted
+        )
 
         # Create bbls
         branch_id = self.new_branch_name()
@@ -1264,7 +1344,7 @@ class LLVMFunction(object):
         # Then Bloc
         builder.position_at_end(then_block)
         PC = self.llvm_context.PC
-        if isinstance(offset, (int, long)):
+        if isinstance(offset, int_types):
             offset = self.add_ir(ExprInt(offset, PC.size))
         self.assign(offset, PC)
         self.assign(self.add_ir(ExprInt(1, 8)), ExprId("status", 32))
@@ -1289,13 +1369,18 @@ class LLVMFunction(object):
         # Compute cond
         if restricted_exception is True:
             flag = m2_csts.EXCEPT_NUM_UPDT_EIP
-            condition_bool = builder.icmp_unsigned(">", exceptionflag,
-                                                   llvm_ir.Constant(t_size, flag))
+            condition_bool = builder.icmp_unsigned(
+                ">",
+                exceptionflag,
+                llvm_ir.Constant(t_size, flag)
+            )
         else:
             zero_casted = llvm_ir.Constant(t_size, 0)
-            condition_bool = builder.icmp_unsigned("!=",
-                                                   exceptionflag,
-                                                   zero_casted)
+            condition_bool = builder.icmp_unsigned(
+                "!=",
+                exceptionflag,
+                zero_casted
+            )
 
         # Create bbls
         branch_id = self.new_branch_name()
@@ -1311,7 +1396,7 @@ class LLVMFunction(object):
         # Then Bloc
         builder.position_at_end(then_block)
         PC = self.llvm_context.PC
-        if isinstance(offset, (int, long)):
+        if isinstance(offset, int_types):
             offset = self.add_ir(ExprInt(offset, PC.size))
         self.assign(offset, PC)
         self.assign(self.add_ir(ExprInt(1, 8)), ExprId("status", 32))
@@ -1324,8 +1409,12 @@ class LLVMFunction(object):
     def gen_pre_code(self, instr_attrib):
         if instr_attrib.log_mn:
             loc_db = self.llvm_context.ir_arch.loc_db
-            self.printf("%.8X %s\n" % (instr_attrib.instr.offset,
-                                       instr_attrib.instr.to_string(loc_db)))
+            self.printf(
+                "%.8X %s\n" % (
+                    instr_attrib.instr.offset,
+                    instr_attrib.instr.to_string(loc_db)
+                )
+            )
 
     def gen_post_code(self, attributes, pc_value):
         if attributes.log_regs:
@@ -1464,7 +1553,7 @@ class LLVMFunction(object):
 
             # Evaluate expressions
             values = {}
-            for dst, src in assignblk.iteritems():
+            for dst, src in viewitems(assignblk):
                 if dst == self.llvm_context.ir_arch.IRDst:
                     case2dst, case_value = self.expr2cases(src)
                 else:
@@ -1472,40 +1561,51 @@ class LLVMFunction(object):
 
             # Check memory access exception
             if attributes[index].mem_read:
-                self.check_memory_exception(instr.offset,
-                                            restricted_exception=True)
+                self.check_memory_exception(
+                    instr.offset,
+                    restricted_exception=True
+                )
 
             # Update the memory
-            for dst, src in values.iteritems():
+            for dst, src in viewitems(values):
                 if isinstance(dst, ExprMem):
                     self.assign(src, dst)
 
             # Check memory write exception
             if attributes[index].mem_write:
-                self.check_memory_exception(instr.offset,
-                                            restricted_exception=True)
+                self.check_memory_exception(
+                    instr.offset,
+                    restricted_exception=True
+                )
 
             # Update registers values
-            for dst, src in values.iteritems():
+            for dst, src in viewitems(values):
                 if not isinstance(dst, ExprMem):
                     self.assign(src, dst)
 
             # Check post assignblk exception flags
             if attributes[index].set_exception:
-                self.check_cpu_exception(instr.offset, restricted_exception=True)
+                self.check_cpu_exception(
+                    instr.offset,
+                    restricted_exception=True
+                )
 
         # Destination
         assert case2dst is not None
         if len(case2dst) == 1:
             # Avoid switch in this common case
-            self.gen_jump2dst(instr_attrib, instr_offsets, case2dst.values()[0])
+            self.gen_jump2dst(
+                instr_attrib,
+                instr_offsets,
+                next(iter(viewvalues(case2dst)))
+            )
         else:
             current_bbl = self.builder.basic_block
 
             # Gen the out cases
             branch_id = self.new_branch_name()
             case2bbl = {}
-            for case, dst in case2dst.iteritems():
+            for case, dst in list(viewitems(case2dst)):
                 name = "switch_%s_%d" % (branch_id, case)
                 bbl = self.append_basic_block(name)
                 case2bbl[case] = bbl
@@ -1515,7 +1615,7 @@ class LLVMFunction(object):
             # Jump on the correct output
             self.builder.position_at_end(current_bbl)
             switch = self.builder.switch(case_value, case2bbl[0])
-            for i, bbl in case2bbl.iteritems():
+            for i, bbl in viewitems(case2bbl):
                 if i == 0:
                     # Default case is case 0, arbitrary
                     continue
@@ -1528,17 +1628,23 @@ class LLVMFunction(object):
         builder = self.builder
         m2_exception_flag = self.llvm_context.ir_arch.arch.regs.exception_flags
         t_size = LLVMType.IntType(m2_exception_flag.size)
-        self.assign(self.add_ir(ExprInt(1, 8)),
-                    ExprId("status", 32))
-        self.assign(t_size(m2_csts.EXCEPT_UNK_MNEMO),
-                    m2_exception_flag)
-        offset = self.llvm_context.ir_arch.loc_db.get_location_offset(asmblock.loc_key)
+        self.assign(
+            self.add_ir(ExprInt(1, 8)),
+            ExprId("status", 32)
+        )
+        self.assign(
+            t_size(m2_csts.EXCEPT_UNK_MNEMO),
+            m2_exception_flag
+        )
+        offset = self.llvm_context.ir_arch.loc_db.get_location_offset(
+            asmblock.loc_key
+        )
         self.set_ret(LLVMType.IntType(64)(offset))
 
     def gen_finalize(self, asmblock, codegen):
         """
-        In case of delayslot, generate a dummy BBL which return on the computed IRDst
-        or on next_label
+        In case of delayslot, generate a dummy BBL which return on the computed
+        IRDst or on next_label
         """
         if self.llvm_context.has_delayslot:
             next_label = codegen.get_block_post_label(asmblock)
@@ -1552,9 +1658,11 @@ class LLVMFunction(object):
 
             # Check if IRDst has been set
             zero_casted = LLVMType.IntType(codegen.delay_slot_set.size)(0)
-            condition_bool = builder.icmp_unsigned("!=",
-                                                   self.add_ir(codegen.delay_slot_set),
-                                                   zero_casted)
+            condition_bool = builder.icmp_unsigned(
+                "!=",
+                self.add_ir(codegen.delay_slot_set),
+                zero_casted
+            )
 
             # Create bbls
             branch_id = self.new_branch_name()
@@ -1627,8 +1735,10 @@ class LLVMFunction(object):
         if self.llvm_context.has_delayslot:
             for element in (codegen.delay_slot_dst, codegen.delay_slot_set):
                 eltype = LLVMType.IntType(element.size)
-                ptr = self.CreateEntryBlockAlloca(eltype,
-                                                  default_value=eltype(0))
+                ptr = self.CreateEntryBlockAlloca(
+                    eltype,
+                    default_value=eltype(0)
+                )
                 self.local_vars_pointers[element.name] = ptr
             loc_key = codegen.get_block_post_label(asmblock)
             offset = self.llvm_context.ir_arch.loc_db.get_location_offset(loc_key)
@@ -1640,9 +1750,12 @@ class LLVMFunction(object):
 
 
         for instr, irblocks in zip(asmblock.lines, irblocks_list):
-            instr_attrib, irblocks_attributes = codegen.get_attributes(instr, irblocks,
-                                                                       self.log_mn,
-                                                                       self.log_regs)
+            instr_attrib, irblocks_attributes = codegen.get_attributes(
+                instr,
+                irblocks,
+                self.log_mn,
+                self.log_regs
+            )
 
             # Pre-create basic blocks
             for irblock in irblocks:
@@ -1783,7 +1896,7 @@ class LLVMFunction_IRCompilation(LLVMFunction):
 
     def gen_irblock(self, irblock):
         instr_attrib = Attributes()
-        attributes = [Attributes() for _ in xrange(len(irblock.assignblks))]
+        attributes = [Attributes() for _ in range(len(irblock.assignblks))]
         instr_offsets = None
         return super(LLVMFunction_IRCompilation, self).gen_irblock(
             instr_attrib, attributes, instr_offsets, irblock
@@ -1791,11 +1904,11 @@ class LLVMFunction_IRCompilation(LLVMFunction):
 
     def from_ircfg(self, ircfg, append_ret=True):
         # Create basic blocks
-        for loc_key, irblock in ircfg.blocks.iteritems():
+        for loc_key, irblock in viewitems(ircfg.blocks):
             self.append_basic_block(loc_key)
 
         # Add IRBlocks
-        for label, irblock in ircfg.blocks.iteritems():
+        for label, irblock in viewitems(ircfg.blocks):
             self.builder.position_at_end(self.get_basic_block_by_loc_key(label))
             self.gen_irblock(irblock)