diff options
Diffstat (limited to 'miasm2/core/asmbloc.py')
| -rw-r--r-- | miasm2/core/asmbloc.py | 429 |
1 files changed, 172 insertions, 257 deletions
diff --git a/miasm2/core/asmbloc.py b/miasm2/core/asmbloc.py index dd18d1f8..612a2e3b 100644 --- a/miasm2/core/asmbloc.py +++ b/miasm2/core/asmbloc.py @@ -100,9 +100,9 @@ class asm_constraint_bad(asm_constraint): label, c_t=asm_constraint.c_bad) -class asm_bloc: +class asm_bloc(object): - def __init__(self, label=None, alignment = 1): + def __init__(self, label=None, alignment=1): self.bto = set() self.lines = [] self.label = label @@ -587,25 +587,25 @@ def conservative_asm(mnemo, instr, symbols, conservative): return c, candidates return candidates[0], candidates -def fix_expr_val(e, symbols): +def fix_expr_val(expr, symbols): + """Resolve an expression @expr using @symbols""" def expr_calc(e): if isinstance(e, m2_expr.ExprId): s = symbols._name2label[e.name] e = m2_expr.ExprInt_from(e, s.offset) return e - e = e.visit(expr_calc) - e = expr_simp(e) - return e + result = expr.visit(expr_calc) + result = expr_simp(result) + if not isinstance(result, m2_expr.ExprInt): + raise RuntimeError('Cannot resolve symbol %s' % expr) + return result def guess_blocks_size(mnemo, blocks): - """ - Asm and compute max bloc size - """ + """Asm and compute max block size""" + for block in blocks: - log_asmbloc.debug('---') size = 0 - max_size = 0 for instr in block.lines: if isinstance(instr, asm_raw): # for special asm_raw, only extract len @@ -635,11 +635,14 @@ def guess_blocks_size(mnemo, blocks): size += l block.size = size - # bloc with max rel values encoded - block.max_size = size + max_size + block.max_size = size log_asmbloc.info("size: %d max: %d", block.size, block.max_size) def fix_label_offset(symbol_pool, label, offset, modified): + """Fix the @label offset to @offset. If the @offset has changed, add @label + to @modified + @symbol_pool: current symbol_pool + """ if label.offset == offset: return symbol_pool.set_offset(label, offset) @@ -647,58 +650,73 @@ def fix_label_offset(symbol_pool, label, offset, modified): class BlockChain(object): - """Manage blocks linked with a "next" constraint""" + """Manage blocks linked with an asm_constraint_next""" def __init__(self, symbol_pool, blocks): self.symbol_pool = symbol_pool self.blocks = blocks self.place() + @property def pinned(self): + """Return True iff at least one block is pinned""" return self.pinned_block_idx is not None - def get_pinned_block_idx(self): - pinned_block_idx = None + def _set_pinned_block_idx(self): + self.pinned_block_idx = None for i, block in enumerate(self.blocks): if is_int(block.label.offset): - if pinned_block_idx is not None: + if self.pinned_block_idx is not None: raise ValueError("Multiples pinned block detected") - pinned_block_idx = i + self.pinned_block_idx = i - self.pinned_block_idx = pinned_block_idx def place(self): - self.get_pinned_block_idx() - self.max_size = reduce(lambda x, block: x + block.max_size, - self.blocks, 0) + """Compute BlockChain min_offset and max_offset using pinned block and + blocks' size + """ + self._set_pinned_block_idx() + self.max_size = 0 + for block in self.blocks: + self.max_size += block.max_size + block.alignment - 1 # Check if chain has one block pinned if not self.pinned: return - size = 0 - for block in self.blocks[:self.pinned_block_idx]: - size += block.max_size - self.offset_min = self.blocks[self.pinned_block_idx].label.offset - size - size = 0 + offset_base = self.blocks[self.pinned_block_idx].label.offset + assert(offset_base % self.blocks[self.pinned_block_idx].alignment == 0) + + self.offset_min = offset_base + for block in self.blocks[:self.pinned_block_idx-1:-1]: + self.offset_min -= block.max_size + (block.alignment - block.max_size) % block.alignment + + self.offset_max = offset_base for block in self.blocks[self.pinned_block_idx:]: - size += block.max_size - self.offset_max = self.blocks[self.pinned_block_idx].label.offset + size + self.offset_max += block.max_size + (block.alignment - block.max_size) % block.alignment def merge(self, chain): + """Best effort merge two block chains + Return the list of resulting blockchains""" self.blocks += chain.blocks self.place() return [self] def fix_blocks(self, modified_labels): + """Propagate a pinned to its blocks' neighbour + @modified_labels: store new pinned labels""" + if not self.pinned: raise ValueError('Trying to fix unpinned block') + # Propagate offset to blocks before pinned block pinned_block = self.blocks[self.pinned_block_idx] offset = pinned_block.label.offset - assert(offset % pinned_block.alignment == 0) - for block in self.blocks[self.pinned_block_idx-1:-1:-1]: + if offset % pinned_block.alignment != 0: + raise RuntimeError('Bad alignment') + + for block in self.blocks[:self.pinned_block_idx-1:-1]: new_offset = offset - block.size new_offset = new_offset - new_offset % pinned_block.alignment fix_label_offset(self.symbol_pool, @@ -706,70 +724,76 @@ class BlockChain(object): new_offset, modified_labels) - # Propagate offset to blocks before pinned block - pblock = pinned_block - offset = pblock.label.offset + pblock.size + # Propagate offset to blocks after pinned block + offset = pinned_block.label.offset + pinned_block.size + last_block = pinned_block for block in self.blocks[self.pinned_block_idx+1:]: - pad = pinned_block.alignment - (offset % pinned_block.alignment) - offset += pad % pinned_block.alignment + offset += (- offset) % last_block.alignment fix_label_offset(self.symbol_pool, block.label, offset, modified_labels) offset += block.size + last_block = block return modified_labels + class BlockChainWedge(object): + """Stand for wedges between blocks""" + def __init__(self, symbol_pool, offset, size): self.symbol_pool = symbol_pool self.offset = offset - self.max_len = size + self.max_size = size self.offset_min = offset self.offset_max = offset + size def merge(self, chain): + """Best effort merge two block chains + Return the list of resulting blockchains""" chain.blocks[0].label.offset = self.offset_max chain.place() return [self, chain] + def group_constrained_blocks(symbol_pool, blocks): """ - Return a list of grouped asm blocks linked by "next_constraints" + Return the BlockChains list built from grouped asm blocks linked by + asm_constraint_next @blocks: a list of asm block - """ log_asmbloc.info('group_constrained_blocks') - # group adjacent blocks - remaining_blocks = blocks[:] + # Group adjacent blocks + remaining_blocks = list(blocks) known_block_chains = {} lbl2block = {block.label:block for block in blocks} - while remaining_blocks: # Create a new block chain - block_chain = [remaining_blocks.pop()] + block_list = [remaining_blocks.pop()] - # Find son in remainings blocks linked with a next constraint + # Find sons in remainings blocks linked with a next constraint while True: - next_label = block_chain[-1].get_next() + # Get next block + next_label = block_list[-1].get_next() if next_label is None or next_label not in lbl2block: break next_block = lbl2block[next_label] - if next_block in remaining_blocks: - block_chain.append(next_block) - remaining_blocks.remove(next_block) - next_label = next_block.get_next() - else: + + # Add the block at the end of the current chain + if next_block not in remaining_blocks: break + block_list.append(next_block) + remaining_blocks.remove(next_block) - # Check if son is in a known block group: + # Check if son is in a known block group if next_label is not None and next_label in known_block_chains: - block_chain += known_block_chains[next_label] + block_list += known_block_chains[next_label] del known_block_chains[next_label] - known_block_chains[block_chain[0].label] = block_chain + known_block_chains[block_list[0].label] = block_list out_block_chains = [] for label in known_block_chains: @@ -777,105 +801,45 @@ def group_constrained_blocks(symbol_pool, blocks): out_block_chains.append(chain) return out_block_chains -def add_dont_erase(f, dont_erase=[]): - tmp_symbol_pool = asm_symbol_pool() - for a, b in dont_erase: - l = tmp_symbol_pool.add_label(a, a) - l.offset_min = a - f[l] = b - a - return - - -def gen_non_free_mapping(blockChains, dont_erase=[]): - non_free_mapping = {} - # calculate free space for bloc placing - for chain in blockChains: - # if a label in the group is fixed - diff_offset = 0 - for block in chain.blocks: - if not is_int(block.label.offset): - diff_offset += b.size_max - continue - chain.pinned = True - chain.offset_min = block.label.offset - diff_offset - break - if chain.pinned: - non_free_mapping[chain] = chain.chain_max_size - - log_asmbloc.debug("non free bloc:") - log_asmbloc.debug(non_free_mapping) - add_dont_erase(non_free_mapping, dont_erase) - log_asmbloc.debug("non free more:") - log_asmbloc.debug(non_free_mapping) - return non_free_mapping - - - -class AsmBlockLink(object): - """Location contraint between blocks""" - - def __init__(self, label): - self.label = label - - def resolve(self, parent_label, label2block): - """ - Resolve the @parent_label.offset_g - @parent_label: parent label - @label2block: dictionnary which links labels to blocks - """ - raise NotImplementedError("Abstract method") - -class AsmBlockLinkNext(AsmBlockLink): - - def resolve(self, parent_label, label2block): - parent_label.offset_g = self.label.offset_g + label2block[self.label].size - -class AsmBlockLinkPrev(AsmBlockLink): - - def resolve(self, parent_label, label2block): - parent_label.offset_g = self.label.offset_g - label2block[parent_label].size - def get_blockchains_address_interval(blockChains, dst_interval): + """Compute the interval used by the pinned @blockChains + Check if the placed chains are in the @dst_interval""" + allocated_interval = interval() for chain in blockChains: if not chain.pinned: continue chain_interval = interval([(chain.offset_min, chain.offset_max-1)]) - if (dst_interval - chain_interval).hull() == (None, None): + if chain_interval not in dst_interval: raise ValueError('Chain placed out of destination interval') allocated_interval += chain_interval return allocated_interval -def resolve_symbol(blockChains, blocks, symbol_pool, dst_interval=None): - """ - place all asmblocks - """ +def resolve_symbol(blockChains, symbol_pool, dst_interval=None): + """Place @blockChains in the @dst_interval""" + log_asmbloc.info('resolve_symbol') if dst_interval is None: dst_interval = interval([(0, 0xFFFFFFFFFFFFFFFF)]) forbidden_interval = interval([(-1, 0xFFFFFFFFFFFFFFFF+1)]) - dst_interval - - bloc_list = [] - unr_bloc = blocks[:] - allocated_interval = get_blockchains_address_interval(blockChains, dst_interval) - log_asmbloc.debug('allocated interval: %s'%allocated_interval) + log_asmbloc.debug('allocated interval: %s', allocated_interval) pinned_chains = [chain for chain in blockChains if chain.pinned] # Add wedge in forbidden intervals - for a, b in forbidden_interval.intervals: - wedge = BlockChainWedge(symbol_pool, offset=a, size=b+1-a) + for start, stop in forbidden_interval.intervals: + wedge = BlockChainWedge(symbol_pool, offset=start, size=stop+1-start) pinned_chains.append(wedge) - pinned_chains.sort(key=lambda x:x.offset_min) # Try to place bigger blockChains first + pinned_chains.sort(key=lambda x:x.offset_min) blockChains.sort(key=lambda x:-x.max_size) - fixed_chains = pinned_chains[:] + fixed_chains = list(pinned_chains) log_asmbloc.debug("place chains") for chain in blockChains: @@ -886,69 +850,86 @@ def resolve_symbol(blockChains, blocks, symbol_pool, dst_interval=None): prev_chain = fixed_chains[i-1] next_chain = fixed_chains[i] - if prev_chain.offset_max + chain.max_size <= next_chain.offset_min: + if prev_chain.offset_max + chain.max_size < next_chain.offset_min: new_chains = prev_chain.merge(chain) fixed_chains[i-1:i] = new_chains fixed = True break - assert(fixed) - - final_chains = [chain for chain in fixed_chains if isinstance(chain, BlockChain)] - return final_chains + if not fixed: + raise RuntimeError('Cannot find enough space to place blocks') -def calc_symbol_offset(symbol_pool, blocks): - """Resolve dependencies between @blocks""" - - # Labels resolved - pinned_labels = set() - # Link an unreferenced label to its reference label - linked_labels = {} - # Label -> block - label2block = dict((block.label, block) for block in blocks) - - # Find pinned labels and labels to resolve - for label in symbol_pool.items: - if label.offset is None: - pass - elif is_int(label.offset): - pinned_labels.add(label) - elif isinstance(label.offset, AsmBlockLink): - # construct dependant blocks tree - linked_labels.setdefault(label.offset.label, set()).add(label) - else: - raise ValueError('Unknown offset type') - label.offset_g = label.offset - - # Resolve labels - while pinned_labels: - ref_label = pinned_labels.pop() - for unresolved_label in linked_labels.get(ref_label, []): - if ref_label.offset_g is None: - raise ValueError("unknown symbol: %s" % str(ref_label.name)) - unresolved_label.offset.resolve(unresolved_label, label2block) - pinned_labels.add(unresolved_label) + return [chain for chain in fixed_chains if isinstance(chain, BlockChain)] def filter_exprid_label(exprs): + """Extract labels from list of ExprId @exprs""" return set(expr.name for expr in exprs if isinstance(expr.name, asm_label)) def get_block_labels(block): + """Extract labels used by @block""" symbols = set() for instr in block.lines: if isinstance(instr, asm_raw): if isinstance(instr.raw, list): - for x in instr.raw: - symbols.update(m2_expr.get_expr_ids(x)) + for expr in instr.raw: + symbols.update(m2_expr.get_expr_ids(expr)) else: for arg in instr.args: symbols.update(m2_expr.get_expr_ids(arg)) labels = filter_exprid_label(symbols) return labels -def asmbloc_final(mnemo, blocks, blockChains, symbol_pool, symb_reloc_off=None, - conservative=False): - log_asmbloc.debug("asmbloc_final") +def assemble_block(mnemo, block, symbol_pool, conservative=False): + """Assemble a @block using @symbol_pool + @conservative: (optional) use original bytes when possible + """ + offset_i = 0 + + for instr in block.lines: + if isinstance(instr, asm_raw): + if isinstance(instr.raw, list): + # Fix special asm_raw + data = "" + for expr in instr.raw: + expr_int = fix_expr_val(expr, symbol_pool) + data += pck[expr_int.size](expr_int.arg) + instr.data = data + + instr.offset = offset_i + offset_i += instr.l + continue + + # Assemble an instruction + saved_args = list(instr.args) + instr.offset = block.label.offset + offset_i + # Replace instruction's arguments by resolved ones + instr.args = instr.resolve_args_with_symbols(symbol_pool) + if instr.dstflow(): + instr.fixDstOffset() + + old_l = instr.l + cached_candidate, candidates = conservative_asm( + mnemo, instr, symbol_pool, conservative) + + # Restore original arguments + instr.args = saved_args + + # We need to update the block size + block.size = block.size - old_l + len(cached_candidate) + instr.data = cached_candidate + instr.l = len(cached_candidate) + + offset_i += instr.l + + +def asmbloc_final(mnemo, blocks, blockChains, symbol_pool, conservative=False): + """Resolve and assemble @blockChains using @symbol_pool until fixed point is + reached""" + + log_asmbloc.debug("asmbloc_final") + + # Init structures lbl2block = {block.label:block for block in blocks} blocks_using_label = {} for block in blocks: @@ -961,21 +942,17 @@ def asmbloc_final(mnemo, blocks, blockChains, symbol_pool, symb_reloc_off=None, for block in chain.blocks: block2chain[block] = chain + # Init worklist blocks_to_rework = set(blocks) - fini = False - while True: - fini = True - my_symb_reloc_off = {} + # Fix and re-assemble blocks until fixed point is reached + while True: # Propagate pinned blocks into chains modified_labels = set() for chain in blockChains: chain.fix_blocks(modified_labels) - if not modified_labels and not blocks_to_rework: - break - for label in modified_labels: # Retrive block with modified reference if label in lbl2block: @@ -987,101 +964,39 @@ def asmbloc_final(mnemo, blocks, blockChains, symbol_pool, symb_reloc_off=None, for block in blocks_using_label[label]: blocks_to_rework.add(block) - #symbols = asm_symbol_pool() - #for s, v in symbol_pool._name2label.items(): - # symbols.add_label(s, v.offset_g) + # No more work + if not blocks_to_rework: + break while blocks_to_rework: block = blocks_to_rework.pop() - offset_i = 0 - my_symb_reloc_off[block.label] = [] - - len_modified = False - - for instr in block.lines: - if isinstance(instr, asm_raw): - if isinstance(instr.raw, list): - # fix special asm_raw - data = "" - for x in instr.raw: - e = fix_expr_val(x, symbol_pool) - data+= pck[e.size](e.arg) - instr.data = data - - instr.offset = offset_i - offset_i += instr.l - continue - sav_a = instr.args[:] - instr.offset = block.label.offset + offset_i - args_e = instr.resolve_args_with_symbols(symbol_pool) - for i, e in enumerate(args_e): - instr.args[i] = e - - if instr.dstflow(): - instr.fixDstOffset() - - symbol_reloc_off = [] - old_l = instr.l - c, candidates = conservative_asm( - mnemo, instr, symbol_reloc_off, conservative) - - for i, e in enumerate(sav_a): - instr.args[i] = e - - if len(c) != instr.l: - # good len, bad offset...XXX - block.size = block.size - old_l + len(c) - instr.data = c - instr.l = len(c) - fini = False - len_modified = True - continue - found = False - for cpos, c in enumerate(candidates): - if len(c) == instr.l: - instr.data = c - instr.l = len(c) - - found = True - break - if not found: - raise ValueError('something wrong in instr.data') - - if cpos < len(symbol_reloc_off): - my_s = symbol_reloc_off[cpos] - else: - my_s = None + assemble_block(mnemo, block, symbol_pool, conservative) - if my_s is not None: - my_symb_reloc_off[block.label].append(offset_i + my_s) - offset_i += instr.l - assert len(instr.data) == instr.l +def asm_resolve_final(mnemo, blocks, symbol_pool, dst_interval=None): + """Resolve and assemble @blocks using @symbol_pool into interval + @dst_interval""" - -def asm_resolve_final(mnemo, blocks, symbol_pool, dst_interval=None, - symb_reloc_off=None): - if symb_reloc_off is None: - symb_reloc_off = {} guess_blocks_size(mnemo, blocks) blockChains = group_constrained_blocks(symbol_pool, blocks) + resolved_blockChains = resolve_symbol(blockChains, symbol_pool, dst_interval) - blockChains = resolve_symbol(blockChains, blocks, symbol_pool, dst_interval) - - asmbloc_final(mnemo, blocks, blockChains, symbol_pool, symb_reloc_off) - written_bytes = {} + asmbloc_final(mnemo, blocks, resolved_blockChains, symbol_pool) patches = {} + output_interval = interval() + for block in blocks: offset = block.label.offset - for line in block.lines: - assert line.data is not None - patches[offset] = line.data - for cur_pos in xrange(line.l): - if offset + cur_pos in written_bytes: - raise ValueError( - "overlapping bytes in asssembly %X" % int(offset)) - written_bytes[offset + cur_pos] = 1 - line.offset = offset - offset += line.l + for instr in block.lines: + if not instr.data: + # Empty line + continue + assert len(instr.data) == instr.l + patches[offset] = instr.data + instruction_interval = interval([(offset, offset + instr.l-1)]) + if not (instruction_interval & output_interval).empty: + raise RuntimeError("overlapping bytes %X" % int(offset)) + instr.offset = offset + offset += instr.l return patches def blist2graph(ab): |