diff options
| -rw-r--r-- | miasm2/core/asmbloc.py | 493 | ||||
| -rw-r--r-- | miasm2/core/parse_asm.py | 14 |
2 files changed, 268 insertions, 239 deletions
diff --git a/miasm2/core/asmbloc.py b/miasm2/core/asmbloc.py index cc05510b..dd18d1f8 100644 --- a/miasm2/core/asmbloc.py +++ b/miasm2/core/asmbloc.py @@ -10,19 +10,14 @@ from miasm2.expression.simplifications import expr_simp from miasm2.expression.modint import moduint, modint from miasm2.core.utils import Disasm_Exception, pck from miasm2.core.graph import DiGraph +from miasm2.core.interval import interval - -log_asmbloc = logging.getLogger("asmbloc") +log_asmbloc = logging.getLogger("asmblock") console_handler = logging.StreamHandler() console_handler.setFormatter(logging.Formatter("%(levelname)-5s: %(message)s")) log_asmbloc.addHandler(console_handler) log_asmbloc.setLevel(logging.WARNING) - -def whoami(): - return inspect.stack()[2][3] - - def is_int(a): return isinstance(a, int) or isinstance(a, long) or \ isinstance(a, moduint) or isinstance(a, modint) @@ -50,7 +45,6 @@ class asm_label: self.offset = offset else: self.offset = int(offset) - self._hash = hash((self.name, self.offset)) def __str__(self): if isinstance(self.offset, (int, long)): @@ -65,16 +59,6 @@ class asm_label: rep += '>' return rep - def __hash__(self): - return self._hash - - def __eq__(self, a): - if isinstance(a, asm_label): - return self._hash == hash(a) - else: - return False - - class asm_raw: def __init__(self, raw=""): self.raw = raw @@ -82,7 +66,6 @@ class asm_raw: def __str__(self): return repr(self.raw) - class asm_constraint(object): c_to = "c_to" c_next = "c_next" @@ -91,20 +74,10 @@ class asm_constraint(object): def __init__(self, label=None, c_t=c_to): self.label = label self.c_t = c_t - self._hash = hash((self.label, self.c_t)) def __str__(self): return "%s:%s" % (str(self.c_t), str(self.label)) - def __hash__(self): - return self._hash - - def __eq__(self, a): - if isinstance(a, asm_constraint): - return self._hash == a._hash - else: - return False - class asm_constraint_next(asm_constraint): @@ -129,10 +102,11 @@ class asm_constraint_bad(asm_constraint): class asm_bloc: - def __init__(self, label=None): + def __init__(self, label=None, alignment = 1): self.bto = set() self.lines = [] self.label = label + self.alignment = alignment def __str__(self): out = [] @@ -320,6 +294,8 @@ class asm_symbol_pool: """ if not label.name in self._name2label: raise ValueError('label %s not in symbol pool' % label) + if offset is not None and offset in self._offset2label: + raise ValueError('Conflict in label %s' % label) self._offset2label.pop(label.offset, None) label.offset = offset if is_int(label.offset): @@ -503,7 +479,6 @@ def split_bloc(mnemo, attrib, pool_bin, blocs, return blocs - def dis_bloc_all(mnemo, pool_bin, offset, job_done, symbol_pool, dont_dis=[], split_dis=[], follow_call=False, dontdis_retcall=False, blocs_wd=None, lines_wd=None, blocs=None, @@ -623,15 +598,15 @@ def fix_expr_val(e, symbols): return e -def guess_blocs_size(mnemo, blocs): +def guess_blocks_size(mnemo, blocks): """ Asm and compute max bloc size """ - for b in blocs: + for block in blocks: log_asmbloc.debug('---') size = 0 max_size = 0 - for instr in b.lines: + for instr in block.lines: if isinstance(instr, asm_raw): # for special asm_raw, only extract len if isinstance(instr.raw, list): @@ -646,27 +621,127 @@ def guess_blocs_size(mnemo, blocs): else: raise NotImplementedError('asm raw') else: - l = mnemo.max_instruction_len + # Assemble the instruction to retrieve its len. + # If the instruction uses symbol it will fail + # In this case, the max_instruction_len is used + try: + candidates = mnemo.asm(instr) + l = len(candidates[-1]) + except: + l = mnemo.max_instruction_len data = None instr.data = data instr.l = l size += l - b.size = size + block.size = size # bloc with max rel values encoded - b.size_max = size + max_size - log_asmbloc.info("size: %d max: %d", b.size, b.size_max) + block.max_size = size + max_size + log_asmbloc.info("size: %d max: %d", block.size, block.max_size) + +def fix_label_offset(symbol_pool, label, offset, modified): + if label.offset == offset: + return + symbol_pool.set_offset(label, offset) + modified.add(label) + +class BlockChain(object): + """Manage blocks linked with a "next" constraint""" -def group_constrained_blocs(blocks): + def __init__(self, symbol_pool, blocks): + self.symbol_pool = symbol_pool + self.blocks = blocks + self.place() + @property + def pinned(self): + return self.pinned_block_idx is not None + + def get_pinned_block_idx(self): + pinned_block_idx = None + for i, block in enumerate(self.blocks): + if is_int(block.label.offset): + if pinned_block_idx is not None: + raise ValueError("Multiples pinned block detected") + pinned_block_idx = i + + self.pinned_block_idx = pinned_block_idx + + def place(self): + self.get_pinned_block_idx() + self.max_size = reduce(lambda x, block: x + block.max_size, + self.blocks, 0) + + # Check if chain has one block pinned + if not self.pinned: + return + + size = 0 + for block in self.blocks[:self.pinned_block_idx]: + size += block.max_size + self.offset_min = self.blocks[self.pinned_block_idx].label.offset - size + + size = 0 + for block in self.blocks[self.pinned_block_idx:]: + size += block.max_size + self.offset_max = self.blocks[self.pinned_block_idx].label.offset + size + + def merge(self, chain): + self.blocks += chain.blocks + self.place() + return [self] + + def fix_blocks(self, modified_labels): + if not self.pinned: + raise ValueError('Trying to fix unpinned block') + # Propagate offset to blocks before pinned block + pinned_block = self.blocks[self.pinned_block_idx] + offset = pinned_block.label.offset + assert(offset % pinned_block.alignment == 0) + for block in self.blocks[self.pinned_block_idx-1:-1:-1]: + new_offset = offset - block.size + new_offset = new_offset - new_offset % pinned_block.alignment + fix_label_offset(self.symbol_pool, + block.label, + new_offset, + modified_labels) + + # Propagate offset to blocks before pinned block + pblock = pinned_block + offset = pblock.label.offset + pblock.size + + for block in self.blocks[self.pinned_block_idx+1:]: + pad = pinned_block.alignment - (offset % pinned_block.alignment) + offset += pad % pinned_block.alignment + fix_label_offset(self.symbol_pool, + block.label, + offset, + modified_labels) + offset += block.size + return modified_labels + +class BlockChainWedge(object): + def __init__(self, symbol_pool, offset, size): + self.symbol_pool = symbol_pool + self.offset = offset + self.max_len = size + self.offset_min = offset + self.offset_max = offset + size + + def merge(self, chain): + chain.blocks[0].label.offset = self.offset_max + chain.place() + return [self, chain] + +def group_constrained_blocks(symbol_pool, blocks): """ Return a list of grouped asm blocks linked by "next_constraints" @blocks: a list of asm block """ - log_asmbloc.info('group_constrained_blocs') + log_asmbloc.info('group_constrained_blocks') - # group adjacent blocs + # group adjacent blocks remaining_blocks = blocks[:] known_block_chains = {} lbl2block = {block.label:block for block in blocks} @@ -696,71 +771,36 @@ def group_constrained_blocs(blocks): known_block_chains[block_chain[0].label] = block_chain - # Compute max len for each block chain + out_block_chains = [] for label in known_block_chains: - label.chain_max_size = reduce(lambda x, block: x + block.size_max, - known_block_chains[label], 0) - log_asmbloc.debug(("offset maxlen", label.offset, label.chain_max_size)) - - return known_block_chains - - -def gen_free_space_intervals(f, max_offset=0xFFFFFFFF): - interval = {} - offset_label = dict([(x.offset_free, x) for x in f]) - offset_label_order = offset_label.keys() - offset_label_order.sort() - offset_label_order.append(max_offset) - offset_label_order.reverse() - - unfree_stop = 0L - while len(offset_label_order) > 1: - offset = offset_label_order.pop() - offset_end = offset + f[offset_label[offset]] - prev = 0 - if unfree_stop > offset_end: - space = 0 - else: - space = offset_label_order[-1] - offset_end - if space < 0: - space = 0 - interval[offset_label[offset]] = space - if offset_label_order[-1] in offset_label: - prev = offset_label[offset_label_order[-1]] - prev = f[prev] - - interval[offset_label[offset]] = space - - unfree_stop = max( - unfree_stop, offset_end, offset_label_order[-1] + prev) - return interval - + chain = BlockChain(symbol_pool, known_block_chains[label]) + out_block_chains.append(chain) + return out_block_chains def add_dont_erase(f, dont_erase=[]): tmp_symbol_pool = asm_symbol_pool() for a, b in dont_erase: l = tmp_symbol_pool.add_label(a, a) - l.offset_free = a + l.offset_min = a f[l] = b - a return -def gen_non_free_mapping(group_bloc, dont_erase=[]): +def gen_non_free_mapping(blockChains, dont_erase=[]): non_free_mapping = {} # calculate free space for bloc placing - for g in group_bloc: - g.fixedblocs = False + for chain in blockChains: # if a label in the group is fixed diff_offset = 0 - for b in group_bloc[g]: - if not is_int(b.label.offset): + for block in chain.blocks: + if not is_int(block.label.offset): diff_offset += b.size_max continue - g.fixedblocs = True - g.offset_free = b.label.offset - diff_offset + chain.pinned = True + chain.offset_min = block.label.offset - diff_offset break - if g.fixedblocs: - non_free_mapping[g] = g.chain_max_size + if chain.pinned: + non_free_mapping[chain] = chain.chain_max_size log_asmbloc.debug("non free bloc:") log_asmbloc.debug(non_free_mapping) @@ -795,125 +835,66 @@ class AsmBlockLinkPrev(AsmBlockLink): def resolve(self, parent_label, label2block): parent_label.offset_g = self.label.offset_g - label2block[parent_label].size -def resolve_symbol(group_bloc, symbol_pool, dont_erase=[], - max_offset=0xFFFFFFFF): + +def get_blockchains_address_interval(blockChains, dst_interval): + allocated_interval = interval() + for chain in blockChains: + if not chain.pinned: + continue + chain_interval = interval([(chain.offset_min, chain.offset_max-1)]) + if (dst_interval - chain_interval).hull() == (None, None): + raise ValueError('Chain placed out of destination interval') + allocated_interval += chain_interval + return allocated_interval + +def resolve_symbol(blockChains, blocks, symbol_pool, dst_interval=None): """ - place all asmblocs + place all asmblocks """ log_asmbloc.info('resolve_symbol') - log_asmbloc.info(str(dont_erase)) - bloc_list = [] - unr_bloc = reduce(lambda x, y: x + group_bloc[y], group_bloc, []) + if dst_interval is None: + dst_interval = interval([(0, 0xFFFFFFFFFFFFFFFF)]) - non_free_mapping = gen_non_free_mapping(group_bloc, dont_erase) - free_interval = gen_free_space_intervals(non_free_mapping, max_offset) - log_asmbloc.debug(free_interval) + forbidden_interval = interval([(-1, 0xFFFFFFFFFFFFFFFF+1)]) - dst_interval - # first big ones - g_tab = [(x.chain_max_size, x) for x in group_bloc] - g_tab.sort() - g_tab.reverse() - g_tab = [x[1] for x in g_tab] + bloc_list = [] + unr_bloc = blocks[:] - # g_tab => label of grouped blov - # group_bloc => dict of grouped bloc labeled-key + allocated_interval = get_blockchains_address_interval(blockChains, + dst_interval) + log_asmbloc.debug('allocated interval: %s'%allocated_interval) - # first, near callee placing algo - for g in g_tab: - if g.fixedblocs: - continue - finish = False - for x in group_bloc: - if not x in free_interval.keys(): - continue - if free_interval[x] < g.chain_max_size: - continue + pinned_chains = [chain for chain in blockChains if chain.pinned] - for b in group_bloc[x]: - for c in b.bto: - if c.label == g: - tmp = free_interval[x] - g.chain_max_size - log_asmbloc.debug( - "consumed %d rest: %d", g.chain_max_size, int(tmp)) - free_interval[g] = tmp - del free_interval[x] - symbol_pool.set_offset( - g, AsmBlockLinkNext(group_bloc[x][-1].label)) - g.fixedblocs = True - finish = True - break - if finish: - break - if finish: - break + # Add wedge in forbidden intervals + for a, b in forbidden_interval.intervals: + wedge = BlockChainWedge(symbol_pool, offset=a, size=b+1-a) + pinned_chains.append(wedge) - # second, bigger in smaller algo - for g in g_tab: - if g.fixedblocs: - continue - # chose smaller free_interval first - k_tab = [(free_interval[x], x) for x in free_interval] - k_tab.sort() - k_tab = [x[1] for x in k_tab] - # choose free_interval - for k in k_tab: - if g.chain_max_size > free_interval[k]: - continue - symbol_pool.set_offset( - g, AsmBlockLinkNext(group_bloc[k][-1].label)) - tmp = free_interval[k] - g.chain_max_size - log_asmbloc.debug( - "consumed %d rest: %d", g.chain_max_size, int(tmp)) - free_interval[g] = tmp - del free_interval[k] - - g.fixedblocs = True - break + pinned_chains.sort(key=lambda x:x.offset_min) + # Try to place bigger blockChains first + blockChains.sort(key=lambda x:-x.max_size) - while unr_bloc: - # propagate know offset - resolving = False - i = 0 - while i < len(unr_bloc): - if unr_bloc[i].label.offset is None: - i += 1 - continue - resolving = True - log_asmbloc.info("bloc %s resolved", unr_bloc[i].label) - bloc_list.append(unr_bloc[i]) - g_found = None - for g in g_tab: - if unr_bloc[i] in group_bloc[g]: - if g_found is not None: - raise ValueError('blocin multiple group!!!') - g_found = g - my_group = group_bloc[g_found] - - index = my_group.index(unr_bloc[i]) - if index > 0 and my_group[index - 1] in unr_bloc: - symbol_pool.set_offset( - my_group[index - 1].label, - AsmBlockLinkPrev(unr_bloc[i].label)) - if index < len(my_group) - 1 and my_group[index + 1] in unr_bloc: - symbol_pool.set_offset( - my_group[index + 1].label, - AsmBlockLinkNext(unr_bloc[i].label)) - del unr_bloc[i] - - if not resolving: - log_asmbloc.warn("cannot resolve symbol! (no symbol fix found)") - else: - continue + fixed_chains = pinned_chains[:] - for g in g_tab: - log_asmbloc.debug(g) - if g.fixedblocs: - log_asmbloc.debug("fixed") - else: - log_asmbloc.debug("not fixed") - raise ValueError('enable to fix bloc') - return bloc_list + log_asmbloc.debug("place chains") + for chain in blockChains: + if chain.pinned: + continue + fixed = False + for i in xrange(1, len(fixed_chains)): + prev_chain = fixed_chains[i-1] + next_chain = fixed_chains[i] + + if prev_chain.offset_max + chain.max_size <= next_chain.offset_min: + new_chains = prev_chain.merge(chain) + fixed_chains[i-1:i] = new_chains + fixed = True + break + assert(fixed) + final_chains = [chain for chain in fixed_chains if isinstance(chain, BlockChain)] + return final_chains def calc_symbol_offset(symbol_pool, blocks): """Resolve dependencies between @blocks""" @@ -932,7 +913,7 @@ def calc_symbol_offset(symbol_pool, blocks): elif is_int(label.offset): pinned_labels.add(label) elif isinstance(label.offset, AsmBlockLink): - # construct dependant blocs tree + # construct dependant blocks tree linked_labels.setdefault(label.offset.label, set()).add(label) else: raise ValueError('Unknown offset type') @@ -947,47 +928,92 @@ def calc_symbol_offset(symbol_pool, blocks): unresolved_label.offset.resolve(unresolved_label, label2block) pinned_labels.add(unresolved_label) +def filter_exprid_label(exprs): + return set(expr.name for expr in exprs if isinstance(expr.name, asm_label)) + +def get_block_labels(block): + symbols = set() + for instr in block.lines: + if isinstance(instr, asm_raw): + if isinstance(instr.raw, list): + for x in instr.raw: + symbols.update(m2_expr.get_expr_ids(x)) + else: + for arg in instr.args: + symbols.update(m2_expr.get_expr_ids(arg)) + labels = filter_exprid_label(symbols) + return labels -def asmbloc_final(mnemo, blocs, symbol_pool, symb_reloc_off=None, +def asmbloc_final(mnemo, blocks, blockChains, symbol_pool, symb_reloc_off=None, conservative=False): - log_asmbloc.info("asmbloc_final") - if symb_reloc_off is None: - symb_reloc_off = {} + log_asmbloc.debug("asmbloc_final") + + + lbl2block = {block.label:block for block in blocks} + blocks_using_label = {} + for block in blocks: + labels = get_block_labels(block) + for label in labels: + blocks_using_label.setdefault(label, set()).add(block) + + block2chain = {} + for chain in blockChains: + for block in chain.blocks: + block2chain[block] = chain + + blocks_to_rework = set(blocks) fini = False - # asm with minimal instr len - # check if dst label are ok to this encoded form - # recompute if not - # TODO XXXX: implement todo list to remove n^high complexity! - while fini is not True: + while True: fini = True my_symb_reloc_off = {} - calc_symbol_offset(symbol_pool, blocs) + # Propagate pinned blocks into chains + modified_labels = set() + for chain in blockChains: + chain.fix_blocks(modified_labels) + + if not modified_labels and not blocks_to_rework: + break + + for label in modified_labels: + # Retrive block with modified reference + if label in lbl2block: + blocks_to_rework.add(lbl2block[label]) - symbols = asm_symbol_pool() - for s, v in symbol_pool._name2label.items(): - symbols.add_label(s, v.offset_g) - # test if bad encoded relative - for bloc in blocs: + # Enqueue blocks referencing a modified label + if label not in blocks_using_label: + continue + for block in blocks_using_label[label]: + blocks_to_rework.add(block) + + #symbols = asm_symbol_pool() + #for s, v in symbol_pool._name2label.items(): + # symbols.add_label(s, v.offset_g) + while blocks_to_rework: + block = blocks_to_rework.pop() offset_i = 0 - my_symb_reloc_off[bloc.label] = [] - for instr in bloc.lines: + my_symb_reloc_off[block.label] = [] + + len_modified = False + + for instr in block.lines: if isinstance(instr, asm_raw): if isinstance(instr.raw, list): # fix special asm_raw data = "" for x in instr.raw: - e = fix_expr_val(x, symbols) + e = fix_expr_val(x, symbol_pool) data+= pck[e.size](e.arg) instr.data = data + instr.offset = offset_i offset_i += instr.l continue sav_a = instr.args[:] - instr.offset = bloc.label.offset_g + offset_i - args_e = instr.resolve_args_with_symbols(symbols) + instr.offset = block.label.offset + offset_i + args_e = instr.resolve_args_with_symbols(symbol_pool) for i, e in enumerate(args_e): instr.args[i] = e @@ -1004,10 +1030,11 @@ def asmbloc_final(mnemo, blocs, symbol_pool, symb_reloc_off=None, if len(c) != instr.l: # good len, bad offset...XXX - bloc.size = bloc.size - old_l + len(c) + block.size = block.size - old_l + len(c) instr.data = c instr.l = len(c) fini = False + len_modified = True continue found = False for cpos, c in enumerate(candidates): @@ -1026,34 +1053,26 @@ def asmbloc_final(mnemo, blocs, symbol_pool, symb_reloc_off=None, my_s = None if my_s is not None: - my_symb_reloc_off[bloc.label].append(offset_i + my_s) + my_symb_reloc_off[block.label].append(offset_i + my_s) offset_i += instr.l assert len(instr.data) == instr.l - # we have fixed all relative values - # recompute good offsets - for label in symbol_pool.items: - symbol_pool.set_offset(label, label.offset_g) - - for a, b in my_symb_reloc_off.items(): - symb_reloc_off[a] = b -def asm_resolve_final(mnemo, blocs, symbol_pool, dont_erase=[], - max_offset=0xFFFFFFFF, symb_reloc_off=None): +def asm_resolve_final(mnemo, blocks, symbol_pool, dst_interval=None, + symb_reloc_off=None): if symb_reloc_off is None: symb_reloc_off = {} - guess_blocs_size(mnemo, blocs) - bloc_g = group_constrained_blocs(blocs) + guess_blocks_size(mnemo, blocks) + blockChains = group_constrained_blocks(symbol_pool, blocks) - resolved_b = resolve_symbol(bloc_g, symbol_pool, dont_erase=dont_erase, - max_offset=max_offset) + blockChains = resolve_symbol(blockChains, blocks, symbol_pool, dst_interval) - asmbloc_final(mnemo, resolved_b, symbol_pool, symb_reloc_off) + asmbloc_final(mnemo, blocks, blockChains, symbol_pool, symb_reloc_off) written_bytes = {} patches = {} - for bloc in resolved_b: - offset = bloc.label.offset - for line in bloc.lines: + for block in blocks: + offset = block.label.offset + for line in block.lines: assert line.data is not None patches[offset] = line.data for cur_pos in xrange(line.l): @@ -1063,10 +1082,8 @@ def asm_resolve_final(mnemo, blocs, symbol_pool, dont_erase=[], written_bytes[offset + cur_pos] = 1 line.offset = offset offset += line.l - return patches - def blist2graph(ab): """ ab: list of asmbloc diff --git a/miasm2/core/parse_asm.py b/miasm2/core/parse_asm.py index 6bec9651..a56bcd9a 100644 --- a/miasm2/core/parse_asm.py +++ b/miasm2/core/parse_asm.py @@ -19,6 +19,11 @@ size2pck = {8: 'B', 64: 'Q', } +class directive_align: + def __init__(self, alignment=1): + self.alignment = alignment + def __str__(self): + return "alignment %s"%self.alignment def guess_next_new_label(symbol_pool, gen_label_index=0): i = 0 @@ -145,6 +150,10 @@ def parse_txt(mnemo, attrib, txt, symbol_pool=None, gen_label_index=0): if directive == 'dontsplit': # custom command lines.append(asmbloc.asm_raw()) continue + if directive == "align": + align_value = int(line[r.end():]) + lines.append(directive_align(align_value)) + continue if directive in ['file', 'intel_syntax', 'globl', 'local', 'type', 'size', 'align', 'ident', 'section']: continue @@ -195,7 +204,7 @@ def parse_txt(mnemo, attrib, txt, symbol_pool=None, gen_label_index=0): lines[i:i] = [l] else: l = lines[i] - b = asmbloc.asm_bloc(l) + b = asmbloc.asm_bloc(l, alignment=mnemo.alignment) b.bloc_num = bloc_num bloc_num += 1 blocs.append(b) @@ -218,6 +227,9 @@ def parse_txt(mnemo, attrib, txt, symbol_pool=None, gen_label_index=0): block_may_link = True b.addline(lines[i]) i += 1 + elif isinstance(lines[i], directive_align): + b.alignment = lines[i].alignment + i += 1 # asmbloc.asm_label elif isinstance(lines[i], asmbloc.asm_label): if block_may_link: |