diff options
| -rw-r--r-- | example/jitter/unpack_upx.py | 3 | ||||
| -rw-r--r-- | miasm2/analysis/debugging.py | 23 | ||||
| -rw-r--r-- | miasm2/analysis/depgraph.py | 87 | ||||
| -rw-r--r-- | miasm2/analysis/gdbserver.py | 9 | ||||
| -rw-r--r-- | miasm2/analysis/sandbox.py | 10 | ||||
| -rw-r--r-- | miasm2/arch/aarch64/arch.py | 7 | ||||
| -rw-r--r-- | miasm2/arch/arm/arch.py | 1 | ||||
| -rw-r--r-- | miasm2/arch/arm/sem.py | 10 | ||||
| -rw-r--r-- | miasm2/core/asmbloc.py | 171 | ||||
| -rw-r--r-- | miasm2/core/utils.py | 33 | ||||
| -rw-r--r-- | miasm2/expression/expression.py | 6 | ||||
| -rw-r--r-- | miasm2/ir/analysis.py | 2 | ||||
| -rw-r--r-- | miasm2/ir/ir.py | 32 | ||||
| -rw-r--r-- | miasm2/ir/translators/z3_ir.py | 4 | ||||
| -rw-r--r-- | miasm2/jitter/jitcore_tcc.py | 8 | ||||
| -rw-r--r-- | miasm2/jitter/jitload.py | 16 | ||||
| -rw-r--r-- | miasm2/jitter/loader/pe.py | 21 | ||||
| -rw-r--r-- | test/arch/aarch64/arch.py | 5 | ||||
| -rw-r--r-- | test/arch/arm/arch.py | 3 | ||||
| -rw-r--r-- | test/test_all.py | 32 |
20 files changed, 321 insertions, 162 deletions
diff --git a/example/jitter/unpack_upx.py b/example/jitter/unpack_upx.py index 08b733a4..2d0a02ea 100644 --- a/example/jitter/unpack_upx.py +++ b/example/jitter/unpack_upx.py @@ -81,6 +81,9 @@ def update_binary(jitter): sdata = sb.jitter.vm.get_mem(sb.pe.rva2virt(s.addr), s.rawsize) sb.pe.virt[sb.pe.rva2virt(s.addr)] = sdata + # Stop execution + jitter.run = False + return False # Set callbacks sb.jitter.add_breakpoint(end_label, update_binary) diff --git a/miasm2/analysis/debugging.py b/miasm2/analysis/debugging.py index 4e6982b3..3fffbf66 100644 --- a/miasm2/analysis/debugging.py +++ b/miasm2/analysis/debugging.py @@ -22,6 +22,16 @@ class DebugBreakpointSoft(DebugBreakpoint): return "Soft BP @0x%08x" % self.addr +class DebugBreakpointTerminate(DebugBreakpoint): + "Stand for an execution termination" + + def __init__(self, status): + self.status = status + + def __str__(self): + return "Terminate with %s" % self.status + + class DebugBreakpointMemory(DebugBreakpoint): "Stand for memory breakpoint" @@ -131,8 +141,9 @@ class Debugguer(object): self.myjit.jit.log_newbloc = newbloc def handle_exception(self, res): - if res is None: - return + if not res: + # A breakpoint has stopped the execution + return DebugBreakpointTerminate(res) if isinstance(res, DebugBreakpointSoft): print "Breakpoint reached @0x%08x" % res.addr @@ -149,6 +160,9 @@ class Debugguer(object): else: raise NotImplementedError("type res") + # Repropagate res + return res + def step(self): "Step in jit" @@ -165,9 +179,8 @@ class Debugguer(object): return res def run(self): - res = self.myjit.continue_run() - self.handle_exception(res) - return res + status = self.myjit.continue_run() + return self.handle_exception(status) def get_mem(self, addr, size=0xF): "hexdump @addr, size" diff --git a/miasm2/analysis/depgraph.py b/miasm2/analysis/depgraph.py index 8d6da3b2..838183bf 100644 --- a/miasm2/analysis/depgraph.py +++ b/miasm2/analysis/depgraph.py @@ -16,6 +16,7 @@ from miasm2.ir.symbexec import symbexec from miasm2.ir.ir import irbloc from miasm2.ir.translators import Translator + class DependencyNode(object): """Node elements of a DependencyGraph @@ -27,6 +28,7 @@ class DependencyNode(object): __slots__ = ["_label", "_element", "_line_nb", "_modifier", "_step", "_nostep_repr", "_hash"] + def __init__(self, label, element, line_nb, step, modifier=False): """Create a dependency node with: @label: asm_label instance @@ -299,7 +301,6 @@ class DependencyDict(object): self._cache = CacheWrapper(self._get_modifiers_in_cache(node_heads)) - def _build_depgraph(self, depnode): """Recursively build the final list of DiGraph, and clean up unmodifier nodes @@ -495,7 +496,7 @@ class DependencyResult(object): def unresolved(self): """Set of nodes whose dependencies weren't found""" return set(node.nostep_repr for node in self._depdict.pending - if node.element != self._ira.IRDst) + if node.element != self._ira.IRDst) @property def relevant_nodes(self): @@ -708,13 +709,10 @@ class DependencyGraph(object): self._cb_follow = [] if apply_simp: self._cb_follow.append(self._follow_simp_expr) - if follow_mem: - self._cb_follow.append(self._follow_mem) - else: - self._cb_follow.append(self._follow_nomem) - if not follow_call: - self._cb_follow.append(self._follow_nocall) - self._cb_follow.append(self._follow_label) + self._cb_follow.append(lambda exprs: self._follow_exprs(exprs, + follow_mem, + follow_call)) + self._cb_follow.append(self._follow_nolabel) @property def step_counter(self): @@ -742,44 +740,59 @@ class DependencyGraph(object): return follow, set() @staticmethod - def _follow_label(exprs): - """Do not follow labels""" - follow = set() - for expr in exprs: - if not expr_is_label(expr): - follow.add(expr) + def get_expr(expr, follow, nofollow): + """Update @follow/@nofollow according to insteresting nodes + Returns same expression (non modifier visitor). - return follow, set() + @expr: expression to handle + @follow: set of nodes to follow + @nofollow: set of nodes not to follow + """ + if isinstance(expr, m2_expr.ExprId): + follow.add(expr) + elif isinstance(expr, m2_expr.ExprInt): + nofollow.add(expr) + return expr @staticmethod - def _follow_mem_wrapper(exprs, mem_read): - """Wrapper to follow or not expression from memory pointer""" - follow = set() - for expr in exprs: - follow.update(expr.get_r(mem_read=mem_read, cst_read=True)) - return follow, set() + def follow_expr(expr, follow, nofollow, follow_mem=False, follow_call=False): + """Returns True if we must visit sub expressions. + @expr: expression to browse + @follow: set of nodes to follow + @nofollow: set of nodes not to follow + @follow_mem: force the visit of memory sub expressions + @follow_call: force the visit of call sub expressions + """ + if not follow_mem and isinstance(expr, m2_expr.ExprMem): + nofollow.add(expr) + return False + if not follow_call and expr.is_function_call(): + nofollow.add(expr) + return False + return True - @staticmethod - def _follow_mem(exprs): - """Follow expression from memory pointer""" - return DependencyGraph._follow_mem_wrapper(exprs, True) + @classmethod + def _follow_exprs(cls, exprs, follow_mem=False, follow_call=False): + """Extracts subnodes from exprs and returns followed/non followed + expressions according to @follow_mem/@follow_call - @staticmethod - def _follow_nomem(exprs): - """Don't follow expression from memory pointer""" - return DependencyGraph._follow_mem_wrapper(exprs, False) + """ + follow, nofollow = set(), set() + for expr in exprs: + expr.visit(lambda x: cls.get_expr(x, follow, nofollow), + lambda x: cls.follow_expr(x, follow, nofollow, + follow_mem, follow_call)) + return follow, nofollow @staticmethod - def _follow_nocall(exprs): - """Don't follow expression from sub_call""" + def _follow_nolabel(exprs): + """Do not follow labels""" follow = set() - nofollow = set() for expr in exprs: - if expr.is_function_call(): - nofollow.add(expr) - else: + if not expr_is_label(expr): follow.add(expr) - return follow, nofollow + + return follow, set() def _follow_apply_cb(self, expr): """Apply callback functions to @expr diff --git a/miasm2/analysis/gdbserver.py b/miasm2/analysis/gdbserver.py index a930cc88..cbc8fe8d 100644 --- a/miasm2/analysis/gdbserver.py +++ b/miasm2/analysis/gdbserver.py @@ -134,7 +134,8 @@ class GdbServer(object): elif msg_type == "k": # Kill self.sock.close() - exit(1) + self.send_queue = [] + self.sock = None elif msg_type == "!": # Extending debugging will be used @@ -245,6 +246,12 @@ class GdbServer(object): self.send_queue.append("S05") else: raise NotImplementedError("Unknown Except") + elif isinstance(ret, debugging.DebugBreakpointTerminate): + # Connexion should close, but keep it running as a TRAP + # The connexion will be close on instance destruction + print ret + self.status = "S05" + self.send_queue.append("S05") else: raise NotImplementedError() diff --git a/miasm2/analysis/sandbox.py b/miasm2/analysis/sandbox.py index 7dc5d76e..688c7592 100644 --- a/miasm2/analysis/sandbox.py +++ b/miasm2/analysis/sandbox.py @@ -266,6 +266,7 @@ class OS_Linux_str(OS): class Arch_x86(Arch): _ARCH_ = None # Arch name STACK_SIZE = 0x100000 + STACK_BASE = 0x123000 def __init__(self): super(Arch_x86, self).__init__() @@ -278,6 +279,7 @@ class Arch_x86(Arch): # Init stack self.jitter.stack_size = self.STACK_SIZE + self.jitter.stack_base = self.STACK_BASE self.jitter.init_stack() @@ -298,46 +300,54 @@ class Arch_x86_64(Arch_x86): class Arch_arml(Arch): _ARCH_ = "arml" STACK_SIZE = 0x100000 + STACK_BASE = 0x100000 def __init__(self): super(Arch_arml, self).__init__() # Init stack self.jitter.stack_size = self.STACK_SIZE + self.jitter.stack_base = self.STACK_BASE self.jitter.init_stack() class Arch_armb(Arch): _ARCH_ = "armb" STACK_SIZE = 0x100000 + STACK_BASE = 0x100000 def __init__(self): super(Arch_armb, self).__init__() # Init stack self.jitter.stack_size = self.STACK_SIZE + self.jitter.stack_base = self.STACK_BASE self.jitter.init_stack() class Arch_aarch64l(Arch): _ARCH_ = "aarch64l" STACK_SIZE = 0x100000 + STACK_BASE = 0x100000 def __init__(self): super(Arch_aarch64l, self).__init__() # Init stack self.jitter.stack_size = self.STACK_SIZE + self.jitter.stack_base = self.STACK_BASE self.jitter.init_stack() class Arch_aarch64b(Arch): _ARCH_ = "aarch64b" STACK_SIZE = 0x100000 + STACK_BASE = 0x100000 def __init__(self): super(Arch_aarch64b, self).__init__() # Init stack self.jitter.stack_size = self.STACK_SIZE + self.jitter.stack_base = self.STACK_BASE self.jitter.init_stack() diff --git a/miasm2/arch/aarch64/arch.py b/miasm2/arch/aarch64/arch.py index 8c439dcc..816d67f4 100644 --- a/miasm2/arch/aarch64/arch.py +++ b/miasm2/arch/aarch64/arch.py @@ -1451,6 +1451,7 @@ rn0 = bs(l=5, cls=(aarch64_gpreg0,), fname="rn") rmz = bs(l=5, cls=(aarch64_gpregz,), fname="rm") rnz = bs(l=5, cls=(aarch64_gpregz,), fname="rn") +rdz = bs(l=5, cls=(aarch64_gpregz,), fname="rd") rn_n1 = bs(l=5, cls=(aarch64_gpreg_n1,), fname="rn") @@ -1602,10 +1603,14 @@ aarch64op("addsub", [sf, bs_adsu_name, modf, bs('01011'), bs('00'), bs('1'), rm_ aarch64op("neg", [sf, bs('1'), modf, bs('01011'), shift, bs('0'), rm_sft, imm6, bs('11111'), rd], [rd, rm_sft], alias=True) -logic_name = {'AND': 0, 'ORR': 1, 'EOR': 2, 'ANDS': 3} +logic_name = {'AND': 0, 'ORR': 1, 'EOR': 2} bs_logic_name = bs_name(l=2, name=logic_name) # logical (imm) aarch64op("logic", [sf, bs_logic_name, bs('100100'), immn, immr, imms, rn0, rd], [rd, rn0, imms]) +# ANDS +aarch64op("ands", [sf, bs('11'), bs('100100'), immn, immr, imms, rn0, rdz], [rdz, rn0, imms]) +aarch64op("tst", [sf, bs('11'), bs('100100'), immn, immr, imms, rn0, bs('11111')], [rn0, imms], alias=True) + # bitfield move p.149 logicbf_name = {'SBFM': 0b00, 'BFM': 0b01, 'UBFM': 0b10} diff --git a/miasm2/arch/arm/arch.py b/miasm2/arch/arm/arch.py index a70718d9..41c99d4d 100644 --- a/miasm2/arch/arm/arch.py +++ b/miasm2/arch/arm/arch.py @@ -1520,6 +1520,7 @@ armop("uxth", [bs('01101111'), bs('1111'), rd, rot_rm, bs('00'), bs('0111'), rm_ armop("sxtb", [bs('01101010'), bs('1111'), rd, rot_rm, bs('00'), bs('0111'), rm_noarg]) armop("sxth", [bs('01101011'), bs('1111'), rd, rot_rm, bs('00'), bs('0111'), rm_noarg]) +armop("rev", [bs('01101011'), bs('1111'), rd, bs('1111'), bs('0011'), rm]) class arm_widthm1(arm_imm, m_arg): def decode(self, v): diff --git a/miasm2/arch/arm/sem.py b/miasm2/arch/arm/sem.py index 85ea8c50..6838ef66 100644 --- a/miasm2/arch/arm/sem.py +++ b/miasm2/arch/arm/sem.py @@ -924,6 +924,15 @@ def bfc(ir, instr, a, b, c): e.append(ExprAff(ir.IRDst, r)) return e +def rev(ir, instr, a, b): + e = [] + c = ExprCompose([(b[:8], 24, 32), + (b[8:16], 16, 24), + (b[16:24], 8, 16), + (b[24:32], 0, 8)]) + e.append(ExprAff(a, c)) + return e + COND_EQ = 0 @@ -1067,6 +1076,7 @@ mnemo_condm0 = {'add': add, 'sxth': sxth, 'ubfx': ubfx, 'bfc': bfc, + 'rev': rev, } mnemo_condm1 = {'adds': add, diff --git a/miasm2/core/asmbloc.py b/miasm2/core/asmbloc.py index 1a2d7a91..31e4bdd7 100644 --- a/miasm2/core/asmbloc.py +++ b/miasm2/core/asmbloc.py @@ -3,6 +3,7 @@ import logging import inspect +import re import miasm2.expression.expression as m2_expr @@ -18,6 +19,7 @@ console_handler.setFormatter(logging.Formatter("%(levelname)-5s: %(message)s")) log_asmbloc.addHandler(console_handler) log_asmbloc.setLevel(logging.WARNING) + def is_int(a): return isinstance(a, int) or isinstance(a, long) or \ isinstance(a, moduint) or isinstance(a, modint) @@ -33,6 +35,7 @@ def expr_is_int_or_label(e): class asm_label: + "Stand for an assembly label" def __init__(self, name="", offset=None): @@ -59,13 +62,16 @@ class asm_label: rep += '>' return rep + class asm_raw: + def __init__(self, raw=""): self.raw = raw def __str__(self): return repr(self.raw) + class asm_constraint(object): c_to = "c_to" c_next = "c_next" @@ -194,7 +200,6 @@ class asm_bloc(object): if l.splitflow() or l.breakflow(): raise NotImplementedError('not fully functional') - def get_subcall_instr(self): if not self.lines: return None @@ -228,11 +233,11 @@ class asm_symbol_pool: # Test for collisions if (label.offset in self._offset2label and - label != self._offset2label[label.offset]): + label != self._offset2label[label.offset]): raise ValueError('symbol %s has same offset as %s' % (label, self._offset2label[label.offset])) if (label.name in self._name2label and - label != self._name2label[label.name]): + label != self._name2label[label.name]): raise ValueError('symbol %s has same name as %s' % (label, self._name2label[label.name])) @@ -438,7 +443,7 @@ def dis_bloc(mnemo, pool_bin, cur_bloc, offset, job_done, symbol_pool, def split_bloc(mnemo, attrib, pool_bin, blocs, - symbol_pool, more_ref=None, dis_bloc_callback=None): + symbol_pool, more_ref=None, dis_bloc_callback=None): if not more_ref: more_ref = [] @@ -472,7 +477,7 @@ def split_bloc(mnemo, attrib, pool_bin, blocs, if dis_bloc_callback: offsets_to_dis = set( [x.label.offset for x in new_b.bto - if isinstance(x.label, asm_label)]) + if isinstance(x.label, asm_label)]) dis_bloc_callback( mnemo, attrib, pool_bin, new_b, offsets_to_dis, symbol_pool) @@ -481,6 +486,7 @@ def split_bloc(mnemo, attrib, pool_bin, blocs, return blocs + def dis_bloc_all(mnemo, pool_bin, offset, job_done, symbol_pool, dont_dis=[], split_dis=[], follow_call=False, dontdis_retcall=False, blocs_wd=None, lines_wd=None, blocs=None, @@ -527,48 +533,75 @@ def dis_bloc_all(mnemo, pool_bin, offset, job_done, symbol_pool, dont_dis=[], blocs.append(cur_bloc) return split_bloc(mnemo, attrib, pool_bin, blocs, - symbol_pool, dis_bloc_callback=dis_bloc_callback) - - -def bloc2graph(blocs, label=False, lines=True): - # rankdir=LR; - out = """ -digraph asm_graph { -size="80,50"; -node [ -fontsize = "16", -shape = "box" -]; -""" - for b in blocs: - out += '%s [\n' % b.label.name - out += 'label = "' + symbol_pool, dis_bloc_callback=dis_bloc_callback) + - out += b.label.name + "\\l\\\n" +def bloc2graph(blocks, label=False, lines=True): + """Render dot graph of @blocks""" + + escape_chars = re.compile('[' + re.escape('{}') + ']') + label_attr = 'colspan="2" align="center" bgcolor="grey"' + edge_attr = 'label = "%s" color="%s" style="bold"' + td_attr = 'align="left"' + block_attr = 'shape="Mrecord" fontname="Courier New"' + + out = ["digraph asm_graph {"] + fix_chars = lambda x: '\\' + x.group() + + # Generate basic blocks + out_blocks = [] + for block in blocks: + out_block = '%s [\n' % block.label.name + out_block += "%s " % block_attr + out_block += 'label =<<table border="0" cellborder="0" cellpadding="3">' + + block_label = '<tr><td %s>%s</td></tr>' % ( + label_attr, block.label.name) + block_html_lines = [] if lines: - for l in b.lines: + for line in block.lines: if label: - out += "%.8X " % l.offset - out += ("%s\\l\\\n" % l).replace('"', '\\"') - out += '"\n];\n' - - for b in blocs: - for n in b.bto: - # print 'xxxx', n.label, n.label.__class__ - # if isinstance(n.label, ExprId): - # print n.label.name, n.label.name.__class__ - if isinstance(n.label, m2_expr.ExprId): - dst, name, cst = b.label.name, n.label.name, n.c_t - # out+='%s -> %s [ label = "%s" ];\n'%(b.label.name, - # n.label.name, n.c_t) - elif isinstance(n.label, asm_label): - dst, name, cst = b.label.name, n.label.name, n.c_t + out_render = "%.8X</td><td %s> " % (line.offset, td_attr) + else: + out_render = "" + out_render += escape_chars.sub(fix_chars, str(line)) + block_html_lines.append(out_render) + block_html_lines = ('<tr><td %s>' % td_attr + + ('</td></tr><tr><td %s>' % td_attr).join(block_html_lines) + + '</td></tr>') + out_block += "%s " % block_label + out_block += block_html_lines + "</table>> ];" + out_blocks.append(out_block) + + out += out_blocks + + # Generate links + for block in blocks: + for next_b in block.bto: + if (isinstance(next_b.label, m2_expr.ExprId) or + isinstance(next_b.label, asm_label)): + src, dst, cst = block.label.name, next_b.label.name, next_b.c_t else: continue - out += '%s -> %s [ label = "%s" ];\n' % (dst, name, cst) + if isinstance(src, asm_label): + src = src.name + if isinstance(dst, asm_label): + dst = dst.name + + edge_color = "black" + if next_b.c_t == asm_constraint.c_next: + edge_color = "red" + elif next_b.c_t == asm_constraint.c_to: + edge_color = "limegreen" + # special case + if len(block.bto) == 1: + edge_color = "blue" - out += "}" - return out + out.append('%s -> %s' % (src, dst) + + '[' + edge_attr % (cst, edge_color) + '];') + + out.append("}") + return '\n'.join(out) def conservative_asm(mnemo, instr, symbols, conservative): @@ -589,6 +622,7 @@ def conservative_asm(mnemo, instr, symbols, conservative): return c, candidates return candidates[0], candidates + def fix_expr_val(expr, symbols): """Resolve an expression @expr using @symbols""" def expr_calc(e): @@ -616,7 +650,7 @@ def guess_blocks_size(mnemo, blocks): if len(instr.raw) == 0: l = 0 else: - l = instr.raw[0].size/8 * len(instr.raw) + l = instr.raw[0].size / 8 * len(instr.raw) elif isinstance(instr.raw, str): data = instr.raw l = len(data) @@ -640,6 +674,7 @@ def guess_blocks_size(mnemo, blocks): block.max_size = size log_asmbloc.info("size: %d max: %d", block.size, block.max_size) + def fix_label_offset(symbol_pool, label, offset, modified): """Fix the @label offset to @offset. If the @offset has changed, add @label to @modified @@ -652,6 +687,7 @@ def fix_label_offset(symbol_pool, label, offset, modified): class BlockChain(object): + """Manage blocks linked with an asm_constraint_next""" def __init__(self, symbol_pool, blocks): @@ -672,7 +708,6 @@ class BlockChain(object): raise ValueError("Multiples pinned block detected") self.pinned_block_idx = i - def place(self): """Compute BlockChain min_offset and max_offset using pinned block and blocks' size @@ -686,17 +721,18 @@ class BlockChain(object): if not self.pinned: return - offset_base = self.blocks[self.pinned_block_idx].label.offset assert(offset_base % self.blocks[self.pinned_block_idx].alignment == 0) self.offset_min = offset_base - for block in self.blocks[:self.pinned_block_idx-1:-1]: - self.offset_min -= block.max_size + (block.alignment - block.max_size) % block.alignment + for block in self.blocks[:self.pinned_block_idx - 1:-1]: + self.offset_min -= block.max_size + \ + (block.alignment - block.max_size) % block.alignment self.offset_max = offset_base for block in self.blocks[self.pinned_block_idx:]: - self.offset_max += block.max_size + (block.alignment - block.max_size) % block.alignment + self.offset_max += block.max_size + \ + (block.alignment - block.max_size) % block.alignment def merge(self, chain): """Best effort merge two block chains @@ -718,7 +754,7 @@ class BlockChain(object): if offset % pinned_block.alignment != 0: raise RuntimeError('Bad alignment') - for block in self.blocks[:self.pinned_block_idx-1:-1]: + for block in self.blocks[:self.pinned_block_idx - 1:-1]: new_offset = offset - block.size new_offset = new_offset - new_offset % pinned_block.alignment fix_label_offset(self.symbol_pool, @@ -730,7 +766,7 @@ class BlockChain(object): offset = pinned_block.label.offset + pinned_block.size last_block = pinned_block - for block in self.blocks[self.pinned_block_idx+1:]: + for block in self.blocks[self.pinned_block_idx + 1:]: offset += (- offset) % last_block.alignment fix_label_offset(self.symbol_pool, block.label, @@ -742,6 +778,7 @@ class BlockChain(object): class BlockChainWedge(object): + """Stand for wedges between blocks""" def __init__(self, symbol_pool, offset, size): @@ -770,7 +807,7 @@ def group_constrained_blocks(symbol_pool, blocks): # Group adjacent blocks remaining_blocks = list(blocks) known_block_chains = {} - lbl2block = {block.label:block for block in blocks} + lbl2block = {block.label: block for block in blocks} while remaining_blocks: # Create a new block chain @@ -812,12 +849,13 @@ def get_blockchains_address_interval(blockChains, dst_interval): for chain in blockChains: if not chain.pinned: continue - chain_interval = interval([(chain.offset_min, chain.offset_max-1)]) + chain_interval = interval([(chain.offset_min, chain.offset_max - 1)]) if chain_interval not in dst_interval: raise ValueError('Chain placed out of destination interval') allocated_interval += chain_interval return allocated_interval + def resolve_symbol(blockChains, symbol_pool, dst_interval=None): """Place @blockChains in the @dst_interval""" @@ -825,7 +863,8 @@ def resolve_symbol(blockChains, symbol_pool, dst_interval=None): if dst_interval is None: dst_interval = interval([(0, 0xFFFFFFFFFFFFFFFF)]) - forbidden_interval = interval([(-1, 0xFFFFFFFFFFFFFFFF+1)]) - dst_interval + forbidden_interval = interval( + [(-1, 0xFFFFFFFFFFFFFFFF + 1)]) - dst_interval allocated_interval = get_blockchains_address_interval(blockChains, dst_interval) log_asmbloc.debug('allocated interval: %s', allocated_interval) @@ -834,12 +873,13 @@ def resolve_symbol(blockChains, symbol_pool, dst_interval=None): # Add wedge in forbidden intervals for start, stop in forbidden_interval.intervals: - wedge = BlockChainWedge(symbol_pool, offset=start, size=stop+1-start) + wedge = BlockChainWedge( + symbol_pool, offset=start, size=stop + 1 - start) pinned_chains.append(wedge) # Try to place bigger blockChains first - pinned_chains.sort(key=lambda x:x.offset_min) - blockChains.sort(key=lambda x:-x.max_size) + pinned_chains.sort(key=lambda x: x.offset_min) + blockChains.sort(key=lambda x: -x.max_size) fixed_chains = list(pinned_chains) @@ -849,12 +889,12 @@ def resolve_symbol(blockChains, symbol_pool, dst_interval=None): continue fixed = False for i in xrange(1, len(fixed_chains)): - prev_chain = fixed_chains[i-1] + prev_chain = fixed_chains[i - 1] next_chain = fixed_chains[i] if prev_chain.offset_max + chain.max_size < next_chain.offset_min: new_chains = prev_chain.merge(chain) - fixed_chains[i-1:i] = new_chains + fixed_chains[i - 1:i] = new_chains fixed = True break if not fixed: @@ -862,10 +902,12 @@ def resolve_symbol(blockChains, symbol_pool, dst_interval=None): return [chain for chain in fixed_chains if isinstance(chain, BlockChain)] + def filter_exprid_label(exprs): """Extract labels from list of ExprId @exprs""" return set(expr.name for expr in exprs if isinstance(expr.name, asm_label)) + def get_block_labels(block): """Extract labels used by @block""" symbols = set() @@ -880,6 +922,7 @@ def get_block_labels(block): labels = filter_exprid_label(symbols) return labels + def assemble_block(mnemo, block, symbol_pool, conservative=False): """Assemble a @block using @symbol_pool @conservative: (optional) use original bytes when possible @@ -932,7 +975,7 @@ def asmbloc_final(mnemo, blocks, blockChains, symbol_pool, conservative=False): log_asmbloc.debug("asmbloc_final") # Init structures - lbl2block = {block.label:block for block in blocks} + lbl2block = {block.label: block for block in blocks} blocks_using_label = {} for block in blocks: labels = get_block_labels(block) @@ -992,7 +1035,8 @@ def sanity_check_blocks(blocks): if blocks_graph.blocs[pred].get_next() == label: pred_next.add(pred) if len(pred_next) > 1: - raise RuntimeError("Too many next constraints for bloc %r"%label) + raise RuntimeError("Too many next constraints for bloc %r" % label) + def asm_resolve_final(mnemo, blocks, symbol_pool, dst_interval=None): """Resolve and assemble @blocks using @symbol_pool into interval @@ -1002,7 +1046,8 @@ def asm_resolve_final(mnemo, blocks, symbol_pool, dst_interval=None): guess_blocks_size(mnemo, blocks) blockChains = group_constrained_blocks(symbol_pool, blocks) - resolved_blockChains = resolve_symbol(blockChains, symbol_pool, dst_interval) + resolved_blockChains = resolve_symbol( + blockChains, symbol_pool, dst_interval) asmbloc_final(mnemo, blocks, resolved_blockChains, symbol_pool) patches = {} @@ -1016,13 +1061,14 @@ def asm_resolve_final(mnemo, blocks, symbol_pool, dst_interval=None): continue assert len(instr.data) == instr.l patches[offset] = instr.data - instruction_interval = interval([(offset, offset + instr.l-1)]) + instruction_interval = interval([(offset, offset + instr.l - 1)]) if not (instruction_interval & output_interval).empty: raise RuntimeError("overlapping bytes %X" % int(offset)) instr.offset = offset offset += instr.l return patches + def blist2graph(ab): """ ab: list of asmbloc @@ -1126,7 +1172,7 @@ def getbloc_parents(blocs, a, level=3, done=None, blocby_label=None): def getbloc_parents_strict( - blocs, a, level=3, rez=None, done=None, blocby_label=None): + blocs, a, level=3, rez=None, done=None, blocby_label=None): if not blocby_label: blocby_label = {} @@ -1280,4 +1326,3 @@ class disasmEngine(object): dont_dis_nulstart_bloc=self.dont_dis_nulstart_bloc, attrib=self.attrib) return blocs - diff --git a/miasm2/core/utils.py b/miasm2/core/utils.py index 782c21d6..75eb3113 100644 --- a/miasm2/core/utils.py +++ b/miasm2/core/utils.py @@ -1,6 +1,7 @@ import struct import inspect import UserDict +from operator import itemgetter upck8 = lambda x: struct.unpack('B', x)[0] upck16 = lambda x: struct.unpack('H', x)[0] @@ -73,7 +74,8 @@ class BoundedDict(UserDict.DictMixin): self._min_size = min_size if min_size else max_size / 3 self._max_size = max_size self._size = len(self._data) - self._counter = collections.Counter(self._data.keys()) + # Do not use collections.Counter as it is quite slow + self._counter = {k: 1 for k in self._data} self._delete_cb = delete_cb def __setitem__(self, asked_key, value): @@ -83,31 +85,46 @@ class BoundedDict(UserDict.DictMixin): # Bound can only be reached on a new element if (self._size >= self._max_size): - most_commons = [key for key, _ in self._counter.most_common()] + most_common = sorted(self._counter.iteritems(), + key=itemgetter(1), reverse=True) # Handle callback if self._delete_cb is not None: - for key in most_commons[self._min_size - 1:]: + for key, _ in most_common[self._min_size - 1:]: self._delete_cb(key) # Keep only the most @_min_size used self._data = {key:self._data[key] - for key in most_commons[:self._min_size - 1]} + for key, _ in most_common[:self._min_size - 1]} self._size = self._min_size # Reset use's counter - self._counter = collections.Counter(self._data.keys()) + self._counter = {k: 1 for k in self._data} + + # Avoid rechecking in dict: set to 1 here, add 1 otherwise + self._counter[asked_key] = 1 + else: + self._counter[asked_key] += 1 self._data[asked_key] = value - self._counter.update([asked_key]) + + def __contains__(self, key): + # Do not call has_key to avoid adding function call overhead + return key in self._data + + def has_key(self, key): + return key in self._data def keys(self): "Return the list of dict's keys" return self._data.keys() def __getitem__(self, key): - self._counter.update([key]) - return self._data[key] + # Retrieve data first to raise the proper exception on error + data = self._data[key] + # Should never raise, since the key is in self._data + self._counter[key] += 1 + return data def __delitem__(self, key): if self._delete_cb is not None: diff --git a/miasm2/expression/expression.py b/miasm2/expression/expression.py index 2b51ef61..c82aec2b 100644 --- a/miasm2/expression/expression.py +++ b/miasm2/expression/expression.py @@ -517,9 +517,6 @@ class ExprAff(Expr): else: return self._dst.get_w() - def is_function_call(self): - return isinstance(self.src, ExprOp) and self.src.op.startswith('call') - def _exprhash(self): return hash((EXPRAFF, hash(self._dst), hash(self._src))) @@ -821,6 +818,9 @@ class ExprOp(Expr): return True return False + def is_function_call(self): + return self._op.startswith('call') + def is_associative(self): "Return True iff current operation is associative" return (self._op in ['+', '*', '^', '&', '|']) diff --git a/miasm2/ir/analysis.py b/miasm2/ir/analysis.py index 51d2b2b7..31f6294c 100644 --- a/miasm2/ir/analysis.py +++ b/miasm2/ir/analysis.py @@ -167,7 +167,7 @@ class ira: # Function call, memory write or IRDst affectation for k, ir in enumerate(block.irs): for i_cur in ir: - if i_cur.is_function_call(): + if i_cur.src.is_function_call(): # /!\ never remove ir calls useful.add((block.label, k, i_cur)) if isinstance(i_cur.dst, ExprMem): diff --git a/miasm2/ir/ir.py b/miasm2/ir/ir.py index 32c97661..e051dc8c 100644 --- a/miasm2/ir/ir.py +++ b/miasm2/ir/ir.py @@ -135,17 +135,27 @@ class ir(object): ir_bloc_cur, ir_blocs_extra = self.get_ir(l) return ir_bloc_cur, ir_blocs_extra - def get_bloc(self, ad): - if isinstance(ad, m2_expr.ExprId) and isinstance(ad.name, - asmbloc.asm_label): + def get_label(self, ad): + """Transforms an ExprId/ExprInt/label/int into a label + @ad: an ExprId/ExprInt/label/int""" + + if (isinstance(ad, m2_expr.ExprId) and + isinstance(ad.name, asmbloc.asm_label)): ad = ad.name if isinstance(ad, m2_expr.ExprInt): ad = int(ad.arg) if type(ad) in [int, long]: - ad = self.symbol_pool.getby_offset(ad) + ad = self.symbol_pool.getby_offset_create(ad) elif isinstance(ad, asmbloc.asm_label): - ad = self.symbol_pool.getby_name(ad.name) - return self.blocs.get(ad, None) + ad = self.symbol_pool.getby_name_create(ad.name) + return ad + + def get_bloc(self, ad): + """Returns the irbloc associated to an ExprId/ExprInt/label/int + @ad: an ExprId/ExprInt/label/int""" + + label = self.get_label(ad) + return self.blocs.get(label, None) def add_instr(self, l, ad=0, gen_pc_updt = False): b = asmbloc.asm_bloc(l) @@ -227,7 +237,7 @@ class ir(object): ir_blocs_all = [] for l in bloc.lines: if c is None: - label = self.get_label(l) + label = self.get_instr_label(l) c = irbloc(label, [], []) ir_blocs_all.append(c) ir_bloc_cur, ir_blocs_extra = self.instr2ir(l) @@ -290,9 +300,11 @@ class ir(object): self.blocs[irb.label] = irb - def get_label(self, instr): - l = self.symbol_pool.getby_offset_create(instr.offset) - return l + def get_instr_label(self, instr): + """Returns the label associated to an instruction + @instr: current instruction""" + + return self.symbol_pool.getby_offset_create(instr.offset) def gen_label(self): # TODO: fix hardcoded offset diff --git a/miasm2/ir/translators/z3_ir.py b/miasm2/ir/translators/z3_ir.py index af4544a9..b6645d2b 100644 --- a/miasm2/ir/translators/z3_ir.py +++ b/miasm2/ir/translators/z3_ir.py @@ -167,6 +167,10 @@ class TranslatorZ3(Translator): res = res >> arg elif expr.op == "a<<": res = res << arg + elif expr.op == "<<<": + res = z3.RotateLeft(res, arg) + elif expr.op == ">>>": + res = z3.RotateRight(res, arg) elif expr.op == "idiv": res = res / arg elif expr.op == "udiv": diff --git a/miasm2/jitter/jitcore_tcc.py b/miasm2/jitter/jitcore_tcc.py index 573572d8..20f10339 100644 --- a/miasm2/jitter/jitcore_tcc.py +++ b/miasm2/jitter/jitcore_tcc.py @@ -2,11 +2,11 @@ #-*- coding:utf-8 -*- import os -from miasm2.ir.ir2C import irblocs2C -from subprocess import Popen, PIPE -import miasm2.jitter.jitcore as jitcore from distutils.sysconfig import get_python_inc -import Jittcc +from subprocess import Popen, PIPE + +from miasm2.ir.ir2C import irblocs2C +from miasm2.jitter import jitcore, Jittcc def jit_tcc_compil(func_name, func_code): diff --git a/miasm2/jitter/jitload.py b/miasm2/jitter/jitload.py index 1c88d0b7..112920a1 100644 --- a/miasm2/jitter/jitload.py +++ b/miasm2/jitter/jitload.py @@ -113,6 +113,9 @@ class CallbackHandler(object): return empty_keys + def has_callbacks(self, name): + return name in self.callbacks + def call_callbacks(self, name, *args): """Call callbacks associated to key 'name' with arguments args. While callbacks return True, continue with next callback. @@ -134,13 +137,17 @@ class CallbackHandlerBitflag(CallbackHandler): "Handle a list of callback with conditions on bitflag" + # Overrides CallbackHandler's implem, but do not serve for optimization + def has_callbacks(self, bitflag): + return any(cb_mask & bitflag != 0 for cb_mask in self.callbacks) + def __call__(self, bitflag, *args): """Call each callbacks associated with bit set in bitflag. While callbacks return True, continue with next callback. Iterator on other results""" res = True - for b in self.callbacks.keys(): + for b in self.callbacks: if b & bitflag != 0: # If the flag matched @@ -301,9 +308,10 @@ class jitter: # Check breakpoints old_pc = self.pc - for res in self.breakpoints_handler(self.pc, self): - if res is not True: - yield res + if self.breakpoints_handler.has_callbacks(self.pc): + for res in self.breakpoints_handler(self.pc, self): + if res is not True: + yield res # If a callback changed pc, re call every callback if old_pc != self.pc: diff --git a/miasm2/jitter/loader/pe.py b/miasm2/jitter/loader/pe.py index 3233cd4b..aaa7a469 100644 --- a/miasm2/jitter/loader/pe.py +++ b/miasm2/jitter/loader/pe.py @@ -17,6 +17,7 @@ hnd.setFormatter(logging.Formatter("[%(levelname)s]: %(message)s")) log.addHandler(hnd) log.setLevel(logging.CRITICAL) + def get_import_address_pe(e): import2addr = defaultdict(set) if e.DirImport.impdesc is None: @@ -53,7 +54,6 @@ def preload_pe(vm, e, runtime_lib, patch_vm_imp=True): return dyn_funcs - def is_redirected_export(e, ad): # test is ad points to code or dll name out = '' @@ -89,7 +89,6 @@ def get_export_name_addr_list(e): return out - def vm_load_pe(vm, fdata, align_s=True, load_hdr=True, **kargs): """Load a PE in memory (@vm) from a data buffer @fdata @vm: VmMngr instance @@ -121,7 +120,8 @@ def vm_load_pe(vm, fdata, align_s=True, load_hdr=True, **kargs): min_len = min(pe.SHList[0].addr, 0x1000) # Get and pad the pe_hdr - pe_hdr = pe.content[:hdr_len] + max(0, (min_len - hdr_len)) * "\x00" + pe_hdr = pe.content[:hdr_len] + max( + 0, (min_len - hdr_len)) * "\x00" vm.add_memory_page(pe.NThdr.ImageBase, PAGE_READ | PAGE_WRITE, pe_hdr) @@ -132,7 +132,8 @@ def vm_load_pe(vm, fdata, align_s=True, load_hdr=True, **kargs): new_size = pe.SHList[i + 1].addr - section.addr section.size = new_size section.rawsize = new_size - section.data = strpatchwork.StrPatchwork(section.data[:new_size]) + section.data = strpatchwork.StrPatchwork( + section.data[:new_size]) section.offset = section.addr # Last section alignement @@ -235,8 +236,8 @@ def vm2pe(myjit, fname, libs=None, e_orig=None, if min_addr is None and e_orig is not None: min_addr = min([e_orig.rva2virt(s.addr) for s in e_orig.SHList]) if max_addr is None and e_orig is not None: - max_addr = max([e_orig.rva2virt(s.addr + s.size) for s in e_orig.SHList]) - + max_addr = max([e_orig.rva2virt(s.addr + s.size) + for s in e_orig.SHList]) if img_base is None: img_base = e_orig.NThdr.ImageBase @@ -370,9 +371,9 @@ class libimp_pe(libimp): # Build an IMAGE_IMPORT_DESCRIPTOR # Get fixed addresses - out_ads = dict() # addr -> func_name + out_ads = dict() # addr -> func_name for func_name, dst_addresses in self.lib_imp2dstad[ad].items(): - out_ads.update({addr:func_name for addr in dst_addresses}) + out_ads.update({addr: func_name for addr in dst_addresses}) # Filter available addresses according to @flt all_ads = [addr for addr in out_ads.keys() if flt(addr)] @@ -391,7 +392,8 @@ class libimp_pe(libimp): # Find libname's Import Address Table othunk = all_ads[0] i = 0 - while i + 1 < len(all_ads) and all_ads[i] + 4 == all_ads[i + 1]: + while (i + 1 < len(all_ads) and + all_ads[i] + target_pe._wsize / 8 == all_ads[i + 1]): i += 1 # 'i + 1' is IAT's length @@ -417,6 +419,7 @@ PE_machine = {0x14c: "x86_32", 0x8664: "x86_64", } + def guess_arch(pe): """Return the architecture specified by the PE container @pe. If unknown, return None""" diff --git a/test/arch/aarch64/arch.py b/test/arch/aarch64/arch.py index aa3ab4dd..cca9184a 100644 --- a/test/arch/aarch64/arch.py +++ b/test/arch/aarch64/arch.py @@ -118,6 +118,11 @@ reg_tests_aarch64 = [ "000B1F72"), ("00079A80 ANDS X20, X2, 0xFF", "541C40F2"), + ("XXXXXXXX TST W14, 0x1", + "DF010072"), + ("XXXXXXXX ANDS W12, W13, 0x1", + "AC010072"), + ("0005BD5C AND W0, W0, W24", "0000180A"), diff --git a/test/arch/arm/arch.py b/test/arch/arm/arch.py index 701c45af..2ffbd3b1 100644 --- a/test/arch/arm/arch.py +++ b/test/arch/arm/arch.py @@ -242,6 +242,9 @@ reg_tests_arm = [ ("XXXXXXXX BFC R0, 0x0, 0xD", "1f00cce7"), + ("XXXXXXXX REV R0, R2", + "320FBFE6"), + ] ts = time.time() diff --git a/test/test_all.py b/test/test_all.py index 2f4efc88..54537bdf 100644 --- a/test/test_all.py +++ b/test/test_all.py @@ -139,19 +139,19 @@ for script in ["win_api_x86_32.py", testset += RegressionTest(["depgraph.py"], base_dir="analysis", products=[fname for fnames in ( ["graph_test_%02d_00.dot" % test_nb, + "exp_graph_test_%02d_00.dot" % test_nb, "graph_%02d.dot" % test_nb] - for test_nb in xrange(1, 17)) + for test_nb in xrange(1, 18)) for fname in fnames] + - ["graph_test_03_01.dot", - "graph_test_05_01.dot", - "graph_test_08_01.dot", - "graph_test_09_01.dot", - "graph_test_10_01.dot", - "graph_test_12_01.dot", - "graph_test_13_01.dot", - "graph_test_14_01.dot", - "graph_test_15_01.dot" - ]) + [fname for fnames in ( + ["graph_test_%02d_%02d.dot" % (test_nb, res_nb), + "exp_graph_test_%02d_%02d.dot" % (test_nb, + res_nb)] + for (test_nb, res_nb) in ((3, 1), (5, 1), (8, 1), + (9, 1), (10, 1), + (12, 1), (13, 1), + (14, 1), (15, 1))) + for fname in fnames]) # Examples class Example(Test): @@ -343,7 +343,7 @@ class ExampleSymbolExec(Example): testset += ExampleSymbolExec(["single_instr.py"]) for options, nb_sol, tag in [([], 8, []), - (["-i", "--rename-args"], 12, [TAGS["z3"]])]: + (["-i", "--rename-args"], 10, [TAGS["z3"]])]: testset += ExampleSymbolExec(["depgraph.py", Example.get_sample("simple_test.bin"), "-m", "x86_32", "0x0", "0x8b", @@ -402,9 +402,9 @@ if __name__ == "__main__": action="store_true") parser.add_argument("-c", "--coverage", help="Include code coverage", action="store_true") - parser.add_argument("-t", "--ommit-tags", help="Ommit tests based on tags \ + parser.add_argument("-t", "--omit-tags", help="Omit tests based on tags \ (tag1,tag2). Available tags are %s. \ -By default, no tag is ommited." % ", ".join(TAGS.keys()), default="") +By default, no tag is omitted." % ", ".join(TAGS.keys()), default="") parser.add_argument("-n", "--do-not-clean", help="Do not clean tests products", action="store_true") args = parser.parse_args() @@ -414,9 +414,9 @@ By default, no tag is ommited." % ", ".join(TAGS.keys()), default="") if args.mono is True or args.coverage is True: multiproc = False - ## Parse ommit-tags argument + ## Parse omit-tags argument exclude_tags = [] - for tag in args.ommit_tags.split(","): + for tag in args.omit_tags.split(","): if not tag: continue if tag not in TAGS: |