diff options
| -rw-r--r-- | example/jitter/unpack_upx.py | 3 | ||||
| -rw-r--r-- | miasm2/analysis/debugging.py | 23 | ||||
| -rw-r--r-- | miasm2/analysis/gdbserver.py | 9 | ||||
| -rw-r--r-- | miasm2/core/asmbloc.py | 171 | ||||
| -rw-r--r-- | miasm2/core/utils.py | 33 | ||||
| -rw-r--r-- | miasm2/ir/translators/z3_ir.py | 4 | ||||
| -rw-r--r-- | miasm2/jitter/jitload.py | 16 | ||||
| -rw-r--r-- | test/test_all.py | 8 |
8 files changed, 182 insertions, 85 deletions
diff --git a/example/jitter/unpack_upx.py b/example/jitter/unpack_upx.py index 08b733a4..2d0a02ea 100644 --- a/example/jitter/unpack_upx.py +++ b/example/jitter/unpack_upx.py @@ -81,6 +81,9 @@ def update_binary(jitter): sdata = sb.jitter.vm.get_mem(sb.pe.rva2virt(s.addr), s.rawsize) sb.pe.virt[sb.pe.rva2virt(s.addr)] = sdata + # Stop execution + jitter.run = False + return False # Set callbacks sb.jitter.add_breakpoint(end_label, update_binary) diff --git a/miasm2/analysis/debugging.py b/miasm2/analysis/debugging.py index 4e6982b3..3fffbf66 100644 --- a/miasm2/analysis/debugging.py +++ b/miasm2/analysis/debugging.py @@ -22,6 +22,16 @@ class DebugBreakpointSoft(DebugBreakpoint): return "Soft BP @0x%08x" % self.addr +class DebugBreakpointTerminate(DebugBreakpoint): + "Stand for an execution termination" + + def __init__(self, status): + self.status = status + + def __str__(self): + return "Terminate with %s" % self.status + + class DebugBreakpointMemory(DebugBreakpoint): "Stand for memory breakpoint" @@ -131,8 +141,9 @@ class Debugguer(object): self.myjit.jit.log_newbloc = newbloc def handle_exception(self, res): - if res is None: - return + if not res: + # A breakpoint has stopped the execution + return DebugBreakpointTerminate(res) if isinstance(res, DebugBreakpointSoft): print "Breakpoint reached @0x%08x" % res.addr @@ -149,6 +160,9 @@ class Debugguer(object): else: raise NotImplementedError("type res") + # Repropagate res + return res + def step(self): "Step in jit" @@ -165,9 +179,8 @@ class Debugguer(object): return res def run(self): - res = self.myjit.continue_run() - self.handle_exception(res) - return res + status = self.myjit.continue_run() + return self.handle_exception(status) def get_mem(self, addr, size=0xF): "hexdump @addr, size" diff --git a/miasm2/analysis/gdbserver.py b/miasm2/analysis/gdbserver.py index a930cc88..cbc8fe8d 100644 --- a/miasm2/analysis/gdbserver.py +++ b/miasm2/analysis/gdbserver.py @@ -134,7 +134,8 @@ class GdbServer(object): elif msg_type == "k": # Kill self.sock.close() - exit(1) + self.send_queue = [] + self.sock = None elif msg_type == "!": # Extending debugging will be used @@ -245,6 +246,12 @@ class GdbServer(object): self.send_queue.append("S05") else: raise NotImplementedError("Unknown Except") + elif isinstance(ret, debugging.DebugBreakpointTerminate): + # Connexion should close, but keep it running as a TRAP + # The connexion will be close on instance destruction + print ret + self.status = "S05" + self.send_queue.append("S05") else: raise NotImplementedError() diff --git a/miasm2/core/asmbloc.py b/miasm2/core/asmbloc.py index 1a2d7a91..31e4bdd7 100644 --- a/miasm2/core/asmbloc.py +++ b/miasm2/core/asmbloc.py @@ -3,6 +3,7 @@ import logging import inspect +import re import miasm2.expression.expression as m2_expr @@ -18,6 +19,7 @@ console_handler.setFormatter(logging.Formatter("%(levelname)-5s: %(message)s")) log_asmbloc.addHandler(console_handler) log_asmbloc.setLevel(logging.WARNING) + def is_int(a): return isinstance(a, int) or isinstance(a, long) or \ isinstance(a, moduint) or isinstance(a, modint) @@ -33,6 +35,7 @@ def expr_is_int_or_label(e): class asm_label: + "Stand for an assembly label" def __init__(self, name="", offset=None): @@ -59,13 +62,16 @@ class asm_label: rep += '>' return rep + class asm_raw: + def __init__(self, raw=""): self.raw = raw def __str__(self): return repr(self.raw) + class asm_constraint(object): c_to = "c_to" c_next = "c_next" @@ -194,7 +200,6 @@ class asm_bloc(object): if l.splitflow() or l.breakflow(): raise NotImplementedError('not fully functional') - def get_subcall_instr(self): if not self.lines: return None @@ -228,11 +233,11 @@ class asm_symbol_pool: # Test for collisions if (label.offset in self._offset2label and - label != self._offset2label[label.offset]): + label != self._offset2label[label.offset]): raise ValueError('symbol %s has same offset as %s' % (label, self._offset2label[label.offset])) if (label.name in self._name2label and - label != self._name2label[label.name]): + label != self._name2label[label.name]): raise ValueError('symbol %s has same name as %s' % (label, self._name2label[label.name])) @@ -438,7 +443,7 @@ def dis_bloc(mnemo, pool_bin, cur_bloc, offset, job_done, symbol_pool, def split_bloc(mnemo, attrib, pool_bin, blocs, - symbol_pool, more_ref=None, dis_bloc_callback=None): + symbol_pool, more_ref=None, dis_bloc_callback=None): if not more_ref: more_ref = [] @@ -472,7 +477,7 @@ def split_bloc(mnemo, attrib, pool_bin, blocs, if dis_bloc_callback: offsets_to_dis = set( [x.label.offset for x in new_b.bto - if isinstance(x.label, asm_label)]) + if isinstance(x.label, asm_label)]) dis_bloc_callback( mnemo, attrib, pool_bin, new_b, offsets_to_dis, symbol_pool) @@ -481,6 +486,7 @@ def split_bloc(mnemo, attrib, pool_bin, blocs, return blocs + def dis_bloc_all(mnemo, pool_bin, offset, job_done, symbol_pool, dont_dis=[], split_dis=[], follow_call=False, dontdis_retcall=False, blocs_wd=None, lines_wd=None, blocs=None, @@ -527,48 +533,75 @@ def dis_bloc_all(mnemo, pool_bin, offset, job_done, symbol_pool, dont_dis=[], blocs.append(cur_bloc) return split_bloc(mnemo, attrib, pool_bin, blocs, - symbol_pool, dis_bloc_callback=dis_bloc_callback) - - -def bloc2graph(blocs, label=False, lines=True): - # rankdir=LR; - out = """ -digraph asm_graph { -size="80,50"; -node [ -fontsize = "16", -shape = "box" -]; -""" - for b in blocs: - out += '%s [\n' % b.label.name - out += 'label = "' + symbol_pool, dis_bloc_callback=dis_bloc_callback) + - out += b.label.name + "\\l\\\n" +def bloc2graph(blocks, label=False, lines=True): + """Render dot graph of @blocks""" + + escape_chars = re.compile('[' + re.escape('{}') + ']') + label_attr = 'colspan="2" align="center" bgcolor="grey"' + edge_attr = 'label = "%s" color="%s" style="bold"' + td_attr = 'align="left"' + block_attr = 'shape="Mrecord" fontname="Courier New"' + + out = ["digraph asm_graph {"] + fix_chars = lambda x: '\\' + x.group() + + # Generate basic blocks + out_blocks = [] + for block in blocks: + out_block = '%s [\n' % block.label.name + out_block += "%s " % block_attr + out_block += 'label =<<table border="0" cellborder="0" cellpadding="3">' + + block_label = '<tr><td %s>%s</td></tr>' % ( + label_attr, block.label.name) + block_html_lines = [] if lines: - for l in b.lines: + for line in block.lines: if label: - out += "%.8X " % l.offset - out += ("%s\\l\\\n" % l).replace('"', '\\"') - out += '"\n];\n' - - for b in blocs: - for n in b.bto: - # print 'xxxx', n.label, n.label.__class__ - # if isinstance(n.label, ExprId): - # print n.label.name, n.label.name.__class__ - if isinstance(n.label, m2_expr.ExprId): - dst, name, cst = b.label.name, n.label.name, n.c_t - # out+='%s -> %s [ label = "%s" ];\n'%(b.label.name, - # n.label.name, n.c_t) - elif isinstance(n.label, asm_label): - dst, name, cst = b.label.name, n.label.name, n.c_t + out_render = "%.8X</td><td %s> " % (line.offset, td_attr) + else: + out_render = "" + out_render += escape_chars.sub(fix_chars, str(line)) + block_html_lines.append(out_render) + block_html_lines = ('<tr><td %s>' % td_attr + + ('</td></tr><tr><td %s>' % td_attr).join(block_html_lines) + + '</td></tr>') + out_block += "%s " % block_label + out_block += block_html_lines + "</table>> ];" + out_blocks.append(out_block) + + out += out_blocks + + # Generate links + for block in blocks: + for next_b in block.bto: + if (isinstance(next_b.label, m2_expr.ExprId) or + isinstance(next_b.label, asm_label)): + src, dst, cst = block.label.name, next_b.label.name, next_b.c_t else: continue - out += '%s -> %s [ label = "%s" ];\n' % (dst, name, cst) + if isinstance(src, asm_label): + src = src.name + if isinstance(dst, asm_label): + dst = dst.name + + edge_color = "black" + if next_b.c_t == asm_constraint.c_next: + edge_color = "red" + elif next_b.c_t == asm_constraint.c_to: + edge_color = "limegreen" + # special case + if len(block.bto) == 1: + edge_color = "blue" - out += "}" - return out + out.append('%s -> %s' % (src, dst) + + '[' + edge_attr % (cst, edge_color) + '];') + + out.append("}") + return '\n'.join(out) def conservative_asm(mnemo, instr, symbols, conservative): @@ -589,6 +622,7 @@ def conservative_asm(mnemo, instr, symbols, conservative): return c, candidates return candidates[0], candidates + def fix_expr_val(expr, symbols): """Resolve an expression @expr using @symbols""" def expr_calc(e): @@ -616,7 +650,7 @@ def guess_blocks_size(mnemo, blocks): if len(instr.raw) == 0: l = 0 else: - l = instr.raw[0].size/8 * len(instr.raw) + l = instr.raw[0].size / 8 * len(instr.raw) elif isinstance(instr.raw, str): data = instr.raw l = len(data) @@ -640,6 +674,7 @@ def guess_blocks_size(mnemo, blocks): block.max_size = size log_asmbloc.info("size: %d max: %d", block.size, block.max_size) + def fix_label_offset(symbol_pool, label, offset, modified): """Fix the @label offset to @offset. If the @offset has changed, add @label to @modified @@ -652,6 +687,7 @@ def fix_label_offset(symbol_pool, label, offset, modified): class BlockChain(object): + """Manage blocks linked with an asm_constraint_next""" def __init__(self, symbol_pool, blocks): @@ -672,7 +708,6 @@ class BlockChain(object): raise ValueError("Multiples pinned block detected") self.pinned_block_idx = i - def place(self): """Compute BlockChain min_offset and max_offset using pinned block and blocks' size @@ -686,17 +721,18 @@ class BlockChain(object): if not self.pinned: return - offset_base = self.blocks[self.pinned_block_idx].label.offset assert(offset_base % self.blocks[self.pinned_block_idx].alignment == 0) self.offset_min = offset_base - for block in self.blocks[:self.pinned_block_idx-1:-1]: - self.offset_min -= block.max_size + (block.alignment - block.max_size) % block.alignment + for block in self.blocks[:self.pinned_block_idx - 1:-1]: + self.offset_min -= block.max_size + \ + (block.alignment - block.max_size) % block.alignment self.offset_max = offset_base for block in self.blocks[self.pinned_block_idx:]: - self.offset_max += block.max_size + (block.alignment - block.max_size) % block.alignment + self.offset_max += block.max_size + \ + (block.alignment - block.max_size) % block.alignment def merge(self, chain): """Best effort merge two block chains @@ -718,7 +754,7 @@ class BlockChain(object): if offset % pinned_block.alignment != 0: raise RuntimeError('Bad alignment') - for block in self.blocks[:self.pinned_block_idx-1:-1]: + for block in self.blocks[:self.pinned_block_idx - 1:-1]: new_offset = offset - block.size new_offset = new_offset - new_offset % pinned_block.alignment fix_label_offset(self.symbol_pool, @@ -730,7 +766,7 @@ class BlockChain(object): offset = pinned_block.label.offset + pinned_block.size last_block = pinned_block - for block in self.blocks[self.pinned_block_idx+1:]: + for block in self.blocks[self.pinned_block_idx + 1:]: offset += (- offset) % last_block.alignment fix_label_offset(self.symbol_pool, block.label, @@ -742,6 +778,7 @@ class BlockChain(object): class BlockChainWedge(object): + """Stand for wedges between blocks""" def __init__(self, symbol_pool, offset, size): @@ -770,7 +807,7 @@ def group_constrained_blocks(symbol_pool, blocks): # Group adjacent blocks remaining_blocks = list(blocks) known_block_chains = {} - lbl2block = {block.label:block for block in blocks} + lbl2block = {block.label: block for block in blocks} while remaining_blocks: # Create a new block chain @@ -812,12 +849,13 @@ def get_blockchains_address_interval(blockChains, dst_interval): for chain in blockChains: if not chain.pinned: continue - chain_interval = interval([(chain.offset_min, chain.offset_max-1)]) + chain_interval = interval([(chain.offset_min, chain.offset_max - 1)]) if chain_interval not in dst_interval: raise ValueError('Chain placed out of destination interval') allocated_interval += chain_interval return allocated_interval + def resolve_symbol(blockChains, symbol_pool, dst_interval=None): """Place @blockChains in the @dst_interval""" @@ -825,7 +863,8 @@ def resolve_symbol(blockChains, symbol_pool, dst_interval=None): if dst_interval is None: dst_interval = interval([(0, 0xFFFFFFFFFFFFFFFF)]) - forbidden_interval = interval([(-1, 0xFFFFFFFFFFFFFFFF+1)]) - dst_interval + forbidden_interval = interval( + [(-1, 0xFFFFFFFFFFFFFFFF + 1)]) - dst_interval allocated_interval = get_blockchains_address_interval(blockChains, dst_interval) log_asmbloc.debug('allocated interval: %s', allocated_interval) @@ -834,12 +873,13 @@ def resolve_symbol(blockChains, symbol_pool, dst_interval=None): # Add wedge in forbidden intervals for start, stop in forbidden_interval.intervals: - wedge = BlockChainWedge(symbol_pool, offset=start, size=stop+1-start) + wedge = BlockChainWedge( + symbol_pool, offset=start, size=stop + 1 - start) pinned_chains.append(wedge) # Try to place bigger blockChains first - pinned_chains.sort(key=lambda x:x.offset_min) - blockChains.sort(key=lambda x:-x.max_size) + pinned_chains.sort(key=lambda x: x.offset_min) + blockChains.sort(key=lambda x: -x.max_size) fixed_chains = list(pinned_chains) @@ -849,12 +889,12 @@ def resolve_symbol(blockChains, symbol_pool, dst_interval=None): continue fixed = False for i in xrange(1, len(fixed_chains)): - prev_chain = fixed_chains[i-1] + prev_chain = fixed_chains[i - 1] next_chain = fixed_chains[i] if prev_chain.offset_max + chain.max_size < next_chain.offset_min: new_chains = prev_chain.merge(chain) - fixed_chains[i-1:i] = new_chains + fixed_chains[i - 1:i] = new_chains fixed = True break if not fixed: @@ -862,10 +902,12 @@ def resolve_symbol(blockChains, symbol_pool, dst_interval=None): return [chain for chain in fixed_chains if isinstance(chain, BlockChain)] + def filter_exprid_label(exprs): """Extract labels from list of ExprId @exprs""" return set(expr.name for expr in exprs if isinstance(expr.name, asm_label)) + def get_block_labels(block): """Extract labels used by @block""" symbols = set() @@ -880,6 +922,7 @@ def get_block_labels(block): labels = filter_exprid_label(symbols) return labels + def assemble_block(mnemo, block, symbol_pool, conservative=False): """Assemble a @block using @symbol_pool @conservative: (optional) use original bytes when possible @@ -932,7 +975,7 @@ def asmbloc_final(mnemo, blocks, blockChains, symbol_pool, conservative=False): log_asmbloc.debug("asmbloc_final") # Init structures - lbl2block = {block.label:block for block in blocks} + lbl2block = {block.label: block for block in blocks} blocks_using_label = {} for block in blocks: labels = get_block_labels(block) @@ -992,7 +1035,8 @@ def sanity_check_blocks(blocks): if blocks_graph.blocs[pred].get_next() == label: pred_next.add(pred) if len(pred_next) > 1: - raise RuntimeError("Too many next constraints for bloc %r"%label) + raise RuntimeError("Too many next constraints for bloc %r" % label) + def asm_resolve_final(mnemo, blocks, symbol_pool, dst_interval=None): """Resolve and assemble @blocks using @symbol_pool into interval @@ -1002,7 +1046,8 @@ def asm_resolve_final(mnemo, blocks, symbol_pool, dst_interval=None): guess_blocks_size(mnemo, blocks) blockChains = group_constrained_blocks(symbol_pool, blocks) - resolved_blockChains = resolve_symbol(blockChains, symbol_pool, dst_interval) + resolved_blockChains = resolve_symbol( + blockChains, symbol_pool, dst_interval) asmbloc_final(mnemo, blocks, resolved_blockChains, symbol_pool) patches = {} @@ -1016,13 +1061,14 @@ def asm_resolve_final(mnemo, blocks, symbol_pool, dst_interval=None): continue assert len(instr.data) == instr.l patches[offset] = instr.data - instruction_interval = interval([(offset, offset + instr.l-1)]) + instruction_interval = interval([(offset, offset + instr.l - 1)]) if not (instruction_interval & output_interval).empty: raise RuntimeError("overlapping bytes %X" % int(offset)) instr.offset = offset offset += instr.l return patches + def blist2graph(ab): """ ab: list of asmbloc @@ -1126,7 +1172,7 @@ def getbloc_parents(blocs, a, level=3, done=None, blocby_label=None): def getbloc_parents_strict( - blocs, a, level=3, rez=None, done=None, blocby_label=None): + blocs, a, level=3, rez=None, done=None, blocby_label=None): if not blocby_label: blocby_label = {} @@ -1280,4 +1326,3 @@ class disasmEngine(object): dont_dis_nulstart_bloc=self.dont_dis_nulstart_bloc, attrib=self.attrib) return blocs - diff --git a/miasm2/core/utils.py b/miasm2/core/utils.py index 782c21d6..75eb3113 100644 --- a/miasm2/core/utils.py +++ b/miasm2/core/utils.py @@ -1,6 +1,7 @@ import struct import inspect import UserDict +from operator import itemgetter upck8 = lambda x: struct.unpack('B', x)[0] upck16 = lambda x: struct.unpack('H', x)[0] @@ -73,7 +74,8 @@ class BoundedDict(UserDict.DictMixin): self._min_size = min_size if min_size else max_size / 3 self._max_size = max_size self._size = len(self._data) - self._counter = collections.Counter(self._data.keys()) + # Do not use collections.Counter as it is quite slow + self._counter = {k: 1 for k in self._data} self._delete_cb = delete_cb def __setitem__(self, asked_key, value): @@ -83,31 +85,46 @@ class BoundedDict(UserDict.DictMixin): # Bound can only be reached on a new element if (self._size >= self._max_size): - most_commons = [key for key, _ in self._counter.most_common()] + most_common = sorted(self._counter.iteritems(), + key=itemgetter(1), reverse=True) # Handle callback if self._delete_cb is not None: - for key in most_commons[self._min_size - 1:]: + for key, _ in most_common[self._min_size - 1:]: self._delete_cb(key) # Keep only the most @_min_size used self._data = {key:self._data[key] - for key in most_commons[:self._min_size - 1]} + for key, _ in most_common[:self._min_size - 1]} self._size = self._min_size # Reset use's counter - self._counter = collections.Counter(self._data.keys()) + self._counter = {k: 1 for k in self._data} + + # Avoid rechecking in dict: set to 1 here, add 1 otherwise + self._counter[asked_key] = 1 + else: + self._counter[asked_key] += 1 self._data[asked_key] = value - self._counter.update([asked_key]) + + def __contains__(self, key): + # Do not call has_key to avoid adding function call overhead + return key in self._data + + def has_key(self, key): + return key in self._data def keys(self): "Return the list of dict's keys" return self._data.keys() def __getitem__(self, key): - self._counter.update([key]) - return self._data[key] + # Retrieve data first to raise the proper exception on error + data = self._data[key] + # Should never raise, since the key is in self._data + self._counter[key] += 1 + return data def __delitem__(self, key): if self._delete_cb is not None: diff --git a/miasm2/ir/translators/z3_ir.py b/miasm2/ir/translators/z3_ir.py index af4544a9..b6645d2b 100644 --- a/miasm2/ir/translators/z3_ir.py +++ b/miasm2/ir/translators/z3_ir.py @@ -167,6 +167,10 @@ class TranslatorZ3(Translator): res = res >> arg elif expr.op == "a<<": res = res << arg + elif expr.op == "<<<": + res = z3.RotateLeft(res, arg) + elif expr.op == ">>>": + res = z3.RotateRight(res, arg) elif expr.op == "idiv": res = res / arg elif expr.op == "udiv": diff --git a/miasm2/jitter/jitload.py b/miasm2/jitter/jitload.py index 1c88d0b7..112920a1 100644 --- a/miasm2/jitter/jitload.py +++ b/miasm2/jitter/jitload.py @@ -113,6 +113,9 @@ class CallbackHandler(object): return empty_keys + def has_callbacks(self, name): + return name in self.callbacks + def call_callbacks(self, name, *args): """Call callbacks associated to key 'name' with arguments args. While callbacks return True, continue with next callback. @@ -134,13 +137,17 @@ class CallbackHandlerBitflag(CallbackHandler): "Handle a list of callback with conditions on bitflag" + # Overrides CallbackHandler's implem, but do not serve for optimization + def has_callbacks(self, bitflag): + return any(cb_mask & bitflag != 0 for cb_mask in self.callbacks) + def __call__(self, bitflag, *args): """Call each callbacks associated with bit set in bitflag. While callbacks return True, continue with next callback. Iterator on other results""" res = True - for b in self.callbacks.keys(): + for b in self.callbacks: if b & bitflag != 0: # If the flag matched @@ -301,9 +308,10 @@ class jitter: # Check breakpoints old_pc = self.pc - for res in self.breakpoints_handler(self.pc, self): - if res is not True: - yield res + if self.breakpoints_handler.has_callbacks(self.pc): + for res in self.breakpoints_handler(self.pc, self): + if res is not True: + yield res # If a callback changed pc, re call every callback if old_pc != self.pc: diff --git a/test/test_all.py b/test/test_all.py index 9773022f..54537bdf 100644 --- a/test/test_all.py +++ b/test/test_all.py @@ -402,9 +402,9 @@ if __name__ == "__main__": action="store_true") parser.add_argument("-c", "--coverage", help="Include code coverage", action="store_true") - parser.add_argument("-t", "--ommit-tags", help="Ommit tests based on tags \ + parser.add_argument("-t", "--omit-tags", help="Omit tests based on tags \ (tag1,tag2). Available tags are %s. \ -By default, no tag is ommited." % ", ".join(TAGS.keys()), default="") +By default, no tag is omitted." % ", ".join(TAGS.keys()), default="") parser.add_argument("-n", "--do-not-clean", help="Do not clean tests products", action="store_true") args = parser.parse_args() @@ -414,9 +414,9 @@ By default, no tag is ommited." % ", ".join(TAGS.keys()), default="") if args.mono is True or args.coverage is True: multiproc = False - ## Parse ommit-tags argument + ## Parse omit-tags argument exclude_tags = [] - for tag in args.ommit_tags.split(","): + for tag in args.omit_tags.split(","): if not tag: continue if tag not in TAGS: |