diff options
| author | Theofilos Augoustis <theofilos.augoustis@gmail.com> | 2025-10-14 09:09:29 +0000 |
|---|---|---|
| committer | Theofilos Augoustis <theofilos.augoustis@gmail.com> | 2025-10-14 09:09:29 +0000 |
| commit | 579cf1d03fb932083e6317967d1613d5c2587fb6 (patch) | |
| tree | 629f039935382a2a7391bce9253f6c9968159049 /src/miasm/core | |
| parent | 51c15d3ea2e16d4fc5f0f01a3b9befc66b1f982e (diff) | |
| download | focaccia-miasm-ta/nix.tar.gz focaccia-miasm-ta/nix.zip | |
Convert to src-layout ta/nix
Diffstat (limited to 'src/miasm/core')
| -rw-r--r-- | src/miasm/core/__init__.py | 1 | ||||
| -rw-r--r-- | src/miasm/core/asm_ast.py | 93 | ||||
| -rw-r--r-- | src/miasm/core/asmblock.py | 1474 | ||||
| -rw-r--r-- | src/miasm/core/bin_stream.py | 319 | ||||
| -rw-r--r-- | src/miasm/core/bin_stream_ida.py | 45 | ||||
| -rw-r--r-- | src/miasm/core/cpu.py | 1715 | ||||
| -rw-r--r-- | src/miasm/core/ctypesmngr.py | 771 | ||||
| -rw-r--r-- | src/miasm/core/graph.py | 1123 | ||||
| -rw-r--r-- | src/miasm/core/interval.py | 284 | ||||
| -rw-r--r-- | src/miasm/core/locationdb.py | 495 | ||||
| -rw-r--r-- | src/miasm/core/modint.py | 270 | ||||
| -rw-r--r-- | src/miasm/core/objc.py | 1763 | ||||
| -rw-r--r-- | src/miasm/core/parse_asm.py | 288 | ||||
| -rw-r--r-- | src/miasm/core/sembuilder.py | 341 | ||||
| -rw-r--r-- | src/miasm/core/types.py | 1693 | ||||
| -rw-r--r-- | src/miasm/core/utils.py | 292 |
16 files changed, 10967 insertions, 0 deletions
diff --git a/src/miasm/core/__init__.py b/src/miasm/core/__init__.py new file mode 100644 index 00000000..d154134b --- /dev/null +++ b/src/miasm/core/__init__.py @@ -0,0 +1 @@ +"Core components" diff --git a/src/miasm/core/asm_ast.py b/src/miasm/core/asm_ast.py new file mode 100644 index 00000000..69ff1f9c --- /dev/null +++ b/src/miasm/core/asm_ast.py @@ -0,0 +1,93 @@ +from builtins import int as int_types + +class AstNode(object): + """ + Ast node object + """ + def __neg__(self): + if isinstance(self, AstInt): + value = AstInt(-self.value) + else: + value = AstOp('-', self) + return value + + def __add__(self, other): + return AstOp('+', self, other) + + def __sub__(self, other): + return AstOp('-', self, other) + + def __div__(self, other): + return AstOp('/', self, other) + + def __mod__(self, other): + return AstOp('%', self, other) + + def __mul__(self, other): + return AstOp('*', self, other) + + def __lshift__(self, other): + return AstOp('<<', self, other) + + def __rshift__(self, other): + return AstOp('>>', self, other) + + def __xor__(self, other): + return AstOp('^', self, other) + + def __or__(self, other): + return AstOp('|', self, other) + + def __and__(self, other): + return AstOp('&', self, other) + + +class AstInt(AstNode): + """ + Ast integer + """ + def __init__(self, value): + self.value = value + + def __str__(self): + return "%s" % self.value + + +class AstId(AstNode): + """ + Ast Id + """ + def __init__(self, name): + self.name = name + + def __str__(self): + return "%s" % self.name + + +class AstMem(AstNode): + """ + Ast memory deref + """ + def __init__(self, ptr, size): + assert isinstance(ptr, AstNode) + assert isinstance(size, int_types) + self.ptr = ptr + self.size = size + + def __str__(self): + return "@%d[%s]" % (self.size, self.ptr) + + +class AstOp(AstNode): + """ + Ast operator + """ + def __init__(self, op, *args): + assert all(isinstance(arg, AstNode) for arg in args) + self.op = op + self.args = args + + def __str__(self): + if len(self.args) == 1: + return "(%s %s)" % (self.op, self.args[0]) + return '(' + ("%s" % self.op).join(str(x) for x in self.args) + ')' diff --git a/src/miasm/core/asmblock.py b/src/miasm/core/asmblock.py new file mode 100644 index 00000000..e92034fe --- /dev/null +++ b/src/miasm/core/asmblock.py @@ -0,0 +1,1474 @@ +#-*- coding:utf-8 -*- + +from builtins import map +from builtins import range +import logging +import warnings +from collections import namedtuple +from builtins import int as int_types + +from future.utils import viewitems, viewvalues + +from miasm.expression.expression import ExprId, ExprInt, get_expr_locs +from miasm.expression.expression import LocKey +from miasm.expression.simplifications import expr_simp +from miasm.core.utils import Disasm_Exception, pck +from miasm.core.graph import DiGraph, DiGraphSimplifier, MatchGraphJoker +from miasm.core.interval import interval + + +log_asmblock = logging.getLogger("asmblock") +console_handler = logging.StreamHandler() +console_handler.setFormatter(logging.Formatter("[%(levelname)-8s]: %(message)s")) +log_asmblock.addHandler(console_handler) +log_asmblock.setLevel(logging.WARNING) + + +class AsmRaw(object): + + def __init__(self, raw=b""): + self.raw = raw + + def __str__(self): + return repr(self.raw) + + def to_string(self, loc_db): + return str(self) + + def to_html(self, loc_db): + return str(self) + + +class AsmConstraint(object): + c_to = "c_to" + c_next = "c_next" + + def __init__(self, loc_key, c_t=c_to): + # Sanity check + assert isinstance(loc_key, LocKey) + + self.loc_key = loc_key + self.c_t = c_t + + def to_string(self, loc_db=None): + if loc_db is None: + return "%s:%s" % (self.c_t, self.loc_key) + else: + return "%s:%s" % ( + self.c_t, + loc_db.pretty_str(self.loc_key) + ) + + def __str__(self): + return self.to_string() + + +class AsmConstraintNext(AsmConstraint): + + def __init__(self, loc_key): + super(AsmConstraintNext, self).__init__( + loc_key, + c_t=AsmConstraint.c_next + ) + + +class AsmConstraintTo(AsmConstraint): + + def __init__(self, loc_key): + super(AsmConstraintTo, self).__init__( + loc_key, + c_t=AsmConstraint.c_to + ) + + +class AsmBlock(object): + + def __init__(self, loc_db, loc_key, alignment=1): + assert isinstance(loc_key, LocKey) + + self.bto = set() + self.lines = [] + self.loc_db = loc_db + self._loc_key = loc_key + self.alignment = alignment + + loc_key = property(lambda self:self._loc_key) + + def to_string(self): + out = [] + out.append(self.loc_db.pretty_str(self.loc_key)) + + for instr in self.lines: + out.append(instr.to_string(self.loc_db)) + if self.bto: + lbls = ["->"] + for dst in self.bto: + if dst is None: + lbls.append("Unknown? ") + else: + lbls.append(dst.to_string(self.loc_db) + " ") + lbls = '\t'.join(sorted(lbls)) + out.append(lbls) + return '\n'.join(out) + + def __str__(self): + return self.to_string() + + def addline(self, l): + self.lines.append(l) + + def addto(self, c): + assert isinstance(self.bto, set) + self.bto.add(c) + + def split(self, offset): + loc_key = self.loc_db.get_or_create_offset_location(offset) + log_asmblock.debug('split at %x', offset) + offsets = [x.offset for x in self.lines] + offset = self.loc_db.get_location_offset(loc_key) + if offset not in offsets: + log_asmblock.warning( + 'cannot split block at %X ' % offset + + 'middle instruction? default middle') + offsets.sort() + return None + new_block = AsmBlock(self.loc_db, loc_key) + i = offsets.index(offset) + + self.lines, new_block.lines = self.lines[:i], self.lines[i:] + flow_mod_instr = self.get_flow_instr() + log_asmblock.debug('flow mod %r', flow_mod_instr) + c = AsmConstraint(loc_key, AsmConstraint.c_next) + # move dst if flowgraph modifier was in original block + # (usecase: split delayslot block) + if flow_mod_instr: + for xx in self.bto: + log_asmblock.debug('lbl %s', xx) + c_next = set( + x for x in self.bto if x.c_t == AsmConstraint.c_next + ) + c_to = [x for x in self.bto if x.c_t != AsmConstraint.c_next] + self.bto = set([c] + c_to) + new_block.bto = c_next + else: + new_block.bto = self.bto + self.bto = set([c]) + return new_block + + def get_range(self): + """Returns the offset hull of an AsmBlock""" + if len(self.lines): + return (self.lines[0].offset, + self.lines[-1].offset + self.lines[-1].l) + else: + return 0, 0 + + def get_offsets(self): + return [x.offset for x in self.lines] + + def add_cst(self, loc_key, constraint_type): + """ + Add constraint between current block and block at @loc_key + @loc_key: LocKey instance of constraint target + @constraint_type: AsmConstraint c_to/c_next + """ + assert isinstance(loc_key, LocKey) + c = AsmConstraint(loc_key, constraint_type) + self.bto.add(c) + + def get_flow_instr(self): + if not self.lines: + return None + for i in range(-1, -1 - self.lines[0].delayslot - 1, -1): + if not 0 <= i < len(self.lines): + return None + l = self.lines[i] + if l.splitflow() or l.breakflow(): + raise NotImplementedError('not fully functional') + + def get_subcall_instr(self): + if not self.lines: + return None + delayslot = self.lines[0].delayslot + end_index = len(self.lines) - 1 + ds_max_index = max(end_index - delayslot, 0) + for i in range(end_index, ds_max_index - 1, -1): + l = self.lines[i] + if l.is_subcall(): + return l + return None + + def get_next(self): + for constraint in self.bto: + if constraint.c_t == AsmConstraint.c_next: + return constraint.loc_key + return None + + @staticmethod + def _filter_constraint(constraints): + """Sort and filter @constraints for AsmBlock.bto + @constraints: non-empty set of AsmConstraint instance + + Always the same type -> one of the constraint + c_next and c_to -> c_next + """ + # Only one constraint + if len(constraints) == 1: + return next(iter(constraints)) + + # Constraint type -> set of corresponding constraint + cbytype = {} + for cons in constraints: + cbytype.setdefault(cons.c_t, set()).add(cons) + + # Only one type -> any constraint is OK + if len(cbytype) == 1: + return next(iter(constraints)) + + # At least 2 types -> types = {c_next, c_to} + # c_to is included in c_next + return next(iter(cbytype[AsmConstraint.c_next])) + + def fix_constraints(self): + """Fix next block constraints""" + # destination -> associated constraints + dests = {} + for constraint in self.bto: + dests.setdefault(constraint.loc_key, set()).add(constraint) + + self.bto = set( + self._filter_constraint(constraints) + for constraints in viewvalues(dests) + ) + + +class AsmBlockBad(AsmBlock): + + """Stand for a *bad* ASM block (malformed, unreachable, + not disassembled, ...)""" + + + ERROR_UNKNOWN = -1 + ERROR_CANNOT_DISASM = 0 + ERROR_NULL_STARTING_BLOCK = 1 + ERROR_FORBIDDEN = 2 + ERROR_IO = 3 + + + ERROR_TYPES = { + ERROR_UNKNOWN: "Unknown error", + ERROR_CANNOT_DISASM: "Unable to disassemble", + ERROR_NULL_STARTING_BLOCK: "Null starting block", + ERROR_FORBIDDEN: "Address forbidden by dont_dis", + ERROR_IO: "IOError", + } + + def __init__(self, loc_db, loc_key=None, alignment=1, errno=ERROR_UNKNOWN, *args, **kwargs): + """Instantiate an AsmBlock_bad. + @loc_key, @alignment: same as AsmBlock.__init__ + @errno: (optional) specify a error type associated with the block + """ + super(AsmBlockBad, self).__init__(loc_db, loc_key, alignment, *args, **kwargs) + self._errno = errno + + errno = property(lambda self: self._errno) + + def __str__(self): + error_txt = self.ERROR_TYPES.get(self._errno, self._errno) + return "%s\n\tBad block: %s" % ( + self.loc_key, + error_txt + ) + + def addline(self, *args, **kwargs): + raise RuntimeError("An AsmBlockBad cannot have line") + + def addto(self, *args, **kwargs): + raise RuntimeError("An AsmBlockBad cannot have bto") + + def split(self, *args, **kwargs): + raise RuntimeError("An AsmBlockBad cannot be split") + + +class AsmCFG(DiGraph): + + """Directed graph standing for a ASM Control Flow Graph with: + - nodes: AsmBlock + - edges: constraints between blocks, synchronized with AsmBlock's "bto" + + Specialized the .dot export and force the relation between block to be uniq, + and associated with a constraint. + + Offer helpers on AsmCFG management, such as research by loc_key, sanity + checking and mnemonic size guessing. + """ + + # Internal structure for pending management + AsmCFGPending = namedtuple("AsmCFGPending", + ["waiter", "constraint"]) + + def __init__(self, loc_db, *args, **kwargs): + super(AsmCFG, self).__init__(*args, **kwargs) + # Edges -> constraint + self.edges2constraint = {} + # Expected LocKey -> set( (src, dst), constraint ) + self._pendings = {} + # Loc_Key2block built on the fly + self._loc_key_to_block = {} + # loc_db + self.loc_db = loc_db + + + def copy(self): + """Copy the current graph instance""" + graph = self.__class__(self.loc_db) + return graph + self + + def __len__(self): + """Return the number of blocks in AsmCFG""" + return len(self._nodes) + + @property + def blocks(self): + return viewvalues(self._loc_key_to_block) + + # Manage graph with associated constraints + def add_edge(self, src, dst, constraint): + """Add an edge to the graph + @src: LocKey instance, source + @dst: LocKey instance, destination + @constraint: constraint associated to this edge + """ + # Sanity check + assert isinstance(src, LocKey) + assert isinstance(dst, LocKey) + known_cst = self.edges2constraint.get((src, dst), None) + if known_cst is not None: + assert known_cst == constraint + return + + # Add the edge to src.bto if needed + block_src = self.loc_key_to_block(src) + if block_src: + if dst not in [cons.loc_key for cons in block_src.bto]: + block_src.bto.add(AsmConstraint(dst, constraint)) + + # Add edge + self.edges2constraint[(src, dst)] = constraint + super(AsmCFG, self).add_edge(src, dst) + + def add_uniq_edge(self, src, dst, constraint): + """ + Synonym for `add_edge` + """ + self.add_edge(src, dst, constraint) + + def del_edge(self, src, dst): + """Delete the edge @src->@dst and its associated constraint""" + src_blk = self.loc_key_to_block(src) + dst_blk = self.loc_key_to_block(dst) + assert src_blk is not None + assert dst_blk is not None + # Delete from src.bto + to_remove = [cons for cons in src_blk.bto if cons.loc_key == dst] + if to_remove: + assert len(to_remove) == 1 + src_blk.bto.remove(to_remove[0]) + + # Del edge + del self.edges2constraint[(src, dst)] + super(AsmCFG, self).del_edge(src, dst) + + def del_block(self, block): + super(AsmCFG, self).del_node(block.loc_key) + del self._loc_key_to_block[block.loc_key] + + + def add_node(self, node): + assert isinstance(node, LocKey) + return super(AsmCFG, self).add_node(node) + + def add_block(self, block): + """ + Add the block @block to the current instance, if it is not already in + @block: AsmBlock instance + + Edges will be created for @block.bto, if destinations are already in + this instance. If not, they will be resolved when adding these + aforementioned destinations. + `self.pendings` indicates which blocks are not yet resolved. + + """ + status = super(AsmCFG, self).add_node(block.loc_key) + + if not status: + return status + + # Update waiters + if block.loc_key in self._pendings: + for bblpend in self._pendings[block.loc_key]: + self.add_edge(bblpend.waiter.loc_key, block.loc_key, bblpend.constraint) + del self._pendings[block.loc_key] + + # Synchronize edges with block destinations + self._loc_key_to_block[block.loc_key] = block + + for constraint in block.bto: + dst = self._loc_key_to_block.get(constraint.loc_key, + None) + if dst is None: + # Block is yet unknown, add it to pendings + to_add = self.AsmCFGPending(waiter=block, + constraint=constraint.c_t) + self._pendings.setdefault(constraint.loc_key, + set()).add(to_add) + else: + # Block is already in known nodes + self.add_edge(block.loc_key, dst.loc_key, constraint.c_t) + + return status + + def merge(self, graph): + """Merge with @graph, taking in account constraints""" + # Add known blocks + for block in graph.blocks: + self.add_block(block) + # Add nodes not already in it (ie. not linked to a block) + for node in graph.nodes(): + self.add_node(node) + # -> add_edge(x, y, constraint) + for edge in graph._edges: + # May fail if there is an incompatibility in edges constraints + # between the two graphs + self.add_edge(*edge, constraint=graph.edges2constraint[edge]) + + def escape_text(self, text): + return text + + + def node2lines(self, node): + loc_key_name = self.loc_db.pretty_str(node) + yield self.DotCellDescription(text=loc_key_name, + attr={'align': 'center', + 'colspan': 2, + 'bgcolor': 'grey'}) + block = self._loc_key_to_block.get(node, None) + if block is None: + return + if isinstance(block, AsmBlockBad): + yield [ + self.DotCellDescription( + text=block.ERROR_TYPES.get(block._errno, + block._errno + ), + attr={}) + ] + return + for line in block.lines: + if self._dot_offset: + yield [self.DotCellDescription(text="%.8X" % line.offset, + attr={}), + self.DotCellDescription(text=line.to_html(self.loc_db), attr={})] + else: + yield self.DotCellDescription(text=line.to_html(self.loc_db), attr={}) + + def node_attr(self, node): + block = self._loc_key_to_block.get(node, None) + if isinstance(block, AsmBlockBad): + return {'style': 'filled', 'fillcolor': 'red'} + return {} + + def edge_attr(self, src, dst): + cst = self.edges2constraint.get((src, dst), None) + edge_color = "blue" + + if len(self.successors(src)) > 1: + if cst == AsmConstraint.c_next: + edge_color = "red" + else: + edge_color = "limegreen" + + return {"color": edge_color} + + def dot(self, offset=False): + """ + @offset: (optional) if set, add the corresponding offsets in each node + """ + self._dot_offset = offset + return super(AsmCFG, self).dot() + + # Helpers + @property + def pendings(self): + """Dictionary of loc_key -> set(AsmCFGPending instance) indicating + which loc_key are missing in the current instance. + A loc_key is missing if a block which is already in nodes has constraints + with him (thanks to its .bto) and the corresponding block is not yet in + nodes + """ + return self._pendings + + def rebuild_edges(self): + """Consider blocks '.bto' and rebuild edges according to them, ie: + - update constraint type + - add missing edge + - remove no more used edge + + This method should be called if a block's '.bto' in nodes have been + modified without notifying this instance to resynchronize edges. + """ + self._pendings = {} + for block in self.blocks: + edges = [] + # Rebuild edges from bto + for constraint in block.bto: + dst = self._loc_key_to_block.get(constraint.loc_key, + None) + if dst is None: + # Missing destination, add to pendings + self._pendings.setdefault( + constraint.loc_key, + set() + ).add( + self.AsmCFGPending( + block, + constraint.c_t + ) + ) + continue + edge = (block.loc_key, dst.loc_key) + edges.append(edge) + if edge in self._edges: + # Already known edge, constraint may have changed + self.edges2constraint[edge] = constraint.c_t + else: + # An edge is missing + self.add_edge(edge[0], edge[1], constraint.c_t) + + # Remove useless edges + for succ in self.successors(block.loc_key): + edge = (block.loc_key, succ) + if edge not in edges: + self.del_edge(*edge) + + def get_bad_blocks(self): + """Iterator on AsmBlockBad elements""" + # A bad asm block is always a leaf + for loc_key in self.leaves(): + block = self._loc_key_to_block.get(loc_key, None) + if isinstance(block, AsmBlockBad): + yield block + + def get_bad_blocks_predecessors(self, strict=False): + """Iterator on loc_keys with an AsmBlockBad destination + @strict: (optional) if set, return loc_key with only bad + successors + """ + # Avoid returning the same block + done = set() + for badblock in self.get_bad_blocks(): + for predecessor in self.predecessors_iter(badblock.loc_key): + if predecessor not in done: + if (strict and + not all(isinstance(self._loc_key_to_block.get(block, None), AsmBlockBad) + for block in self.successors_iter(predecessor))): + continue + yield predecessor + done.add(predecessor) + + def getby_offset(self, offset): + """Return asmblock containing @offset""" + for block in self.blocks: + if block.lines[0].offset <= offset < \ + (block.lines[-1].offset + block.lines[-1].l): + return block + return None + + def loc_key_to_block(self, loc_key): + """ + Return the asmblock corresponding to loc_key @loc_key, None if unknown + loc_key + @loc_key: LocKey instance + """ + return self._loc_key_to_block.get(loc_key, None) + + def sanity_check(self): + """Do sanity checks on blocks' constraints: + * no pendings + * no multiple next constraint to same block + * no next constraint to self + """ + + if len(self._pendings) != 0: + raise RuntimeError( + "Some blocks are missing: %s" % list( + map( + str, + self._pendings + ) + ) + ) + + next_edges = { + edge: constraint + for edge, constraint in viewitems(self.edges2constraint) + if constraint == AsmConstraint.c_next + } + + for loc_key in self._nodes: + if loc_key not in self._loc_key_to_block: + raise RuntimeError("Not supported yet: every node must have a corresponding AsmBlock") + # No next constraint to self + if (loc_key, loc_key) in next_edges: + raise RuntimeError('Bad constraint: self in next') + + # No multiple next constraint to same block + pred_next = list(ploc_key + for (ploc_key, dloc_key) in next_edges + if dloc_key == loc_key) + + if len(pred_next) > 1: + raise RuntimeError("Too many next constraints for block %r" + "(%s)" % (loc_key, + pred_next)) + + def guess_blocks_size(self, mnemo): + """Asm and compute max block size + Add a 'size' and 'max_size' attribute on each block + @mnemo: metamn instance""" + for block in self.blocks: + size = 0 + for instr in block.lines: + if isinstance(instr, AsmRaw): + # for special AsmRaw, only extract len + if isinstance(instr.raw, list): + data = None + if len(instr.raw) == 0: + l = 0 + else: + l = (instr.raw[0].size // 8) * len(instr.raw) + elif isinstance(instr.raw, str): + data = instr.raw.encode() + l = len(data) + elif isinstance(instr.raw, bytes): + data = instr.raw + l = len(data) + else: + raise NotImplementedError('asm raw') + else: + # Assemble the instruction to retrieve its len. + # If the instruction uses symbol it will fail + # In this case, the max_instruction_len is used + try: + candidates = mnemo.asm(instr) + l = len(candidates[-1]) + except: + l = mnemo.max_instruction_len + data = None + instr.data = data + instr.l = l + size += l + + block.size = size + block.max_size = size + log_asmblock.info("size: %d max: %d", block.size, block.max_size) + + def apply_splitting(self, loc_db, dis_block_callback=None, **kwargs): + warnings.warn('DEPRECATION WARNING: apply_splitting is member of disasm_engine') + raise RuntimeError("Moved api") + + def __str__(self): + out = [] + for block in self.blocks: + out.append(str(block)) + for loc_key_a, loc_key_b in self.edges(): + out.append("%s -> %s" % (loc_key_a, loc_key_b)) + return '\n'.join(out) + + def __repr__(self): + return "<%s %s>" % (self.__class__.__name__, hex(id(self))) + +# Out of _merge_blocks to be computed only once +_acceptable_block = lambda graph, loc_key: (not isinstance(graph.loc_key_to_block(loc_key), AsmBlockBad) and + len(graph.loc_key_to_block(loc_key).lines) > 0) +_parent = MatchGraphJoker(restrict_in=False, filt=_acceptable_block) +_son = MatchGraphJoker(restrict_out=False, filt=_acceptable_block) +_expgraph = _parent >> _son + + +def _merge_blocks(dg, graph): + """Graph simplification merging AsmBlock with one and only one son with this + son if this son has one and only one parent""" + + # Blocks to ignore, because they have been removed from the graph + to_ignore = set() + + for match in _expgraph.match(graph): + + # Get matching blocks + lbl_block, lbl_succ = match[_parent], match[_son] + block = graph.loc_key_to_block(lbl_block) + succ = graph.loc_key_to_block(lbl_succ) + + # Ignore already deleted blocks + if (block in to_ignore or + succ in to_ignore): + continue + + # Remove block last instruction if needed + last_instr = block.lines[-1] + if last_instr.delayslot > 0: + # TODO: delayslot + raise RuntimeError("Not implemented yet") + + if last_instr.is_subcall(): + continue + if last_instr.breakflow() and last_instr.dstflow(): + block.lines.pop() + + # Merge block + block.lines += succ.lines + for nextb in graph.successors_iter(lbl_succ): + graph.add_edge(lbl_block, nextb, graph.edges2constraint[(lbl_succ, nextb)]) + + graph.del_block(succ) + to_ignore.add(lbl_succ) + + +bbl_simplifier = DiGraphSimplifier() +bbl_simplifier.enable_passes([_merge_blocks]) + + +def conservative_asm(mnemo, instr, symbols, conservative): + """ + Asm instruction; + Try to keep original instruction bytes if it exists + """ + candidates = mnemo.asm(instr, symbols) + if not candidates: + raise ValueError('cannot asm:%s' % str(instr)) + if not hasattr(instr, "b"): + return candidates[0], candidates + if instr.b in candidates: + return instr.b, candidates + if conservative: + for c in candidates: + if len(c) == len(instr.b): + return c, candidates + return candidates[0], candidates + + +def fix_expr_val(expr, symbols): + """Resolve an expression @expr using @symbols""" + def expr_calc(e): + if isinstance(e, ExprId): + # Example: + # toto: + # .dword label + loc_key = symbols.get_name_location(e.name) + offset = symbols.get_location_offset(loc_key) + e = ExprInt(offset, e.size) + return e + result = expr.visit(expr_calc) + result = expr_simp(result) + if not isinstance(result, ExprInt): + raise RuntimeError('Cannot resolve symbol %s' % expr) + return result + + +def fix_loc_offset(loc_db, loc_key, offset, modified): + """ + Fix the @loc_key offset to @offset. If the @offset has changed, add @loc_key + to @modified + @loc_db: current loc_db + """ + loc_offset = loc_db.get_location_offset(loc_key) + if loc_offset == offset: + return + if loc_offset is not None: + loc_db.unset_location_offset(loc_key) + loc_db.set_location_offset(loc_key, offset) + modified.add(loc_key) + + +class BlockChain(object): + + """Manage blocks linked with an asm_constraint_next""" + + def __init__(self, loc_db, blocks): + self.loc_db = loc_db + self.blocks = blocks + self.place() + + @property + def pinned(self): + """Return True iff at least one block is pinned""" + return self.pinned_block_idx is not None + + def _set_pinned_block_idx(self): + self.pinned_block_idx = None + for i, block in enumerate(self.blocks): + loc_key = block.loc_key + if self.loc_db.get_location_offset(loc_key) is not None: + if self.pinned_block_idx is not None: + raise ValueError("Multiples pinned block detected") + self.pinned_block_idx = i + + def place(self): + """Compute BlockChain min_offset and max_offset using pinned block and + blocks' size + """ + self._set_pinned_block_idx() + self.max_size = 0 + for block in self.blocks: + self.max_size += block.max_size + block.alignment - 1 + + # Check if chain has one block pinned + if not self.pinned: + return + + loc = self.blocks[self.pinned_block_idx].loc_key + offset_base = self.loc_db.get_location_offset(loc) + assert(offset_base % self.blocks[self.pinned_block_idx].alignment == 0) + + self.offset_min = offset_base + for block in self.blocks[:self.pinned_block_idx - 1:-1]: + self.offset_min -= block.max_size + \ + (block.alignment - block.max_size) % block.alignment + + self.offset_max = offset_base + for block in self.blocks[self.pinned_block_idx:]: + self.offset_max += block.max_size + \ + (block.alignment - block.max_size) % block.alignment + + def merge(self, chain): + """Best effort merge two block chains + Return the list of resulting blockchains""" + self.blocks += chain.blocks + self.place() + return [self] + + def fix_blocks(self, modified_loc_keys): + """Propagate a pinned to its blocks' neighbour + @modified_loc_keys: store new pinned loc_keys""" + + if not self.pinned: + raise ValueError('Trying to fix unpinned block') + + # Propagate offset to blocks before pinned block + pinned_block = self.blocks[self.pinned_block_idx] + offset = self.loc_db.get_location_offset(pinned_block.loc_key) + if offset % pinned_block.alignment != 0: + raise RuntimeError('Bad alignment') + + for block in self.blocks[:self.pinned_block_idx - 1:-1]: + new_offset = offset - block.size + new_offset = new_offset - new_offset % pinned_block.alignment + fix_loc_offset(self.loc_db, + block.loc_key, + new_offset, + modified_loc_keys) + + # Propagate offset to blocks after pinned block + offset = self.loc_db.get_location_offset(pinned_block.loc_key) + pinned_block.size + + last_block = pinned_block + for block in self.blocks[self.pinned_block_idx + 1:]: + offset += (- offset) % last_block.alignment + fix_loc_offset(self.loc_db, + block.loc_key, + offset, + modified_loc_keys) + offset += block.size + last_block = block + return modified_loc_keys + + +class BlockChainWedge(object): + + """Stand for wedges between blocks""" + + def __init__(self, loc_db, offset, size): + self.loc_db = loc_db + self.offset = offset + self.max_size = size + self.offset_min = offset + self.offset_max = offset + size + + def merge(self, chain): + """Best effort merge two block chains + Return the list of resulting blockchains""" + self.loc_db.set_location_offset(chain.blocks[0].loc_key, self.offset_max) + chain.place() + return [self, chain] + + +def group_constrained_blocks(asmcfg): + """ + Return the BlockChains list built from grouped blocks in asmcfg linked by + asm_constraint_next + @asmcfg: an AsmCfg instance + """ + log_asmblock.info('group_constrained_blocks') + + # Group adjacent asmcfg + remaining_blocks = list(asmcfg.blocks) + known_block_chains = {} + + while remaining_blocks: + # Create a new block chain + block_list = [remaining_blocks.pop()] + + # Find sons in remainings blocks linked with a next constraint + while True: + # Get next block + next_loc_key = block_list[-1].get_next() + if next_loc_key is None or asmcfg.loc_key_to_block(next_loc_key) is None: + break + next_block = asmcfg.loc_key_to_block(next_loc_key) + + # Add the block at the end of the current chain + if next_block not in remaining_blocks: + break + block_list.append(next_block) + remaining_blocks.remove(next_block) + + # Check if son is in a known block group + if next_loc_key is not None and next_loc_key in known_block_chains: + block_list += known_block_chains[next_loc_key] + del known_block_chains[next_loc_key] + + known_block_chains[block_list[0].loc_key] = block_list + + out_block_chains = [] + for loc_key in known_block_chains: + chain = BlockChain(asmcfg.loc_db, known_block_chains[loc_key]) + out_block_chains.append(chain) + return out_block_chains + + +def get_blockchains_address_interval(blockChains, dst_interval): + """Compute the interval used by the pinned @blockChains + Check if the placed chains are in the @dst_interval""" + + allocated_interval = interval() + for chain in blockChains: + if not chain.pinned: + continue + chain_interval = interval([(chain.offset_min, chain.offset_max - 1)]) + if chain_interval not in dst_interval: + raise ValueError('Chain placed out of destination interval') + allocated_interval += chain_interval + return allocated_interval + + +def resolve_symbol(blockChains, loc_db, dst_interval=None): + """Place @blockChains in the @dst_interval""" + + log_asmblock.info('resolve_symbol') + if dst_interval is None: + dst_interval = interval([(0, 0xFFFFFFFFFFFFFFFF)]) + + forbidden_interval = interval( + [(-1, 0xFFFFFFFFFFFFFFFF + 1)]) - dst_interval + allocated_interval = get_blockchains_address_interval(blockChains, + dst_interval) + log_asmblock.debug('allocated interval: %s', allocated_interval) + + pinned_chains = [chain for chain in blockChains if chain.pinned] + + # Add wedge in forbidden intervals + for start, stop in forbidden_interval.intervals: + wedge = BlockChainWedge( + loc_db, offset=start, size=stop + 1 - start) + pinned_chains.append(wedge) + + # Try to place bigger blockChains first + pinned_chains.sort(key=lambda x: x.offset_min) + blockChains.sort(key=lambda x: -x.max_size) + + fixed_chains = list(pinned_chains) + + log_asmblock.debug("place chains") + for chain in blockChains: + if chain.pinned: + continue + fixed = False + for i in range(1, len(fixed_chains)): + prev_chain = fixed_chains[i - 1] + next_chain = fixed_chains[i] + + if prev_chain.offset_max + chain.max_size < next_chain.offset_min: + new_chains = prev_chain.merge(chain) + fixed_chains[i - 1:i] = new_chains + fixed = True + break + if not fixed: + raise RuntimeError('Cannot find enough space to place blocks') + + return [chain for chain in fixed_chains if isinstance(chain, BlockChain)] + + +def get_block_loc_keys(block): + """Extract loc_keys used by @block""" + symbols = set() + for instr in block.lines: + if isinstance(instr, AsmRaw): + if isinstance(instr.raw, list): + for expr in instr.raw: + symbols.update(get_expr_locs(expr)) + else: + for arg in instr.args: + symbols.update(get_expr_locs(arg)) + return symbols + + +def assemble_block(mnemo, block, conservative=False): + """Assemble a @block + @conservative: (optional) use original bytes when possible + """ + offset_i = 0 + + for instr in block.lines: + if isinstance(instr, AsmRaw): + if isinstance(instr.raw, list): + # Fix special AsmRaw + data = b"" + for expr in instr.raw: + expr_int = fix_expr_val(expr, block.loc_db) + data += pck[expr_int.size](int(expr_int)) + instr.data = data + + instr.offset = offset_i + offset_i += instr.l + continue + + # Assemble an instruction + saved_args = list(instr.args) + instr.offset = block.loc_db.get_location_offset(block.loc_key) + offset_i + + # Replace instruction's arguments by resolved ones + instr.args = instr.resolve_args_with_symbols(block.loc_db) + + if instr.dstflow(): + instr.fixDstOffset() + + old_l = instr.l + cached_candidate, _ = conservative_asm( + mnemo, instr, block.loc_db, + conservative + ) + if len(cached_candidate) != instr.l: + # The output instruction length is different from the one we guessed + # Retry assembly with updated length + instr.l = len(cached_candidate) + instr.args = saved_args + instr.args = instr.resolve_args_with_symbols(block.loc_db) + if instr.dstflow(): + instr.fixDstOffset() + cached_candidate, _ = conservative_asm( + mnemo, instr, block.loc_db, + conservative + ) + assert len(cached_candidate) == instr.l + + # Restore original arguments + instr.args = saved_args + + # We need to update the block size + block.size = block.size - old_l + len(cached_candidate) + instr.data = cached_candidate + instr.l = len(cached_candidate) + + offset_i += instr.l + + +def asmblock_final(mnemo, asmcfg, blockChains, conservative=False): + """Resolve and assemble @blockChains until fixed point is + reached""" + + log_asmblock.debug("asmbloc_final") + + # Init structures + blocks_using_loc_key = {} + for block in asmcfg.blocks: + exprlocs = get_block_loc_keys(block) + loc_keys = set(expr.loc_key for expr in exprlocs) + for loc_key in loc_keys: + blocks_using_loc_key.setdefault(loc_key, set()).add(block) + + block2chain = {} + for chain in blockChains: + for block in chain.blocks: + block2chain[block] = chain + + # Init worklist + blocks_to_rework = set(asmcfg.blocks) + + # Fix and re-assemble blocks until fixed point is reached + while True: + + # Propagate pinned blocks into chains + modified_loc_keys = set() + for chain in blockChains: + chain.fix_blocks(modified_loc_keys) + + for loc_key in modified_loc_keys: + # Retrieve block with modified reference + mod_block = asmcfg.loc_key_to_block(loc_key) + if mod_block is not None: + blocks_to_rework.add(mod_block) + + # Enqueue blocks referencing a modified loc_key + if loc_key not in blocks_using_loc_key: + continue + for block in blocks_using_loc_key[loc_key]: + blocks_to_rework.add(block) + + # No more work + if not blocks_to_rework: + break + + while blocks_to_rework: + block = blocks_to_rework.pop() + assemble_block(mnemo, block, conservative) + + +def asm_resolve_final(mnemo, asmcfg, dst_interval=None): + """Resolve and assemble @asmcfg into interval + @dst_interval""" + + asmcfg.sanity_check() + + asmcfg.guess_blocks_size(mnemo) + blockChains = group_constrained_blocks(asmcfg) + resolved_blockChains = resolve_symbol(blockChains, asmcfg.loc_db, dst_interval) + asmblock_final(mnemo, asmcfg, resolved_blockChains) + patches = {} + output_interval = interval() + + for block in asmcfg.blocks: + offset = asmcfg.loc_db.get_location_offset(block.loc_key) + for instr in block.lines: + if not instr.data: + # Empty line + continue + assert len(instr.data) == instr.l + patches[offset] = instr.data + instruction_interval = interval([(offset, offset + instr.l - 1)]) + if not (instruction_interval & output_interval).empty: + raise RuntimeError("overlapping bytes %X" % int(offset)) + output_interval = output_interval.union(instruction_interval) + instr.offset = offset + offset += instr.l + return patches + + +class disasmEngine(object): + + """Disassembly engine, taking care of disassembler options and mutli-block + strategy. + + Engine options: + + + Object supporting membership test (offset in ..) + - dont_dis: stop the current disassembly branch if reached + - split_dis: force a basic block end if reached, + with a next constraint on its successor + - dont_dis_retcall_funcs: stop disassembly after a call to one + of the given functions + + + On/Off + - follow_call: recursively disassemble CALL destinations + - dontdis_retcall: stop on CALL return addresses + - dont_dis_nulstart_bloc: stop if a block begin with a few \x00 + + + Number + - lines_wd: maximum block's size (in number of instruction) + - blocs_wd: maximum number of distinct disassembled block + + + callback(mdis, cur_block, offsets_to_dis) + - dis_block_callback: callback after each new disassembled block + """ + + def __init__(self, arch, attrib, bin_stream, loc_db, **kwargs): + """Instantiate a new disassembly engine + @arch: targeted architecture + @attrib: architecture attribute + @bin_stream: bytes source + @kwargs: (optional) custom options + """ + self.arch = arch + self.attrib = attrib + self.bin_stream = bin_stream + self.loc_db = loc_db + + # Setup options + self.dont_dis = [] + self.split_dis = [] + self.follow_call = False + self.dontdis_retcall = False + self.lines_wd = None + self.blocs_wd = None + self.dis_block_callback = None + self.dont_dis_nulstart_bloc = False + self.dont_dis_retcall_funcs = set() + + # Override options if needed + self.__dict__.update(kwargs) + + def _dis_block(self, offset, job_done=None): + """Disassemble the block at offset @offset + @job_done: a set of already disassembled addresses + Return the created AsmBlock and future offsets to disassemble + """ + + if job_done is None: + job_done = set() + lines_cpt = 0 + in_delayslot = False + delayslot_count = self.arch.delayslot + offsets_to_dis = set() + add_next_offset = False + loc_key = self.loc_db.get_or_create_offset_location(offset) + cur_block = AsmBlock(self.loc_db, loc_key) + log_asmblock.debug("dis at %X", int(offset)) + while not in_delayslot or delayslot_count > 0: + if in_delayslot: + delayslot_count -= 1 + + if offset in self.dont_dis: + if not cur_block.lines: + job_done.add(offset) + # Block is empty -> bad block + cur_block = AsmBlockBad(self.loc_db, loc_key, errno=AsmBlockBad.ERROR_FORBIDDEN) + else: + # Block is not empty, stop the desassembly pass and add a + # constraint to the next block + loc_key_cst = self.loc_db.get_or_create_offset_location(offset) + cur_block.add_cst(loc_key_cst, AsmConstraint.c_next) + break + + if lines_cpt > 0 and offset in self.split_dis: + loc_key_cst = self.loc_db.get_or_create_offset_location(offset) + cur_block.add_cst(loc_key_cst, AsmConstraint.c_next) + offsets_to_dis.add(offset) + break + + lines_cpt += 1 + if self.lines_wd is not None and lines_cpt > self.lines_wd: + log_asmblock.debug("lines watchdog reached at %X", int(offset)) + break + + if offset in job_done: + loc_key_cst = self.loc_db.get_or_create_offset_location(offset) + cur_block.add_cst(loc_key_cst, AsmConstraint.c_next) + break + + off_i = offset + error = None + try: + instr = self.arch.dis(self.bin_stream, self.attrib, offset) + except Disasm_Exception as e: + log_asmblock.warning(e) + instr = None + error = AsmBlockBad.ERROR_CANNOT_DISASM + except IOError as e: + log_asmblock.warning(e) + instr = None + error = AsmBlockBad.ERROR_IO + + + if instr is None: + log_asmblock.warning("cannot disasm at %X", int(off_i)) + if not cur_block.lines: + job_done.add(offset) + # Block is empty -> bad block + cur_block = AsmBlockBad(self.loc_db, loc_key, errno=error) + else: + # Block is not empty, stop the desassembly pass and add a + # constraint to the next block + loc_key_cst = self.loc_db.get_or_create_offset_location(off_i) + cur_block.add_cst(loc_key_cst, AsmConstraint.c_next) + break + + # XXX TODO nul start block option + if (self.dont_dis_nulstart_bloc and + not cur_block.lines and + instr.b.count(b'\x00') == instr.l): + log_asmblock.warning("reach nul instr at %X", int(off_i)) + # Block is empty -> bad block + cur_block = AsmBlockBad(self.loc_db, loc_key, errno=AsmBlockBad.ERROR_NULL_STARTING_BLOCK) + break + + # special case: flow graph modificator in delayslot + if in_delayslot and instr and (instr.splitflow() or instr.breakflow()): + add_next_offset = True + break + + job_done.add(offset) + log_asmblock.debug("dis at %X", int(offset)) + + offset += instr.l + log_asmblock.debug(instr) + log_asmblock.debug(instr.args) + + cur_block.addline(instr) + if not instr.breakflow(): + continue + # test split + if instr.splitflow() and not (instr.is_subcall() and self.dontdis_retcall): + add_next_offset = True + if instr.dstflow(): + instr.dstflow2label(self.loc_db) + destinations = instr.getdstflow(self.loc_db) + known_dsts = [] + for dst in destinations: + if not dst.is_loc(): + continue + loc_key = dst.loc_key + loc_key_offset = self.loc_db.get_location_offset(loc_key) + known_dsts.append(loc_key) + if loc_key_offset in self.dont_dis_retcall_funcs: + add_next_offset = False + if (not instr.is_subcall()) or self.follow_call: + cur_block.bto.update([AsmConstraint(loc_key, AsmConstraint.c_to) for loc_key in known_dsts]) + + # get in delayslot mode + in_delayslot = True + delayslot_count = instr.delayslot + + for c in cur_block.bto: + loc_key_offset = self.loc_db.get_location_offset(c.loc_key) + offsets_to_dis.add(loc_key_offset) + + if add_next_offset: + loc_key_cst = self.loc_db.get_or_create_offset_location(offset) + cur_block.add_cst(loc_key_cst, AsmConstraint.c_next) + offsets_to_dis.add(offset) + + # Fix multiple constraints + cur_block.fix_constraints() + + if self.dis_block_callback is not None: + self.dis_block_callback(self, cur_block, offsets_to_dis) + return cur_block, offsets_to_dis + + def dis_block(self, offset): + """Disassemble the block at offset @offset and return the created + AsmBlock + @offset: targeted offset to disassemble + """ + current_block, _ = self._dis_block(offset) + return current_block + + def dis_multiblock(self, offset, blocks=None, job_done=None): + """Disassemble every block reachable from @offset regarding + specific disasmEngine conditions + Return an AsmCFG instance containing disassembled blocks + @offset: starting offset + @blocks: (optional) AsmCFG instance of already disassembled blocks to + merge with + """ + log_asmblock.info("dis block all") + if job_done is None: + job_done = set() + if blocks is None: + blocks = AsmCFG(self.loc_db) + todo = [offset] + + bloc_cpt = 0 + while len(todo): + bloc_cpt += 1 + if self.blocs_wd is not None and bloc_cpt > self.blocs_wd: + log_asmblock.debug("blocks watchdog reached at %X", int(offset)) + break + + target_offset = int(todo.pop(0)) + if (target_offset is None or + target_offset in job_done): + continue + cur_block, nexts = self._dis_block(target_offset, job_done) + todo += nexts + blocks.add_block(cur_block) + + self.apply_splitting(blocks) + return blocks + + def apply_splitting(self, blocks): + """Consider @blocks' bto destinations and split block in @blocks if one + of these destinations jumps in the middle of this block. In order to + work, they must be only one block in @self per loc_key in + + @blocks: Asmcfg + """ + # Get all possible destinations not yet resolved, with a resolved + # offset + block_dst = [] + for loc_key in blocks.pendings: + offset = self.loc_db.get_location_offset(loc_key) + if offset is not None: + block_dst.append(offset) + + todo = set(blocks.blocks) + rebuild_needed = False + + while todo: + # Find a block with a destination inside another one + cur_block = todo.pop() + range_start, range_stop = cur_block.get_range() + + for off in block_dst: + if not (off > range_start and off < range_stop): + continue + + # `cur_block` must be split at offset `off`from miasm.core.locationdb import LocationDB + + new_b = cur_block.split(off) + log_asmblock.debug("Split block %x", off) + if new_b is None: + log_asmblock.error("Cannot split %x!!", off) + continue + + # Remove pending from cur_block + # Links from new_b will be generated in rebuild_edges + for dst in new_b.bto: + if dst.loc_key not in blocks.pendings: + continue + blocks.pendings[dst.loc_key] = set(pending for pending in blocks.pendings[dst.loc_key] + if pending.waiter != cur_block) + + # The new block destinations may need to be disassembled + if self.dis_block_callback: + offsets_to_dis = set( + self.loc_db.get_location_offset(constraint.loc_key) + for constraint in new_b.bto + ) + self.dis_block_callback(self, new_b, offsets_to_dis) + + # Update structure + rebuild_needed = True + blocks.add_block(new_b) + + # The new block must be considered + todo.add(new_b) + range_start, range_stop = cur_block.get_range() + + # Rebuild edges to match new blocks'bto + if rebuild_needed: + blocks.rebuild_edges() + + def dis_instr(self, offset): + """Disassemble one instruction at offset @offset and return the + corresponding instruction instance + @offset: targeted offset to disassemble + """ + old_lineswd = self.lines_wd + self.lines_wd = 1 + try: + block = self.dis_block(offset) + finally: + self.lines_wd = old_lineswd + + instr = block.lines[0] + return instr diff --git a/src/miasm/core/bin_stream.py b/src/miasm/core/bin_stream.py new file mode 100644 index 00000000..46165d49 --- /dev/null +++ b/src/miasm/core/bin_stream.py @@ -0,0 +1,319 @@ +# +# Copyright (C) 2011 EADS France, Fabrice Desclaux <fabrice.desclaux@eads.net> +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with this program; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# + +from builtins import str +from future.utils import PY3 + +from miasm.core.utils import BIG_ENDIAN, LITTLE_ENDIAN +from miasm.core.utils import upck8le, upck16le, upck32le, upck64le +from miasm.core.utils import upck8be, upck16be, upck32be, upck64be + + +class bin_stream(object): + + # Cache must be initialized by entering atomic mode + _cache = None + CACHE_SIZE = 10000 + # By default, no atomic mode + _atomic_mode = False + + def __init__(self, *args, **kargs): + self.endianness = LITTLE_ENDIAN + + def __repr__(self): + return "<%s !!>" % self.__class__.__name__ + + def __str__(self): + if PY3: + return repr(self) + return self.__bytes__() + + def hexdump(self, offset, l): + return + + def enter_atomic_mode(self): + """Enter atomic mode. In this mode, read may be cached""" + assert not self._atomic_mode + self._atomic_mode = True + self._cache = {} + + def leave_atomic_mode(self): + """Leave atomic mode""" + assert self._atomic_mode + self._atomic_mode = False + self._cache = None + + def _getbytes(self, start, length): + return self.bin[start:start + length] + + def getbytes(self, start, l=1): + """Return the bytes from the bit stream + @start: starting offset (in byte) + @l: (optional) number of bytes to read + + Wrapper on _getbytes, with atomic mode handling. + """ + if self._atomic_mode: + val = self._cache.get((start,l), None) + if val is None: + val = self._getbytes(start, l) + self._cache[(start,l)] = val + else: + val = self._getbytes(start, l) + return val + + def getbits(self, start, n): + """Return the bits from the bit stream + @start: the offset in bits + @n: number of bits to read + """ + # Trivial case + if n == 0: + return 0 + + # Get initial bytes + if n > self.getlen() * 8: + raise IOError('not enough bits %r %r' % (n, len(self.bin) * 8)) + byte_start = start // 8 + byte_stop = (start + n + 7) // 8 + temp = self.getbytes(byte_start, byte_stop - byte_start) + if not temp: + raise IOError('cannot get bytes') + + # Init + start = start % 8 + out = 0 + while n: + # Get needed bits, working on maximum 8 bits at a time + cur_byte_idx = start // 8 + new_bits = ord(temp[cur_byte_idx:cur_byte_idx + 1]) + to_keep = 8 - start % 8 + new_bits &= (1 << to_keep) - 1 + cur_len = min(to_keep, n) + new_bits >>= (to_keep - cur_len) + + # Update output + out <<= cur_len + out |= new_bits + + # Update counters + n -= cur_len + start += cur_len + return out + + def get_u8(self, addr, endianness=None): + """ + Return u8 from address @addr + endianness: Optional: LITTLE_ENDIAN/BIG_ENDIAN + """ + if endianness is None: + endianness = self.endianness + data = self.getbytes(addr, 1) + if endianness == LITTLE_ENDIAN: + return upck8le(data) + else: + return upck8be(data) + + def get_u16(self, addr, endianness=None): + """ + Return u16 from address @addr + endianness: Optional: LITTLE_ENDIAN/BIG_ENDIAN + """ + if endianness is None: + endianness = self.endianness + data = self.getbytes(addr, 2) + if endianness == LITTLE_ENDIAN: + return upck16le(data) + else: + return upck16be(data) + + def get_u32(self, addr, endianness=None): + """ + Return u32 from address @addr + endianness: Optional: LITTLE_ENDIAN/BIG_ENDIAN + """ + if endianness is None: + endianness = self.endianness + data = self.getbytes(addr, 4) + if endianness == LITTLE_ENDIAN: + return upck32le(data) + else: + return upck32be(data) + + def get_u64(self, addr, endianness=None): + """ + Return u64 from address @addr + endianness: Optional: LITTLE_ENDIAN/BIG_ENDIAN + """ + if endianness is None: + endianness = self.endianness + data = self.getbytes(addr, 8) + if endianness == LITTLE_ENDIAN: + return upck64le(data) + else: + return upck64be(data) + + +class bin_stream_str(bin_stream): + + def __init__(self, input_str=b"", offset=0, base_address=0, shift=None): + bin_stream.__init__(self) + if shift is not None: + raise DeprecationWarning("use base_address instead of shift") + self.bin = input_str + self.offset = offset + self.base_address = base_address + self.l = len(input_str) + + def _getbytes(self, start, l=1): + if start + l - self.base_address > self.l: + raise IOError("not enough bytes in str") + if start - self.base_address < 0: + raise IOError("Negative offset") + + return super(bin_stream_str, self)._getbytes(start - self.base_address, l) + + def readbs(self, l=1): + if self.offset + l - self.base_address > self.l: + raise IOError("not enough bytes in str") + if self.offset - self.base_address < 0: + raise IOError("Negative offset") + self.offset += l + return self.bin[self.offset - l - self.base_address:self.offset - self.base_address] + + def __bytes__(self): + return self.bin[self.offset - self.base_address:] + + def setoffset(self, val): + self.offset = val + + def getlen(self): + return self.l - (self.offset - self.base_address) + + +class bin_stream_file(bin_stream): + + def __init__(self, binary, offset=0, base_address=0, shift=None): + bin_stream.__init__(self) + if shift is not None: + raise DeprecationWarning("use base_address instead of shift") + self.bin = binary + self.bin.seek(0, 2) + self.base_address = base_address + self.l = self.bin.tell() + self.offset = offset + + def getoffset(self): + return self.bin.tell() + self.base_address + + def setoffset(self, val): + self.bin.seek(val - self.base_address) + offset = property(getoffset, setoffset) + + def readbs(self, l=1): + if self.offset + l - self.base_address > self.l: + raise IOError("not enough bytes in file") + if self.offset - self.base_address < 0: + raise IOError("Negative offset") + return self.bin.read(l) + + def __bytes__(self): + return self.bin.read() + + def getlen(self): + return self.l - (self.offset - self.base_address) + + +class bin_stream_container(bin_stream): + + def __init__(self, binary, offset=0): + bin_stream.__init__(self) + self.bin = binary + self.l = binary.virt.max_addr() + self.offset = offset + + def is_addr_in(self, ad): + return self.bin.virt.is_addr_in(ad) + + def getlen(self): + return self.l + + def readbs(self, l=1): + if self.offset + l > self.l: + raise IOError("not enough bytes") + if self.offset < 0: + raise IOError("Negative offset") + self.offset += l + return self.bin.virt.get(self.offset - l, self.offset) + + def _getbytes(self, start, l=1): + try: + return self.bin.virt.get(start, start + l) + except ValueError: + raise IOError("cannot get bytes") + + def __bytes__(self): + return self.bin.virt.get(self.offset, self.offset + self.l) + + def setoffset(self, val): + self.offset = val + + +class bin_stream_pe(bin_stream_container): + def __init__(self, binary, *args, **kwargs): + super(bin_stream_pe, self).__init__(binary, *args, **kwargs) + self.endianness = binary._sex + + +class bin_stream_elf(bin_stream_container): + def __init__(self, binary, *args, **kwargs): + super(bin_stream_elf, self).__init__(binary, *args, **kwargs) + self.endianness = binary.sex + + +class bin_stream_vm(bin_stream): + + def __init__(self, vm, offset=0, base_offset=0): + self.offset = offset + self.base_offset = base_offset + self.vm = vm + if self.vm.is_little_endian(): + self.endianness = LITTLE_ENDIAN + else: + self.endianness = BIG_ENDIAN + + def getlen(self): + return 0xFFFFFFFFFFFFFFFF + + def _getbytes(self, start, l=1): + try: + s = self.vm.get_mem(start + self.base_offset, l) + except: + raise IOError('cannot get mem ad', hex(start)) + return s + + def readbs(self, l=1): + try: + s = self.vm.get_mem(self.offset + self.base_offset, l) + except: + raise IOError('cannot get mem ad', hex(self.offset)) + self.offset += l + return s + + def setoffset(self, val): + self.offset = val diff --git a/src/miasm/core/bin_stream_ida.py b/src/miasm/core/bin_stream_ida.py new file mode 100644 index 00000000..15bd9d8b --- /dev/null +++ b/src/miasm/core/bin_stream_ida.py @@ -0,0 +1,45 @@ +from builtins import range +from idc import get_wide_byte, get_segm_end +from idautils import Segments +from idaapi import is_mapped + +from miasm.core.utils import int_to_byte +from miasm.core.bin_stream import bin_stream_str + + +class bin_stream_ida(bin_stream_str): + """ + bin_stream implementation for IDA + + Don't generate xrange using address computation: + It can raise error on overflow 7FFFFFFF with 32 bit python + """ + def _getbytes(self, start, l=1): + out = [] + for ad in range(l): + offset = ad + start + self.base_address + if not is_mapped(offset): + raise IOError(f"not enough bytes @ offset {offset:x}") + out.append(int_to_byte(get_wide_byte(offset))) + return b''.join(out) + + def readbs(self, l=1): + if self.offset + l > self.l: + raise IOError("not enough bytes") + content = self.getbytes(self.offset) + self.offset += l + return content + + def __str__(self): + raise NotImplementedError('Not fully functional') + + def setoffset(self, val): + self.offset = val + + def getlen(self): + # Lazy version + if hasattr(self, "_getlen"): + return self._getlen + max_addr = get_segm_end(list(Segments())[-1] - (self.offset - self.base_address)) + self._getlen = max_addr + return max_addr diff --git a/src/miasm/core/cpu.py b/src/miasm/core/cpu.py new file mode 100644 index 00000000..7df9f991 --- /dev/null +++ b/src/miasm/core/cpu.py @@ -0,0 +1,1715 @@ +#-*- coding:utf-8 -*- + +from builtins import range +import re +import struct +import logging +from collections import defaultdict + + +from future.utils import viewitems, viewvalues + +import pyparsing + +from miasm.core.utils import decode_hex +import miasm.expression.expression as m2_expr +from miasm.core.bin_stream import bin_stream, bin_stream_str +from miasm.core.utils import Disasm_Exception +from miasm.expression.simplifications import expr_simp + + +from miasm.core.asm_ast import AstNode, AstInt, AstId, AstOp +from miasm.core import utils +from future.utils import with_metaclass + +log = logging.getLogger("cpuhelper") +console_handler = logging.StreamHandler() +console_handler.setFormatter(logging.Formatter("[%(levelname)-8s]: %(message)s")) +log.addHandler(console_handler) +log.setLevel(logging.WARN) + + +class bitobj(object): + + def __init__(self, s=b""): + if not s: + bits = [] + else: + bits = [int(x) for x in bin(int(encode_hex(s), 16))[2:]] + if len(bits) % 8: + bits = [0 for x in range(8 - (len(bits) % 8))] + bits + self.bits = bits + self.offset = 0 + + def __len__(self): + return len(self.bits) - self.offset + + def getbits(self, n): + if not n: + return 0 + if n > len(self.bits) - self.offset: + raise ValueError('not enough bits %r %r' % (n, len(self.bits))) + b = self.bits[self.offset:self.offset + n] + b = int("".join(str(x) for x in b), 2) + self.offset += n + return b + + def putbits(self, b, n): + if not n: + return + bits = list(bin(b)[2:]) + bits = [int(x) for x in bits] + bits = [0 for x in range(n - len(bits))] + bits + self.bits += bits + + def tostring(self): + if len(self.bits) % 8: + raise ValueError( + 'num bits must be 8 bit aligned: %d' % len(self.bits) + ) + b = int("".join(str(x) for x in self.bits), 2) + b = "%X" % b + b = '0' * (len(self.bits) // 4 - len(b)) + b + b = decode_hex(b.encode()) + return b + + def reset(self): + self.offset = 0 + + def copy_state(self): + b = self.__class__() + b.bits = self.bits + b.offset = self.offset + return b + + +def literal_list(l): + l = l[:] + l.sort() + l = l[::-1] + o = pyparsing.Literal(l[0]) + for x in l[1:]: + o |= pyparsing.Literal(x) + return o + + +class reg_info(object): + + def __init__(self, reg_str, reg_expr): + self.str = reg_str + self.expr = reg_expr + self.parser = literal_list(reg_str).setParseAction(self.cb_parse) + + def cb_parse(self, tokens): + assert len(tokens) == 1 + i = self.str.index(tokens[0]) + reg = self.expr[i] + result = AstId(reg) + return result + + def reg2expr(self, s): + i = self.str.index(s[0]) + return self.expr[i] + + def expr2regi(self, e): + return self.expr.index(e) + + +class reg_info_dct(object): + + def __init__(self, reg_expr): + self.dct_str_inv = dict((v.name, k) for k, v in viewitems(reg_expr)) + self.dct_expr = reg_expr + self.dct_expr_inv = dict((v, k) for k, v in viewitems(reg_expr)) + reg_str = [v.name for v in viewvalues(reg_expr)] + self.parser = literal_list(reg_str).setParseAction(self.cb_parse) + + def cb_parse(self, tokens): + assert len(tokens) == 1 + i = self.dct_str_inv[tokens[0]] + reg = self.dct_expr[i] + result = AstId(reg) + return result + + def reg2expr(self, s): + i = self.dct_str_inv[s[0]] + return self.dct_expr[i] + + def expr2regi(self, e): + return self.dct_expr_inv[e] + + +def gen_reg(reg_name, sz=32): + """Gen reg expr and parser""" + reg = m2_expr.ExprId(reg_name, sz) + reginfo = reg_info([reg_name], [reg]) + return reg, reginfo + + +def gen_reg_bs(reg_name, reg_info, base_cls): + """ + Generate: + class bs_reg_name(base_cls): + reg = reg_info + + bs_reg_name = bs(l=0, cls=(bs_reg_name,)) + """ + + bs_name = "bs_%s" % reg_name + cls = type(bs_name, base_cls, {'reg': reg_info}) + + bs_obj = bs(l=0, cls=(cls,)) + + return cls, bs_obj + + +def gen_regs(rnames, env, sz=32): + regs_str = [] + regs_expr = [] + regs_init = [] + for rname in rnames: + r = m2_expr.ExprId(rname, sz) + r_init = m2_expr.ExprId(rname+'_init', sz) + regs_str.append(rname) + regs_expr.append(r) + regs_init.append(r_init) + env[rname] = r + + reginfo = reg_info(regs_str, regs_expr) + return regs_expr, regs_init, reginfo + + +LPARENTHESIS = pyparsing.Literal("(") +RPARENTHESIS = pyparsing.Literal(")") + + +def int2expr(tokens): + v = tokens[0] + return (m2_expr.ExprInt, v) + + +def parse_op(tokens): + v = tokens[0] + return (m2_expr.ExprOp, v) + + +def parse_id(tokens): + v = tokens[0] + return (m2_expr.ExprId, v) + + +def ast_parse_op(tokens): + if len(tokens) == 1: + return tokens[0] + if len(tokens) == 2: + if tokens[0] in ['-', '+', '!']: + return m2_expr.ExprOp(tokens[0], tokens[1]) + if len(tokens) == 3: + if tokens[1] == '-': + # a - b => a + (-b) + tokens[1] = '+' + tokens[2] = - tokens[2] + return m2_expr.ExprOp(tokens[1], tokens[0], tokens[2]) + tokens = tokens[::-1] + while len(tokens) >= 3: + o1, op, o2 = tokens.pop(), tokens.pop(), tokens.pop() + if op == '-': + # a - b => a + (-b) + op = '+' + o2 = - o2 + e = m2_expr.ExprOp(op, o1, o2) + tokens.append(e) + if len(tokens) != 1: + raise NotImplementedError('strange op') + return tokens[0] + + +def ast_id2expr(a): + return m2_expr.ExprId(a, 32) + + +def ast_int2expr(a): + return m2_expr.ExprInt(a, 32) + + +def neg_int(tokens): + x = -tokens[0] + return x + + +integer = pyparsing.Word(pyparsing.nums).setParseAction(lambda tokens: int(tokens[0])) +hex_word = pyparsing.Literal('0x') + pyparsing.Word(pyparsing.hexnums) +hex_int = pyparsing.Combine(hex_word).setParseAction(lambda tokens: int(tokens[0], 16)) + +# str_int = (Optional('-') + (hex_int | integer)) +str_int_pos = (hex_int | integer) +str_int_neg = (pyparsing.Suppress('-') + \ + (hex_int | integer)).setParseAction(neg_int) + +str_int = str_int_pos | str_int_neg +str_int.setParseAction(int2expr) + +logicop = pyparsing.oneOf('& | ^ >> << <<< >>>') +signop = pyparsing.oneOf('+ -') +multop = pyparsing.oneOf('* / %') +plusop = pyparsing.oneOf('+ -') + + +########################## + +def literal_list(l): + l = l[:] + l.sort() + l = l[::-1] + o = pyparsing.Literal(l[0]) + for x in l[1:]: + o |= pyparsing.Literal(x) + return o + + +def cb_int(tokens): + assert len(tokens) == 1 + integer = AstInt(tokens[0]) + return integer + + +def cb_parse_id(tokens): + assert len(tokens) == 1 + reg = tokens[0] + return AstId(reg) + + +def cb_op_not(tokens): + tokens = tokens[0] + assert len(tokens) == 2 + assert tokens[0] == "!" + result = AstOp("!", tokens[1]) + return result + + +def merge_ops(tokens, op): + args = [] + if len(tokens) >= 3: + args = [tokens.pop(0)] + i = 0 + while i < len(tokens): + op_tmp = tokens[i] + arg = tokens[i+1] + i += 2 + if op_tmp != op: + raise ValueError("Bad operator") + args.append(arg) + result = AstOp(op, *args) + return result + + +def cb_op_and(tokens): + result = merge_ops(tokens[0], "&") + return result + + +def cb_op_xor(tokens): + result = merge_ops(tokens[0], "^") + return result + + +def cb_op_sign(tokens): + assert len(tokens) == 1 + op, value = tokens[0] + return -value + + +def cb_op_div(tokens): + tokens = tokens[0] + assert len(tokens) == 3 + assert tokens[1] == "/" + result = AstOp("/", tokens[0], tokens[2]) + return result + + +def cb_op_plusminus(tokens): + tokens = tokens[0] + if len(tokens) == 3: + # binary op + assert isinstance(tokens[0], AstNode) + assert isinstance(tokens[2], AstNode) + op, args = tokens[1], [tokens[0], tokens[2]] + elif len(tokens) > 3: + args = [tokens.pop(0)] + i = 0 + while i < len(tokens): + op = tokens[i] + arg = tokens[i+1] + i += 2 + if op == '-': + arg = -arg + elif op == '+': + pass + else: + raise ValueError("Bad operator") + args.append(arg) + op = '+' + else: + raise ValueError("Parsing error") + assert all(isinstance(arg, AstNode) for arg in args) + result = AstOp(op, *args) + return result + + +def cb_op_mul(tokens): + tokens = tokens[0] + assert len(tokens) == 3 + assert isinstance(tokens[0], AstNode) + assert isinstance(tokens[2], AstNode) + + # binary op + op, args = tokens[1], [tokens[0], tokens[2]] + result = AstOp(op, *args) + return result + + +integer = pyparsing.Word(pyparsing.nums).setParseAction(lambda tokens: int(tokens[0])) +hex_word = pyparsing.Literal('0x') + pyparsing.Word(pyparsing.hexnums) +hex_int = pyparsing.Combine(hex_word).setParseAction(lambda tokens: int(tokens[0], 16)) + +str_int_pos = (hex_int | integer) + +str_int = str_int_pos +str_int.setParseAction(cb_int) + +notop = pyparsing.oneOf('!') +andop = pyparsing.oneOf('&') +orop = pyparsing.oneOf('|') +xorop = pyparsing.oneOf('^') +shiftop = pyparsing.oneOf('>> <<') +rotop = pyparsing.oneOf('<<< >>>') +signop = pyparsing.oneOf('+ -') +mulop = pyparsing.oneOf('*') +plusop = pyparsing.oneOf('+ -') +divop = pyparsing.oneOf('/') + + +variable = pyparsing.Word(pyparsing.alphas + "_$.", pyparsing.alphanums + "_") +variable.setParseAction(cb_parse_id) +operand = str_int | variable + +base_expr = pyparsing.infixNotation(operand, + [(notop, 1, pyparsing.opAssoc.RIGHT, cb_op_not), + (andop, 2, pyparsing.opAssoc.RIGHT, cb_op_and), + (xorop, 2, pyparsing.opAssoc.RIGHT, cb_op_xor), + (signop, 1, pyparsing.opAssoc.RIGHT, cb_op_sign), + (mulop, 2, pyparsing.opAssoc.RIGHT, cb_op_mul), + (divop, 2, pyparsing.opAssoc.RIGHT, cb_op_div), + (plusop, 2, pyparsing.opAssoc.LEFT, cb_op_plusminus), + ]) + + +default_prio = 0x1337 + + +def isbin(s): + return re.match(r'[0-1]+$', s) + + +def int2bin(i, l): + s = '0' * l + bin(i)[2:] + return s[-l:] + + +def myror32(v, r): + return ((v & 0xFFFFFFFF) >> r) | ((v << (32 - r)) & 0xFFFFFFFF) + + +def myrol32(v, r): + return ((v & 0xFFFFFFFF) >> (32 - r)) | ((v << r) & 0xFFFFFFFF) + + +class bs(object): + all_new_c = {} + prio = default_prio + + def __init__(self, strbits=None, l=None, cls=None, + fname=None, order=0, flen=None, **kargs): + if fname is None: + fname = hex(id(str((strbits, l, cls, fname, order, flen, kargs)))) + if strbits is None: + strbits = "" # "X"*l + elif l is None: + l = len(strbits) + if strbits and isbin(strbits): + value = int(strbits, 2) + elif 'default_val' in kargs: + value = int(kargs['default_val'], 2) + else: + value = None + allbits = list(strbits) + allbits.reverse() + fbits = 0 + fmask = 0 + while allbits: + a = allbits.pop() + if a == " ": + continue + fbits <<= 1 + fmask <<= 1 + if a in '01': + a = int(a) + fbits |= a + fmask |= 1 + lmask = (1 << l) - 1 + # gen conditional field + if cls: + for b in cls: + if 'flen' in b.__dict__: + flen = getattr(b, 'flen') + + self.strbits = strbits + self.l = l + self.cls = cls + self.fname = fname + self.order = order + self.fbits = fbits + self.fmask = fmask + self.flen = flen + self.value = value + self.kargs = kargs + + lmask = property(lambda self:(1 << self.l) - 1) + + def __getitem__(self, item): + return getattr(self, item) + + def __repr__(self): + o = self.__class__.__name__ + if self.fname: + o += "_%s" % self.fname + o += "_%(strbits)s" % self + if self.cls: + o += '_' + '_'.join([x.__name__ for x in self.cls]) + return o + + def gen(self, parent): + c_name = 'nbsi' + if self.cls: + c_name += '_' + '_'.join([x.__name__ for x in self.cls]) + bases = list(self.cls) + else: + bases = [] + # bsi added at end of list + # used to use first function of added class + bases += [bsi] + k = c_name, tuple(bases) + if k in self.all_new_c: + new_c = self.all_new_c[k] + else: + new_c = type(c_name, tuple(bases), {}) + self.all_new_c[k] = new_c + c = new_c(parent, + self.strbits, self.l, self.cls, + self.fname, self.order, self.lmask, self.fbits, + self.fmask, self.value, self.flen, **self.kargs) + return c + + def check_fbits(self, v): + return v & self.fmask == self.fbits + + @classmethod + def flen(cls, v): + raise NotImplementedError('not fully functional') + + +class dum_arg(object): + + def __init__(self, e=None): + self.expr = e + + +class bsopt(bs): + + def ispresent(self): + return True + + +class bsi(object): + + def __init__(self, parent, strbits, l, cls, fname, order, + lmask, fbits, fmask, value, flen, **kargs): + self.parent = parent + self.strbits = strbits + self.l = l + self.cls = cls + self.fname = fname + self.order = order + self.fbits = fbits + self.fmask = fmask + self.flen = flen + self.value = value + self.kargs = kargs + self.__dict__.update(self.kargs) + + lmask = property(lambda self:(1 << self.l) - 1) + + def decode(self, v): + self.value = v & self.lmask + return True + + def encode(self): + return True + + def clone(self): + s = self.__class__(self.parent, + self.strbits, self.l, self.cls, + self.fname, self.order, self.lmask, self.fbits, + self.fmask, self.value, self.flen, **self.kargs) + s.__dict__.update(self.kargs) + if hasattr(self, 'expr'): + s.expr = self.expr + return s + + def __hash__(self): + kargs = [] + for k, v in list(viewitems(self.kargs)): + if isinstance(v, list): + v = tuple(v) + kargs.append((k, v)) + l = [self.strbits, self.l, self.cls, + self.fname, self.order, self.lmask, self.fbits, + self.fmask, self.value] # + kargs + + return hash(tuple(l)) + + +class bs_divert(object): + prio = default_prio + + def __init__(self, **kargs): + self.args = kargs + + def __getattr__(self, item): + if item in self.__dict__: + return self.__dict__[item] + elif item in self.args: + return self.args.get(item) + else: + raise AttributeError + + +class bs_name(bs_divert): + prio = 1 + + def divert(self, i, candidates): + out = [] + for cls, _, bases, dct, fields in candidates: + for new_name, value in viewitems(self.args['name']): + nfields = fields[:] + s = int2bin(value, self.args['l']) + args = dict(self.args) + args.update({'strbits': s}) + f = bs(**args) + nfields[i] = f + ndct = dict(dct) + ndct['name'] = new_name + out.append((cls, new_name, bases, ndct, nfields)) + return out + + +class bs_mod_name(bs_divert): + prio = 2 + + def divert(self, i, candidates): + out = [] + for cls, _, bases, dct, fields in candidates: + tab = self.args['mn_mod'] + if isinstance(tab, list): + tmp = {} + for j, v in enumerate(tab): + tmp[j] = v + tab = tmp + for value, new_name in viewitems(tab): + nfields = fields[:] + s = int2bin(value, self.args['l']) + args = dict(self.args) + args.update({'strbits': s}) + f = bs(**args) + nfields[i] = f + ndct = dict(dct) + ndct['name'] = self.modname(ndct['name'], value) + out.append((cls, new_name, bases, ndct, nfields)) + return out + + def modname(self, name, i): + return name + self.args['mn_mod'][i] + + +class bs_cond(bsi): + pass + + +class bs_swapargs(bs_divert): + + def divert(self, i, candidates): + out = [] + for cls, name, bases, dct, fields in candidates: + # args not permuted + ndct = dict(dct) + nfields = fields[:] + # gen fix field + f = gen_bsint(0, self.args['l'], self.args) + nfields[i] = f + out.append((cls, name, bases, ndct, nfields)) + + # args permuted + ndct = dict(dct) + nfields = fields[:] + ap = ndct['args_permut'][:] + a = ap.pop(0) + b = ap.pop(0) + ndct['args_permut'] = [b, a] + ap + # gen fix field + f = gen_bsint(1, self.args['l'], self.args) + nfields[i] = f + + out.append((cls, name, bases, ndct, nfields)) + return out + + +class m_arg(object): + + def fromstring(self, text, loc_db, parser_result=None): + if parser_result: + e, start, stop = parser_result[self.parser] + self.expr = e + return start, stop + try: + v, start, stop = next(self.parser.scanString(text)) + except StopIteration: + return None, None + arg = v[0] + expr = self.asm_ast_to_expr(arg, loc_db) + self.expr = expr + return start, stop + + def asm_ast_to_expr(self, arg, loc_db, **kwargs): + raise NotImplementedError("Virtual") + + +class m_reg(m_arg): + prio = default_prio + + @property + def parser(self): + return self.reg.parser + + def decode(self, v): + self.expr = self.reg.expr[0] + return True + + def encode(self): + return self.expr == self.reg.expr[0] + + +class reg_noarg(object): + reg_info = None + parser = None + + def fromstring(self, text, loc_db, parser_result=None): + if parser_result: + e, start, stop = parser_result[self.parser] + self.expr = e + return start, stop + try: + v, start, stop = next(self.parser.scanString(text)) + except StopIteration: + return None, None + arg = v[0] + expr = self.parses_to_expr(arg, loc_db) + self.expr = expr + return start, stop + + def decode(self, v): + v = v & self.lmask + if v >= len(self.reg_info.expr): + return False + self.expr = self.reg_info.expr[v] + return True + + def encode(self): + if not self.expr in self.reg_info.expr: + log.debug("cannot encode reg %r", self.expr) + return False + self.value = self.reg_info.expr.index(self.expr) + return True + + def check_fbits(self, v): + return v & self.fmask == self.fbits + + +class mn_prefix(object): + pass + + +def swap16(v): + return struct.unpack('<H', struct.pack('>H', v))[0] + + +def swap32(v): + return struct.unpack('<I', struct.pack('>I', v))[0] + + +def perm_inv(p): + o = [None for x in range(len(p))] + for i, x in enumerate(p): + o[x] = i + return o + + +def gen_bsint(value, l, args): + s = int2bin(value, l) + args = dict(args) + args.update({'strbits': s}) + f = bs(**args) + return f + +total_scans = 0 + + +def branch2nodes(branch, nodes=None): + if nodes is None: + nodes = [] + for k, v in viewitems(branch): + if not isinstance(v, dict): + continue + for k2 in v: + nodes.append((k, k2)) + branch2nodes(v, nodes) + + +def factor_one_bit(tree): + if isinstance(tree, set): + return tree + new_keys = defaultdict(lambda: defaultdict(dict)) + if len(tree) == 1: + return tree + for k, v in viewitems(tree): + if k == "mn": + new_keys[k] = v + continue + l, fmask, fbits, fname, flen = k + if flen is not None or l <= 1: + new_keys[k] = v + continue + cfmask = fmask >> (l - 1) + nfmask = fmask & ((1 << (l - 1)) - 1) + cfbits = fbits >> (l - 1) + nfbits = fbits & ((1 << (l - 1)) - 1) + ck = 1, cfmask, cfbits, None, flen + nk = l - 1, nfmask, nfbits, fname, flen + if nk in new_keys[ck]: + raise NotImplementedError('not fully functional') + new_keys[ck][nk] = v + for k, v in list(viewitems(new_keys)): + new_keys[k] = factor_one_bit(v) + # try factor sons + if len(new_keys) != 1: + return new_keys + subtree = next(iter(viewvalues(new_keys))) + if len(subtree) != 1: + return new_keys + if next(iter(subtree)) == 'mn': + return new_keys + + return new_keys + + +def factor_fields(tree): + if not isinstance(tree, dict): + return tree + if len(tree) != 1: + return tree + # merge + k1, v1 = next(iter(viewitems(tree))) + if k1 == "mn": + return tree + l1, fmask1, fbits1, fname1, flen1 = k1 + if fname1 is not None: + return tree + if flen1 is not None: + return tree + + if not isinstance(v1, dict): + return tree + if len(v1) != 1: + return tree + k2, v2 = next(iter(viewitems(v1))) + if k2 == "mn": + return tree + l2, fmask2, fbits2, fname2, flen2 = k2 + if fname2 is not None: + return tree + if flen2 is not None: + return tree + l = l1 + l2 + fmask = (fmask1 << l2) | fmask2 + fbits = (fbits1 << l2) | fbits2 + fname = fname2 + flen = flen2 + k = l, fmask, fbits, fname, flen + new_keys = {k: v2} + return new_keys + + +def factor_fields_all(tree): + if not isinstance(tree, dict): + return tree + new_keys = {} + for k, v in viewitems(tree): + v = factor_fields(v) + new_keys[k] = factor_fields_all(v) + return new_keys + + +def graph_tree(tree): + nodes = [] + branch2nodes(tree, nodes) + + out = """ + digraph G { + """ + for a, b in nodes: + if b == 'mn': + continue + out += "%s -> %s;\n" % (id(a), id(b)) + out += "}" + open('graph.txt', 'w').write(out) + + +def add_candidate_to_tree(tree, c): + branch = tree + for f in c.fields: + if f.l == 0: + continue + node = f.l, f.fmask, f.fbits, f.fname, f.flen + + if not node in branch: + branch[node] = {} + branch = branch[node] + if not 'mn' in branch: + branch['mn'] = set() + branch['mn'].add(c) + + +def add_candidate(bases, c): + add_candidate_to_tree(bases[0].bintree, c) + + +def getfieldby_name(fields, fname): + f = [x for x in fields if hasattr(x, 'fname') and x.fname == fname] + if len(f) != 1: + raise ValueError('more than one field with name: %s' % fname) + return f[0] + + +def getfieldindexby_name(fields, fname): + for i, f in enumerate(fields): + if hasattr(f, 'fname') and f.fname == fname: + return f, i + return None + + +class metamn(type): + + def __new__(mcs, name, bases, dct): + if name == "cls_mn" or name.startswith('mn_'): + return type.__new__(mcs, name, bases, dct) + alias = dct.get('alias', False) + + fields = bases[0].mod_fields(dct['fields']) + if not 'name' in dct: + dct["name"] = bases[0].getmn(name) + if 'args' in dct: + # special case for permuted arguments + o = [] + p = [] + for i, a in enumerate(dct['args']): + o.append((i, a)) + if a in fields: + p.append((fields.index(a), a)) + p.sort() + p = [x[1] for x in p] + p = [dct['args'].index(x) for x in p] + dct['args_permut'] = perm_inv(p) + # order fields + f_ordered = [x for x in enumerate(fields)] + f_ordered.sort(key=lambda x: (x[1].prio, x[0])) + candidates = bases[0].gen_modes(mcs, name, bases, dct, fields) + for i, fc in f_ordered: + if isinstance(fc, bs_divert): + candidates = fc.divert(i, candidates) + for cls, name, bases, dct, fields in candidates: + ndct = dict(dct) + fields = [f for f in fields if f] + ndct['fields'] = fields + ndct['mn_len'] = sum([x.l for x in fields]) + c = type.__new__(cls, name, bases, ndct) + c.alias = alias + c.check_mnemo(fields) + c.num = bases[0].num + bases[0].num += 1 + bases[0].all_mn.append(c) + mode = dct['mode'] + bases[0].all_mn_mode[mode].append(c) + bases[0].all_mn_name[c.name].append(c) + i = c() + i.init_class() + bases[0].all_mn_inst[c].append(i) + add_candidate(bases, c) + # gen byte lookup + o = "" + for f in i.fields_order: + if not isinstance(f, bsi): + raise ValueError('f is not bsi') + if f.l == 0: + continue + o += f.strbits + return c + + +class instruction(object): + __slots__ = ["name", "mode", "args", + "l", "b", "offset", "data", + "additional_info", "delayslot"] + + def __init__(self, name, mode, args, additional_info=None): + self.name = name + self.mode = mode + self.args = args + self.additional_info = additional_info + self.offset = None + self.l = None + self.b = None + self.delayslot = 0 + + def gen_args(self, args): + out = ', '.join([str(x) for x in args]) + return out + + def __str__(self): + return self.to_string() + + def to_string(self, loc_db=None): + o = "%-10s " % self.name + args = [] + for i, arg in enumerate(self.args): + if not isinstance(arg, m2_expr.Expr): + raise ValueError('zarb arg type') + x = self.arg2str(arg, i, loc_db) + args.append(x) + o += self.gen_args(args) + return o + + def to_html(self, loc_db=None): + out = "%-10s " % self.name + out = '<font color="%s">%s</font>' % (utils.COLOR_MNEMO, out) + + args = [] + for i, arg in enumerate(self.args): + if not isinstance(arg, m2_expr.Expr): + raise ValueError('zarb arg type') + x = self.arg2html(arg, i, loc_db) + args.append(x) + out += self.gen_args(args) + return out + + def get_asm_offset(self, expr): + return m2_expr.ExprInt(self.offset, expr.size) + + def get_asm_next_offset(self, expr): + return m2_expr.ExprInt(self.offset+self.l, expr.size) + + def resolve_args_with_symbols(self, loc_db): + args_out = [] + for expr in self.args: + # try to resolve symbols using loc_db (0 for default value) + loc_keys = m2_expr.get_expr_locs(expr) + fixed_expr = {} + for exprloc in loc_keys: + loc_key = exprloc.loc_key + names = loc_db.get_location_names(loc_key) + # special symbols + if '$' in names: + fixed_expr[exprloc] = self.get_asm_offset(exprloc) + continue + if '_' in names: + fixed_expr[exprloc] = self.get_asm_next_offset(exprloc) + continue + arg_int = loc_db.get_location_offset(loc_key) + if arg_int is not None: + fixed_expr[exprloc] = m2_expr.ExprInt(arg_int, exprloc.size) + continue + if not names: + raise ValueError('Unresolved symbol: %r' % exprloc) + + offset = loc_db.get_location_offset(loc_key) + if offset is None: + raise ValueError( + 'The offset of loc_key "%s" cannot be determined' % names + ) + else: + # Fix symbol with its offset + size = exprloc.size + if size is None: + default_size = self.get_symbol_size(exprloc, loc_db) + size = default_size + value = m2_expr.ExprInt(offset, size) + fixed_expr[exprloc] = value + + expr = expr.replace_expr(fixed_expr) + expr = expr_simp(expr) + args_out.append(expr) + return args_out + + def get_info(self, c): + return + + +class cls_mn(with_metaclass(metamn, object)): + args_symb = [] + instruction = instruction + # Block's offset alignment + alignment = 1 + + @classmethod + def guess_mnemo(cls, bs, attrib, pre_dis_info, offset): + candidates = [] + + candidates = set() + + fname_values = pre_dis_info + todo = [ + (dict(fname_values), branch, offset * 8) + for branch in list(viewitems(cls.bintree)) + ] + for fname_values, branch, offset_b in todo: + (l, fmask, fbits, fname, flen), vals = branch + + if flen is not None: + l = flen(attrib, fname_values) + if l is not None: + try: + v = cls.getbits(bs, attrib, offset_b, l) + except IOError: + # Raised if offset is out of bound + continue + offset_b += l + if v & fmask != fbits: + continue + if fname is not None and not fname in fname_values: + fname_values[fname] = v + for nb, v in viewitems(vals): + if 'mn' in nb: + candidates.update(v) + else: + todo.append((dict(fname_values), (nb, v), offset_b)) + + return [c for c in candidates] + + def reset_class(self): + for f in self.fields_order: + if f.strbits and isbin(f.strbits): + f.value = int(f.strbits, 2) + elif 'default_val' in f.kargs: + f.value = int(f.kargs['default_val'], 2) + else: + f.value = None + if f.fname: + setattr(self, f.fname, f) + + def init_class(self): + args = [] + fields_order = [] + to_decode = [] + off = 0 + for i, fc in enumerate(self.fields): + f = fc.gen(self) + f.offset = off + off += f.l + fields_order.append(f) + to_decode.append((i, f)) + + if isinstance(f, m_arg): + args.append(f) + if f.fname: + setattr(self, f.fname, f) + if hasattr(self, 'args_permut'): + args = [args[self.args_permut[i]] + for i in range(len(self.args_permut))] + to_decode.sort(key=lambda x: (x[1].order, x[0])) + to_decode = [fields_order.index(f[1]) for f in to_decode] + self.args = args + self.fields_order = fields_order + self.to_decode = to_decode + + def add_pre_dis_info(self, prefix=None): + return True + + @classmethod + def getbits(cls, bs, attrib, offset_b, l): + return bs.getbits(offset_b, l) + + @classmethod + def getbytes(cls, bs, offset, l): + return bs.getbytes(offset, l) + + @classmethod + def pre_dis(cls, v_o, attrib, offset): + return {}, v_o, attrib, offset, 0 + + def post_dis(self): + return self + + @classmethod + def check_mnemo(cls, fields): + pass + + @classmethod + def mod_fields(cls, fields): + return fields + + @classmethod + def dis(cls, bs_o, mode_o = None, offset=0): + if not isinstance(bs_o, bin_stream): + bs_o = bin_stream_str(bs_o) + + bs_o.enter_atomic_mode() + + offset_o = offset + try: + pre_dis_info, bs, mode, offset, prefix_len = cls.pre_dis( + bs_o, mode_o, offset) + except: + bs_o.leave_atomic_mode() + raise + candidates = cls.guess_mnemo(bs, mode, pre_dis_info, offset) + if not candidates: + bs_o.leave_atomic_mode() + raise Disasm_Exception('cannot disasm (guess) at %X' % offset) + + out = [] + out_c = [] + if hasattr(bs, 'getlen'): + bs_l = bs.getlen() + else: + bs_l = len(bs) + + alias = False + for c in candidates: + log.debug("*" * 40, mode, c.mode) + log.debug(c.fields) + + c = cls.all_mn_inst[c][0] + + c.reset_class() + c.mode = mode + + if not c.add_pre_dis_info(pre_dis_info): + continue + + todo = {} + getok = True + fname_values = dict(pre_dis_info) + offset_b = offset * 8 + + total_l = 0 + for i, f in enumerate(c.fields_order): + if f.flen is not None: + l = f.flen(mode, fname_values) + else: + l = f.l + if l is not None: + total_l += l + f.l = l + f.is_present = True + log.debug("FIELD %s %s %s %s", f.__class__, f.fname, + offset_b, l) + if bs_l * 8 - offset_b < l: + getok = False + break + try: + bv = cls.getbits(bs, mode, offset_b, l) + except: + bs_o.leave_atomic_mode() + raise + offset_b += l + if not f.fname in fname_values: + fname_values[f.fname] = bv + todo[i] = bv + else: + f.is_present = False + todo[i] = None + + if not getok: + continue + + c.l = prefix_len + total_l // 8 + for i in c.to_decode: + f = c.fields_order[i] + if f.is_present: + ret = f.decode(todo[i]) + if not ret: + log.debug("cannot decode %r", f) + break + + if not ret: + continue + for a in c.args: + a.expr = expr_simp(a.expr) + + c.b = cls.getbytes(bs, offset_o, c.l) + c.offset = offset_o + c = c.post_dis() + if c is None: + continue + c_args = [a.expr for a in c.args] + instr = cls.instruction(c.name, mode, c_args, + additional_info=c.additional_info()) + instr.l = prefix_len + total_l // 8 + instr.b = cls.getbytes(bs, offset_o, instr.l) + instr.offset = offset_o + instr.get_info(c) + if c.alias: + alias = True + out.append(instr) + out_c.append(c) + + bs_o.leave_atomic_mode() + + if not out: + raise Disasm_Exception('cannot disasm at %X' % offset_o) + if len(out) != 1: + if not alias: + log.warning('dis multiple args ret default') + + for i, o in enumerate(out_c): + if o.alias: + return out[i] + raise NotImplementedError( + 'Multiple disas: \n' + + "\n".join(str(x) for x in out) + ) + return out[0] + + @classmethod + def fromstring(cls, text, loc_db, mode = None): + global total_scans + name = re.search(r'(\S+)', text).groups() + if not name: + raise ValueError('cannot find name', text) + name = name[0] + + if not name in cls.all_mn_name: + raise ValueError('unknown name', name) + clist = [x for x in cls.all_mn_name[name]] + out = [] + out_args = [] + parsers = defaultdict(dict) + + for cc in clist: + for c in cls.get_cls_instance(cc, mode): + args_expr = [] + args_str = text[len(name):].strip(' ') + + start = 0 + cannot_parse = False + len_o = len(args_str) + + for i, f in enumerate(c.args): + start_i = len_o - len(args_str) + if type(f.parser) == tuple: + parser = f.parser + else: + parser = (f.parser,) + for p in parser: + if p in parsers[(i, start_i)]: + continue + try: + total_scans += 1 + v, start, stop = next(p.scanString(args_str)) + except StopIteration: + v, start, stop = [None], None, None + if start != 0: + v, start, stop = [None], None, None + if v != [None]: + v = f.asm_ast_to_expr(v[0], loc_db) + if v is None: + v, start, stop = [None], None, None + parsers[(i, start_i)][p] = v, start, stop + start, stop = f.fromstring(args_str, loc_db, parsers[(i, start_i)]) + if start != 0: + log.debug("cannot fromstring %r", args_str) + cannot_parse = True + break + if f.expr is None: + raise NotImplementedError('not fully functional') + f.expr = expr_simp(f.expr) + args_expr.append(f.expr) + args_str = args_str[stop:].strip(' ') + if args_str.startswith(','): + args_str = args_str[1:] + args_str = args_str.strip(' ') + if args_str: + cannot_parse = True + if cannot_parse: + continue + + out.append(c) + out_args.append(args_expr) + break + + if len(out) == 0: + raise ValueError('cannot fromstring %r' % text) + if len(out) != 1: + log.debug('fromstring multiple args ret default') + c = out[0] + c_args = out_args[0] + + instr = cls.instruction(c.name, mode, c_args, + additional_info=c.additional_info()) + return instr + + def dup_info(self, infos): + return + + @classmethod + def get_cls_instance(cls, cc, mode, infos=None): + c = cls.all_mn_inst[cc][0] + + c.reset_class() + c.add_pre_dis_info() + c.dup_info(infos) + + c.mode = mode + yield c + + @classmethod + def asm(cls, instr, loc_db=None): + """ + Re asm instruction by searching mnemo using name and args. We then + can modify args and get the hex of a modified instruction + """ + clist = cls.all_mn_name[instr.name] + clist = [x for x in clist] + vals = [] + candidates = [] + args = instr.resolve_args_with_symbols(loc_db) + + for cc in clist: + + for c in cls.get_cls_instance( + cc, instr.mode, instr.additional_info): + + cannot_parse = False + if len(c.args) != len(instr.args): + continue + + # only fix args expr + for i in range(len(c.args)): + c.args[i].expr = args[i] + + v = c.value(instr.mode) + if not v: + log.debug("cannot encode %r", c) + cannot_parse = True + if cannot_parse: + continue + vals += v + candidates.append((c, v)) + if len(vals) == 0: + raise ValueError( + 'cannot asm %r %r' % + (instr.name, [str(x) for x in instr.args]) + ) + if len(vals) != 1: + log.debug('asm multiple args ret default') + + vals = cls.filter_asm_candidates(instr, candidates) + return vals + + @classmethod + def filter_asm_candidates(cls, instr, candidates): + o = [] + for _, v in candidates: + o += v + o.sort(key=len) + return o + + def value(self, mode): + todo = [(0, 0, [(x, self.fields_order[x]) for x in self.to_decode[::-1]])] + + result = [] + done = [] + + while todo: + index, cur_len, to_decode = todo.pop() + # TEST XXX + for _, f in to_decode: + setattr(self, f.fname, f) + if (index, [x[1].value for x in to_decode]) in done: + continue + done.append((index, [x[1].value for x in to_decode])) + + can_encode = True + for i, f in to_decode[index:]: + f.parent.l = cur_len + ret = f.encode() + if not ret: + log.debug('cannot encode %r', f) + can_encode = False + break + + if f.value is not None and f.l: + if f.value > f.lmask: + log.debug('cannot encode %r', f) + can_encode = False + break + cur_len += f.l + index += 1 + if ret is True: + continue + + for _ in ret: + o = [] + if ((index, cur_len, [xx[1].value for xx in to_decode]) in todo or + (index, cur_len, [xx[1].value for xx in to_decode]) in done): + raise NotImplementedError('not fully functional') + + for p, f in to_decode: + fnew = f.clone() + o.append((p, fnew)) + todo.append((index, cur_len, o)) + can_encode = False + + break + if not can_encode: + continue + result.append(to_decode) + + return self.decoded2bytes(result) + + def encodefields(self, decoded): + bits = bitobj() + for _, f in decoded: + setattr(self, f.fname, f) + + if f.value is None: + continue + bits.putbits(f.value, f.l) + + return bits.tostring() + + def decoded2bytes(self, result): + if not result: + return [] + + out = [] + for decoded in result: + decoded.sort() + + o = self.encodefields(decoded) + if o is None: + continue + out.append(o) + out = list(set(out)) + return out + + def gen_args(self, args): + out = ', '.join([str(x) for x in args]) + return out + + def args2str(self): + args = [] + for arg in self.args: + # XXX todo test + if not (isinstance(arg, m2_expr.Expr) or + isinstance(arg.expr, m2_expr.Expr)): + raise ValueError('zarb arg type') + x = str(arg) + args.append(x) + return args + + def __str__(self): + o = "%-10s " % self.name + args = [] + for arg in self.args: + # XXX todo test + if not (isinstance(arg, m2_expr.Expr) or + isinstance(arg.expr, m2_expr.Expr)): + raise ValueError('zarb arg type') + x = str(arg) + args.append(x) + + o += self.gen_args(args) + return o + + def parse_prefix(self, v): + return 0 + + def set_dst_symbol(self, loc_db): + dst = self.getdstflow(loc_db) + args = [] + for d in dst: + if isinstance(d, m2_expr.ExprInt): + l = loc_db.get_or_create_offset_location(int(d)) + + a = m2_expr.ExprId(l.name, d.size) + else: + a = d + args.append(a) + self.args_symb = args + + def getdstflow(self, loc_db): + return [self.args[0].expr] + + +class imm_noarg(object): + intsize = 32 + intmask = (1 << intsize) - 1 + + def int2expr(self, v): + if (v & ~self.intmask) != 0: + return None + return m2_expr.ExprInt(v, self.intsize) + + def expr2int(self, e): + if not isinstance(e, m2_expr.ExprInt): + return None + v = int(e) + if v & ~self.intmask != 0: + return None + return v + + def fromstring(self, text, loc_db, parser_result=None): + if parser_result: + e, start, stop = parser_result[self.parser] + else: + try: + e, start, stop = next(self.parser.scanString(text)) + except StopIteration: + return None, None + if e == [None]: + return None, None + + assert(m2_expr.is_expr(e)) + self.expr = e + if self.expr is None: + log.debug('cannot fromstring int %r', text) + return None, None + return start, stop + + def decodeval(self, v): + return v + + def encodeval(self, v): + return v + + def decode(self, v): + v = v & self.lmask + v = self.decodeval(v) + e = self.int2expr(v) + if not e: + return False + self.expr = e + return True + + def encode(self): + v = self.expr2int(self.expr) + if v is None: + return False + v = self.encodeval(v) + if v is False: + return False + self.value = v + return True + + +class imm08_noarg(object): + int2expr = lambda self, x: m2_expr.ExprInt(x, 8) + + +class imm16_noarg(object): + int2expr = lambda self, x: m2_expr.ExprInt(x, 16) + + +class imm32_noarg(object): + int2expr = lambda self, x: m2_expr.ExprInt(x, 32) + + +class imm64_noarg(object): + int2expr = lambda self, x: m2_expr.ExprInt(x, 64) + + +class int32_noarg(imm_noarg): + intsize = 32 + intmask = (1 << intsize) - 1 + + def decode(self, v): + v = sign_ext(v, self.l, self.intsize) + v = self.decodeval(v) + self.expr = self.int2expr(v) + return True + + def encode(self): + if not isinstance(self.expr, m2_expr.ExprInt): + return False + v = int(self.expr) + if sign_ext(v & self.lmask, self.l, self.intsize) != v: + return False + v = self.encodeval(v & self.lmask) + if v is False: + return False + self.value = v & self.lmask + return True + +class bs8(bs): + prio = default_prio + + def __init__(self, v, cls=None, fname=None, **kargs): + super(bs8, self).__init__(int2bin(v, 8), 8, + cls=cls, fname=fname, **kargs) + + + + +def swap_uint(size, i): + if size == 8: + return i & 0xff + elif size == 16: + return struct.unpack('<H', struct.pack('>H', i & 0xffff))[0] + elif size == 32: + return struct.unpack('<I', struct.pack('>I', i & 0xffffffff))[0] + elif size == 64: + return struct.unpack('<Q', struct.pack('>Q', i & 0xffffffffffffffff))[0] + raise ValueError('unknown int len %r' % size) + + +def swap_sint(size, i): + if size == 8: + return i + elif size == 16: + return struct.unpack('<h', struct.pack('>H', i & 0xffff))[0] + elif size == 32: + return struct.unpack('<i', struct.pack('>I', i & 0xffffffff))[0] + elif size == 64: + return struct.unpack('<q', struct.pack('>Q', i & 0xffffffffffffffff))[0] + raise ValueError('unknown int len %r' % size) + + +def sign_ext(v, s_in, s_out): + assert(s_in <= s_out) + v &= (1 << s_in) - 1 + sign_in = v & (1 << (s_in - 1)) + if not sign_in: + return v + m = (1 << (s_out)) - 1 + m ^= (1 << s_in) - 1 + v |= m + return v diff --git a/src/miasm/core/ctypesmngr.py b/src/miasm/core/ctypesmngr.py new file mode 100644 index 00000000..94c96f7e --- /dev/null +++ b/src/miasm/core/ctypesmngr.py @@ -0,0 +1,771 @@ +import re + +from pycparser import c_parser, c_ast + +RE_HASH_CMT = re.compile(r'^#\s*\d+.*$', flags=re.MULTILINE) + +# Ref: ISO/IEC 9899:TC2 +# http://www.open-std.org/jtc1/sc22/wg14/www/docs/n1124.pdf + + +def c_to_ast(parser, c_str): + """Transform a @c_str into a C ast + Note: will ignore lines containing code refs ie: + # 23 "miasm.h" + + @parser: pycparser instance + @c_str: c string + """ + + new_str = re.sub(RE_HASH_CMT, "", c_str) + return parser.parse(new_str, filename='<stdin>') + + +class CTypeBase(object): + """Object to represent the 3 forms of C type: + * object types + * function types + * incomplete types + """ + + def __init__(self): + self.__repr = str(self) + self.__hash = hash(self.__repr) + + @property + def _typerepr(self): + return self.__repr + + def __eq__(self, other): + raise NotImplementedError("Abstract method") + + def __ne__(self, other): + return not self.__eq__(other) + + def eq_base(self, other): + """Trivial common equality test""" + return self.__class__ == other.__class__ + + def __hash__(self): + return self.__hash + + def __repr__(self): + return self._typerepr + + +class CTypeId(CTypeBase): + """C type id: + int + unsigned int + """ + + def __init__(self, *names): + # Type specifier order does not matter + # so the canonical form is ordered + self.names = tuple(sorted(names)) + super(CTypeId, self).__init__() + + def __hash__(self): + return hash((self.__class__, self.names)) + + def __eq__(self, other): + return (self.eq_base(other) and + self.names == other.names) + + def __ne__(self, other): + return not self.__eq__(other) + + def __str__(self): + return "<Id:%s>" % ', '.join(self.names) + + +class CTypeArray(CTypeBase): + """C type for array: + typedef int XXX[4]; + """ + + def __init__(self, target, size): + assert isinstance(target, CTypeBase) + self.target = target + self.size = size + super(CTypeArray, self).__init__() + + def __hash__(self): + return hash((self.__class__, self.target, self.size)) + + def __eq__(self, other): + return (self.eq_base(other) and + self.target == other.target and + self.size == other.size) + + def __ne__(self, other): + return not self.__eq__(other) + + def __str__(self): + return "<Array[%s]:%s>" % (self.size, str(self.target)) + + +class CTypePtr(CTypeBase): + """C type for pointer: + typedef int* XXX; + """ + + def __init__(self, target): + assert isinstance(target, CTypeBase) + self.target = target + super(CTypePtr, self).__init__() + + def __hash__(self): + return hash((self.__class__, self.target)) + + def __eq__(self, other): + return (self.eq_base(other) and + self.target == other.target) + + def __ne__(self, other): + return not self.__eq__(other) + + def __str__(self): + return "<Ptr:%s>" % str(self.target) + + +class CTypeStruct(CTypeBase): + """C type for structure""" + + def __init__(self, name, fields=None): + assert name is not None + self.name = name + if fields is None: + fields = () + for field_name, field in fields: + assert field_name is not None + assert isinstance(field, CTypeBase) + self.fields = tuple(fields) + super(CTypeStruct, self).__init__() + + def __hash__(self): + return hash((self.__class__, self.name, self.fields)) + + def __eq__(self, other): + return (self.eq_base(other) and + self.name == other.name and + self.fields == other.fields) + + def __ne__(self, other): + return not self.__eq__(other) + + def __str__(self): + out = [] + out.append("<Struct:%s>" % self.name) + for name, field in self.fields: + out.append("\t%-10s %s" % (name, field)) + return '\n'.join(out) + + +class CTypeUnion(CTypeBase): + """C type for union""" + + def __init__(self, name, fields=None): + assert name is not None + self.name = name + if fields is None: + fields = [] + for field_name, field in fields: + assert field_name is not None + assert isinstance(field, CTypeBase) + self.fields = tuple(fields) + super(CTypeUnion, self).__init__() + + def __hash__(self): + return hash((self.__class__, self.name, self.fields)) + + def __eq__(self, other): + return (self.eq_base(other) and + self.name == other.name and + self.fields == other.fields) + + def __str__(self): + out = [] + out.append("<Union:%s>" % self.name) + for name, field in self.fields: + out.append("\t%-10s %s" % (name, field)) + return '\n'.join(out) + + +class CTypeEnum(CTypeBase): + """C type for enums""" + + def __init__(self, name): + self.name = name + super(CTypeEnum, self).__init__() + + def __hash__(self): + return hash((self.__class__, self.name)) + + def __eq__(self, other): + return (self.eq_base(other) and + self.name == other.name) + + def __ne__(self, other): + return not self.__eq__(other) + + def __str__(self): + return "<Enum:%s>" % self.name + + +class CTypeFunc(CTypeBase): + """C type for enums""" + + def __init__(self, name, abi=None, type_ret=None, args=None): + if type_ret: + assert isinstance(type_ret, CTypeBase) + if args: + for arg_name, arg in args: + assert isinstance(arg, CTypeBase) + args = tuple(args) + else: + args = tuple() + self.name = name + self.abi = abi + self.type_ret = type_ret + self.args = args + super(CTypeFunc, self).__init__() + + def __hash__(self): + return hash((self.__class__, self.name, self.abi, + self.type_ret, self.args)) + + def __eq__(self, other): + return (self.eq_base(other) and + self.name == other.name and + self.abi == other.abi and + self.type_ret == other.type_ret and + self.args == other.args) + + def __ne__(self, other): + return not self.__eq__(other) + + def __str__(self): + return "<Func:%s (%s) %s(%s)>" % (self.type_ret, + self.abi, + self.name, + ", ".join(["%s %s" % (name, arg) for (name, arg) in self.args])) + + +class CTypeEllipsis(CTypeBase): + """C type for ellipsis argument (...)""" + + def __hash__(self): + return hash((self.__class__)) + + def __eq__(self, other): + return self.eq_base(other) + + def __ne__(self, other): + return not self.__eq__(other) + + def __str__(self): + return "<Ellipsis>" + + +class CTypeSizeof(CTypeBase): + """C type for sizeof""" + + def __init__(self, target): + self.target = target + super(CTypeSizeof, self).__init__() + + def __hash__(self): + return hash((self.__class__, self.target)) + + def __eq__(self, other): + return (self.eq_base(other) and + self.target == other.target) + + def __ne__(self, other): + return not self.__eq__(other) + + def __str__(self): + return "<Sizeof(%s)>" % self.target + + +class CTypeOp(CTypeBase): + """C type for operator (+ * ...)""" + + def __init__(self, operator, *args): + self.operator = operator + self.args = tuple(args) + super(CTypeOp, self).__init__() + + def __hash__(self): + return hash((self.__class__, self.operator, self.args)) + + def __eq__(self, other): + return (self.eq_base(other) and + self.operator == other.operator and + self.args == other.args) + + def __str__(self): + return "<CTypeOp(%s, %s)>" % (self.operator, + ', '.join([str(arg) for arg in self.args])) + + +class FuncNameIdentifier(c_ast.NodeVisitor): + """Visit an c_ast to find IdentifierType""" + + def __init__(self): + super(FuncNameIdentifier, self).__init__() + self.node_name = None + + def visit_TypeDecl(self, node): + """Retrieve the name in a function declaration: + Only one IdentifierType is present""" + self.node_name = node + + +class CAstTypes(object): + """Store all defined C types and typedefs""" + INTERNAL_PREFIX = "__GENTYPE__" + ANONYMOUS_PREFIX = "__ANONYMOUS__" + + def __init__(self, knowntypes=None, knowntypedefs=None): + if knowntypes is None: + knowntypes = {} + if knowntypedefs is None: + knowntypedefs = {} + + self._types = dict(knowntypes) + self._typedefs = dict(knowntypedefs) + self.cpt = 0 + self.loc_to_decl_info = {} + self.parser = c_parser.CParser() + self._cpt_decl = 0 + + + self.ast_to_typeid_rules = { + c_ast.Struct: self.ast_to_typeid_struct, + c_ast.Union: self.ast_to_typeid_union, + c_ast.IdentifierType: self.ast_to_typeid_identifiertype, + c_ast.TypeDecl: self.ast_to_typeid_typedecl, + c_ast.Decl: self.ast_to_typeid_decl, + c_ast.Typename: self.ast_to_typeid_typename, + c_ast.FuncDecl: self.ast_to_typeid_funcdecl, + c_ast.Enum: self.ast_to_typeid_enum, + c_ast.PtrDecl: self.ast_to_typeid_ptrdecl, + c_ast.EllipsisParam: self.ast_to_typeid_ellipsisparam, + c_ast.ArrayDecl: self.ast_to_typeid_arraydecl, + } + + self.ast_parse_rules = { + c_ast.Struct: self.ast_parse_struct, + c_ast.Union: self.ast_parse_union, + c_ast.Typedef: self.ast_parse_typedef, + c_ast.TypeDecl: self.ast_parse_typedecl, + c_ast.IdentifierType: self.ast_parse_identifiertype, + c_ast.Decl: self.ast_parse_decl, + c_ast.PtrDecl: self.ast_parse_ptrdecl, + c_ast.Enum: self.ast_parse_enum, + c_ast.ArrayDecl: self.ast_parse_arraydecl, + c_ast.FuncDecl: self.ast_parse_funcdecl, + c_ast.FuncDef: self.ast_parse_funcdef, + c_ast.Pragma: self.ast_parse_pragma, + } + + def gen_uniq_name(self): + """Generate uniq name for unnamed strucs/union""" + cpt = self.cpt + self.cpt += 1 + return self.INTERNAL_PREFIX + "%d" % cpt + + def gen_anon_name(self): + """Generate name for anonymous strucs/union""" + cpt = self.cpt + self.cpt += 1 + return self.ANONYMOUS_PREFIX + "%d" % cpt + + def is_generated_name(self, name): + """Return True if the name is internal""" + return name.startswith(self.INTERNAL_PREFIX) + + def is_anonymous_name(self, name): + """Return True if the name is anonymous""" + return name.startswith(self.ANONYMOUS_PREFIX) + + def add_type(self, type_id, type_obj): + """Add new C type + @type_id: Type descriptor (CTypeBase instance) + @type_obj: Obj* instance""" + assert isinstance(type_id, CTypeBase) + if type_id in self._types: + assert self._types[type_id] == type_obj + else: + self._types[type_id] = type_obj + + def add_typedef(self, type_new, type_src): + """Add new typedef + @type_new: CTypeBase instance of the new type name + @type_src: CTypeBase instance of the target type""" + assert isinstance(type_src, CTypeBase) + self._typedefs[type_new] = type_src + + def get_type(self, type_id): + """Get ObjC corresponding to the @type_id + @type_id: Type descriptor (CTypeBase instance) + """ + assert isinstance(type_id, CTypeBase) + if isinstance(type_id, CTypePtr): + subobj = self.get_type(type_id.target) + return CTypePtr(subobj) + if type_id in self._types: + return self._types[type_id] + elif type_id in self._typedefs: + return self.get_type(self._typedefs[type_id]) + return type_id + + def is_known_type(self, type_id): + """Return true if @type_id is known + @type_id: Type descriptor (CTypeBase instance) + """ + if isinstance(type_id, CTypePtr): + return self.is_known_type(type_id.target) + if type_id in self._types: + return True + if type_id in self._typedefs: + return self.is_known_type(self._typedefs[type_id]) + return False + + def add_c_decl_from_ast(self, ast): + """ + Adds types from a C ast + @ast: C ast + """ + self.ast_parse_declarations(ast) + + + def digest_decl(self, c_str): + + char_id = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_" + + + # Seek deck + index_decl = [] + index = 0 + for decl in ['__cdecl__', '__stdcall__']: + index = 0 + while True: + index = c_str.find(decl, index) + if index == -1: + break + decl_off = index + decl_len = len(decl) + + index = index+len(decl) + while c_str[index] not in char_id: + index += 1 + + id_start = index + + while c_str[index] in char_id: + index += 1 + id_stop = index + + name = c_str[id_start:id_stop] + index_decl.append((decl_off, decl_len, id_start, id_stop, decl, )) + + index_decl.sort() + + # Remove decl + off = 0 + offsets = [] + for decl_off, decl_len, id_start, id_stop, decl in index_decl: + decl_off -= off + c_str = c_str[:decl_off] + c_str[decl_off+decl_len:] + off += decl_len + offsets.append((id_start-off, id_stop-off, decl)) + + index = 0 + lineno = 1 + + # Index to lineno, column + for id_start, id_stop, decl in offsets: + nbr = c_str.count('\n', index, id_start) + lineno += nbr + last_cr = c_str.rfind('\n', 0, id_start) + # column starts at 1 + column = id_start - last_cr + index = id_start + self.loc_to_decl_info[(lineno, column)] = decl + return c_str + + + def add_c_decl(self, c_str): + """ + Adds types from a C string types declaring + Note: will ignore lines containing code refs ie: + '# 23 "miasm.h"' + Returns the C ast + @c_str: C string containing C types declarations + """ + c_str = self.digest_decl(c_str) + + ast = c_to_ast(self.parser, c_str) + self.add_c_decl_from_ast(ast) + + return ast + + def ast_eval_int(self, ast): + """Eval a C ast object integer + + @ast: parsed pycparser.c_ast object + """ + + if isinstance(ast, c_ast.BinaryOp): + left = self.ast_eval_int(ast.left) + right = self.ast_eval_int(ast.right) + is_pure_int = (isinstance(left, int) and + isinstance(right, int)) + + if is_pure_int: + if ast.op == '*': + result = left * right + elif ast.op == '/': + assert left % right == 0 + result = left // right + elif ast.op == '+': + result = left + right + elif ast.op == '-': + result = left - right + elif ast.op == '<<': + result = left << right + elif ast.op == '>>': + result = left >> right + else: + raise NotImplementedError("Not implemented!") + else: + result = CTypeOp(ast.op, left, right) + + elif isinstance(ast, c_ast.UnaryOp): + if ast.op == 'sizeof' and isinstance(ast.expr, c_ast.Typename): + subobj = self.ast_to_typeid(ast.expr) + result = CTypeSizeof(subobj) + else: + raise NotImplementedError("Not implemented!") + + elif isinstance(ast, c_ast.Constant): + result = int(ast.value, 0) + elif isinstance(ast, c_ast.Cast): + # TODO: Can trunc integers? + result = self.ast_eval_int(ast.expr) + else: + raise NotImplementedError("Not implemented!") + return result + + def ast_to_typeid_struct(self, ast): + """Return the CTypeBase of an Struct ast""" + name = self.gen_uniq_name() if ast.name is None else ast.name + args = [] + if ast.decls: + for arg in ast.decls: + if arg.name is None: + arg_name = self.gen_anon_name() + else: + arg_name = arg.name + args.append((arg_name, self.ast_to_typeid(arg))) + decl = CTypeStruct(name, args) + return decl + + def ast_to_typeid_union(self, ast): + """Return the CTypeBase of an Union ast""" + name = self.gen_uniq_name() if ast.name is None else ast.name + args = [] + if ast.decls: + for arg in ast.decls: + if arg.name is None: + arg_name = self.gen_anon_name() + else: + arg_name = arg.name + args.append((arg_name, self.ast_to_typeid(arg))) + decl = CTypeUnion(name, args) + return decl + + def ast_to_typeid_identifiertype(self, ast): + """Return the CTypeBase of an IdentifierType ast""" + return CTypeId(*ast.names) + + def ast_to_typeid_typedecl(self, ast): + """Return the CTypeBase of a TypeDecl ast""" + return self.ast_to_typeid(ast.type) + + def ast_to_typeid_decl(self, ast): + """Return the CTypeBase of a Decl ast""" + return self.ast_to_typeid(ast.type) + + def ast_to_typeid_typename(self, ast): + """Return the CTypeBase of a TypeName ast""" + return self.ast_to_typeid(ast.type) + + def get_funcname(self, ast): + """Return the name of a function declaration ast""" + funcnameid = FuncNameIdentifier() + funcnameid.visit(ast) + node_name = funcnameid.node_name + if node_name.coord is not None: + lineno, column = node_name.coord.line, node_name.coord.column + decl_info = self.loc_to_decl_info.get((lineno, column), None) + else: + decl_info = None + return node_name.declname, decl_info + + def ast_to_typeid_funcdecl(self, ast): + """Return the CTypeBase of an FuncDecl ast""" + type_ret = self.ast_to_typeid(ast.type) + name, decl_info = self.get_funcname(ast.type) + if ast.args: + args = [] + for arg in ast.args.params: + typeid = self.ast_to_typeid(arg) + if isinstance(typeid, CTypeEllipsis): + arg_name = None + else: + arg_name = arg.name + args.append((arg_name, typeid)) + else: + args = [] + + obj = CTypeFunc(name, decl_info, type_ret, args) + decl = CTypeFunc(name) + if not self.is_known_type(decl): + self.add_type(decl, obj) + return obj + + def ast_to_typeid_enum(self, ast): + """Return the CTypeBase of an Enum ast""" + name = self.gen_uniq_name() if ast.name is None else ast.name + return CTypeEnum(name) + + def ast_to_typeid_ptrdecl(self, ast): + """Return the CTypeBase of a PtrDecl ast""" + return CTypePtr(self.ast_to_typeid(ast.type)) + + def ast_to_typeid_ellipsisparam(self, _): + """Return the CTypeBase of an EllipsisParam ast""" + return CTypeEllipsis() + + def ast_to_typeid_arraydecl(self, ast): + """Return the CTypeBase of an ArrayDecl ast""" + target = self.ast_to_typeid(ast.type) + if ast.dim is None: + value = None + else: + value = self.ast_eval_int(ast.dim) + return CTypeArray(target, value) + + def ast_to_typeid(self, ast): + """Return the CTypeBase of the @ast + @ast: pycparser.c_ast instance""" + cls = ast.__class__ + if not cls in self.ast_to_typeid_rules: + raise NotImplementedError("Strange type %r" % ast) + return self.ast_to_typeid_rules[cls](ast) + + # Ast parse type declarators + + def ast_parse_decl(self, ast): + """Parse ast Decl""" + return self.ast_parse_declaration(ast.type) + + def ast_parse_typedecl(self, ast): + """Parse ast Typedecl""" + return self.ast_parse_declaration(ast.type) + + def ast_parse_struct(self, ast): + """Parse ast Struct""" + obj = self.ast_to_typeid(ast) + if ast.decls and ast.name is not None: + # Add struct to types if named + decl = CTypeStruct(ast.name) + if not self.is_known_type(decl): + self.add_type(decl, obj) + return obj + + def ast_parse_union(self, ast): + """Parse ast Union""" + obj = self.ast_to_typeid(ast) + if ast.decls and ast.name is not None: + # Add union to types if named + decl = CTypeUnion(ast.name) + if not self.is_known_type(decl): + self.add_type(decl, obj) + return obj + + def ast_parse_typedef(self, ast): + """Parse ast TypeDef""" + decl = CTypeId(ast.name) + obj = self.ast_parse_declaration(ast.type) + if (isinstance(obj, (CTypeStruct, CTypeUnion)) and + self.is_generated_name(obj.name)): + # Add typedef name to default name + # for a question of clarity + obj.name += "__%s" % ast.name + self.add_typedef(decl, obj) + # Typedef does not return any object + return None + + def ast_parse_identifiertype(self, ast): + """Parse ast IdentifierType""" + return CTypeId(*ast.names) + + def ast_parse_ptrdecl(self, ast): + """Parse ast PtrDecl""" + return CTypePtr(self.ast_parse_declaration(ast.type)) + + def ast_parse_enum(self, ast): + """Parse ast Enum""" + return self.ast_to_typeid(ast) + + def ast_parse_arraydecl(self, ast): + """Parse ast ArrayDecl""" + return self.ast_to_typeid(ast) + + def ast_parse_funcdecl(self, ast): + """Parse ast FuncDecl""" + return self.ast_to_typeid(ast) + + def ast_parse_funcdef(self, ast): + """Parse ast FuncDef""" + return self.ast_to_typeid(ast.decl) + + def ast_parse_pragma(self, _): + """Prama does not return any object""" + return None + + def ast_parse_declaration(self, ast): + """Add one ast type declaration to the type manager + (packed style in type manager) + + @ast: parsed pycparser.c_ast object + """ + cls = ast.__class__ + if not cls in self.ast_parse_rules: + raise NotImplementedError("Strange declaration %r" % cls) + return self.ast_parse_rules[cls](ast) + + def ast_parse_declarations(self, ast): + """Add ast types declaration to the type manager + (packed style in type manager) + + @ast: parsed pycparser.c_ast object + """ + for ext in ast.ext: + ret = self.ast_parse_declaration(ext) + + def parse_c_type(self, c_str): + """Parse a C string representing a C type and return the associated + Miasm C object. + @c_str: C string of a C type + """ + + new_str = "%s __MIASM_INTERNAL_%s;" % (c_str, self._cpt_decl) + ret = self.parser.cparser.parse(input=new_str, lexer=self.parser.clex) + self._cpt_decl += 1 + return ret diff --git a/src/miasm/core/graph.py b/src/miasm/core/graph.py new file mode 100644 index 00000000..debea38e --- /dev/null +++ b/src/miasm/core/graph.py @@ -0,0 +1,1123 @@ +from collections import defaultdict, namedtuple + +from future.utils import viewitems, viewvalues +import re + + +class DiGraph(object): + + """Implementation of directed graph""" + + # Stand for a cell in a dot node rendering + DotCellDescription = namedtuple("DotCellDescription", + ["text", "attr"]) + + def __init__(self): + self._nodes = set() + self._edges = [] + # N -> Nodes N2 with a edge (N -> N2) + self._nodes_succ = {} + # N -> Nodes N2 with a edge (N2 -> N) + self._nodes_pred = {} + + self.escape_chars = re.compile('[' + re.escape('{}[]') + '&|<>' + ']') + + + def __repr__(self): + out = [] + for node in self._nodes: + out.append(str(node)) + for src, dst in self._edges: + out.append("%s -> %s" % (src, dst)) + return '\n'.join(out) + + def nodes(self): + return self._nodes + + def edges(self): + return self._edges + + def merge(self, graph): + """Merge the current graph with @graph + @graph: DiGraph instance + """ + for node in graph._nodes: + self.add_node(node) + for edge in graph._edges: + self.add_edge(*edge) + + def __add__(self, graph): + """Wrapper on `.merge`""" + self.merge(graph) + return self + + def copy(self): + """Copy the current graph instance""" + graph = self.__class__() + return graph + self + + def __eq__(self, graph): + if not isinstance(graph, self.__class__): + return False + if self._nodes != graph.nodes(): + return False + return sorted(self._edges) == sorted(graph.edges()) + + def __ne__(self, other): + return not self.__eq__(other) + + def add_node(self, node): + """Add the node @node to the graph. + If the node was already present, return False. + Otherwise, return True + """ + if node in self._nodes: + return False + self._nodes.add(node) + self._nodes_succ[node] = [] + self._nodes_pred[node] = [] + return True + + def del_node(self, node): + """Delete the @node of the graph; Also delete every edge to/from this + @node""" + + if node in self._nodes: + self._nodes.remove(node) + for pred in self.predecessors(node): + self.del_edge(pred, node) + for succ in self.successors(node): + self.del_edge(node, succ) + + def add_edge(self, src, dst): + if not src in self._nodes: + self.add_node(src) + if not dst in self._nodes: + self.add_node(dst) + self._edges.append((src, dst)) + self._nodes_succ[src].append(dst) + self._nodes_pred[dst].append(src) + + def add_uniq_edge(self, src, dst): + """Add an edge from @src to @dst if it doesn't already exist""" + if (src not in self._nodes_succ or + dst not in self._nodes_succ[src]): + self.add_edge(src, dst) + + def del_edge(self, src, dst): + self._edges.remove((src, dst)) + self._nodes_succ[src].remove(dst) + self._nodes_pred[dst].remove(src) + + def discard_edge(self, src, dst): + """Remove edge between @src and @dst if it exits""" + if (src, dst) in self._edges: + self.del_edge(src, dst) + + def predecessors_iter(self, node): + if not node in self._nodes_pred: + return + for n_pred in self._nodes_pred[node]: + yield n_pred + + def predecessors(self, node): + return [x for x in self.predecessors_iter(node)] + + def successors_iter(self, node): + if not node in self._nodes_succ: + return + for n_suc in self._nodes_succ[node]: + yield n_suc + + def successors(self, node): + return [x for x in self.successors_iter(node)] + + def leaves_iter(self): + for node in self._nodes: + if not self._nodes_succ[node]: + yield node + + def leaves(self): + return [x for x in self.leaves_iter()] + + def heads_iter(self): + for node in self._nodes: + if not self._nodes_pred[node]: + yield node + + def heads(self): + return [x for x in self.heads_iter()] + + def find_path(self, src, dst, cycles_count=0, done=None): + """ + Searches for paths from @src to @dst + @src: loc_key of basic block from which it should start + @dst: loc_key of basic block where it should stop + @cycles_count: maximum number of times a basic block can be processed + @done: dictionary of already processed loc_keys, it's value is number of times it was processed + @out: list of paths from @src to @dst + """ + if done is None: + done = {} + if dst in done and done[dst] > cycles_count: + return [[]] + if src == dst: + return [[src]] + out = [] + for node in self.predecessors(dst): + done_n = dict(done) + done_n[dst] = done_n.get(dst, 0) + 1 + for path in self.find_path(src, node, cycles_count, done_n): + if path and path[0] == src: + out.append(path + [dst]) + return out + + def find_path_from_src(self, src, dst, cycles_count=0, done=None): + """ + This function does the same as function find_path. + But it searches the paths from src to dst, not vice versa like find_path. + This approach might be more efficient in some cases. + @src: loc_key of basic block from which it should start + @dst: loc_key of basic block where it should stop + @cycles_count: maximum number of times a basic block can be processed + @done: dictionary of already processed loc_keys, it's value is number of times it was processed + @out: list of paths from @src to @dst + """ + + if done is None: + done = {} + if src == dst: + return [[src]] + if src in done and done[src] > cycles_count: + return [[]] + out = [] + for node in self.successors(src): + done_n = dict(done) + done_n[src] = done_n.get(src, 0) + 1 + for path in self.find_path_from_src(node, dst, cycles_count, done_n): + if path and path[len(path)-1] == dst: + out.append([src] + path) + return out + + def nodeid(self, node): + """ + Returns uniq id for a @node + @node: a node of the graph + """ + return hash(node) & 0xFFFFFFFFFFFFFFFF + + def node2lines(self, node): + """ + Returns an iterator on cells of the dot @node. + A DotCellDescription or a list of DotCellDescription are accepted + @node: a node of the graph + """ + yield self.DotCellDescription(text=str(node), attr={}) + + def node_attr(self, node): + """ + Returns a dictionary of the @node's attributes + @node: a node of the graph + """ + return {} + + def edge_attr(self, src, dst): + """ + Return a dictionary of attributes for the edge between @src and @dst + @src: the source node of the edge + @dst: the destination node of the edge + """ + return {} + + @staticmethod + def _fix_chars(token): + return "&#%04d;" % ord(token.group()) + + @staticmethod + def _attr2str(default_attr, attr): + return ' '.join( + '%s="%s"' % (name, value) + for name, value in + viewitems(dict(default_attr, + **attr)) + ) + + def escape_text(self, text): + return self.escape_chars.sub(self._fix_chars, text) + + def dot(self): + """Render dot graph with HTML""" + + td_attr = {'align': 'left'} + nodes_attr = {'shape': 'Mrecord', + 'fontname': 'Courier New'} + + out = ["digraph asm_graph {"] + + # Generate basic nodes + out_nodes = [] + for node in self.nodes(): + node_id = self.nodeid(node) + out_node = '%s [\n' % node_id + out_node += self._attr2str(nodes_attr, self.node_attr(node)) + out_node += 'label =<<table border="0" cellborder="0" cellpadding="3">' + + node_html_lines = [] + + for lineDesc in self.node2lines(node): + out_render = "" + if isinstance(lineDesc, self.DotCellDescription): + lineDesc = [lineDesc] + for col in lineDesc: + out_render += "<td %s>%s</td>" % ( + self._attr2str(td_attr, col.attr), + self.escape_text(str(col.text))) + node_html_lines.append(out_render) + + node_html_lines = ('<tr>' + + ('</tr><tr>').join(node_html_lines) + + '</tr>') + + out_node += node_html_lines + "</table>> ];" + out_nodes.append(out_node) + + out += out_nodes + + # Generate links + for src, dst in self.edges(): + attrs = self.edge_attr(src, dst) + + attrs = ' '.join( + '%s="%s"' % (name, value) + for name, value in viewitems(attrs) + ) + + out.append('%s -> %s' % (self.nodeid(src), self.nodeid(dst)) + + '[' + attrs + '];') + + out.append("}") + return '\n'.join(out) + + + def graphviz(self): + try: + import re + import graphviz + + + self.gv = graphviz.Digraph('html_table') + self._dot_offset = False + td_attr = {'align': 'left'} + nodes_attr = {'shape': 'Mrecord', + 'fontname': 'Courier New'} + + for node in self.nodes(): + elements = [x for x in self.node2lines(node)] + node_id = self.nodeid(node) + out_node = '<<table border="0" cellborder="0" cellpadding="3">' + + node_html_lines = [] + for lineDesc in elements: + out_render = "" + if isinstance(lineDesc, self.DotCellDescription): + lineDesc = [lineDesc] + for col in lineDesc: + out_render += "<td %s>%s</td>" % ( + self._attr2str(td_attr, col.attr), + self.escape_text(str(col.text))) + node_html_lines.append(out_render) + + node_html_lines = ('<tr>' + + ('</tr><tr>').join(node_html_lines) + + '</tr>') + + out_node += node_html_lines + "</table>>" + attrs = dict(nodes_attr) + attrs.update(self.node_attr(node)) + self.gv.node( + "%s" % node_id, + out_node, + attrs, + ) + + + for src, dst in self.edges(): + attrs = self.edge_attr(src, dst) + self.gv.edge( + str(self.nodeid(src)), + str(self.nodeid(dst)), + "", + attrs, + ) + + return self.gv + except ImportError: + # Skip as graphviz is not installed + return None + + + @staticmethod + def _reachable_nodes(head, next_cb): + """Generic algorithm to compute all nodes reachable from/to node + @head""" + + todo = set([head]) + reachable = set() + while todo: + node = todo.pop() + if node in reachable: + continue + reachable.add(node) + yield node + for next_node in next_cb(node): + todo.add(next_node) + + def predecessors_stop_node_iter(self, node, head): + if node == head: + return + for next_node in self.predecessors_iter(node): + yield next_node + + def reachable_sons(self, head): + """Compute all nodes reachable from node @head. Each son is an + immediate successor of an arbitrary, already yielded son of @head""" + return self._reachable_nodes(head, self.successors_iter) + + def reachable_parents(self, leaf): + """Compute all parents of node @leaf. Each parent is an immediate + predecessor of an arbitrary, already yielded parent of @leaf""" + return self._reachable_nodes(leaf, self.predecessors_iter) + + def reachable_parents_stop_node(self, leaf, head): + """Compute all parents of node @leaf. Each parent is an immediate + predecessor of an arbitrary, already yielded parent of @leaf. + Do not compute reachables past @head node""" + return self._reachable_nodes( + leaf, + lambda node_cur: self.predecessors_stop_node_iter( + node_cur, head + ) + ) + + + @staticmethod + def _compute_generic_dominators(head, reachable_cb, prev_cb, next_cb): + """Generic algorithm to compute either the dominators or postdominators + of the graph. + @head: the head/leaf of the graph + @reachable_cb: sons/parents of the head/leaf + @prev_cb: return predecessors/successors of a node + @next_cb: return successors/predecessors of a node + """ + + nodes = set(reachable_cb(head)) + dominators = {} + for node in nodes: + dominators[node] = set(nodes) + + dominators[head] = set([head]) + todo = set(nodes) + + while todo: + node = todo.pop() + + # Heads state must not be changed + if node == head: + continue + + # Compute intersection of all predecessors'dominators + new_dom = None + for pred in prev_cb(node): + if not pred in nodes: + continue + if new_dom is None: + new_dom = set(dominators[pred]) + new_dom.intersection_update(dominators[pred]) + + # We are not a head to we have at least one dominator + assert(new_dom is not None) + + new_dom.update(set([node])) + + # If intersection has changed, add sons to the todo list + if new_dom == dominators[node]: + continue + + dominators[node] = new_dom + for succ in next_cb(node): + todo.add(succ) + return dominators + + def compute_dominators(self, head): + """Compute the dominators of the graph""" + return self._compute_generic_dominators(head, + self.reachable_sons, + self.predecessors_iter, + self.successors_iter) + + def compute_postdominators(self, leaf): + """Compute the postdominators of the graph""" + return self._compute_generic_dominators(leaf, + self.reachable_parents, + self.successors_iter, + self.predecessors_iter) + + + + + def compute_dominator_tree(self, head): + """ + Computes the dominator tree of a graph + :param head: head of graph + :return: DiGraph + """ + idoms = self.compute_immediate_dominators(head) + dominator_tree = DiGraph() + for node in idoms: + dominator_tree.add_edge(idoms[node], node) + + return dominator_tree + + @staticmethod + def _walk_generic_dominator(node, gen_dominators, succ_cb): + """Generic algorithm to return an iterator of the ordered list of + @node's dominators/post_dominator. + + The function doesn't return the self reference in dominators. + @node: The start node + @gen_dominators: The dictionary containing at least node's + dominators/post_dominators + @succ_cb: return predecessors/successors of a node + + """ + # Init + done = set() + if node not in gen_dominators: + # We are in a branch which doesn't reach head + return + node_gen_dominators = set(gen_dominators[node]) + todo = set([node]) + + # Avoid working on itself + node_gen_dominators.remove(node) + + # For each level + while node_gen_dominators: + new_node = None + + # Worklist pattern + while todo: + node = todo.pop() + if node in done: + continue + if node in node_gen_dominators: + new_node = node + break + + # Avoid loops + done.add(node) + + # Look for the next level + for pred in succ_cb(node): + todo.add(pred) + + # Return the node; it's the next starting point + assert(new_node is not None) + yield new_node + node_gen_dominators.remove(new_node) + todo = set([new_node]) + + def walk_dominators(self, node, dominators): + """Return an iterator of the ordered list of @node's dominators + The function doesn't return the self reference in dominators. + @node: The start node + @dominators: The dictionary containing at least node's dominators + """ + return self._walk_generic_dominator(node, + dominators, + self.predecessors_iter) + + def walk_postdominators(self, node, postdominators): + """Return an iterator of the ordered list of @node's postdominators + The function doesn't return the self reference in postdominators. + @node: The start node + @postdominators: The dictionary containing at least node's + postdominators + + """ + return self._walk_generic_dominator(node, + postdominators, + self.successors_iter) + + def compute_immediate_dominators(self, head): + """Compute the immediate dominators of the graph""" + dominators = self.compute_dominators(head) + idoms = {} + + for node in dominators: + for predecessor in self.walk_dominators(node, dominators): + if predecessor in dominators[node] and node != predecessor: + idoms[node] = predecessor + break + return idoms + + def compute_immediate_postdominators(self,tail): + """Compute the immediate postdominators of the graph""" + postdominators = self.compute_postdominators(tail) + ipdoms = {} + + for node in postdominators: + for successor in self.walk_postdominators(node, postdominators): + if successor in postdominators[node] and node != successor: + ipdoms[node] = successor + break + return ipdoms + + def compute_dominance_frontier(self, head): + """ + Compute the dominance frontier of the graph + + Source: Cooper, Keith D., Timothy J. Harvey, and Ken Kennedy. + "A simple, fast dominance algorithm." + Software Practice & Experience 4 (2001), p. 9 + """ + idoms = self.compute_immediate_dominators(head) + frontier = {} + + for node in idoms: + if len(self._nodes_pred[node]) >= 2: + for predecessor in self.predecessors_iter(node): + runner = predecessor + if runner not in idoms: + continue + while runner != idoms[node]: + if runner not in frontier: + frontier[runner] = set() + + frontier[runner].add(node) + runner = idoms[runner] + return frontier + + def _walk_generic_first(self, head, flag, succ_cb): + """ + Generic algorithm to compute breadth or depth first search + for a node. + @head: the head of the graph + @flag: denotes if @todo is used as queue or stack + @succ_cb: returns a node's predecessors/successors + :return: next node + """ + todo = [head] + done = set() + + while todo: + node = todo.pop(flag) + if node in done: + continue + done.add(node) + + for succ in succ_cb(node): + todo.append(succ) + + yield node + + def walk_breadth_first_forward(self, head): + """Performs a breadth first search on the graph from @head""" + return self._walk_generic_first(head, 0, self.successors_iter) + + def walk_depth_first_forward(self, head): + """Performs a depth first search on the graph from @head""" + return self._walk_generic_first(head, -1, self.successors_iter) + + def walk_breadth_first_backward(self, head): + """Performs a breadth first search on the reversed graph from @head""" + return self._walk_generic_first(head, 0, self.predecessors_iter) + + def walk_depth_first_backward(self, head): + """Performs a depth first search on the reversed graph from @head""" + return self._walk_generic_first(head, -1, self.predecessors_iter) + + def has_loop(self): + """Return True if the graph contains at least a cycle""" + todo = list(self.nodes()) + # tested nodes + done = set() + # current DFS nodes + current = set() + while todo: + node = todo.pop() + if node in done: + continue + + if node in current: + # DFS branch end + for succ in self.successors_iter(node): + if succ in current: + return True + # A node cannot be in current AND in done + current.remove(node) + done.add(node) + else: + # Launch DFS from node + todo.append(node) + current.add(node) + todo += self.successors(node) + + return False + + def compute_natural_loops(self, head): + """ + Computes all natural loops in the graph. + + Source: Aho, Alfred V., Lam, Monica S., Sethi, R. and Jeffrey Ullman. + "Compilers: Principles, Techniques, & Tools, Second Edition" + Pearson/Addison Wesley (2007), Chapter 9.6.6 + :param head: head of the graph + :return: yield a tuple of the form (back edge, loop body) + """ + for a, b in self.compute_back_edges(head): + body = self._compute_natural_loop_body(b, a) + yield ((a, b), body) + + def compute_back_edges(self, head): + """ + Computes all back edges from a node to a + dominator in the graph. + :param head: head of graph + :return: yield a back edge + """ + dominators = self.compute_dominators(head) + + # traverse graph + for node in self.walk_depth_first_forward(head): + for successor in self.successors_iter(node): + # check for a back edge to a dominator + if successor in dominators[node]: + edge = (node, successor) + yield edge + + def _compute_natural_loop_body(self, head, leaf): + """ + Computes the body of a natural loop by a depth-first + search on the reversed control flow graph. + :param head: leaf of the loop + :param leaf: header of the loop + :return: set containing loop body + """ + todo = [leaf] + done = {head} + + while todo: + node = todo.pop() + if node in done: + continue + done.add(node) + + for predecessor in self.predecessors_iter(node): + todo.append(predecessor) + return done + + def compute_strongly_connected_components(self): + """ + Partitions the graph into strongly connected components. + + Iterative implementation of Gabow's path-based SCC algorithm. + Source: Gabow, Harold N. + "Path-based depth-first search for strong and biconnected components." + Information Processing Letters 74.3 (2000), pp. 109--110 + + The iterative implementation is inspired by Mark Dickinson's + code: + http://code.activestate.com/recipes/ + 578507-strongly-connected-components-of-a-directed-graph/ + :return: yield a strongly connected component + """ + stack = [] + boundaries = [] + counter = len(self.nodes()) + + # init index with 0 + index = {v: 0 for v in self.nodes()} + + # state machine for worklist algorithm + VISIT, HANDLE_RECURSION, MERGE = 0, 1, 2 + NodeState = namedtuple('NodeState', ['state', 'node']) + + for node in self.nodes(): + # next node if node was already visited + if index[node]: + continue + + todo = [NodeState(VISIT, node)] + done = set() + + while todo: + current = todo.pop() + + if current.node in done: + continue + + # node is unvisited + if current.state == VISIT: + stack.append(current.node) + index[current.node] = len(stack) + boundaries.append(index[current.node]) + + todo.append(NodeState(MERGE, current.node)) + # follow successors + for successor in self.successors_iter(current.node): + todo.append(NodeState(HANDLE_RECURSION, successor)) + + # iterative handling of recursion algorithm + elif current.state == HANDLE_RECURSION: + # visit unvisited successor + if index[current.node] == 0: + todo.append(NodeState(VISIT, current.node)) + else: + # contract cycle if necessary + while index[current.node] < boundaries[-1]: + boundaries.pop() + + # merge strongly connected component + else: + if index[current.node] == boundaries[-1]: + boundaries.pop() + counter += 1 + scc = set() + + while index[current.node] <= len(stack): + popped = stack.pop() + index[popped] = counter + scc.add(popped) + + done.add(current.node) + + yield scc + + + def compute_weakly_connected_components(self): + """ + Return the weakly connected components + """ + remaining = set(self.nodes()) + components = [] + while remaining: + node = remaining.pop() + todo = set() + todo.add(node) + component = set() + done = set() + while todo: + node = todo.pop() + if node in done: + continue + done.add(node) + remaining.discard(node) + component.add(node) + todo.update(self.predecessors(node)) + todo.update(self.successors(node)) + components.append(component) + return components + + + + def replace_node(self, node, new_node): + """ + Replace @node by @new_node + """ + + predecessors = self.predecessors(node) + successors = self.successors(node) + self.del_node(node) + for predecessor in predecessors: + if predecessor == node: + predecessor = new_node + self.add_uniq_edge(predecessor, new_node) + for successor in successors: + if successor == node: + successor = new_node + self.add_uniq_edge(new_node, successor) + +class DiGraphSimplifier(object): + + """Wrapper on graph simplification passes. + + Instance handle passes lists. + """ + + def __init__(self): + self.passes = [] + + def enable_passes(self, passes): + """Add @passes to passes to applied + @passes: sequence of function (DiGraphSimplifier, DiGraph) -> None + """ + self.passes += passes + + def apply_simp(self, graph): + """Apply enabled simplifications on graph @graph + @graph: DiGraph instance + """ + while True: + new_graph = graph.copy() + for simp_func in self.passes: + simp_func(self, new_graph) + + if new_graph == graph: + break + graph = new_graph + return new_graph + + def __call__(self, graph): + """Wrapper on 'apply_simp'""" + return self.apply_simp(graph) + + +class MatchGraphJoker(object): + + """MatchGraphJoker are joker nodes of MatchGraph, that is to say nodes which + stand for any node. Restrictions can be added to jokers. + + If j1, j2 and j3 are MatchGraphJoker, one can quickly build a matcher for + the pattern: + | + +----v----+ + | (j1) | + +----+----+ + | + +----v----+ + | (j2) |<---+ + +----+--+-+ | + | +------+ + +----v----+ + | (j3) | + +----+----+ + | + v + Using: + >>> matcher = j1 >> j2 >> j3 + >>> matcher += j2 >> j2 + Or: + >>> matcher = j1 >> j2 >> j2 >> j3 + + """ + + def __init__(self, restrict_in=True, restrict_out=True, filt=None, + name=None): + """Instantiate a MatchGraphJoker, with restrictions + @restrict_in: (optional) if set, the number of predecessors of the + matched node must be the same than the joker node in the + associated MatchGraph + @restrict_out: (optional) counterpart of @restrict_in for successors + @filt: (optional) function(graph, node) -> boolean for filtering + candidate node + @name: (optional) helper for displaying the current joker + """ + if filt is None: + filt = lambda graph, node: True + self.filt = filt + if name is None: + name = str(id(self)) + self._name = name + self.restrict_in = restrict_in + self.restrict_out = restrict_out + + def __rshift__(self, joker): + """Helper for describing a MatchGraph from @joker + J1 >> J2 stands for an edge going to J2 from J1 + @joker: MatchGraphJoker instance + """ + assert isinstance(joker, MatchGraphJoker) + + graph = MatchGraph() + graph.add_node(self) + graph.add_node(joker) + graph.add_edge(self, joker) + + # For future "A >> B" idiom construction + graph._last_node = joker + + return graph + + def __str__(self): + info = [] + if not self.restrict_in: + info.append("In:*") + if not self.restrict_out: + info.append("Out:*") + return "Joker %s %s" % (self._name, + "(%s)" % " ".join(info) if info else "") + + +class MatchGraph(DiGraph): + + """MatchGraph intends to be the counterpart of match_expr, but for DiGraph + + This class provides API to match a given DiGraph pattern, with addidionnal + restrictions. + The implemented algorithm is a naive approach. + + The recommended way to instantiate a MatchGraph is the use of + MatchGraphJoker. + """ + + def __init__(self, *args, **kwargs): + super(MatchGraph, self).__init__(*args, **kwargs) + # Construction helper + self._last_node = None + + # Construction helpers + def __rshift__(self, joker): + """Construction helper, adding @joker to the current graph as a son of + _last_node + @joker: MatchGraphJoker instance""" + assert isinstance(joker, MatchGraphJoker) + assert isinstance(self._last_node, MatchGraphJoker) + + self.add_node(joker) + self.add_edge(self._last_node, joker) + self._last_node = joker + return self + + def __add__(self, graph): + """Construction helper, merging @graph with self + @graph: MatchGraph instance + """ + assert isinstance(graph, MatchGraph) + + # Reset helpers flag + self._last_node = None + graph._last_node = None + + # Merge graph into self + for node in graph.nodes(): + self.add_node(node) + for edge in graph.edges(): + self.add_edge(*edge) + + return self + + # Graph matching + def _check_node(self, candidate, expected, graph, partial_sol=None): + """Check if @candidate can stand for @expected in @graph, given @partial_sol + @candidate: @graph's node + @expected: MatchGraphJoker instance + @graph: DiGraph instance + @partial_sol: (optional) dictionary of MatchGraphJoker -> @graph's node + standing for a partial solution + """ + # Avoid having 2 different joker for the same node + if partial_sol and candidate in viewvalues(partial_sol): + return False + + # Check lambda filtering + if not expected.filt(graph, candidate): + return False + + # Check arity + # If filter_in/out, then arity must be the same + # Otherwise, arity of the candidate must be at least equal + if ((expected.restrict_in == True and + len(self.predecessors(expected)) != len(graph.predecessors(candidate))) or + (expected.restrict_in == False and + len(self.predecessors(expected)) > len(graph.predecessors(candidate)))): + return False + if ((expected.restrict_out == True and + len(self.successors(expected)) != len(graph.successors(candidate))) or + (expected.restrict_out == False and + len(self.successors(expected)) > len(graph.successors(candidate)))): + return False + + # Check edges with partial solution if any + if not partial_sol: + return True + for pred in self.predecessors(expected): + if (pred in partial_sol and + partial_sol[pred] not in graph.predecessors(candidate)): + return False + + for succ in self.successors(expected): + if (succ in partial_sol and + partial_sol[succ] not in graph.successors(candidate)): + return False + + # All checks OK + return True + + def _propagate_sol(self, node, partial_sol, graph, todo, propagator): + """ + Try to extend the current @partial_sol by propagating the solution using + @propagator on @node. + New solutions are added to @todo + """ + real_node = partial_sol[node] + for candidate in propagator(self, node): + # Edge already in the partial solution, skip it + if candidate in partial_sol: + continue + + # Check candidate + for candidate_real in propagator(graph, real_node): + if self._check_node(candidate_real, candidate, graph, + partial_sol): + temp_sol = partial_sol.copy() + temp_sol[candidate] = candidate_real + if temp_sol not in todo: + todo.append(temp_sol) + + @staticmethod + def _propagate_successors(graph, node): + """Propagate through @node successors in @graph""" + return graph.successors_iter(node) + + @staticmethod + def _propagate_predecessors(graph, node): + """Propagate through @node predecessors in @graph""" + return graph.predecessors_iter(node) + + def match(self, graph): + """Naive subgraph matching between graph and self. + Iterator on matching solution, as dictionary MatchGraphJoker -> @graph + @graph: DiGraph instance + In order to obtained correct and complete results, @graph must be + connected. + """ + # Partial solution: nodes corrects, edges between these nodes corrects + # A partial solution is a dictionary MatchGraphJoker -> @graph's node + todo = list() # Dictionaries containing partial solution + done = list() # Already computed partial solutions + + # Elect first candidates + to_match = next(iter(self._nodes)) + for node in graph.nodes(): + if self._check_node(node, to_match, graph): + to_add = {to_match: node} + if to_add not in todo: + todo.append(to_add) + + while todo: + # When a partial_sol is computed, if more precise partial solutions + # are found, they will be added to 'todo' + # -> using last entry of todo first performs a "depth first" + # approach on solutions + # -> the algorithm may converge faster to a solution, a desired + # behavior while doing graph simplification (stopping after one + # sol) + partial_sol = todo.pop() + + # Avoid infinite loop and recurrent work + if partial_sol in done: + continue + done.append(partial_sol) + + # If all nodes are matching, this is a potential solution + if len(partial_sol) == len(self._nodes): + yield partial_sol + continue + + # Find node to tests using edges + for node in partial_sol: + self._propagate_sol(node, partial_sol, graph, todo, + MatchGraph._propagate_successors) + self._propagate_sol(node, partial_sol, graph, todo, + MatchGraph._propagate_predecessors) diff --git a/src/miasm/core/interval.py b/src/miasm/core/interval.py new file mode 100644 index 00000000..172197c0 --- /dev/null +++ b/src/miasm/core/interval.py @@ -0,0 +1,284 @@ +from __future__ import print_function + +INT_EQ = 0 # Equivalent +INT_B_IN_A = 1 # B in A +INT_A_IN_B = -1 # A in B +INT_DISJOIN = 2 # Disjoint +INT_JOIN = 3 # Overlap +INT_JOIN_AB = 4 # B starts at the end of A +INT_JOIN_BA = 5 # A starts at the end of B + + +def cmp_interval(inter1, inter2): + """Compare @inter1 and @inter2 and returns the associated INT_* case + @inter1, @inter2: interval instance + """ + if inter1 == inter2: + return INT_EQ + + inter1_start, inter1_stop = inter1 + inter2_start, inter2_stop = inter2 + result = INT_JOIN + if inter1_start <= inter2_start and inter1_stop >= inter2_stop: + result = INT_B_IN_A + if inter2_start <= inter1_start and inter2_stop >= inter1_stop: + result = INT_A_IN_B + if inter1_stop + 1 == inter2_start: + result = INT_JOIN_AB + if inter2_stop + 1 == inter1_start: + result = INT_JOIN_BA + if inter1_start > inter2_stop + 1 or inter2_start > inter1_stop + 1: + result = INT_DISJOIN + return result + + +class interval(object): + """Stands for intervals with integer bounds + + Offers common methods to work with interval""" + + def __init__(self, bounds=None): + """Instance an interval object + @bounds: (optional) list of (int, int) and/or interval instance + """ + if bounds is None: + bounds = [] + elif isinstance(bounds, interval): + bounds = bounds.intervals + self.is_cannon = False + self.intervals = bounds + self.cannon() + + def __iter__(self): + """Iterate on intervals""" + for inter in self.intervals: + yield inter + + @staticmethod + def cannon_list(tmp): + """ + Return a cannonizes list of intervals + @tmp: list of (int, int) + """ + tmp = sorted([x for x in tmp if x[0] <= x[1]]) + out = [] + if not tmp: + return out + out.append(tmp.pop()) + while tmp: + x = tmp.pop() + rez = cmp_interval(out[-1], x) + + if rez == INT_EQ: + continue + elif rez == INT_DISJOIN: + out.append(x) + elif rez == INT_B_IN_A: + continue + elif rez in [INT_JOIN, INT_JOIN_AB, INT_JOIN_BA, INT_A_IN_B]: + u, v = x + while out and cmp_interval(out[-1], (u, v)) in [ + INT_JOIN, INT_JOIN_AB, INT_JOIN_BA, INT_A_IN_B]: + u = min(u, out[-1][0]) + v = max(v, out[-1][1]) + out.pop() + out.append((u, v)) + else: + raise ValueError('unknown state', rez) + return out[::-1] + + def cannon(self): + "Apply .cannon_list() on self contained intervals" + if self.is_cannon is True: + return + self.intervals = interval.cannon_list(self.intervals) + self.is_cannon = True + + def __repr__(self): + if self.intervals: + o = " U ".join(["[0x%X 0x%X]" % (x[0], x[1]) + for x in self.intervals]) + else: + o = "[]" + return o + + def __contains__(self, other): + if isinstance(other, interval): + for intervalB in other.intervals: + is_in = False + for intervalA in self.intervals: + if cmp_interval(intervalA, intervalB) in [INT_EQ, INT_B_IN_A]: + is_in = True + break + if not is_in: + return False + return True + else: + for intervalA in self.intervals: + if intervalA[0] <= other <= intervalA[1]: + return True + return False + + def __eq__(self, i): + return self.intervals == i.intervals + + def __ne__(self, other): + return not self.__eq__(other) + + def union(self, other): + """ + Return the union of intervals + @other: interval instance + """ + + if isinstance(other, interval): + other = other.intervals + other = interval(self.intervals + other) + return other + + def difference(self, other): + """ + Return the difference of intervals + @other: interval instance + """ + + to_test = self.intervals[:] + i = -1 + to_del = other.intervals[:] + while i < len(to_test) - 1: + i += 1 + x = to_test[i] + if x[0] > x[1]: + del to_test[i] + i -= 1 + continue + + while to_del and to_del[0][1] < x[0]: + del to_del[0] + + for y in to_del: + if y[0] > x[1]: + break + rez = cmp_interval(x, y) + if rez == INT_DISJOIN: + continue + elif rez == INT_EQ: + del to_test[i] + i -= 1 + break + elif rez == INT_A_IN_B: + del to_test[i] + i -= 1 + break + elif rez == INT_B_IN_A: + del to_test[i] + i1 = (x[0], y[0] - 1) + i2 = (y[1] + 1, x[1]) + to_test[i:i] = [i1, i2] + i -= 1 + break + elif rez in [INT_JOIN_AB, INT_JOIN_BA]: + continue + elif rez == INT_JOIN: + del to_test[i] + if x[0] < y[0]: + to_test[i:i] = [(x[0], y[0] - 1)] + else: + to_test[i:i] = [(y[1] + 1, x[1])] + i -= 1 + break + else: + raise ValueError('unknown state', rez) + return interval(to_test) + + def intersection(self, other): + """ + Return the intersection of intervals + @other: interval instance + """ + + out = [] + for x in self.intervals: + if x[0] > x[1]: + continue + for y in other.intervals: + rez = cmp_interval(x, y) + + if rez == INT_DISJOIN: + continue + elif rez == INT_EQ: + out.append(x) + continue + elif rez == INT_A_IN_B: + out.append(x) + continue + elif rez == INT_B_IN_A: + out.append(y) + continue + elif rez == INT_JOIN_AB: + continue + elif rez == INT_JOIN_BA: + continue + elif rez == INT_JOIN: + if x[0] < y[0]: + out.append((y[0], x[1])) + else: + out.append((x[0], y[1])) + continue + else: + raise ValueError('unknown state', rez) + return interval(out) + + + def __add__(self, other): + return self.union(other) + + def __and__(self, other): + return self.intersection(other) + + def __sub__(self, other): + return self.difference(other) + + def hull(self): + "Return the first and the last bounds of intervals" + if not self.intervals: + return None, None + return self.intervals[0][0], self.intervals[-1][1] + + + @property + def empty(self): + """Return True iff the interval is empty""" + return not self.intervals + + def show(self, img_x=1350, img_y=20, dry_run=False): + """ + show image representing the interval + """ + try: + import Image + import ImageDraw + except ImportError: + print('cannot import python PIL imaging') + return + + img = Image.new('RGB', (img_x, img_y), (100, 100, 100)) + draw = ImageDraw.Draw(img) + i_min, i_max = self.hull() + + print(hex(i_min), hex(i_max)) + + addr2x = lambda addr: ((addr - i_min) * img_x) // (i_max - i_min) + for a, b in self.intervals: + draw.rectangle((addr2x(a), 0, addr2x(b), img_y), (200, 0, 0)) + + if dry_run is False: + img.show() + + @property + def length(self): + """ + Return the cumulated length of intervals + """ + # Do not use __len__ because we may return a value > 32 bits + return sum((stop - start + 1) for start, stop in self.intervals) diff --git a/src/miasm/core/locationdb.py b/src/miasm/core/locationdb.py new file mode 100644 index 00000000..b7e16ea2 --- /dev/null +++ b/src/miasm/core/locationdb.py @@ -0,0 +1,495 @@ +import warnings +from builtins import int as int_types + +from functools import reduce +from future.utils import viewitems, viewvalues + +from miasm.core.utils import printable +from miasm.expression.expression import LocKey, ExprLoc + + +class LocationDB(object): + """ + LocationDB is a "database" of information associated to location. + + An entry in a LocationDB is uniquely identified with a LocKey. + Additional information which can be associated with a LocKey are: + - an offset (uniq per LocationDB) + - several names (each are uniqs per LocationDB) + + As a schema: + loc_key 1 <-> 0..1 offset + 1 <-> 0..n name + + >>> loc_db = LocationDB() + # Add a location with no additional information + >>> loc_key1 = loc_db.add_location() + # Add a location with an offset + >>> loc_key2 = loc_db.add_location(offset=0x1234) + # Add a location with several names + >>> loc_key3 = loc_db.add_location(name="first_name") + >>> loc_db.add_location_name(loc_key3, "second_name") + # Associate an offset to an existing location + >>> loc_db.set_location_offset(loc_key3, 0x5678) + # Remove a name from an existing location + >>> loc_db.remove_location_name(loc_key3, "second_name") + + # Get back offset + >>> loc_db.get_location_offset(loc_key1) + None + >>> loc_db.get_location_offset(loc_key2) + 0x1234 + >>> loc_db.get_location_offset("first_name") + 0x5678 + + # Display a location + >>> loc_db.pretty_str(loc_key1) + loc_key_1 + >>> loc_db.pretty_str(loc_key2) + loc_1234 + >>> loc_db.pretty_str(loc_key3) + first_name + """ + + def __init__(self): + # Known LocKeys + self._loc_keys = set() + + # Association tables + self._loc_key_to_offset = {} + self._loc_key_to_names = {} + self._name_to_loc_key = {} + self._offset_to_loc_key = {} + + # Counter for new LocKey generation + self._loc_key_num = 0 + + def get_location_offset(self, loc_key): + """ + Return the offset of @loc_key if any, None otherwise. + @loc_key: LocKey instance + """ + assert isinstance(loc_key, LocKey) + return self._loc_key_to_offset.get(loc_key) + + def get_location_names(self, loc_key): + """ + Return the frozenset of names associated to @loc_key + @loc_key: LocKey instance + """ + assert isinstance(loc_key, LocKey) + return frozenset(self._loc_key_to_names.get(loc_key, set())) + + def get_name_location(self, name): + """ + Return the LocKey of @name if any, None otherwise. + @name: target name + """ + assert isinstance(name, str) + return self._name_to_loc_key.get(name) + + def get_or_create_name_location(self, name): + """ + Return the LocKey of @name if any, create one otherwise. + @name: target name + """ + assert isinstance(name, str) + loc_key = self._name_to_loc_key.get(name) + if loc_key is not None: + return loc_key + return self.add_location(name=name) + + def get_offset_location(self, offset): + """ + Return the LocKey of @offset if any, None otherwise. + @offset: target offset + """ + return self._offset_to_loc_key.get(offset) + + def get_or_create_offset_location(self, offset): + """ + Return the LocKey of @offset if any, create one otherwise. + @offset: target offset + """ + loc_key = self._offset_to_loc_key.get(offset) + if loc_key is not None: + return loc_key + return self.add_location(offset=offset) + + def get_name_offset(self, name): + """ + Return the offset of @name if any, None otherwise. + @name: target name + """ + assert isinstance(name, str) + loc_key = self.get_name_location(name) + if loc_key is None: + return None + return self.get_location_offset(loc_key) + + def add_location_name(self, loc_key, name): + """Associate a name @name to a given @loc_key + @name: str instance + @loc_key: LocKey instance + """ + assert isinstance(name, str) + assert loc_key in self._loc_keys + already_existing_loc = self._name_to_loc_key.get(name) + if already_existing_loc is not None and already_existing_loc != loc_key: + raise KeyError("%r is already associated to a different loc_key " + "(%r)" % (name, already_existing_loc)) + self._loc_key_to_names.setdefault(loc_key, set()).add(name) + self._name_to_loc_key[name] = loc_key + + def remove_location_name(self, loc_key, name): + """Disassociate a name @name from a given @loc_key + Fail if @name is not already associated to @loc_key + @name: str instance + @loc_key: LocKey instance + """ + assert loc_key in self._loc_keys + assert isinstance(name, str) + already_existing_loc = self._name_to_loc_key.get(name) + if already_existing_loc is None: + raise KeyError("%r is not already associated" % name) + if already_existing_loc != loc_key: + raise KeyError("%r is already associated to a different loc_key " + "(%r)" % (name, already_existing_loc)) + del self._name_to_loc_key[name] + self._loc_key_to_names[loc_key].remove(name) + + def set_location_offset(self, loc_key, offset, force=False): + """Associate the offset @offset to an LocKey @loc_key + + If @force is set, override silently. Otherwise, if an offset is already + associated to @loc_key, an error will be raised + """ + assert loc_key in self._loc_keys + already_existing_loc = self.get_offset_location(offset) + if already_existing_loc is not None and already_existing_loc != loc_key: + raise KeyError("%r is already associated to a different loc_key " + "(%r)" % (offset, already_existing_loc)) + already_existing_off = self._loc_key_to_offset.get(loc_key) + if (already_existing_off is not None and + already_existing_off != offset): + if not force: + raise ValueError( + "%r already has an offset (0x%x). Use 'force=True'" + " for silent overriding" % ( + loc_key, already_existing_off + )) + else: + self.unset_location_offset(loc_key) + self._offset_to_loc_key[offset] = loc_key + self._loc_key_to_offset[loc_key] = offset + + def unset_location_offset(self, loc_key): + """Disassociate LocKey @loc_key's offset + + Fail if there is already no offset associate with it + @loc_key: LocKey + """ + assert loc_key in self._loc_keys + already_existing_off = self._loc_key_to_offset.get(loc_key) + if already_existing_off is None: + raise ValueError("%r already has no offset" % (loc_key)) + del self._offset_to_loc_key[already_existing_off] + del self._loc_key_to_offset[loc_key] + + def consistency_check(self): + """Ensure internal structures are consistent with each others""" + assert set(self._loc_key_to_names).issubset(self._loc_keys) + assert set(self._loc_key_to_offset).issubset(self._loc_keys) + assert self._loc_key_to_offset == {v: k for k, v in viewitems(self._offset_to_loc_key)} + assert reduce( + lambda x, y:x.union(y), + viewvalues(self._loc_key_to_names), + set(), + ) == set(self._name_to_loc_key) + for name, loc_key in viewitems(self._name_to_loc_key): + assert name in self._loc_key_to_names[loc_key] + + def find_free_name(self, name): + """ + If @name is not known in DB, return it + Else append an index to it corresponding to the next unknown name + + @name: string + """ + assert isinstance(name, str) + if self.get_name_location(name) is None: + return name + i = 0 + while True: + new_name = "%s_%d" % (name, i) + if self.get_name_location(new_name) is None: + return new_name + i += 1 + + def add_location(self, name=None, offset=None, strict=True): + """Add a new location in the locationDB. Returns the corresponding LocKey. + If @name is set, also associate a name to this new location. + If @offset is set, also associate an offset to this new location. + + Strict mode (set by @strict, default): + If a location with @offset or @name already exists, an error will be + raised. + Otherwise: + If a location with @offset or @name already exists, the corresponding + LocKey may be updated and will be returned. + """ + + # Deprecation handling + if isinstance(name, int_types): + assert offset is None or offset == name + warnings.warn("Deprecated API: use 'add_location(offset=)' instead." + " An additional 'name=' can be provided to also " + "associate a name (there is no more default name)") + offset = name + name = None + + # Argument cleaning + offset_loc_key = None + if offset is not None: + offset = int(offset) + offset_loc_key = self.get_offset_location(offset) + + # Test for collisions + name_loc_key = None + if name is not None: + assert isinstance(name, str) + name_loc_key = self.get_name_location(name) + + if strict: + if name_loc_key is not None: + raise ValueError("An entry for %r already exists (%r), and " + "strict mode is enabled" % ( + name, name_loc_key + )) + if offset_loc_key is not None: + raise ValueError("An entry for 0x%x already exists (%r), and " + "strict mode is enabled" % ( + offset, offset_loc_key + )) + else: + # Non-strict mode + if name_loc_key is not None: + known_offset = self.get_offset_location(name_loc_key) + if known_offset is None: + if offset is not None: + self.set_location_offset(name_loc_key, offset) + elif known_offset != offset: + raise ValueError( + "Location with name '%s' already have an offset: 0x%x " + "(!= 0x%x)" % (name, offset, known_offset) + ) + # Name already known, same offset -> nothing to do + return name_loc_key + + elif offset_loc_key is not None: + if name is not None: + # Check for already known name are checked above + return self.add_location_name(offset_loc_key, name) + # Offset already known, no name specified + return offset_loc_key + + # No collision, this is a brand new location + loc_key = LocKey(self._loc_key_num) + self._loc_key_num += 1 + self._loc_keys.add(loc_key) + + if offset is not None: + assert offset not in self._offset_to_loc_key + self._offset_to_loc_key[offset] = loc_key + self._loc_key_to_offset[loc_key] = offset + + if name is not None: + self._name_to_loc_key[name] = loc_key + self._loc_key_to_names[loc_key] = set([name]) + + return loc_key + + def remove_location(self, loc_key): + """ + Delete the location corresponding to @loc_key + @loc_key: LocKey instance + """ + assert isinstance(loc_key, LocKey) + if loc_key not in self._loc_keys: + raise KeyError("Unknown loc_key %r" % loc_key) + names = self._loc_key_to_names.pop(loc_key, []) + for name in names: + del self._name_to_loc_key[name] + offset = self._loc_key_to_offset.pop(loc_key, None) + self._offset_to_loc_key.pop(offset, None) + self._loc_keys.remove(loc_key) + + def pretty_str(self, loc_key): + """Return a human readable version of @loc_key, according to information + available in this LocationDB instance""" + names = self.get_location_names(loc_key) + new_names = set() + for name in names: + try: + name = name.decode() + except AttributeError: + pass + new_names.add(name) + names = new_names + if names: + return ",".join(names) + offset = self.get_location_offset(loc_key) + if offset is not None: + return "loc_%x" % offset + return str(loc_key) + + @property + def loc_keys(self): + """Return all loc_keys""" + return self._loc_keys + + @property + def names(self): + """Return all known names""" + return list(self._name_to_loc_key) + + @property + def offsets(self): + """Return all known offsets""" + return list(self._offset_to_loc_key) + + def __str__(self): + out = [] + for loc_key in self._loc_keys: + names = self.get_location_names(loc_key) + offset = self.get_location_offset(loc_key) + out.append( + "%s: %s - %s" % ( + loc_key, + "0x%x" % offset if offset is not None else None, + ",".join(printable(name) for name in names) + ) + ) + return "\n".join(out) + + def merge(self, location_db): + """Merge with another LocationDB @location_db + + WARNING: old reference to @location_db information (such as LocKeys) + must be retrieved from the updated version of this instance. The + dedicated "get_*" APIs may be used for this task + """ + # A simple merge is not doable here, because LocKey will certainly + # collides + + for foreign_loc_key in location_db.loc_keys: + foreign_names = location_db.get_location_names(foreign_loc_key) + foreign_offset = location_db.get_location_offset(foreign_loc_key) + if foreign_names: + init_name = list(foreign_names)[0] + else: + init_name = None + loc_key = self.add_location(offset=foreign_offset, name=init_name, + strict=False) + cur_names = self.get_location_names(loc_key) + for name in foreign_names: + if name not in cur_names and name != init_name: + self.add_location_name(loc_key, name=name) + + def canonize_to_exprloc(self, expr): + """ + If expr is ExprInt, return ExprLoc with corresponding loc_key + Else, return expr + + @expr: Expr instance + """ + if expr.is_int(): + loc_key = self.get_or_create_offset_location(int(expr)) + ret = ExprLoc(loc_key, expr.size) + return ret + return expr + + # Deprecated APIs + @property + def items(self): + """Return all loc_keys""" + warnings.warn('DEPRECATION WARNING: use "loc_keys" instead of "items"') + return list(self._loc_keys) + + def __getitem__(self, item): + warnings.warn('DEPRECATION WARNING: use "get_name_location" or ' + '"get_offset_location"') + if item in self._name_to_loc_key: + return self._name_to_loc_key[item] + if item in self._offset_to_loc_key: + return self._offset_to_loc_key[item] + raise KeyError('unknown symbol %r' % item) + + def __contains__(self, item): + warnings.warn('DEPRECATION WARNING: use "get_name_location" or ' + '"get_offset_location", or ".offsets" or ".names"') + return item in self._name_to_loc_key or item in self._offset_to_loc_key + + def loc_key_to_name(self, loc_key): + """[DEPRECATED API], see 'get_location_names'""" + warnings.warn("Deprecated API: use 'get_location_names'") + return sorted(self.get_location_names(loc_key))[0] + + def loc_key_to_offset(self, loc_key): + """[DEPRECATED API], see 'get_location_offset'""" + warnings.warn("Deprecated API: use 'get_location_offset'") + return self.get_location_offset(loc_key) + + def remove_loc_key(self, loc_key): + """[DEPRECATED API], see 'remove_location'""" + warnings.warn("Deprecated API: use 'remove_location'") + self.remove_location(loc_key) + + def del_loc_key_offset(self, loc_key): + """[DEPRECATED API], see 'unset_location_offset'""" + warnings.warn("Deprecated API: use 'unset_location_offset'") + self.unset_location_offset(loc_key) + + def getby_offset(self, offset): + """[DEPRECATED API], see 'get_offset_location'""" + warnings.warn("Deprecated API: use 'get_offset_location'") + return self.get_offset_location(offset) + + def getby_name(self, name): + """[DEPRECATED API], see 'get_name_location'""" + warnings.warn("Deprecated API: use 'get_name_location'") + return self.get_name_location(name) + + def getby_offset_create(self, offset): + """[DEPRECATED API], see 'get_or_create_offset_location'""" + warnings.warn("Deprecated API: use 'get_or_create_offset_location'") + return self.get_or_create_offset_location(offset) + + def getby_name_create(self, name): + """[DEPRECATED API], see 'get_or_create_name_location'""" + warnings.warn("Deprecated API: use 'get_or_create_name_location'") + return self.get_or_create_name_location(name) + + def rename_location(self, loc_key, newname): + """[DEPRECATED API], see 'add_name_location' and 'remove_location_name' + """ + warnings.warn("Deprecated API: use 'add_location_name' and " + "'remove_location_name'") + for name in self.get_location_names(loc_key): + self.remove_location_name(loc_key, name) + self.add_location_name(loc_key, name) + + def set_offset(self, loc_key, offset): + """[DEPRECATED API], see 'set_location_offset'""" + warnings.warn("Deprecated API: use 'set_location_offset'") + self.set_location_offset(loc_key, offset, force=True) + + def gen_loc_key(self): + """[DEPRECATED API], see 'add_location'""" + warnings.warn("Deprecated API: use 'add_location'") + return self.add_location() + + def str_loc_key(self, loc_key): + """[DEPRECATED API], see 'pretty_str'""" + warnings.warn("Deprecated API: use 'pretty_str'") + return self.pretty_str(loc_key) diff --git a/src/miasm/core/modint.py b/src/miasm/core/modint.py new file mode 100644 index 00000000..14b4dc2c --- /dev/null +++ b/src/miasm/core/modint.py @@ -0,0 +1,270 @@ +#-*- coding:utf-8 -*- + +from builtins import range +from functools import total_ordering + +@total_ordering +class moduint(object): + + def __init__(self, arg): + self.arg = int(arg) % self.__class__.limit + assert(self.arg >= 0 and self.arg < self.__class__.limit) + + def __repr__(self): + return self.__class__.__name__ + '(' + hex(self.arg) + ')' + + def __hash__(self): + return hash(self.arg) + + @classmethod + def maxcast(cls, c2): + c2 = c2.__class__ + if cls.size > c2.size: + return cls + else: + return c2 + + def __eq__(self, y): + if isinstance(y, moduint): + return self.arg == y.arg + return self.arg == y + + def __ne__(self, y): + # required Python 2.7.14 + return not self == y + + def __lt__(self, y): + if isinstance(y, moduint): + return self.arg < y.arg + return self.arg < y + + def __add__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(self.arg + y.arg) + else: + return self.__class__(self.arg + y) + + def __and__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(self.arg & y.arg) + else: + return self.__class__(self.arg & y) + + def __div__(self, y): + # Python: 8 / -7 == -2 (C-like: -1) + # int(float) trick cannot be used, due to information loss + # Examples: + # + # 42 / 10 => 4 + # 42 % 10 => 2 + # + # -42 / 10 => -4 + # -42 % 10 => -2 + # + # 42 / -10 => -4 + # 42 % -10 => 2 + # + # -42 / -10 => 4 + # -42 % -10 => -2 + + den = int(y) + num = int(self) + result_sign = 1 if (den * num) >= 0 else -1 + cls = self.__class__ + if isinstance(y, moduint): + cls = self.maxcast(y) + return (abs(num) // abs(den)) * result_sign + + def __floordiv__(self, y): + return self.__div__(y) + + def __int__(self): + return int(self.arg) + + def __long__(self): + return int(self.arg) + + def __index__(self): + return int(self.arg) + + def __invert__(self): + return self.__class__(~self.arg) + + def __lshift__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(self.arg << y.arg) + else: + return self.__class__(self.arg << y) + + def __mod__(self, y): + # See __div__ for implementation choice + cls = self.__class__ + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(self.arg - y * (self // y)) + + def __mul__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(self.arg * y.arg) + else: + return self.__class__(self.arg * y) + + def __neg__(self): + return self.__class__(-self.arg) + + def __or__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(self.arg | y.arg) + else: + return self.__class__(self.arg | y) + + def __radd__(self, y): + return self.__add__(y) + + def __rand__(self, y): + return self.__and__(y) + + def __rdiv__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(y.arg // self.arg) + else: + return self.__class__(y // self.arg) + + def __rfloordiv__(self, y): + return self.__rdiv__(y) + + def __rlshift__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(y.arg << self.arg) + else: + return self.__class__(y << self.arg) + + def __rmod__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(y.arg % self.arg) + else: + return self.__class__(y % self.arg) + + def __rmul__(self, y): + return self.__mul__(y) + + def __ror__(self, y): + return self.__or__(y) + + def __rrshift__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(y.arg >> self.arg) + else: + return self.__class__(y >> self.arg) + + def __rshift__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(self.arg >> y.arg) + else: + return self.__class__(self.arg >> y) + + def __rsub__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(y.arg - self.arg) + else: + return self.__class__(y - self.arg) + + def __rxor__(self, y): + return self.__xor__(y) + + def __sub__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(self.arg - y.arg) + else: + return self.__class__(self.arg - y) + + def __xor__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(self.arg ^ y.arg) + else: + return self.__class__(self.arg ^ y) + + def __hex__(self): + return hex(self.arg) + + def __abs__(self): + return abs(self.arg) + + def __rpow__(self, v): + return v ** self.arg + + def __pow__(self, v): + return self.__class__(self.arg ** v) + + +class modint(moduint): + + def __init__(self, arg): + if isinstance(arg, moduint): + arg = arg.arg + a = arg % self.__class__.limit + if a >= self.__class__.limit // 2: + a -= self.__class__.limit + self.arg = a + assert( + self.arg >= -self.__class__.limit // 2 and + self.arg < self.__class__.limit + ) + + +def is_modint(a): + return isinstance(a, moduint) + + +mod_size2uint = {} +mod_size2int = {} + +mod_uint2size = {} +mod_int2size = {} + +def define_int(size): + """Build the 'modint' instance corresponding to size @size""" + global mod_size2int, mod_int2size + + name = 'int%d' % size + cls = type(name, (modint,), {"size": size, "limit": 1 << size}) + globals()[name] = cls + mod_size2int[size] = cls + mod_int2size[cls] = size + return cls + +def define_uint(size): + """Build the 'moduint' instance corresponding to size @size""" + global mod_size2uint, mod_uint2size + + name = 'uint%d' % size + cls = type(name, (moduint,), {"size": size, "limit": 1 << size}) + globals()[name] = cls + mod_size2uint[size] = cls + mod_uint2size[cls] = size + return cls + +def define_common_int(): + "Define common int" + common_int = range(1, 257) + + for i in common_int: + define_int(i) + + for i in common_int: + define_uint(i) + +define_common_int() diff --git a/src/miasm/core/objc.py b/src/miasm/core/objc.py new file mode 100644 index 00000000..24ee84ab --- /dev/null +++ b/src/miasm/core/objc.py @@ -0,0 +1,1763 @@ +""" +C helper for Miasm: +* raw C to Miasm expression +* Miasm expression to raw C +* Miasm expression to C type +""" + +from builtins import zip +from builtins import int as int_types + +import warnings +from pycparser import c_parser, c_ast +from functools import total_ordering + +from miasm.core.utils import cmp_elts +from miasm.expression.expression_reduce import ExprReducer +from miasm.expression.expression import ExprInt, ExprId, ExprOp, ExprMem +from miasm.arch.x86.arch import is_op_segm + +from miasm.core.ctypesmngr import CTypeUnion, CTypeStruct, CTypeId, CTypePtr,\ + CTypeArray, CTypeOp, CTypeSizeof, CTypeEnum, CTypeFunc, CTypeEllipsis + + +PADDING_TYPE_NAME = "___padding___" + +def missing_definition(objtype): + warnings.warn("Null size type: Missing definition? %r" % objtype) + +""" +Display C type +source: "The C Programming Language - 2nd Edition - Ritchie Kernighan.pdf" +p. 124 +""" + +def objc_to_str(objc, result=None): + if result is None: + result = "" + while True: + if isinstance(objc, ObjCArray): + result += "[%d]" % objc.elems + objc = objc.objtype + elif isinstance(objc, ObjCPtr): + if not result and isinstance(objc.objtype, ObjCFunc): + result = objc.objtype.name + if isinstance(objc.objtype, (ObjCPtr, ObjCDecl, ObjCStruct, ObjCUnion)): + result = "*%s" % result + else: + result = "(*%s)" % result + + objc = objc.objtype + elif isinstance(objc, (ObjCDecl, ObjCStruct, ObjCUnion)): + if result: + result = "%s %s" % (objc, result) + else: + result = str(objc) + break + elif isinstance(objc, ObjCFunc): + args_str = [] + for name, arg in objc.args: + args_str.append(objc_to_str(arg, name)) + args = ", ".join(args_str) + result += "(%s)" % args + objc = objc.type_ret + elif isinstance(objc, ObjCInt): + return "int" + elif isinstance(objc, ObjCEllipsis): + return "..." + else: + raise TypeError("Unknown c type") + return result + + +@total_ordering +class ObjC(object): + """Generic ObjC""" + + def __init__(self, align, size): + self._align = align + self._size = size + + @property + def align(self): + """Alignment (in bytes) of the C object""" + return self._align + + @property + def size(self): + """Size (in bytes) of the C object""" + return self._size + + def cmp_base(self, other): + assert self.__class__ in OBJC_PRIO + assert other.__class__ in OBJC_PRIO + + if OBJC_PRIO[self.__class__] != OBJC_PRIO[other.__class__]: + return cmp_elts( + OBJC_PRIO[self.__class__], + OBJC_PRIO[other.__class__] + ) + if self.align != other.align: + return cmp_elts(self.align, other.align) + return cmp_elts(self.size, other.size) + + def __hash__(self): + return hash((self.__class__, self._align, self._size)) + + def __str__(self): + return objc_to_str(self) + + def __eq__(self, other): + return self.cmp_base(other) == 0 + + def __ne__(self, other): + # required Python 2.7.14 + return not self == other + + def __lt__(self, other): + return self.cmp_base(other) < 0 + + +@total_ordering +class ObjCDecl(ObjC): + """C Declaration identified""" + + def __init__(self, name, align, size): + super(ObjCDecl, self).__init__(align, size) + self._name = name + + name = property(lambda self: self._name) + + def __hash__(self): + return hash((super(ObjCDecl, self).__hash__(), self._name)) + + def __repr__(self): + return '<%s %s>' % (self.__class__.__name__, self.name) + + def __str__(self): + return str(self.name) + + def __eq__(self, other): + ret = self.cmp_base(other) + if ret: + return False + return self.name == other.name + + def __lt__(self, other): + ret = self.cmp_base(other) + if ret: + if ret < 0: + return True + return False + return self.name < other.name + + +class ObjCInt(ObjC): + """C integer""" + + def __init__(self): + super(ObjCInt, self).__init__(None, 0) + + def __str__(self): + return 'int' + + +@total_ordering +class ObjCPtr(ObjC): + """C Pointer""" + + def __init__(self, objtype, void_p_align, void_p_size): + """Init ObjCPtr + + @objtype: pointer target ObjC + @void_p_align: pointer alignment (in bytes) + @void_p_size: pointer size (in bytes) + """ + + super(ObjCPtr, self).__init__(void_p_align, void_p_size) + self._lock = False + + self.objtype = objtype + if objtype is None: + self._lock = False + + def get_objtype(self): + assert self._lock is True + return self._objtype + + def set_objtype(self, objtype): + assert self._lock is False + self._lock = True + self._objtype = objtype + + objtype = property(get_objtype, set_objtype) + + def __hash__(self): + # Don't try to hash on an unlocked Ptr (still mutable) + assert self._lock + return hash((super(ObjCPtr, self).__hash__(), hash(self._objtype))) + + def __repr__(self): + return '<%s %r>' % ( + self.__class__.__name__, + self.objtype.__class__ + ) + + def __eq__(self, other): + ret = self.cmp_base(other) + if ret: + return False + return self.objtype == other.objtype + + def __lt__(self, other): + ret = self.cmp_base(other) + if ret: + if ret < 0: + return True + return False + return self.objtype < other.objtype + + +@total_ordering +class ObjCArray(ObjC): + """C array (test[XX])""" + + def __init__(self, objtype, elems): + """Init ObjCArray + + @objtype: pointer target ObjC + @elems: number of elements in the array + """ + + super(ObjCArray, self).__init__(objtype.align, elems * objtype.size) + self._elems = elems + self._objtype = objtype + + objtype = property(lambda self: self._objtype) + elems = property(lambda self: self._elems) + + def __hash__(self): + return hash((super(ObjCArray, self).__hash__(), self._elems, hash(self._objtype))) + + def __repr__(self): + return '<%r[%d]>' % (self.objtype, self.elems) + + def __eq__(self, other): + ret = self.cmp_base(other) + if ret: + return False + if self.objtype != other.objtype: + return False + return self.elems == other.elems + + def __lt__(self, other): + ret = self.cmp_base(other) + if ret > 0: + return False + if self.objtype > other.objtype: + return False + return self.elems < other.elems + +@total_ordering +class ObjCStruct(ObjC): + """C object for structures""" + + def __init__(self, name, align, size, fields): + super(ObjCStruct, self).__init__(align, size) + self._name = name + self._fields = tuple(fields) + + name = property(lambda self: self._name) + fields = property(lambda self: self._fields) + + def __hash__(self): + return hash((super(ObjCStruct, self).__hash__(), self._name)) + + def __repr__(self): + out = [] + out.append("Struct %s: (align: %d)" % (self.name, self.align)) + out.append(" off sz name") + for name, objtype, offset, size in self.fields: + out.append(" 0x%-3x %-3d %-10s %r" % + (offset, size, name, objtype.__class__.__name__)) + return '\n'.join(out) + + def __str__(self): + return 'struct %s' % (self.name) + + def __eq__(self, other): + ret = self.cmp_base(other) + if ret: + return False + return self.name == other.name + + def __lt__(self, other): + ret = self.cmp_base(other) + if ret: + if ret < 0: + return True + return False + return self.name < other.name + + +@total_ordering +class ObjCUnion(ObjC): + """C object for unions""" + + def __init__(self, name, align, size, fields): + super(ObjCUnion, self).__init__(align, size) + self._name = name + self._fields = tuple(fields) + + name = property(lambda self: self._name) + fields = property(lambda self: self._fields) + + def __hash__(self): + return hash((super(ObjCUnion, self).__hash__(), self._name)) + + def __repr__(self): + out = [] + out.append("Union %s: (align: %d)" % (self.name, self.align)) + out.append(" off sz name") + for name, objtype, offset, size in self.fields: + out.append(" 0x%-3x %-3d %-10s %r" % + (offset, size, name, objtype)) + return '\n'.join(out) + + def __str__(self): + return 'union %s' % (self.name) + + def __eq__(self, other): + ret = self.cmp_base(other) + if ret: + return False + return self.name == other.name + + def __lt__(self, other): + ret = self.cmp_base(other) + if ret: + if ret < 0: + return True + return False + return self.name < other.name + +class ObjCEllipsis(ObjC): + """C integer""" + + def __init__(self): + super(ObjCEllipsis, self).__init__(None, None) + + align = property(lambda self: self._align) + size = property(lambda self: self._size) + +@total_ordering +class ObjCFunc(ObjC): + """C object for Functions""" + + def __init__(self, name, abi, type_ret, args, void_p_align, void_p_size): + super(ObjCFunc, self).__init__(void_p_align, void_p_size) + self._name = name + self._abi = abi + self._type_ret = type_ret + self._args = tuple(args) + + args = property(lambda self: self._args) + type_ret = property(lambda self: self._type_ret) + abi = property(lambda self: self._abi) + name = property(lambda self: self._name) + + def __hash__(self): + return hash((super(ObjCFunc, self).__hash__(), hash(self._args), self._name)) + + def __repr__(self): + return "<%s %s>" % ( + self.__class__.__name__, + self.name + ) + + def __str__(self): + out = [] + out.append("Function (%s) %s: (align: %d)" % (self.abi, self.name, self.align)) + out.append(" ret: %s" % (str(self.type_ret))) + out.append(" Args:") + for name, arg in self.args: + out.append(" %s %s" % (name, arg)) + return '\n'.join(out) + + def __eq__(self, other): + ret = self.cmp_base(other) + if ret: + return False + return self.name == other.name + + def __lt__(self, other): + ret = self.cmp_base(other) + if ret: + if ret < 0: + return True + return False + return self.name < other.name + +OBJC_PRIO = { + ObjC: 0, + ObjCDecl:1, + ObjCInt:2, + ObjCPtr:3, + ObjCArray:4, + ObjCStruct:5, + ObjCUnion:6, + ObjCEllipsis:7, + ObjCFunc:8, +} + + +def access_simplifier(expr): + """Expression visitor to simplify a C access represented in Miasm + + @expr: Miasm expression representing the C access + + Example: + + IN: (In c: ['*(&((&((*(ptr_Test)).a))[0]))']) + [ExprOp('deref', ExprOp('addr', ExprOp('[]', ExprOp('addr', + ExprOp('field', ExprOp('deref', ExprId('ptr_Test', 64)), + ExprId('a', 64))), ExprInt(0x0, 64))))] + + OUT: (In c: ['(ptr_Test)->a']) + [ExprOp('->', ExprId('ptr_Test', 64), ExprId('a', 64))] + """ + + if (expr.is_op("addr") and + expr.args[0].is_op("[]") and + expr.args[0].args[1] == ExprInt(0, 64)): + return expr.args[0].args[0] + elif (expr.is_op("[]") and + expr.args[0].is_op("addr") and + expr.args[1] == ExprInt(0, 64)): + return expr.args[0].args[0] + elif (expr.is_op("addr") and + expr.args[0].is_op("deref")): + return expr.args[0].args[0] + elif (expr.is_op("deref") and + expr.args[0].is_op("addr")): + return expr.args[0].args[0] + elif (expr.is_op("field") and + expr.args[0].is_op("deref")): + return ExprOp("->", expr.args[0].args[0], expr.args[1]) + return expr + + +def access_str(expr): + """Return the C string of a C access represented in Miasm + + @expr: Miasm expression representing the C access + + In: + ExprOp('->', ExprId('ptr_Test', 64), ExprId('a', 64)) + OUT: + '(ptr_Test)->a' + """ + + if isinstance(expr, ExprId): + out = str(expr) + elif isinstance(expr, ExprInt): + out = str(int(expr)) + elif expr.is_op("addr"): + out = "&(%s)" % access_str(expr.args[0]) + elif expr.is_op("deref"): + out = "*(%s)" % access_str(expr.args[0]) + elif expr.is_op("field"): + out = "(%s).%s" % (access_str(expr.args[0]), access_str(expr.args[1])) + elif expr.is_op("->"): + out = "(%s)->%s" % (access_str(expr.args[0]), access_str(expr.args[1])) + elif expr.is_op("[]"): + out = "(%s)[%s]" % (access_str(expr.args[0]), access_str(expr.args[1])) + else: + raise RuntimeError("unknown op") + + return out + + +class CGen(object): + """Generic object to represent a C expression""" + + default_size = 64 + + + def __init__(self, ctype): + self._ctype = ctype + + @property + def ctype(self): + """Type (ObjC instance) of the current object""" + return self._ctype + + def __hash__(self): + return hash(self.__class__) + + def __eq__(self, other): + return (self.__class__ == other.__class__ and + self._ctype == other.ctype) + + def __ne__(self, other): + return not self.__eq__(other) + + def to_c(self): + """Generate corresponding C""" + + raise NotImplementedError("Virtual") + + def to_expr(self): + """Generate Miasm expression representing the C access""" + + raise NotImplementedError("Virtual") + + +class CGenInt(CGen): + """Int C object""" + + def __init__(self, integer): + assert isinstance(integer, int_types) + self._integer = integer + super(CGenInt, self).__init__(ObjCInt()) + + @property + def integer(self): + """Value of the object""" + return self._integer + + def __hash__(self): + return hash((super(CGenInt, self).__hash__(), self._integer)) + + def __eq__(self, other): + return (super(CGenInt, self).__eq__(other) and + self._integer == other.integer) + + def __ne__(self, other): + return not self.__eq__(other) + + def to_c(self): + """Generate corresponding C""" + + return "0x%X" % self.integer + + def __repr__(self): + return "<%s %s>" % (self.__class__.__name__, + self.integer) + + def to_expr(self): + """Generate Miasm expression representing the C access""" + + return ExprInt(self.integer, self.default_size) + + +class CGenId(CGen): + """ID of a C object""" + + def __init__(self, ctype, name): + self._name = name + assert isinstance(name, str) + super(CGenId, self).__init__(ctype) + + @property + def name(self): + """Name of the Id""" + return self._name + + def __hash__(self): + return hash((super(CGenId, self).__hash__(), self._name)) + + def __eq__(self, other): + return (super(CGenId, self).__eq__(other) and + self._name == other.name) + + def __repr__(self): + return "<%s %s>" % (self.__class__.__name__, + self.name) + + def to_c(self): + """Generate corresponding C""" + + return "%s" % (self.name) + + def to_expr(self): + """Generate Miasm expression representing the C access""" + + return ExprId(self.name, self.default_size) + + +class CGenField(CGen): + """ + Field of a C struct/union + + IN: + - struct (not ptr struct) + - field name + OUT: + - input type of the field => output type + - X[] => X[] + - X => X* + """ + + def __init__(self, struct, field, fieldtype, void_p_align, void_p_size): + self._struct = struct + self._field = field + assert isinstance(field, str) + if isinstance(fieldtype, ObjCArray): + ctype = fieldtype + else: + ctype = ObjCPtr(fieldtype, void_p_align, void_p_size) + super(CGenField, self).__init__(ctype) + + @property + def struct(self): + """Structure containing the field""" + return self._struct + + @property + def field(self): + """Field name""" + return self._field + + def __hash__(self): + return hash((super(CGenField, self).__hash__(), self._struct, self._field)) + + def __eq__(self, other): + return (super(CGenField, self).__eq__(other) and + self._struct == other.struct and + self._field == other.field) + + def to_c(self): + """Generate corresponding C""" + + if isinstance(self.ctype, ObjCArray): + return "(%s).%s" % (self.struct.to_c(), self.field) + elif isinstance(self.ctype, ObjCPtr): + return "&((%s).%s)" % (self.struct.to_c(), self.field) + else: + raise RuntimeError("Strange case") + + def __repr__(self): + return "<%s %s %s>" % (self.__class__.__name__, + self.struct, + self.field) + + def to_expr(self): + """Generate Miasm expression representing the C access""" + + if isinstance(self.ctype, ObjCArray): + return ExprOp("field", + self.struct.to_expr(), + ExprId(self.field, self.default_size)) + elif isinstance(self.ctype, ObjCPtr): + return ExprOp("addr", + ExprOp("field", + self.struct.to_expr(), + ExprId(self.field, self.default_size))) + else: + raise RuntimeError("Strange case") + + +class CGenArray(CGen): + """ + C Array + + This object does *not* deref the source, it only do object casting. + + IN: + - obj + OUT: + - X* => X* + - ..[][] => ..[] + - X[] => X* + """ + + def __init__(self, base, elems, void_p_align, void_p_size): + ctype = base.ctype + if isinstance(ctype, ObjCPtr): + pass + elif isinstance(ctype, ObjCArray) and isinstance(ctype.objtype, ObjCArray): + ctype = ctype.objtype + elif isinstance(ctype, ObjCArray): + ctype = ObjCPtr(ctype.objtype, void_p_align, void_p_size) + else: + raise TypeError("Strange case") + self._base = base + self._elems = elems + super(CGenArray, self).__init__(ctype) + + @property + def base(self): + """Base object supporting the array""" + return self._base + + @property + def elems(self): + """Number of elements in the array""" + return self._elems + + def __hash__(self): + return hash((super(CGenArray, self).__hash__(), self._base, self._elems)) + + def __eq__(self, other): + return (super(CGenField, self).__eq__(other) and + self._base == other.base and + self._elems == other.elems) + + def __repr__(self): + return "<%s %s>" % (self.__class__.__name__, + self.base) + + def to_c(self): + """Generate corresponding C""" + + if isinstance(self.ctype, ObjCPtr): + out_str = "&((%s)[%d])" % (self.base.to_c(), self.elems) + elif isinstance(self.ctype, ObjCArray): + out_str = "(%s)[%d]" % (self.base.to_c(), self.elems) + else: + raise RuntimeError("Strange case") + return out_str + + def to_expr(self): + """Generate Miasm expression representing the C access""" + + if isinstance(self.ctype, ObjCPtr): + return ExprOp("addr", + ExprOp("[]", + self.base.to_expr(), + ExprInt(self.elems, self.default_size))) + elif isinstance(self.ctype, ObjCArray): + return ExprOp("[]", + self.base.to_expr(), + ExprInt(self.elems, self.default_size)) + else: + raise RuntimeError("Strange case") + + +class CGenDeref(CGen): + """ + C dereference + + IN: + - ptr + OUT: + - X* => X + """ + + def __init__(self, ptr): + assert isinstance(ptr.ctype, ObjCPtr) + self._ptr = ptr + super(CGenDeref, self).__init__(ptr.ctype.objtype) + + @property + def ptr(self): + """Pointer object""" + return self._ptr + + def __hash__(self): + return hash((super(CGenDeref, self).__hash__(), self._ptr)) + + def __eq__(self, other): + return (super(CGenField, self).__eq__(other) and + self._ptr == other.ptr) + + def __repr__(self): + return "<%s %s>" % (self.__class__.__name__, + self.ptr) + + def to_c(self): + """Generate corresponding C""" + + if not isinstance(self.ptr.ctype, ObjCPtr): + raise RuntimeError() + return "*(%s)" % (self.ptr.to_c()) + + def to_expr(self): + """Generate Miasm expression representing the C access""" + + if not isinstance(self.ptr.ctype, ObjCPtr): + raise RuntimeError() + return ExprOp("deref", self.ptr.to_expr()) + + +def ast_get_c_access_expr(ast, expr_types, lvl=0): + """Transform C ast object into a C Miasm expression + + @ast: parsed pycparser.c_ast object + @expr_types: a dictionary linking ID names to their types + @lvl: actual recursion level + + Example: + + IN: + StructRef: -> + ID: ptr_Test + ID: a + + OUT: + ExprOp('->', ExprId('ptr_Test', 64), ExprId('a', 64)) + """ + + if isinstance(ast, c_ast.Constant): + obj = ExprInt(int(ast.value), 64) + elif isinstance(ast, c_ast.StructRef): + name, field = ast.name, ast.field.name + name = ast_get_c_access_expr(name, expr_types) + if ast.type == "->": + s_name = name + s_field = ExprId(field, 64) + obj = ExprOp('->', s_name, s_field) + elif ast.type == ".": + s_name = name + s_field = ExprId(field, 64) + obj = ExprOp("field", s_name, s_field) + else: + raise RuntimeError("Unknown struct access") + elif isinstance(ast, c_ast.UnaryOp) and ast.op == "&": + tmp = ast_get_c_access_expr(ast.expr, expr_types, lvl + 1) + obj = ExprOp("addr", tmp) + elif isinstance(ast, c_ast.ArrayRef): + tmp = ast_get_c_access_expr(ast.name, expr_types, lvl + 1) + index = ast_get_c_access_expr(ast.subscript, expr_types, lvl + 1) + obj = ExprOp("[]", tmp, index) + elif isinstance(ast, c_ast.ID): + assert ast.name in expr_types + obj = ExprId(ast.name, 64) + elif isinstance(ast, c_ast.UnaryOp) and ast.op == "*": + tmp = ast_get_c_access_expr(ast.expr, expr_types, lvl + 1) + obj = ExprOp("deref", tmp) + else: + raise NotImplementedError("Unknown type") + return obj + + +def parse_access(c_access): + """Parse C access + + @c_access: C access string + """ + + main = ''' + int main() { + %s; + } + ''' % c_access + + parser = c_parser.CParser() + node = parser.parse(main, filename='<stdin>') + access = node.ext[-1].body.block_items[0] + return access + + +class ExprToAccessC(ExprReducer): + """ + Generate the C access object(s) for a given native Miasm expression + Example: + IN: + @32[ptr_Test] + OUT: + [<CGenDeref <CGenArray <CGenField <CGenDeref <CGenId ptr_Test>> a>>>] + + An expression may be represented by multiple accessor (due to unions). + """ + + def __init__(self, expr_types, types_mngr, enforce_strict_access=True): + """Init GenCAccess + + @expr_types: a dictionary linking ID names to their types + @types_mngr: types manager + @enforce_strict_access: If false, generate access even on expression + pointing to a middle of an object. If true, raise exception if such a + pointer is encountered + """ + + self.expr_types = expr_types + self.types_mngr = types_mngr + self.enforce_strict_access = enforce_strict_access + + def updt_expr_types(self, expr_types): + """Update expr_types + @expr_types: Dictionary associating name to type + """ + + self.expr_types = expr_types + + def cgen_access(self, cgenobj, base_type, offset, deref, lvl=0): + """Return the access(es) which lead to the element at @offset of an + object of type @base_type + + In case of no @deref, stops recursion as soon as we reached the base of + an object. + In other cases, we need to go down to the final dereferenced object + + @cgenobj: current object access + @base_type: type of main object + @offset: offset (in bytes) of the target sub object + @deref: get type for a pointer or a deref + @lvl: actual recursion level + + + IN: + - base_type: struct Toto{ + int a + int b + } + - base_name: var + - 4 + OUT: + - CGenField(var, b) + + + + IN: + - base_type: int a + - 0 + OUT: + - CGenAddr(a) + + IN: + - base_type: X = int* a + - 0 + OUT: + - CGenAddr(X) + + IN: + - X = int* a + - 8 + OUT: + - ASSERT + + + IN: + - struct toto{ + int a + int b[10] + } + - 8 + OUT: + - CGenArray(CGenField(toto, b), 1) + """ + if base_type.size == 0: + missing_definition(base_type) + return set() + + + void_type = self.types_mngr.void_ptr + if isinstance(base_type, ObjCStruct): + if not 0 <= offset < base_type.size: + return set() + + if offset == 0 and not deref: + # In this case, return the struct* + return set([cgenobj]) + + for fieldname, subtype, field_offset, size in base_type.fields: + if not field_offset <= offset < field_offset + size: + continue + fieldptr = CGenField(CGenDeref(cgenobj), fieldname, subtype, + void_type.align, void_type.size) + new_type = self.cgen_access(fieldptr, subtype, + offset - field_offset, + deref, lvl + 1) + break + else: + return set() + elif isinstance(base_type, ObjCArray): + if base_type.objtype.size == 0: + missing_definition(base_type.objtype) + return set() + element_num = offset // (base_type.objtype.size) + field_offset = offset % base_type.objtype.size + if element_num >= base_type.elems: + return set() + if offset == 0 and not deref: + # In this case, return the array + return set([cgenobj]) + + curobj = CGenArray(cgenobj, element_num, + void_type.align, + void_type.size) + if field_offset == 0: + # We point to the start of the sub object, + # return it directly + return set([curobj]) + new_type = self.cgen_access(curobj, base_type.objtype, + field_offset, deref, lvl + 1) + + elif isinstance(base_type, ObjCDecl): + if self.enforce_strict_access and offset % base_type.size != 0: + return set() + elem_num = offset // base_type.size + + nobj = CGenArray(cgenobj, elem_num, + void_type.align, void_type.size) + new_type = set([nobj]) + + elif isinstance(base_type, ObjCUnion): + if offset == 0 and not deref: + # In this case, return the struct* + return set([cgenobj]) + + out = set() + for fieldname, objtype, field_offset, size in base_type.fields: + if not field_offset <= offset < field_offset + size: + continue + field = CGenField(CGenDeref(cgenobj), fieldname, objtype, + void_type.align, void_type.size) + out.update(self.cgen_access(field, objtype, + offset - field_offset, + deref, lvl + 1)) + new_type = out + + elif isinstance(base_type, ObjCPtr): + elem_num = offset // base_type.size + if self.enforce_strict_access and offset % base_type.size != 0: + return set() + nobj = CGenArray(cgenobj, elem_num, + void_type.align, void_type.size) + new_type = set([nobj]) + + else: + raise NotImplementedError("deref type %r" % base_type) + return new_type + + def reduce_known_expr(self, node, ctxt, **kwargs): + """Generate access for known expr""" + if node.expr in ctxt: + objcs = ctxt[node.expr] + return set(CGenId(objc, str(node.expr)) for objc in objcs) + return None + + def reduce_int(self, node, **kwargs): + """Generate access for ExprInt""" + + if not isinstance(node.expr, ExprInt): + return None + return set([CGenInt(int(node.expr))]) + + def get_solo_type(self, node): + """Return the type of the @node if it has only one possible type, + different from not None. In other cases, return None. + """ + if node.info is None or len(node.info) != 1: + return None + return type(list(node.info)[0].ctype) + + def reduce_op(self, node, lvl=0, **kwargs): + """Generate access for ExprOp""" + if not (node.expr.is_op("+") or is_op_segm(node.expr)) \ + or len(node.args) != 2: + return None + type_arg1 = self.get_solo_type(node.args[1]) + if type_arg1 != ObjCInt: + return None + arg0, arg1 = node.args + if arg0.info is None: + return None + void_type = self.types_mngr.void_ptr + out = set() + if not arg1.expr.is_int(): + return None + ptr_offset = int(arg1.expr) + for info in arg0.info: + if isinstance(info.ctype, ObjCArray): + field_type = info.ctype + elif isinstance(info.ctype, ObjCPtr): + field_type = info.ctype.objtype + else: + continue + target_type = info.ctype.objtype + + # Array-like: int* ptr; ptr[1] = X + out.update(self.cgen_access(info, field_type, ptr_offset, False, lvl)) + return out + + def reduce_mem(self, node, lvl=0, **kwargs): + """Generate access for ExprMem: + * @NN[ptr<elem>] -> elem (type) + * @64[ptr<ptr<elem>>] -> ptr<elem> + * @32[ptr<struct>] -> struct.00 + """ + + if not isinstance(node.expr, ExprMem): + return None + if node.ptr.info is None: + return None + assert isinstance(node.ptr.info, set) + void_type = self.types_mngr.void_ptr + found = set() + for subcgenobj in node.ptr.info: + if isinstance(subcgenobj.ctype, ObjCArray): + nobj = CGenArray(subcgenobj, 0, + void_type.align, + void_type.size) + target = nobj.ctype.objtype + for finalcgenobj in self.cgen_access(nobj, target, 0, True, lvl): + assert isinstance(finalcgenobj.ctype, ObjCPtr) + if self.enforce_strict_access and finalcgenobj.ctype.objtype.size != node.expr.size // 8: + continue + found.add(CGenDeref(finalcgenobj)) + + elif isinstance(subcgenobj.ctype, ObjCPtr): + target = subcgenobj.ctype.objtype + # target : type(elem) + if isinstance(target, (ObjCStruct, ObjCUnion)): + for finalcgenobj in self.cgen_access(subcgenobj, target, 0, True, lvl): + target = finalcgenobj.ctype.objtype + if self.enforce_strict_access and target.size != node.expr.size // 8: + continue + found.add(CGenDeref(finalcgenobj)) + elif isinstance(target, ObjCArray): + if self.enforce_strict_access and subcgenobj.ctype.size != node.expr.size // 8: + continue + found.update(self.cgen_access(CGenDeref(subcgenobj), target, + 0, False, lvl)) + else: + if self.enforce_strict_access and target.size != node.expr.size // 8: + continue + found.add(CGenDeref(subcgenobj)) + if not found: + return None + return found + + reduction_rules = [reduce_known_expr, + reduce_int, + reduce_op, + reduce_mem, + ] + + def get_accesses(self, expr, expr_context=None): + """Generate C access(es) for the native Miasm expression @expr + @expr: native Miasm expression + @expr_context: a dictionary linking known expressions to their + types. An expression is linked to a tuple of types. + """ + if expr_context is None: + expr_context = self.expr_types + ret = self.reduce(expr, ctxt=expr_context) + if ret.info is None: + return set() + return ret.info + + +class ExprCToExpr(ExprReducer): + """Translate a Miasm expression (representing a C access) into a native + Miasm expression and its C type: + + Example: + + IN: ((ptr_struct -> f_mini) field x) + OUT: @32[ptr_struct + 0x80], int + + + Tricky cases: + Struct S0 { + int x; + int y[0x10]; + } + + Struct S1 { + int a; + S0 toto; + } + + S1* ptr; + + Case 1: + ptr->toto => ptr + 0x4 + &(ptr->toto) => ptr + 0x4 + + Case 2: + (ptr->toto).x => @32[ptr + 0x4] + &((ptr->toto).x) => ptr + 0x4 + + Case 3: + (ptr->toto).y => ptr + 0x8 + &((ptr->toto).y) => ptr + 0x8 + + Case 4: + (ptr->toto).y[1] => @32[ptr + 0x8 + 0x4] + &((ptr->toto).y[1]) => ptr + 0x8 + 0x4 + + """ + + def __init__(self, expr_types, types_mngr): + """Init ExprCAccess + + @expr_types: a dictionary linking ID names to their types + @types_mngr: types manager + """ + + self.expr_types = expr_types + self.types_mngr = types_mngr + + def updt_expr_types(self, expr_types): + """Update expr_types + @expr_types: Dictionary associating name to type + """ + + self.expr_types = expr_types + + CST = "CST" + + def reduce_known_expr(self, node, ctxt, **kwargs): + """Reduce known expressions""" + if str(node.expr) in ctxt: + objc = ctxt[str(node.expr)] + out = (node.expr, objc) + elif node.expr.is_id(): + out = (node.expr, None) + else: + out = None + return out + + def reduce_int(self, node, **kwargs): + """Reduce ExprInt""" + + if not isinstance(node.expr, ExprInt): + return None + return self.CST + + def reduce_op_memberof(self, node, **kwargs): + """Reduce -> operator""" + + if not node.expr.is_op('->'): + return None + assert len(node.args) == 2 + out = [] + assert isinstance(node.args[1].expr, ExprId) + field = node.args[1].expr.name + src, src_type = node.args[0].info + if src_type is None: + return None + assert isinstance(src_type, (ObjCPtr, ObjCArray)) + struct_dst = src_type.objtype + assert isinstance(struct_dst, ObjCStruct) + + found = False + for name, objtype, offset, _ in struct_dst.fields: + if name != field: + continue + expr = src + ExprInt(offset, src.size) + if isinstance(objtype, (ObjCArray, ObjCStruct, ObjCUnion)): + pass + else: + expr = ExprMem(expr, objtype.size * 8) + assert not found + found = True + out = (expr, objtype) + assert found + return out + + def reduce_op_field(self, node, **kwargs): + """Reduce field operator (Struct or Union)""" + + if not node.expr.is_op('field'): + return None + assert len(node.args) == 2 + out = [] + assert isinstance(node.args[1].expr, ExprId) + field = node.args[1].expr.name + src, src_type = node.args[0].info + struct_dst = src_type + + if isinstance(struct_dst, ObjCStruct): + found = False + for name, objtype, offset, _ in struct_dst.fields: + if name != field: + continue + expr = src + ExprInt(offset, src.size) + if isinstance(objtype, ObjCArray): + # Case 4 + pass + elif isinstance(objtype, (ObjCStruct, ObjCUnion)): + # Case 1 + pass + else: + # Case 2 + expr = ExprMem(expr, objtype.size * 8) + assert not found + found = True + out = (expr, objtype) + elif isinstance(struct_dst, ObjCUnion): + found = False + for name, objtype, offset, _ in struct_dst.fields: + if name != field: + continue + expr = src + ExprInt(offset, src.size) + if isinstance(objtype, ObjCArray): + # Case 4 + pass + elif isinstance(objtype, (ObjCStruct, ObjCUnion)): + # Case 1 + pass + else: + # Case 2 + expr = ExprMem(expr, objtype.size * 8) + assert not found + found = True + out = (expr, objtype) + else: + raise NotImplementedError("unknown ObjC") + assert found + return out + + def reduce_op_array(self, node, **kwargs): + """Reduce array operator""" + + if not node.expr.is_op('[]'): + return None + assert len(node.args) == 2 + out = [] + assert isinstance(node.args[1].expr, ExprInt) + cst = node.args[1].expr + src, src_type = node.args[0].info + objtype = src_type.objtype + expr = src + cst * ExprInt(objtype.size, cst.size) + if isinstance(src_type, ObjCPtr): + if isinstance(objtype, ObjCArray): + final = objtype.objtype + expr = src + cst * ExprInt(final.size, cst.size) + objtype = final + expr = ExprMem(expr, final.size * 8) + found = True + else: + expr = ExprMem(expr, objtype.size * 8) + found = True + elif isinstance(src_type, ObjCArray): + if isinstance(objtype, ObjCArray): + final = objtype + found = True + elif isinstance(objtype, ObjCStruct): + found = True + else: + expr = ExprMem(expr, objtype.size * 8) + found = True + else: + raise NotImplementedError("Unknown access" % node.expr) + assert found + out = (expr, objtype) + return out + + def reduce_op_addr(self, node, **kwargs): + """Reduce addr operator""" + + if not node.expr.is_op('addr'): + return None + assert len(node.args) == 1 + out = [] + src, src_type = node.args[0].info + + void_type = self.types_mngr.void_ptr + + if isinstance(src_type, ObjCArray): + out = (src.arg, ObjCPtr(src_type.objtype, + void_type.align, void_type.size)) + elif isinstance(src, ExprMem): + out = (src.ptr, ObjCPtr(src_type, + void_type.align, void_type.size)) + elif isinstance(src_type, ObjCStruct): + out = (src, ObjCPtr(src_type, + void_type.align, void_type.size)) + elif isinstance(src_type, ObjCUnion): + out = (src, ObjCPtr(src_type, + void_type.align, void_type.size)) + else: + raise NotImplementedError("unk type") + return out + + def reduce_op_deref(self, node, **kwargs): + """Reduce deref operator""" + + if not node.expr.is_op('deref'): + return None + out = [] + src, src_type = node.args[0].info + assert isinstance(src_type, (ObjCPtr, ObjCArray)) + void_type = self.types_mngr.void_ptr + if isinstance(src_type, ObjCPtr): + if isinstance(src_type.objtype, ObjCArray): + size = void_type.size*8 + else: + size = src_type.objtype.size * 8 + out = (ExprMem(src, size), (src_type.objtype)) + else: + size = src_type.objtype.size * 8 + out = (ExprMem(src, size), (src_type.objtype)) + return out + + reduction_rules = [reduce_known_expr, + reduce_int, + reduce_op_memberof, + reduce_op_field, + reduce_op_array, + reduce_op_addr, + reduce_op_deref, + ] + + def get_expr(self, expr, c_context): + """Translate a Miasm expression @expr (representing a C access) into a + tuple composed of a native Miasm expression and its C type. + @expr: Miasm expression (representing a C access) + @c_context: a dictionary linking known tokens (strings) to their + types. A token is linked to only one type. + """ + ret = self.reduce(expr, ctxt=c_context) + if ret.info is None: + return (None, None) + return ret.info + + +class CTypesManager(object): + """Represent a C object, without any layout information""" + + def __init__(self, types_ast, leaf_types): + self.types_ast = types_ast + self.leaf_types = leaf_types + + @property + def void_ptr(self): + """Retrieve a void* objc""" + return self.leaf_types.types.get(CTypePtr(CTypeId('void'))) + + @property + def padding(self): + """Retrieve a padding ctype""" + return CTypeId(PADDING_TYPE_NAME) + + def _get_objc(self, type_id, resolved=None, to_fix=None, lvl=0): + if resolved is None: + resolved = {} + if to_fix is None: + to_fix = [] + if type_id in resolved: + return resolved[type_id] + type_id = self.types_ast.get_type(type_id) + fixed = True + if isinstance(type_id, CTypeId): + out = self.leaf_types.types.get(type_id, None) + assert out is not None + elif isinstance(type_id, CTypeUnion): + args = [] + align_max, size_max = 0, 0 + for name, field in type_id.fields: + objc = self._get_objc(field, resolved, to_fix, lvl + 1) + resolved[field] = objc + align_max = max(align_max, objc.align) + size_max = max(size_max, objc.size) + args.append((name, objc, 0, objc.size)) + + align, size = self.union_compute_align_size(align_max, size_max) + out = ObjCUnion(type_id.name, align, size, args) + + elif isinstance(type_id, CTypeStruct): + align_max, size_max = 0, 0 + + args = [] + offset, align_max = 0, 1 + pad_index = 0 + for name, field in type_id.fields: + objc = self._get_objc(field, resolved, to_fix, lvl + 1) + resolved[field] = objc + align_max = max(align_max, objc.align) + new_offset = self.struct_compute_field_offset(objc, offset) + if new_offset - offset: + pad_name = "__PAD__%d__" % pad_index + pad_index += 1 + size = new_offset - offset + pad_objc = self._get_objc(CTypeArray(self.padding, size), resolved, to_fix, lvl + 1) + args.append((pad_name, pad_objc, offset, pad_objc.size)) + offset = new_offset + args.append((name, objc, offset, objc.size)) + offset += objc.size + + align, size = self.struct_compute_align_size(align_max, offset) + out = ObjCStruct(type_id.name, align, size, args) + + elif isinstance(type_id, CTypePtr): + target = type_id.target + out = ObjCPtr(None, self.void_ptr.align, self.void_ptr.size) + fixed = False + + elif isinstance(type_id, CTypeArray): + target = type_id.target + objc = self._get_objc(target, resolved, to_fix, lvl + 1) + resolved[target] = objc + if type_id.size is None: + # case: toto[] + # return ObjCPtr + out = ObjCPtr(objc, self.void_ptr.align, self.void_ptr.size) + else: + size = self.size_to_int(type_id.size) + if size is None: + raise RuntimeError('Enable to compute objc size') + else: + out = ObjCArray(objc, size) + assert out.size is not None and out.align is not None + elif isinstance(type_id, CTypeEnum): + # Enum are integer + return self.leaf_types.types.get(CTypeId('int')) + elif isinstance(type_id, CTypeFunc): + type_ret = self._get_objc( + type_id.type_ret, resolved, to_fix, lvl + 1) + resolved[type_id.type_ret] = type_ret + args = [] + for name, arg in type_id.args: + objc = self._get_objc(arg, resolved, to_fix, lvl + 1) + resolved[arg] = objc + args.append((name, objc)) + out = ObjCFunc(type_id.name, type_id.abi, type_ret, args, + self.void_ptr.align, self.void_ptr.size) + elif isinstance(type_id, CTypeEllipsis): + out = ObjCEllipsis() + else: + raise TypeError("Unknown type %r" % type_id.__class__) + if not isinstance(out, ObjCEllipsis): + assert out.align is not None and out.size is not None + + if fixed: + resolved[type_id] = out + else: + to_fix.append((type_id, out)) + return out + + def get_objc(self, type_id): + """Get the ObjC corresponding to the CType @type_id + @type_id: CTypeBase instance""" + resolved = {} + to_fix = [] + out = self._get_objc(type_id, resolved, to_fix) + # Fix sub objects + while to_fix: + type_id, objc_to_fix = to_fix.pop() + objc = self._get_objc(type_id.target, resolved, to_fix) + objc_to_fix.objtype = objc + self.check_objc(out) + return out + + def check_objc(self, objc, done=None): + """Ensure each sub ObjC is resolved + @objc: ObjC instance""" + if done is None: + done = set() + if objc in done: + return True + done.add(objc) + if isinstance(objc, (ObjCDecl, ObjCInt, ObjCEllipsis)): + return True + elif isinstance(objc, (ObjCPtr, ObjCArray)): + assert self.check_objc(objc.objtype, done) + return True + elif isinstance(objc, (ObjCStruct, ObjCUnion)): + for _, field, _, _ in objc.fields: + assert self.check_objc(field, done) + return True + elif isinstance(objc, ObjCFunc): + assert self.check_objc(objc.type_ret, done) + for name, arg in objc.args: + assert self.check_objc(arg, done) + return True + else: + assert False + + def size_to_int(self, size): + """Resolve an array size + @size: CTypeOp or integer""" + if isinstance(size, CTypeOp): + assert len(size.args) == 2 + arg0, arg1 = [self.size_to_int(arg) for arg in size.args] + if size.operator == "+": + return arg0 + arg1 + elif size.operator == "-": + return arg0 - arg1 + elif size.operator == "*": + return arg0 * arg1 + elif size.operator == "/": + return arg0 // arg1 + elif size.operator == "<<": + return arg0 << arg1 + elif size.operator == ">>": + return arg0 >> arg1 + else: + raise ValueError("Unknown operator %s" % size.operator) + elif isinstance(size, int_types): + return size + elif isinstance(size, CTypeSizeof): + obj = self._get_objc(size.target) + return obj.size + else: + raise TypeError("Unknown size type") + + def struct_compute_field_offset(self, obj, offset): + """Compute the offset of the field @obj in the current structure""" + raise NotImplementedError("Abstract method") + + def struct_compute_align_size(self, align_max, size): + """Compute the alignment and size of the current structure""" + raise NotImplementedError("Abstract method") + + def union_compute_align_size(self, align_max, size): + """Compute the alignment and size of the current union""" + raise NotImplementedError("Abstract method") + + +class CTypesManagerNotPacked(CTypesManager): + """Store defined C types (not packed)""" + + def struct_compute_field_offset(self, obj, offset): + """Compute the offset of the field @obj in the current structure + (not packed)""" + + if obj.align > 1: + offset = (offset + obj.align - 1) & ~(obj.align - 1) + return offset + + def struct_compute_align_size(self, align_max, size): + """Compute the alignment and size of the current structure + (not packed)""" + if align_max > 1: + size = (size + align_max - 1) & ~(align_max - 1) + return align_max, size + + def union_compute_align_size(self, align_max, size): + """Compute the alignment and size of the current union + (not packed)""" + return align_max, size + + +class CTypesManagerPacked(CTypesManager): + """Store defined C types (packed form)""" + + def struct_compute_field_offset(self, _, offset): + """Compute the offset of the field @obj in the current structure + (packed form)""" + return offset + + def struct_compute_align_size(self, _, size): + """Compute the alignment and size of the current structure + (packed form)""" + return 1, size + + def union_compute_align_size(self, align_max, size): + """Compute the alignment and size of the current union + (packed form)""" + return 1, size + + +class CHandler(object): + """ + C manipulator for Miasm + Miasm expr <-> C + """ + + exprCToExpr_cls = ExprCToExpr + exprToAccessC_cls = ExprToAccessC + + def __init__(self, types_mngr, expr_types=None, + C_types=None, + simplify_c=access_simplifier, + enforce_strict_access=True): + self.exprc2expr = self.exprCToExpr_cls(expr_types, types_mngr) + self.access_c_gen = self.exprToAccessC_cls(expr_types, + types_mngr, + enforce_strict_access) + self.types_mngr = types_mngr + self.simplify_c = simplify_c + if expr_types is None: + expr_types = {} + self.expr_types = expr_types + if C_types is None: + C_types = {} + self.C_types = C_types + + def updt_expr_types(self, expr_types): + """Update expr_types + @expr_types: Dictionary associating name to type + """ + + self.expr_types = expr_types + self.exprc2expr.updt_expr_types(expr_types) + self.access_c_gen.updt_expr_types(expr_types) + + def expr_to_c_access(self, expr, expr_context=None): + """Generate the C access object(s) for a given native Miasm expression. + @expr: Miasm expression + @expr_context: a dictionary linking known expressions to a set of types + """ + + if expr_context is None: + expr_context = self.expr_types + return self.access_c_gen.get_accesses(expr, expr_context) + + + def expr_to_c_and_types(self, expr, expr_context=None): + """Generate the C access string and corresponding type for a given + native Miasm expression. + @expr_context: a dictionary linking known expressions to a set of types + """ + + accesses = set() + for access in self.expr_to_c_access(expr, expr_context): + c_str = access_str(access.to_expr().visit(self.simplify_c)) + accesses.add((c_str, access.ctype)) + return accesses + + def expr_to_c(self, expr, expr_context=None): + """Convert a Miasm @expr into it's C equivalent string + @expr_context: a dictionary linking known expressions to a set of types + """ + + return set(access[0] + for access in self.expr_to_c_and_types(expr, expr_context)) + + def expr_to_types(self, expr, expr_context=None): + """Get the possible types of the Miasm @expr + @expr_context: a dictionary linking known expressions to a set of types + """ + + return set(access.ctype + for access in self.expr_to_c_access(expr, expr_context)) + + def c_to_expr_and_type(self, c_str, c_context=None): + """Convert a C string expression to a Miasm expression and it's + corresponding c type + @c_str: C string + @c_context: (optional) dictionary linking known tokens (strings) to its + type. + """ + + ast = parse_access(c_str) + if c_context is None: + c_context = self.C_types + access_c = ast_get_c_access_expr(ast, c_context) + return self.exprc2expr.get_expr(access_c, c_context) + + def c_to_expr(self, c_str, c_context=None): + """Convert a C string expression to a Miasm expression + @c_str: C string + @c_context: (optional) dictionary linking known tokens (strings) to its + type. + """ + + if c_context is None: + c_context = self.C_types + expr, _ = self.c_to_expr_and_type(c_str, c_context) + return expr + + def c_to_type(self, c_str, c_context=None): + """Get the type of a C string expression + @expr: Miasm expression + @c_context: (optional) dictionary linking known tokens (strings) to its + type. + """ + + if c_context is None: + c_context = self.C_types + _, ctype = self.c_to_expr_and_type(c_str, c_context) + return ctype + + +class CLeafTypes(object): + """Define C types sizes/alignment for a given architecture""" + pass diff --git a/src/miasm/core/parse_asm.py b/src/miasm/core/parse_asm.py new file mode 100644 index 00000000..79ef416d --- /dev/null +++ b/src/miasm/core/parse_asm.py @@ -0,0 +1,288 @@ +#-*- coding:utf-8 -*- +import re +import codecs +from builtins import range + +from miasm.core.utils import force_str +from miasm.expression.expression import ExprId, ExprInt, ExprOp, LocKey +import miasm.core.asmblock as asmblock +from miasm.core.cpu import instruction, base_expr +from miasm.core.asm_ast import AstInt, AstId, AstOp + +declarator = {'byte': 8, + 'word': 16, + 'dword': 32, + 'qword': 64, + 'long': 32, + } + +size2pck = {8: 'B', + 16: 'H', + 32: 'I', + 64: 'Q', + } + +EMPTY_RE = re.compile(r'\s*$') +COMMENT_RE = re.compile(r'\s*;\S*') +LOCAL_LABEL_RE = re.compile(r'\s*(\.L\S+)\s*:') +DIRECTIVE_START_RE = re.compile(r'\s*\.') +DIRECTIVE_RE = re.compile(r'\s*\.(\S+)') +LABEL_RE = re.compile(r'\s*(\S+)\s*:') +FORGET_LABEL_RE = re.compile(r'\s*\.LF[BE]\d\s*:') + + +class Directive(object): + + """Stand for Directive""" + + pass + +class DirectiveAlign(Directive): + + """Stand for alignment representation""" + + def __init__(self, alignment=1): + self.alignment = alignment + + def __str__(self): + return "Alignment %s" % self.alignment + + +class DirectiveSplit(Directive): + + """Stand for alignment representation""" + + pass + + +class DirectiveDontSplit(Directive): + + """Stand for alignment representation""" + + pass + + +STATE_NO_BLOC = 0 +STATE_IN_BLOC = 1 + + +def asm_ast_to_expr_with_size(arg, loc_db, size): + if isinstance(arg, AstId): + return ExprId(force_str(arg.name), size) + if isinstance(arg, AstOp): + args = [asm_ast_to_expr_with_size(tmp, loc_db, size) for tmp in arg.args] + return ExprOp(arg.op, *args) + if isinstance(arg, AstInt): + return ExprInt(arg.value, size) + return None + +def parse_txt(mnemo, attrib, txt, loc_db): + """Parse an assembly listing. Returns an AsmCfg instance + + @mnemo: architecture used + @attrib: architecture attribute + @txt: assembly listing + @loc_db: the LocationDB instance used to handle labels of the listing + + """ + + C_NEXT = asmblock.AsmConstraint.c_next + C_TO = asmblock.AsmConstraint.c_to + + lines = [] + # parse each line + for line in txt.split('\n'): + # empty + if EMPTY_RE.match(line): + continue + # comment + if COMMENT_RE.match(line): + continue + # labels to forget + if FORGET_LABEL_RE.match(line): + continue + # label beginning with .L + match_re = LABEL_RE.match(line) + if match_re: + label_name = match_re.group(1) + label = loc_db.get_or_create_name_location(label_name) + lines.append(label) + continue + # directive + if DIRECTIVE_START_RE.match(line): + match_re = DIRECTIVE_RE.match(line) + directive = match_re.group(1) + if directive in ['text', 'data', 'bss']: + continue + if directive in ['string', 'ascii']: + # XXX HACK + line = line.replace(r'\n', '\n').replace(r'\r', '\r') + raw = line[line.find(r'"') + 1:line.rfind(r'"')] + raw = codecs.escape_decode(raw)[0] + if directive == 'string': + raw += b"\x00" + lines.append(asmblock.AsmRaw(raw)) + continue + if directive == 'ustring': + # XXX HACK + line = line.replace(r'\n', '\n').replace(r'\r', '\r') + raw = line[line.find(r'"') + 1:line.rfind(r'"')] + "\x00" + raw = codecs.escape_decode(raw)[0] + out = b'' + for i in range(len(raw)): + out += raw[i:i+1] + b'\x00' + lines.append(asmblock.AsmRaw(out)) + continue + if directive in declarator: + data_raw = line[match_re.end():].split(' ', 1)[1] + data_raw = data_raw.split(',') + size = declarator[directive] + expr_list = [] + + # parser + + for element in data_raw: + element = element.strip() + element_parsed = base_expr.parseString(element)[0] + element_expr = asm_ast_to_expr_with_size(element_parsed, loc_db, size) + expr_list.append(element_expr) + + raw_data = asmblock.AsmRaw(expr_list) + raw_data.element_size = size + lines.append(raw_data) + continue + if directive == 'comm': + # TODO + continue + if directive == 'split': # custom command + lines.append(DirectiveSplit()) + continue + if directive == 'dontsplit': # custom command + lines.append(DirectiveDontSplit()) + continue + if directive == "align": + align_value = int(line[match_re.end():], 0) + lines.append(DirectiveAlign(align_value)) + continue + if directive in ['file', 'intel_syntax', 'globl', 'local', + 'type', 'size', 'align', 'ident', 'section']: + continue + if directive[0:4] == 'cfi_': + continue + + raise ValueError("unknown directive %s" % directive) + + # label + match_re = LABEL_RE.match(line) + if match_re: + label_name = match_re.group(1) + label = loc_db.get_or_create_name_location(label_name) + lines.append(label) + continue + + # code + if ';' in line: + line = line[:line.find(';')] + line = line.strip(' ').strip('\t') + instr = mnemo.fromstring(line, loc_db, attrib) + lines.append(instr) + + asmblock.log_asmblock.info("___pre asm oki___") + # make asmcfg + + cur_block = None + state = STATE_NO_BLOC + i = 0 + asmcfg = asmblock.AsmCFG(loc_db) + block_to_nlink = None + delayslot = 0 + while i < len(lines): + if delayslot: + delayslot -= 1 + if delayslot == 0: + state = STATE_NO_BLOC + line = lines[i] + # no current block + if state == STATE_NO_BLOC: + if isinstance(line, DirectiveDontSplit): + block_to_nlink = cur_block + i += 1 + continue + elif isinstance(line, DirectiveSplit): + block_to_nlink = None + i += 1 + continue + elif not isinstance(line, LocKey): + # First line must be a label. If it's not the case, generate + # it. + loc = loc_db.add_location() + cur_block = asmblock.AsmBlock(loc_db, loc, alignment=mnemo.alignment) + else: + cur_block = asmblock.AsmBlock(loc_db, line, alignment=mnemo.alignment) + i += 1 + # Generate the current block + asmcfg.add_block(cur_block) + state = STATE_IN_BLOC + if block_to_nlink: + block_to_nlink.addto( + asmblock.AsmConstraint( + cur_block.loc_key, + C_NEXT + ) + ) + block_to_nlink = None + continue + + # in block + elif state == STATE_IN_BLOC: + if isinstance(line, DirectiveSplit): + state = STATE_NO_BLOC + block_to_nlink = None + elif isinstance(line, DirectiveDontSplit): + state = STATE_NO_BLOC + block_to_nlink = cur_block + elif isinstance(line, DirectiveAlign): + cur_block.alignment = line.alignment + elif isinstance(line, asmblock.AsmRaw): + cur_block.addline(line) + block_to_nlink = cur_block + elif isinstance(line, LocKey): + if block_to_nlink: + cur_block.addto( + asmblock.AsmConstraint(line, C_NEXT) + ) + block_to_nlink = None + state = STATE_NO_BLOC + continue + # instruction + elif isinstance(line, instruction): + cur_block.addline(line) + block_to_nlink = cur_block + if not line.breakflow(): + i += 1 + continue + if delayslot: + raise RuntimeError("Cannot have breakflow in delayslot") + if line.dstflow(): + for dst in line.getdstflow(loc_db): + if not isinstance(dst, ExprId): + continue + if dst in mnemo.regs.all_regs_ids: + continue + cur_block.addto(asmblock.AsmConstraint(dst.name, C_TO)) + + if not line.splitflow(): + block_to_nlink = None + + delayslot = line.delayslot + 1 + else: + raise RuntimeError("unknown class %s" % line.__class__) + i += 1 + + for block in asmcfg.blocks: + # Fix multiple constraints + block.fix_constraints() + + # Log block + asmblock.log_asmblock.info(block) + return asmcfg diff --git a/src/miasm/core/sembuilder.py b/src/miasm/core/sembuilder.py new file mode 100644 index 00000000..9843ee6a --- /dev/null +++ b/src/miasm/core/sembuilder.py @@ -0,0 +1,341 @@ +"Helper to quickly build instruction's semantic side effects" + +import inspect +import ast +import re + +from future.utils import PY3 + +import miasm.expression.expression as m2_expr +from miasm.ir.ir import IRBlock, AssignBlock + + +class MiasmTransformer(ast.NodeTransformer): + """AST visitor translating DSL to Miasm expression + + memX[Y] -> ExprMem(Y, X) + iX(Y) -> ExprIntX(Y) + X if Y else Z -> ExprCond(Y, X, Z) + 'X'(Y) -> ExprOp('X', Y) + ('X' % Y)(Z) -> ExprOp('X' % Y, Z) + {a, b} -> ExprCompose(((a, 0, a.size), (b, a.size, a.size + b.size))) + """ + + # Parsers + parse_integer = re.compile(r"^i([0-9]+)$") + parse_mem = re.compile(r"^mem([0-9]+)$") + + # Visitors + def visit_Call(self, node): + """iX(Y) -> ExprIntX(Y), + 'X'(Y) -> ExprOp('X', Y), ('X' % Y)(Z) -> ExprOp('X' % Y, Z)""" + + # Recursive visit + node = self.generic_visit(node) + if isinstance(node.func, ast.Name): + # iX(Y) -> ExprInt(Y, X) + fc_name = node.func.id + + # Match the function name + new_name = fc_name + integer = self.parse_integer.search(fc_name) + + # Do replacement + if integer is not None: + size = int(integer.groups()[0]) + new_name = "ExprInt" + # Replace in the node + node.func.id = new_name + node.args.append(ast.Num(n=size)) + + elif (isinstance(node.func, ast.Str) or + (isinstance(node.func, ast.BinOp) and + isinstance(node.func.op, ast.Mod) and + isinstance(node.func.left, ast.Str))): + # 'op'(args...) -> ExprOp('op', args...) + # ('op' % (fmt))(args...) -> ExprOp('op' % (fmt), args...) + op_name = node.func + + # Do replacement + node.func = ast.Name(id="ExprOp", ctx=ast.Load()) + node.args[0:0] = [op_name] + + return node + + def visit_IfExp(self, node): + """X if Y else Z -> ExprCond(Y, X, Z)""" + # Recursive visit + node = self.generic_visit(node) + + # Build the new ExprCond + call = ast.Call(func=ast.Name(id='ExprCond', ctx=ast.Load()), + args=[self.visit(node.test), + self.visit(node.body), + self.visit(node.orelse)], + keywords=[], starargs=None, kwargs=None) + return call + + def visit_Set(self, node): + "{a, b} -> ExprCompose(a, b)" + if len(node.elts) == 0: + return node + + # Recursive visit + node = self.generic_visit(node) + + return ast.Call(func=ast.Name(id='ExprCompose', + ctx=ast.Load()), + args=node.elts, + keywords=[], + starargs=None, + kwargs=None) + +if PY3: + def get_arg_name(name): + return name.arg + def gen_arg(name, ctx): + return ast.arg(arg=name, ctx=ctx) +else: + def get_arg_name(name): + return name.id + def gen_arg(name, ctx): + return ast.Name(id=name, ctx=ctx) + + +class SemBuilder(object): + """Helper for building instruction's semantic side effects method + + This class provides a decorator @parse to use on them. + The context in which the function will be parsed must be supplied on + instantiation + """ + + def __init__(self, ctx): + """Create a SemBuilder + @ctx: context dictionary used during parsing + """ + # Init + self.transformer = MiasmTransformer() + self._ctx = dict(m2_expr.__dict__) + self._ctx["IRBlock"] = IRBlock + self._ctx["AssignBlock"] = AssignBlock + self._functions = {} + + # Update context + self._ctx.update(ctx) + + @property + def functions(self): + """Return a dictionary name -> func of parsed functions""" + return self._functions.copy() + + @staticmethod + def _create_labels(loc_else=False): + """Return the AST standing for label creations + @loc_else (optional): if set, create a label 'loc_else'""" + loc_end = "loc_end = ir.get_next_loc_key(instr)" + loc_end_expr = "loc_end_expr = ExprLoc(loc_end, ir.IRDst.size)" + out = ast.parse(loc_end).body + out += ast.parse(loc_end_expr).body + loc_if = "loc_if = ir.loc_db.add_location()" + loc_if_expr = "loc_if_expr = ExprLoc(loc_if, ir.IRDst.size)" + out += ast.parse(loc_if).body + out += ast.parse(loc_if_expr).body + if loc_else: + loc_else = "loc_else = ir.loc_db.add_location()" + loc_else_expr = "loc_else_expr = ExprLoc(loc_else, ir.IRDst.size)" + out += ast.parse(loc_else).body + out += ast.parse(loc_else_expr).body + return out + + def _parse_body(self, body, argument_names): + """Recursive function transforming a @body to a block expression + Return: + - AST to append to body (real python statements) + - a list of blocks, ie list of affblock, ie list of ExprAssign (AST)""" + + # Init + ## Real instructions + real_body = [] + ## Final blocks + blocks = [[[]]] + + for statement in body: + + if isinstance(statement, ast.Assign): + src = self.transformer.visit(statement.value) + dst = self.transformer.visit(statement.targets[0]) + + if (isinstance(dst, ast.Name) and + dst.id not in argument_names and + dst.id not in self._ctx and + dst.id not in self._local_ctx): + + # Real variable declaration + statement.value = src + real_body.append(statement) + self._local_ctx[dst.id] = src + continue + + dst.ctx = ast.Load() + + res = ast.Call(func=ast.Name(id='ExprAssign', + ctx=ast.Load()), + args=[dst, src], + keywords=[], + starargs=None, + kwargs=None) + + blocks[-1][-1].append(res) + + elif (isinstance(statement, ast.Expr) and + isinstance(statement.value, ast.Str)): + # String (docstring, comment, ...) -> keep it + real_body.append(statement) + + elif isinstance(statement, ast.If): + # Create jumps : ir.IRDst = loc_if if cond else loc_end + # if .. else .. are also handled + cond = statement.test + real_body += self._create_labels(loc_else=True) + + loc_end = ast.Name(id='loc_end_expr', ctx=ast.Load()) + loc_if = ast.Name(id='loc_if_expr', ctx=ast.Load()) + loc_else = ast.Name(id='loc_else_expr', ctx=ast.Load()) \ + if statement.orelse else loc_end + dst = ast.Call(func=ast.Name(id='ExprCond', + ctx=ast.Load()), + args=[cond, + loc_if, + loc_else], + keywords=[], + starargs=None, + kwargs=None) + + if (isinstance(cond, ast.UnaryOp) and + isinstance(cond.op, ast.Not)): + ## if not cond -> switch exprCond + dst.args[1:] = dst.args[1:][::-1] + dst.args[0] = cond.operand + + IRDst = ast.Attribute(value=ast.Name(id='ir', + ctx=ast.Load()), + attr='IRDst', ctx=ast.Load()) + loc_db = ast.Attribute(value=ast.Name(id='ir', + ctx=ast.Load()), + attr='loc_db', ctx=ast.Load()) + blocks[-1][-1].append(ast.Call(func=ast.Name(id='ExprAssign', + ctx=ast.Load()), + args=[IRDst, dst], + keywords=[], + starargs=None, + kwargs=None)) + + # Create the new blocks + elements = [(statement.body, 'loc_if')] + if statement.orelse: + elements.append((statement.orelse, 'loc_else')) + for content, loc_name in elements: + sub_blocks, sub_body = self._parse_body(content, + argument_names) + if len(sub_blocks) > 1: + raise RuntimeError("Imbricated if unimplemented") + + ## Close the last block + jmp_end = ast.Call(func=ast.Name(id='ExprAssign', + ctx=ast.Load()), + args=[IRDst, loc_end], + keywords=[], + starargs=None, + kwargs=None) + sub_blocks[-1][-1].append(jmp_end) + + + instr = ast.Name(id='instr', ctx=ast.Load()) + effects = ast.List(elts=sub_blocks[-1][-1], + ctx=ast.Load()) + assignblk = ast.Call(func=ast.Name(id='AssignBlock', + ctx=ast.Load()), + args=[effects, instr], + keywords=[], + starargs=None, + kwargs=None) + + + ## Replace the block with a call to 'IRBlock' + loc_if_name = ast.Name(id=loc_name, ctx=ast.Load()) + + assignblks = ast.List(elts=[assignblk], + ctx=ast.Load()) + + sub_blocks[-1] = ast.Call(func=ast.Name(id='IRBlock', + ctx=ast.Load()), + args=[ + loc_db, + loc_if_name, + assignblks + ], + keywords=[], + starargs=None, + kwargs=None) + blocks += sub_blocks + real_body += sub_body + + # Prepare a new block for following statement + blocks.append([[]]) + + else: + # TODO: real var, +=, /=, -=, <<=, >>=, if/else, ... + raise RuntimeError("Unimplemented %s" % statement) + + return blocks, real_body + + def parse(self, func): + """Function decorator, returning a correct method from a pseudo-Python + one""" + + # Get the function AST + parsed = ast.parse(inspect.getsource(func)) + fc_ast = parsed.body[0] + argument_names = [get_arg_name(name) for name in fc_ast.args.args] + + # Init local cache + self._local_ctx = {} + + # Translate (blocks[0][0] is the current instr) + blocks, body = self._parse_body(fc_ast.body, argument_names) + + # Build the new function + fc_ast.args.args[0:0] = [ + gen_arg('ir', ast.Param()), + gen_arg('instr', ast.Param()) + ] + cur_instr = blocks[0][0] + if len(blocks[-1][0]) == 0: + ## Last block can be empty + blocks.pop() + other_blocks = blocks[1:] + body.append(ast.Return(value=ast.Tuple(elts=[ast.List(elts=cur_instr, + ctx=ast.Load()), + ast.List(elts=other_blocks, + ctx=ast.Load())], + ctx=ast.Load()))) + + ret = ast.parse('') + ret.body = [ast.FunctionDef(name=fc_ast.name, + args=fc_ast.args, + body=body, + decorator_list=[])] + + # To display the generated function, use codegen.to_source + # codegen: https://github.com/andreif/codegen + + # Compile according to the context + fixed = ast.fix_missing_locations(ret) + codeobj = compile(fixed, '<string>', 'exec') + ctx = self._ctx.copy() + eval(codeobj, ctx) + + # Get the function back + self._functions[fc_ast.name] = ctx[fc_ast.name] + return ctx[fc_ast.name] diff --git a/src/miasm/core/types.py b/src/miasm/core/types.py new file mode 100644 index 00000000..4f99627d --- /dev/null +++ b/src/miasm/core/types.py @@ -0,0 +1,1693 @@ +"""This module provides classes to manipulate pure C types as well as their +representation in memory. A typical usecase is to use this module to +easily manipylate structures backed by a VmMngr object (a miasm sandbox virtual +memory): + + class ListNode(MemStruct): + fields = [ + ("next", Ptr("<I", Self())), + ("data", Ptr("<I", Void())), + ] + + class LinkedList(MemStruct): + fields = [ + ("head", Ptr("<I", ListNode)), + ("tail", Ptr("<I", ListNode)), + ("size", Num("<I")), + ] + + link = LinkedList(vm, addr1) + link.memset() + node = ListNode(vm, addr2) + node.memset() + link.head = node.get_addr() + link.tail = node.get_addr() + link.size += 1 + assert link.head.deref == node + data = Num("<I").lval(vm, addr3) + data.val = 5 + node.data = data.get_addr() + # see examples/jitter/types.py for more info + + +It provides two families of classes, Type-s (Num, Ptr, Str...) and their +associated MemType-s. A Type subclass instance represents a fully defined C +type. A MemType subclass instance represents a C LValue (or variable): it is +a type attached to the memory. Available types are: + + - Num: for number (float or int) handling + - Ptr: a pointer to another Type + - Struct: equivalent to a C struct definition + - Union: similar to union in C, list of Types at the same offset in a + structure; the union has the size of the biggest Type (~ Struct with all + the fields at offset 0) + - Array: an array of items of the same type; can have a fixed size or + not (e.g. char[3] vs char* used as an array in C) + - BitField: similar to C bitfields, a list of + [(<field_name>, <number_of_bits>),]; creates fields that correspond to + certain bits of the field; analogous to a Union of Bits (see Bits below) + - Str: a character string, with an encoding; not directly mapped to a C + type, it is a higher level notion provided for ease of use + - Void: analogous to C void, can be a placeholder in void*-style cases. + - Self: special marker to reference a Struct inside itself (FIXME: to + remove?) + +And some less common types: + + - Bits: mask only some bits of a Num + - RawStruct: abstraction over a simple struct pack/unpack (no mapping to a + standard C type) + +For each type, the `.lval` property returns a MemType subclass that +allows to access the field in memory. + + +The easiest way to use the API to declare and manipulate new structures is to +subclass MemStruct and define a list of (<field_name>, <field_definition>): + + class MyStruct(MemStruct): + fields = [ + # Scalar field: just struct.pack field with one value + ("num", Num("I")), + ("flags", Num("B")), + # Ptr fields contain two fields: "val", for the numerical value, + # and "deref" to get the pointed object + ("other", Ptr("I", OtherStruct)), + # Ptr to a variable length String + ("s", Ptr("I", Str())), + ("i", Ptr("I", Num("I"))), + ] + +And access the fields: + + mstruct = MyStruct(jitter.vm, addr) + mstruct.num = 3 + assert mstruct.num == 3 + mstruct.other.val = addr2 + # Also works: + mstruct.other = addr2 + mstruct.other.deref = OtherStruct(jitter.vm, addr) + +MemUnion and MemBitField can also be subclassed, the `fields` field being +in the format expected by, respectively, Union and BitField. + +The `addr` argument can be omitted if an allocator is set, in which case the +structure will be automatically allocated in memory: + + my_heap = miasm.os_dep.common.heap() + # the allocator is a func(VmMngr) -> integer_address + set_allocator(my_heap) + +Note that some structures (e.g. MemStr or MemArray) do not have a static +size and cannot be allocated automatically. +""" + +from builtins import range, zip +from builtins import int as int_types +import itertools +import logging +import struct +from future.utils import PY3 +from future.utils import viewitems, with_metaclass + +log = logging.getLogger(__name__) +console_handler = logging.StreamHandler() +console_handler.setFormatter(logging.Formatter("[%(levelname)-8s]: %(message)s")) +log.addHandler(console_handler) +log.setLevel(logging.WARN) + +# Cache for dynamically generated MemTypes +DYN_MEM_STRUCT_CACHE = {} + +def set_allocator(alloc_func): + """Shorthand to set the default allocator of MemType. See + MemType.set_allocator doc for more information. + """ + MemType.set_allocator(alloc_func) + + +# Helpers + +def to_type(obj): + """If possible, return the Type associated with @obj, otherwise raises + a ValueError. + + Works with a Type instance (returns obj) or a MemType subclass or instance + (returns obj.get_type()). + """ + # obj is a python type + if isinstance(obj, type): + if issubclass(obj, MemType): + if obj.get_type() is None: + raise ValueError("%r has no static type; use a subclasses " + "with a non null _type or use a " + "Type instance" % obj) + return obj.get_type() + # obj is not not a type + else: + if isinstance(obj, Type): + return obj + elif isinstance(obj, MemType): + return obj.get_type() + raise ValueError("%r is not a Type or a MemType" % obj) + +def indent(s, size=4): + """Indent a string with @size spaces""" + return ' '*size + ('\n' + ' '*size).join(s.split('\n')) + + +# String generic getter/setter/len-er +# TODO: make miasm.os_dep.common and jitter ones use these ones + +def get_str(vm, addr, enc, max_char=None, end=u'\x00'): + """Get a @end (by default '\\x00') terminated @enc encoded string from a + VmMngr. + + For example: + - get_str(vm, addr, "ascii") will read "foo\\x00" in memory and + return u"foo" + - get_str(vm, addr, "utf-16le") will read "f\\x00o\\x00o\\x00\\x00\\x00" + in memory and return u"foo" as well. + + Setting @max_char=<n> and @end='' allows to read non null terminated strings + from memory. + + @vm: VmMngr instance + @addr: the address at which to read the string + @enc: the encoding of the string to read. + @max_char: max number of bytes to get in memory + @end: the unencoded ending sequence of the string, by default "\\x00". + Unencoded here means that the actual ending sequence that this function + will look for is end.encode(enc), not directly @end. + """ + s = [] + end_char= end.encode(enc) + step = len(end_char) + i = 0 + while max_char is None or i < max_char: + c = vm.get_mem(addr + i, step) + if c == end_char: + break + s.append(c) + i += step + return b''.join(s).decode(enc) + +def raw_str(s, enc, end=u'\x00'): + """Returns a string representing @s as an @end (by default \\x00) + terminated @enc encoded string. + + @s: the unicode str to serialize + @enc: the encoding to apply to @s and @end before serialization. + @end: the ending string/character to append to the string _before encoding_ + and serialization (by default '\\x00') + """ + return (s + end).encode(enc) + +def set_str(vm, addr, s, enc, end=u'\x00'): + """Encode a string to an @end (by default \\x00) terminated @enc encoded + string and set it in a VmMngr memory. + + @vm: VmMngr instance + @addr: start address to serialize the string to + @s: the unicode str to serialize + @enc: the encoding to apply to @s and @end before serialization. + @end: the ending string/character to append to the string _before encoding_ + and serialization (by default '\\x00') + """ + s = raw_str(s, enc, end=end) + vm.set_mem(addr, s) + +def raw_len(py_unic_str, enc, end=u'\x00'): + """Returns the length in bytes of @py_unic_str in memory (once @end has been + added and the full str has been encoded). It returns exactly the room + necessary to call set_str with similar arguments. + + @py_unic_str: the unicode str to work with + @enc: the encoding to encode @py_unic_str to + @end: the ending string/character to append to the string _before encoding_ + (by default \\x00) + """ + return len(raw_str(py_unic_str, enc)) + +def enc_triplet(enc, max_char=None, end=u'\x00'): + """Returns a triplet of functions (get_str_enc, set_str_enc, raw_len_enc) + for a given encoding (as needed by Str to add an encoding). The prototypes + are: + + - get_str_end: same as get_str without the @enc argument + - set_str_end: same as set_str without the @enc argument + - raw_len_enc: same as raw_len without the @enc argument + """ + return ( + lambda vm, addr, max_char=max_char, end=end: \ + get_str(vm, addr, enc, max_char=max_char, end=end), + lambda vm, addr, s, end=end: set_str(vm, addr, s, enc, end=end), + lambda s, end=end: raw_len(s, enc, end=end), + ) + + +# Type classes + +class Type(object): + """Base class to provide methods to describe a type, as well as how to set + and get fields from virtual mem. + + Each Type subclass is linked to a MemType subclass (e.g. Struct to + MemStruct, Ptr to MemPtr, etc.). + + When nothing is specified, MemValue is used to access the type in memory. + MemValue instances have one `.val` field, setting and getting it call + the set and get of the Type. + + Subclasses can either override _pack and _unpack, or get and set if data + serialization requires more work (see Struct implementation for an example). + + TODO: move any trace of vm and addr out of these classes? + """ + + _self_type = None + + def _pack(self, val): + """Serializes the python value @val to a raw str""" + raise NotImplementedError() + + def _unpack(self, raw_str): + """Deserializes a raw str to an object representing the python value + of this field. + """ + raise NotImplementedError() + + def set(self, vm, addr, val): + """Set a VmMngr memory from a value. + + @vm: VmMngr instance + @addr: the start address in memory to set + @val: the python value to serialize in @vm at @addr + """ + raw = self._pack(val) + vm.set_mem(addr, raw) + + def get(self, vm, addr): + """Get the python value of a field from a VmMngr memory at @addr.""" + raw = vm.get_mem(addr, self.size) + return self._unpack(raw) + + @property + def lval(self): + """Returns a class with a (vm, addr) constructor that allows to + interact with this type in memory. + + In compilation terms, it returns a class allowing to instantiate an + lvalue of this type. + + @return: a MemType subclass. + """ + if self in DYN_MEM_STRUCT_CACHE: + return DYN_MEM_STRUCT_CACHE[self] + pinned_type = self._build_pinned_type() + DYN_MEM_STRUCT_CACHE[self] = pinned_type + return pinned_type + + def _build_pinned_type(self): + """Builds the MemType subclass allowing to interact with this type. + + Called by self.lval when it is not in cache. + """ + pinned_base_class = self._get_pinned_base_class() + pinned_type = type( + "Mem%r" % self, + (pinned_base_class,), + {'_type': self} + ) + return pinned_type + + def _get_pinned_base_class(self): + """Return the MemType subclass that maps this type in memory""" + return MemValue + + def _get_self_type(self): + """Used for the Self trick.""" + return self._self_type + + def _set_self_type(self, self_type): + """If this field refers to MemSelf/Self, replace it with @self_type + (a Type instance) when using it. Generally not used outside this + module. + """ + self._self_type = self_type + + @property + def size(self): + """Return the size in bytes of the serialized version of this field""" + raise NotImplementedError() + + def __len__(self): + return self.size + + def __neq__(self, other): + return not self == other + + def __eq__(self, other): + raise NotImplementedError("Abstract method") + + def __ne__(self, other): + return not self == other + + +class RawStruct(Type): + """Dumb struct.pack/unpack field. Mainly used to factorize code. + + Value is a tuple corresponding to the struct @fmt passed to the constructor. + """ + + def __init__(self, fmt): + self._fmt = fmt + + def _pack(self, fields): + return struct.pack(self._fmt, *fields) + + def _unpack(self, raw_str): + return struct.unpack(self._fmt, raw_str) + + @property + def size(self): + return struct.calcsize(self._fmt) + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, self._fmt) + + def __eq__(self, other): + return self.__class__ == other.__class__ and self._fmt == other._fmt + + def __ne__(self, other): + return not self == other + + def __hash__(self): + return hash((self.__class__, self._fmt)) + + +class Num(RawStruct): + """Represents a number (integer or float). The number is encoded with + a struct-style format which must represent only one value. + + TODO: use u32, i16, etc. for format. + """ + + def _pack(self, number): + return super(Num, self)._pack([number]) + + def _unpack(self, raw_str): + upck = super(Num, self)._unpack(raw_str) + if len(upck) != 1: + raise ValueError("Num format string unpacks to multiple values, " + "should be 1") + return upck[0] + + +class Ptr(Num): + """Special case of number of which value indicates the address of a + MemType. + + Mapped to MemPtr (see its doc for more info): + + assert isinstance(mystruct.ptr, MemPtr) + mystruct.ptr = 0x4000 # Assign the Ptr numeric value + mystruct.ptr.val = 0x4000 # Also assigns the Ptr numeric value + assert isinstance(mystruct.ptr.val, int) # Get the Ptr numeric value + mystruct.ptr.deref # Get the pointed MemType + mystruct.ptr.deref = other # Set the pointed MemType + """ + + def __init__(self, fmt, dst_type, *type_args, **type_kwargs): + """ + @fmt: (str) Num compatible format that will be the Ptr representation + in memory + @dst_type: (MemType or Type) the Type this Ptr points to. + If a Type is given, it is transformed into a MemType with + TheType.lval. + *type_args, **type_kwargs: arguments to pass to the the pointed + MemType when instantiating it (e.g. for MemStr encoding or + MemArray field_type). + """ + if (not isinstance(dst_type, Type) and + not (isinstance(dst_type, type) and + issubclass(dst_type, MemType)) and + not dst_type == MemSelf): + raise ValueError("dst_type of Ptr must be a MemType type, a " + "Type instance, the MemSelf marker or a class " + "name.") + super(Ptr, self).__init__(fmt) + if isinstance(dst_type, Type): + # Patch the field to propagate the MemSelf replacement + dst_type._get_self_type = lambda: self._get_self_type() + # dst_type cannot be patched here, since _get_self_type of the outer + # class has not yet been set. Patching dst_type involves calling + # dst_type.lval, which will only return a type that does not point + # on MemSelf but on the right class only when _get_self_type of the + # outer class has been replaced by _MetaMemStruct. + # In short, dst_type = dst_type.lval is not valid here, it is done + # lazily in _fix_dst_type + self._dst_type = dst_type + self._type_args = type_args + self._type_kwargs = type_kwargs + + def _fix_dst_type(self): + if self._dst_type in [MemSelf, SELF_TYPE_INSTANCE]: + if self._get_self_type() is not None: + self._dst_type = self._get_self_type() + else: + raise ValueError("Unsupported usecase for (Mem)Self, sorry") + self._dst_type = to_type(self._dst_type) + + @property + def dst_type(self): + """Return the type (MemType subtype) this Ptr points to.""" + self._fix_dst_type() + return self._dst_type + + def set(self, vm, addr, val): + """A Ptr field can be set with a MemPtr or an int""" + if isinstance(val, MemType) and isinstance(val.get_type(), Ptr): + self.set_val(vm, addr, val.val) + else: + super(Ptr, self).set(vm, addr, val) + + def get(self, vm, addr): + return self.lval(vm, addr) + + def get_val(self, vm, addr): + """Get the numeric value of a Ptr""" + return super(Ptr, self).get(vm, addr) + + def set_val(self, vm, addr, val): + """Set the numeric value of a Ptr""" + return super(Ptr, self).set(vm, addr, val) + + def deref_get(self, vm, addr): + """Deserializes the data in @vm (VmMngr) at @addr to self.dst_type. + Equivalent to a pointer dereference rvalue in C. + """ + dst_addr = self.get_val(vm, addr) + return self.dst_type.lval(vm, dst_addr, + *self._type_args, **self._type_kwargs) + + def deref_set(self, vm, addr, val): + """Serializes the @val MemType subclass instance in @vm (VmMngr) at + @addr. Equivalent to a pointer dereference assignment in C. + """ + # Sanity check + if self.dst_type != val.get_type(): + log.warning("Original type was %s, overridden by value of type %s", + self._dst_type.__name__, val.__class__.__name__) + + # Actual job + dst_addr = self.get_val(vm, addr) + vm.set_mem(dst_addr, bytes(val)) + + def _get_pinned_base_class(self): + return MemPtr + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self.dst_type) + + def __eq__(self, other): + return super(Ptr, self).__eq__(other) and \ + self.dst_type == other.dst_type and \ + self._type_args == other._type_args and \ + self._type_kwargs == other._type_kwargs + + def __ne__(self, other): + return not self == other + + def __hash__(self): + return hash((super(Ptr, self).__hash__(), self.dst_type, + self._type_args)) + + +class Struct(Type): + """Equivalent to a C struct type. Composed of a name, and a + (<field_name (str)>, <Type_subclass_instance>) list describing the fields + of the struct. + + Mapped to MemStruct. + + NOTE: The `.lval` property of Struct creates classes on the fly. If an + equivalent structure is created by subclassing MemStruct, an exception + is raised to prevent creating multiple classes designating the same type. + + Example: + s = Struct("Toto", [("f1", Num("I")), ("f2", Num("I"))]) + + Toto1 = s.lval + + # This raises an exception, because it describes the same structure as + # Toto1 + class Toto(MemStruct): + fields = [("f1", Num("I")), ("f2", Num("I"))] + + Anonymous Struct, Union or BitField can be used if their field name + evaluates to False ("" or None). Such anonymous Struct field will generate + fields to the parent Struct, e.g.: + bla = Struct("Bla", [ + ("a", Num("B")), + ("", Union([("b1", Num("B")), ("b2", Num("H"))])), + ("", Struct("", [("c1", Num("B")), ("c2", Num("B"))])), + ] + Will have a b1, b2 and c1, c2 field directly accessible. The anonymous + fields are renamed to "__anon_<num>", with <num> an incremented number. + + In such case, bla.fields will not contain b1, b2, c1 and c2 (only the 3 + actual fields, with the anonymous ones renamed), but bla.all_fields will + return the 3 fields + b1, b2, c1 and c2 (and an information telling if it + has been generated from an anonymous Struct/Union). + + bla.get_field(vm, addr, "b1") will work. + """ + + def __init__(self, name, fields): + self.name = name + # generates self._fields and self._fields_desc + self._gen_fields(fields) + + def _gen_fields(self, fields): + """Precompute useful metadata on self.fields.""" + self._fields_desc = {} + offset = 0 + + # Build a proper (name, Field()) list, handling cases where the user + # supplies a MemType subclass instead of a Type instance + real_fields = [] + uniq_count = 0 + for fname, field in fields: + field = to_type(field) + + # For reflexion + field._set_self_type(self) + + # Anonymous Struct/Union + if not fname and isinstance(field, Struct): + # Generate field information + updated_fields = { + name: { + # Same field type than the anon field subfield + 'field': fd['field'], + # But the current offset is added + 'offset': fd['offset'] + offset, + } + for name, fd in viewitems(field._fields_desc) + } + + # Add the newly generated fields from the anon field + self._fields_desc.update(updated_fields) + real_fields += [(name, fld, True) + for name, fld in field.fields] + + # Rename the anonymous field + fname = '__anon_%x' % uniq_count + uniq_count += 1 + + self._fields_desc[fname] = {"field": field, "offset": offset} + real_fields.append((fname, field, False)) + offset = self._next_offset(field, offset) + + # fields is immutable + self._fields = tuple(real_fields) + + def _next_offset(self, field, orig_offset): + return orig_offset + field.size + + @property + def fields(self): + """Returns a sequence of (name, field) describing the fields of this + Struct, in order of offset. + + Fields generated from anonymous Unions or Structs are excluded from + this sequence. + """ + return tuple((name, field) for name, field, anon in self._fields + if not anon) + + @property + def all_fields(self): + """Returns a sequence of (<name>, <field (Type instance)>, <is_anon>), + where is_anon is True when a field is generated from an anonymous + Struct or Union, and False for the fields that have been provided as is. + """ + return self._fields + + def set(self, vm, addr, val): + raw = bytes(val) + vm.set_mem(addr, raw) + + def get(self, vm, addr): + return self.lval(vm, addr) + + def get_field(self, vm, addr, name): + """Get a field value by @name and base structure @addr in @vm VmMngr.""" + if name not in self._fields_desc: + raise ValueError("'%s' type has no field '%s'" % (self, name)) + field = self.get_field_type(name) + offset = self.get_offset(name) + return field.get(vm, addr + offset) + + def set_field(self, vm, addr, name, val): + """Set a field value by @name and base structure @addr in @vm VmMngr. + @val is the python value corresponding to this field type. + """ + if name not in self._fields_desc: + raise AttributeError("'%s' object has no attribute '%s'" + % (self.__class__.__name__, name)) + field = self.get_field_type(name) + offset = self.get_offset(name) + field.set(vm, addr + offset, val) + + @property + def size(self): + return sum(field.size for _, field in self.fields) + + def get_offset(self, field_name): + """ + @field_name: (str, optional) the name of the field to get the + offset of + """ + if field_name not in self._fields_desc: + raise ValueError("This structure has no %s field" % field_name) + return self._fields_desc[field_name]['offset'] + + def get_field_type(self, name): + """Return the Type subclass instance describing field @name.""" + return self._fields_desc[name]['field'] + + def _get_pinned_base_class(self): + return MemStruct + + def __repr__(self): + return "struct %s" % self.name + + def __eq__(self, other): + return self.__class__ == other.__class__ and \ + self.fields == other.fields and \ + self.name == other.name + + def __ne__(self, other): + return not self == other + + def __hash__(self): + # Only hash name, not fields, because if a field is a Ptr to this + # Struct type, an infinite loop occurs + return hash((self.__class__, self.name)) + + +class Union(Struct): + """Represents a C union. + + Allows to put multiple fields at the same offset in a MemStruct, + similar to unions in C. The Union will have the size of the largest of its + fields. + + Mapped to MemUnion. + + Example: + + class Example(MemStruct): + fields = [("uni", Union([ + ("f1", Num("<B")), + ("f2", Num("<H")) + ]) + )] + + ex = Example(vm, addr) + ex.uni.f2 = 0x1234 + assert ex.uni.f1 == 0x34 + """ + + def __init__(self, field_list): + """@field_list: a [(name, field)] list, see the class doc""" + super(Union, self).__init__("union", field_list) + + @property + def size(self): + return max(field.size for _, field in self.fields) + + def _next_offset(self, field, orig_offset): + return orig_offset + + def _get_pinned_base_class(self): + return MemUnion + + def __repr__(self): + fields_repr = ', '.join("%s: %r" % (name, field) + for name, field in self.fields) + return "%s(%s)" % (self.__class__.__name__, fields_repr) + + +class Array(Type): + """An array (contiguous sequence) of a Type subclass elements. + + Can be sized (similar to something like the char[10] type in C) or unsized + if no @array_len is given to the constructor (similar to char* used as an + array). + + Mapped to MemArray or MemSizedArray, depending on if the Array is + sized or not. + + Getting an array field actually returns a MemSizedArray. Setting it is + possible with either a list or a MemSizedArray instance. Examples of + syntax: + + class Example(MemStruct): + fields = [("array", Array(Num("B"), 4))] + + mystruct = Example(vm, addr) + mystruct.array[3] = 27 + mystruct.array = [1, 4, 8, 9] + mystruct.array = MemSizedArray(vm, addr2, Num("B"), 4) + """ + + def __init__(self, field_type, array_len=None): + # Handle both Type instance and MemType subclasses + self.field_type = to_type(field_type) + self.array_len = array_len + + def _set_self_type(self, self_type): + super(Array, self)._set_self_type(self_type) + self.field_type._set_self_type(self_type) + + def set(self, vm, addr, val): + # MemSizedArray assignment + if isinstance(val, MemSizedArray): + if val.array_len != self.array_len or len(val) != self.size: + raise ValueError("Size mismatch in MemSizedArray assignment") + raw = bytes(val) + vm.set_mem(addr, raw) + + # list assignment + elif isinstance(val, list): + if len(val) != self.array_len: + raise ValueError("Size mismatch in MemSizedArray assignment ") + offset = 0 + for elt in val: + self.field_type.set(vm, addr + offset, elt) + offset += self.field_type.size + + else: + raise RuntimeError( + "Assignment only implemented for list and MemSizedArray") + + def get(self, vm, addr): + return self.lval(vm, addr) + + @property + def size(self): + if self.is_sized(): + return self.get_offset(self.array_len) + else: + raise ValueError("%s is unsized, use an array with a fixed " + "array_len instead." % self) + + def get_offset(self, idx): + """Returns the offset of the item at index @idx.""" + return self.field_type.size * idx + + def get_item(self, vm, addr, idx): + """Get the item(s) at index @idx. + + @idx: int, long or slice + """ + if isinstance(idx, slice): + res = [] + idx = self._normalize_slice(idx) + for i in range(idx.start, idx.stop, idx.step): + res.append(self.field_type.get(vm, addr + self.get_offset(i))) + return res + else: + idx = self._normalize_idx(idx) + return self.field_type.get(vm, addr + self.get_offset(idx)) + + def set_item(self, vm, addr, idx, item): + """Sets one or multiple items in this array (@idx can be an int, long + or slice). + """ + if isinstance(idx, slice): + idx = self._normalize_slice(idx) + if len(item) != len(range(idx.start, idx.stop, idx.step)): + raise ValueError("Mismatched lengths in slice assignment") + for i, val in zip(range(idx.start, idx.stop, idx.step), + item): + self.field_type.set(vm, addr + self.get_offset(i), val) + else: + idx = self._normalize_idx(idx) + self.field_type.set(vm, addr + self.get_offset(idx), item) + + def is_sized(self): + """True if this is a sized array (non None self.array_len), False + otherwise. + """ + return self.array_len is not None + + def _normalize_idx(self, idx): + # Noop for this type + if self.is_sized(): + if idx < 0: + idx = self.array_len + idx + self._check_bounds(idx) + return idx + + def _normalize_slice(self, slice_): + start = slice_.start if slice_.start is not None else 0 + stop = slice_.stop if slice_.stop is not None else self.get_size() + step = slice_.step if slice_.step is not None else 1 + start = self._normalize_idx(start) + stop = self._normalize_idx(stop) + return slice(start, stop, step) + + def _check_bounds(self, idx): + if not isinstance(idx, int_types): + raise ValueError("index must be an int or a long") + if idx < 0 or (self.is_sized() and idx >= self.size): + raise IndexError("Index %s out of bounds" % idx) + + def _get_pinned_base_class(self): + if self.is_sized(): + return MemSizedArray + else: + return MemArray + + def __repr__(self): + return "[%r; %s]" % (self.field_type, self.array_len or "unsized") + + def __eq__(self, other): + return self.__class__ == other.__class__ and \ + self.field_type == other.field_type and \ + self.array_len == other.array_len + + def __ne__(self, other): + return not self == other + + def __hash__(self): + return hash((self.__class__, self.field_type, self.array_len)) + + +class Bits(Type): + """Helper class for BitField, not very useful on its own. Represents some + bits of a Num. + + The @backing_num is used to know how to serialize/deserialize data in vm, + but getting/setting this fields only assign bits from @bit_offset to + @bit_offset + @bits. Masking and shifting is handled by the class, the aim + is to provide a transparent way to set and get some bits of a num. + """ + + def __init__(self, backing_num, bits, bit_offset): + if not isinstance(backing_num, Num): + raise ValueError("backing_num should be a Num instance") + self._num = backing_num + self._bits = bits + self._bit_offset = bit_offset + + def set(self, vm, addr, val): + val_mask = (1 << self._bits) - 1 + val_shifted = (val & val_mask) << self._bit_offset + num_size = self._num.size * 8 + + full_num_mask = (1 << num_size) - 1 + num_mask = (~(val_mask << self._bit_offset)) & full_num_mask + + num_val = self._num.get(vm, addr) + res_val = (num_val & num_mask) | val_shifted + self._num.set(vm, addr, res_val) + + def get(self, vm, addr): + val_mask = (1 << self._bits) - 1 + num_val = self._num.get(vm, addr) + res_val = (num_val >> self._bit_offset) & val_mask + return res_val + + @property + def size(self): + return self._num.size + + @property + def bit_size(self): + """Number of bits read/written by this class""" + return self._bits + + @property + def bit_offset(self): + """Offset in bits (beginning at 0, the LSB) from which to read/write + bits. + """ + return self._bit_offset + + def __repr__(self): + return "%s%r(%d:%d)" % (self.__class__.__name__, self._num, + self._bit_offset, self._bit_offset + self._bits) + + def __eq__(self, other): + return self.__class__ == other.__class__ and \ + self._num == other._num and self._bits == other._bits and \ + self._bit_offset == other._bit_offset + + def __ne__(self, other): + return not self == other + + def __hash__(self): + return hash((self.__class__, self._num, self._bits, self._bit_offset)) + + +class BitField(Union): + """A C-like bitfield. + + Constructed with a list [(<field_name>, <number_of_bits>)] and a + @backing_num. The @backing_num is a Num instance that determines the total + size of the bitfield and the way the bits are serialized/deserialized (big + endian int, little endian short...). Can be seen (and implemented) as a + Union of Bits fields. + + Mapped to MemBitField. + + Creates fields that allow to access the bitfield fields easily. Example: + + class Example(MemStruct): + fields = [("bf", BitField(Num("B"), [ + ("f1", 2), + ("f2", 4), + ("f3", 1) + ]) + )] + + ex = Example(vm, addr) + ex.memset() + ex.f2 = 2 + ex.f1 = 5 # 5 does not fit on two bits, it will be binarily truncated + assert ex.f1 == 3 + assert ex.f2 == 2 + assert ex.f3 == 0 # previously memset() + assert ex.bf == 3 + 2 << 2 + """ + + def __init__(self, backing_num, bit_list): + """@backing num: Num instance, @bit_list: [(name, n_bits)]""" + self._num = backing_num + fields = [] + offset = 0 + for name, bits in bit_list: + fields.append((name, Bits(self._num, bits, offset))) + offset += bits + if offset > self._num.size * 8: + raise ValueError("sum of bit lengths is > to the backing num size") + super(BitField, self).__init__(fields) + + def set(self, vm, addr, val): + self._num.set(vm, addr, val) + + def _get_pinned_base_class(self): + return MemBitField + + def __eq__(self, other): + return self.__class__ == other.__class__ and \ + self._num == other._num and super(BitField, self).__eq__(other) + + def __ne__(self, other): + return not self == other + + def __hash__(self): + return hash((super(BitField, self).__hash__(), self._num)) + + def __repr__(self): + fields_repr = ', '.join("%s: %r" % (name, field.bit_size) + for name, field in self.fields) + return "%s(%s)" % (self.__class__.__name__, fields_repr) + + +class Str(Type): + """A string type that handles encoding. This type is unsized (no static + size). + + The @encoding is passed to the constructor, and is one of the keys of + Str.encodings, currently: + - ascii + - latin1 + - ansi (= latin1) + - utf8 (= utf-8le) + - utf16 (= utf-16le, Windows UCS-2 compatible) + New encodings can be added with Str.add_encoding. + If an unknown encoding is passed to the constructor, Str will try to add it + to the available ones with Str.add_encoding. + + Mapped to MemStr. + """ + + # Dict of {name: (getter, setter, raw_len)} + # Where: + # - getter(vm, addr) -> unicode + # - setter(vm, addr, unicode) + # - raw_len(unicode_str) -> int (length of the str value one encoded in + # memory) + # See enc_triplet() + # + # NOTE: this appears like it could be implemented only with + # (getter, raw_str), but this would cause trouble for length-prefixed str + # encoding (Pascal-style strings). + encodings = { + "ascii": enc_triplet("ascii"), + "latin1": enc_triplet("latin1"), + "ansi": enc_triplet("latin1"), + "utf8": enc_triplet("utf8"), + "utf16": enc_triplet("utf-16le"), + } + + def __init__(self, encoding="ansi"): + if encoding not in self.encodings: + self.add_encoding(encoding) + self._enc = encoding + + @classmethod + def add_encoding(cls, enc_name, str_enc=None, getter=None, setter=None, + raw_len=None): + """Add an available Str encoding. + + @enc_name: the name that will be used to designate this encoding in the + Str constructor + @str_end: (optional) the actual str encoding name if it differs from + @enc_name + @getter: (optional) func(vm, addr) -> unicode, to force usage of this + function to retrieve the str from memory + @setter: (optional) func(vm, addr, unicode), to force usage of this + function to set the str in memory + @raw_len: (optional) func(unicode_str) -> int (length of the str value + one encoded in memory), to force usage of this function to compute + the length of this string once in memory + """ + default = enc_triplet(str_enc or enc_name) + actual = ( + getter or default[0], + setter or default[1], + raw_len or default[2], + ) + cls.encodings[enc_name] = actual + + def get(self, vm, addr): + """Set the string value in memory""" + get_str = self.encodings[self.enc][0] + return get_str(vm, addr) + + def set(self, vm, addr, s): + """Get the string value from memory""" + set_str = self.encodings[self.enc][1] + set_str(vm, addr, s) + + @property + def size(self): + """This type is unsized.""" + raise ValueError("Str is unsized") + + def value_size(self, py_str): + """Returns the in-memory size of a @py_str for this Str type (handles + encoding, i.e. will not return the same size for "utf16" and "ansi"). + """ + raw_len = self.encodings[self.enc][2] + return raw_len(py_str) + + @property + def enc(self): + """This Str's encoding name (as a str).""" + return self._enc + + def _get_pinned_base_class(self): + return MemStr + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, self.enc) + + def __eq__(self, other): + return self.__class__ == other.__class__ and self._enc == other._enc + + def __ne__(self, other): + return not self == other + + def __hash__(self): + return hash((self.__class__, self._enc)) + + +class Void(Type): + """Represents the C void type. + + Mapped to MemVoid. + """ + + def _build_pinned_type(self): + return MemVoid + + def __eq__(self, other): + return self.__class__ == other.__class__ + + def __ne__(self, other): + return not self == other + + def __hash__(self): + return hash(self.__class__) + + def __repr__(self): + return self.__class__.__name__ + + +class Self(Void): + """Special marker to reference a type inside itself. + + Mapped to MemSelf. + + Example: + class ListNode(MemStruct): + fields = [ + ("next", Ptr("<I", Self())), + ("data", Ptr("<I", Void())), + ] + """ + + def _build_pinned_type(self): + return MemSelf + +# To avoid reinstantiation when testing equality +SELF_TYPE_INSTANCE = Self() +VOID_TYPE_INSTANCE = Void() + + +# MemType classes + +class _MetaMemType(type): + def __repr__(cls): + return cls.__name__ + + +class _MetaMemStruct(_MetaMemType): + """MemStruct metaclass. Triggers the magic that generates the class + fields from the cls.fields list. + + Just calls MemStruct.gen_fields() if the fields class attribute has been + set, the actual implementation can seen be there. + """ + + def __init__(cls, name, bases, dct): + super(_MetaMemStruct, cls).__init__(name, bases, dct) + if cls.fields is not None: + cls.fields = tuple(cls.fields) + # Am I able to generate fields? (if not, let the user do it manually + # later) + if cls.get_type() is not None or cls.fields is not None: + cls.gen_fields() + + +class MemType(with_metaclass(_MetaMemType, object)): + """Base class for classes that allow to map python objects to C types in + virtual memory. Represents an lvalue of a given type. + + Globally, MemTypes are not meant to be used directly: specialized + subclasses are generated by Type(...).lval and should be used instead. + The main exception is MemStruct, which you may want to subclass yourself + for syntactic ease. + """ + + # allocator is a function(vm, size) -> allocated_address + allocator = None + + _type = None + + def __init__(self, vm, addr=None, type_=None): + self._vm = vm + if addr is None: + self._addr = self.alloc(vm, self.get_size()) + else: + self._addr = addr + if type_ is not None: + self._type = type_ + if self._type is None: + raise ValueError("Subclass MemType and define cls._type or pass " + "a type to the constructor") + + @classmethod + def alloc(cls, vm, size): + """Returns an allocated page of size @size if cls.allocator is set. + Raises ValueError otherwise. + """ + if cls.allocator is None: + raise ValueError("Cannot provide None address to MemType() if" + "%s.set_allocator has not been called." + % __name__) + return cls.allocator(vm, size) + + @classmethod + def set_allocator(cls, alloc_func): + """Set an allocator for this class; allows to instantiate statically + sized MemTypes (i.e. sizeof() is implemented) without specifying the + address (the object is allocated by @alloc_func in the vm). + + You may call set_allocator on specific MemType classes if you want + to use a different allocator. + + @alloc_func: func(VmMngr) -> integer_address + """ + cls.allocator = alloc_func + + def get_addr(self, field=None): + """Return the address of this MemType or one of its fields. + + @field: (str, optional) used by subclasses to specify the name or index + of the field to get the address of + """ + if field is not None: + raise NotImplementedError("Getting a field's address is not " + "implemented for this class.") + return self._addr + + @classmethod + def get_type(cls): + """Returns the Type subclass instance representing the C type of this + MemType. + """ + return cls._type + + @classmethod + def sizeof(cls): + """Return the static size of this type. By default, it is the size + of the underlying Type. + """ + return cls._type.size + + def get_size(self): + """Return the dynamic size of this structure (e.g. the size of an + instance). Defaults to sizeof for this base class. + + For example, MemStr defines get_size but not sizeof, as an instance + has a fixed size (at least its value has), but all the instance do not + have the same size. + """ + return self.sizeof() + + def memset(self, byte=b'\x00'): + """Fill the memory space of this MemType with @byte ('\x00' by + default). The size is retrieved with self.get_size() (dynamic size). + """ + # TODO: multibyte patterns + if not isinstance(byte, bytes) or len(byte) != 1: + raise ValueError("byte must be a 1-lengthed str") + self._vm.set_mem(self.get_addr(), byte * self.get_size()) + + def cast(self, other_type): + """Cast this MemType to another MemType (same address, same vm, + but different type). Return the casted MemType. + + @other_type: either a Type instance (other_type.lval is used) or a + MemType subclass + """ + if isinstance(other_type, Type): + other_type = other_type.lval + return other_type(self._vm, self.get_addr()) + + def cast_field(self, field, other_type, *type_args, **type_kwargs): + """ABSTRACT: Same as cast, but the address of the returned MemType + is the address at which @field is in the current MemType. + + @field: field specification, for example its name for a struct, or an + index in an array. See the subclass doc. + @other_type: either a Type instance (other_type.lval is used) or a + MemType subclass + """ + raise NotImplementedError("Abstract") + + def raw(self): + """Raw binary (str) representation of the MemType as it is in + memory. + """ + return self._vm.get_mem(self.get_addr(), self.get_size()) + + def __len__(self): + return self.get_size() + + def __str__(self): + if PY3: + return repr(self) + return self.__bytes__() + + def __bytes__(self): + return self.raw() + + def __repr__(self): + return "Mem%r" % self._type + + def __eq__(self, other): + return self.__class__ == other.__class__ and \ + self.get_type() == other.get_type() and \ + bytes(self) == bytes(other) + + def __ne__(self, other): + return not self == other + + +class MemValue(MemType): + """Simple MemType that gets and sets the Type through the `.val` + attribute. + """ + + @property + def val(self): + return self._type.get(self._vm, self._addr) + + @val.setter + def val(self, value): + self._type.set(self._vm, self._addr, value) + + def __repr__(self): + return "%r: %r" % (self.__class__, self.val) + + +class MemStruct(with_metaclass(_MetaMemStruct, MemType)): + """Base class to easily implement VmMngr backed C-like structures in miasm. + Represents a structure in virtual memory. + + The mechanism is the following: + - set a "fields" class field to be a list of + (<field_name (str)>, <Type_subclass_instance>) + - instances of this class will have properties to interact with these + fields. + + Example: + class MyStruct(MemStruct): + fields = [ + # Scalar field: just struct.pack field with one value + ("num", Num("I")), + ("flags", Num("B")), + # Ptr fields contain two fields: "val", for the numerical value, + # and "deref" to get the pointed object + ("other", Ptr("I", OtherStruct)), + # Ptr to a variable length String + ("s", Ptr("I", Str())), + ("i", Ptr("I", Num("I"))), + ] + + mstruct = MyStruct(vm, addr) + + # Field assignment modifies virtual memory + mstruct.num = 3 + assert mstruct.num == 3 + memval = struct.unpack("I", vm.get_mem(mstruct.get_addr(), + 4))[0] + assert memval == mstruct.num + + # Memset sets the whole structure + mstruct.memset() + assert mstruct.num == 0 + mstruct.memset('\x11') + assert mstruct.num == 0x11111111 + + other = OtherStruct(vm, addr2) + mstruct.other = other.get_addr() + assert mstruct.other.val == other.get_addr() + assert mstruct.other.deref == other + assert mstruct.other.deref.foo == 0x1234 + + Note that: + MyStruct = Struct("MyStruct", <same fields>).lval + is equivalent to the previous MyStruct declaration. + + See the various Type-s doc for more information. See MemStruct.gen_fields + doc for more information on how to handle recursive types and cyclic + dependencies. + """ + fields = None + + def get_addr(self, field_name=None): + """ + @field_name: (str, optional) the name of the field to get the + address of + """ + if field_name is not None: + offset = self._type.get_offset(field_name) + else: + offset = 0 + return self._addr + offset + + @classmethod + def get_offset(cls, field_name): + """Shorthand for self.get_type().get_offset(field_name).""" + return cls.get_type().get_offset(field_name) + + def get_field(self, name): + """Get a field value by name. + + useless most of the time since fields are accessible via self.<name>. + """ + return self._type.get_field(self._vm, self.get_addr(), name) + + def set_field(self, name, val): + """Set a field value by name. @val is the python value corresponding to + this field type. + + useless most of the time since fields are accessible via self.<name>. + """ + return self._type.set_field(self._vm, self.get_addr(), name, val) + + def cast_field(self, field, other_type): + """In this implementation, @field is a field name""" + if isinstance(other_type, Type): + other_type = other_type.lval + return other_type(self._vm, self.get_addr(field)) + + # Field generation method, voluntarily public to be able to gen fields + # after class definition + @classmethod + def gen_fields(cls, fields=None): + """Generate the fields of this class (so that they can be accessed with + self.<field_name>) from a @fields list, as described in the class doc. + + Useful in case of a type cyclic dependency. For example, the following + is not possible in python: + + class A(MemStruct): + fields = [("b", Ptr("I", B))] + + class B(MemStruct): + fields = [("a", Ptr("I", A))] + + With gen_fields, the following is the legal equivalent: + + class A(MemStruct): + pass + + class B(MemStruct): + fields = [("a", Ptr("I", A))] + + A.gen_fields([("b", Ptr("I", B))]) + """ + if fields is not None: + if cls.fields is not None: + raise ValueError("Cannot regen fields of a class. Setting " + "cls.fields at class definition and calling " + "gen_fields are mutually exclusive.") + cls.fields = fields + + if cls._type is None: + if cls.fields is None: + raise ValueError("Cannot create a MemStruct subclass without" + " a cls._type or a cls.fields") + cls._type = cls._gen_type(cls.fields) + + if cls._type in DYN_MEM_STRUCT_CACHE: + # FIXME: Maybe a warning would be better? + raise RuntimeError("Another MemType has the same type as this " + "one. Use it instead.") + + # Register this class so that another one will not be created when + # calling cls._type.lval + DYN_MEM_STRUCT_CACHE[cls._type] = cls + + cls._gen_attributes() + + @classmethod + def _gen_attributes(cls): + # Generate self.<name> getter and setters + for name, _, _ in cls._type.all_fields: + setattr(cls, name, property( + lambda self, name=name: self.get_field(name), + lambda self, val, name=name: self.set_field(name, val) + )) + + @classmethod + def _gen_type(cls, fields): + return Struct(cls.__name__, fields) + + def __repr__(self): + out = [] + for name, field in self._type.fields: + val_repr = repr(self.get_field(name)) + if '\n' in val_repr: + val_repr = '\n' + indent(val_repr, 4) + out.append("%s: %r = %s" % (name, field, val_repr)) + return '%r:\n' % self.__class__ + indent('\n'.join(out), 2) + + +class MemUnion(MemStruct): + """Same as MemStruct but all fields have a 0 offset in the struct.""" + @classmethod + def _gen_type(cls, fields): + return Union(fields) + + +class MemBitField(MemUnion): + """MemUnion of Bits(...) fields.""" + @classmethod + def _gen_type(cls, fields): + return BitField(fields) + + +class MemSelf(MemStruct): + """Special Marker class for reference to current class in a Ptr or Array + (mostly Array of Ptr). See Self doc. + """ + def __repr__(self): + return self.__class__.__name__ + + +class MemVoid(MemType): + """Placeholder for e.g. Ptr to an undetermined type. Useful mostly when + casted to another type. Allows to implement C's "void*" pattern. + """ + _type = Void() + + def __repr__(self): + return self.__class__.__name__ + + +class MemPtr(MemValue): + """Mem version of a Ptr, provides two properties: + - val, to set and get the numeric value of the Ptr + - deref, to set and get the pointed type + """ + @property + def val(self): + return self._type.get_val(self._vm, self._addr) + + @val.setter + def val(self, value): + return self._type.set_val(self._vm, self._addr, value) + + @property + def deref(self): + return self._type.deref_get(self._vm, self._addr) + + @deref.setter + def deref(self, val): + return self._type.deref_set(self._vm, self._addr, val) + + def __repr__(self): + return "*%s" % hex(self.val) + + +class MemStr(MemValue): + """Implements a string representation in memory. + + The string value can be got or set (with python str/unicode) through the + self.val attribute. String encoding/decoding is handled by the class, + + This type is dynamically sized only (get_size is implemented, not sizeof). + """ + + def get_size(self): + """This get_size implementation is quite unsafe: it reads the string + underneath to determine the size, it may therefore read a lot of memory + and provoke mem faults (analogous to strlen). + """ + val = self.val + return self.get_type().value_size(val) + + @classmethod + def from_str(cls, vm, py_str): + """Allocates a MemStr with the global allocator with value py_str. + Raises a ValueError if allocator is not set. + """ + size = cls._type.value_size(py_str) + addr = cls.alloc(vm, size) + memstr = cls(vm, addr) + memstr.val = py_str + return memstr + + def raw(self): + raw = self._vm.get_mem(self.get_addr(), self.get_size()) + return raw + + def __repr__(self): + return "%r: %r" % (self.__class__, self.val) + + +class MemArray(MemType): + """An unsized array of type @field_type (a Type subclass instance). + This class has no static or dynamic size. + + It can be indexed for setting and getting elements, example: + + array = Array(Num("I")).lval(vm, addr)) + array[2] = 5 + array[4:8] = [0, 1, 2, 3] + print array[20] + """ + + @property + def field_type(self): + """Return the Type subclass instance that represents the type of + this MemArray items. + """ + return self.get_type().field_type + + def get_addr(self, idx=0): + return self._addr + self.get_type().get_offset(idx) + + @classmethod + def get_offset(cls, idx): + """Shorthand for self.get_type().get_offset(idx).""" + return cls.get_type().get_offset(idx) + + def __getitem__(self, idx): + return self.get_type().get_item(self._vm, self._addr, idx) + + def __setitem__(self, idx, item): + self.get_type().set_item(self._vm, self._addr, idx, item) + + def raw(self): + raise ValueError("%s is unsized, which prevents from getting its full " + "raw representation. Use MemSizedArray instead." % + self.__class__) + + def __repr__(self): + return "[%r, ...] [%r]" % (self[0], self.field_type) + + +class MemSizedArray(MemArray): + """A fixed size MemArray. + + This type is dynamically sized. Generate a fixed @field_type and @array_len + array which has a static size by using Array(type, size).lval. + """ + + @property + def array_len(self): + """The length, in number of elements, of this array.""" + return self.get_type().array_len + + def get_size(self): + return self.get_type().size + + def __iter__(self): + for i in range(self.get_type().array_len): + yield self[i] + + def raw(self): + return self._vm.get_mem(self.get_addr(), self.get_size()) + + def __repr__(self): + item_reprs = [repr(item) for item in self] + if self.array_len > 0 and '\n' in item_reprs[0]: + items = '\n' + indent(',\n'.join(item_reprs), 2) + '\n' + else: + items = ', '.join(item_reprs) + return "[%s] [%r; %s]" % (items, self.field_type, self.array_len) + diff --git a/src/miasm/core/utils.py b/src/miasm/core/utils.py new file mode 100644 index 00000000..291c5f4d --- /dev/null +++ b/src/miasm/core/utils.py @@ -0,0 +1,292 @@ +from __future__ import print_function +import re +import sys +from builtins import range +import struct +import inspect + +try: + from collections.abc import MutableMapping as DictMixin +except ImportError: + from collections import MutableMapping as DictMixin + +from operator import itemgetter +import codecs + +from future.utils import viewitems + +import collections + +COLOR_INT = "azure4" +COLOR_ID = "forestgreen"#"chartreuse3" +COLOR_MEM = "deeppink4" +COLOR_OP_FUNC = "blue1" +COLOR_LOC = "darkslateblue" +COLOR_OP = "black" + +COLOR_MNEMO = "blue1" + +ESCAPE_CHARS = re.compile('[' + re.escape('{}[]') + '&|<>' + ']') + + + +def set_html_text_color(text, color): + return '<font color="%s">%s</font>' % (color, text) + + +def _fix_chars(token): + return "&#%04d;" % ord(token.group()) + + +def fix_html_chars(text): + return ESCAPE_CHARS.sub(_fix_chars, str(text)) + +BRACKET_O = fix_html_chars('[') +BRACKET_C = fix_html_chars(']') + +upck8 = lambda x: struct.unpack('B', x)[0] +upck16 = lambda x: struct.unpack('H', x)[0] +upck32 = lambda x: struct.unpack('I', x)[0] +upck64 = lambda x: struct.unpack('Q', x)[0] +pck8 = lambda x: struct.pack('B', x) +pck16 = lambda x: struct.pack('H', x) +pck32 = lambda x: struct.pack('I', x) +pck64 = lambda x: struct.pack('Q', x) + +# Little endian +upck8le = lambda x: struct.unpack('<B', x)[0] +upck16le = lambda x: struct.unpack('<H', x)[0] +upck32le = lambda x: struct.unpack('<I', x)[0] +upck64le = lambda x: struct.unpack('<Q', x)[0] +pck8le = lambda x: struct.pack('<B', x) +pck16le = lambda x: struct.pack('<H', x) +pck32le = lambda x: struct.pack('<I', x) +pck64le = lambda x: struct.pack('<Q', x) + +# Big endian +upck8be = lambda x: struct.unpack('>B', x)[0] +upck16be = lambda x: struct.unpack('>H', x)[0] +upck32be = lambda x: struct.unpack('>I', x)[0] +upck64be = lambda x: struct.unpack('>Q', x)[0] +pck8be = lambda x: struct.pack('>B', x) +pck16be = lambda x: struct.pack('>H', x) +pck32be = lambda x: struct.pack('>I', x) +pck64be = lambda x: struct.pack('>Q', x) + + +LITTLE_ENDIAN = 1 +BIG_ENDIAN = 2 + + +pck = {8: pck8, + 16: pck16, + 32: pck32, + 64: pck64} + + +def get_caller_name(caller_num=0): + """Get the nth caller's name + @caller_num: 0 = the caller of get_caller_name, 1 = next parent, ...""" + pystk = inspect.stack() + if len(pystk) > 1 + caller_num: + return pystk[1 + caller_num][3] + else: + return "Bad caller num" + + +def whoami(): + """Returns the caller's name""" + return get_caller_name(1) + + +class Disasm_Exception(Exception): + pass + + +def printable(string): + if isinstance(string, bytes): + return "".join( + c.decode() if b" " <= c < b"~" else "." + for c in (string[i:i+1] for i in range(len(string))) + ) + return string + + +def force_bytes(value): + if isinstance(value, bytes): + return value + if not isinstance(value, str): + return value + out = [] + for c in value: + c = ord(c) + assert c < 0x100 + out.append(c) + return bytes(out) + + +def force_str(value): + if isinstance(value, str): + return value + elif isinstance(value, bytes): + out = "" + for i in range(len(value)): + # For Python2/Python3 compatibility + c = ord(value[i:i+1]) + out += chr(c) + value = out + else: + raise ValueError("Unsupported type") + return value + + +def iterbytes(string): + for i in range(len(string)): + yield string[i:i+1] + + +def int_to_byte(value): + return struct.pack('B', value) + +def cmp_elts(elt1, elt2): + return (elt1 > elt2) - (elt1 < elt2) + + +_DECODE_HEX = codecs.getdecoder("hex_codec") +_ENCODE_HEX = codecs.getencoder("hex_codec") + +def decode_hex(value): + return _DECODE_HEX(value)[0] + +def encode_hex(value): + return _ENCODE_HEX(value)[0] + +def size2mask(size): + """Return the bit mask of size @size""" + return (1 << size) - 1 + +def hexdump(src, length=16): + lines = [] + for c in range(0, len(src), length): + chars = src[c:c + length] + hexa = ' '.join("%02x" % ord(x) for x in iterbytes(chars)) + printable = ''.join( + x.decode() if 32 <= ord(x) <= 126 else '.' for x in iterbytes(chars) + ) + lines.append("%04x %-*s %s\n" % (c, length * 3, hexa, printable)) + print(''.join(lines)) + + +# stackoverflow.com/questions/2912231 +class keydefaultdict(collections.defaultdict): + + def __missing__(self, key): + if self.default_factory is None: + raise KeyError(key) + value = self[key] = self.default_factory(key) + return value + + +class BoundedDict(DictMixin): + + """Limited in size dictionary. + + To reduce combinatory cost, once an upper limit @max_size is reached, + @max_size - @min_size elements are suppressed. + The targeted elements are the less accessed. + + One can define a callback called when an element is removed + """ + + def __init__(self, max_size, min_size=None, initialdata=None, + delete_cb=None): + """Create a BoundedDict + @max_size: maximum size of the dictionary + @min_size: (optional) number of most used element to keep when resizing + @initialdata: (optional) dict instance with initial data + @delete_cb: (optional) callback called when an element is removed + """ + self._data = initialdata.copy() if initialdata else {} + self._min_size = min_size if min_size else max_size // 3 + self._max_size = max_size + self._size = len(self._data) + # Do not use collections.Counter as it is quite slow + self._counter = {k: 1 for k in self._data} + self._delete_cb = delete_cb + + def __setitem__(self, asked_key, value): + if asked_key not in self._data: + # Update internal size and use's counter + self._size += 1 + + # Bound can only be reached on a new element + if (self._size >= self._max_size): + most_common = sorted( + viewitems(self._counter), + key=itemgetter(1), + reverse=True + ) + + # Handle callback + if self._delete_cb is not None: + for key, _ in most_common[self._min_size - 1:]: + self._delete_cb(key) + + # Keep only the most @_min_size used + self._data = {key: self._data[key] + for key, _ in most_common[:self._min_size - 1]} + self._size = self._min_size + + # Reset use's counter + self._counter = {k: 1 for k in self._data} + + # Avoid rechecking in dict: set to 1 here, add 1 otherwise + self._counter[asked_key] = 1 + else: + self._counter[asked_key] += 1 + + self._data[asked_key] = value + + def __contains__(self, key): + # Do not call has_key to avoid adding function call overhead + return key in self._data + + def has_key(self, key): + return key in self._data + + def keys(self): + "Return the list of dict's keys" + return list(self._data) + + @property + def data(self): + "Return the current instance as a dictionary" + return self._data + + def __getitem__(self, key): + # Retrieve data first to raise the proper exception on error + data = self._data[key] + # Should never raise, since the key is in self._data + self._counter[key] += 1 + return data + + def __delitem__(self, key): + if self._delete_cb is not None: + self._delete_cb(key) + del self._data[key] + self._size -= 1 + del self._counter[key] + + def __del__(self): + """Ensure the callback is called when last reference is lost""" + if self._delete_cb: + for key in self._data: + self._delete_cb(key) + + + def __len__(self): + return len(self._data) + + def __iter__(self): + return iter(self._data) + |