about summary refs log tree commit diff stats
path: root/miasm2/core/asmbloc.py
diff options
context:
space:
mode:
Diffstat (limited to 'miasm2/core/asmbloc.py')
-rw-r--r--miasm2/core/asmbloc.py429
1 files changed, 172 insertions, 257 deletions
diff --git a/miasm2/core/asmbloc.py b/miasm2/core/asmbloc.py
index dd18d1f8..612a2e3b 100644
--- a/miasm2/core/asmbloc.py
+++ b/miasm2/core/asmbloc.py
@@ -100,9 +100,9 @@ class asm_constraint_bad(asm_constraint):
             label, c_t=asm_constraint.c_bad)
 
 
-class asm_bloc:
+class asm_bloc(object):
 
-    def __init__(self, label=None, alignment = 1):
+    def __init__(self, label=None, alignment=1):
         self.bto = set()
         self.lines = []
         self.label = label
@@ -587,25 +587,25 @@ def conservative_asm(mnemo, instr, symbols, conservative):
                 return c, candidates
     return candidates[0], candidates
 
-def fix_expr_val(e, symbols):
+def fix_expr_val(expr, symbols):
+    """Resolve an expression @expr using @symbols"""
     def expr_calc(e):
         if isinstance(e, m2_expr.ExprId):
             s = symbols._name2label[e.name]
             e = m2_expr.ExprInt_from(e, s.offset)
         return e
-    e = e.visit(expr_calc)
-    e = expr_simp(e)
-    return e
+    result = expr.visit(expr_calc)
+    result = expr_simp(result)
+    if not isinstance(result, m2_expr.ExprInt):
+        raise RuntimeError('Cannot resolve symbol %s' % expr)
+    return result
 
 
 def guess_blocks_size(mnemo, blocks):
-    """
-    Asm and compute max bloc size
-    """
+    """Asm and compute max block size"""
+
     for block in blocks:
-        log_asmbloc.debug('---')
         size = 0
-        max_size = 0
         for instr in block.lines:
             if isinstance(instr, asm_raw):
                 # for special asm_raw, only extract len
@@ -635,11 +635,14 @@ def guess_blocks_size(mnemo, blocks):
             size += l
 
         block.size = size
-        # bloc with max rel values encoded
-        block.max_size = size + max_size
+        block.max_size = size
         log_asmbloc.info("size: %d max: %d", block.size, block.max_size)
 
 def fix_label_offset(symbol_pool, label, offset, modified):
+    """Fix the @label offset to @offset. If the @offset has changed, add @label
+    to @modified
+    @symbol_pool: current symbol_pool
+    """
     if label.offset == offset:
         return
     symbol_pool.set_offset(label, offset)
@@ -647,58 +650,73 @@ def fix_label_offset(symbol_pool, label, offset, modified):
 
 
 class BlockChain(object):
-    """Manage blocks linked with a "next" constraint"""
+    """Manage blocks linked with an asm_constraint_next"""
 
     def __init__(self, symbol_pool, blocks):
         self.symbol_pool = symbol_pool
         self.blocks = blocks
         self.place()
+
     @property
     def pinned(self):
+        """Return True iff at least one block is pinned"""
         return self.pinned_block_idx is not None
 
-    def get_pinned_block_idx(self):
-        pinned_block_idx = None
+    def _set_pinned_block_idx(self):
+        self.pinned_block_idx = None
         for i, block in enumerate(self.blocks):
             if is_int(block.label.offset):
-                if pinned_block_idx is not None:
+                if self.pinned_block_idx is not None:
                     raise ValueError("Multiples pinned block detected")
-                pinned_block_idx = i
+                self.pinned_block_idx = i
 
-        self.pinned_block_idx = pinned_block_idx
 
     def place(self):
-        self.get_pinned_block_idx()
-        self.max_size = reduce(lambda x, block: x + block.max_size,
-                               self.blocks, 0)
+        """Compute BlockChain min_offset and max_offset using pinned block and
+        blocks' size
+        """
+        self._set_pinned_block_idx()
+        self.max_size = 0
+        for block in self.blocks:
+            self.max_size += block.max_size + block.alignment - 1
 
         # Check if chain has one block pinned
         if not self.pinned:
             return
 
-        size = 0
-        for block in self.blocks[:self.pinned_block_idx]:
-            size += block.max_size
-        self.offset_min = self.blocks[self.pinned_block_idx].label.offset - size
 
-        size = 0
+        offset_base = self.blocks[self.pinned_block_idx].label.offset
+        assert(offset_base % self.blocks[self.pinned_block_idx].alignment == 0)
+
+        self.offset_min = offset_base
+        for block in self.blocks[:self.pinned_block_idx-1:-1]:
+            self.offset_min -= block.max_size + (block.alignment - block.max_size) % block.alignment
+
+        self.offset_max = offset_base
         for block in self.blocks[self.pinned_block_idx:]:
-            size += block.max_size
-        self.offset_max = self.blocks[self.pinned_block_idx].label.offset + size
+            self.offset_max += block.max_size + (block.alignment - block.max_size) % block.alignment
 
     def merge(self, chain):
+        """Best effort merge two block chains
+        Return the list of resulting blockchains"""
         self.blocks += chain.blocks
         self.place()
         return [self]
 
     def fix_blocks(self, modified_labels):
+        """Propagate a pinned to its blocks' neighbour
+        @modified_labels: store new pinned labels"""
+
         if not self.pinned:
             raise ValueError('Trying to fix unpinned block')
+
         # Propagate offset to blocks before pinned block
         pinned_block = self.blocks[self.pinned_block_idx]
         offset = pinned_block.label.offset
-        assert(offset % pinned_block.alignment == 0)
-        for block in self.blocks[self.pinned_block_idx-1:-1:-1]:
+        if offset % pinned_block.alignment != 0:
+            raise RuntimeError('Bad alignment')
+
+        for block in self.blocks[:self.pinned_block_idx-1:-1]:
             new_offset = offset - block.size
             new_offset = new_offset - new_offset % pinned_block.alignment
             fix_label_offset(self.symbol_pool,
@@ -706,70 +724,76 @@ class BlockChain(object):
                              new_offset,
                              modified_labels)
 
-        # Propagate offset to blocks before pinned block
-        pblock = pinned_block
-        offset = pblock.label.offset + pblock.size
+        # Propagate offset to blocks after pinned block
+        offset = pinned_block.label.offset + pinned_block.size
 
+        last_block = pinned_block
         for block in self.blocks[self.pinned_block_idx+1:]:
-            pad = pinned_block.alignment - (offset % pinned_block.alignment)
-            offset += pad % pinned_block.alignment
+            offset += (- offset) % last_block.alignment
             fix_label_offset(self.symbol_pool,
                              block.label,
                              offset,
                              modified_labels)
             offset += block.size
+            last_block = block
         return modified_labels
 
+
 class BlockChainWedge(object):
+    """Stand for wedges between blocks"""
+
     def __init__(self, symbol_pool, offset, size):
         self.symbol_pool = symbol_pool
         self.offset = offset
-        self.max_len = size
+        self.max_size = size
         self.offset_min = offset
         self.offset_max = offset + size
 
     def merge(self, chain):
+        """Best effort merge two block chains
+        Return the list of resulting blockchains"""
         chain.blocks[0].label.offset = self.offset_max
         chain.place()
         return [self, chain]
 
+
 def group_constrained_blocks(symbol_pool, blocks):
     """
-    Return a list of grouped asm blocks linked by "next_constraints"
+    Return the BlockChains list built from grouped asm blocks linked by
+    asm_constraint_next
     @blocks: a list of asm block
-
     """
     log_asmbloc.info('group_constrained_blocks')
 
-    # group adjacent blocks
-    remaining_blocks = blocks[:]
+    # Group adjacent blocks
+    remaining_blocks = list(blocks)
     known_block_chains = {}
     lbl2block = {block.label:block for block in blocks}
 
-
     while remaining_blocks:
         # Create a new block chain
-        block_chain = [remaining_blocks.pop()]
+        block_list = [remaining_blocks.pop()]
 
-        # Find son in remainings blocks linked with a next constraint
+        # Find sons in remainings blocks linked with a next constraint
         while True:
-            next_label = block_chain[-1].get_next()
+            # Get next block
+            next_label = block_list[-1].get_next()
             if next_label is None or next_label not in lbl2block:
                 break
             next_block = lbl2block[next_label]
-            if next_block in remaining_blocks:
-                block_chain.append(next_block)
-                remaining_blocks.remove(next_block)
-                next_label = next_block.get_next()
-            else:
+
+            # Add the block at the end of the current chain
+            if next_block not in remaining_blocks:
                 break
+            block_list.append(next_block)
+            remaining_blocks.remove(next_block)
 
-        # Check if son is in a known block group:
+        # Check if son is in a known block group
         if next_label is not None and next_label in known_block_chains:
-            block_chain += known_block_chains[next_label]
+            block_list += known_block_chains[next_label]
             del known_block_chains[next_label]
 
-        known_block_chains[block_chain[0].label] = block_chain
+        known_block_chains[block_list[0].label] = block_list
 
     out_block_chains = []
     for label in known_block_chains:
@@ -777,105 +801,45 @@ def group_constrained_blocks(symbol_pool, blocks):
         out_block_chains.append(chain)
     return out_block_chains
 
-def add_dont_erase(f, dont_erase=[]):
-    tmp_symbol_pool = asm_symbol_pool()
-    for a, b in dont_erase:
-        l = tmp_symbol_pool.add_label(a, a)
-        l.offset_min = a
-        f[l] = b - a
-    return
-
-
-def gen_non_free_mapping(blockChains, dont_erase=[]):
-    non_free_mapping = {}
-    # calculate free space for bloc placing
-    for chain in blockChains:
-        # if a label in the group is fixed
-        diff_offset = 0
-        for block in chain.blocks:
-            if not is_int(block.label.offset):
-                diff_offset += b.size_max
-                continue
-            chain.pinned = True
-            chain.offset_min = block.label.offset - diff_offset
-            break
-        if chain.pinned:
-            non_free_mapping[chain] = chain.chain_max_size
-
-    log_asmbloc.debug("non free bloc:")
-    log_asmbloc.debug(non_free_mapping)
-    add_dont_erase(non_free_mapping, dont_erase)
-    log_asmbloc.debug("non free more:")
-    log_asmbloc.debug(non_free_mapping)
-    return non_free_mapping
-
-
-
-class AsmBlockLink(object):
-    """Location contraint between blocks"""
-
-    def __init__(self, label):
-        self.label = label
-
-    def resolve(self, parent_label, label2block):
-        """
-        Resolve the @parent_label.offset_g
-        @parent_label: parent label
-        @label2block: dictionnary which links labels to blocks
-        """
-        raise NotImplementedError("Abstract method")
-
-class AsmBlockLinkNext(AsmBlockLink):
-
-    def resolve(self, parent_label, label2block):
-        parent_label.offset_g = self.label.offset_g + label2block[self.label].size
-
-class AsmBlockLinkPrev(AsmBlockLink):
-
-    def resolve(self, parent_label, label2block):
-        parent_label.offset_g = self.label.offset_g - label2block[parent_label].size
-
 
 def get_blockchains_address_interval(blockChains, dst_interval):
+    """Compute the interval used by the pinned @blockChains
+    Check if the placed chains are in the @dst_interval"""
+
     allocated_interval = interval()
     for chain in blockChains:
         if not chain.pinned:
             continue
         chain_interval = interval([(chain.offset_min, chain.offset_max-1)])
-        if (dst_interval - chain_interval).hull() == (None, None):
+        if chain_interval not in dst_interval:
             raise ValueError('Chain placed out of destination interval')
         allocated_interval += chain_interval
     return allocated_interval
 
-def resolve_symbol(blockChains, blocks, symbol_pool, dst_interval=None):
-    """
-    place all asmblocks
-    """
+def resolve_symbol(blockChains, symbol_pool, dst_interval=None):
+    """Place @blockChains in the @dst_interval"""
+
     log_asmbloc.info('resolve_symbol')
     if dst_interval is None:
         dst_interval = interval([(0, 0xFFFFFFFFFFFFFFFF)])
 
     forbidden_interval = interval([(-1, 0xFFFFFFFFFFFFFFFF+1)]) - dst_interval
-
-    bloc_list = []
-    unr_bloc = blocks[:]
-
     allocated_interval = get_blockchains_address_interval(blockChains,
                                                           dst_interval)
-    log_asmbloc.debug('allocated interval: %s'%allocated_interval)
+    log_asmbloc.debug('allocated interval: %s', allocated_interval)
 
     pinned_chains = [chain for chain in blockChains if chain.pinned]
 
     # Add wedge in forbidden intervals
-    for a, b in forbidden_interval.intervals:
-        wedge = BlockChainWedge(symbol_pool, offset=a, size=b+1-a)
+    for start, stop in forbidden_interval.intervals:
+        wedge = BlockChainWedge(symbol_pool, offset=start, size=stop+1-start)
         pinned_chains.append(wedge)
 
-    pinned_chains.sort(key=lambda x:x.offset_min)
     # Try to place bigger blockChains first
+    pinned_chains.sort(key=lambda x:x.offset_min)
     blockChains.sort(key=lambda x:-x.max_size)
 
-    fixed_chains = pinned_chains[:]
+    fixed_chains = list(pinned_chains)
 
     log_asmbloc.debug("place chains")
     for chain in blockChains:
@@ -886,69 +850,86 @@ def resolve_symbol(blockChains, blocks, symbol_pool, dst_interval=None):
             prev_chain = fixed_chains[i-1]
             next_chain = fixed_chains[i]
 
-            if prev_chain.offset_max + chain.max_size <= next_chain.offset_min:
+            if prev_chain.offset_max + chain.max_size < next_chain.offset_min:
                 new_chains = prev_chain.merge(chain)
                 fixed_chains[i-1:i] = new_chains
                 fixed = True
                 break
-        assert(fixed)
-
-    final_chains = [chain for chain in fixed_chains if isinstance(chain, BlockChain)]
-    return final_chains
+        if not fixed:
+            raise RuntimeError('Cannot find enough space to place blocks')
 
-def calc_symbol_offset(symbol_pool, blocks):
-    """Resolve dependencies between @blocks"""
-
-    # Labels resolved
-    pinned_labels = set()
-    # Link an unreferenced label to its reference label
-    linked_labels = {}
-    # Label -> block
-    label2block = dict((block.label, block) for block in blocks)
-
-    # Find pinned labels and labels to resolve
-    for label in symbol_pool.items:
-        if label.offset is None:
-            pass
-        elif is_int(label.offset):
-            pinned_labels.add(label)
-        elif isinstance(label.offset, AsmBlockLink):
-            # construct dependant blocks tree
-            linked_labels.setdefault(label.offset.label, set()).add(label)
-        else:
-            raise ValueError('Unknown offset type')
-        label.offset_g = label.offset
-
-    # Resolve labels
-    while pinned_labels:
-        ref_label = pinned_labels.pop()
-        for unresolved_label in linked_labels.get(ref_label, []):
-            if ref_label.offset_g is None:
-                raise ValueError("unknown symbol: %s" % str(ref_label.name))
-            unresolved_label.offset.resolve(unresolved_label, label2block)
-            pinned_labels.add(unresolved_label)
+    return [chain for chain in fixed_chains if isinstance(chain, BlockChain)]
 
 def filter_exprid_label(exprs):
+    """Extract labels from list of ExprId @exprs"""
     return set(expr.name for expr in exprs if isinstance(expr.name, asm_label))
 
 def get_block_labels(block):
+    """Extract labels used by @block"""
     symbols = set()
     for instr in block.lines:
         if isinstance(instr, asm_raw):
             if isinstance(instr.raw, list):
-                for x in instr.raw:
-                    symbols.update(m2_expr.get_expr_ids(x))
+                for expr in instr.raw:
+                    symbols.update(m2_expr.get_expr_ids(expr))
         else:
             for arg in instr.args:
                 symbols.update(m2_expr.get_expr_ids(arg))
     labels = filter_exprid_label(symbols)
     return labels
 
-def asmbloc_final(mnemo, blocks, blockChains, symbol_pool, symb_reloc_off=None,
-                  conservative=False):
-    log_asmbloc.debug("asmbloc_final")
+def assemble_block(mnemo, block, symbol_pool, conservative=False):
+    """Assemble a @block using @symbol_pool
+    @conservative: (optional) use original bytes when possible
+    """
+    offset_i = 0
+
+    for instr in block.lines:
+        if isinstance(instr, asm_raw):
+            if isinstance(instr.raw, list):
+                # Fix special asm_raw
+                data = ""
+                for expr in instr.raw:
+                    expr_int = fix_expr_val(expr, symbol_pool)
+                    data += pck[expr_int.size](expr_int.arg)
+                instr.data = data
+
+            instr.offset = offset_i
+            offset_i += instr.l
+            continue
+
+        # Assemble an instruction
+        saved_args = list(instr.args)
+        instr.offset = block.label.offset + offset_i
 
+        # Replace instruction's arguments by resolved ones
+        instr.args = instr.resolve_args_with_symbols(symbol_pool)
 
+        if instr.dstflow():
+            instr.fixDstOffset()
+
+        old_l = instr.l
+        cached_candidate, candidates = conservative_asm(
+            mnemo, instr, symbol_pool, conservative)
+
+        # Restore original arguments
+        instr.args = saved_args
+
+        # We need to update the block size
+        block.size = block.size - old_l + len(cached_candidate)
+        instr.data = cached_candidate
+        instr.l = len(cached_candidate)
+
+        offset_i += instr.l
+
+
+def asmbloc_final(mnemo, blocks, blockChains, symbol_pool, conservative=False):
+    """Resolve and assemble @blockChains using @symbol_pool until fixed point is
+    reached"""
+
+    log_asmbloc.debug("asmbloc_final")
+
+    # Init structures
     lbl2block = {block.label:block for block in blocks}
     blocks_using_label = {}
     for block in blocks:
@@ -961,21 +942,17 @@ def asmbloc_final(mnemo, blocks, blockChains, symbol_pool, symb_reloc_off=None,
         for block in chain.blocks:
             block2chain[block] = chain
 
+    # Init worklist
     blocks_to_rework = set(blocks)
-    fini = False
-    while True:
 
-        fini = True
-        my_symb_reloc_off = {}
+    # Fix and re-assemble blocks until fixed point is reached
+    while True:
 
         # Propagate pinned blocks into chains
         modified_labels = set()
         for chain in blockChains:
             chain.fix_blocks(modified_labels)
 
-        if not modified_labels and not blocks_to_rework:
-            break
-
         for label in modified_labels:
             # Retrive block with modified reference
             if label in lbl2block:
@@ -987,101 +964,39 @@ def asmbloc_final(mnemo, blocks, blockChains, symbol_pool, symb_reloc_off=None,
             for block in blocks_using_label[label]:
                 blocks_to_rework.add(block)
 
-        #symbols = asm_symbol_pool()
-        #for s, v in symbol_pool._name2label.items():
-        #    symbols.add_label(s, v.offset_g)
+        # No more work
+        if not blocks_to_rework:
+            break
 
         while blocks_to_rework:
             block = blocks_to_rework.pop()
-            offset_i = 0
-            my_symb_reloc_off[block.label] = []
-
-            len_modified = False
-
-            for instr in block.lines:
-                if isinstance(instr, asm_raw):
-                    if isinstance(instr.raw, list):
-                        # fix special asm_raw
-                        data = ""
-                        for x in instr.raw:
-                            e = fix_expr_val(x, symbol_pool)
-                            data+= pck[e.size](e.arg)
-                        instr.data = data
-
-                    instr.offset = offset_i
-                    offset_i += instr.l
-                    continue
-                sav_a = instr.args[:]
-                instr.offset = block.label.offset + offset_i
-                args_e = instr.resolve_args_with_symbols(symbol_pool)
-                for i, e in enumerate(args_e):
-                    instr.args[i] = e
-
-                if instr.dstflow():
-                    instr.fixDstOffset()
-
-                symbol_reloc_off = []
-                old_l = instr.l
-                c, candidates = conservative_asm(
-                    mnemo, instr, symbol_reloc_off, conservative)
-
-                for i, e in enumerate(sav_a):
-                    instr.args[i] = e
-
-                if len(c) != instr.l:
-                    # good len, bad offset...XXX
-                    block.size = block.size - old_l + len(c)
-                    instr.data = c
-                    instr.l = len(c)
-                    fini = False
-                    len_modified = True
-                    continue
-                found = False
-                for cpos, c in enumerate(candidates):
-                    if len(c) == instr.l:
-                        instr.data = c
-                        instr.l = len(c)
-
-                        found = True
-                        break
-                if not found:
-                    raise ValueError('something wrong in instr.data')
-
-                if cpos < len(symbol_reloc_off):
-                    my_s = symbol_reloc_off[cpos]
-                else:
-                    my_s = None
+            assemble_block(mnemo, block, symbol_pool, conservative)
 
-                if my_s is not None:
-                    my_symb_reloc_off[block.label].append(offset_i + my_s)
-                offset_i += instr.l
-                assert len(instr.data) == instr.l
+def asm_resolve_final(mnemo, blocks, symbol_pool, dst_interval=None):
+    """Resolve and assemble @blocks using @symbol_pool into interval
+    @dst_interval"""
 
-
-def asm_resolve_final(mnemo, blocks, symbol_pool, dst_interval=None,
-                      symb_reloc_off=None):
-    if symb_reloc_off is None:
-        symb_reloc_off = {}
     guess_blocks_size(mnemo, blocks)
     blockChains = group_constrained_blocks(symbol_pool, blocks)
+    resolved_blockChains = resolve_symbol(blockChains, symbol_pool, dst_interval)
 
-    blockChains = resolve_symbol(blockChains, blocks, symbol_pool, dst_interval)
-
-    asmbloc_final(mnemo, blocks, blockChains, symbol_pool, symb_reloc_off)
-    written_bytes = {}
+    asmbloc_final(mnemo, blocks, resolved_blockChains, symbol_pool)
     patches = {}
+    output_interval = interval()
+
     for block in blocks:
         offset = block.label.offset
-        for line in block.lines:
-            assert line.data is not None
-            patches[offset] = line.data
-            for cur_pos in xrange(line.l):
-                if offset + cur_pos in written_bytes:
-                    raise ValueError(
-                        "overlapping bytes in asssembly %X" % int(offset))
-                written_bytes[offset + cur_pos] = 1
-            line.offset = offset
-            offset += line.l
+        for instr in block.lines:
+            if not instr.data:
+                # Empty line
+                continue
+            assert len(instr.data) == instr.l
+            patches[offset] = instr.data
+            instruction_interval = interval([(offset, offset + instr.l-1)])
+            if not (instruction_interval & output_interval).empty:
+                raise RuntimeError("overlapping bytes %X" % int(offset))
+            instr.offset = offset
+            offset += instr.l
     return patches
 
 def blist2graph(ab):