about summary refs log tree commit diff stats
path: root/miasm2/core/asmbloc.py
diff options
context:
space:
mode:
Diffstat (limited to 'miasm2/core/asmbloc.py')
-rw-r--r--miasm2/core/asmbloc.py171
1 files changed, 108 insertions, 63 deletions
diff --git a/miasm2/core/asmbloc.py b/miasm2/core/asmbloc.py
index 1a2d7a91..31e4bdd7 100644
--- a/miasm2/core/asmbloc.py
+++ b/miasm2/core/asmbloc.py
@@ -3,6 +3,7 @@
 
 import logging
 import inspect
+import re
 
 
 import miasm2.expression.expression as m2_expr
@@ -18,6 +19,7 @@ console_handler.setFormatter(logging.Formatter("%(levelname)-5s: %(message)s"))
 log_asmbloc.addHandler(console_handler)
 log_asmbloc.setLevel(logging.WARNING)
 
+
 def is_int(a):
     return isinstance(a, int) or isinstance(a, long) or \
         isinstance(a, moduint) or isinstance(a, modint)
@@ -33,6 +35,7 @@ def expr_is_int_or_label(e):
 
 
 class asm_label:
+
     "Stand for an assembly label"
 
     def __init__(self, name="", offset=None):
@@ -59,13 +62,16 @@ class asm_label:
         rep += '>'
         return rep
 
+
 class asm_raw:
+
     def __init__(self, raw=""):
         self.raw = raw
 
     def __str__(self):
         return repr(self.raw)
 
+
 class asm_constraint(object):
     c_to = "c_to"
     c_next = "c_next"
@@ -194,7 +200,6 @@ class asm_bloc(object):
             if l.splitflow() or l.breakflow():
                 raise NotImplementedError('not fully functional')
 
-
     def get_subcall_instr(self):
         if not self.lines:
             return None
@@ -228,11 +233,11 @@ class asm_symbol_pool:
 
         # Test for collisions
         if (label.offset in self._offset2label and
-            label != self._offset2label[label.offset]):
+                label != self._offset2label[label.offset]):
             raise ValueError('symbol %s has same offset as %s' %
                              (label, self._offset2label[label.offset]))
         if (label.name in self._name2label and
-            label != self._name2label[label.name]):
+                label != self._name2label[label.name]):
             raise ValueError('symbol %s has same name as %s' %
                              (label, self._name2label[label.name]))
 
@@ -438,7 +443,7 @@ def dis_bloc(mnemo, pool_bin, cur_bloc, offset, job_done, symbol_pool,
 
 
 def split_bloc(mnemo, attrib, pool_bin, blocs,
-    symbol_pool, more_ref=None, dis_bloc_callback=None):
+               symbol_pool, more_ref=None, dis_bloc_callback=None):
     if not more_ref:
         more_ref = []
 
@@ -472,7 +477,7 @@ def split_bloc(mnemo, attrib, pool_bin, blocs,
             if dis_bloc_callback:
                 offsets_to_dis = set(
                     [x.label.offset for x in new_b.bto
-                    if isinstance(x.label, asm_label)])
+                     if isinstance(x.label, asm_label)])
                 dis_bloc_callback(
                     mnemo, attrib, pool_bin, new_b, offsets_to_dis,
                     symbol_pool)
@@ -481,6 +486,7 @@ def split_bloc(mnemo, attrib, pool_bin, blocs,
 
     return blocs
 
+
 def dis_bloc_all(mnemo, pool_bin, offset, job_done, symbol_pool, dont_dis=[],
                  split_dis=[], follow_call=False, dontdis_retcall=False,
                  blocs_wd=None, lines_wd=None, blocs=None,
@@ -527,48 +533,75 @@ def dis_bloc_all(mnemo, pool_bin, offset, job_done, symbol_pool, dont_dis=[],
         blocs.append(cur_bloc)
 
     return split_bloc(mnemo, attrib, pool_bin, blocs,
-    symbol_pool, dis_bloc_callback=dis_bloc_callback)
-
-
-def bloc2graph(blocs, label=False, lines=True):
-    # rankdir=LR;
-    out = """
-digraph asm_graph {
-size="80,50";
-node [
-fontsize = "16",
-shape = "box"
-];
-"""
-    for b in blocs:
-        out += '%s [\n' % b.label.name
-        out += 'label = "'
+                      symbol_pool, dis_bloc_callback=dis_bloc_callback)
+
 
-        out += b.label.name + "\\l\\\n"
+def bloc2graph(blocks, label=False, lines=True):
+    """Render dot graph of @blocks"""
+
+    escape_chars = re.compile('[' + re.escape('{}') + ']')
+    label_attr = 'colspan="2" align="center" bgcolor="grey"'
+    edge_attr = 'label = "%s" color="%s" style="bold"'
+    td_attr = 'align="left"'
+    block_attr = 'shape="Mrecord" fontname="Courier New"'
+
+    out = ["digraph asm_graph {"]
+    fix_chars = lambda x: '\\' + x.group()
+
+    # Generate basic blocks
+    out_blocks = []
+    for block in blocks:
+        out_block = '%s [\n' % block.label.name
+        out_block += "%s " % block_attr
+        out_block += 'label =<<table border="0" cellborder="0" cellpadding="3">'
+
+        block_label = '<tr><td %s>%s</td></tr>' % (
+            label_attr, block.label.name)
+        block_html_lines = []
         if lines:
-            for l in b.lines:
+            for line in block.lines:
                 if label:
-                    out += "%.8X " % l.offset
-                out += ("%s\\l\\\n" % l).replace('"', '\\"')
-        out += '"\n];\n'
-
-    for b in blocs:
-        for n in b.bto:
-            # print 'xxxx', n.label, n.label.__class__
-            # if isinstance(n.label, ExprId):
-            #    print n.label.name, n.label.name.__class__
-            if isinstance(n.label, m2_expr.ExprId):
-                dst, name, cst = b.label.name, n.label.name, n.c_t
-                # out+='%s -> %s [ label = "%s" ];\n'%(b.label.name,
-                # n.label.name, n.c_t)
-            elif isinstance(n.label, asm_label):
-                dst, name, cst = b.label.name, n.label.name, n.c_t
+                    out_render = "%.8X</td><td %s> " % (line.offset, td_attr)
+                else:
+                    out_render = ""
+                out_render += escape_chars.sub(fix_chars, str(line))
+                block_html_lines.append(out_render)
+        block_html_lines = ('<tr><td %s>' % td_attr +
+                            ('</td></tr><tr><td %s>' % td_attr).join(block_html_lines) +
+                            '</td></tr>')
+        out_block += "%s " % block_label
+        out_block += block_html_lines + "</table>> ];"
+        out_blocks.append(out_block)
+
+    out += out_blocks
+
+    # Generate links
+    for block in blocks:
+        for next_b in block.bto:
+            if (isinstance(next_b.label, m2_expr.ExprId) or
+                    isinstance(next_b.label, asm_label)):
+                src, dst, cst = block.label.name, next_b.label.name, next_b.c_t
             else:
                 continue
-            out += '%s -> %s [ label = "%s" ];\n' % (dst, name, cst)
+            if isinstance(src, asm_label):
+                src = src.name
+            if isinstance(dst, asm_label):
+                dst = dst.name
+
+            edge_color = "black"
+            if next_b.c_t == asm_constraint.c_next:
+                edge_color = "red"
+            elif next_b.c_t == asm_constraint.c_to:
+                edge_color = "limegreen"
+            # special case
+            if len(block.bto) == 1:
+                edge_color = "blue"
 
-    out += "}"
-    return out
+            out.append('%s -> %s' % (src, dst) +
+                       '[' + edge_attr % (cst, edge_color) + '];')
+
+    out.append("}")
+    return '\n'.join(out)
 
 
 def conservative_asm(mnemo, instr, symbols, conservative):
@@ -589,6 +622,7 @@ def conservative_asm(mnemo, instr, symbols, conservative):
                 return c, candidates
     return candidates[0], candidates
 
+
 def fix_expr_val(expr, symbols):
     """Resolve an expression @expr using @symbols"""
     def expr_calc(e):
@@ -616,7 +650,7 @@ def guess_blocks_size(mnemo, blocks):
                     if len(instr.raw) == 0:
                         l = 0
                     else:
-                        l = instr.raw[0].size/8 * len(instr.raw)
+                        l = instr.raw[0].size / 8 * len(instr.raw)
                 elif isinstance(instr.raw, str):
                     data = instr.raw
                     l = len(data)
@@ -640,6 +674,7 @@ def guess_blocks_size(mnemo, blocks):
         block.max_size = size
         log_asmbloc.info("size: %d max: %d", block.size, block.max_size)
 
+
 def fix_label_offset(symbol_pool, label, offset, modified):
     """Fix the @label offset to @offset. If the @offset has changed, add @label
     to @modified
@@ -652,6 +687,7 @@ def fix_label_offset(symbol_pool, label, offset, modified):
 
 
 class BlockChain(object):
+
     """Manage blocks linked with an asm_constraint_next"""
 
     def __init__(self, symbol_pool, blocks):
@@ -672,7 +708,6 @@ class BlockChain(object):
                     raise ValueError("Multiples pinned block detected")
                 self.pinned_block_idx = i
 
-
     def place(self):
         """Compute BlockChain min_offset and max_offset using pinned block and
         blocks' size
@@ -686,17 +721,18 @@ class BlockChain(object):
         if not self.pinned:
             return
 
-
         offset_base = self.blocks[self.pinned_block_idx].label.offset
         assert(offset_base % self.blocks[self.pinned_block_idx].alignment == 0)
 
         self.offset_min = offset_base
-        for block in self.blocks[:self.pinned_block_idx-1:-1]:
-            self.offset_min -= block.max_size + (block.alignment - block.max_size) % block.alignment
+        for block in self.blocks[:self.pinned_block_idx - 1:-1]:
+            self.offset_min -= block.max_size + \
+                (block.alignment - block.max_size) % block.alignment
 
         self.offset_max = offset_base
         for block in self.blocks[self.pinned_block_idx:]:
-            self.offset_max += block.max_size + (block.alignment - block.max_size) % block.alignment
+            self.offset_max += block.max_size + \
+                (block.alignment - block.max_size) % block.alignment
 
     def merge(self, chain):
         """Best effort merge two block chains
@@ -718,7 +754,7 @@ class BlockChain(object):
         if offset % pinned_block.alignment != 0:
             raise RuntimeError('Bad alignment')
 
-        for block in self.blocks[:self.pinned_block_idx-1:-1]:
+        for block in self.blocks[:self.pinned_block_idx - 1:-1]:
             new_offset = offset - block.size
             new_offset = new_offset - new_offset % pinned_block.alignment
             fix_label_offset(self.symbol_pool,
@@ -730,7 +766,7 @@ class BlockChain(object):
         offset = pinned_block.label.offset + pinned_block.size
 
         last_block = pinned_block
-        for block in self.blocks[self.pinned_block_idx+1:]:
+        for block in self.blocks[self.pinned_block_idx + 1:]:
             offset += (- offset) % last_block.alignment
             fix_label_offset(self.symbol_pool,
                              block.label,
@@ -742,6 +778,7 @@ class BlockChain(object):
 
 
 class BlockChainWedge(object):
+
     """Stand for wedges between blocks"""
 
     def __init__(self, symbol_pool, offset, size):
@@ -770,7 +807,7 @@ def group_constrained_blocks(symbol_pool, blocks):
     # Group adjacent blocks
     remaining_blocks = list(blocks)
     known_block_chains = {}
-    lbl2block = {block.label:block for block in blocks}
+    lbl2block = {block.label: block for block in blocks}
 
     while remaining_blocks:
         # Create a new block chain
@@ -812,12 +849,13 @@ def get_blockchains_address_interval(blockChains, dst_interval):
     for chain in blockChains:
         if not chain.pinned:
             continue
-        chain_interval = interval([(chain.offset_min, chain.offset_max-1)])
+        chain_interval = interval([(chain.offset_min, chain.offset_max - 1)])
         if chain_interval not in dst_interval:
             raise ValueError('Chain placed out of destination interval')
         allocated_interval += chain_interval
     return allocated_interval
 
+
 def resolve_symbol(blockChains, symbol_pool, dst_interval=None):
     """Place @blockChains in the @dst_interval"""
 
@@ -825,7 +863,8 @@ def resolve_symbol(blockChains, symbol_pool, dst_interval=None):
     if dst_interval is None:
         dst_interval = interval([(0, 0xFFFFFFFFFFFFFFFF)])
 
-    forbidden_interval = interval([(-1, 0xFFFFFFFFFFFFFFFF+1)]) - dst_interval
+    forbidden_interval = interval(
+        [(-1, 0xFFFFFFFFFFFFFFFF + 1)]) - dst_interval
     allocated_interval = get_blockchains_address_interval(blockChains,
                                                           dst_interval)
     log_asmbloc.debug('allocated interval: %s', allocated_interval)
@@ -834,12 +873,13 @@ def resolve_symbol(blockChains, symbol_pool, dst_interval=None):
 
     # Add wedge in forbidden intervals
     for start, stop in forbidden_interval.intervals:
-        wedge = BlockChainWedge(symbol_pool, offset=start, size=stop+1-start)
+        wedge = BlockChainWedge(
+            symbol_pool, offset=start, size=stop + 1 - start)
         pinned_chains.append(wedge)
 
     # Try to place bigger blockChains first
-    pinned_chains.sort(key=lambda x:x.offset_min)
-    blockChains.sort(key=lambda x:-x.max_size)
+    pinned_chains.sort(key=lambda x: x.offset_min)
+    blockChains.sort(key=lambda x: -x.max_size)
 
     fixed_chains = list(pinned_chains)
 
@@ -849,12 +889,12 @@ def resolve_symbol(blockChains, symbol_pool, dst_interval=None):
             continue
         fixed = False
         for i in xrange(1, len(fixed_chains)):
-            prev_chain = fixed_chains[i-1]
+            prev_chain = fixed_chains[i - 1]
             next_chain = fixed_chains[i]
 
             if prev_chain.offset_max + chain.max_size < next_chain.offset_min:
                 new_chains = prev_chain.merge(chain)
-                fixed_chains[i-1:i] = new_chains
+                fixed_chains[i - 1:i] = new_chains
                 fixed = True
                 break
         if not fixed:
@@ -862,10 +902,12 @@ def resolve_symbol(blockChains, symbol_pool, dst_interval=None):
 
     return [chain for chain in fixed_chains if isinstance(chain, BlockChain)]
 
+
 def filter_exprid_label(exprs):
     """Extract labels from list of ExprId @exprs"""
     return set(expr.name for expr in exprs if isinstance(expr.name, asm_label))
 
+
 def get_block_labels(block):
     """Extract labels used by @block"""
     symbols = set()
@@ -880,6 +922,7 @@ def get_block_labels(block):
     labels = filter_exprid_label(symbols)
     return labels
 
+
 def assemble_block(mnemo, block, symbol_pool, conservative=False):
     """Assemble a @block using @symbol_pool
     @conservative: (optional) use original bytes when possible
@@ -932,7 +975,7 @@ def asmbloc_final(mnemo, blocks, blockChains, symbol_pool, conservative=False):
     log_asmbloc.debug("asmbloc_final")
 
     # Init structures
-    lbl2block = {block.label:block for block in blocks}
+    lbl2block = {block.label: block for block in blocks}
     blocks_using_label = {}
     for block in blocks:
         labels = get_block_labels(block)
@@ -992,7 +1035,8 @@ def sanity_check_blocks(blocks):
             if blocks_graph.blocs[pred].get_next() == label:
                 pred_next.add(pred)
         if len(pred_next) > 1:
-            raise RuntimeError("Too many next constraints for bloc %r"%label)
+            raise RuntimeError("Too many next constraints for bloc %r" % label)
+
 
 def asm_resolve_final(mnemo, blocks, symbol_pool, dst_interval=None):
     """Resolve and assemble @blocks using @symbol_pool into interval
@@ -1002,7 +1046,8 @@ def asm_resolve_final(mnemo, blocks, symbol_pool, dst_interval=None):
 
     guess_blocks_size(mnemo, blocks)
     blockChains = group_constrained_blocks(symbol_pool, blocks)
-    resolved_blockChains = resolve_symbol(blockChains, symbol_pool, dst_interval)
+    resolved_blockChains = resolve_symbol(
+        blockChains, symbol_pool, dst_interval)
 
     asmbloc_final(mnemo, blocks, resolved_blockChains, symbol_pool)
     patches = {}
@@ -1016,13 +1061,14 @@ def asm_resolve_final(mnemo, blocks, symbol_pool, dst_interval=None):
                 continue
             assert len(instr.data) == instr.l
             patches[offset] = instr.data
-            instruction_interval = interval([(offset, offset + instr.l-1)])
+            instruction_interval = interval([(offset, offset + instr.l - 1)])
             if not (instruction_interval & output_interval).empty:
                 raise RuntimeError("overlapping bytes %X" % int(offset))
             instr.offset = offset
             offset += instr.l
     return patches
 
+
 def blist2graph(ab):
     """
     ab: list of asmbloc
@@ -1126,7 +1172,7 @@ def getbloc_parents(blocs, a, level=3, done=None, blocby_label=None):
 
 
 def getbloc_parents_strict(
-    blocs, a, level=3, rez=None, done=None, blocby_label=None):
+        blocs, a, level=3, rez=None, done=None, blocby_label=None):
 
     if not blocby_label:
         blocby_label = {}
@@ -1280,4 +1326,3 @@ class disasmEngine(object):
                              dont_dis_nulstart_bloc=self.dont_dis_nulstart_bloc,
                              attrib=self.attrib)
         return blocs
-