about summary refs log tree commit diff stats
path: root/miasm2/jitter/llvmconvert.py
diff options
context:
space:
mode:
Diffstat (limited to 'miasm2/jitter/llvmconvert.py')
-rw-r--r--miasm2/jitter/llvmconvert.py154
1 files changed, 67 insertions, 87 deletions
diff --git a/miasm2/jitter/llvmconvert.py b/miasm2/jitter/llvmconvert.py
index 3f7d0c6d..2045f083 100644
--- a/miasm2/jitter/llvmconvert.py
+++ b/miasm2/jitter/llvmconvert.py
@@ -14,7 +14,8 @@
 import os
 from llvmlite import binding as llvm
 from llvmlite import ir as llvm_ir
-import miasm2.expression.expression as m2_expr
+from miasm2.expression.expression import ExprId, ExprInt, ExprMem, ExprSlice, \
+    ExprCond, ExprLoc, ExprOp, ExprCompose, LocKey
 import miasm2.jitter.csts as m2_csts
 import miasm2.core.asmblock as m2_asmblock
 from miasm2.jitter.codegen import CGen
@@ -43,7 +44,7 @@ class LLVMType(llvm_ir.Type):
     @classmethod
     def generic(cls, e):
         "Generic value for execution"
-        if isinstance(e, m2_expr.ExprInt):
+        if isinstance(e, ExprInt):
             return llvm_e.GenericValue.int(LLVMType.IntType(e.size), int(e.arg))
         elif isinstance(e, llvm_e.GenericValue):
             return e
@@ -69,25 +70,21 @@ class LLVMContext():
         self.target_machine = target.create_target_machine()
         self.init_exec_engine()
 
-
     def canonize_label_name(self, label):
         """Canonize @label names to a common form.
         @label: str or asmlabel instance"""
         if isinstance(label, str):
             return label
-        if isinstance(label, m2_expr.Expr) and expr.is_label():
-            label = self.llvm_context.ir_arch.symbol_pool.loc_key_to_label(label.index)
-        if isinstance(label, (int, long)):
-            fds
-            label = self.llvm_context.ir_arch.symbol_pool.loc_key_to_label(label)
-
-        if isinstance(label, m2_asmblock.AsmLabel):
-            if label.offset is None:
-                return "label_%s" % label.name
-            else:
-                return "label_%X" % label.offset
+        if not isinstance(label, LocKey):
+            raise ValueError("label must either be str or LocKey")
+
+        offset = self.ir_arch.symbol_pool.loc_key_to_offset(label)
+
+        if offset is None:
+            name = self.ir_arch.symbol_pool.loc_key_to_name(label)
+            return "%s" % name
         else:
-            raise ValueError("label must either be str or asmlabel")
+            return "label_off_%X" % offset
 
     def optimise_level(self, level=2):
         """Set the optimisation level to @level from 0 to 2
@@ -400,8 +397,8 @@ class LLVMFunction():
 
     def __init__(self, llvm_context, name="fc", new_module=True):
         "Create a new function with name @name"
-        name = self.canonize_label_name(name)
         self.llvm_context = llvm_context
+        name = self.llvm_context.canonize_label_name(name)
         if new_module:
             self.llvm_context.new_module()
         self.mod = self.llvm_context.get_module()
@@ -427,7 +424,7 @@ class LLVMFunction():
         @label: str or asmlabel
         @overwrite: if False, do nothing if a bbl with the same name already exists
         Return the corresponding LLVM Basic Block"""
-        name = self.canonize_label_name(label)
+        name = self.llvm_context.canonize_label_name(label)
         bbl = self.name2bbl.get(name, None)
         if not overwrite and bbl is not None:
             return bbl
@@ -505,27 +502,9 @@ class LLVMFunction():
             var_casted = var
         self.builder.ret(var_casted)
 
-    def canonize_label_name(self, label):
-        """Canonize @label names to a common form.
-        @label: str or asmlabel instance"""
-        if isinstance(label, str):
-            return label
-        if isinstance(label, m2_expr.Expr) and expr.is_label():
-            label = self.llvm_context.ir_arch.symbol_pool.loc_key_to_label(label.index)
-        if isinstance(label, m2_expr.LocKey):
-            label = self.llvm_context.ir_arch.symbol_pool.loc_key_to_label(label)
-
-        if isinstance(label, m2_asmblock.AsmLabel):
-            if label.offset is None:
-                return "label_%s" % label.name
-            else:
-                return "label_%X" % label.offset
-        else:
-            raise ValueError("label must either be str or asmlabel")
-
-    def get_basic_bloc_by_label(self, label):
+    def get_basic_block_by_loc_key(self, loc_key):
         "Return the bbl corresponding to label, None otherwise"
-        return self.name2bbl.get(self.canonize_label_name(label), None)
+        return self.name2bbl.get(self.llvm_context.canonize_label_name(loc_key), None)
 
     def global_constant(self, name, value):
         """
@@ -591,11 +570,11 @@ class LLVMFunction():
         # Destination
         builder = self.builder
 
-        if isinstance(dst, m2_expr.ExprId):
+        if isinstance(dst, ExprId):
             ptr_casted = self.get_ptr_by_expr(dst)
             builder.store(src, ptr_casted)
 
-        elif isinstance(dst, m2_expr.ExprMem):
+        elif isinstance(dst, ExprMem):
             addr = self.add_ir(dst.arg)
             self.llvm_context.memory_write(self, addr, dst.size, src)
         else:
@@ -648,19 +627,18 @@ class LLVMFunction():
 
         builder = self.builder
 
-        if isinstance(expr, m2_expr.ExprInt):
+        if isinstance(expr, ExprInt):
             ret = llvm_ir.Constant(LLVMType.IntType(expr.size), int(expr.arg))
             self.update_cache(expr, ret)
             return ret
 
-        if expr.is_label():
-            label = self.llvm_context.ir_arch.symbol_pool.loc_key_to_label(expr.loc_key)
-            offset = label.offset
+        if expr.is_loc():
+            offset = self.llvm_context.ir_arch.symbol_pool.loc_key_to_offset(expr.loc_key)
             ret = llvm_ir.Constant(LLVMType.IntType(expr.size), offset)
             self.update_cache(expr, ret)
             return ret
 
-        if isinstance(expr, m2_expr.ExprId):
+        if isinstance(expr, ExprId):
             name = expr.name
             try:
                 # If expr.name is already known (args)
@@ -674,7 +652,7 @@ class LLVMFunction():
             self.update_cache(expr, var)
             return var
 
-        if isinstance(expr, m2_expr.ExprOp):
+        if isinstance(expr, ExprOp):
             op = expr.op
 
             if (op in self.op_translate or
@@ -881,12 +859,12 @@ class LLVMFunction():
 
             raise NotImplementedError()
 
-        if isinstance(expr, m2_expr.ExprMem):
+        if isinstance(expr, ExprMem):
 
             addr = self.add_ir(expr.arg)
             return self.llvm_context.memory_lookup(self, addr, expr.size)
 
-        if isinstance(expr, m2_expr.ExprCond):
+        if isinstance(expr, ExprCond):
             # Compute cond
             cond = self.add_ir(expr.cond)
             zero_casted = LLVMType.IntType(expr.cond.size)(0)
@@ -899,7 +877,7 @@ class LLVMFunction():
             self.update_cache(expr, ret)
             return ret
 
-        if isinstance(expr, m2_expr.ExprSlice):
+        if isinstance(expr, ExprSlice):
 
             src = self.add_ir(expr.arg)
 
@@ -925,7 +903,7 @@ class LLVMFunction():
             self.update_cache(expr, ret)
             return ret
 
-        if isinstance(expr, m2_expr.ExprCompose):
+        if isinstance(expr, ExprCompose):
 
             args = []
 
@@ -1006,9 +984,9 @@ class LLVMFunction():
         builder.position_at_end(then_block)
         PC = self.llvm_context.PC
         if isinstance(offset, (int, long)):
-            offset = self.add_ir(m2_expr.ExprInt(offset, PC.size))
+            offset = self.add_ir(ExprInt(offset, PC.size))
         self.affect(offset, PC)
-        self.affect(self.add_ir(m2_expr.ExprInt(1, 8)), m2_expr.ExprId("status", 32))
+        self.affect(self.add_ir(ExprInt(1, 8)), ExprId("status", 32))
         self.set_ret(offset)
 
         builder.position_at_end(merge_block)
@@ -1053,9 +1031,9 @@ class LLVMFunction():
         builder.position_at_end(then_block)
         PC = self.llvm_context.PC
         if isinstance(offset, (int, long)):
-            offset = self.add_ir(m2_expr.ExprInt(offset, PC.size))
+            offset = self.add_ir(ExprInt(offset, PC.size))
         self.affect(offset, PC)
-        self.affect(self.add_ir(m2_expr.ExprInt(1, 8)), m2_expr.ExprId("status", 32))
+        self.affect(self.add_ir(ExprInt(1, 8)), ExprId("status", 32))
         self.set_ret(offset)
 
         builder.position_at_end(merge_block)
@@ -1100,9 +1078,9 @@ class LLVMFunction():
         for i, solution in enumerate(possible_values(expr)):
             value = solution.value
             index = dst2case.get(value, i)
-            to_eval = to_eval.replace_expr({value: m2_expr.ExprInt(index, value.size)})
+            to_eval = to_eval.replace_expr({value: ExprInt(index, value.size)})
             dst2case[value] = index
-            if value.is_int() or value.is_label():
+            if value.is_int() or value.is_loc():
                 case2dst[i] = value
             else:
                 case2dst[i] = self.add_ir(value)
@@ -1128,14 +1106,14 @@ class LLVMFunction():
         # We are no longer in the main stream, deactivate cache
         self.main_stream = False
 
-        if isinstance(dst, m2_expr.ExprInt):
-            label = self.llvm_context.ir_arch.symbol_pool.getby_offset_create(int(dst))
-            dst = m2_expr.ExprLoc(label.loc_key, dst.size)
+        if isinstance(dst, ExprInt):
+            loc_key = self.llvm_context.ir_arch.symbol_pool.getby_offset_create(int(dst))
+            dst = ExprLoc(loc_key, dst.size)
 
-        if isinstance(dst, m2_expr.ExprLoc):
-            label = self.llvm_context.ir_arch.symbol_pool.loc_key_to_label(dst.loc_key)
-            bbl = self.get_basic_bloc_by_label(label)
-            offset = label.offset
+        if isinstance(dst, ExprLoc):
+            loc_key = dst.loc_key
+            bbl = self.get_basic_block_by_loc_key(loc_key)
+            offset = self.llvm_context.ir_arch.symbol_pool.loc_key_to_offset(loc_key)
             if bbl is not None:
                 # "local" jump, inside this function
                 if offset is None:
@@ -1155,7 +1133,7 @@ class LLVMFunction():
                 # extern
 
             # "extern" jump on a defined offset, return to the caller
-            dst = self.add_ir(m2_expr.ExprInt(offset, PC.size))
+            dst = self.add_ir(ExprInt(offset, PC.size))
 
         # "extern" jump with a computed value, return to the caller
         assert isinstance(dst, (llvm_ir.Instruction, llvm_ir.Value))
@@ -1167,7 +1145,7 @@ class LLVMFunction():
         self.gen_post_code(attrib)
         self.affect(dst, PC)
         self.gen_post_instr_checks(attrib, dst)
-        self.affect(self.add_ir(m2_expr.ExprInt(0, 8)), m2_expr.ExprId("status", 32))
+        self.affect(self.add_ir(ExprInt(0, 8)), ExprId("status", 32))
         self.set_ret(dst)
 
 
@@ -1191,7 +1169,7 @@ class LLVMFunction():
 
             # Prefetch memory
             for element in assignblk.get_r(mem_read=True):
-                if isinstance(element, m2_expr.ExprMem):
+                if isinstance(element, ExprMem):
                     self.add_ir(element)
 
             # Evaluate expressions
@@ -1209,7 +1187,7 @@ class LLVMFunction():
 
             # Update the memory
             for dst, src in values.iteritems():
-                if isinstance(dst, m2_expr.ExprMem):
+                if isinstance(dst, ExprMem):
                     self.affect(src, dst)
 
             # Check memory write exception
@@ -1219,7 +1197,7 @@ class LLVMFunction():
 
             # Update registers values
             for dst, src in values.iteritems():
-                if not isinstance(dst, m2_expr.ExprMem):
+                if not isinstance(dst, ExprMem):
                     self.affect(src, dst)
 
             # Check post assignblk exception flags
@@ -1260,11 +1238,12 @@ class LLVMFunction():
         builder = self.builder
         m2_exception_flag = self.llvm_context.ir_arch.arch.regs.exception_flags
         t_size = LLVMType.IntType(m2_exception_flag.size)
-        self.affect(self.add_ir(m2_expr.ExprInt(1, 8)),
-                    m2_expr.ExprId("status", 32))
+        self.affect(self.add_ir(ExprInt(1, 8)),
+                    ExprId("status", 32))
         self.affect(t_size(m2_csts.EXCEPT_UNK_MNEMO),
                     m2_exception_flag)
-        self.set_ret(LLVMType.IntType(64)(asmblock.label.offset))
+        offset = self.llvm_context.ir_arch.symbol_pool.loc_key_to_offset(asmblock.loc_key)
+        self.set_ret(LLVMType.IntType(64)(offset))
 
     def gen_finalize(self, asmblock, codegen):
         """
@@ -1275,11 +1254,11 @@ class LLVMFunction():
             next_label = codegen.get_block_post_label(asmblock)
             builder = self.builder
 
-            builder.position_at_end(self.get_basic_bloc_by_label(next_label))
+            builder.position_at_end(self.get_basic_block_by_loc_key(next_label))
 
             # Common code
-            self.affect(self.add_ir(m2_expr.ExprInt(0, 8)),
-                        m2_expr.ExprId("status", 32))
+            self.affect(self.add_ir(ExprInt(0, 8)),
+                        ExprId("status", 32))
 
             # Check if IRDst has been set
             zero_casted = LLVMType.IntType(codegen.delay_slot_set.size)(0)
@@ -1302,14 +1281,15 @@ class LLVMFunction():
             PC = self.llvm_context.PC
             to_ret = self.add_ir(codegen.delay_slot_dst)
             self.affect(to_ret, PC)
-            self.affect(self.add_ir(m2_expr.ExprInt(0, 8)),
-                        m2_expr.ExprId("status", 32))
+            self.affect(self.add_ir(ExprInt(0, 8)),
+                        ExprId("status", 32))
             self.set_ret(to_ret)
 
             # Else Block
             builder.position_at_end(else_block)
             PC = self.llvm_context.PC
-            to_ret = LLVMType.IntType(PC.size)(next_label.offset)
+            next_label_offset = self.llvm_context.ir_arch.symbol_pool.loc_key_to_offset(next_label)
+            to_ret = LLVMType.IntType(PC.size)(next_label_offset)
             self.affect(to_ret, PC)
             self.set_ret(to_ret)
 
@@ -1318,16 +1298,16 @@ class LLVMFunction():
         Prototype : f(i8* jitcpu, i8* vmcpu, i8* vmmngr, i8* status)"""
 
         # Build function signature
-        self.my_args.append((m2_expr.ExprId("jitcpu", 32),
+        self.my_args.append((ExprId("jitcpu", 32),
                              llvm_ir.PointerType(LLVMType.IntType(8)),
                              "jitcpu"))
-        self.my_args.append((m2_expr.ExprId("vmcpu", 32),
+        self.my_args.append((ExprId("vmcpu", 32),
                              llvm_ir.PointerType(LLVMType.IntType(8)),
                              "vmcpu"))
-        self.my_args.append((m2_expr.ExprId("vmmngr", 32),
+        self.my_args.append((ExprId("vmmngr", 32),
                              llvm_ir.PointerType(LLVMType.IntType(8)),
                              "vmmngr"))
-        self.my_args.append((m2_expr.ExprId("status", 32),
+        self.my_args.append((ExprId("status", 32),
                              llvm_ir.PointerType(LLVMType.IntType(8)),
                              "status"))
         ret_size = 64
@@ -1360,9 +1340,10 @@ class LLVMFunction():
                 ptr = self.CreateEntryBlockAlloca(eltype,
                                                   default_value=eltype(0))
                 self.local_vars_pointers[element.name] = ptr
-            lbl = codegen.get_block_post_label(asmblock)
-            instr_offsets.append(lbl.offset)
-            self.append_basic_block(lbl)
+            loc_key = codegen.get_block_post_label(asmblock)
+            offset = self.llvm_context.ir_arch.symbol_pool.loc_key_to_offset(loc_key)
+            instr_offsets.append(offset)
+            self.append_basic_block(loc_key)
 
         # Add content
         builder.position_at_end(entry_bbl)
@@ -1375,7 +1356,7 @@ class LLVMFunction():
 
             # Pre-create basic blocks
             for irblock in irblocks:
-                self.append_basic_block(irblock.label, overwrite=False)
+                self.append_basic_block(irblock.loc_key, overwrite=False)
 
             # Generate the corresponding code
             for index, irblock in enumerate(irblocks):
@@ -1383,8 +1364,7 @@ class LLVMFunction():
                     irblock, self.llvm_context.ir_arch.attrib)
 
                 # Set the builder at the begining of the correct bbl
-                name = self.canonize_label_name(new_irblock.label)
-                self.builder.position_at_end(self.get_basic_bloc_by_label(name))
+                self.builder.position_at_end(self.get_basic_block_by_loc_key(new_irblock.loc_key))
 
                 if index == 0:
                     self.gen_pre_code(instr_attrib)
@@ -1395,7 +1375,7 @@ class LLVMFunction():
 
         # Branch entry_bbl on first label
         builder.position_at_end(entry_bbl)
-        first_label_bbl = self.get_basic_bloc_by_label(asmblock.label)
+        first_label_bbl = self.get_basic_block_by_loc_key(asmblock.loc_key)
         builder.branch(first_label_bbl)