about summary refs log tree commit diff stats
path: root/miasm2/core/asmbloc.py
diff options
context:
space:
mode:
Diffstat (limited to 'miasm2/core/asmbloc.py')
-rw-r--r--miasm2/core/asmbloc.py270
1 files changed, 144 insertions, 126 deletions
diff --git a/miasm2/core/asmbloc.py b/miasm2/core/asmbloc.py
index 4770b597..04e2a605 100644
--- a/miasm2/core/asmbloc.py
+++ b/miasm2/core/asmbloc.py
@@ -238,132 +238,122 @@ class asm_bloc:
 
 class asm_symbol_pool:
 
-    def __init__(self, no_collision=True):
-        self.labels = []
-        self.s = {}
-        self.s_offset = {}
-        self.no_collision = no_collision
-        self.label_num = 0
-
-    def add_label(self, name="", offset=None):
+    def __init__(self):
+        self._labels = []
+        self._name2label = {}
+        self._offset2label = {}
+        self._label_num = 0
+
+    def add_label(self, name, offset=None):
         """
-        This should be the only method to create new asm_label objects
+        Create and add a label to the symbol_pool
+        @name: label's name
+        @offset: (optional) label's offset
         """
-        l = asm_label(name, offset)
-        collision = None
-        if l.offset in self.s_offset and l != self.s_offset[l.offset]:
-            collision = 'offset'
-        if l.name in self.s and l != self.s[l.name]:
-            collision = 'name'
-        if self.no_collision and collision == 'offset':
+        label = asm_label(name, offset)
+
+        # Test for collisions
+        if (label.offset in self._offset2label and
+            label != self._offset2label[label.offset]):
             raise ValueError('symbol %s has same offset as %s' %
-                             (l, self.s_offset[l.offset]))
-        if self.no_collision and collision == 'name':
-            raise ValueError(
-                'symbol %s has same name as %s' % (l, self.s[l.name]))
-        self.labels.append(l)
-        if l.offset is not None:
-            self.s_offset[l.offset] = l
-        if l.name != "":
-            self.s[l.name] = l
-        return l
-
-    def remove(self, obj):
+                             (label, self._offset2label[label.offset]))
+        if (label.name in self._name2label and
+            label != self._name2label[label.name]):
+            raise ValueError('symbol %s has same name as %s' %
+                             (label, self._name2label[label.name]))
+
+        self._labels.append(label)
+        if label.offset is not None:
+            self._offset2label[label.offset] = label
+        if label.name != "":
+            self._name2label[label.name] = label
+        return label
+
+    def remove_label(self, label):
         """
-        obj can be an asm_label or an offset
+        Delete a @label
         """
-        if isinstance(obj, asm_label):
-            if obj.name in self.s:
-                del self.s[obj.name]
-            if obj.offset is not None and obj.offset in self.s_offset:
-                del self.s_offset[obj.offset]
-        else:
-            offset = int(obj)
-            if offset in self.s_offset:
-                obj = self.s_offset[offset]
-                del self.s_offset[offset]
-            if obj.name in self.s:
-                del self.s[obj.name]
-
-    def del_offset(self, l=None):
-        if l is not None:
-            if l.offset in self.s_offset:
-                del self.s_offset[l.offset]
-            l.offset = None
-        else:
-            self.s_offset = {}
-            for l in self.s:
-                self.s[l].offset = None
+        self._name2label.pop(label.name, None)
+        self._offset2label.pop(label.offset, None)
+        if label in self._labels:
+            self._labels.remove(label)
+
+    def del_label_offset(self, label):
+        """Unpin the @label from its offset"""
+        self._offset2label.pop(label.offset, None)
+        label.offset = None
 
     def getby_offset(self, offset):
-        return self.s_offset.get(offset, None)
+        """Retrieve label using its @offset"""
+        return self._offset2label.get(offset, None)
 
     def getby_name(self, name):
-        return self.s.get(name, None)
+        """Retrieve label using its @name"""
+        return self._name2label.get(name, None)
 
     def getby_name_create(self, name):
-        l = self.getby_name(name)
-        if l is None:
-            l = self.add_label(name)
-        return l
+        """Get a label from its @name, create it if it doesn't exist"""
+        label = self.getby_name(name)
+        if label is None:
+            label = self.add_label(name)
+        return label
 
     def getby_offset_create(self, offset):
-        l = self.getby_offset(offset)
-        if l is None:
-            l = self.add_label(offset, offset)
-        return l
-
-    def rename(self, s, newname):
-        if not s.name in self.s:
-            log_asmbloc.warn('unk symb')
-            return
-        del self.s[s.name]
-        s.name = newname
-        self.s[s.name] = s
+        """Get a label from its @offset, create it if it doesn't exist"""
+        label = self.getby_offset(offset)
+        if label is None:
+            label = self.add_label(offset, offset)
+        return label
+
+    def rename_label(self, label, newname):
+        """Rename the @label name to @newname"""
+        if newname in self._name2label:
+            raise ValueError('Symbol already known')
+        self._name2label.pop(label.name, None)
+        label.name = newname
+        self._name2label[label.name] = label
 
     def set_offset(self, label, offset):
-        # Note that there is a special case when the offset is a list
-        # it happens when offsets are recomputed in resolve_symbol*
-        if not label.name in self.s:
+        """Pin the @label from at @offset
+        Note that there is a special case when the offset is a list
+        it happens when offsets are recomputed in resolve_symbol*
+        """
+        if not label.name in self._name2label:
             raise ValueError('label %s not in symbol pool' % label)
-        if not isinstance(label.offset, list) and label.offset in self.s_offset:
-            del self.s_offset[label.offset]
+        self._offset2label.pop(label.offset, None)
         label.offset = offset
-        if not isinstance(label.offset, list):
-            self.s_offset[label.offset] = label
+        if is_int(label.offset):
+            self._offset2label[label.offset] = label
 
+    @property
     def items(self):
-        return self.labels[:]
+        """Return all labels"""
+        return self._labels
 
     def __str__(self):
-        return reduce(lambda x, y: x + str(y) + '\n', self.labels, "")
-
-    def __in__(self, obj):
-        if obj in self.s:
-            return True
-        if obj in self.s_offset:
-            return True
-        return False
+        return reduce(lambda x, y: x + str(y) + '\n', self._labels, "")
 
     def __getitem__(self, item):
-        if item in self.s:
-            return self.s[item]
-        if item in self.s_offset:
-            return self.s_offset[item]
+        if item in self._name2label:
+            return self._name2label[item]
+        if item in self._offset2label:
+            return self._offset2label[item]
         raise KeyError('unknown symbol %r' % item)
 
     def __contains__(self, item):
-        return item in self.s or item in self.s_offset
+        return item in self._name2label or item in self._offset2label
 
     def merge(self, symbol_pool):
-        self.labels += symbol_pool.labels
-        self.s.update(symbol_pool.s)
-        self.s_offset.update(symbol_pool.s_offset)
+        """Merge with another @symbol_pool"""
+        self._labels += symbol_pool._labels
+        self._name2label.update(symbol_pool._name2label)
+        self._offset2label.update(symbol_pool._offset2label)
 
     def gen_label(self):
-        l = self.add_label("lbl_gen_%.8X" % (self.label_num))
-        self.label_num += 1
-        return l
+        """Generate a new unpinned label"""
+        label = self.add_label("lbl_gen_%.8X" % (self._label_num))
+        self._label_num += 1
+        return label
 
 
 def dis_bloc(mnemo, pool_bin, cur_bloc, offset, job_done, symbol_pool,
@@ -475,7 +465,7 @@ def split_bloc(mnemo, attrib, pool_bin, blocs,
         more_ref = []
 
     # get all possible dst
-    bloc_dst = [symbol_pool.s_offset[x] for x in more_ref]
+    bloc_dst = [symbol_pool._offset2label[x] for x in more_ref]
     for b in blocs:
         for c in b.bto:
             if not isinstance(c.label, asm_label):
@@ -625,7 +615,7 @@ def conservative_asm(mnemo, instr, symbols, conservative):
 def fix_expr_val(e, symbols):
     def expr_calc(e):
         if isinstance(e, m2_expr.ExprId):
-            s = symbols.s[e.name]
+            s = symbols._name2label[e.name]
             e = m2_expr.ExprInt_from(e, s.offset)
         return e
     e = e.visit(expr_calc)
@@ -785,6 +775,31 @@ def gen_non_free_mapping(group_bloc, dont_erase=[]):
     return non_free_mapping
 
 
+
+class AsmBlockLink(object):
+    """Location contraint between blocks"""
+
+    def __init__(self, label):
+        self.label = label
+
+    def resolve(self, parent_label, label2block):
+        """
+        Resolve the @parent_label.offset_g
+        @parent_label: parent label
+        @label2block: dictionnary which links labels to blocks
+        """
+        raise NotImplementedError("Abstract method")
+
+class AsmBlockLinkNext(AsmBlockLink):
+
+    def resolve(self, parent_label, label2block):
+        parent_label.offset_g = self.label.offset_g + label2block[self.label].blen
+
+class AsmBlockLinkPrev(AsmBlockLink):
+
+    def resolve(self, parent_label, label2block):
+        parent_label.offset_g = self.label.offset_g - label2block[parent_label].blen
+
 def resolve_symbol(group_bloc, symbol_pool, dont_erase=[],
                    max_offset=0xFFFFFFFF):
     """
@@ -828,7 +843,7 @@ def resolve_symbol(group_bloc, symbol_pool, dont_erase=[],
                         free_interval[g] = tmp
                         del free_interval[x]
                         symbol_pool.set_offset(
-                            g, [group_bloc[x][-1].label, group_bloc[x][-1], 1])
+                            g, AsmBlockLinkNext(group_bloc[x][-1].label))
                         g.fixedblocs = True
                         finish = True
                         break
@@ -850,7 +865,7 @@ def resolve_symbol(group_bloc, symbol_pool, dont_erase=[],
             if g.total_max_l > free_interval[k]:
                 continue
             symbol_pool.set_offset(
-                g, [group_bloc[k][-1].label, group_bloc[k][-1], 1])
+                g, AsmBlockLinkNext(group_bloc[k][-1].label))
             tmp = free_interval[k] - g.total_max_l
             log_asmbloc.debug(
                 "consumed %d rest: %d" % (g.total_max_l, int(tmp)))
@@ -883,11 +898,11 @@ def resolve_symbol(group_bloc, symbol_pool, dont_erase=[],
             if index > 0 and my_group[index - 1] in unr_bloc:
                 symbol_pool.set_offset(
                     my_group[index - 1].label,
-                    [unr_bloc[i].label, unr_bloc[i - 1], -1])
+                    AsmBlockLinkPrev(unr_bloc[i].label))
             if index < len(my_group) - 1 and my_group[index + 1] in unr_bloc:
                 symbol_pool.set_offset(
                     my_group[index + 1].label,
-                    [unr_bloc[i].label, unr_bloc[i], 1])
+                    AsmBlockLinkNext(unr_bloc[i].label))
             del unr_bloc[i]
 
         if not resolving:
@@ -905,34 +920,37 @@ def resolve_symbol(group_bloc, symbol_pool, dont_erase=[],
     return bloc_list
 
 
-def calc_symbol_offset(symbol_pool):
-    s_to_use = set()
+def calc_symbol_offset(symbol_pool, blocks):
+    """Resolve dependencies between @blocks"""
 
-    s_dependent = {}
+    # Labels resolved
+    pinned_labels = set()
+    # Link an unreferenced label to its reference label
+    linked_labels = {}
+    # Label -> block
+    label2block = dict((block.label, block) for block in blocks)
 
-    for label in symbol_pool.items():
+    # Find pinned labels and labels to resolve
+    for label in symbol_pool.items:
         if label.offset is None:
-            label.offset_g = None
-            continue
-        if not is_int(label.offset):
+            pass
+        elif is_int(label.offset):
+            pinned_labels.add(label)
+        elif isinstance(label.offset, AsmBlockLink):
             # construct dependant blocs tree
-            s_d = label.offset[0]
-            if not s_d in s_dependent:
-                s_dependent[s_d] = set()
-            s_dependent[s_d].add(label)
+            linked_labels.setdefault(label.offset.label, set()).add(label)
         else:
-            s_to_use.add(label)
+            raise ValueError('Unknown offset type')
         label.offset_g = label.offset
 
-    while s_to_use:
-        label = s_to_use.pop()
-        if not label in s_dependent:
-            continue
-        for l in s_dependent[label]:
-            if label.offset_g is None:
-                raise ValueError("unknown symbol: %s" % str(label.name))
-            l.offset_g = label.offset_g + l.offset_g[1].blen * l.offset_g[2]
-            s_to_use.add(l)
+    # Resolve labels
+    while pinned_labels:
+        ref_label = pinned_labels.pop()
+        for unresolved_label in linked_labels.get(ref_label, []):
+            if ref_label.offset_g is None:
+                raise ValueError("unknown symbol: %s" % str(ref_label.name))
+            unresolved_label.offset.resolve(unresolved_label, label2block)
+            pinned_labels.add(unresolved_label)
 
 
 def asmbloc_final(mnemo, blocs, symbol_pool, symb_reloc_off=None,
@@ -950,10 +968,10 @@ def asmbloc_final(mnemo, blocs, symbol_pool, symb_reloc_off=None,
         fini = True
         my_symb_reloc_off = {}
 
-        calc_symbol_offset(symbol_pool)
+        calc_symbol_offset(symbol_pool, blocs)
 
         symbols = asm_symbol_pool()
-        for s, v in symbol_pool.s.items():
+        for s, v in symbol_pool._name2label.items():
             symbols.add_label(s, v.offset_g)
         # test if bad encoded relative
         for bloc in blocs:
@@ -1018,7 +1036,7 @@ def asmbloc_final(mnemo, blocs, symbol_pool, symb_reloc_off=None,
                 assert len(instr.data) == instr.l
     # we have fixed all relative values
     # recompute good offsets
-    for label in symbol_pool.items():
+    for label in symbol_pool.items:
         symbol_pool.set_offset(label, label.offset_g)
 
     for a, b in my_symb_reloc_off.items():