diff options
Diffstat (limited to 'miasm/core')
| -rw-r--r-- | miasm/core/__init__.py | 1 | ||||
| -rw-r--r-- | miasm/core/asm_ast.py | 93 | ||||
| -rw-r--r-- | miasm/core/asmblock.py | 1474 | ||||
| -rw-r--r-- | miasm/core/bin_stream.py | 319 | ||||
| -rw-r--r-- | miasm/core/bin_stream_ida.py | 45 | ||||
| -rw-r--r-- | miasm/core/cpu.py | 1715 | ||||
| -rw-r--r-- | miasm/core/ctypesmngr.py | 771 | ||||
| -rw-r--r-- | miasm/core/graph.py | 1123 | ||||
| -rw-r--r-- | miasm/core/interval.py | 284 | ||||
| -rw-r--r-- | miasm/core/locationdb.py | 495 | ||||
| -rw-r--r-- | miasm/core/modint.py | 270 | ||||
| -rw-r--r-- | miasm/core/objc.py | 1763 | ||||
| -rw-r--r-- | miasm/core/parse_asm.py | 288 | ||||
| -rw-r--r-- | miasm/core/sembuilder.py | 341 | ||||
| -rw-r--r-- | miasm/core/types.py | 1693 | ||||
| -rw-r--r-- | miasm/core/utils.py | 292 |
16 files changed, 0 insertions, 10967 deletions
diff --git a/miasm/core/__init__.py b/miasm/core/__init__.py deleted file mode 100644 index d154134b..00000000 --- a/miasm/core/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"Core components" diff --git a/miasm/core/asm_ast.py b/miasm/core/asm_ast.py deleted file mode 100644 index 69ff1f9c..00000000 --- a/miasm/core/asm_ast.py +++ /dev/null @@ -1,93 +0,0 @@ -from builtins import int as int_types - -class AstNode(object): - """ - Ast node object - """ - def __neg__(self): - if isinstance(self, AstInt): - value = AstInt(-self.value) - else: - value = AstOp('-', self) - return value - - def __add__(self, other): - return AstOp('+', self, other) - - def __sub__(self, other): - return AstOp('-', self, other) - - def __div__(self, other): - return AstOp('/', self, other) - - def __mod__(self, other): - return AstOp('%', self, other) - - def __mul__(self, other): - return AstOp('*', self, other) - - def __lshift__(self, other): - return AstOp('<<', self, other) - - def __rshift__(self, other): - return AstOp('>>', self, other) - - def __xor__(self, other): - return AstOp('^', self, other) - - def __or__(self, other): - return AstOp('|', self, other) - - def __and__(self, other): - return AstOp('&', self, other) - - -class AstInt(AstNode): - """ - Ast integer - """ - def __init__(self, value): - self.value = value - - def __str__(self): - return "%s" % self.value - - -class AstId(AstNode): - """ - Ast Id - """ - def __init__(self, name): - self.name = name - - def __str__(self): - return "%s" % self.name - - -class AstMem(AstNode): - """ - Ast memory deref - """ - def __init__(self, ptr, size): - assert isinstance(ptr, AstNode) - assert isinstance(size, int_types) - self.ptr = ptr - self.size = size - - def __str__(self): - return "@%d[%s]" % (self.size, self.ptr) - - -class AstOp(AstNode): - """ - Ast operator - """ - def __init__(self, op, *args): - assert all(isinstance(arg, AstNode) for arg in args) - self.op = op - self.args = args - - def __str__(self): - if len(self.args) == 1: - return "(%s %s)" % (self.op, self.args[0]) - return '(' + ("%s" % self.op).join(str(x) for x in self.args) + ')' diff --git a/miasm/core/asmblock.py b/miasm/core/asmblock.py deleted file mode 100644 index e92034fe..00000000 --- a/miasm/core/asmblock.py +++ /dev/null @@ -1,1474 +0,0 @@ -#-*- coding:utf-8 -*- - -from builtins import map -from builtins import range -import logging -import warnings -from collections import namedtuple -from builtins import int as int_types - -from future.utils import viewitems, viewvalues - -from miasm.expression.expression import ExprId, ExprInt, get_expr_locs -from miasm.expression.expression import LocKey -from miasm.expression.simplifications import expr_simp -from miasm.core.utils import Disasm_Exception, pck -from miasm.core.graph import DiGraph, DiGraphSimplifier, MatchGraphJoker -from miasm.core.interval import interval - - -log_asmblock = logging.getLogger("asmblock") -console_handler = logging.StreamHandler() -console_handler.setFormatter(logging.Formatter("[%(levelname)-8s]: %(message)s")) -log_asmblock.addHandler(console_handler) -log_asmblock.setLevel(logging.WARNING) - - -class AsmRaw(object): - - def __init__(self, raw=b""): - self.raw = raw - - def __str__(self): - return repr(self.raw) - - def to_string(self, loc_db): - return str(self) - - def to_html(self, loc_db): - return str(self) - - -class AsmConstraint(object): - c_to = "c_to" - c_next = "c_next" - - def __init__(self, loc_key, c_t=c_to): - # Sanity check - assert isinstance(loc_key, LocKey) - - self.loc_key = loc_key - self.c_t = c_t - - def to_string(self, loc_db=None): - if loc_db is None: - return "%s:%s" % (self.c_t, self.loc_key) - else: - return "%s:%s" % ( - self.c_t, - loc_db.pretty_str(self.loc_key) - ) - - def __str__(self): - return self.to_string() - - -class AsmConstraintNext(AsmConstraint): - - def __init__(self, loc_key): - super(AsmConstraintNext, self).__init__( - loc_key, - c_t=AsmConstraint.c_next - ) - - -class AsmConstraintTo(AsmConstraint): - - def __init__(self, loc_key): - super(AsmConstraintTo, self).__init__( - loc_key, - c_t=AsmConstraint.c_to - ) - - -class AsmBlock(object): - - def __init__(self, loc_db, loc_key, alignment=1): - assert isinstance(loc_key, LocKey) - - self.bto = set() - self.lines = [] - self.loc_db = loc_db - self._loc_key = loc_key - self.alignment = alignment - - loc_key = property(lambda self:self._loc_key) - - def to_string(self): - out = [] - out.append(self.loc_db.pretty_str(self.loc_key)) - - for instr in self.lines: - out.append(instr.to_string(self.loc_db)) - if self.bto: - lbls = ["->"] - for dst in self.bto: - if dst is None: - lbls.append("Unknown? ") - else: - lbls.append(dst.to_string(self.loc_db) + " ") - lbls = '\t'.join(sorted(lbls)) - out.append(lbls) - return '\n'.join(out) - - def __str__(self): - return self.to_string() - - def addline(self, l): - self.lines.append(l) - - def addto(self, c): - assert isinstance(self.bto, set) - self.bto.add(c) - - def split(self, offset): - loc_key = self.loc_db.get_or_create_offset_location(offset) - log_asmblock.debug('split at %x', offset) - offsets = [x.offset for x in self.lines] - offset = self.loc_db.get_location_offset(loc_key) - if offset not in offsets: - log_asmblock.warning( - 'cannot split block at %X ' % offset + - 'middle instruction? default middle') - offsets.sort() - return None - new_block = AsmBlock(self.loc_db, loc_key) - i = offsets.index(offset) - - self.lines, new_block.lines = self.lines[:i], self.lines[i:] - flow_mod_instr = self.get_flow_instr() - log_asmblock.debug('flow mod %r', flow_mod_instr) - c = AsmConstraint(loc_key, AsmConstraint.c_next) - # move dst if flowgraph modifier was in original block - # (usecase: split delayslot block) - if flow_mod_instr: - for xx in self.bto: - log_asmblock.debug('lbl %s', xx) - c_next = set( - x for x in self.bto if x.c_t == AsmConstraint.c_next - ) - c_to = [x for x in self.bto if x.c_t != AsmConstraint.c_next] - self.bto = set([c] + c_to) - new_block.bto = c_next - else: - new_block.bto = self.bto - self.bto = set([c]) - return new_block - - def get_range(self): - """Returns the offset hull of an AsmBlock""" - if len(self.lines): - return (self.lines[0].offset, - self.lines[-1].offset + self.lines[-1].l) - else: - return 0, 0 - - def get_offsets(self): - return [x.offset for x in self.lines] - - def add_cst(self, loc_key, constraint_type): - """ - Add constraint between current block and block at @loc_key - @loc_key: LocKey instance of constraint target - @constraint_type: AsmConstraint c_to/c_next - """ - assert isinstance(loc_key, LocKey) - c = AsmConstraint(loc_key, constraint_type) - self.bto.add(c) - - def get_flow_instr(self): - if not self.lines: - return None - for i in range(-1, -1 - self.lines[0].delayslot - 1, -1): - if not 0 <= i < len(self.lines): - return None - l = self.lines[i] - if l.splitflow() or l.breakflow(): - raise NotImplementedError('not fully functional') - - def get_subcall_instr(self): - if not self.lines: - return None - delayslot = self.lines[0].delayslot - end_index = len(self.lines) - 1 - ds_max_index = max(end_index - delayslot, 0) - for i in range(end_index, ds_max_index - 1, -1): - l = self.lines[i] - if l.is_subcall(): - return l - return None - - def get_next(self): - for constraint in self.bto: - if constraint.c_t == AsmConstraint.c_next: - return constraint.loc_key - return None - - @staticmethod - def _filter_constraint(constraints): - """Sort and filter @constraints for AsmBlock.bto - @constraints: non-empty set of AsmConstraint instance - - Always the same type -> one of the constraint - c_next and c_to -> c_next - """ - # Only one constraint - if len(constraints) == 1: - return next(iter(constraints)) - - # Constraint type -> set of corresponding constraint - cbytype = {} - for cons in constraints: - cbytype.setdefault(cons.c_t, set()).add(cons) - - # Only one type -> any constraint is OK - if len(cbytype) == 1: - return next(iter(constraints)) - - # At least 2 types -> types = {c_next, c_to} - # c_to is included in c_next - return next(iter(cbytype[AsmConstraint.c_next])) - - def fix_constraints(self): - """Fix next block constraints""" - # destination -> associated constraints - dests = {} - for constraint in self.bto: - dests.setdefault(constraint.loc_key, set()).add(constraint) - - self.bto = set( - self._filter_constraint(constraints) - for constraints in viewvalues(dests) - ) - - -class AsmBlockBad(AsmBlock): - - """Stand for a *bad* ASM block (malformed, unreachable, - not disassembled, ...)""" - - - ERROR_UNKNOWN = -1 - ERROR_CANNOT_DISASM = 0 - ERROR_NULL_STARTING_BLOCK = 1 - ERROR_FORBIDDEN = 2 - ERROR_IO = 3 - - - ERROR_TYPES = { - ERROR_UNKNOWN: "Unknown error", - ERROR_CANNOT_DISASM: "Unable to disassemble", - ERROR_NULL_STARTING_BLOCK: "Null starting block", - ERROR_FORBIDDEN: "Address forbidden by dont_dis", - ERROR_IO: "IOError", - } - - def __init__(self, loc_db, loc_key=None, alignment=1, errno=ERROR_UNKNOWN, *args, **kwargs): - """Instantiate an AsmBlock_bad. - @loc_key, @alignment: same as AsmBlock.__init__ - @errno: (optional) specify a error type associated with the block - """ - super(AsmBlockBad, self).__init__(loc_db, loc_key, alignment, *args, **kwargs) - self._errno = errno - - errno = property(lambda self: self._errno) - - def __str__(self): - error_txt = self.ERROR_TYPES.get(self._errno, self._errno) - return "%s\n\tBad block: %s" % ( - self.loc_key, - error_txt - ) - - def addline(self, *args, **kwargs): - raise RuntimeError("An AsmBlockBad cannot have line") - - def addto(self, *args, **kwargs): - raise RuntimeError("An AsmBlockBad cannot have bto") - - def split(self, *args, **kwargs): - raise RuntimeError("An AsmBlockBad cannot be split") - - -class AsmCFG(DiGraph): - - """Directed graph standing for a ASM Control Flow Graph with: - - nodes: AsmBlock - - edges: constraints between blocks, synchronized with AsmBlock's "bto" - - Specialized the .dot export and force the relation between block to be uniq, - and associated with a constraint. - - Offer helpers on AsmCFG management, such as research by loc_key, sanity - checking and mnemonic size guessing. - """ - - # Internal structure for pending management - AsmCFGPending = namedtuple("AsmCFGPending", - ["waiter", "constraint"]) - - def __init__(self, loc_db, *args, **kwargs): - super(AsmCFG, self).__init__(*args, **kwargs) - # Edges -> constraint - self.edges2constraint = {} - # Expected LocKey -> set( (src, dst), constraint ) - self._pendings = {} - # Loc_Key2block built on the fly - self._loc_key_to_block = {} - # loc_db - self.loc_db = loc_db - - - def copy(self): - """Copy the current graph instance""" - graph = self.__class__(self.loc_db) - return graph + self - - def __len__(self): - """Return the number of blocks in AsmCFG""" - return len(self._nodes) - - @property - def blocks(self): - return viewvalues(self._loc_key_to_block) - - # Manage graph with associated constraints - def add_edge(self, src, dst, constraint): - """Add an edge to the graph - @src: LocKey instance, source - @dst: LocKey instance, destination - @constraint: constraint associated to this edge - """ - # Sanity check - assert isinstance(src, LocKey) - assert isinstance(dst, LocKey) - known_cst = self.edges2constraint.get((src, dst), None) - if known_cst is not None: - assert known_cst == constraint - return - - # Add the edge to src.bto if needed - block_src = self.loc_key_to_block(src) - if block_src: - if dst not in [cons.loc_key for cons in block_src.bto]: - block_src.bto.add(AsmConstraint(dst, constraint)) - - # Add edge - self.edges2constraint[(src, dst)] = constraint - super(AsmCFG, self).add_edge(src, dst) - - def add_uniq_edge(self, src, dst, constraint): - """ - Synonym for `add_edge` - """ - self.add_edge(src, dst, constraint) - - def del_edge(self, src, dst): - """Delete the edge @src->@dst and its associated constraint""" - src_blk = self.loc_key_to_block(src) - dst_blk = self.loc_key_to_block(dst) - assert src_blk is not None - assert dst_blk is not None - # Delete from src.bto - to_remove = [cons for cons in src_blk.bto if cons.loc_key == dst] - if to_remove: - assert len(to_remove) == 1 - src_blk.bto.remove(to_remove[0]) - - # Del edge - del self.edges2constraint[(src, dst)] - super(AsmCFG, self).del_edge(src, dst) - - def del_block(self, block): - super(AsmCFG, self).del_node(block.loc_key) - del self._loc_key_to_block[block.loc_key] - - - def add_node(self, node): - assert isinstance(node, LocKey) - return super(AsmCFG, self).add_node(node) - - def add_block(self, block): - """ - Add the block @block to the current instance, if it is not already in - @block: AsmBlock instance - - Edges will be created for @block.bto, if destinations are already in - this instance. If not, they will be resolved when adding these - aforementioned destinations. - `self.pendings` indicates which blocks are not yet resolved. - - """ - status = super(AsmCFG, self).add_node(block.loc_key) - - if not status: - return status - - # Update waiters - if block.loc_key in self._pendings: - for bblpend in self._pendings[block.loc_key]: - self.add_edge(bblpend.waiter.loc_key, block.loc_key, bblpend.constraint) - del self._pendings[block.loc_key] - - # Synchronize edges with block destinations - self._loc_key_to_block[block.loc_key] = block - - for constraint in block.bto: - dst = self._loc_key_to_block.get(constraint.loc_key, - None) - if dst is None: - # Block is yet unknown, add it to pendings - to_add = self.AsmCFGPending(waiter=block, - constraint=constraint.c_t) - self._pendings.setdefault(constraint.loc_key, - set()).add(to_add) - else: - # Block is already in known nodes - self.add_edge(block.loc_key, dst.loc_key, constraint.c_t) - - return status - - def merge(self, graph): - """Merge with @graph, taking in account constraints""" - # Add known blocks - for block in graph.blocks: - self.add_block(block) - # Add nodes not already in it (ie. not linked to a block) - for node in graph.nodes(): - self.add_node(node) - # -> add_edge(x, y, constraint) - for edge in graph._edges: - # May fail if there is an incompatibility in edges constraints - # between the two graphs - self.add_edge(*edge, constraint=graph.edges2constraint[edge]) - - def escape_text(self, text): - return text - - - def node2lines(self, node): - loc_key_name = self.loc_db.pretty_str(node) - yield self.DotCellDescription(text=loc_key_name, - attr={'align': 'center', - 'colspan': 2, - 'bgcolor': 'grey'}) - block = self._loc_key_to_block.get(node, None) - if block is None: - return - if isinstance(block, AsmBlockBad): - yield [ - self.DotCellDescription( - text=block.ERROR_TYPES.get(block._errno, - block._errno - ), - attr={}) - ] - return - for line in block.lines: - if self._dot_offset: - yield [self.DotCellDescription(text="%.8X" % line.offset, - attr={}), - self.DotCellDescription(text=line.to_html(self.loc_db), attr={})] - else: - yield self.DotCellDescription(text=line.to_html(self.loc_db), attr={}) - - def node_attr(self, node): - block = self._loc_key_to_block.get(node, None) - if isinstance(block, AsmBlockBad): - return {'style': 'filled', 'fillcolor': 'red'} - return {} - - def edge_attr(self, src, dst): - cst = self.edges2constraint.get((src, dst), None) - edge_color = "blue" - - if len(self.successors(src)) > 1: - if cst == AsmConstraint.c_next: - edge_color = "red" - else: - edge_color = "limegreen" - - return {"color": edge_color} - - def dot(self, offset=False): - """ - @offset: (optional) if set, add the corresponding offsets in each node - """ - self._dot_offset = offset - return super(AsmCFG, self).dot() - - # Helpers - @property - def pendings(self): - """Dictionary of loc_key -> set(AsmCFGPending instance) indicating - which loc_key are missing in the current instance. - A loc_key is missing if a block which is already in nodes has constraints - with him (thanks to its .bto) and the corresponding block is not yet in - nodes - """ - return self._pendings - - def rebuild_edges(self): - """Consider blocks '.bto' and rebuild edges according to them, ie: - - update constraint type - - add missing edge - - remove no more used edge - - This method should be called if a block's '.bto' in nodes have been - modified without notifying this instance to resynchronize edges. - """ - self._pendings = {} - for block in self.blocks: - edges = [] - # Rebuild edges from bto - for constraint in block.bto: - dst = self._loc_key_to_block.get(constraint.loc_key, - None) - if dst is None: - # Missing destination, add to pendings - self._pendings.setdefault( - constraint.loc_key, - set() - ).add( - self.AsmCFGPending( - block, - constraint.c_t - ) - ) - continue - edge = (block.loc_key, dst.loc_key) - edges.append(edge) - if edge in self._edges: - # Already known edge, constraint may have changed - self.edges2constraint[edge] = constraint.c_t - else: - # An edge is missing - self.add_edge(edge[0], edge[1], constraint.c_t) - - # Remove useless edges - for succ in self.successors(block.loc_key): - edge = (block.loc_key, succ) - if edge not in edges: - self.del_edge(*edge) - - def get_bad_blocks(self): - """Iterator on AsmBlockBad elements""" - # A bad asm block is always a leaf - for loc_key in self.leaves(): - block = self._loc_key_to_block.get(loc_key, None) - if isinstance(block, AsmBlockBad): - yield block - - def get_bad_blocks_predecessors(self, strict=False): - """Iterator on loc_keys with an AsmBlockBad destination - @strict: (optional) if set, return loc_key with only bad - successors - """ - # Avoid returning the same block - done = set() - for badblock in self.get_bad_blocks(): - for predecessor in self.predecessors_iter(badblock.loc_key): - if predecessor not in done: - if (strict and - not all(isinstance(self._loc_key_to_block.get(block, None), AsmBlockBad) - for block in self.successors_iter(predecessor))): - continue - yield predecessor - done.add(predecessor) - - def getby_offset(self, offset): - """Return asmblock containing @offset""" - for block in self.blocks: - if block.lines[0].offset <= offset < \ - (block.lines[-1].offset + block.lines[-1].l): - return block - return None - - def loc_key_to_block(self, loc_key): - """ - Return the asmblock corresponding to loc_key @loc_key, None if unknown - loc_key - @loc_key: LocKey instance - """ - return self._loc_key_to_block.get(loc_key, None) - - def sanity_check(self): - """Do sanity checks on blocks' constraints: - * no pendings - * no multiple next constraint to same block - * no next constraint to self - """ - - if len(self._pendings) != 0: - raise RuntimeError( - "Some blocks are missing: %s" % list( - map( - str, - self._pendings - ) - ) - ) - - next_edges = { - edge: constraint - for edge, constraint in viewitems(self.edges2constraint) - if constraint == AsmConstraint.c_next - } - - for loc_key in self._nodes: - if loc_key not in self._loc_key_to_block: - raise RuntimeError("Not supported yet: every node must have a corresponding AsmBlock") - # No next constraint to self - if (loc_key, loc_key) in next_edges: - raise RuntimeError('Bad constraint: self in next') - - # No multiple next constraint to same block - pred_next = list(ploc_key - for (ploc_key, dloc_key) in next_edges - if dloc_key == loc_key) - - if len(pred_next) > 1: - raise RuntimeError("Too many next constraints for block %r" - "(%s)" % (loc_key, - pred_next)) - - def guess_blocks_size(self, mnemo): - """Asm and compute max block size - Add a 'size' and 'max_size' attribute on each block - @mnemo: metamn instance""" - for block in self.blocks: - size = 0 - for instr in block.lines: - if isinstance(instr, AsmRaw): - # for special AsmRaw, only extract len - if isinstance(instr.raw, list): - data = None - if len(instr.raw) == 0: - l = 0 - else: - l = (instr.raw[0].size // 8) * len(instr.raw) - elif isinstance(instr.raw, str): - data = instr.raw.encode() - l = len(data) - elif isinstance(instr.raw, bytes): - data = instr.raw - l = len(data) - else: - raise NotImplementedError('asm raw') - else: - # Assemble the instruction to retrieve its len. - # If the instruction uses symbol it will fail - # In this case, the max_instruction_len is used - try: - candidates = mnemo.asm(instr) - l = len(candidates[-1]) - except: - l = mnemo.max_instruction_len - data = None - instr.data = data - instr.l = l - size += l - - block.size = size - block.max_size = size - log_asmblock.info("size: %d max: %d", block.size, block.max_size) - - def apply_splitting(self, loc_db, dis_block_callback=None, **kwargs): - warnings.warn('DEPRECATION WARNING: apply_splitting is member of disasm_engine') - raise RuntimeError("Moved api") - - def __str__(self): - out = [] - for block in self.blocks: - out.append(str(block)) - for loc_key_a, loc_key_b in self.edges(): - out.append("%s -> %s" % (loc_key_a, loc_key_b)) - return '\n'.join(out) - - def __repr__(self): - return "<%s %s>" % (self.__class__.__name__, hex(id(self))) - -# Out of _merge_blocks to be computed only once -_acceptable_block = lambda graph, loc_key: (not isinstance(graph.loc_key_to_block(loc_key), AsmBlockBad) and - len(graph.loc_key_to_block(loc_key).lines) > 0) -_parent = MatchGraphJoker(restrict_in=False, filt=_acceptable_block) -_son = MatchGraphJoker(restrict_out=False, filt=_acceptable_block) -_expgraph = _parent >> _son - - -def _merge_blocks(dg, graph): - """Graph simplification merging AsmBlock with one and only one son with this - son if this son has one and only one parent""" - - # Blocks to ignore, because they have been removed from the graph - to_ignore = set() - - for match in _expgraph.match(graph): - - # Get matching blocks - lbl_block, lbl_succ = match[_parent], match[_son] - block = graph.loc_key_to_block(lbl_block) - succ = graph.loc_key_to_block(lbl_succ) - - # Ignore already deleted blocks - if (block in to_ignore or - succ in to_ignore): - continue - - # Remove block last instruction if needed - last_instr = block.lines[-1] - if last_instr.delayslot > 0: - # TODO: delayslot - raise RuntimeError("Not implemented yet") - - if last_instr.is_subcall(): - continue - if last_instr.breakflow() and last_instr.dstflow(): - block.lines.pop() - - # Merge block - block.lines += succ.lines - for nextb in graph.successors_iter(lbl_succ): - graph.add_edge(lbl_block, nextb, graph.edges2constraint[(lbl_succ, nextb)]) - - graph.del_block(succ) - to_ignore.add(lbl_succ) - - -bbl_simplifier = DiGraphSimplifier() -bbl_simplifier.enable_passes([_merge_blocks]) - - -def conservative_asm(mnemo, instr, symbols, conservative): - """ - Asm instruction; - Try to keep original instruction bytes if it exists - """ - candidates = mnemo.asm(instr, symbols) - if not candidates: - raise ValueError('cannot asm:%s' % str(instr)) - if not hasattr(instr, "b"): - return candidates[0], candidates - if instr.b in candidates: - return instr.b, candidates - if conservative: - for c in candidates: - if len(c) == len(instr.b): - return c, candidates - return candidates[0], candidates - - -def fix_expr_val(expr, symbols): - """Resolve an expression @expr using @symbols""" - def expr_calc(e): - if isinstance(e, ExprId): - # Example: - # toto: - # .dword label - loc_key = symbols.get_name_location(e.name) - offset = symbols.get_location_offset(loc_key) - e = ExprInt(offset, e.size) - return e - result = expr.visit(expr_calc) - result = expr_simp(result) - if not isinstance(result, ExprInt): - raise RuntimeError('Cannot resolve symbol %s' % expr) - return result - - -def fix_loc_offset(loc_db, loc_key, offset, modified): - """ - Fix the @loc_key offset to @offset. If the @offset has changed, add @loc_key - to @modified - @loc_db: current loc_db - """ - loc_offset = loc_db.get_location_offset(loc_key) - if loc_offset == offset: - return - if loc_offset is not None: - loc_db.unset_location_offset(loc_key) - loc_db.set_location_offset(loc_key, offset) - modified.add(loc_key) - - -class BlockChain(object): - - """Manage blocks linked with an asm_constraint_next""" - - def __init__(self, loc_db, blocks): - self.loc_db = loc_db - self.blocks = blocks - self.place() - - @property - def pinned(self): - """Return True iff at least one block is pinned""" - return self.pinned_block_idx is not None - - def _set_pinned_block_idx(self): - self.pinned_block_idx = None - for i, block in enumerate(self.blocks): - loc_key = block.loc_key - if self.loc_db.get_location_offset(loc_key) is not None: - if self.pinned_block_idx is not None: - raise ValueError("Multiples pinned block detected") - self.pinned_block_idx = i - - def place(self): - """Compute BlockChain min_offset and max_offset using pinned block and - blocks' size - """ - self._set_pinned_block_idx() - self.max_size = 0 - for block in self.blocks: - self.max_size += block.max_size + block.alignment - 1 - - # Check if chain has one block pinned - if not self.pinned: - return - - loc = self.blocks[self.pinned_block_idx].loc_key - offset_base = self.loc_db.get_location_offset(loc) - assert(offset_base % self.blocks[self.pinned_block_idx].alignment == 0) - - self.offset_min = offset_base - for block in self.blocks[:self.pinned_block_idx - 1:-1]: - self.offset_min -= block.max_size + \ - (block.alignment - block.max_size) % block.alignment - - self.offset_max = offset_base - for block in self.blocks[self.pinned_block_idx:]: - self.offset_max += block.max_size + \ - (block.alignment - block.max_size) % block.alignment - - def merge(self, chain): - """Best effort merge two block chains - Return the list of resulting blockchains""" - self.blocks += chain.blocks - self.place() - return [self] - - def fix_blocks(self, modified_loc_keys): - """Propagate a pinned to its blocks' neighbour - @modified_loc_keys: store new pinned loc_keys""" - - if not self.pinned: - raise ValueError('Trying to fix unpinned block') - - # Propagate offset to blocks before pinned block - pinned_block = self.blocks[self.pinned_block_idx] - offset = self.loc_db.get_location_offset(pinned_block.loc_key) - if offset % pinned_block.alignment != 0: - raise RuntimeError('Bad alignment') - - for block in self.blocks[:self.pinned_block_idx - 1:-1]: - new_offset = offset - block.size - new_offset = new_offset - new_offset % pinned_block.alignment - fix_loc_offset(self.loc_db, - block.loc_key, - new_offset, - modified_loc_keys) - - # Propagate offset to blocks after pinned block - offset = self.loc_db.get_location_offset(pinned_block.loc_key) + pinned_block.size - - last_block = pinned_block - for block in self.blocks[self.pinned_block_idx + 1:]: - offset += (- offset) % last_block.alignment - fix_loc_offset(self.loc_db, - block.loc_key, - offset, - modified_loc_keys) - offset += block.size - last_block = block - return modified_loc_keys - - -class BlockChainWedge(object): - - """Stand for wedges between blocks""" - - def __init__(self, loc_db, offset, size): - self.loc_db = loc_db - self.offset = offset - self.max_size = size - self.offset_min = offset - self.offset_max = offset + size - - def merge(self, chain): - """Best effort merge two block chains - Return the list of resulting blockchains""" - self.loc_db.set_location_offset(chain.blocks[0].loc_key, self.offset_max) - chain.place() - return [self, chain] - - -def group_constrained_blocks(asmcfg): - """ - Return the BlockChains list built from grouped blocks in asmcfg linked by - asm_constraint_next - @asmcfg: an AsmCfg instance - """ - log_asmblock.info('group_constrained_blocks') - - # Group adjacent asmcfg - remaining_blocks = list(asmcfg.blocks) - known_block_chains = {} - - while remaining_blocks: - # Create a new block chain - block_list = [remaining_blocks.pop()] - - # Find sons in remainings blocks linked with a next constraint - while True: - # Get next block - next_loc_key = block_list[-1].get_next() - if next_loc_key is None or asmcfg.loc_key_to_block(next_loc_key) is None: - break - next_block = asmcfg.loc_key_to_block(next_loc_key) - - # Add the block at the end of the current chain - if next_block not in remaining_blocks: - break - block_list.append(next_block) - remaining_blocks.remove(next_block) - - # Check if son is in a known block group - if next_loc_key is not None and next_loc_key in known_block_chains: - block_list += known_block_chains[next_loc_key] - del known_block_chains[next_loc_key] - - known_block_chains[block_list[0].loc_key] = block_list - - out_block_chains = [] - for loc_key in known_block_chains: - chain = BlockChain(asmcfg.loc_db, known_block_chains[loc_key]) - out_block_chains.append(chain) - return out_block_chains - - -def get_blockchains_address_interval(blockChains, dst_interval): - """Compute the interval used by the pinned @blockChains - Check if the placed chains are in the @dst_interval""" - - allocated_interval = interval() - for chain in blockChains: - if not chain.pinned: - continue - chain_interval = interval([(chain.offset_min, chain.offset_max - 1)]) - if chain_interval not in dst_interval: - raise ValueError('Chain placed out of destination interval') - allocated_interval += chain_interval - return allocated_interval - - -def resolve_symbol(blockChains, loc_db, dst_interval=None): - """Place @blockChains in the @dst_interval""" - - log_asmblock.info('resolve_symbol') - if dst_interval is None: - dst_interval = interval([(0, 0xFFFFFFFFFFFFFFFF)]) - - forbidden_interval = interval( - [(-1, 0xFFFFFFFFFFFFFFFF + 1)]) - dst_interval - allocated_interval = get_blockchains_address_interval(blockChains, - dst_interval) - log_asmblock.debug('allocated interval: %s', allocated_interval) - - pinned_chains = [chain for chain in blockChains if chain.pinned] - - # Add wedge in forbidden intervals - for start, stop in forbidden_interval.intervals: - wedge = BlockChainWedge( - loc_db, offset=start, size=stop + 1 - start) - pinned_chains.append(wedge) - - # Try to place bigger blockChains first - pinned_chains.sort(key=lambda x: x.offset_min) - blockChains.sort(key=lambda x: -x.max_size) - - fixed_chains = list(pinned_chains) - - log_asmblock.debug("place chains") - for chain in blockChains: - if chain.pinned: - continue - fixed = False - for i in range(1, len(fixed_chains)): - prev_chain = fixed_chains[i - 1] - next_chain = fixed_chains[i] - - if prev_chain.offset_max + chain.max_size < next_chain.offset_min: - new_chains = prev_chain.merge(chain) - fixed_chains[i - 1:i] = new_chains - fixed = True - break - if not fixed: - raise RuntimeError('Cannot find enough space to place blocks') - - return [chain for chain in fixed_chains if isinstance(chain, BlockChain)] - - -def get_block_loc_keys(block): - """Extract loc_keys used by @block""" - symbols = set() - for instr in block.lines: - if isinstance(instr, AsmRaw): - if isinstance(instr.raw, list): - for expr in instr.raw: - symbols.update(get_expr_locs(expr)) - else: - for arg in instr.args: - symbols.update(get_expr_locs(arg)) - return symbols - - -def assemble_block(mnemo, block, conservative=False): - """Assemble a @block - @conservative: (optional) use original bytes when possible - """ - offset_i = 0 - - for instr in block.lines: - if isinstance(instr, AsmRaw): - if isinstance(instr.raw, list): - # Fix special AsmRaw - data = b"" - for expr in instr.raw: - expr_int = fix_expr_val(expr, block.loc_db) - data += pck[expr_int.size](int(expr_int)) - instr.data = data - - instr.offset = offset_i - offset_i += instr.l - continue - - # Assemble an instruction - saved_args = list(instr.args) - instr.offset = block.loc_db.get_location_offset(block.loc_key) + offset_i - - # Replace instruction's arguments by resolved ones - instr.args = instr.resolve_args_with_symbols(block.loc_db) - - if instr.dstflow(): - instr.fixDstOffset() - - old_l = instr.l - cached_candidate, _ = conservative_asm( - mnemo, instr, block.loc_db, - conservative - ) - if len(cached_candidate) != instr.l: - # The output instruction length is different from the one we guessed - # Retry assembly with updated length - instr.l = len(cached_candidate) - instr.args = saved_args - instr.args = instr.resolve_args_with_symbols(block.loc_db) - if instr.dstflow(): - instr.fixDstOffset() - cached_candidate, _ = conservative_asm( - mnemo, instr, block.loc_db, - conservative - ) - assert len(cached_candidate) == instr.l - - # Restore original arguments - instr.args = saved_args - - # We need to update the block size - block.size = block.size - old_l + len(cached_candidate) - instr.data = cached_candidate - instr.l = len(cached_candidate) - - offset_i += instr.l - - -def asmblock_final(mnemo, asmcfg, blockChains, conservative=False): - """Resolve and assemble @blockChains until fixed point is - reached""" - - log_asmblock.debug("asmbloc_final") - - # Init structures - blocks_using_loc_key = {} - for block in asmcfg.blocks: - exprlocs = get_block_loc_keys(block) - loc_keys = set(expr.loc_key for expr in exprlocs) - for loc_key in loc_keys: - blocks_using_loc_key.setdefault(loc_key, set()).add(block) - - block2chain = {} - for chain in blockChains: - for block in chain.blocks: - block2chain[block] = chain - - # Init worklist - blocks_to_rework = set(asmcfg.blocks) - - # Fix and re-assemble blocks until fixed point is reached - while True: - - # Propagate pinned blocks into chains - modified_loc_keys = set() - for chain in blockChains: - chain.fix_blocks(modified_loc_keys) - - for loc_key in modified_loc_keys: - # Retrieve block with modified reference - mod_block = asmcfg.loc_key_to_block(loc_key) - if mod_block is not None: - blocks_to_rework.add(mod_block) - - # Enqueue blocks referencing a modified loc_key - if loc_key not in blocks_using_loc_key: - continue - for block in blocks_using_loc_key[loc_key]: - blocks_to_rework.add(block) - - # No more work - if not blocks_to_rework: - break - - while blocks_to_rework: - block = blocks_to_rework.pop() - assemble_block(mnemo, block, conservative) - - -def asm_resolve_final(mnemo, asmcfg, dst_interval=None): - """Resolve and assemble @asmcfg into interval - @dst_interval""" - - asmcfg.sanity_check() - - asmcfg.guess_blocks_size(mnemo) - blockChains = group_constrained_blocks(asmcfg) - resolved_blockChains = resolve_symbol(blockChains, asmcfg.loc_db, dst_interval) - asmblock_final(mnemo, asmcfg, resolved_blockChains) - patches = {} - output_interval = interval() - - for block in asmcfg.blocks: - offset = asmcfg.loc_db.get_location_offset(block.loc_key) - for instr in block.lines: - if not instr.data: - # Empty line - continue - assert len(instr.data) == instr.l - patches[offset] = instr.data - instruction_interval = interval([(offset, offset + instr.l - 1)]) - if not (instruction_interval & output_interval).empty: - raise RuntimeError("overlapping bytes %X" % int(offset)) - output_interval = output_interval.union(instruction_interval) - instr.offset = offset - offset += instr.l - return patches - - -class disasmEngine(object): - - """Disassembly engine, taking care of disassembler options and mutli-block - strategy. - - Engine options: - - + Object supporting membership test (offset in ..) - - dont_dis: stop the current disassembly branch if reached - - split_dis: force a basic block end if reached, - with a next constraint on its successor - - dont_dis_retcall_funcs: stop disassembly after a call to one - of the given functions - - + On/Off - - follow_call: recursively disassemble CALL destinations - - dontdis_retcall: stop on CALL return addresses - - dont_dis_nulstart_bloc: stop if a block begin with a few \x00 - - + Number - - lines_wd: maximum block's size (in number of instruction) - - blocs_wd: maximum number of distinct disassembled block - - + callback(mdis, cur_block, offsets_to_dis) - - dis_block_callback: callback after each new disassembled block - """ - - def __init__(self, arch, attrib, bin_stream, loc_db, **kwargs): - """Instantiate a new disassembly engine - @arch: targeted architecture - @attrib: architecture attribute - @bin_stream: bytes source - @kwargs: (optional) custom options - """ - self.arch = arch - self.attrib = attrib - self.bin_stream = bin_stream - self.loc_db = loc_db - - # Setup options - self.dont_dis = [] - self.split_dis = [] - self.follow_call = False - self.dontdis_retcall = False - self.lines_wd = None - self.blocs_wd = None - self.dis_block_callback = None - self.dont_dis_nulstart_bloc = False - self.dont_dis_retcall_funcs = set() - - # Override options if needed - self.__dict__.update(kwargs) - - def _dis_block(self, offset, job_done=None): - """Disassemble the block at offset @offset - @job_done: a set of already disassembled addresses - Return the created AsmBlock and future offsets to disassemble - """ - - if job_done is None: - job_done = set() - lines_cpt = 0 - in_delayslot = False - delayslot_count = self.arch.delayslot - offsets_to_dis = set() - add_next_offset = False - loc_key = self.loc_db.get_or_create_offset_location(offset) - cur_block = AsmBlock(self.loc_db, loc_key) - log_asmblock.debug("dis at %X", int(offset)) - while not in_delayslot or delayslot_count > 0: - if in_delayslot: - delayslot_count -= 1 - - if offset in self.dont_dis: - if not cur_block.lines: - job_done.add(offset) - # Block is empty -> bad block - cur_block = AsmBlockBad(self.loc_db, loc_key, errno=AsmBlockBad.ERROR_FORBIDDEN) - else: - # Block is not empty, stop the desassembly pass and add a - # constraint to the next block - loc_key_cst = self.loc_db.get_or_create_offset_location(offset) - cur_block.add_cst(loc_key_cst, AsmConstraint.c_next) - break - - if lines_cpt > 0 and offset in self.split_dis: - loc_key_cst = self.loc_db.get_or_create_offset_location(offset) - cur_block.add_cst(loc_key_cst, AsmConstraint.c_next) - offsets_to_dis.add(offset) - break - - lines_cpt += 1 - if self.lines_wd is not None and lines_cpt > self.lines_wd: - log_asmblock.debug("lines watchdog reached at %X", int(offset)) - break - - if offset in job_done: - loc_key_cst = self.loc_db.get_or_create_offset_location(offset) - cur_block.add_cst(loc_key_cst, AsmConstraint.c_next) - break - - off_i = offset - error = None - try: - instr = self.arch.dis(self.bin_stream, self.attrib, offset) - except Disasm_Exception as e: - log_asmblock.warning(e) - instr = None - error = AsmBlockBad.ERROR_CANNOT_DISASM - except IOError as e: - log_asmblock.warning(e) - instr = None - error = AsmBlockBad.ERROR_IO - - - if instr is None: - log_asmblock.warning("cannot disasm at %X", int(off_i)) - if not cur_block.lines: - job_done.add(offset) - # Block is empty -> bad block - cur_block = AsmBlockBad(self.loc_db, loc_key, errno=error) - else: - # Block is not empty, stop the desassembly pass and add a - # constraint to the next block - loc_key_cst = self.loc_db.get_or_create_offset_location(off_i) - cur_block.add_cst(loc_key_cst, AsmConstraint.c_next) - break - - # XXX TODO nul start block option - if (self.dont_dis_nulstart_bloc and - not cur_block.lines and - instr.b.count(b'\x00') == instr.l): - log_asmblock.warning("reach nul instr at %X", int(off_i)) - # Block is empty -> bad block - cur_block = AsmBlockBad(self.loc_db, loc_key, errno=AsmBlockBad.ERROR_NULL_STARTING_BLOCK) - break - - # special case: flow graph modificator in delayslot - if in_delayslot and instr and (instr.splitflow() or instr.breakflow()): - add_next_offset = True - break - - job_done.add(offset) - log_asmblock.debug("dis at %X", int(offset)) - - offset += instr.l - log_asmblock.debug(instr) - log_asmblock.debug(instr.args) - - cur_block.addline(instr) - if not instr.breakflow(): - continue - # test split - if instr.splitflow() and not (instr.is_subcall() and self.dontdis_retcall): - add_next_offset = True - if instr.dstflow(): - instr.dstflow2label(self.loc_db) - destinations = instr.getdstflow(self.loc_db) - known_dsts = [] - for dst in destinations: - if not dst.is_loc(): - continue - loc_key = dst.loc_key - loc_key_offset = self.loc_db.get_location_offset(loc_key) - known_dsts.append(loc_key) - if loc_key_offset in self.dont_dis_retcall_funcs: - add_next_offset = False - if (not instr.is_subcall()) or self.follow_call: - cur_block.bto.update([AsmConstraint(loc_key, AsmConstraint.c_to) for loc_key in known_dsts]) - - # get in delayslot mode - in_delayslot = True - delayslot_count = instr.delayslot - - for c in cur_block.bto: - loc_key_offset = self.loc_db.get_location_offset(c.loc_key) - offsets_to_dis.add(loc_key_offset) - - if add_next_offset: - loc_key_cst = self.loc_db.get_or_create_offset_location(offset) - cur_block.add_cst(loc_key_cst, AsmConstraint.c_next) - offsets_to_dis.add(offset) - - # Fix multiple constraints - cur_block.fix_constraints() - - if self.dis_block_callback is not None: - self.dis_block_callback(self, cur_block, offsets_to_dis) - return cur_block, offsets_to_dis - - def dis_block(self, offset): - """Disassemble the block at offset @offset and return the created - AsmBlock - @offset: targeted offset to disassemble - """ - current_block, _ = self._dis_block(offset) - return current_block - - def dis_multiblock(self, offset, blocks=None, job_done=None): - """Disassemble every block reachable from @offset regarding - specific disasmEngine conditions - Return an AsmCFG instance containing disassembled blocks - @offset: starting offset - @blocks: (optional) AsmCFG instance of already disassembled blocks to - merge with - """ - log_asmblock.info("dis block all") - if job_done is None: - job_done = set() - if blocks is None: - blocks = AsmCFG(self.loc_db) - todo = [offset] - - bloc_cpt = 0 - while len(todo): - bloc_cpt += 1 - if self.blocs_wd is not None and bloc_cpt > self.blocs_wd: - log_asmblock.debug("blocks watchdog reached at %X", int(offset)) - break - - target_offset = int(todo.pop(0)) - if (target_offset is None or - target_offset in job_done): - continue - cur_block, nexts = self._dis_block(target_offset, job_done) - todo += nexts - blocks.add_block(cur_block) - - self.apply_splitting(blocks) - return blocks - - def apply_splitting(self, blocks): - """Consider @blocks' bto destinations and split block in @blocks if one - of these destinations jumps in the middle of this block. In order to - work, they must be only one block in @self per loc_key in - - @blocks: Asmcfg - """ - # Get all possible destinations not yet resolved, with a resolved - # offset - block_dst = [] - for loc_key in blocks.pendings: - offset = self.loc_db.get_location_offset(loc_key) - if offset is not None: - block_dst.append(offset) - - todo = set(blocks.blocks) - rebuild_needed = False - - while todo: - # Find a block with a destination inside another one - cur_block = todo.pop() - range_start, range_stop = cur_block.get_range() - - for off in block_dst: - if not (off > range_start and off < range_stop): - continue - - # `cur_block` must be split at offset `off`from miasm.core.locationdb import LocationDB - - new_b = cur_block.split(off) - log_asmblock.debug("Split block %x", off) - if new_b is None: - log_asmblock.error("Cannot split %x!!", off) - continue - - # Remove pending from cur_block - # Links from new_b will be generated in rebuild_edges - for dst in new_b.bto: - if dst.loc_key not in blocks.pendings: - continue - blocks.pendings[dst.loc_key] = set(pending for pending in blocks.pendings[dst.loc_key] - if pending.waiter != cur_block) - - # The new block destinations may need to be disassembled - if self.dis_block_callback: - offsets_to_dis = set( - self.loc_db.get_location_offset(constraint.loc_key) - for constraint in new_b.bto - ) - self.dis_block_callback(self, new_b, offsets_to_dis) - - # Update structure - rebuild_needed = True - blocks.add_block(new_b) - - # The new block must be considered - todo.add(new_b) - range_start, range_stop = cur_block.get_range() - - # Rebuild edges to match new blocks'bto - if rebuild_needed: - blocks.rebuild_edges() - - def dis_instr(self, offset): - """Disassemble one instruction at offset @offset and return the - corresponding instruction instance - @offset: targeted offset to disassemble - """ - old_lineswd = self.lines_wd - self.lines_wd = 1 - try: - block = self.dis_block(offset) - finally: - self.lines_wd = old_lineswd - - instr = block.lines[0] - return instr diff --git a/miasm/core/bin_stream.py b/miasm/core/bin_stream.py deleted file mode 100644 index 46165d49..00000000 --- a/miasm/core/bin_stream.py +++ /dev/null @@ -1,319 +0,0 @@ -# -# Copyright (C) 2011 EADS France, Fabrice Desclaux <fabrice.desclaux@eads.net> -# -# This program is free software; you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation; either version 2 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License along -# with this program; if not, write to the Free Software Foundation, Inc., -# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -# - -from builtins import str -from future.utils import PY3 - -from miasm.core.utils import BIG_ENDIAN, LITTLE_ENDIAN -from miasm.core.utils import upck8le, upck16le, upck32le, upck64le -from miasm.core.utils import upck8be, upck16be, upck32be, upck64be - - -class bin_stream(object): - - # Cache must be initialized by entering atomic mode - _cache = None - CACHE_SIZE = 10000 - # By default, no atomic mode - _atomic_mode = False - - def __init__(self, *args, **kargs): - self.endianness = LITTLE_ENDIAN - - def __repr__(self): - return "<%s !!>" % self.__class__.__name__ - - def __str__(self): - if PY3: - return repr(self) - return self.__bytes__() - - def hexdump(self, offset, l): - return - - def enter_atomic_mode(self): - """Enter atomic mode. In this mode, read may be cached""" - assert not self._atomic_mode - self._atomic_mode = True - self._cache = {} - - def leave_atomic_mode(self): - """Leave atomic mode""" - assert self._atomic_mode - self._atomic_mode = False - self._cache = None - - def _getbytes(self, start, length): - return self.bin[start:start + length] - - def getbytes(self, start, l=1): - """Return the bytes from the bit stream - @start: starting offset (in byte) - @l: (optional) number of bytes to read - - Wrapper on _getbytes, with atomic mode handling. - """ - if self._atomic_mode: - val = self._cache.get((start,l), None) - if val is None: - val = self._getbytes(start, l) - self._cache[(start,l)] = val - else: - val = self._getbytes(start, l) - return val - - def getbits(self, start, n): - """Return the bits from the bit stream - @start: the offset in bits - @n: number of bits to read - """ - # Trivial case - if n == 0: - return 0 - - # Get initial bytes - if n > self.getlen() * 8: - raise IOError('not enough bits %r %r' % (n, len(self.bin) * 8)) - byte_start = start // 8 - byte_stop = (start + n + 7) // 8 - temp = self.getbytes(byte_start, byte_stop - byte_start) - if not temp: - raise IOError('cannot get bytes') - - # Init - start = start % 8 - out = 0 - while n: - # Get needed bits, working on maximum 8 bits at a time - cur_byte_idx = start // 8 - new_bits = ord(temp[cur_byte_idx:cur_byte_idx + 1]) - to_keep = 8 - start % 8 - new_bits &= (1 << to_keep) - 1 - cur_len = min(to_keep, n) - new_bits >>= (to_keep - cur_len) - - # Update output - out <<= cur_len - out |= new_bits - - # Update counters - n -= cur_len - start += cur_len - return out - - def get_u8(self, addr, endianness=None): - """ - Return u8 from address @addr - endianness: Optional: LITTLE_ENDIAN/BIG_ENDIAN - """ - if endianness is None: - endianness = self.endianness - data = self.getbytes(addr, 1) - if endianness == LITTLE_ENDIAN: - return upck8le(data) - else: - return upck8be(data) - - def get_u16(self, addr, endianness=None): - """ - Return u16 from address @addr - endianness: Optional: LITTLE_ENDIAN/BIG_ENDIAN - """ - if endianness is None: - endianness = self.endianness - data = self.getbytes(addr, 2) - if endianness == LITTLE_ENDIAN: - return upck16le(data) - else: - return upck16be(data) - - def get_u32(self, addr, endianness=None): - """ - Return u32 from address @addr - endianness: Optional: LITTLE_ENDIAN/BIG_ENDIAN - """ - if endianness is None: - endianness = self.endianness - data = self.getbytes(addr, 4) - if endianness == LITTLE_ENDIAN: - return upck32le(data) - else: - return upck32be(data) - - def get_u64(self, addr, endianness=None): - """ - Return u64 from address @addr - endianness: Optional: LITTLE_ENDIAN/BIG_ENDIAN - """ - if endianness is None: - endianness = self.endianness - data = self.getbytes(addr, 8) - if endianness == LITTLE_ENDIAN: - return upck64le(data) - else: - return upck64be(data) - - -class bin_stream_str(bin_stream): - - def __init__(self, input_str=b"", offset=0, base_address=0, shift=None): - bin_stream.__init__(self) - if shift is not None: - raise DeprecationWarning("use base_address instead of shift") - self.bin = input_str - self.offset = offset - self.base_address = base_address - self.l = len(input_str) - - def _getbytes(self, start, l=1): - if start + l - self.base_address > self.l: - raise IOError("not enough bytes in str") - if start - self.base_address < 0: - raise IOError("Negative offset") - - return super(bin_stream_str, self)._getbytes(start - self.base_address, l) - - def readbs(self, l=1): - if self.offset + l - self.base_address > self.l: - raise IOError("not enough bytes in str") - if self.offset - self.base_address < 0: - raise IOError("Negative offset") - self.offset += l - return self.bin[self.offset - l - self.base_address:self.offset - self.base_address] - - def __bytes__(self): - return self.bin[self.offset - self.base_address:] - - def setoffset(self, val): - self.offset = val - - def getlen(self): - return self.l - (self.offset - self.base_address) - - -class bin_stream_file(bin_stream): - - def __init__(self, binary, offset=0, base_address=0, shift=None): - bin_stream.__init__(self) - if shift is not None: - raise DeprecationWarning("use base_address instead of shift") - self.bin = binary - self.bin.seek(0, 2) - self.base_address = base_address - self.l = self.bin.tell() - self.offset = offset - - def getoffset(self): - return self.bin.tell() + self.base_address - - def setoffset(self, val): - self.bin.seek(val - self.base_address) - offset = property(getoffset, setoffset) - - def readbs(self, l=1): - if self.offset + l - self.base_address > self.l: - raise IOError("not enough bytes in file") - if self.offset - self.base_address < 0: - raise IOError("Negative offset") - return self.bin.read(l) - - def __bytes__(self): - return self.bin.read() - - def getlen(self): - return self.l - (self.offset - self.base_address) - - -class bin_stream_container(bin_stream): - - def __init__(self, binary, offset=0): - bin_stream.__init__(self) - self.bin = binary - self.l = binary.virt.max_addr() - self.offset = offset - - def is_addr_in(self, ad): - return self.bin.virt.is_addr_in(ad) - - def getlen(self): - return self.l - - def readbs(self, l=1): - if self.offset + l > self.l: - raise IOError("not enough bytes") - if self.offset < 0: - raise IOError("Negative offset") - self.offset += l - return self.bin.virt.get(self.offset - l, self.offset) - - def _getbytes(self, start, l=1): - try: - return self.bin.virt.get(start, start + l) - except ValueError: - raise IOError("cannot get bytes") - - def __bytes__(self): - return self.bin.virt.get(self.offset, self.offset + self.l) - - def setoffset(self, val): - self.offset = val - - -class bin_stream_pe(bin_stream_container): - def __init__(self, binary, *args, **kwargs): - super(bin_stream_pe, self).__init__(binary, *args, **kwargs) - self.endianness = binary._sex - - -class bin_stream_elf(bin_stream_container): - def __init__(self, binary, *args, **kwargs): - super(bin_stream_elf, self).__init__(binary, *args, **kwargs) - self.endianness = binary.sex - - -class bin_stream_vm(bin_stream): - - def __init__(self, vm, offset=0, base_offset=0): - self.offset = offset - self.base_offset = base_offset - self.vm = vm - if self.vm.is_little_endian(): - self.endianness = LITTLE_ENDIAN - else: - self.endianness = BIG_ENDIAN - - def getlen(self): - return 0xFFFFFFFFFFFFFFFF - - def _getbytes(self, start, l=1): - try: - s = self.vm.get_mem(start + self.base_offset, l) - except: - raise IOError('cannot get mem ad', hex(start)) - return s - - def readbs(self, l=1): - try: - s = self.vm.get_mem(self.offset + self.base_offset, l) - except: - raise IOError('cannot get mem ad', hex(self.offset)) - self.offset += l - return s - - def setoffset(self, val): - self.offset = val diff --git a/miasm/core/bin_stream_ida.py b/miasm/core/bin_stream_ida.py deleted file mode 100644 index 15bd9d8b..00000000 --- a/miasm/core/bin_stream_ida.py +++ /dev/null @@ -1,45 +0,0 @@ -from builtins import range -from idc import get_wide_byte, get_segm_end -from idautils import Segments -from idaapi import is_mapped - -from miasm.core.utils import int_to_byte -from miasm.core.bin_stream import bin_stream_str - - -class bin_stream_ida(bin_stream_str): - """ - bin_stream implementation for IDA - - Don't generate xrange using address computation: - It can raise error on overflow 7FFFFFFF with 32 bit python - """ - def _getbytes(self, start, l=1): - out = [] - for ad in range(l): - offset = ad + start + self.base_address - if not is_mapped(offset): - raise IOError(f"not enough bytes @ offset {offset:x}") - out.append(int_to_byte(get_wide_byte(offset))) - return b''.join(out) - - def readbs(self, l=1): - if self.offset + l > self.l: - raise IOError("not enough bytes") - content = self.getbytes(self.offset) - self.offset += l - return content - - def __str__(self): - raise NotImplementedError('Not fully functional') - - def setoffset(self, val): - self.offset = val - - def getlen(self): - # Lazy version - if hasattr(self, "_getlen"): - return self._getlen - max_addr = get_segm_end(list(Segments())[-1] - (self.offset - self.base_address)) - self._getlen = max_addr - return max_addr diff --git a/miasm/core/cpu.py b/miasm/core/cpu.py deleted file mode 100644 index 7df9f991..00000000 --- a/miasm/core/cpu.py +++ /dev/null @@ -1,1715 +0,0 @@ -#-*- coding:utf-8 -*- - -from builtins import range -import re -import struct -import logging -from collections import defaultdict - - -from future.utils import viewitems, viewvalues - -import pyparsing - -from miasm.core.utils import decode_hex -import miasm.expression.expression as m2_expr -from miasm.core.bin_stream import bin_stream, bin_stream_str -from miasm.core.utils import Disasm_Exception -from miasm.expression.simplifications import expr_simp - - -from miasm.core.asm_ast import AstNode, AstInt, AstId, AstOp -from miasm.core import utils -from future.utils import with_metaclass - -log = logging.getLogger("cpuhelper") -console_handler = logging.StreamHandler() -console_handler.setFormatter(logging.Formatter("[%(levelname)-8s]: %(message)s")) -log.addHandler(console_handler) -log.setLevel(logging.WARN) - - -class bitobj(object): - - def __init__(self, s=b""): - if not s: - bits = [] - else: - bits = [int(x) for x in bin(int(encode_hex(s), 16))[2:]] - if len(bits) % 8: - bits = [0 for x in range(8 - (len(bits) % 8))] + bits - self.bits = bits - self.offset = 0 - - def __len__(self): - return len(self.bits) - self.offset - - def getbits(self, n): - if not n: - return 0 - if n > len(self.bits) - self.offset: - raise ValueError('not enough bits %r %r' % (n, len(self.bits))) - b = self.bits[self.offset:self.offset + n] - b = int("".join(str(x) for x in b), 2) - self.offset += n - return b - - def putbits(self, b, n): - if not n: - return - bits = list(bin(b)[2:]) - bits = [int(x) for x in bits] - bits = [0 for x in range(n - len(bits))] + bits - self.bits += bits - - def tostring(self): - if len(self.bits) % 8: - raise ValueError( - 'num bits must be 8 bit aligned: %d' % len(self.bits) - ) - b = int("".join(str(x) for x in self.bits), 2) - b = "%X" % b - b = '0' * (len(self.bits) // 4 - len(b)) + b - b = decode_hex(b.encode()) - return b - - def reset(self): - self.offset = 0 - - def copy_state(self): - b = self.__class__() - b.bits = self.bits - b.offset = self.offset - return b - - -def literal_list(l): - l = l[:] - l.sort() - l = l[::-1] - o = pyparsing.Literal(l[0]) - for x in l[1:]: - o |= pyparsing.Literal(x) - return o - - -class reg_info(object): - - def __init__(self, reg_str, reg_expr): - self.str = reg_str - self.expr = reg_expr - self.parser = literal_list(reg_str).setParseAction(self.cb_parse) - - def cb_parse(self, tokens): - assert len(tokens) == 1 - i = self.str.index(tokens[0]) - reg = self.expr[i] - result = AstId(reg) - return result - - def reg2expr(self, s): - i = self.str.index(s[0]) - return self.expr[i] - - def expr2regi(self, e): - return self.expr.index(e) - - -class reg_info_dct(object): - - def __init__(self, reg_expr): - self.dct_str_inv = dict((v.name, k) for k, v in viewitems(reg_expr)) - self.dct_expr = reg_expr - self.dct_expr_inv = dict((v, k) for k, v in viewitems(reg_expr)) - reg_str = [v.name for v in viewvalues(reg_expr)] - self.parser = literal_list(reg_str).setParseAction(self.cb_parse) - - def cb_parse(self, tokens): - assert len(tokens) == 1 - i = self.dct_str_inv[tokens[0]] - reg = self.dct_expr[i] - result = AstId(reg) - return result - - def reg2expr(self, s): - i = self.dct_str_inv[s[0]] - return self.dct_expr[i] - - def expr2regi(self, e): - return self.dct_expr_inv[e] - - -def gen_reg(reg_name, sz=32): - """Gen reg expr and parser""" - reg = m2_expr.ExprId(reg_name, sz) - reginfo = reg_info([reg_name], [reg]) - return reg, reginfo - - -def gen_reg_bs(reg_name, reg_info, base_cls): - """ - Generate: - class bs_reg_name(base_cls): - reg = reg_info - - bs_reg_name = bs(l=0, cls=(bs_reg_name,)) - """ - - bs_name = "bs_%s" % reg_name - cls = type(bs_name, base_cls, {'reg': reg_info}) - - bs_obj = bs(l=0, cls=(cls,)) - - return cls, bs_obj - - -def gen_regs(rnames, env, sz=32): - regs_str = [] - regs_expr = [] - regs_init = [] - for rname in rnames: - r = m2_expr.ExprId(rname, sz) - r_init = m2_expr.ExprId(rname+'_init', sz) - regs_str.append(rname) - regs_expr.append(r) - regs_init.append(r_init) - env[rname] = r - - reginfo = reg_info(regs_str, regs_expr) - return regs_expr, regs_init, reginfo - - -LPARENTHESIS = pyparsing.Literal("(") -RPARENTHESIS = pyparsing.Literal(")") - - -def int2expr(tokens): - v = tokens[0] - return (m2_expr.ExprInt, v) - - -def parse_op(tokens): - v = tokens[0] - return (m2_expr.ExprOp, v) - - -def parse_id(tokens): - v = tokens[0] - return (m2_expr.ExprId, v) - - -def ast_parse_op(tokens): - if len(tokens) == 1: - return tokens[0] - if len(tokens) == 2: - if tokens[0] in ['-', '+', '!']: - return m2_expr.ExprOp(tokens[0], tokens[1]) - if len(tokens) == 3: - if tokens[1] == '-': - # a - b => a + (-b) - tokens[1] = '+' - tokens[2] = - tokens[2] - return m2_expr.ExprOp(tokens[1], tokens[0], tokens[2]) - tokens = tokens[::-1] - while len(tokens) >= 3: - o1, op, o2 = tokens.pop(), tokens.pop(), tokens.pop() - if op == '-': - # a - b => a + (-b) - op = '+' - o2 = - o2 - e = m2_expr.ExprOp(op, o1, o2) - tokens.append(e) - if len(tokens) != 1: - raise NotImplementedError('strange op') - return tokens[0] - - -def ast_id2expr(a): - return m2_expr.ExprId(a, 32) - - -def ast_int2expr(a): - return m2_expr.ExprInt(a, 32) - - -def neg_int(tokens): - x = -tokens[0] - return x - - -integer = pyparsing.Word(pyparsing.nums).setParseAction(lambda tokens: int(tokens[0])) -hex_word = pyparsing.Literal('0x') + pyparsing.Word(pyparsing.hexnums) -hex_int = pyparsing.Combine(hex_word).setParseAction(lambda tokens: int(tokens[0], 16)) - -# str_int = (Optional('-') + (hex_int | integer)) -str_int_pos = (hex_int | integer) -str_int_neg = (pyparsing.Suppress('-') + \ - (hex_int | integer)).setParseAction(neg_int) - -str_int = str_int_pos | str_int_neg -str_int.setParseAction(int2expr) - -logicop = pyparsing.oneOf('& | ^ >> << <<< >>>') -signop = pyparsing.oneOf('+ -') -multop = pyparsing.oneOf('* / %') -plusop = pyparsing.oneOf('+ -') - - -########################## - -def literal_list(l): - l = l[:] - l.sort() - l = l[::-1] - o = pyparsing.Literal(l[0]) - for x in l[1:]: - o |= pyparsing.Literal(x) - return o - - -def cb_int(tokens): - assert len(tokens) == 1 - integer = AstInt(tokens[0]) - return integer - - -def cb_parse_id(tokens): - assert len(tokens) == 1 - reg = tokens[0] - return AstId(reg) - - -def cb_op_not(tokens): - tokens = tokens[0] - assert len(tokens) == 2 - assert tokens[0] == "!" - result = AstOp("!", tokens[1]) - return result - - -def merge_ops(tokens, op): - args = [] - if len(tokens) >= 3: - args = [tokens.pop(0)] - i = 0 - while i < len(tokens): - op_tmp = tokens[i] - arg = tokens[i+1] - i += 2 - if op_tmp != op: - raise ValueError("Bad operator") - args.append(arg) - result = AstOp(op, *args) - return result - - -def cb_op_and(tokens): - result = merge_ops(tokens[0], "&") - return result - - -def cb_op_xor(tokens): - result = merge_ops(tokens[0], "^") - return result - - -def cb_op_sign(tokens): - assert len(tokens) == 1 - op, value = tokens[0] - return -value - - -def cb_op_div(tokens): - tokens = tokens[0] - assert len(tokens) == 3 - assert tokens[1] == "/" - result = AstOp("/", tokens[0], tokens[2]) - return result - - -def cb_op_plusminus(tokens): - tokens = tokens[0] - if len(tokens) == 3: - # binary op - assert isinstance(tokens[0], AstNode) - assert isinstance(tokens[2], AstNode) - op, args = tokens[1], [tokens[0], tokens[2]] - elif len(tokens) > 3: - args = [tokens.pop(0)] - i = 0 - while i < len(tokens): - op = tokens[i] - arg = tokens[i+1] - i += 2 - if op == '-': - arg = -arg - elif op == '+': - pass - else: - raise ValueError("Bad operator") - args.append(arg) - op = '+' - else: - raise ValueError("Parsing error") - assert all(isinstance(arg, AstNode) for arg in args) - result = AstOp(op, *args) - return result - - -def cb_op_mul(tokens): - tokens = tokens[0] - assert len(tokens) == 3 - assert isinstance(tokens[0], AstNode) - assert isinstance(tokens[2], AstNode) - - # binary op - op, args = tokens[1], [tokens[0], tokens[2]] - result = AstOp(op, *args) - return result - - -integer = pyparsing.Word(pyparsing.nums).setParseAction(lambda tokens: int(tokens[0])) -hex_word = pyparsing.Literal('0x') + pyparsing.Word(pyparsing.hexnums) -hex_int = pyparsing.Combine(hex_word).setParseAction(lambda tokens: int(tokens[0], 16)) - -str_int_pos = (hex_int | integer) - -str_int = str_int_pos -str_int.setParseAction(cb_int) - -notop = pyparsing.oneOf('!') -andop = pyparsing.oneOf('&') -orop = pyparsing.oneOf('|') -xorop = pyparsing.oneOf('^') -shiftop = pyparsing.oneOf('>> <<') -rotop = pyparsing.oneOf('<<< >>>') -signop = pyparsing.oneOf('+ -') -mulop = pyparsing.oneOf('*') -plusop = pyparsing.oneOf('+ -') -divop = pyparsing.oneOf('/') - - -variable = pyparsing.Word(pyparsing.alphas + "_$.", pyparsing.alphanums + "_") -variable.setParseAction(cb_parse_id) -operand = str_int | variable - -base_expr = pyparsing.infixNotation(operand, - [(notop, 1, pyparsing.opAssoc.RIGHT, cb_op_not), - (andop, 2, pyparsing.opAssoc.RIGHT, cb_op_and), - (xorop, 2, pyparsing.opAssoc.RIGHT, cb_op_xor), - (signop, 1, pyparsing.opAssoc.RIGHT, cb_op_sign), - (mulop, 2, pyparsing.opAssoc.RIGHT, cb_op_mul), - (divop, 2, pyparsing.opAssoc.RIGHT, cb_op_div), - (plusop, 2, pyparsing.opAssoc.LEFT, cb_op_plusminus), - ]) - - -default_prio = 0x1337 - - -def isbin(s): - return re.match(r'[0-1]+$', s) - - -def int2bin(i, l): - s = '0' * l + bin(i)[2:] - return s[-l:] - - -def myror32(v, r): - return ((v & 0xFFFFFFFF) >> r) | ((v << (32 - r)) & 0xFFFFFFFF) - - -def myrol32(v, r): - return ((v & 0xFFFFFFFF) >> (32 - r)) | ((v << r) & 0xFFFFFFFF) - - -class bs(object): - all_new_c = {} - prio = default_prio - - def __init__(self, strbits=None, l=None, cls=None, - fname=None, order=0, flen=None, **kargs): - if fname is None: - fname = hex(id(str((strbits, l, cls, fname, order, flen, kargs)))) - if strbits is None: - strbits = "" # "X"*l - elif l is None: - l = len(strbits) - if strbits and isbin(strbits): - value = int(strbits, 2) - elif 'default_val' in kargs: - value = int(kargs['default_val'], 2) - else: - value = None - allbits = list(strbits) - allbits.reverse() - fbits = 0 - fmask = 0 - while allbits: - a = allbits.pop() - if a == " ": - continue - fbits <<= 1 - fmask <<= 1 - if a in '01': - a = int(a) - fbits |= a - fmask |= 1 - lmask = (1 << l) - 1 - # gen conditional field - if cls: - for b in cls: - if 'flen' in b.__dict__: - flen = getattr(b, 'flen') - - self.strbits = strbits - self.l = l - self.cls = cls - self.fname = fname - self.order = order - self.fbits = fbits - self.fmask = fmask - self.flen = flen - self.value = value - self.kargs = kargs - - lmask = property(lambda self:(1 << self.l) - 1) - - def __getitem__(self, item): - return getattr(self, item) - - def __repr__(self): - o = self.__class__.__name__ - if self.fname: - o += "_%s" % self.fname - o += "_%(strbits)s" % self - if self.cls: - o += '_' + '_'.join([x.__name__ for x in self.cls]) - return o - - def gen(self, parent): - c_name = 'nbsi' - if self.cls: - c_name += '_' + '_'.join([x.__name__ for x in self.cls]) - bases = list(self.cls) - else: - bases = [] - # bsi added at end of list - # used to use first function of added class - bases += [bsi] - k = c_name, tuple(bases) - if k in self.all_new_c: - new_c = self.all_new_c[k] - else: - new_c = type(c_name, tuple(bases), {}) - self.all_new_c[k] = new_c - c = new_c(parent, - self.strbits, self.l, self.cls, - self.fname, self.order, self.lmask, self.fbits, - self.fmask, self.value, self.flen, **self.kargs) - return c - - def check_fbits(self, v): - return v & self.fmask == self.fbits - - @classmethod - def flen(cls, v): - raise NotImplementedError('not fully functional') - - -class dum_arg(object): - - def __init__(self, e=None): - self.expr = e - - -class bsopt(bs): - - def ispresent(self): - return True - - -class bsi(object): - - def __init__(self, parent, strbits, l, cls, fname, order, - lmask, fbits, fmask, value, flen, **kargs): - self.parent = parent - self.strbits = strbits - self.l = l - self.cls = cls - self.fname = fname - self.order = order - self.fbits = fbits - self.fmask = fmask - self.flen = flen - self.value = value - self.kargs = kargs - self.__dict__.update(self.kargs) - - lmask = property(lambda self:(1 << self.l) - 1) - - def decode(self, v): - self.value = v & self.lmask - return True - - def encode(self): - return True - - def clone(self): - s = self.__class__(self.parent, - self.strbits, self.l, self.cls, - self.fname, self.order, self.lmask, self.fbits, - self.fmask, self.value, self.flen, **self.kargs) - s.__dict__.update(self.kargs) - if hasattr(self, 'expr'): - s.expr = self.expr - return s - - def __hash__(self): - kargs = [] - for k, v in list(viewitems(self.kargs)): - if isinstance(v, list): - v = tuple(v) - kargs.append((k, v)) - l = [self.strbits, self.l, self.cls, - self.fname, self.order, self.lmask, self.fbits, - self.fmask, self.value] # + kargs - - return hash(tuple(l)) - - -class bs_divert(object): - prio = default_prio - - def __init__(self, **kargs): - self.args = kargs - - def __getattr__(self, item): - if item in self.__dict__: - return self.__dict__[item] - elif item in self.args: - return self.args.get(item) - else: - raise AttributeError - - -class bs_name(bs_divert): - prio = 1 - - def divert(self, i, candidates): - out = [] - for cls, _, bases, dct, fields in candidates: - for new_name, value in viewitems(self.args['name']): - nfields = fields[:] - s = int2bin(value, self.args['l']) - args = dict(self.args) - args.update({'strbits': s}) - f = bs(**args) - nfields[i] = f - ndct = dict(dct) - ndct['name'] = new_name - out.append((cls, new_name, bases, ndct, nfields)) - return out - - -class bs_mod_name(bs_divert): - prio = 2 - - def divert(self, i, candidates): - out = [] - for cls, _, bases, dct, fields in candidates: - tab = self.args['mn_mod'] - if isinstance(tab, list): - tmp = {} - for j, v in enumerate(tab): - tmp[j] = v - tab = tmp - for value, new_name in viewitems(tab): - nfields = fields[:] - s = int2bin(value, self.args['l']) - args = dict(self.args) - args.update({'strbits': s}) - f = bs(**args) - nfields[i] = f - ndct = dict(dct) - ndct['name'] = self.modname(ndct['name'], value) - out.append((cls, new_name, bases, ndct, nfields)) - return out - - def modname(self, name, i): - return name + self.args['mn_mod'][i] - - -class bs_cond(bsi): - pass - - -class bs_swapargs(bs_divert): - - def divert(self, i, candidates): - out = [] - for cls, name, bases, dct, fields in candidates: - # args not permuted - ndct = dict(dct) - nfields = fields[:] - # gen fix field - f = gen_bsint(0, self.args['l'], self.args) - nfields[i] = f - out.append((cls, name, bases, ndct, nfields)) - - # args permuted - ndct = dict(dct) - nfields = fields[:] - ap = ndct['args_permut'][:] - a = ap.pop(0) - b = ap.pop(0) - ndct['args_permut'] = [b, a] + ap - # gen fix field - f = gen_bsint(1, self.args['l'], self.args) - nfields[i] = f - - out.append((cls, name, bases, ndct, nfields)) - return out - - -class m_arg(object): - - def fromstring(self, text, loc_db, parser_result=None): - if parser_result: - e, start, stop = parser_result[self.parser] - self.expr = e - return start, stop - try: - v, start, stop = next(self.parser.scanString(text)) - except StopIteration: - return None, None - arg = v[0] - expr = self.asm_ast_to_expr(arg, loc_db) - self.expr = expr - return start, stop - - def asm_ast_to_expr(self, arg, loc_db, **kwargs): - raise NotImplementedError("Virtual") - - -class m_reg(m_arg): - prio = default_prio - - @property - def parser(self): - return self.reg.parser - - def decode(self, v): - self.expr = self.reg.expr[0] - return True - - def encode(self): - return self.expr == self.reg.expr[0] - - -class reg_noarg(object): - reg_info = None - parser = None - - def fromstring(self, text, loc_db, parser_result=None): - if parser_result: - e, start, stop = parser_result[self.parser] - self.expr = e - return start, stop - try: - v, start, stop = next(self.parser.scanString(text)) - except StopIteration: - return None, None - arg = v[0] - expr = self.parses_to_expr(arg, loc_db) - self.expr = expr - return start, stop - - def decode(self, v): - v = v & self.lmask - if v >= len(self.reg_info.expr): - return False - self.expr = self.reg_info.expr[v] - return True - - def encode(self): - if not self.expr in self.reg_info.expr: - log.debug("cannot encode reg %r", self.expr) - return False - self.value = self.reg_info.expr.index(self.expr) - return True - - def check_fbits(self, v): - return v & self.fmask == self.fbits - - -class mn_prefix(object): - pass - - -def swap16(v): - return struct.unpack('<H', struct.pack('>H', v))[0] - - -def swap32(v): - return struct.unpack('<I', struct.pack('>I', v))[0] - - -def perm_inv(p): - o = [None for x in range(len(p))] - for i, x in enumerate(p): - o[x] = i - return o - - -def gen_bsint(value, l, args): - s = int2bin(value, l) - args = dict(args) - args.update({'strbits': s}) - f = bs(**args) - return f - -total_scans = 0 - - -def branch2nodes(branch, nodes=None): - if nodes is None: - nodes = [] - for k, v in viewitems(branch): - if not isinstance(v, dict): - continue - for k2 in v: - nodes.append((k, k2)) - branch2nodes(v, nodes) - - -def factor_one_bit(tree): - if isinstance(tree, set): - return tree - new_keys = defaultdict(lambda: defaultdict(dict)) - if len(tree) == 1: - return tree - for k, v in viewitems(tree): - if k == "mn": - new_keys[k] = v - continue - l, fmask, fbits, fname, flen = k - if flen is not None or l <= 1: - new_keys[k] = v - continue - cfmask = fmask >> (l - 1) - nfmask = fmask & ((1 << (l - 1)) - 1) - cfbits = fbits >> (l - 1) - nfbits = fbits & ((1 << (l - 1)) - 1) - ck = 1, cfmask, cfbits, None, flen - nk = l - 1, nfmask, nfbits, fname, flen - if nk in new_keys[ck]: - raise NotImplementedError('not fully functional') - new_keys[ck][nk] = v - for k, v in list(viewitems(new_keys)): - new_keys[k] = factor_one_bit(v) - # try factor sons - if len(new_keys) != 1: - return new_keys - subtree = next(iter(viewvalues(new_keys))) - if len(subtree) != 1: - return new_keys - if next(iter(subtree)) == 'mn': - return new_keys - - return new_keys - - -def factor_fields(tree): - if not isinstance(tree, dict): - return tree - if len(tree) != 1: - return tree - # merge - k1, v1 = next(iter(viewitems(tree))) - if k1 == "mn": - return tree - l1, fmask1, fbits1, fname1, flen1 = k1 - if fname1 is not None: - return tree - if flen1 is not None: - return tree - - if not isinstance(v1, dict): - return tree - if len(v1) != 1: - return tree - k2, v2 = next(iter(viewitems(v1))) - if k2 == "mn": - return tree - l2, fmask2, fbits2, fname2, flen2 = k2 - if fname2 is not None: - return tree - if flen2 is not None: - return tree - l = l1 + l2 - fmask = (fmask1 << l2) | fmask2 - fbits = (fbits1 << l2) | fbits2 - fname = fname2 - flen = flen2 - k = l, fmask, fbits, fname, flen - new_keys = {k: v2} - return new_keys - - -def factor_fields_all(tree): - if not isinstance(tree, dict): - return tree - new_keys = {} - for k, v in viewitems(tree): - v = factor_fields(v) - new_keys[k] = factor_fields_all(v) - return new_keys - - -def graph_tree(tree): - nodes = [] - branch2nodes(tree, nodes) - - out = """ - digraph G { - """ - for a, b in nodes: - if b == 'mn': - continue - out += "%s -> %s;\n" % (id(a), id(b)) - out += "}" - open('graph.txt', 'w').write(out) - - -def add_candidate_to_tree(tree, c): - branch = tree - for f in c.fields: - if f.l == 0: - continue - node = f.l, f.fmask, f.fbits, f.fname, f.flen - - if not node in branch: - branch[node] = {} - branch = branch[node] - if not 'mn' in branch: - branch['mn'] = set() - branch['mn'].add(c) - - -def add_candidate(bases, c): - add_candidate_to_tree(bases[0].bintree, c) - - -def getfieldby_name(fields, fname): - f = [x for x in fields if hasattr(x, 'fname') and x.fname == fname] - if len(f) != 1: - raise ValueError('more than one field with name: %s' % fname) - return f[0] - - -def getfieldindexby_name(fields, fname): - for i, f in enumerate(fields): - if hasattr(f, 'fname') and f.fname == fname: - return f, i - return None - - -class metamn(type): - - def __new__(mcs, name, bases, dct): - if name == "cls_mn" or name.startswith('mn_'): - return type.__new__(mcs, name, bases, dct) - alias = dct.get('alias', False) - - fields = bases[0].mod_fields(dct['fields']) - if not 'name' in dct: - dct["name"] = bases[0].getmn(name) - if 'args' in dct: - # special case for permuted arguments - o = [] - p = [] - for i, a in enumerate(dct['args']): - o.append((i, a)) - if a in fields: - p.append((fields.index(a), a)) - p.sort() - p = [x[1] for x in p] - p = [dct['args'].index(x) for x in p] - dct['args_permut'] = perm_inv(p) - # order fields - f_ordered = [x for x in enumerate(fields)] - f_ordered.sort(key=lambda x: (x[1].prio, x[0])) - candidates = bases[0].gen_modes(mcs, name, bases, dct, fields) - for i, fc in f_ordered: - if isinstance(fc, bs_divert): - candidates = fc.divert(i, candidates) - for cls, name, bases, dct, fields in candidates: - ndct = dict(dct) - fields = [f for f in fields if f] - ndct['fields'] = fields - ndct['mn_len'] = sum([x.l for x in fields]) - c = type.__new__(cls, name, bases, ndct) - c.alias = alias - c.check_mnemo(fields) - c.num = bases[0].num - bases[0].num += 1 - bases[0].all_mn.append(c) - mode = dct['mode'] - bases[0].all_mn_mode[mode].append(c) - bases[0].all_mn_name[c.name].append(c) - i = c() - i.init_class() - bases[0].all_mn_inst[c].append(i) - add_candidate(bases, c) - # gen byte lookup - o = "" - for f in i.fields_order: - if not isinstance(f, bsi): - raise ValueError('f is not bsi') - if f.l == 0: - continue - o += f.strbits - return c - - -class instruction(object): - __slots__ = ["name", "mode", "args", - "l", "b", "offset", "data", - "additional_info", "delayslot"] - - def __init__(self, name, mode, args, additional_info=None): - self.name = name - self.mode = mode - self.args = args - self.additional_info = additional_info - self.offset = None - self.l = None - self.b = None - self.delayslot = 0 - - def gen_args(self, args): - out = ', '.join([str(x) for x in args]) - return out - - def __str__(self): - return self.to_string() - - def to_string(self, loc_db=None): - o = "%-10s " % self.name - args = [] - for i, arg in enumerate(self.args): - if not isinstance(arg, m2_expr.Expr): - raise ValueError('zarb arg type') - x = self.arg2str(arg, i, loc_db) - args.append(x) - o += self.gen_args(args) - return o - - def to_html(self, loc_db=None): - out = "%-10s " % self.name - out = '<font color="%s">%s</font>' % (utils.COLOR_MNEMO, out) - - args = [] - for i, arg in enumerate(self.args): - if not isinstance(arg, m2_expr.Expr): - raise ValueError('zarb arg type') - x = self.arg2html(arg, i, loc_db) - args.append(x) - out += self.gen_args(args) - return out - - def get_asm_offset(self, expr): - return m2_expr.ExprInt(self.offset, expr.size) - - def get_asm_next_offset(self, expr): - return m2_expr.ExprInt(self.offset+self.l, expr.size) - - def resolve_args_with_symbols(self, loc_db): - args_out = [] - for expr in self.args: - # try to resolve symbols using loc_db (0 for default value) - loc_keys = m2_expr.get_expr_locs(expr) - fixed_expr = {} - for exprloc in loc_keys: - loc_key = exprloc.loc_key - names = loc_db.get_location_names(loc_key) - # special symbols - if '$' in names: - fixed_expr[exprloc] = self.get_asm_offset(exprloc) - continue - if '_' in names: - fixed_expr[exprloc] = self.get_asm_next_offset(exprloc) - continue - arg_int = loc_db.get_location_offset(loc_key) - if arg_int is not None: - fixed_expr[exprloc] = m2_expr.ExprInt(arg_int, exprloc.size) - continue - if not names: - raise ValueError('Unresolved symbol: %r' % exprloc) - - offset = loc_db.get_location_offset(loc_key) - if offset is None: - raise ValueError( - 'The offset of loc_key "%s" cannot be determined' % names - ) - else: - # Fix symbol with its offset - size = exprloc.size - if size is None: - default_size = self.get_symbol_size(exprloc, loc_db) - size = default_size - value = m2_expr.ExprInt(offset, size) - fixed_expr[exprloc] = value - - expr = expr.replace_expr(fixed_expr) - expr = expr_simp(expr) - args_out.append(expr) - return args_out - - def get_info(self, c): - return - - -class cls_mn(with_metaclass(metamn, object)): - args_symb = [] - instruction = instruction - # Block's offset alignment - alignment = 1 - - @classmethod - def guess_mnemo(cls, bs, attrib, pre_dis_info, offset): - candidates = [] - - candidates = set() - - fname_values = pre_dis_info - todo = [ - (dict(fname_values), branch, offset * 8) - for branch in list(viewitems(cls.bintree)) - ] - for fname_values, branch, offset_b in todo: - (l, fmask, fbits, fname, flen), vals = branch - - if flen is not None: - l = flen(attrib, fname_values) - if l is not None: - try: - v = cls.getbits(bs, attrib, offset_b, l) - except IOError: - # Raised if offset is out of bound - continue - offset_b += l - if v & fmask != fbits: - continue - if fname is not None and not fname in fname_values: - fname_values[fname] = v - for nb, v in viewitems(vals): - if 'mn' in nb: - candidates.update(v) - else: - todo.append((dict(fname_values), (nb, v), offset_b)) - - return [c for c in candidates] - - def reset_class(self): - for f in self.fields_order: - if f.strbits and isbin(f.strbits): - f.value = int(f.strbits, 2) - elif 'default_val' in f.kargs: - f.value = int(f.kargs['default_val'], 2) - else: - f.value = None - if f.fname: - setattr(self, f.fname, f) - - def init_class(self): - args = [] - fields_order = [] - to_decode = [] - off = 0 - for i, fc in enumerate(self.fields): - f = fc.gen(self) - f.offset = off - off += f.l - fields_order.append(f) - to_decode.append((i, f)) - - if isinstance(f, m_arg): - args.append(f) - if f.fname: - setattr(self, f.fname, f) - if hasattr(self, 'args_permut'): - args = [args[self.args_permut[i]] - for i in range(len(self.args_permut))] - to_decode.sort(key=lambda x: (x[1].order, x[0])) - to_decode = [fields_order.index(f[1]) for f in to_decode] - self.args = args - self.fields_order = fields_order - self.to_decode = to_decode - - def add_pre_dis_info(self, prefix=None): - return True - - @classmethod - def getbits(cls, bs, attrib, offset_b, l): - return bs.getbits(offset_b, l) - - @classmethod - def getbytes(cls, bs, offset, l): - return bs.getbytes(offset, l) - - @classmethod - def pre_dis(cls, v_o, attrib, offset): - return {}, v_o, attrib, offset, 0 - - def post_dis(self): - return self - - @classmethod - def check_mnemo(cls, fields): - pass - - @classmethod - def mod_fields(cls, fields): - return fields - - @classmethod - def dis(cls, bs_o, mode_o = None, offset=0): - if not isinstance(bs_o, bin_stream): - bs_o = bin_stream_str(bs_o) - - bs_o.enter_atomic_mode() - - offset_o = offset - try: - pre_dis_info, bs, mode, offset, prefix_len = cls.pre_dis( - bs_o, mode_o, offset) - except: - bs_o.leave_atomic_mode() - raise - candidates = cls.guess_mnemo(bs, mode, pre_dis_info, offset) - if not candidates: - bs_o.leave_atomic_mode() - raise Disasm_Exception('cannot disasm (guess) at %X' % offset) - - out = [] - out_c = [] - if hasattr(bs, 'getlen'): - bs_l = bs.getlen() - else: - bs_l = len(bs) - - alias = False - for c in candidates: - log.debug("*" * 40, mode, c.mode) - log.debug(c.fields) - - c = cls.all_mn_inst[c][0] - - c.reset_class() - c.mode = mode - - if not c.add_pre_dis_info(pre_dis_info): - continue - - todo = {} - getok = True - fname_values = dict(pre_dis_info) - offset_b = offset * 8 - - total_l = 0 - for i, f in enumerate(c.fields_order): - if f.flen is not None: - l = f.flen(mode, fname_values) - else: - l = f.l - if l is not None: - total_l += l - f.l = l - f.is_present = True - log.debug("FIELD %s %s %s %s", f.__class__, f.fname, - offset_b, l) - if bs_l * 8 - offset_b < l: - getok = False - break - try: - bv = cls.getbits(bs, mode, offset_b, l) - except: - bs_o.leave_atomic_mode() - raise - offset_b += l - if not f.fname in fname_values: - fname_values[f.fname] = bv - todo[i] = bv - else: - f.is_present = False - todo[i] = None - - if not getok: - continue - - c.l = prefix_len + total_l // 8 - for i in c.to_decode: - f = c.fields_order[i] - if f.is_present: - ret = f.decode(todo[i]) - if not ret: - log.debug("cannot decode %r", f) - break - - if not ret: - continue - for a in c.args: - a.expr = expr_simp(a.expr) - - c.b = cls.getbytes(bs, offset_o, c.l) - c.offset = offset_o - c = c.post_dis() - if c is None: - continue - c_args = [a.expr for a in c.args] - instr = cls.instruction(c.name, mode, c_args, - additional_info=c.additional_info()) - instr.l = prefix_len + total_l // 8 - instr.b = cls.getbytes(bs, offset_o, instr.l) - instr.offset = offset_o - instr.get_info(c) - if c.alias: - alias = True - out.append(instr) - out_c.append(c) - - bs_o.leave_atomic_mode() - - if not out: - raise Disasm_Exception('cannot disasm at %X' % offset_o) - if len(out) != 1: - if not alias: - log.warning('dis multiple args ret default') - - for i, o in enumerate(out_c): - if o.alias: - return out[i] - raise NotImplementedError( - 'Multiple disas: \n' + - "\n".join(str(x) for x in out) - ) - return out[0] - - @classmethod - def fromstring(cls, text, loc_db, mode = None): - global total_scans - name = re.search(r'(\S+)', text).groups() - if not name: - raise ValueError('cannot find name', text) - name = name[0] - - if not name in cls.all_mn_name: - raise ValueError('unknown name', name) - clist = [x for x in cls.all_mn_name[name]] - out = [] - out_args = [] - parsers = defaultdict(dict) - - for cc in clist: - for c in cls.get_cls_instance(cc, mode): - args_expr = [] - args_str = text[len(name):].strip(' ') - - start = 0 - cannot_parse = False - len_o = len(args_str) - - for i, f in enumerate(c.args): - start_i = len_o - len(args_str) - if type(f.parser) == tuple: - parser = f.parser - else: - parser = (f.parser,) - for p in parser: - if p in parsers[(i, start_i)]: - continue - try: - total_scans += 1 - v, start, stop = next(p.scanString(args_str)) - except StopIteration: - v, start, stop = [None], None, None - if start != 0: - v, start, stop = [None], None, None - if v != [None]: - v = f.asm_ast_to_expr(v[0], loc_db) - if v is None: - v, start, stop = [None], None, None - parsers[(i, start_i)][p] = v, start, stop - start, stop = f.fromstring(args_str, loc_db, parsers[(i, start_i)]) - if start != 0: - log.debug("cannot fromstring %r", args_str) - cannot_parse = True - break - if f.expr is None: - raise NotImplementedError('not fully functional') - f.expr = expr_simp(f.expr) - args_expr.append(f.expr) - args_str = args_str[stop:].strip(' ') - if args_str.startswith(','): - args_str = args_str[1:] - args_str = args_str.strip(' ') - if args_str: - cannot_parse = True - if cannot_parse: - continue - - out.append(c) - out_args.append(args_expr) - break - - if len(out) == 0: - raise ValueError('cannot fromstring %r' % text) - if len(out) != 1: - log.debug('fromstring multiple args ret default') - c = out[0] - c_args = out_args[0] - - instr = cls.instruction(c.name, mode, c_args, - additional_info=c.additional_info()) - return instr - - def dup_info(self, infos): - return - - @classmethod - def get_cls_instance(cls, cc, mode, infos=None): - c = cls.all_mn_inst[cc][0] - - c.reset_class() - c.add_pre_dis_info() - c.dup_info(infos) - - c.mode = mode - yield c - - @classmethod - def asm(cls, instr, loc_db=None): - """ - Re asm instruction by searching mnemo using name and args. We then - can modify args and get the hex of a modified instruction - """ - clist = cls.all_mn_name[instr.name] - clist = [x for x in clist] - vals = [] - candidates = [] - args = instr.resolve_args_with_symbols(loc_db) - - for cc in clist: - - for c in cls.get_cls_instance( - cc, instr.mode, instr.additional_info): - - cannot_parse = False - if len(c.args) != len(instr.args): - continue - - # only fix args expr - for i in range(len(c.args)): - c.args[i].expr = args[i] - - v = c.value(instr.mode) - if not v: - log.debug("cannot encode %r", c) - cannot_parse = True - if cannot_parse: - continue - vals += v - candidates.append((c, v)) - if len(vals) == 0: - raise ValueError( - 'cannot asm %r %r' % - (instr.name, [str(x) for x in instr.args]) - ) - if len(vals) != 1: - log.debug('asm multiple args ret default') - - vals = cls.filter_asm_candidates(instr, candidates) - return vals - - @classmethod - def filter_asm_candidates(cls, instr, candidates): - o = [] - for _, v in candidates: - o += v - o.sort(key=len) - return o - - def value(self, mode): - todo = [(0, 0, [(x, self.fields_order[x]) for x in self.to_decode[::-1]])] - - result = [] - done = [] - - while todo: - index, cur_len, to_decode = todo.pop() - # TEST XXX - for _, f in to_decode: - setattr(self, f.fname, f) - if (index, [x[1].value for x in to_decode]) in done: - continue - done.append((index, [x[1].value for x in to_decode])) - - can_encode = True - for i, f in to_decode[index:]: - f.parent.l = cur_len - ret = f.encode() - if not ret: - log.debug('cannot encode %r', f) - can_encode = False - break - - if f.value is not None and f.l: - if f.value > f.lmask: - log.debug('cannot encode %r', f) - can_encode = False - break - cur_len += f.l - index += 1 - if ret is True: - continue - - for _ in ret: - o = [] - if ((index, cur_len, [xx[1].value for xx in to_decode]) in todo or - (index, cur_len, [xx[1].value for xx in to_decode]) in done): - raise NotImplementedError('not fully functional') - - for p, f in to_decode: - fnew = f.clone() - o.append((p, fnew)) - todo.append((index, cur_len, o)) - can_encode = False - - break - if not can_encode: - continue - result.append(to_decode) - - return self.decoded2bytes(result) - - def encodefields(self, decoded): - bits = bitobj() - for _, f in decoded: - setattr(self, f.fname, f) - - if f.value is None: - continue - bits.putbits(f.value, f.l) - - return bits.tostring() - - def decoded2bytes(self, result): - if not result: - return [] - - out = [] - for decoded in result: - decoded.sort() - - o = self.encodefields(decoded) - if o is None: - continue - out.append(o) - out = list(set(out)) - return out - - def gen_args(self, args): - out = ', '.join([str(x) for x in args]) - return out - - def args2str(self): - args = [] - for arg in self.args: - # XXX todo test - if not (isinstance(arg, m2_expr.Expr) or - isinstance(arg.expr, m2_expr.Expr)): - raise ValueError('zarb arg type') - x = str(arg) - args.append(x) - return args - - def __str__(self): - o = "%-10s " % self.name - args = [] - for arg in self.args: - # XXX todo test - if not (isinstance(arg, m2_expr.Expr) or - isinstance(arg.expr, m2_expr.Expr)): - raise ValueError('zarb arg type') - x = str(arg) - args.append(x) - - o += self.gen_args(args) - return o - - def parse_prefix(self, v): - return 0 - - def set_dst_symbol(self, loc_db): - dst = self.getdstflow(loc_db) - args = [] - for d in dst: - if isinstance(d, m2_expr.ExprInt): - l = loc_db.get_or_create_offset_location(int(d)) - - a = m2_expr.ExprId(l.name, d.size) - else: - a = d - args.append(a) - self.args_symb = args - - def getdstflow(self, loc_db): - return [self.args[0].expr] - - -class imm_noarg(object): - intsize = 32 - intmask = (1 << intsize) - 1 - - def int2expr(self, v): - if (v & ~self.intmask) != 0: - return None - return m2_expr.ExprInt(v, self.intsize) - - def expr2int(self, e): - if not isinstance(e, m2_expr.ExprInt): - return None - v = int(e) - if v & ~self.intmask != 0: - return None - return v - - def fromstring(self, text, loc_db, parser_result=None): - if parser_result: - e, start, stop = parser_result[self.parser] - else: - try: - e, start, stop = next(self.parser.scanString(text)) - except StopIteration: - return None, None - if e == [None]: - return None, None - - assert(m2_expr.is_expr(e)) - self.expr = e - if self.expr is None: - log.debug('cannot fromstring int %r', text) - return None, None - return start, stop - - def decodeval(self, v): - return v - - def encodeval(self, v): - return v - - def decode(self, v): - v = v & self.lmask - v = self.decodeval(v) - e = self.int2expr(v) - if not e: - return False - self.expr = e - return True - - def encode(self): - v = self.expr2int(self.expr) - if v is None: - return False - v = self.encodeval(v) - if v is False: - return False - self.value = v - return True - - -class imm08_noarg(object): - int2expr = lambda self, x: m2_expr.ExprInt(x, 8) - - -class imm16_noarg(object): - int2expr = lambda self, x: m2_expr.ExprInt(x, 16) - - -class imm32_noarg(object): - int2expr = lambda self, x: m2_expr.ExprInt(x, 32) - - -class imm64_noarg(object): - int2expr = lambda self, x: m2_expr.ExprInt(x, 64) - - -class int32_noarg(imm_noarg): - intsize = 32 - intmask = (1 << intsize) - 1 - - def decode(self, v): - v = sign_ext(v, self.l, self.intsize) - v = self.decodeval(v) - self.expr = self.int2expr(v) - return True - - def encode(self): - if not isinstance(self.expr, m2_expr.ExprInt): - return False - v = int(self.expr) - if sign_ext(v & self.lmask, self.l, self.intsize) != v: - return False - v = self.encodeval(v & self.lmask) - if v is False: - return False - self.value = v & self.lmask - return True - -class bs8(bs): - prio = default_prio - - def __init__(self, v, cls=None, fname=None, **kargs): - super(bs8, self).__init__(int2bin(v, 8), 8, - cls=cls, fname=fname, **kargs) - - - - -def swap_uint(size, i): - if size == 8: - return i & 0xff - elif size == 16: - return struct.unpack('<H', struct.pack('>H', i & 0xffff))[0] - elif size == 32: - return struct.unpack('<I', struct.pack('>I', i & 0xffffffff))[0] - elif size == 64: - return struct.unpack('<Q', struct.pack('>Q', i & 0xffffffffffffffff))[0] - raise ValueError('unknown int len %r' % size) - - -def swap_sint(size, i): - if size == 8: - return i - elif size == 16: - return struct.unpack('<h', struct.pack('>H', i & 0xffff))[0] - elif size == 32: - return struct.unpack('<i', struct.pack('>I', i & 0xffffffff))[0] - elif size == 64: - return struct.unpack('<q', struct.pack('>Q', i & 0xffffffffffffffff))[0] - raise ValueError('unknown int len %r' % size) - - -def sign_ext(v, s_in, s_out): - assert(s_in <= s_out) - v &= (1 << s_in) - 1 - sign_in = v & (1 << (s_in - 1)) - if not sign_in: - return v - m = (1 << (s_out)) - 1 - m ^= (1 << s_in) - 1 - v |= m - return v diff --git a/miasm/core/ctypesmngr.py b/miasm/core/ctypesmngr.py deleted file mode 100644 index 94c96f7e..00000000 --- a/miasm/core/ctypesmngr.py +++ /dev/null @@ -1,771 +0,0 @@ -import re - -from pycparser import c_parser, c_ast - -RE_HASH_CMT = re.compile(r'^#\s*\d+.*$', flags=re.MULTILINE) - -# Ref: ISO/IEC 9899:TC2 -# http://www.open-std.org/jtc1/sc22/wg14/www/docs/n1124.pdf - - -def c_to_ast(parser, c_str): - """Transform a @c_str into a C ast - Note: will ignore lines containing code refs ie: - # 23 "miasm.h" - - @parser: pycparser instance - @c_str: c string - """ - - new_str = re.sub(RE_HASH_CMT, "", c_str) - return parser.parse(new_str, filename='<stdin>') - - -class CTypeBase(object): - """Object to represent the 3 forms of C type: - * object types - * function types - * incomplete types - """ - - def __init__(self): - self.__repr = str(self) - self.__hash = hash(self.__repr) - - @property - def _typerepr(self): - return self.__repr - - def __eq__(self, other): - raise NotImplementedError("Abstract method") - - def __ne__(self, other): - return not self.__eq__(other) - - def eq_base(self, other): - """Trivial common equality test""" - return self.__class__ == other.__class__ - - def __hash__(self): - return self.__hash - - def __repr__(self): - return self._typerepr - - -class CTypeId(CTypeBase): - """C type id: - int - unsigned int - """ - - def __init__(self, *names): - # Type specifier order does not matter - # so the canonical form is ordered - self.names = tuple(sorted(names)) - super(CTypeId, self).__init__() - - def __hash__(self): - return hash((self.__class__, self.names)) - - def __eq__(self, other): - return (self.eq_base(other) and - self.names == other.names) - - def __ne__(self, other): - return not self.__eq__(other) - - def __str__(self): - return "<Id:%s>" % ', '.join(self.names) - - -class CTypeArray(CTypeBase): - """C type for array: - typedef int XXX[4]; - """ - - def __init__(self, target, size): - assert isinstance(target, CTypeBase) - self.target = target - self.size = size - super(CTypeArray, self).__init__() - - def __hash__(self): - return hash((self.__class__, self.target, self.size)) - - def __eq__(self, other): - return (self.eq_base(other) and - self.target == other.target and - self.size == other.size) - - def __ne__(self, other): - return not self.__eq__(other) - - def __str__(self): - return "<Array[%s]:%s>" % (self.size, str(self.target)) - - -class CTypePtr(CTypeBase): - """C type for pointer: - typedef int* XXX; - """ - - def __init__(self, target): - assert isinstance(target, CTypeBase) - self.target = target - super(CTypePtr, self).__init__() - - def __hash__(self): - return hash((self.__class__, self.target)) - - def __eq__(self, other): - return (self.eq_base(other) and - self.target == other.target) - - def __ne__(self, other): - return not self.__eq__(other) - - def __str__(self): - return "<Ptr:%s>" % str(self.target) - - -class CTypeStruct(CTypeBase): - """C type for structure""" - - def __init__(self, name, fields=None): - assert name is not None - self.name = name - if fields is None: - fields = () - for field_name, field in fields: - assert field_name is not None - assert isinstance(field, CTypeBase) - self.fields = tuple(fields) - super(CTypeStruct, self).__init__() - - def __hash__(self): - return hash((self.__class__, self.name, self.fields)) - - def __eq__(self, other): - return (self.eq_base(other) and - self.name == other.name and - self.fields == other.fields) - - def __ne__(self, other): - return not self.__eq__(other) - - def __str__(self): - out = [] - out.append("<Struct:%s>" % self.name) - for name, field in self.fields: - out.append("\t%-10s %s" % (name, field)) - return '\n'.join(out) - - -class CTypeUnion(CTypeBase): - """C type for union""" - - def __init__(self, name, fields=None): - assert name is not None - self.name = name - if fields is None: - fields = [] - for field_name, field in fields: - assert field_name is not None - assert isinstance(field, CTypeBase) - self.fields = tuple(fields) - super(CTypeUnion, self).__init__() - - def __hash__(self): - return hash((self.__class__, self.name, self.fields)) - - def __eq__(self, other): - return (self.eq_base(other) and - self.name == other.name and - self.fields == other.fields) - - def __str__(self): - out = [] - out.append("<Union:%s>" % self.name) - for name, field in self.fields: - out.append("\t%-10s %s" % (name, field)) - return '\n'.join(out) - - -class CTypeEnum(CTypeBase): - """C type for enums""" - - def __init__(self, name): - self.name = name - super(CTypeEnum, self).__init__() - - def __hash__(self): - return hash((self.__class__, self.name)) - - def __eq__(self, other): - return (self.eq_base(other) and - self.name == other.name) - - def __ne__(self, other): - return not self.__eq__(other) - - def __str__(self): - return "<Enum:%s>" % self.name - - -class CTypeFunc(CTypeBase): - """C type for enums""" - - def __init__(self, name, abi=None, type_ret=None, args=None): - if type_ret: - assert isinstance(type_ret, CTypeBase) - if args: - for arg_name, arg in args: - assert isinstance(arg, CTypeBase) - args = tuple(args) - else: - args = tuple() - self.name = name - self.abi = abi - self.type_ret = type_ret - self.args = args - super(CTypeFunc, self).__init__() - - def __hash__(self): - return hash((self.__class__, self.name, self.abi, - self.type_ret, self.args)) - - def __eq__(self, other): - return (self.eq_base(other) and - self.name == other.name and - self.abi == other.abi and - self.type_ret == other.type_ret and - self.args == other.args) - - def __ne__(self, other): - return not self.__eq__(other) - - def __str__(self): - return "<Func:%s (%s) %s(%s)>" % (self.type_ret, - self.abi, - self.name, - ", ".join(["%s %s" % (name, arg) for (name, arg) in self.args])) - - -class CTypeEllipsis(CTypeBase): - """C type for ellipsis argument (...)""" - - def __hash__(self): - return hash((self.__class__)) - - def __eq__(self, other): - return self.eq_base(other) - - def __ne__(self, other): - return not self.__eq__(other) - - def __str__(self): - return "<Ellipsis>" - - -class CTypeSizeof(CTypeBase): - """C type for sizeof""" - - def __init__(self, target): - self.target = target - super(CTypeSizeof, self).__init__() - - def __hash__(self): - return hash((self.__class__, self.target)) - - def __eq__(self, other): - return (self.eq_base(other) and - self.target == other.target) - - def __ne__(self, other): - return not self.__eq__(other) - - def __str__(self): - return "<Sizeof(%s)>" % self.target - - -class CTypeOp(CTypeBase): - """C type for operator (+ * ...)""" - - def __init__(self, operator, *args): - self.operator = operator - self.args = tuple(args) - super(CTypeOp, self).__init__() - - def __hash__(self): - return hash((self.__class__, self.operator, self.args)) - - def __eq__(self, other): - return (self.eq_base(other) and - self.operator == other.operator and - self.args == other.args) - - def __str__(self): - return "<CTypeOp(%s, %s)>" % (self.operator, - ', '.join([str(arg) for arg in self.args])) - - -class FuncNameIdentifier(c_ast.NodeVisitor): - """Visit an c_ast to find IdentifierType""" - - def __init__(self): - super(FuncNameIdentifier, self).__init__() - self.node_name = None - - def visit_TypeDecl(self, node): - """Retrieve the name in a function declaration: - Only one IdentifierType is present""" - self.node_name = node - - -class CAstTypes(object): - """Store all defined C types and typedefs""" - INTERNAL_PREFIX = "__GENTYPE__" - ANONYMOUS_PREFIX = "__ANONYMOUS__" - - def __init__(self, knowntypes=None, knowntypedefs=None): - if knowntypes is None: - knowntypes = {} - if knowntypedefs is None: - knowntypedefs = {} - - self._types = dict(knowntypes) - self._typedefs = dict(knowntypedefs) - self.cpt = 0 - self.loc_to_decl_info = {} - self.parser = c_parser.CParser() - self._cpt_decl = 0 - - - self.ast_to_typeid_rules = { - c_ast.Struct: self.ast_to_typeid_struct, - c_ast.Union: self.ast_to_typeid_union, - c_ast.IdentifierType: self.ast_to_typeid_identifiertype, - c_ast.TypeDecl: self.ast_to_typeid_typedecl, - c_ast.Decl: self.ast_to_typeid_decl, - c_ast.Typename: self.ast_to_typeid_typename, - c_ast.FuncDecl: self.ast_to_typeid_funcdecl, - c_ast.Enum: self.ast_to_typeid_enum, - c_ast.PtrDecl: self.ast_to_typeid_ptrdecl, - c_ast.EllipsisParam: self.ast_to_typeid_ellipsisparam, - c_ast.ArrayDecl: self.ast_to_typeid_arraydecl, - } - - self.ast_parse_rules = { - c_ast.Struct: self.ast_parse_struct, - c_ast.Union: self.ast_parse_union, - c_ast.Typedef: self.ast_parse_typedef, - c_ast.TypeDecl: self.ast_parse_typedecl, - c_ast.IdentifierType: self.ast_parse_identifiertype, - c_ast.Decl: self.ast_parse_decl, - c_ast.PtrDecl: self.ast_parse_ptrdecl, - c_ast.Enum: self.ast_parse_enum, - c_ast.ArrayDecl: self.ast_parse_arraydecl, - c_ast.FuncDecl: self.ast_parse_funcdecl, - c_ast.FuncDef: self.ast_parse_funcdef, - c_ast.Pragma: self.ast_parse_pragma, - } - - def gen_uniq_name(self): - """Generate uniq name for unnamed strucs/union""" - cpt = self.cpt - self.cpt += 1 - return self.INTERNAL_PREFIX + "%d" % cpt - - def gen_anon_name(self): - """Generate name for anonymous strucs/union""" - cpt = self.cpt - self.cpt += 1 - return self.ANONYMOUS_PREFIX + "%d" % cpt - - def is_generated_name(self, name): - """Return True if the name is internal""" - return name.startswith(self.INTERNAL_PREFIX) - - def is_anonymous_name(self, name): - """Return True if the name is anonymous""" - return name.startswith(self.ANONYMOUS_PREFIX) - - def add_type(self, type_id, type_obj): - """Add new C type - @type_id: Type descriptor (CTypeBase instance) - @type_obj: Obj* instance""" - assert isinstance(type_id, CTypeBase) - if type_id in self._types: - assert self._types[type_id] == type_obj - else: - self._types[type_id] = type_obj - - def add_typedef(self, type_new, type_src): - """Add new typedef - @type_new: CTypeBase instance of the new type name - @type_src: CTypeBase instance of the target type""" - assert isinstance(type_src, CTypeBase) - self._typedefs[type_new] = type_src - - def get_type(self, type_id): - """Get ObjC corresponding to the @type_id - @type_id: Type descriptor (CTypeBase instance) - """ - assert isinstance(type_id, CTypeBase) - if isinstance(type_id, CTypePtr): - subobj = self.get_type(type_id.target) - return CTypePtr(subobj) - if type_id in self._types: - return self._types[type_id] - elif type_id in self._typedefs: - return self.get_type(self._typedefs[type_id]) - return type_id - - def is_known_type(self, type_id): - """Return true if @type_id is known - @type_id: Type descriptor (CTypeBase instance) - """ - if isinstance(type_id, CTypePtr): - return self.is_known_type(type_id.target) - if type_id in self._types: - return True - if type_id in self._typedefs: - return self.is_known_type(self._typedefs[type_id]) - return False - - def add_c_decl_from_ast(self, ast): - """ - Adds types from a C ast - @ast: C ast - """ - self.ast_parse_declarations(ast) - - - def digest_decl(self, c_str): - - char_id = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_" - - - # Seek deck - index_decl = [] - index = 0 - for decl in ['__cdecl__', '__stdcall__']: - index = 0 - while True: - index = c_str.find(decl, index) - if index == -1: - break - decl_off = index - decl_len = len(decl) - - index = index+len(decl) - while c_str[index] not in char_id: - index += 1 - - id_start = index - - while c_str[index] in char_id: - index += 1 - id_stop = index - - name = c_str[id_start:id_stop] - index_decl.append((decl_off, decl_len, id_start, id_stop, decl, )) - - index_decl.sort() - - # Remove decl - off = 0 - offsets = [] - for decl_off, decl_len, id_start, id_stop, decl in index_decl: - decl_off -= off - c_str = c_str[:decl_off] + c_str[decl_off+decl_len:] - off += decl_len - offsets.append((id_start-off, id_stop-off, decl)) - - index = 0 - lineno = 1 - - # Index to lineno, column - for id_start, id_stop, decl in offsets: - nbr = c_str.count('\n', index, id_start) - lineno += nbr - last_cr = c_str.rfind('\n', 0, id_start) - # column starts at 1 - column = id_start - last_cr - index = id_start - self.loc_to_decl_info[(lineno, column)] = decl - return c_str - - - def add_c_decl(self, c_str): - """ - Adds types from a C string types declaring - Note: will ignore lines containing code refs ie: - '# 23 "miasm.h"' - Returns the C ast - @c_str: C string containing C types declarations - """ - c_str = self.digest_decl(c_str) - - ast = c_to_ast(self.parser, c_str) - self.add_c_decl_from_ast(ast) - - return ast - - def ast_eval_int(self, ast): - """Eval a C ast object integer - - @ast: parsed pycparser.c_ast object - """ - - if isinstance(ast, c_ast.BinaryOp): - left = self.ast_eval_int(ast.left) - right = self.ast_eval_int(ast.right) - is_pure_int = (isinstance(left, int) and - isinstance(right, int)) - - if is_pure_int: - if ast.op == '*': - result = left * right - elif ast.op == '/': - assert left % right == 0 - result = left // right - elif ast.op == '+': - result = left + right - elif ast.op == '-': - result = left - right - elif ast.op == '<<': - result = left << right - elif ast.op == '>>': - result = left >> right - else: - raise NotImplementedError("Not implemented!") - else: - result = CTypeOp(ast.op, left, right) - - elif isinstance(ast, c_ast.UnaryOp): - if ast.op == 'sizeof' and isinstance(ast.expr, c_ast.Typename): - subobj = self.ast_to_typeid(ast.expr) - result = CTypeSizeof(subobj) - else: - raise NotImplementedError("Not implemented!") - - elif isinstance(ast, c_ast.Constant): - result = int(ast.value, 0) - elif isinstance(ast, c_ast.Cast): - # TODO: Can trunc integers? - result = self.ast_eval_int(ast.expr) - else: - raise NotImplementedError("Not implemented!") - return result - - def ast_to_typeid_struct(self, ast): - """Return the CTypeBase of an Struct ast""" - name = self.gen_uniq_name() if ast.name is None else ast.name - args = [] - if ast.decls: - for arg in ast.decls: - if arg.name is None: - arg_name = self.gen_anon_name() - else: - arg_name = arg.name - args.append((arg_name, self.ast_to_typeid(arg))) - decl = CTypeStruct(name, args) - return decl - - def ast_to_typeid_union(self, ast): - """Return the CTypeBase of an Union ast""" - name = self.gen_uniq_name() if ast.name is None else ast.name - args = [] - if ast.decls: - for arg in ast.decls: - if arg.name is None: - arg_name = self.gen_anon_name() - else: - arg_name = arg.name - args.append((arg_name, self.ast_to_typeid(arg))) - decl = CTypeUnion(name, args) - return decl - - def ast_to_typeid_identifiertype(self, ast): - """Return the CTypeBase of an IdentifierType ast""" - return CTypeId(*ast.names) - - def ast_to_typeid_typedecl(self, ast): - """Return the CTypeBase of a TypeDecl ast""" - return self.ast_to_typeid(ast.type) - - def ast_to_typeid_decl(self, ast): - """Return the CTypeBase of a Decl ast""" - return self.ast_to_typeid(ast.type) - - def ast_to_typeid_typename(self, ast): - """Return the CTypeBase of a TypeName ast""" - return self.ast_to_typeid(ast.type) - - def get_funcname(self, ast): - """Return the name of a function declaration ast""" - funcnameid = FuncNameIdentifier() - funcnameid.visit(ast) - node_name = funcnameid.node_name - if node_name.coord is not None: - lineno, column = node_name.coord.line, node_name.coord.column - decl_info = self.loc_to_decl_info.get((lineno, column), None) - else: - decl_info = None - return node_name.declname, decl_info - - def ast_to_typeid_funcdecl(self, ast): - """Return the CTypeBase of an FuncDecl ast""" - type_ret = self.ast_to_typeid(ast.type) - name, decl_info = self.get_funcname(ast.type) - if ast.args: - args = [] - for arg in ast.args.params: - typeid = self.ast_to_typeid(arg) - if isinstance(typeid, CTypeEllipsis): - arg_name = None - else: - arg_name = arg.name - args.append((arg_name, typeid)) - else: - args = [] - - obj = CTypeFunc(name, decl_info, type_ret, args) - decl = CTypeFunc(name) - if not self.is_known_type(decl): - self.add_type(decl, obj) - return obj - - def ast_to_typeid_enum(self, ast): - """Return the CTypeBase of an Enum ast""" - name = self.gen_uniq_name() if ast.name is None else ast.name - return CTypeEnum(name) - - def ast_to_typeid_ptrdecl(self, ast): - """Return the CTypeBase of a PtrDecl ast""" - return CTypePtr(self.ast_to_typeid(ast.type)) - - def ast_to_typeid_ellipsisparam(self, _): - """Return the CTypeBase of an EllipsisParam ast""" - return CTypeEllipsis() - - def ast_to_typeid_arraydecl(self, ast): - """Return the CTypeBase of an ArrayDecl ast""" - target = self.ast_to_typeid(ast.type) - if ast.dim is None: - value = None - else: - value = self.ast_eval_int(ast.dim) - return CTypeArray(target, value) - - def ast_to_typeid(self, ast): - """Return the CTypeBase of the @ast - @ast: pycparser.c_ast instance""" - cls = ast.__class__ - if not cls in self.ast_to_typeid_rules: - raise NotImplementedError("Strange type %r" % ast) - return self.ast_to_typeid_rules[cls](ast) - - # Ast parse type declarators - - def ast_parse_decl(self, ast): - """Parse ast Decl""" - return self.ast_parse_declaration(ast.type) - - def ast_parse_typedecl(self, ast): - """Parse ast Typedecl""" - return self.ast_parse_declaration(ast.type) - - def ast_parse_struct(self, ast): - """Parse ast Struct""" - obj = self.ast_to_typeid(ast) - if ast.decls and ast.name is not None: - # Add struct to types if named - decl = CTypeStruct(ast.name) - if not self.is_known_type(decl): - self.add_type(decl, obj) - return obj - - def ast_parse_union(self, ast): - """Parse ast Union""" - obj = self.ast_to_typeid(ast) - if ast.decls and ast.name is not None: - # Add union to types if named - decl = CTypeUnion(ast.name) - if not self.is_known_type(decl): - self.add_type(decl, obj) - return obj - - def ast_parse_typedef(self, ast): - """Parse ast TypeDef""" - decl = CTypeId(ast.name) - obj = self.ast_parse_declaration(ast.type) - if (isinstance(obj, (CTypeStruct, CTypeUnion)) and - self.is_generated_name(obj.name)): - # Add typedef name to default name - # for a question of clarity - obj.name += "__%s" % ast.name - self.add_typedef(decl, obj) - # Typedef does not return any object - return None - - def ast_parse_identifiertype(self, ast): - """Parse ast IdentifierType""" - return CTypeId(*ast.names) - - def ast_parse_ptrdecl(self, ast): - """Parse ast PtrDecl""" - return CTypePtr(self.ast_parse_declaration(ast.type)) - - def ast_parse_enum(self, ast): - """Parse ast Enum""" - return self.ast_to_typeid(ast) - - def ast_parse_arraydecl(self, ast): - """Parse ast ArrayDecl""" - return self.ast_to_typeid(ast) - - def ast_parse_funcdecl(self, ast): - """Parse ast FuncDecl""" - return self.ast_to_typeid(ast) - - def ast_parse_funcdef(self, ast): - """Parse ast FuncDef""" - return self.ast_to_typeid(ast.decl) - - def ast_parse_pragma(self, _): - """Prama does not return any object""" - return None - - def ast_parse_declaration(self, ast): - """Add one ast type declaration to the type manager - (packed style in type manager) - - @ast: parsed pycparser.c_ast object - """ - cls = ast.__class__ - if not cls in self.ast_parse_rules: - raise NotImplementedError("Strange declaration %r" % cls) - return self.ast_parse_rules[cls](ast) - - def ast_parse_declarations(self, ast): - """Add ast types declaration to the type manager - (packed style in type manager) - - @ast: parsed pycparser.c_ast object - """ - for ext in ast.ext: - ret = self.ast_parse_declaration(ext) - - def parse_c_type(self, c_str): - """Parse a C string representing a C type and return the associated - Miasm C object. - @c_str: C string of a C type - """ - - new_str = "%s __MIASM_INTERNAL_%s;" % (c_str, self._cpt_decl) - ret = self.parser.cparser.parse(input=new_str, lexer=self.parser.clex) - self._cpt_decl += 1 - return ret diff --git a/miasm/core/graph.py b/miasm/core/graph.py deleted file mode 100644 index debea38e..00000000 --- a/miasm/core/graph.py +++ /dev/null @@ -1,1123 +0,0 @@ -from collections import defaultdict, namedtuple - -from future.utils import viewitems, viewvalues -import re - - -class DiGraph(object): - - """Implementation of directed graph""" - - # Stand for a cell in a dot node rendering - DotCellDescription = namedtuple("DotCellDescription", - ["text", "attr"]) - - def __init__(self): - self._nodes = set() - self._edges = [] - # N -> Nodes N2 with a edge (N -> N2) - self._nodes_succ = {} - # N -> Nodes N2 with a edge (N2 -> N) - self._nodes_pred = {} - - self.escape_chars = re.compile('[' + re.escape('{}[]') + '&|<>' + ']') - - - def __repr__(self): - out = [] - for node in self._nodes: - out.append(str(node)) - for src, dst in self._edges: - out.append("%s -> %s" % (src, dst)) - return '\n'.join(out) - - def nodes(self): - return self._nodes - - def edges(self): - return self._edges - - def merge(self, graph): - """Merge the current graph with @graph - @graph: DiGraph instance - """ - for node in graph._nodes: - self.add_node(node) - for edge in graph._edges: - self.add_edge(*edge) - - def __add__(self, graph): - """Wrapper on `.merge`""" - self.merge(graph) - return self - - def copy(self): - """Copy the current graph instance""" - graph = self.__class__() - return graph + self - - def __eq__(self, graph): - if not isinstance(graph, self.__class__): - return False - if self._nodes != graph.nodes(): - return False - return sorted(self._edges) == sorted(graph.edges()) - - def __ne__(self, other): - return not self.__eq__(other) - - def add_node(self, node): - """Add the node @node to the graph. - If the node was already present, return False. - Otherwise, return True - """ - if node in self._nodes: - return False - self._nodes.add(node) - self._nodes_succ[node] = [] - self._nodes_pred[node] = [] - return True - - def del_node(self, node): - """Delete the @node of the graph; Also delete every edge to/from this - @node""" - - if node in self._nodes: - self._nodes.remove(node) - for pred in self.predecessors(node): - self.del_edge(pred, node) - for succ in self.successors(node): - self.del_edge(node, succ) - - def add_edge(self, src, dst): - if not src in self._nodes: - self.add_node(src) - if not dst in self._nodes: - self.add_node(dst) - self._edges.append((src, dst)) - self._nodes_succ[src].append(dst) - self._nodes_pred[dst].append(src) - - def add_uniq_edge(self, src, dst): - """Add an edge from @src to @dst if it doesn't already exist""" - if (src not in self._nodes_succ or - dst not in self._nodes_succ[src]): - self.add_edge(src, dst) - - def del_edge(self, src, dst): - self._edges.remove((src, dst)) - self._nodes_succ[src].remove(dst) - self._nodes_pred[dst].remove(src) - - def discard_edge(self, src, dst): - """Remove edge between @src and @dst if it exits""" - if (src, dst) in self._edges: - self.del_edge(src, dst) - - def predecessors_iter(self, node): - if not node in self._nodes_pred: - return - for n_pred in self._nodes_pred[node]: - yield n_pred - - def predecessors(self, node): - return [x for x in self.predecessors_iter(node)] - - def successors_iter(self, node): - if not node in self._nodes_succ: - return - for n_suc in self._nodes_succ[node]: - yield n_suc - - def successors(self, node): - return [x for x in self.successors_iter(node)] - - def leaves_iter(self): - for node in self._nodes: - if not self._nodes_succ[node]: - yield node - - def leaves(self): - return [x for x in self.leaves_iter()] - - def heads_iter(self): - for node in self._nodes: - if not self._nodes_pred[node]: - yield node - - def heads(self): - return [x for x in self.heads_iter()] - - def find_path(self, src, dst, cycles_count=0, done=None): - """ - Searches for paths from @src to @dst - @src: loc_key of basic block from which it should start - @dst: loc_key of basic block where it should stop - @cycles_count: maximum number of times a basic block can be processed - @done: dictionary of already processed loc_keys, it's value is number of times it was processed - @out: list of paths from @src to @dst - """ - if done is None: - done = {} - if dst in done and done[dst] > cycles_count: - return [[]] - if src == dst: - return [[src]] - out = [] - for node in self.predecessors(dst): - done_n = dict(done) - done_n[dst] = done_n.get(dst, 0) + 1 - for path in self.find_path(src, node, cycles_count, done_n): - if path and path[0] == src: - out.append(path + [dst]) - return out - - def find_path_from_src(self, src, dst, cycles_count=0, done=None): - """ - This function does the same as function find_path. - But it searches the paths from src to dst, not vice versa like find_path. - This approach might be more efficient in some cases. - @src: loc_key of basic block from which it should start - @dst: loc_key of basic block where it should stop - @cycles_count: maximum number of times a basic block can be processed - @done: dictionary of already processed loc_keys, it's value is number of times it was processed - @out: list of paths from @src to @dst - """ - - if done is None: - done = {} - if src == dst: - return [[src]] - if src in done and done[src] > cycles_count: - return [[]] - out = [] - for node in self.successors(src): - done_n = dict(done) - done_n[src] = done_n.get(src, 0) + 1 - for path in self.find_path_from_src(node, dst, cycles_count, done_n): - if path and path[len(path)-1] == dst: - out.append([src] + path) - return out - - def nodeid(self, node): - """ - Returns uniq id for a @node - @node: a node of the graph - """ - return hash(node) & 0xFFFFFFFFFFFFFFFF - - def node2lines(self, node): - """ - Returns an iterator on cells of the dot @node. - A DotCellDescription or a list of DotCellDescription are accepted - @node: a node of the graph - """ - yield self.DotCellDescription(text=str(node), attr={}) - - def node_attr(self, node): - """ - Returns a dictionary of the @node's attributes - @node: a node of the graph - """ - return {} - - def edge_attr(self, src, dst): - """ - Return a dictionary of attributes for the edge between @src and @dst - @src: the source node of the edge - @dst: the destination node of the edge - """ - return {} - - @staticmethod - def _fix_chars(token): - return "&#%04d;" % ord(token.group()) - - @staticmethod - def _attr2str(default_attr, attr): - return ' '.join( - '%s="%s"' % (name, value) - for name, value in - viewitems(dict(default_attr, - **attr)) - ) - - def escape_text(self, text): - return self.escape_chars.sub(self._fix_chars, text) - - def dot(self): - """Render dot graph with HTML""" - - td_attr = {'align': 'left'} - nodes_attr = {'shape': 'Mrecord', - 'fontname': 'Courier New'} - - out = ["digraph asm_graph {"] - - # Generate basic nodes - out_nodes = [] - for node in self.nodes(): - node_id = self.nodeid(node) - out_node = '%s [\n' % node_id - out_node += self._attr2str(nodes_attr, self.node_attr(node)) - out_node += 'label =<<table border="0" cellborder="0" cellpadding="3">' - - node_html_lines = [] - - for lineDesc in self.node2lines(node): - out_render = "" - if isinstance(lineDesc, self.DotCellDescription): - lineDesc = [lineDesc] - for col in lineDesc: - out_render += "<td %s>%s</td>" % ( - self._attr2str(td_attr, col.attr), - self.escape_text(str(col.text))) - node_html_lines.append(out_render) - - node_html_lines = ('<tr>' + - ('</tr><tr>').join(node_html_lines) + - '</tr>') - - out_node += node_html_lines + "</table>> ];" - out_nodes.append(out_node) - - out += out_nodes - - # Generate links - for src, dst in self.edges(): - attrs = self.edge_attr(src, dst) - - attrs = ' '.join( - '%s="%s"' % (name, value) - for name, value in viewitems(attrs) - ) - - out.append('%s -> %s' % (self.nodeid(src), self.nodeid(dst)) + - '[' + attrs + '];') - - out.append("}") - return '\n'.join(out) - - - def graphviz(self): - try: - import re - import graphviz - - - self.gv = graphviz.Digraph('html_table') - self._dot_offset = False - td_attr = {'align': 'left'} - nodes_attr = {'shape': 'Mrecord', - 'fontname': 'Courier New'} - - for node in self.nodes(): - elements = [x for x in self.node2lines(node)] - node_id = self.nodeid(node) - out_node = '<<table border="0" cellborder="0" cellpadding="3">' - - node_html_lines = [] - for lineDesc in elements: - out_render = "" - if isinstance(lineDesc, self.DotCellDescription): - lineDesc = [lineDesc] - for col in lineDesc: - out_render += "<td %s>%s</td>" % ( - self._attr2str(td_attr, col.attr), - self.escape_text(str(col.text))) - node_html_lines.append(out_render) - - node_html_lines = ('<tr>' + - ('</tr><tr>').join(node_html_lines) + - '</tr>') - - out_node += node_html_lines + "</table>>" - attrs = dict(nodes_attr) - attrs.update(self.node_attr(node)) - self.gv.node( - "%s" % node_id, - out_node, - attrs, - ) - - - for src, dst in self.edges(): - attrs = self.edge_attr(src, dst) - self.gv.edge( - str(self.nodeid(src)), - str(self.nodeid(dst)), - "", - attrs, - ) - - return self.gv - except ImportError: - # Skip as graphviz is not installed - return None - - - @staticmethod - def _reachable_nodes(head, next_cb): - """Generic algorithm to compute all nodes reachable from/to node - @head""" - - todo = set([head]) - reachable = set() - while todo: - node = todo.pop() - if node in reachable: - continue - reachable.add(node) - yield node - for next_node in next_cb(node): - todo.add(next_node) - - def predecessors_stop_node_iter(self, node, head): - if node == head: - return - for next_node in self.predecessors_iter(node): - yield next_node - - def reachable_sons(self, head): - """Compute all nodes reachable from node @head. Each son is an - immediate successor of an arbitrary, already yielded son of @head""" - return self._reachable_nodes(head, self.successors_iter) - - def reachable_parents(self, leaf): - """Compute all parents of node @leaf. Each parent is an immediate - predecessor of an arbitrary, already yielded parent of @leaf""" - return self._reachable_nodes(leaf, self.predecessors_iter) - - def reachable_parents_stop_node(self, leaf, head): - """Compute all parents of node @leaf. Each parent is an immediate - predecessor of an arbitrary, already yielded parent of @leaf. - Do not compute reachables past @head node""" - return self._reachable_nodes( - leaf, - lambda node_cur: self.predecessors_stop_node_iter( - node_cur, head - ) - ) - - - @staticmethod - def _compute_generic_dominators(head, reachable_cb, prev_cb, next_cb): - """Generic algorithm to compute either the dominators or postdominators - of the graph. - @head: the head/leaf of the graph - @reachable_cb: sons/parents of the head/leaf - @prev_cb: return predecessors/successors of a node - @next_cb: return successors/predecessors of a node - """ - - nodes = set(reachable_cb(head)) - dominators = {} - for node in nodes: - dominators[node] = set(nodes) - - dominators[head] = set([head]) - todo = set(nodes) - - while todo: - node = todo.pop() - - # Heads state must not be changed - if node == head: - continue - - # Compute intersection of all predecessors'dominators - new_dom = None - for pred in prev_cb(node): - if not pred in nodes: - continue - if new_dom is None: - new_dom = set(dominators[pred]) - new_dom.intersection_update(dominators[pred]) - - # We are not a head to we have at least one dominator - assert(new_dom is not None) - - new_dom.update(set([node])) - - # If intersection has changed, add sons to the todo list - if new_dom == dominators[node]: - continue - - dominators[node] = new_dom - for succ in next_cb(node): - todo.add(succ) - return dominators - - def compute_dominators(self, head): - """Compute the dominators of the graph""" - return self._compute_generic_dominators(head, - self.reachable_sons, - self.predecessors_iter, - self.successors_iter) - - def compute_postdominators(self, leaf): - """Compute the postdominators of the graph""" - return self._compute_generic_dominators(leaf, - self.reachable_parents, - self.successors_iter, - self.predecessors_iter) - - - - - def compute_dominator_tree(self, head): - """ - Computes the dominator tree of a graph - :param head: head of graph - :return: DiGraph - """ - idoms = self.compute_immediate_dominators(head) - dominator_tree = DiGraph() - for node in idoms: - dominator_tree.add_edge(idoms[node], node) - - return dominator_tree - - @staticmethod - def _walk_generic_dominator(node, gen_dominators, succ_cb): - """Generic algorithm to return an iterator of the ordered list of - @node's dominators/post_dominator. - - The function doesn't return the self reference in dominators. - @node: The start node - @gen_dominators: The dictionary containing at least node's - dominators/post_dominators - @succ_cb: return predecessors/successors of a node - - """ - # Init - done = set() - if node not in gen_dominators: - # We are in a branch which doesn't reach head - return - node_gen_dominators = set(gen_dominators[node]) - todo = set([node]) - - # Avoid working on itself - node_gen_dominators.remove(node) - - # For each level - while node_gen_dominators: - new_node = None - - # Worklist pattern - while todo: - node = todo.pop() - if node in done: - continue - if node in node_gen_dominators: - new_node = node - break - - # Avoid loops - done.add(node) - - # Look for the next level - for pred in succ_cb(node): - todo.add(pred) - - # Return the node; it's the next starting point - assert(new_node is not None) - yield new_node - node_gen_dominators.remove(new_node) - todo = set([new_node]) - - def walk_dominators(self, node, dominators): - """Return an iterator of the ordered list of @node's dominators - The function doesn't return the self reference in dominators. - @node: The start node - @dominators: The dictionary containing at least node's dominators - """ - return self._walk_generic_dominator(node, - dominators, - self.predecessors_iter) - - def walk_postdominators(self, node, postdominators): - """Return an iterator of the ordered list of @node's postdominators - The function doesn't return the self reference in postdominators. - @node: The start node - @postdominators: The dictionary containing at least node's - postdominators - - """ - return self._walk_generic_dominator(node, - postdominators, - self.successors_iter) - - def compute_immediate_dominators(self, head): - """Compute the immediate dominators of the graph""" - dominators = self.compute_dominators(head) - idoms = {} - - for node in dominators: - for predecessor in self.walk_dominators(node, dominators): - if predecessor in dominators[node] and node != predecessor: - idoms[node] = predecessor - break - return idoms - - def compute_immediate_postdominators(self,tail): - """Compute the immediate postdominators of the graph""" - postdominators = self.compute_postdominators(tail) - ipdoms = {} - - for node in postdominators: - for successor in self.walk_postdominators(node, postdominators): - if successor in postdominators[node] and node != successor: - ipdoms[node] = successor - break - return ipdoms - - def compute_dominance_frontier(self, head): - """ - Compute the dominance frontier of the graph - - Source: Cooper, Keith D., Timothy J. Harvey, and Ken Kennedy. - "A simple, fast dominance algorithm." - Software Practice & Experience 4 (2001), p. 9 - """ - idoms = self.compute_immediate_dominators(head) - frontier = {} - - for node in idoms: - if len(self._nodes_pred[node]) >= 2: - for predecessor in self.predecessors_iter(node): - runner = predecessor - if runner not in idoms: - continue - while runner != idoms[node]: - if runner not in frontier: - frontier[runner] = set() - - frontier[runner].add(node) - runner = idoms[runner] - return frontier - - def _walk_generic_first(self, head, flag, succ_cb): - """ - Generic algorithm to compute breadth or depth first search - for a node. - @head: the head of the graph - @flag: denotes if @todo is used as queue or stack - @succ_cb: returns a node's predecessors/successors - :return: next node - """ - todo = [head] - done = set() - - while todo: - node = todo.pop(flag) - if node in done: - continue - done.add(node) - - for succ in succ_cb(node): - todo.append(succ) - - yield node - - def walk_breadth_first_forward(self, head): - """Performs a breadth first search on the graph from @head""" - return self._walk_generic_first(head, 0, self.successors_iter) - - def walk_depth_first_forward(self, head): - """Performs a depth first search on the graph from @head""" - return self._walk_generic_first(head, -1, self.successors_iter) - - def walk_breadth_first_backward(self, head): - """Performs a breadth first search on the reversed graph from @head""" - return self._walk_generic_first(head, 0, self.predecessors_iter) - - def walk_depth_first_backward(self, head): - """Performs a depth first search on the reversed graph from @head""" - return self._walk_generic_first(head, -1, self.predecessors_iter) - - def has_loop(self): - """Return True if the graph contains at least a cycle""" - todo = list(self.nodes()) - # tested nodes - done = set() - # current DFS nodes - current = set() - while todo: - node = todo.pop() - if node in done: - continue - - if node in current: - # DFS branch end - for succ in self.successors_iter(node): - if succ in current: - return True - # A node cannot be in current AND in done - current.remove(node) - done.add(node) - else: - # Launch DFS from node - todo.append(node) - current.add(node) - todo += self.successors(node) - - return False - - def compute_natural_loops(self, head): - """ - Computes all natural loops in the graph. - - Source: Aho, Alfred V., Lam, Monica S., Sethi, R. and Jeffrey Ullman. - "Compilers: Principles, Techniques, & Tools, Second Edition" - Pearson/Addison Wesley (2007), Chapter 9.6.6 - :param head: head of the graph - :return: yield a tuple of the form (back edge, loop body) - """ - for a, b in self.compute_back_edges(head): - body = self._compute_natural_loop_body(b, a) - yield ((a, b), body) - - def compute_back_edges(self, head): - """ - Computes all back edges from a node to a - dominator in the graph. - :param head: head of graph - :return: yield a back edge - """ - dominators = self.compute_dominators(head) - - # traverse graph - for node in self.walk_depth_first_forward(head): - for successor in self.successors_iter(node): - # check for a back edge to a dominator - if successor in dominators[node]: - edge = (node, successor) - yield edge - - def _compute_natural_loop_body(self, head, leaf): - """ - Computes the body of a natural loop by a depth-first - search on the reversed control flow graph. - :param head: leaf of the loop - :param leaf: header of the loop - :return: set containing loop body - """ - todo = [leaf] - done = {head} - - while todo: - node = todo.pop() - if node in done: - continue - done.add(node) - - for predecessor in self.predecessors_iter(node): - todo.append(predecessor) - return done - - def compute_strongly_connected_components(self): - """ - Partitions the graph into strongly connected components. - - Iterative implementation of Gabow's path-based SCC algorithm. - Source: Gabow, Harold N. - "Path-based depth-first search for strong and biconnected components." - Information Processing Letters 74.3 (2000), pp. 109--110 - - The iterative implementation is inspired by Mark Dickinson's - code: - http://code.activestate.com/recipes/ - 578507-strongly-connected-components-of-a-directed-graph/ - :return: yield a strongly connected component - """ - stack = [] - boundaries = [] - counter = len(self.nodes()) - - # init index with 0 - index = {v: 0 for v in self.nodes()} - - # state machine for worklist algorithm - VISIT, HANDLE_RECURSION, MERGE = 0, 1, 2 - NodeState = namedtuple('NodeState', ['state', 'node']) - - for node in self.nodes(): - # next node if node was already visited - if index[node]: - continue - - todo = [NodeState(VISIT, node)] - done = set() - - while todo: - current = todo.pop() - - if current.node in done: - continue - - # node is unvisited - if current.state == VISIT: - stack.append(current.node) - index[current.node] = len(stack) - boundaries.append(index[current.node]) - - todo.append(NodeState(MERGE, current.node)) - # follow successors - for successor in self.successors_iter(current.node): - todo.append(NodeState(HANDLE_RECURSION, successor)) - - # iterative handling of recursion algorithm - elif current.state == HANDLE_RECURSION: - # visit unvisited successor - if index[current.node] == 0: - todo.append(NodeState(VISIT, current.node)) - else: - # contract cycle if necessary - while index[current.node] < boundaries[-1]: - boundaries.pop() - - # merge strongly connected component - else: - if index[current.node] == boundaries[-1]: - boundaries.pop() - counter += 1 - scc = set() - - while index[current.node] <= len(stack): - popped = stack.pop() - index[popped] = counter - scc.add(popped) - - done.add(current.node) - - yield scc - - - def compute_weakly_connected_components(self): - """ - Return the weakly connected components - """ - remaining = set(self.nodes()) - components = [] - while remaining: - node = remaining.pop() - todo = set() - todo.add(node) - component = set() - done = set() - while todo: - node = todo.pop() - if node in done: - continue - done.add(node) - remaining.discard(node) - component.add(node) - todo.update(self.predecessors(node)) - todo.update(self.successors(node)) - components.append(component) - return components - - - - def replace_node(self, node, new_node): - """ - Replace @node by @new_node - """ - - predecessors = self.predecessors(node) - successors = self.successors(node) - self.del_node(node) - for predecessor in predecessors: - if predecessor == node: - predecessor = new_node - self.add_uniq_edge(predecessor, new_node) - for successor in successors: - if successor == node: - successor = new_node - self.add_uniq_edge(new_node, successor) - -class DiGraphSimplifier(object): - - """Wrapper on graph simplification passes. - - Instance handle passes lists. - """ - - def __init__(self): - self.passes = [] - - def enable_passes(self, passes): - """Add @passes to passes to applied - @passes: sequence of function (DiGraphSimplifier, DiGraph) -> None - """ - self.passes += passes - - def apply_simp(self, graph): - """Apply enabled simplifications on graph @graph - @graph: DiGraph instance - """ - while True: - new_graph = graph.copy() - for simp_func in self.passes: - simp_func(self, new_graph) - - if new_graph == graph: - break - graph = new_graph - return new_graph - - def __call__(self, graph): - """Wrapper on 'apply_simp'""" - return self.apply_simp(graph) - - -class MatchGraphJoker(object): - - """MatchGraphJoker are joker nodes of MatchGraph, that is to say nodes which - stand for any node. Restrictions can be added to jokers. - - If j1, j2 and j3 are MatchGraphJoker, one can quickly build a matcher for - the pattern: - | - +----v----+ - | (j1) | - +----+----+ - | - +----v----+ - | (j2) |<---+ - +----+--+-+ | - | +------+ - +----v----+ - | (j3) | - +----+----+ - | - v - Using: - >>> matcher = j1 >> j2 >> j3 - >>> matcher += j2 >> j2 - Or: - >>> matcher = j1 >> j2 >> j2 >> j3 - - """ - - def __init__(self, restrict_in=True, restrict_out=True, filt=None, - name=None): - """Instantiate a MatchGraphJoker, with restrictions - @restrict_in: (optional) if set, the number of predecessors of the - matched node must be the same than the joker node in the - associated MatchGraph - @restrict_out: (optional) counterpart of @restrict_in for successors - @filt: (optional) function(graph, node) -> boolean for filtering - candidate node - @name: (optional) helper for displaying the current joker - """ - if filt is None: - filt = lambda graph, node: True - self.filt = filt - if name is None: - name = str(id(self)) - self._name = name - self.restrict_in = restrict_in - self.restrict_out = restrict_out - - def __rshift__(self, joker): - """Helper for describing a MatchGraph from @joker - J1 >> J2 stands for an edge going to J2 from J1 - @joker: MatchGraphJoker instance - """ - assert isinstance(joker, MatchGraphJoker) - - graph = MatchGraph() - graph.add_node(self) - graph.add_node(joker) - graph.add_edge(self, joker) - - # For future "A >> B" idiom construction - graph._last_node = joker - - return graph - - def __str__(self): - info = [] - if not self.restrict_in: - info.append("In:*") - if not self.restrict_out: - info.append("Out:*") - return "Joker %s %s" % (self._name, - "(%s)" % " ".join(info) if info else "") - - -class MatchGraph(DiGraph): - - """MatchGraph intends to be the counterpart of match_expr, but for DiGraph - - This class provides API to match a given DiGraph pattern, with addidionnal - restrictions. - The implemented algorithm is a naive approach. - - The recommended way to instantiate a MatchGraph is the use of - MatchGraphJoker. - """ - - def __init__(self, *args, **kwargs): - super(MatchGraph, self).__init__(*args, **kwargs) - # Construction helper - self._last_node = None - - # Construction helpers - def __rshift__(self, joker): - """Construction helper, adding @joker to the current graph as a son of - _last_node - @joker: MatchGraphJoker instance""" - assert isinstance(joker, MatchGraphJoker) - assert isinstance(self._last_node, MatchGraphJoker) - - self.add_node(joker) - self.add_edge(self._last_node, joker) - self._last_node = joker - return self - - def __add__(self, graph): - """Construction helper, merging @graph with self - @graph: MatchGraph instance - """ - assert isinstance(graph, MatchGraph) - - # Reset helpers flag - self._last_node = None - graph._last_node = None - - # Merge graph into self - for node in graph.nodes(): - self.add_node(node) - for edge in graph.edges(): - self.add_edge(*edge) - - return self - - # Graph matching - def _check_node(self, candidate, expected, graph, partial_sol=None): - """Check if @candidate can stand for @expected in @graph, given @partial_sol - @candidate: @graph's node - @expected: MatchGraphJoker instance - @graph: DiGraph instance - @partial_sol: (optional) dictionary of MatchGraphJoker -> @graph's node - standing for a partial solution - """ - # Avoid having 2 different joker for the same node - if partial_sol and candidate in viewvalues(partial_sol): - return False - - # Check lambda filtering - if not expected.filt(graph, candidate): - return False - - # Check arity - # If filter_in/out, then arity must be the same - # Otherwise, arity of the candidate must be at least equal - if ((expected.restrict_in == True and - len(self.predecessors(expected)) != len(graph.predecessors(candidate))) or - (expected.restrict_in == False and - len(self.predecessors(expected)) > len(graph.predecessors(candidate)))): - return False - if ((expected.restrict_out == True and - len(self.successors(expected)) != len(graph.successors(candidate))) or - (expected.restrict_out == False and - len(self.successors(expected)) > len(graph.successors(candidate)))): - return False - - # Check edges with partial solution if any - if not partial_sol: - return True - for pred in self.predecessors(expected): - if (pred in partial_sol and - partial_sol[pred] not in graph.predecessors(candidate)): - return False - - for succ in self.successors(expected): - if (succ in partial_sol and - partial_sol[succ] not in graph.successors(candidate)): - return False - - # All checks OK - return True - - def _propagate_sol(self, node, partial_sol, graph, todo, propagator): - """ - Try to extend the current @partial_sol by propagating the solution using - @propagator on @node. - New solutions are added to @todo - """ - real_node = partial_sol[node] - for candidate in propagator(self, node): - # Edge already in the partial solution, skip it - if candidate in partial_sol: - continue - - # Check candidate - for candidate_real in propagator(graph, real_node): - if self._check_node(candidate_real, candidate, graph, - partial_sol): - temp_sol = partial_sol.copy() - temp_sol[candidate] = candidate_real - if temp_sol not in todo: - todo.append(temp_sol) - - @staticmethod - def _propagate_successors(graph, node): - """Propagate through @node successors in @graph""" - return graph.successors_iter(node) - - @staticmethod - def _propagate_predecessors(graph, node): - """Propagate through @node predecessors in @graph""" - return graph.predecessors_iter(node) - - def match(self, graph): - """Naive subgraph matching between graph and self. - Iterator on matching solution, as dictionary MatchGraphJoker -> @graph - @graph: DiGraph instance - In order to obtained correct and complete results, @graph must be - connected. - """ - # Partial solution: nodes corrects, edges between these nodes corrects - # A partial solution is a dictionary MatchGraphJoker -> @graph's node - todo = list() # Dictionaries containing partial solution - done = list() # Already computed partial solutions - - # Elect first candidates - to_match = next(iter(self._nodes)) - for node in graph.nodes(): - if self._check_node(node, to_match, graph): - to_add = {to_match: node} - if to_add not in todo: - todo.append(to_add) - - while todo: - # When a partial_sol is computed, if more precise partial solutions - # are found, they will be added to 'todo' - # -> using last entry of todo first performs a "depth first" - # approach on solutions - # -> the algorithm may converge faster to a solution, a desired - # behavior while doing graph simplification (stopping after one - # sol) - partial_sol = todo.pop() - - # Avoid infinite loop and recurrent work - if partial_sol in done: - continue - done.append(partial_sol) - - # If all nodes are matching, this is a potential solution - if len(partial_sol) == len(self._nodes): - yield partial_sol - continue - - # Find node to tests using edges - for node in partial_sol: - self._propagate_sol(node, partial_sol, graph, todo, - MatchGraph._propagate_successors) - self._propagate_sol(node, partial_sol, graph, todo, - MatchGraph._propagate_predecessors) diff --git a/miasm/core/interval.py b/miasm/core/interval.py deleted file mode 100644 index 172197c0..00000000 --- a/miasm/core/interval.py +++ /dev/null @@ -1,284 +0,0 @@ -from __future__ import print_function - -INT_EQ = 0 # Equivalent -INT_B_IN_A = 1 # B in A -INT_A_IN_B = -1 # A in B -INT_DISJOIN = 2 # Disjoint -INT_JOIN = 3 # Overlap -INT_JOIN_AB = 4 # B starts at the end of A -INT_JOIN_BA = 5 # A starts at the end of B - - -def cmp_interval(inter1, inter2): - """Compare @inter1 and @inter2 and returns the associated INT_* case - @inter1, @inter2: interval instance - """ - if inter1 == inter2: - return INT_EQ - - inter1_start, inter1_stop = inter1 - inter2_start, inter2_stop = inter2 - result = INT_JOIN - if inter1_start <= inter2_start and inter1_stop >= inter2_stop: - result = INT_B_IN_A - if inter2_start <= inter1_start and inter2_stop >= inter1_stop: - result = INT_A_IN_B - if inter1_stop + 1 == inter2_start: - result = INT_JOIN_AB - if inter2_stop + 1 == inter1_start: - result = INT_JOIN_BA - if inter1_start > inter2_stop + 1 or inter2_start > inter1_stop + 1: - result = INT_DISJOIN - return result - - -class interval(object): - """Stands for intervals with integer bounds - - Offers common methods to work with interval""" - - def __init__(self, bounds=None): - """Instance an interval object - @bounds: (optional) list of (int, int) and/or interval instance - """ - if bounds is None: - bounds = [] - elif isinstance(bounds, interval): - bounds = bounds.intervals - self.is_cannon = False - self.intervals = bounds - self.cannon() - - def __iter__(self): - """Iterate on intervals""" - for inter in self.intervals: - yield inter - - @staticmethod - def cannon_list(tmp): - """ - Return a cannonizes list of intervals - @tmp: list of (int, int) - """ - tmp = sorted([x for x in tmp if x[0] <= x[1]]) - out = [] - if not tmp: - return out - out.append(tmp.pop()) - while tmp: - x = tmp.pop() - rez = cmp_interval(out[-1], x) - - if rez == INT_EQ: - continue - elif rez == INT_DISJOIN: - out.append(x) - elif rez == INT_B_IN_A: - continue - elif rez in [INT_JOIN, INT_JOIN_AB, INT_JOIN_BA, INT_A_IN_B]: - u, v = x - while out and cmp_interval(out[-1], (u, v)) in [ - INT_JOIN, INT_JOIN_AB, INT_JOIN_BA, INT_A_IN_B]: - u = min(u, out[-1][0]) - v = max(v, out[-1][1]) - out.pop() - out.append((u, v)) - else: - raise ValueError('unknown state', rez) - return out[::-1] - - def cannon(self): - "Apply .cannon_list() on self contained intervals" - if self.is_cannon is True: - return - self.intervals = interval.cannon_list(self.intervals) - self.is_cannon = True - - def __repr__(self): - if self.intervals: - o = " U ".join(["[0x%X 0x%X]" % (x[0], x[1]) - for x in self.intervals]) - else: - o = "[]" - return o - - def __contains__(self, other): - if isinstance(other, interval): - for intervalB in other.intervals: - is_in = False - for intervalA in self.intervals: - if cmp_interval(intervalA, intervalB) in [INT_EQ, INT_B_IN_A]: - is_in = True - break - if not is_in: - return False - return True - else: - for intervalA in self.intervals: - if intervalA[0] <= other <= intervalA[1]: - return True - return False - - def __eq__(self, i): - return self.intervals == i.intervals - - def __ne__(self, other): - return not self.__eq__(other) - - def union(self, other): - """ - Return the union of intervals - @other: interval instance - """ - - if isinstance(other, interval): - other = other.intervals - other = interval(self.intervals + other) - return other - - def difference(self, other): - """ - Return the difference of intervals - @other: interval instance - """ - - to_test = self.intervals[:] - i = -1 - to_del = other.intervals[:] - while i < len(to_test) - 1: - i += 1 - x = to_test[i] - if x[0] > x[1]: - del to_test[i] - i -= 1 - continue - - while to_del and to_del[0][1] < x[0]: - del to_del[0] - - for y in to_del: - if y[0] > x[1]: - break - rez = cmp_interval(x, y) - if rez == INT_DISJOIN: - continue - elif rez == INT_EQ: - del to_test[i] - i -= 1 - break - elif rez == INT_A_IN_B: - del to_test[i] - i -= 1 - break - elif rez == INT_B_IN_A: - del to_test[i] - i1 = (x[0], y[0] - 1) - i2 = (y[1] + 1, x[1]) - to_test[i:i] = [i1, i2] - i -= 1 - break - elif rez in [INT_JOIN_AB, INT_JOIN_BA]: - continue - elif rez == INT_JOIN: - del to_test[i] - if x[0] < y[0]: - to_test[i:i] = [(x[0], y[0] - 1)] - else: - to_test[i:i] = [(y[1] + 1, x[1])] - i -= 1 - break - else: - raise ValueError('unknown state', rez) - return interval(to_test) - - def intersection(self, other): - """ - Return the intersection of intervals - @other: interval instance - """ - - out = [] - for x in self.intervals: - if x[0] > x[1]: - continue - for y in other.intervals: - rez = cmp_interval(x, y) - - if rez == INT_DISJOIN: - continue - elif rez == INT_EQ: - out.append(x) - continue - elif rez == INT_A_IN_B: - out.append(x) - continue - elif rez == INT_B_IN_A: - out.append(y) - continue - elif rez == INT_JOIN_AB: - continue - elif rez == INT_JOIN_BA: - continue - elif rez == INT_JOIN: - if x[0] < y[0]: - out.append((y[0], x[1])) - else: - out.append((x[0], y[1])) - continue - else: - raise ValueError('unknown state', rez) - return interval(out) - - - def __add__(self, other): - return self.union(other) - - def __and__(self, other): - return self.intersection(other) - - def __sub__(self, other): - return self.difference(other) - - def hull(self): - "Return the first and the last bounds of intervals" - if not self.intervals: - return None, None - return self.intervals[0][0], self.intervals[-1][1] - - - @property - def empty(self): - """Return True iff the interval is empty""" - return not self.intervals - - def show(self, img_x=1350, img_y=20, dry_run=False): - """ - show image representing the interval - """ - try: - import Image - import ImageDraw - except ImportError: - print('cannot import python PIL imaging') - return - - img = Image.new('RGB', (img_x, img_y), (100, 100, 100)) - draw = ImageDraw.Draw(img) - i_min, i_max = self.hull() - - print(hex(i_min), hex(i_max)) - - addr2x = lambda addr: ((addr - i_min) * img_x) // (i_max - i_min) - for a, b in self.intervals: - draw.rectangle((addr2x(a), 0, addr2x(b), img_y), (200, 0, 0)) - - if dry_run is False: - img.show() - - @property - def length(self): - """ - Return the cumulated length of intervals - """ - # Do not use __len__ because we may return a value > 32 bits - return sum((stop - start + 1) for start, stop in self.intervals) diff --git a/miasm/core/locationdb.py b/miasm/core/locationdb.py deleted file mode 100644 index b7e16ea2..00000000 --- a/miasm/core/locationdb.py +++ /dev/null @@ -1,495 +0,0 @@ -import warnings -from builtins import int as int_types - -from functools import reduce -from future.utils import viewitems, viewvalues - -from miasm.core.utils import printable -from miasm.expression.expression import LocKey, ExprLoc - - -class LocationDB(object): - """ - LocationDB is a "database" of information associated to location. - - An entry in a LocationDB is uniquely identified with a LocKey. - Additional information which can be associated with a LocKey are: - - an offset (uniq per LocationDB) - - several names (each are uniqs per LocationDB) - - As a schema: - loc_key 1 <-> 0..1 offset - 1 <-> 0..n name - - >>> loc_db = LocationDB() - # Add a location with no additional information - >>> loc_key1 = loc_db.add_location() - # Add a location with an offset - >>> loc_key2 = loc_db.add_location(offset=0x1234) - # Add a location with several names - >>> loc_key3 = loc_db.add_location(name="first_name") - >>> loc_db.add_location_name(loc_key3, "second_name") - # Associate an offset to an existing location - >>> loc_db.set_location_offset(loc_key3, 0x5678) - # Remove a name from an existing location - >>> loc_db.remove_location_name(loc_key3, "second_name") - - # Get back offset - >>> loc_db.get_location_offset(loc_key1) - None - >>> loc_db.get_location_offset(loc_key2) - 0x1234 - >>> loc_db.get_location_offset("first_name") - 0x5678 - - # Display a location - >>> loc_db.pretty_str(loc_key1) - loc_key_1 - >>> loc_db.pretty_str(loc_key2) - loc_1234 - >>> loc_db.pretty_str(loc_key3) - first_name - """ - - def __init__(self): - # Known LocKeys - self._loc_keys = set() - - # Association tables - self._loc_key_to_offset = {} - self._loc_key_to_names = {} - self._name_to_loc_key = {} - self._offset_to_loc_key = {} - - # Counter for new LocKey generation - self._loc_key_num = 0 - - def get_location_offset(self, loc_key): - """ - Return the offset of @loc_key if any, None otherwise. - @loc_key: LocKey instance - """ - assert isinstance(loc_key, LocKey) - return self._loc_key_to_offset.get(loc_key) - - def get_location_names(self, loc_key): - """ - Return the frozenset of names associated to @loc_key - @loc_key: LocKey instance - """ - assert isinstance(loc_key, LocKey) - return frozenset(self._loc_key_to_names.get(loc_key, set())) - - def get_name_location(self, name): - """ - Return the LocKey of @name if any, None otherwise. - @name: target name - """ - assert isinstance(name, str) - return self._name_to_loc_key.get(name) - - def get_or_create_name_location(self, name): - """ - Return the LocKey of @name if any, create one otherwise. - @name: target name - """ - assert isinstance(name, str) - loc_key = self._name_to_loc_key.get(name) - if loc_key is not None: - return loc_key - return self.add_location(name=name) - - def get_offset_location(self, offset): - """ - Return the LocKey of @offset if any, None otherwise. - @offset: target offset - """ - return self._offset_to_loc_key.get(offset) - - def get_or_create_offset_location(self, offset): - """ - Return the LocKey of @offset if any, create one otherwise. - @offset: target offset - """ - loc_key = self._offset_to_loc_key.get(offset) - if loc_key is not None: - return loc_key - return self.add_location(offset=offset) - - def get_name_offset(self, name): - """ - Return the offset of @name if any, None otherwise. - @name: target name - """ - assert isinstance(name, str) - loc_key = self.get_name_location(name) - if loc_key is None: - return None - return self.get_location_offset(loc_key) - - def add_location_name(self, loc_key, name): - """Associate a name @name to a given @loc_key - @name: str instance - @loc_key: LocKey instance - """ - assert isinstance(name, str) - assert loc_key in self._loc_keys - already_existing_loc = self._name_to_loc_key.get(name) - if already_existing_loc is not None and already_existing_loc != loc_key: - raise KeyError("%r is already associated to a different loc_key " - "(%r)" % (name, already_existing_loc)) - self._loc_key_to_names.setdefault(loc_key, set()).add(name) - self._name_to_loc_key[name] = loc_key - - def remove_location_name(self, loc_key, name): - """Disassociate a name @name from a given @loc_key - Fail if @name is not already associated to @loc_key - @name: str instance - @loc_key: LocKey instance - """ - assert loc_key in self._loc_keys - assert isinstance(name, str) - already_existing_loc = self._name_to_loc_key.get(name) - if already_existing_loc is None: - raise KeyError("%r is not already associated" % name) - if already_existing_loc != loc_key: - raise KeyError("%r is already associated to a different loc_key " - "(%r)" % (name, already_existing_loc)) - del self._name_to_loc_key[name] - self._loc_key_to_names[loc_key].remove(name) - - def set_location_offset(self, loc_key, offset, force=False): - """Associate the offset @offset to an LocKey @loc_key - - If @force is set, override silently. Otherwise, if an offset is already - associated to @loc_key, an error will be raised - """ - assert loc_key in self._loc_keys - already_existing_loc = self.get_offset_location(offset) - if already_existing_loc is not None and already_existing_loc != loc_key: - raise KeyError("%r is already associated to a different loc_key " - "(%r)" % (offset, already_existing_loc)) - already_existing_off = self._loc_key_to_offset.get(loc_key) - if (already_existing_off is not None and - already_existing_off != offset): - if not force: - raise ValueError( - "%r already has an offset (0x%x). Use 'force=True'" - " for silent overriding" % ( - loc_key, already_existing_off - )) - else: - self.unset_location_offset(loc_key) - self._offset_to_loc_key[offset] = loc_key - self._loc_key_to_offset[loc_key] = offset - - def unset_location_offset(self, loc_key): - """Disassociate LocKey @loc_key's offset - - Fail if there is already no offset associate with it - @loc_key: LocKey - """ - assert loc_key in self._loc_keys - already_existing_off = self._loc_key_to_offset.get(loc_key) - if already_existing_off is None: - raise ValueError("%r already has no offset" % (loc_key)) - del self._offset_to_loc_key[already_existing_off] - del self._loc_key_to_offset[loc_key] - - def consistency_check(self): - """Ensure internal structures are consistent with each others""" - assert set(self._loc_key_to_names).issubset(self._loc_keys) - assert set(self._loc_key_to_offset).issubset(self._loc_keys) - assert self._loc_key_to_offset == {v: k for k, v in viewitems(self._offset_to_loc_key)} - assert reduce( - lambda x, y:x.union(y), - viewvalues(self._loc_key_to_names), - set(), - ) == set(self._name_to_loc_key) - for name, loc_key in viewitems(self._name_to_loc_key): - assert name in self._loc_key_to_names[loc_key] - - def find_free_name(self, name): - """ - If @name is not known in DB, return it - Else append an index to it corresponding to the next unknown name - - @name: string - """ - assert isinstance(name, str) - if self.get_name_location(name) is None: - return name - i = 0 - while True: - new_name = "%s_%d" % (name, i) - if self.get_name_location(new_name) is None: - return new_name - i += 1 - - def add_location(self, name=None, offset=None, strict=True): - """Add a new location in the locationDB. Returns the corresponding LocKey. - If @name is set, also associate a name to this new location. - If @offset is set, also associate an offset to this new location. - - Strict mode (set by @strict, default): - If a location with @offset or @name already exists, an error will be - raised. - Otherwise: - If a location with @offset or @name already exists, the corresponding - LocKey may be updated and will be returned. - """ - - # Deprecation handling - if isinstance(name, int_types): - assert offset is None or offset == name - warnings.warn("Deprecated API: use 'add_location(offset=)' instead." - " An additional 'name=' can be provided to also " - "associate a name (there is no more default name)") - offset = name - name = None - - # Argument cleaning - offset_loc_key = None - if offset is not None: - offset = int(offset) - offset_loc_key = self.get_offset_location(offset) - - # Test for collisions - name_loc_key = None - if name is not None: - assert isinstance(name, str) - name_loc_key = self.get_name_location(name) - - if strict: - if name_loc_key is not None: - raise ValueError("An entry for %r already exists (%r), and " - "strict mode is enabled" % ( - name, name_loc_key - )) - if offset_loc_key is not None: - raise ValueError("An entry for 0x%x already exists (%r), and " - "strict mode is enabled" % ( - offset, offset_loc_key - )) - else: - # Non-strict mode - if name_loc_key is not None: - known_offset = self.get_offset_location(name_loc_key) - if known_offset is None: - if offset is not None: - self.set_location_offset(name_loc_key, offset) - elif known_offset != offset: - raise ValueError( - "Location with name '%s' already have an offset: 0x%x " - "(!= 0x%x)" % (name, offset, known_offset) - ) - # Name already known, same offset -> nothing to do - return name_loc_key - - elif offset_loc_key is not None: - if name is not None: - # Check for already known name are checked above - return self.add_location_name(offset_loc_key, name) - # Offset already known, no name specified - return offset_loc_key - - # No collision, this is a brand new location - loc_key = LocKey(self._loc_key_num) - self._loc_key_num += 1 - self._loc_keys.add(loc_key) - - if offset is not None: - assert offset not in self._offset_to_loc_key - self._offset_to_loc_key[offset] = loc_key - self._loc_key_to_offset[loc_key] = offset - - if name is not None: - self._name_to_loc_key[name] = loc_key - self._loc_key_to_names[loc_key] = set([name]) - - return loc_key - - def remove_location(self, loc_key): - """ - Delete the location corresponding to @loc_key - @loc_key: LocKey instance - """ - assert isinstance(loc_key, LocKey) - if loc_key not in self._loc_keys: - raise KeyError("Unknown loc_key %r" % loc_key) - names = self._loc_key_to_names.pop(loc_key, []) - for name in names: - del self._name_to_loc_key[name] - offset = self._loc_key_to_offset.pop(loc_key, None) - self._offset_to_loc_key.pop(offset, None) - self._loc_keys.remove(loc_key) - - def pretty_str(self, loc_key): - """Return a human readable version of @loc_key, according to information - available in this LocationDB instance""" - names = self.get_location_names(loc_key) - new_names = set() - for name in names: - try: - name = name.decode() - except AttributeError: - pass - new_names.add(name) - names = new_names - if names: - return ",".join(names) - offset = self.get_location_offset(loc_key) - if offset is not None: - return "loc_%x" % offset - return str(loc_key) - - @property - def loc_keys(self): - """Return all loc_keys""" - return self._loc_keys - - @property - def names(self): - """Return all known names""" - return list(self._name_to_loc_key) - - @property - def offsets(self): - """Return all known offsets""" - return list(self._offset_to_loc_key) - - def __str__(self): - out = [] - for loc_key in self._loc_keys: - names = self.get_location_names(loc_key) - offset = self.get_location_offset(loc_key) - out.append( - "%s: %s - %s" % ( - loc_key, - "0x%x" % offset if offset is not None else None, - ",".join(printable(name) for name in names) - ) - ) - return "\n".join(out) - - def merge(self, location_db): - """Merge with another LocationDB @location_db - - WARNING: old reference to @location_db information (such as LocKeys) - must be retrieved from the updated version of this instance. The - dedicated "get_*" APIs may be used for this task - """ - # A simple merge is not doable here, because LocKey will certainly - # collides - - for foreign_loc_key in location_db.loc_keys: - foreign_names = location_db.get_location_names(foreign_loc_key) - foreign_offset = location_db.get_location_offset(foreign_loc_key) - if foreign_names: - init_name = list(foreign_names)[0] - else: - init_name = None - loc_key = self.add_location(offset=foreign_offset, name=init_name, - strict=False) - cur_names = self.get_location_names(loc_key) - for name in foreign_names: - if name not in cur_names and name != init_name: - self.add_location_name(loc_key, name=name) - - def canonize_to_exprloc(self, expr): - """ - If expr is ExprInt, return ExprLoc with corresponding loc_key - Else, return expr - - @expr: Expr instance - """ - if expr.is_int(): - loc_key = self.get_or_create_offset_location(int(expr)) - ret = ExprLoc(loc_key, expr.size) - return ret - return expr - - # Deprecated APIs - @property - def items(self): - """Return all loc_keys""" - warnings.warn('DEPRECATION WARNING: use "loc_keys" instead of "items"') - return list(self._loc_keys) - - def __getitem__(self, item): - warnings.warn('DEPRECATION WARNING: use "get_name_location" or ' - '"get_offset_location"') - if item in self._name_to_loc_key: - return self._name_to_loc_key[item] - if item in self._offset_to_loc_key: - return self._offset_to_loc_key[item] - raise KeyError('unknown symbol %r' % item) - - def __contains__(self, item): - warnings.warn('DEPRECATION WARNING: use "get_name_location" or ' - '"get_offset_location", or ".offsets" or ".names"') - return item in self._name_to_loc_key or item in self._offset_to_loc_key - - def loc_key_to_name(self, loc_key): - """[DEPRECATED API], see 'get_location_names'""" - warnings.warn("Deprecated API: use 'get_location_names'") - return sorted(self.get_location_names(loc_key))[0] - - def loc_key_to_offset(self, loc_key): - """[DEPRECATED API], see 'get_location_offset'""" - warnings.warn("Deprecated API: use 'get_location_offset'") - return self.get_location_offset(loc_key) - - def remove_loc_key(self, loc_key): - """[DEPRECATED API], see 'remove_location'""" - warnings.warn("Deprecated API: use 'remove_location'") - self.remove_location(loc_key) - - def del_loc_key_offset(self, loc_key): - """[DEPRECATED API], see 'unset_location_offset'""" - warnings.warn("Deprecated API: use 'unset_location_offset'") - self.unset_location_offset(loc_key) - - def getby_offset(self, offset): - """[DEPRECATED API], see 'get_offset_location'""" - warnings.warn("Deprecated API: use 'get_offset_location'") - return self.get_offset_location(offset) - - def getby_name(self, name): - """[DEPRECATED API], see 'get_name_location'""" - warnings.warn("Deprecated API: use 'get_name_location'") - return self.get_name_location(name) - - def getby_offset_create(self, offset): - """[DEPRECATED API], see 'get_or_create_offset_location'""" - warnings.warn("Deprecated API: use 'get_or_create_offset_location'") - return self.get_or_create_offset_location(offset) - - def getby_name_create(self, name): - """[DEPRECATED API], see 'get_or_create_name_location'""" - warnings.warn("Deprecated API: use 'get_or_create_name_location'") - return self.get_or_create_name_location(name) - - def rename_location(self, loc_key, newname): - """[DEPRECATED API], see 'add_name_location' and 'remove_location_name' - """ - warnings.warn("Deprecated API: use 'add_location_name' and " - "'remove_location_name'") - for name in self.get_location_names(loc_key): - self.remove_location_name(loc_key, name) - self.add_location_name(loc_key, name) - - def set_offset(self, loc_key, offset): - """[DEPRECATED API], see 'set_location_offset'""" - warnings.warn("Deprecated API: use 'set_location_offset'") - self.set_location_offset(loc_key, offset, force=True) - - def gen_loc_key(self): - """[DEPRECATED API], see 'add_location'""" - warnings.warn("Deprecated API: use 'add_location'") - return self.add_location() - - def str_loc_key(self, loc_key): - """[DEPRECATED API], see 'pretty_str'""" - warnings.warn("Deprecated API: use 'pretty_str'") - return self.pretty_str(loc_key) diff --git a/miasm/core/modint.py b/miasm/core/modint.py deleted file mode 100644 index 14b4dc2c..00000000 --- a/miasm/core/modint.py +++ /dev/null @@ -1,270 +0,0 @@ -#-*- coding:utf-8 -*- - -from builtins import range -from functools import total_ordering - -@total_ordering -class moduint(object): - - def __init__(self, arg): - self.arg = int(arg) % self.__class__.limit - assert(self.arg >= 0 and self.arg < self.__class__.limit) - - def __repr__(self): - return self.__class__.__name__ + '(' + hex(self.arg) + ')' - - def __hash__(self): - return hash(self.arg) - - @classmethod - def maxcast(cls, c2): - c2 = c2.__class__ - if cls.size > c2.size: - return cls - else: - return c2 - - def __eq__(self, y): - if isinstance(y, moduint): - return self.arg == y.arg - return self.arg == y - - def __ne__(self, y): - # required Python 2.7.14 - return not self == y - - def __lt__(self, y): - if isinstance(y, moduint): - return self.arg < y.arg - return self.arg < y - - def __add__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(self.arg + y.arg) - else: - return self.__class__(self.arg + y) - - def __and__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(self.arg & y.arg) - else: - return self.__class__(self.arg & y) - - def __div__(self, y): - # Python: 8 / -7 == -2 (C-like: -1) - # int(float) trick cannot be used, due to information loss - # Examples: - # - # 42 / 10 => 4 - # 42 % 10 => 2 - # - # -42 / 10 => -4 - # -42 % 10 => -2 - # - # 42 / -10 => -4 - # 42 % -10 => 2 - # - # -42 / -10 => 4 - # -42 % -10 => -2 - - den = int(y) - num = int(self) - result_sign = 1 if (den * num) >= 0 else -1 - cls = self.__class__ - if isinstance(y, moduint): - cls = self.maxcast(y) - return (abs(num) // abs(den)) * result_sign - - def __floordiv__(self, y): - return self.__div__(y) - - def __int__(self): - return int(self.arg) - - def __long__(self): - return int(self.arg) - - def __index__(self): - return int(self.arg) - - def __invert__(self): - return self.__class__(~self.arg) - - def __lshift__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(self.arg << y.arg) - else: - return self.__class__(self.arg << y) - - def __mod__(self, y): - # See __div__ for implementation choice - cls = self.__class__ - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(self.arg - y * (self // y)) - - def __mul__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(self.arg * y.arg) - else: - return self.__class__(self.arg * y) - - def __neg__(self): - return self.__class__(-self.arg) - - def __or__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(self.arg | y.arg) - else: - return self.__class__(self.arg | y) - - def __radd__(self, y): - return self.__add__(y) - - def __rand__(self, y): - return self.__and__(y) - - def __rdiv__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(y.arg // self.arg) - else: - return self.__class__(y // self.arg) - - def __rfloordiv__(self, y): - return self.__rdiv__(y) - - def __rlshift__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(y.arg << self.arg) - else: - return self.__class__(y << self.arg) - - def __rmod__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(y.arg % self.arg) - else: - return self.__class__(y % self.arg) - - def __rmul__(self, y): - return self.__mul__(y) - - def __ror__(self, y): - return self.__or__(y) - - def __rrshift__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(y.arg >> self.arg) - else: - return self.__class__(y >> self.arg) - - def __rshift__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(self.arg >> y.arg) - else: - return self.__class__(self.arg >> y) - - def __rsub__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(y.arg - self.arg) - else: - return self.__class__(y - self.arg) - - def __rxor__(self, y): - return self.__xor__(y) - - def __sub__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(self.arg - y.arg) - else: - return self.__class__(self.arg - y) - - def __xor__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(self.arg ^ y.arg) - else: - return self.__class__(self.arg ^ y) - - def __hex__(self): - return hex(self.arg) - - def __abs__(self): - return abs(self.arg) - - def __rpow__(self, v): - return v ** self.arg - - def __pow__(self, v): - return self.__class__(self.arg ** v) - - -class modint(moduint): - - def __init__(self, arg): - if isinstance(arg, moduint): - arg = arg.arg - a = arg % self.__class__.limit - if a >= self.__class__.limit // 2: - a -= self.__class__.limit - self.arg = a - assert( - self.arg >= -self.__class__.limit // 2 and - self.arg < self.__class__.limit - ) - - -def is_modint(a): - return isinstance(a, moduint) - - -mod_size2uint = {} -mod_size2int = {} - -mod_uint2size = {} -mod_int2size = {} - -def define_int(size): - """Build the 'modint' instance corresponding to size @size""" - global mod_size2int, mod_int2size - - name = 'int%d' % size - cls = type(name, (modint,), {"size": size, "limit": 1 << size}) - globals()[name] = cls - mod_size2int[size] = cls - mod_int2size[cls] = size - return cls - -def define_uint(size): - """Build the 'moduint' instance corresponding to size @size""" - global mod_size2uint, mod_uint2size - - name = 'uint%d' % size - cls = type(name, (moduint,), {"size": size, "limit": 1 << size}) - globals()[name] = cls - mod_size2uint[size] = cls - mod_uint2size[cls] = size - return cls - -def define_common_int(): - "Define common int" - common_int = range(1, 257) - - for i in common_int: - define_int(i) - - for i in common_int: - define_uint(i) - -define_common_int() diff --git a/miasm/core/objc.py b/miasm/core/objc.py deleted file mode 100644 index 24ee84ab..00000000 --- a/miasm/core/objc.py +++ /dev/null @@ -1,1763 +0,0 @@ -""" -C helper for Miasm: -* raw C to Miasm expression -* Miasm expression to raw C -* Miasm expression to C type -""" - -from builtins import zip -from builtins import int as int_types - -import warnings -from pycparser import c_parser, c_ast -from functools import total_ordering - -from miasm.core.utils import cmp_elts -from miasm.expression.expression_reduce import ExprReducer -from miasm.expression.expression import ExprInt, ExprId, ExprOp, ExprMem -from miasm.arch.x86.arch import is_op_segm - -from miasm.core.ctypesmngr import CTypeUnion, CTypeStruct, CTypeId, CTypePtr,\ - CTypeArray, CTypeOp, CTypeSizeof, CTypeEnum, CTypeFunc, CTypeEllipsis - - -PADDING_TYPE_NAME = "___padding___" - -def missing_definition(objtype): - warnings.warn("Null size type: Missing definition? %r" % objtype) - -""" -Display C type -source: "The C Programming Language - 2nd Edition - Ritchie Kernighan.pdf" -p. 124 -""" - -def objc_to_str(objc, result=None): - if result is None: - result = "" - while True: - if isinstance(objc, ObjCArray): - result += "[%d]" % objc.elems - objc = objc.objtype - elif isinstance(objc, ObjCPtr): - if not result and isinstance(objc.objtype, ObjCFunc): - result = objc.objtype.name - if isinstance(objc.objtype, (ObjCPtr, ObjCDecl, ObjCStruct, ObjCUnion)): - result = "*%s" % result - else: - result = "(*%s)" % result - - objc = objc.objtype - elif isinstance(objc, (ObjCDecl, ObjCStruct, ObjCUnion)): - if result: - result = "%s %s" % (objc, result) - else: - result = str(objc) - break - elif isinstance(objc, ObjCFunc): - args_str = [] - for name, arg in objc.args: - args_str.append(objc_to_str(arg, name)) - args = ", ".join(args_str) - result += "(%s)" % args - objc = objc.type_ret - elif isinstance(objc, ObjCInt): - return "int" - elif isinstance(objc, ObjCEllipsis): - return "..." - else: - raise TypeError("Unknown c type") - return result - - -@total_ordering -class ObjC(object): - """Generic ObjC""" - - def __init__(self, align, size): - self._align = align - self._size = size - - @property - def align(self): - """Alignment (in bytes) of the C object""" - return self._align - - @property - def size(self): - """Size (in bytes) of the C object""" - return self._size - - def cmp_base(self, other): - assert self.__class__ in OBJC_PRIO - assert other.__class__ in OBJC_PRIO - - if OBJC_PRIO[self.__class__] != OBJC_PRIO[other.__class__]: - return cmp_elts( - OBJC_PRIO[self.__class__], - OBJC_PRIO[other.__class__] - ) - if self.align != other.align: - return cmp_elts(self.align, other.align) - return cmp_elts(self.size, other.size) - - def __hash__(self): - return hash((self.__class__, self._align, self._size)) - - def __str__(self): - return objc_to_str(self) - - def __eq__(self, other): - return self.cmp_base(other) == 0 - - def __ne__(self, other): - # required Python 2.7.14 - return not self == other - - def __lt__(self, other): - return self.cmp_base(other) < 0 - - -@total_ordering -class ObjCDecl(ObjC): - """C Declaration identified""" - - def __init__(self, name, align, size): - super(ObjCDecl, self).__init__(align, size) - self._name = name - - name = property(lambda self: self._name) - - def __hash__(self): - return hash((super(ObjCDecl, self).__hash__(), self._name)) - - def __repr__(self): - return '<%s %s>' % (self.__class__.__name__, self.name) - - def __str__(self): - return str(self.name) - - def __eq__(self, other): - ret = self.cmp_base(other) - if ret: - return False - return self.name == other.name - - def __lt__(self, other): - ret = self.cmp_base(other) - if ret: - if ret < 0: - return True - return False - return self.name < other.name - - -class ObjCInt(ObjC): - """C integer""" - - def __init__(self): - super(ObjCInt, self).__init__(None, 0) - - def __str__(self): - return 'int' - - -@total_ordering -class ObjCPtr(ObjC): - """C Pointer""" - - def __init__(self, objtype, void_p_align, void_p_size): - """Init ObjCPtr - - @objtype: pointer target ObjC - @void_p_align: pointer alignment (in bytes) - @void_p_size: pointer size (in bytes) - """ - - super(ObjCPtr, self).__init__(void_p_align, void_p_size) - self._lock = False - - self.objtype = objtype - if objtype is None: - self._lock = False - - def get_objtype(self): - assert self._lock is True - return self._objtype - - def set_objtype(self, objtype): - assert self._lock is False - self._lock = True - self._objtype = objtype - - objtype = property(get_objtype, set_objtype) - - def __hash__(self): - # Don't try to hash on an unlocked Ptr (still mutable) - assert self._lock - return hash((super(ObjCPtr, self).__hash__(), hash(self._objtype))) - - def __repr__(self): - return '<%s %r>' % ( - self.__class__.__name__, - self.objtype.__class__ - ) - - def __eq__(self, other): - ret = self.cmp_base(other) - if ret: - return False - return self.objtype == other.objtype - - def __lt__(self, other): - ret = self.cmp_base(other) - if ret: - if ret < 0: - return True - return False - return self.objtype < other.objtype - - -@total_ordering -class ObjCArray(ObjC): - """C array (test[XX])""" - - def __init__(self, objtype, elems): - """Init ObjCArray - - @objtype: pointer target ObjC - @elems: number of elements in the array - """ - - super(ObjCArray, self).__init__(objtype.align, elems * objtype.size) - self._elems = elems - self._objtype = objtype - - objtype = property(lambda self: self._objtype) - elems = property(lambda self: self._elems) - - def __hash__(self): - return hash((super(ObjCArray, self).__hash__(), self._elems, hash(self._objtype))) - - def __repr__(self): - return '<%r[%d]>' % (self.objtype, self.elems) - - def __eq__(self, other): - ret = self.cmp_base(other) - if ret: - return False - if self.objtype != other.objtype: - return False - return self.elems == other.elems - - def __lt__(self, other): - ret = self.cmp_base(other) - if ret > 0: - return False - if self.objtype > other.objtype: - return False - return self.elems < other.elems - -@total_ordering -class ObjCStruct(ObjC): - """C object for structures""" - - def __init__(self, name, align, size, fields): - super(ObjCStruct, self).__init__(align, size) - self._name = name - self._fields = tuple(fields) - - name = property(lambda self: self._name) - fields = property(lambda self: self._fields) - - def __hash__(self): - return hash((super(ObjCStruct, self).__hash__(), self._name)) - - def __repr__(self): - out = [] - out.append("Struct %s: (align: %d)" % (self.name, self.align)) - out.append(" off sz name") - for name, objtype, offset, size in self.fields: - out.append(" 0x%-3x %-3d %-10s %r" % - (offset, size, name, objtype.__class__.__name__)) - return '\n'.join(out) - - def __str__(self): - return 'struct %s' % (self.name) - - def __eq__(self, other): - ret = self.cmp_base(other) - if ret: - return False - return self.name == other.name - - def __lt__(self, other): - ret = self.cmp_base(other) - if ret: - if ret < 0: - return True - return False - return self.name < other.name - - -@total_ordering -class ObjCUnion(ObjC): - """C object for unions""" - - def __init__(self, name, align, size, fields): - super(ObjCUnion, self).__init__(align, size) - self._name = name - self._fields = tuple(fields) - - name = property(lambda self: self._name) - fields = property(lambda self: self._fields) - - def __hash__(self): - return hash((super(ObjCUnion, self).__hash__(), self._name)) - - def __repr__(self): - out = [] - out.append("Union %s: (align: %d)" % (self.name, self.align)) - out.append(" off sz name") - for name, objtype, offset, size in self.fields: - out.append(" 0x%-3x %-3d %-10s %r" % - (offset, size, name, objtype)) - return '\n'.join(out) - - def __str__(self): - return 'union %s' % (self.name) - - def __eq__(self, other): - ret = self.cmp_base(other) - if ret: - return False - return self.name == other.name - - def __lt__(self, other): - ret = self.cmp_base(other) - if ret: - if ret < 0: - return True - return False - return self.name < other.name - -class ObjCEllipsis(ObjC): - """C integer""" - - def __init__(self): - super(ObjCEllipsis, self).__init__(None, None) - - align = property(lambda self: self._align) - size = property(lambda self: self._size) - -@total_ordering -class ObjCFunc(ObjC): - """C object for Functions""" - - def __init__(self, name, abi, type_ret, args, void_p_align, void_p_size): - super(ObjCFunc, self).__init__(void_p_align, void_p_size) - self._name = name - self._abi = abi - self._type_ret = type_ret - self._args = tuple(args) - - args = property(lambda self: self._args) - type_ret = property(lambda self: self._type_ret) - abi = property(lambda self: self._abi) - name = property(lambda self: self._name) - - def __hash__(self): - return hash((super(ObjCFunc, self).__hash__(), hash(self._args), self._name)) - - def __repr__(self): - return "<%s %s>" % ( - self.__class__.__name__, - self.name - ) - - def __str__(self): - out = [] - out.append("Function (%s) %s: (align: %d)" % (self.abi, self.name, self.align)) - out.append(" ret: %s" % (str(self.type_ret))) - out.append(" Args:") - for name, arg in self.args: - out.append(" %s %s" % (name, arg)) - return '\n'.join(out) - - def __eq__(self, other): - ret = self.cmp_base(other) - if ret: - return False - return self.name == other.name - - def __lt__(self, other): - ret = self.cmp_base(other) - if ret: - if ret < 0: - return True - return False - return self.name < other.name - -OBJC_PRIO = { - ObjC: 0, - ObjCDecl:1, - ObjCInt:2, - ObjCPtr:3, - ObjCArray:4, - ObjCStruct:5, - ObjCUnion:6, - ObjCEllipsis:7, - ObjCFunc:8, -} - - -def access_simplifier(expr): - """Expression visitor to simplify a C access represented in Miasm - - @expr: Miasm expression representing the C access - - Example: - - IN: (In c: ['*(&((&((*(ptr_Test)).a))[0]))']) - [ExprOp('deref', ExprOp('addr', ExprOp('[]', ExprOp('addr', - ExprOp('field', ExprOp('deref', ExprId('ptr_Test', 64)), - ExprId('a', 64))), ExprInt(0x0, 64))))] - - OUT: (In c: ['(ptr_Test)->a']) - [ExprOp('->', ExprId('ptr_Test', 64), ExprId('a', 64))] - """ - - if (expr.is_op("addr") and - expr.args[0].is_op("[]") and - expr.args[0].args[1] == ExprInt(0, 64)): - return expr.args[0].args[0] - elif (expr.is_op("[]") and - expr.args[0].is_op("addr") and - expr.args[1] == ExprInt(0, 64)): - return expr.args[0].args[0] - elif (expr.is_op("addr") and - expr.args[0].is_op("deref")): - return expr.args[0].args[0] - elif (expr.is_op("deref") and - expr.args[0].is_op("addr")): - return expr.args[0].args[0] - elif (expr.is_op("field") and - expr.args[0].is_op("deref")): - return ExprOp("->", expr.args[0].args[0], expr.args[1]) - return expr - - -def access_str(expr): - """Return the C string of a C access represented in Miasm - - @expr: Miasm expression representing the C access - - In: - ExprOp('->', ExprId('ptr_Test', 64), ExprId('a', 64)) - OUT: - '(ptr_Test)->a' - """ - - if isinstance(expr, ExprId): - out = str(expr) - elif isinstance(expr, ExprInt): - out = str(int(expr)) - elif expr.is_op("addr"): - out = "&(%s)" % access_str(expr.args[0]) - elif expr.is_op("deref"): - out = "*(%s)" % access_str(expr.args[0]) - elif expr.is_op("field"): - out = "(%s).%s" % (access_str(expr.args[0]), access_str(expr.args[1])) - elif expr.is_op("->"): - out = "(%s)->%s" % (access_str(expr.args[0]), access_str(expr.args[1])) - elif expr.is_op("[]"): - out = "(%s)[%s]" % (access_str(expr.args[0]), access_str(expr.args[1])) - else: - raise RuntimeError("unknown op") - - return out - - -class CGen(object): - """Generic object to represent a C expression""" - - default_size = 64 - - - def __init__(self, ctype): - self._ctype = ctype - - @property - def ctype(self): - """Type (ObjC instance) of the current object""" - return self._ctype - - def __hash__(self): - return hash(self.__class__) - - def __eq__(self, other): - return (self.__class__ == other.__class__ and - self._ctype == other.ctype) - - def __ne__(self, other): - return not self.__eq__(other) - - def to_c(self): - """Generate corresponding C""" - - raise NotImplementedError("Virtual") - - def to_expr(self): - """Generate Miasm expression representing the C access""" - - raise NotImplementedError("Virtual") - - -class CGenInt(CGen): - """Int C object""" - - def __init__(self, integer): - assert isinstance(integer, int_types) - self._integer = integer - super(CGenInt, self).__init__(ObjCInt()) - - @property - def integer(self): - """Value of the object""" - return self._integer - - def __hash__(self): - return hash((super(CGenInt, self).__hash__(), self._integer)) - - def __eq__(self, other): - return (super(CGenInt, self).__eq__(other) and - self._integer == other.integer) - - def __ne__(self, other): - return not self.__eq__(other) - - def to_c(self): - """Generate corresponding C""" - - return "0x%X" % self.integer - - def __repr__(self): - return "<%s %s>" % (self.__class__.__name__, - self.integer) - - def to_expr(self): - """Generate Miasm expression representing the C access""" - - return ExprInt(self.integer, self.default_size) - - -class CGenId(CGen): - """ID of a C object""" - - def __init__(self, ctype, name): - self._name = name - assert isinstance(name, str) - super(CGenId, self).__init__(ctype) - - @property - def name(self): - """Name of the Id""" - return self._name - - def __hash__(self): - return hash((super(CGenId, self).__hash__(), self._name)) - - def __eq__(self, other): - return (super(CGenId, self).__eq__(other) and - self._name == other.name) - - def __repr__(self): - return "<%s %s>" % (self.__class__.__name__, - self.name) - - def to_c(self): - """Generate corresponding C""" - - return "%s" % (self.name) - - def to_expr(self): - """Generate Miasm expression representing the C access""" - - return ExprId(self.name, self.default_size) - - -class CGenField(CGen): - """ - Field of a C struct/union - - IN: - - struct (not ptr struct) - - field name - OUT: - - input type of the field => output type - - X[] => X[] - - X => X* - """ - - def __init__(self, struct, field, fieldtype, void_p_align, void_p_size): - self._struct = struct - self._field = field - assert isinstance(field, str) - if isinstance(fieldtype, ObjCArray): - ctype = fieldtype - else: - ctype = ObjCPtr(fieldtype, void_p_align, void_p_size) - super(CGenField, self).__init__(ctype) - - @property - def struct(self): - """Structure containing the field""" - return self._struct - - @property - def field(self): - """Field name""" - return self._field - - def __hash__(self): - return hash((super(CGenField, self).__hash__(), self._struct, self._field)) - - def __eq__(self, other): - return (super(CGenField, self).__eq__(other) and - self._struct == other.struct and - self._field == other.field) - - def to_c(self): - """Generate corresponding C""" - - if isinstance(self.ctype, ObjCArray): - return "(%s).%s" % (self.struct.to_c(), self.field) - elif isinstance(self.ctype, ObjCPtr): - return "&((%s).%s)" % (self.struct.to_c(), self.field) - else: - raise RuntimeError("Strange case") - - def __repr__(self): - return "<%s %s %s>" % (self.__class__.__name__, - self.struct, - self.field) - - def to_expr(self): - """Generate Miasm expression representing the C access""" - - if isinstance(self.ctype, ObjCArray): - return ExprOp("field", - self.struct.to_expr(), - ExprId(self.field, self.default_size)) - elif isinstance(self.ctype, ObjCPtr): - return ExprOp("addr", - ExprOp("field", - self.struct.to_expr(), - ExprId(self.field, self.default_size))) - else: - raise RuntimeError("Strange case") - - -class CGenArray(CGen): - """ - C Array - - This object does *not* deref the source, it only do object casting. - - IN: - - obj - OUT: - - X* => X* - - ..[][] => ..[] - - X[] => X* - """ - - def __init__(self, base, elems, void_p_align, void_p_size): - ctype = base.ctype - if isinstance(ctype, ObjCPtr): - pass - elif isinstance(ctype, ObjCArray) and isinstance(ctype.objtype, ObjCArray): - ctype = ctype.objtype - elif isinstance(ctype, ObjCArray): - ctype = ObjCPtr(ctype.objtype, void_p_align, void_p_size) - else: - raise TypeError("Strange case") - self._base = base - self._elems = elems - super(CGenArray, self).__init__(ctype) - - @property - def base(self): - """Base object supporting the array""" - return self._base - - @property - def elems(self): - """Number of elements in the array""" - return self._elems - - def __hash__(self): - return hash((super(CGenArray, self).__hash__(), self._base, self._elems)) - - def __eq__(self, other): - return (super(CGenField, self).__eq__(other) and - self._base == other.base and - self._elems == other.elems) - - def __repr__(self): - return "<%s %s>" % (self.__class__.__name__, - self.base) - - def to_c(self): - """Generate corresponding C""" - - if isinstance(self.ctype, ObjCPtr): - out_str = "&((%s)[%d])" % (self.base.to_c(), self.elems) - elif isinstance(self.ctype, ObjCArray): - out_str = "(%s)[%d]" % (self.base.to_c(), self.elems) - else: - raise RuntimeError("Strange case") - return out_str - - def to_expr(self): - """Generate Miasm expression representing the C access""" - - if isinstance(self.ctype, ObjCPtr): - return ExprOp("addr", - ExprOp("[]", - self.base.to_expr(), - ExprInt(self.elems, self.default_size))) - elif isinstance(self.ctype, ObjCArray): - return ExprOp("[]", - self.base.to_expr(), - ExprInt(self.elems, self.default_size)) - else: - raise RuntimeError("Strange case") - - -class CGenDeref(CGen): - """ - C dereference - - IN: - - ptr - OUT: - - X* => X - """ - - def __init__(self, ptr): - assert isinstance(ptr.ctype, ObjCPtr) - self._ptr = ptr - super(CGenDeref, self).__init__(ptr.ctype.objtype) - - @property - def ptr(self): - """Pointer object""" - return self._ptr - - def __hash__(self): - return hash((super(CGenDeref, self).__hash__(), self._ptr)) - - def __eq__(self, other): - return (super(CGenField, self).__eq__(other) and - self._ptr == other.ptr) - - def __repr__(self): - return "<%s %s>" % (self.__class__.__name__, - self.ptr) - - def to_c(self): - """Generate corresponding C""" - - if not isinstance(self.ptr.ctype, ObjCPtr): - raise RuntimeError() - return "*(%s)" % (self.ptr.to_c()) - - def to_expr(self): - """Generate Miasm expression representing the C access""" - - if not isinstance(self.ptr.ctype, ObjCPtr): - raise RuntimeError() - return ExprOp("deref", self.ptr.to_expr()) - - -def ast_get_c_access_expr(ast, expr_types, lvl=0): - """Transform C ast object into a C Miasm expression - - @ast: parsed pycparser.c_ast object - @expr_types: a dictionary linking ID names to their types - @lvl: actual recursion level - - Example: - - IN: - StructRef: -> - ID: ptr_Test - ID: a - - OUT: - ExprOp('->', ExprId('ptr_Test', 64), ExprId('a', 64)) - """ - - if isinstance(ast, c_ast.Constant): - obj = ExprInt(int(ast.value), 64) - elif isinstance(ast, c_ast.StructRef): - name, field = ast.name, ast.field.name - name = ast_get_c_access_expr(name, expr_types) - if ast.type == "->": - s_name = name - s_field = ExprId(field, 64) - obj = ExprOp('->', s_name, s_field) - elif ast.type == ".": - s_name = name - s_field = ExprId(field, 64) - obj = ExprOp("field", s_name, s_field) - else: - raise RuntimeError("Unknown struct access") - elif isinstance(ast, c_ast.UnaryOp) and ast.op == "&": - tmp = ast_get_c_access_expr(ast.expr, expr_types, lvl + 1) - obj = ExprOp("addr", tmp) - elif isinstance(ast, c_ast.ArrayRef): - tmp = ast_get_c_access_expr(ast.name, expr_types, lvl + 1) - index = ast_get_c_access_expr(ast.subscript, expr_types, lvl + 1) - obj = ExprOp("[]", tmp, index) - elif isinstance(ast, c_ast.ID): - assert ast.name in expr_types - obj = ExprId(ast.name, 64) - elif isinstance(ast, c_ast.UnaryOp) and ast.op == "*": - tmp = ast_get_c_access_expr(ast.expr, expr_types, lvl + 1) - obj = ExprOp("deref", tmp) - else: - raise NotImplementedError("Unknown type") - return obj - - -def parse_access(c_access): - """Parse C access - - @c_access: C access string - """ - - main = ''' - int main() { - %s; - } - ''' % c_access - - parser = c_parser.CParser() - node = parser.parse(main, filename='<stdin>') - access = node.ext[-1].body.block_items[0] - return access - - -class ExprToAccessC(ExprReducer): - """ - Generate the C access object(s) for a given native Miasm expression - Example: - IN: - @32[ptr_Test] - OUT: - [<CGenDeref <CGenArray <CGenField <CGenDeref <CGenId ptr_Test>> a>>>] - - An expression may be represented by multiple accessor (due to unions). - """ - - def __init__(self, expr_types, types_mngr, enforce_strict_access=True): - """Init GenCAccess - - @expr_types: a dictionary linking ID names to their types - @types_mngr: types manager - @enforce_strict_access: If false, generate access even on expression - pointing to a middle of an object. If true, raise exception if such a - pointer is encountered - """ - - self.expr_types = expr_types - self.types_mngr = types_mngr - self.enforce_strict_access = enforce_strict_access - - def updt_expr_types(self, expr_types): - """Update expr_types - @expr_types: Dictionary associating name to type - """ - - self.expr_types = expr_types - - def cgen_access(self, cgenobj, base_type, offset, deref, lvl=0): - """Return the access(es) which lead to the element at @offset of an - object of type @base_type - - In case of no @deref, stops recursion as soon as we reached the base of - an object. - In other cases, we need to go down to the final dereferenced object - - @cgenobj: current object access - @base_type: type of main object - @offset: offset (in bytes) of the target sub object - @deref: get type for a pointer or a deref - @lvl: actual recursion level - - - IN: - - base_type: struct Toto{ - int a - int b - } - - base_name: var - - 4 - OUT: - - CGenField(var, b) - - - - IN: - - base_type: int a - - 0 - OUT: - - CGenAddr(a) - - IN: - - base_type: X = int* a - - 0 - OUT: - - CGenAddr(X) - - IN: - - X = int* a - - 8 - OUT: - - ASSERT - - - IN: - - struct toto{ - int a - int b[10] - } - - 8 - OUT: - - CGenArray(CGenField(toto, b), 1) - """ - if base_type.size == 0: - missing_definition(base_type) - return set() - - - void_type = self.types_mngr.void_ptr - if isinstance(base_type, ObjCStruct): - if not 0 <= offset < base_type.size: - return set() - - if offset == 0 and not deref: - # In this case, return the struct* - return set([cgenobj]) - - for fieldname, subtype, field_offset, size in base_type.fields: - if not field_offset <= offset < field_offset + size: - continue - fieldptr = CGenField(CGenDeref(cgenobj), fieldname, subtype, - void_type.align, void_type.size) - new_type = self.cgen_access(fieldptr, subtype, - offset - field_offset, - deref, lvl + 1) - break - else: - return set() - elif isinstance(base_type, ObjCArray): - if base_type.objtype.size == 0: - missing_definition(base_type.objtype) - return set() - element_num = offset // (base_type.objtype.size) - field_offset = offset % base_type.objtype.size - if element_num >= base_type.elems: - return set() - if offset == 0 and not deref: - # In this case, return the array - return set([cgenobj]) - - curobj = CGenArray(cgenobj, element_num, - void_type.align, - void_type.size) - if field_offset == 0: - # We point to the start of the sub object, - # return it directly - return set([curobj]) - new_type = self.cgen_access(curobj, base_type.objtype, - field_offset, deref, lvl + 1) - - elif isinstance(base_type, ObjCDecl): - if self.enforce_strict_access and offset % base_type.size != 0: - return set() - elem_num = offset // base_type.size - - nobj = CGenArray(cgenobj, elem_num, - void_type.align, void_type.size) - new_type = set([nobj]) - - elif isinstance(base_type, ObjCUnion): - if offset == 0 and not deref: - # In this case, return the struct* - return set([cgenobj]) - - out = set() - for fieldname, objtype, field_offset, size in base_type.fields: - if not field_offset <= offset < field_offset + size: - continue - field = CGenField(CGenDeref(cgenobj), fieldname, objtype, - void_type.align, void_type.size) - out.update(self.cgen_access(field, objtype, - offset - field_offset, - deref, lvl + 1)) - new_type = out - - elif isinstance(base_type, ObjCPtr): - elem_num = offset // base_type.size - if self.enforce_strict_access and offset % base_type.size != 0: - return set() - nobj = CGenArray(cgenobj, elem_num, - void_type.align, void_type.size) - new_type = set([nobj]) - - else: - raise NotImplementedError("deref type %r" % base_type) - return new_type - - def reduce_known_expr(self, node, ctxt, **kwargs): - """Generate access for known expr""" - if node.expr in ctxt: - objcs = ctxt[node.expr] - return set(CGenId(objc, str(node.expr)) for objc in objcs) - return None - - def reduce_int(self, node, **kwargs): - """Generate access for ExprInt""" - - if not isinstance(node.expr, ExprInt): - return None - return set([CGenInt(int(node.expr))]) - - def get_solo_type(self, node): - """Return the type of the @node if it has only one possible type, - different from not None. In other cases, return None. - """ - if node.info is None or len(node.info) != 1: - return None - return type(list(node.info)[0].ctype) - - def reduce_op(self, node, lvl=0, **kwargs): - """Generate access for ExprOp""" - if not (node.expr.is_op("+") or is_op_segm(node.expr)) \ - or len(node.args) != 2: - return None - type_arg1 = self.get_solo_type(node.args[1]) - if type_arg1 != ObjCInt: - return None - arg0, arg1 = node.args - if arg0.info is None: - return None - void_type = self.types_mngr.void_ptr - out = set() - if not arg1.expr.is_int(): - return None - ptr_offset = int(arg1.expr) - for info in arg0.info: - if isinstance(info.ctype, ObjCArray): - field_type = info.ctype - elif isinstance(info.ctype, ObjCPtr): - field_type = info.ctype.objtype - else: - continue - target_type = info.ctype.objtype - - # Array-like: int* ptr; ptr[1] = X - out.update(self.cgen_access(info, field_type, ptr_offset, False, lvl)) - return out - - def reduce_mem(self, node, lvl=0, **kwargs): - """Generate access for ExprMem: - * @NN[ptr<elem>] -> elem (type) - * @64[ptr<ptr<elem>>] -> ptr<elem> - * @32[ptr<struct>] -> struct.00 - """ - - if not isinstance(node.expr, ExprMem): - return None - if node.ptr.info is None: - return None - assert isinstance(node.ptr.info, set) - void_type = self.types_mngr.void_ptr - found = set() - for subcgenobj in node.ptr.info: - if isinstance(subcgenobj.ctype, ObjCArray): - nobj = CGenArray(subcgenobj, 0, - void_type.align, - void_type.size) - target = nobj.ctype.objtype - for finalcgenobj in self.cgen_access(nobj, target, 0, True, lvl): - assert isinstance(finalcgenobj.ctype, ObjCPtr) - if self.enforce_strict_access and finalcgenobj.ctype.objtype.size != node.expr.size // 8: - continue - found.add(CGenDeref(finalcgenobj)) - - elif isinstance(subcgenobj.ctype, ObjCPtr): - target = subcgenobj.ctype.objtype - # target : type(elem) - if isinstance(target, (ObjCStruct, ObjCUnion)): - for finalcgenobj in self.cgen_access(subcgenobj, target, 0, True, lvl): - target = finalcgenobj.ctype.objtype - if self.enforce_strict_access and target.size != node.expr.size // 8: - continue - found.add(CGenDeref(finalcgenobj)) - elif isinstance(target, ObjCArray): - if self.enforce_strict_access and subcgenobj.ctype.size != node.expr.size // 8: - continue - found.update(self.cgen_access(CGenDeref(subcgenobj), target, - 0, False, lvl)) - else: - if self.enforce_strict_access and target.size != node.expr.size // 8: - continue - found.add(CGenDeref(subcgenobj)) - if not found: - return None - return found - - reduction_rules = [reduce_known_expr, - reduce_int, - reduce_op, - reduce_mem, - ] - - def get_accesses(self, expr, expr_context=None): - """Generate C access(es) for the native Miasm expression @expr - @expr: native Miasm expression - @expr_context: a dictionary linking known expressions to their - types. An expression is linked to a tuple of types. - """ - if expr_context is None: - expr_context = self.expr_types - ret = self.reduce(expr, ctxt=expr_context) - if ret.info is None: - return set() - return ret.info - - -class ExprCToExpr(ExprReducer): - """Translate a Miasm expression (representing a C access) into a native - Miasm expression and its C type: - - Example: - - IN: ((ptr_struct -> f_mini) field x) - OUT: @32[ptr_struct + 0x80], int - - - Tricky cases: - Struct S0 { - int x; - int y[0x10]; - } - - Struct S1 { - int a; - S0 toto; - } - - S1* ptr; - - Case 1: - ptr->toto => ptr + 0x4 - &(ptr->toto) => ptr + 0x4 - - Case 2: - (ptr->toto).x => @32[ptr + 0x4] - &((ptr->toto).x) => ptr + 0x4 - - Case 3: - (ptr->toto).y => ptr + 0x8 - &((ptr->toto).y) => ptr + 0x8 - - Case 4: - (ptr->toto).y[1] => @32[ptr + 0x8 + 0x4] - &((ptr->toto).y[1]) => ptr + 0x8 + 0x4 - - """ - - def __init__(self, expr_types, types_mngr): - """Init ExprCAccess - - @expr_types: a dictionary linking ID names to their types - @types_mngr: types manager - """ - - self.expr_types = expr_types - self.types_mngr = types_mngr - - def updt_expr_types(self, expr_types): - """Update expr_types - @expr_types: Dictionary associating name to type - """ - - self.expr_types = expr_types - - CST = "CST" - - def reduce_known_expr(self, node, ctxt, **kwargs): - """Reduce known expressions""" - if str(node.expr) in ctxt: - objc = ctxt[str(node.expr)] - out = (node.expr, objc) - elif node.expr.is_id(): - out = (node.expr, None) - else: - out = None - return out - - def reduce_int(self, node, **kwargs): - """Reduce ExprInt""" - - if not isinstance(node.expr, ExprInt): - return None - return self.CST - - def reduce_op_memberof(self, node, **kwargs): - """Reduce -> operator""" - - if not node.expr.is_op('->'): - return None - assert len(node.args) == 2 - out = [] - assert isinstance(node.args[1].expr, ExprId) - field = node.args[1].expr.name - src, src_type = node.args[0].info - if src_type is None: - return None - assert isinstance(src_type, (ObjCPtr, ObjCArray)) - struct_dst = src_type.objtype - assert isinstance(struct_dst, ObjCStruct) - - found = False - for name, objtype, offset, _ in struct_dst.fields: - if name != field: - continue - expr = src + ExprInt(offset, src.size) - if isinstance(objtype, (ObjCArray, ObjCStruct, ObjCUnion)): - pass - else: - expr = ExprMem(expr, objtype.size * 8) - assert not found - found = True - out = (expr, objtype) - assert found - return out - - def reduce_op_field(self, node, **kwargs): - """Reduce field operator (Struct or Union)""" - - if not node.expr.is_op('field'): - return None - assert len(node.args) == 2 - out = [] - assert isinstance(node.args[1].expr, ExprId) - field = node.args[1].expr.name - src, src_type = node.args[0].info - struct_dst = src_type - - if isinstance(struct_dst, ObjCStruct): - found = False - for name, objtype, offset, _ in struct_dst.fields: - if name != field: - continue - expr = src + ExprInt(offset, src.size) - if isinstance(objtype, ObjCArray): - # Case 4 - pass - elif isinstance(objtype, (ObjCStruct, ObjCUnion)): - # Case 1 - pass - else: - # Case 2 - expr = ExprMem(expr, objtype.size * 8) - assert not found - found = True - out = (expr, objtype) - elif isinstance(struct_dst, ObjCUnion): - found = False - for name, objtype, offset, _ in struct_dst.fields: - if name != field: - continue - expr = src + ExprInt(offset, src.size) - if isinstance(objtype, ObjCArray): - # Case 4 - pass - elif isinstance(objtype, (ObjCStruct, ObjCUnion)): - # Case 1 - pass - else: - # Case 2 - expr = ExprMem(expr, objtype.size * 8) - assert not found - found = True - out = (expr, objtype) - else: - raise NotImplementedError("unknown ObjC") - assert found - return out - - def reduce_op_array(self, node, **kwargs): - """Reduce array operator""" - - if not node.expr.is_op('[]'): - return None - assert len(node.args) == 2 - out = [] - assert isinstance(node.args[1].expr, ExprInt) - cst = node.args[1].expr - src, src_type = node.args[0].info - objtype = src_type.objtype - expr = src + cst * ExprInt(objtype.size, cst.size) - if isinstance(src_type, ObjCPtr): - if isinstance(objtype, ObjCArray): - final = objtype.objtype - expr = src + cst * ExprInt(final.size, cst.size) - objtype = final - expr = ExprMem(expr, final.size * 8) - found = True - else: - expr = ExprMem(expr, objtype.size * 8) - found = True - elif isinstance(src_type, ObjCArray): - if isinstance(objtype, ObjCArray): - final = objtype - found = True - elif isinstance(objtype, ObjCStruct): - found = True - else: - expr = ExprMem(expr, objtype.size * 8) - found = True - else: - raise NotImplementedError("Unknown access" % node.expr) - assert found - out = (expr, objtype) - return out - - def reduce_op_addr(self, node, **kwargs): - """Reduce addr operator""" - - if not node.expr.is_op('addr'): - return None - assert len(node.args) == 1 - out = [] - src, src_type = node.args[0].info - - void_type = self.types_mngr.void_ptr - - if isinstance(src_type, ObjCArray): - out = (src.arg, ObjCPtr(src_type.objtype, - void_type.align, void_type.size)) - elif isinstance(src, ExprMem): - out = (src.ptr, ObjCPtr(src_type, - void_type.align, void_type.size)) - elif isinstance(src_type, ObjCStruct): - out = (src, ObjCPtr(src_type, - void_type.align, void_type.size)) - elif isinstance(src_type, ObjCUnion): - out = (src, ObjCPtr(src_type, - void_type.align, void_type.size)) - else: - raise NotImplementedError("unk type") - return out - - def reduce_op_deref(self, node, **kwargs): - """Reduce deref operator""" - - if not node.expr.is_op('deref'): - return None - out = [] - src, src_type = node.args[0].info - assert isinstance(src_type, (ObjCPtr, ObjCArray)) - void_type = self.types_mngr.void_ptr - if isinstance(src_type, ObjCPtr): - if isinstance(src_type.objtype, ObjCArray): - size = void_type.size*8 - else: - size = src_type.objtype.size * 8 - out = (ExprMem(src, size), (src_type.objtype)) - else: - size = src_type.objtype.size * 8 - out = (ExprMem(src, size), (src_type.objtype)) - return out - - reduction_rules = [reduce_known_expr, - reduce_int, - reduce_op_memberof, - reduce_op_field, - reduce_op_array, - reduce_op_addr, - reduce_op_deref, - ] - - def get_expr(self, expr, c_context): - """Translate a Miasm expression @expr (representing a C access) into a - tuple composed of a native Miasm expression and its C type. - @expr: Miasm expression (representing a C access) - @c_context: a dictionary linking known tokens (strings) to their - types. A token is linked to only one type. - """ - ret = self.reduce(expr, ctxt=c_context) - if ret.info is None: - return (None, None) - return ret.info - - -class CTypesManager(object): - """Represent a C object, without any layout information""" - - def __init__(self, types_ast, leaf_types): - self.types_ast = types_ast - self.leaf_types = leaf_types - - @property - def void_ptr(self): - """Retrieve a void* objc""" - return self.leaf_types.types.get(CTypePtr(CTypeId('void'))) - - @property - def padding(self): - """Retrieve a padding ctype""" - return CTypeId(PADDING_TYPE_NAME) - - def _get_objc(self, type_id, resolved=None, to_fix=None, lvl=0): - if resolved is None: - resolved = {} - if to_fix is None: - to_fix = [] - if type_id in resolved: - return resolved[type_id] - type_id = self.types_ast.get_type(type_id) - fixed = True - if isinstance(type_id, CTypeId): - out = self.leaf_types.types.get(type_id, None) - assert out is not None - elif isinstance(type_id, CTypeUnion): - args = [] - align_max, size_max = 0, 0 - for name, field in type_id.fields: - objc = self._get_objc(field, resolved, to_fix, lvl + 1) - resolved[field] = objc - align_max = max(align_max, objc.align) - size_max = max(size_max, objc.size) - args.append((name, objc, 0, objc.size)) - - align, size = self.union_compute_align_size(align_max, size_max) - out = ObjCUnion(type_id.name, align, size, args) - - elif isinstance(type_id, CTypeStruct): - align_max, size_max = 0, 0 - - args = [] - offset, align_max = 0, 1 - pad_index = 0 - for name, field in type_id.fields: - objc = self._get_objc(field, resolved, to_fix, lvl + 1) - resolved[field] = objc - align_max = max(align_max, objc.align) - new_offset = self.struct_compute_field_offset(objc, offset) - if new_offset - offset: - pad_name = "__PAD__%d__" % pad_index - pad_index += 1 - size = new_offset - offset - pad_objc = self._get_objc(CTypeArray(self.padding, size), resolved, to_fix, lvl + 1) - args.append((pad_name, pad_objc, offset, pad_objc.size)) - offset = new_offset - args.append((name, objc, offset, objc.size)) - offset += objc.size - - align, size = self.struct_compute_align_size(align_max, offset) - out = ObjCStruct(type_id.name, align, size, args) - - elif isinstance(type_id, CTypePtr): - target = type_id.target - out = ObjCPtr(None, self.void_ptr.align, self.void_ptr.size) - fixed = False - - elif isinstance(type_id, CTypeArray): - target = type_id.target - objc = self._get_objc(target, resolved, to_fix, lvl + 1) - resolved[target] = objc - if type_id.size is None: - # case: toto[] - # return ObjCPtr - out = ObjCPtr(objc, self.void_ptr.align, self.void_ptr.size) - else: - size = self.size_to_int(type_id.size) - if size is None: - raise RuntimeError('Enable to compute objc size') - else: - out = ObjCArray(objc, size) - assert out.size is not None and out.align is not None - elif isinstance(type_id, CTypeEnum): - # Enum are integer - return self.leaf_types.types.get(CTypeId('int')) - elif isinstance(type_id, CTypeFunc): - type_ret = self._get_objc( - type_id.type_ret, resolved, to_fix, lvl + 1) - resolved[type_id.type_ret] = type_ret - args = [] - for name, arg in type_id.args: - objc = self._get_objc(arg, resolved, to_fix, lvl + 1) - resolved[arg] = objc - args.append((name, objc)) - out = ObjCFunc(type_id.name, type_id.abi, type_ret, args, - self.void_ptr.align, self.void_ptr.size) - elif isinstance(type_id, CTypeEllipsis): - out = ObjCEllipsis() - else: - raise TypeError("Unknown type %r" % type_id.__class__) - if not isinstance(out, ObjCEllipsis): - assert out.align is not None and out.size is not None - - if fixed: - resolved[type_id] = out - else: - to_fix.append((type_id, out)) - return out - - def get_objc(self, type_id): - """Get the ObjC corresponding to the CType @type_id - @type_id: CTypeBase instance""" - resolved = {} - to_fix = [] - out = self._get_objc(type_id, resolved, to_fix) - # Fix sub objects - while to_fix: - type_id, objc_to_fix = to_fix.pop() - objc = self._get_objc(type_id.target, resolved, to_fix) - objc_to_fix.objtype = objc - self.check_objc(out) - return out - - def check_objc(self, objc, done=None): - """Ensure each sub ObjC is resolved - @objc: ObjC instance""" - if done is None: - done = set() - if objc in done: - return True - done.add(objc) - if isinstance(objc, (ObjCDecl, ObjCInt, ObjCEllipsis)): - return True - elif isinstance(objc, (ObjCPtr, ObjCArray)): - assert self.check_objc(objc.objtype, done) - return True - elif isinstance(objc, (ObjCStruct, ObjCUnion)): - for _, field, _, _ in objc.fields: - assert self.check_objc(field, done) - return True - elif isinstance(objc, ObjCFunc): - assert self.check_objc(objc.type_ret, done) - for name, arg in objc.args: - assert self.check_objc(arg, done) - return True - else: - assert False - - def size_to_int(self, size): - """Resolve an array size - @size: CTypeOp or integer""" - if isinstance(size, CTypeOp): - assert len(size.args) == 2 - arg0, arg1 = [self.size_to_int(arg) for arg in size.args] - if size.operator == "+": - return arg0 + arg1 - elif size.operator == "-": - return arg0 - arg1 - elif size.operator == "*": - return arg0 * arg1 - elif size.operator == "/": - return arg0 // arg1 - elif size.operator == "<<": - return arg0 << arg1 - elif size.operator == ">>": - return arg0 >> arg1 - else: - raise ValueError("Unknown operator %s" % size.operator) - elif isinstance(size, int_types): - return size - elif isinstance(size, CTypeSizeof): - obj = self._get_objc(size.target) - return obj.size - else: - raise TypeError("Unknown size type") - - def struct_compute_field_offset(self, obj, offset): - """Compute the offset of the field @obj in the current structure""" - raise NotImplementedError("Abstract method") - - def struct_compute_align_size(self, align_max, size): - """Compute the alignment and size of the current structure""" - raise NotImplementedError("Abstract method") - - def union_compute_align_size(self, align_max, size): - """Compute the alignment and size of the current union""" - raise NotImplementedError("Abstract method") - - -class CTypesManagerNotPacked(CTypesManager): - """Store defined C types (not packed)""" - - def struct_compute_field_offset(self, obj, offset): - """Compute the offset of the field @obj in the current structure - (not packed)""" - - if obj.align > 1: - offset = (offset + obj.align - 1) & ~(obj.align - 1) - return offset - - def struct_compute_align_size(self, align_max, size): - """Compute the alignment and size of the current structure - (not packed)""" - if align_max > 1: - size = (size + align_max - 1) & ~(align_max - 1) - return align_max, size - - def union_compute_align_size(self, align_max, size): - """Compute the alignment and size of the current union - (not packed)""" - return align_max, size - - -class CTypesManagerPacked(CTypesManager): - """Store defined C types (packed form)""" - - def struct_compute_field_offset(self, _, offset): - """Compute the offset of the field @obj in the current structure - (packed form)""" - return offset - - def struct_compute_align_size(self, _, size): - """Compute the alignment and size of the current structure - (packed form)""" - return 1, size - - def union_compute_align_size(self, align_max, size): - """Compute the alignment and size of the current union - (packed form)""" - return 1, size - - -class CHandler(object): - """ - C manipulator for Miasm - Miasm expr <-> C - """ - - exprCToExpr_cls = ExprCToExpr - exprToAccessC_cls = ExprToAccessC - - def __init__(self, types_mngr, expr_types=None, - C_types=None, - simplify_c=access_simplifier, - enforce_strict_access=True): - self.exprc2expr = self.exprCToExpr_cls(expr_types, types_mngr) - self.access_c_gen = self.exprToAccessC_cls(expr_types, - types_mngr, - enforce_strict_access) - self.types_mngr = types_mngr - self.simplify_c = simplify_c - if expr_types is None: - expr_types = {} - self.expr_types = expr_types - if C_types is None: - C_types = {} - self.C_types = C_types - - def updt_expr_types(self, expr_types): - """Update expr_types - @expr_types: Dictionary associating name to type - """ - - self.expr_types = expr_types - self.exprc2expr.updt_expr_types(expr_types) - self.access_c_gen.updt_expr_types(expr_types) - - def expr_to_c_access(self, expr, expr_context=None): - """Generate the C access object(s) for a given native Miasm expression. - @expr: Miasm expression - @expr_context: a dictionary linking known expressions to a set of types - """ - - if expr_context is None: - expr_context = self.expr_types - return self.access_c_gen.get_accesses(expr, expr_context) - - - def expr_to_c_and_types(self, expr, expr_context=None): - """Generate the C access string and corresponding type for a given - native Miasm expression. - @expr_context: a dictionary linking known expressions to a set of types - """ - - accesses = set() - for access in self.expr_to_c_access(expr, expr_context): - c_str = access_str(access.to_expr().visit(self.simplify_c)) - accesses.add((c_str, access.ctype)) - return accesses - - def expr_to_c(self, expr, expr_context=None): - """Convert a Miasm @expr into it's C equivalent string - @expr_context: a dictionary linking known expressions to a set of types - """ - - return set(access[0] - for access in self.expr_to_c_and_types(expr, expr_context)) - - def expr_to_types(self, expr, expr_context=None): - """Get the possible types of the Miasm @expr - @expr_context: a dictionary linking known expressions to a set of types - """ - - return set(access.ctype - for access in self.expr_to_c_access(expr, expr_context)) - - def c_to_expr_and_type(self, c_str, c_context=None): - """Convert a C string expression to a Miasm expression and it's - corresponding c type - @c_str: C string - @c_context: (optional) dictionary linking known tokens (strings) to its - type. - """ - - ast = parse_access(c_str) - if c_context is None: - c_context = self.C_types - access_c = ast_get_c_access_expr(ast, c_context) - return self.exprc2expr.get_expr(access_c, c_context) - - def c_to_expr(self, c_str, c_context=None): - """Convert a C string expression to a Miasm expression - @c_str: C string - @c_context: (optional) dictionary linking known tokens (strings) to its - type. - """ - - if c_context is None: - c_context = self.C_types - expr, _ = self.c_to_expr_and_type(c_str, c_context) - return expr - - def c_to_type(self, c_str, c_context=None): - """Get the type of a C string expression - @expr: Miasm expression - @c_context: (optional) dictionary linking known tokens (strings) to its - type. - """ - - if c_context is None: - c_context = self.C_types - _, ctype = self.c_to_expr_and_type(c_str, c_context) - return ctype - - -class CLeafTypes(object): - """Define C types sizes/alignment for a given architecture""" - pass diff --git a/miasm/core/parse_asm.py b/miasm/core/parse_asm.py deleted file mode 100644 index 79ef416d..00000000 --- a/miasm/core/parse_asm.py +++ /dev/null @@ -1,288 +0,0 @@ -#-*- coding:utf-8 -*- -import re -import codecs -from builtins import range - -from miasm.core.utils import force_str -from miasm.expression.expression import ExprId, ExprInt, ExprOp, LocKey -import miasm.core.asmblock as asmblock -from miasm.core.cpu import instruction, base_expr -from miasm.core.asm_ast import AstInt, AstId, AstOp - -declarator = {'byte': 8, - 'word': 16, - 'dword': 32, - 'qword': 64, - 'long': 32, - } - -size2pck = {8: 'B', - 16: 'H', - 32: 'I', - 64: 'Q', - } - -EMPTY_RE = re.compile(r'\s*$') -COMMENT_RE = re.compile(r'\s*;\S*') -LOCAL_LABEL_RE = re.compile(r'\s*(\.L\S+)\s*:') -DIRECTIVE_START_RE = re.compile(r'\s*\.') -DIRECTIVE_RE = re.compile(r'\s*\.(\S+)') -LABEL_RE = re.compile(r'\s*(\S+)\s*:') -FORGET_LABEL_RE = re.compile(r'\s*\.LF[BE]\d\s*:') - - -class Directive(object): - - """Stand for Directive""" - - pass - -class DirectiveAlign(Directive): - - """Stand for alignment representation""" - - def __init__(self, alignment=1): - self.alignment = alignment - - def __str__(self): - return "Alignment %s" % self.alignment - - -class DirectiveSplit(Directive): - - """Stand for alignment representation""" - - pass - - -class DirectiveDontSplit(Directive): - - """Stand for alignment representation""" - - pass - - -STATE_NO_BLOC = 0 -STATE_IN_BLOC = 1 - - -def asm_ast_to_expr_with_size(arg, loc_db, size): - if isinstance(arg, AstId): - return ExprId(force_str(arg.name), size) - if isinstance(arg, AstOp): - args = [asm_ast_to_expr_with_size(tmp, loc_db, size) for tmp in arg.args] - return ExprOp(arg.op, *args) - if isinstance(arg, AstInt): - return ExprInt(arg.value, size) - return None - -def parse_txt(mnemo, attrib, txt, loc_db): - """Parse an assembly listing. Returns an AsmCfg instance - - @mnemo: architecture used - @attrib: architecture attribute - @txt: assembly listing - @loc_db: the LocationDB instance used to handle labels of the listing - - """ - - C_NEXT = asmblock.AsmConstraint.c_next - C_TO = asmblock.AsmConstraint.c_to - - lines = [] - # parse each line - for line in txt.split('\n'): - # empty - if EMPTY_RE.match(line): - continue - # comment - if COMMENT_RE.match(line): - continue - # labels to forget - if FORGET_LABEL_RE.match(line): - continue - # label beginning with .L - match_re = LABEL_RE.match(line) - if match_re: - label_name = match_re.group(1) - label = loc_db.get_or_create_name_location(label_name) - lines.append(label) - continue - # directive - if DIRECTIVE_START_RE.match(line): - match_re = DIRECTIVE_RE.match(line) - directive = match_re.group(1) - if directive in ['text', 'data', 'bss']: - continue - if directive in ['string', 'ascii']: - # XXX HACK - line = line.replace(r'\n', '\n').replace(r'\r', '\r') - raw = line[line.find(r'"') + 1:line.rfind(r'"')] - raw = codecs.escape_decode(raw)[0] - if directive == 'string': - raw += b"\x00" - lines.append(asmblock.AsmRaw(raw)) - continue - if directive == 'ustring': - # XXX HACK - line = line.replace(r'\n', '\n').replace(r'\r', '\r') - raw = line[line.find(r'"') + 1:line.rfind(r'"')] + "\x00" - raw = codecs.escape_decode(raw)[0] - out = b'' - for i in range(len(raw)): - out += raw[i:i+1] + b'\x00' - lines.append(asmblock.AsmRaw(out)) - continue - if directive in declarator: - data_raw = line[match_re.end():].split(' ', 1)[1] - data_raw = data_raw.split(',') - size = declarator[directive] - expr_list = [] - - # parser - - for element in data_raw: - element = element.strip() - element_parsed = base_expr.parseString(element)[0] - element_expr = asm_ast_to_expr_with_size(element_parsed, loc_db, size) - expr_list.append(element_expr) - - raw_data = asmblock.AsmRaw(expr_list) - raw_data.element_size = size - lines.append(raw_data) - continue - if directive == 'comm': - # TODO - continue - if directive == 'split': # custom command - lines.append(DirectiveSplit()) - continue - if directive == 'dontsplit': # custom command - lines.append(DirectiveDontSplit()) - continue - if directive == "align": - align_value = int(line[match_re.end():], 0) - lines.append(DirectiveAlign(align_value)) - continue - if directive in ['file', 'intel_syntax', 'globl', 'local', - 'type', 'size', 'align', 'ident', 'section']: - continue - if directive[0:4] == 'cfi_': - continue - - raise ValueError("unknown directive %s" % directive) - - # label - match_re = LABEL_RE.match(line) - if match_re: - label_name = match_re.group(1) - label = loc_db.get_or_create_name_location(label_name) - lines.append(label) - continue - - # code - if ';' in line: - line = line[:line.find(';')] - line = line.strip(' ').strip('\t') - instr = mnemo.fromstring(line, loc_db, attrib) - lines.append(instr) - - asmblock.log_asmblock.info("___pre asm oki___") - # make asmcfg - - cur_block = None - state = STATE_NO_BLOC - i = 0 - asmcfg = asmblock.AsmCFG(loc_db) - block_to_nlink = None - delayslot = 0 - while i < len(lines): - if delayslot: - delayslot -= 1 - if delayslot == 0: - state = STATE_NO_BLOC - line = lines[i] - # no current block - if state == STATE_NO_BLOC: - if isinstance(line, DirectiveDontSplit): - block_to_nlink = cur_block - i += 1 - continue - elif isinstance(line, DirectiveSplit): - block_to_nlink = None - i += 1 - continue - elif not isinstance(line, LocKey): - # First line must be a label. If it's not the case, generate - # it. - loc = loc_db.add_location() - cur_block = asmblock.AsmBlock(loc_db, loc, alignment=mnemo.alignment) - else: - cur_block = asmblock.AsmBlock(loc_db, line, alignment=mnemo.alignment) - i += 1 - # Generate the current block - asmcfg.add_block(cur_block) - state = STATE_IN_BLOC - if block_to_nlink: - block_to_nlink.addto( - asmblock.AsmConstraint( - cur_block.loc_key, - C_NEXT - ) - ) - block_to_nlink = None - continue - - # in block - elif state == STATE_IN_BLOC: - if isinstance(line, DirectiveSplit): - state = STATE_NO_BLOC - block_to_nlink = None - elif isinstance(line, DirectiveDontSplit): - state = STATE_NO_BLOC - block_to_nlink = cur_block - elif isinstance(line, DirectiveAlign): - cur_block.alignment = line.alignment - elif isinstance(line, asmblock.AsmRaw): - cur_block.addline(line) - block_to_nlink = cur_block - elif isinstance(line, LocKey): - if block_to_nlink: - cur_block.addto( - asmblock.AsmConstraint(line, C_NEXT) - ) - block_to_nlink = None - state = STATE_NO_BLOC - continue - # instruction - elif isinstance(line, instruction): - cur_block.addline(line) - block_to_nlink = cur_block - if not line.breakflow(): - i += 1 - continue - if delayslot: - raise RuntimeError("Cannot have breakflow in delayslot") - if line.dstflow(): - for dst in line.getdstflow(loc_db): - if not isinstance(dst, ExprId): - continue - if dst in mnemo.regs.all_regs_ids: - continue - cur_block.addto(asmblock.AsmConstraint(dst.name, C_TO)) - - if not line.splitflow(): - block_to_nlink = None - - delayslot = line.delayslot + 1 - else: - raise RuntimeError("unknown class %s" % line.__class__) - i += 1 - - for block in asmcfg.blocks: - # Fix multiple constraints - block.fix_constraints() - - # Log block - asmblock.log_asmblock.info(block) - return asmcfg diff --git a/miasm/core/sembuilder.py b/miasm/core/sembuilder.py deleted file mode 100644 index 9843ee6a..00000000 --- a/miasm/core/sembuilder.py +++ /dev/null @@ -1,341 +0,0 @@ -"Helper to quickly build instruction's semantic side effects" - -import inspect -import ast -import re - -from future.utils import PY3 - -import miasm.expression.expression as m2_expr -from miasm.ir.ir import IRBlock, AssignBlock - - -class MiasmTransformer(ast.NodeTransformer): - """AST visitor translating DSL to Miasm expression - - memX[Y] -> ExprMem(Y, X) - iX(Y) -> ExprIntX(Y) - X if Y else Z -> ExprCond(Y, X, Z) - 'X'(Y) -> ExprOp('X', Y) - ('X' % Y)(Z) -> ExprOp('X' % Y, Z) - {a, b} -> ExprCompose(((a, 0, a.size), (b, a.size, a.size + b.size))) - """ - - # Parsers - parse_integer = re.compile(r"^i([0-9]+)$") - parse_mem = re.compile(r"^mem([0-9]+)$") - - # Visitors - def visit_Call(self, node): - """iX(Y) -> ExprIntX(Y), - 'X'(Y) -> ExprOp('X', Y), ('X' % Y)(Z) -> ExprOp('X' % Y, Z)""" - - # Recursive visit - node = self.generic_visit(node) - if isinstance(node.func, ast.Name): - # iX(Y) -> ExprInt(Y, X) - fc_name = node.func.id - - # Match the function name - new_name = fc_name - integer = self.parse_integer.search(fc_name) - - # Do replacement - if integer is not None: - size = int(integer.groups()[0]) - new_name = "ExprInt" - # Replace in the node - node.func.id = new_name - node.args.append(ast.Num(n=size)) - - elif (isinstance(node.func, ast.Str) or - (isinstance(node.func, ast.BinOp) and - isinstance(node.func.op, ast.Mod) and - isinstance(node.func.left, ast.Str))): - # 'op'(args...) -> ExprOp('op', args...) - # ('op' % (fmt))(args...) -> ExprOp('op' % (fmt), args...) - op_name = node.func - - # Do replacement - node.func = ast.Name(id="ExprOp", ctx=ast.Load()) - node.args[0:0] = [op_name] - - return node - - def visit_IfExp(self, node): - """X if Y else Z -> ExprCond(Y, X, Z)""" - # Recursive visit - node = self.generic_visit(node) - - # Build the new ExprCond - call = ast.Call(func=ast.Name(id='ExprCond', ctx=ast.Load()), - args=[self.visit(node.test), - self.visit(node.body), - self.visit(node.orelse)], - keywords=[], starargs=None, kwargs=None) - return call - - def visit_Set(self, node): - "{a, b} -> ExprCompose(a, b)" - if len(node.elts) == 0: - return node - - # Recursive visit - node = self.generic_visit(node) - - return ast.Call(func=ast.Name(id='ExprCompose', - ctx=ast.Load()), - args=node.elts, - keywords=[], - starargs=None, - kwargs=None) - -if PY3: - def get_arg_name(name): - return name.arg - def gen_arg(name, ctx): - return ast.arg(arg=name, ctx=ctx) -else: - def get_arg_name(name): - return name.id - def gen_arg(name, ctx): - return ast.Name(id=name, ctx=ctx) - - -class SemBuilder(object): - """Helper for building instruction's semantic side effects method - - This class provides a decorator @parse to use on them. - The context in which the function will be parsed must be supplied on - instantiation - """ - - def __init__(self, ctx): - """Create a SemBuilder - @ctx: context dictionary used during parsing - """ - # Init - self.transformer = MiasmTransformer() - self._ctx = dict(m2_expr.__dict__) - self._ctx["IRBlock"] = IRBlock - self._ctx["AssignBlock"] = AssignBlock - self._functions = {} - - # Update context - self._ctx.update(ctx) - - @property - def functions(self): - """Return a dictionary name -> func of parsed functions""" - return self._functions.copy() - - @staticmethod - def _create_labels(loc_else=False): - """Return the AST standing for label creations - @loc_else (optional): if set, create a label 'loc_else'""" - loc_end = "loc_end = ir.get_next_loc_key(instr)" - loc_end_expr = "loc_end_expr = ExprLoc(loc_end, ir.IRDst.size)" - out = ast.parse(loc_end).body - out += ast.parse(loc_end_expr).body - loc_if = "loc_if = ir.loc_db.add_location()" - loc_if_expr = "loc_if_expr = ExprLoc(loc_if, ir.IRDst.size)" - out += ast.parse(loc_if).body - out += ast.parse(loc_if_expr).body - if loc_else: - loc_else = "loc_else = ir.loc_db.add_location()" - loc_else_expr = "loc_else_expr = ExprLoc(loc_else, ir.IRDst.size)" - out += ast.parse(loc_else).body - out += ast.parse(loc_else_expr).body - return out - - def _parse_body(self, body, argument_names): - """Recursive function transforming a @body to a block expression - Return: - - AST to append to body (real python statements) - - a list of blocks, ie list of affblock, ie list of ExprAssign (AST)""" - - # Init - ## Real instructions - real_body = [] - ## Final blocks - blocks = [[[]]] - - for statement in body: - - if isinstance(statement, ast.Assign): - src = self.transformer.visit(statement.value) - dst = self.transformer.visit(statement.targets[0]) - - if (isinstance(dst, ast.Name) and - dst.id not in argument_names and - dst.id not in self._ctx and - dst.id not in self._local_ctx): - - # Real variable declaration - statement.value = src - real_body.append(statement) - self._local_ctx[dst.id] = src - continue - - dst.ctx = ast.Load() - - res = ast.Call(func=ast.Name(id='ExprAssign', - ctx=ast.Load()), - args=[dst, src], - keywords=[], - starargs=None, - kwargs=None) - - blocks[-1][-1].append(res) - - elif (isinstance(statement, ast.Expr) and - isinstance(statement.value, ast.Str)): - # String (docstring, comment, ...) -> keep it - real_body.append(statement) - - elif isinstance(statement, ast.If): - # Create jumps : ir.IRDst = loc_if if cond else loc_end - # if .. else .. are also handled - cond = statement.test - real_body += self._create_labels(loc_else=True) - - loc_end = ast.Name(id='loc_end_expr', ctx=ast.Load()) - loc_if = ast.Name(id='loc_if_expr', ctx=ast.Load()) - loc_else = ast.Name(id='loc_else_expr', ctx=ast.Load()) \ - if statement.orelse else loc_end - dst = ast.Call(func=ast.Name(id='ExprCond', - ctx=ast.Load()), - args=[cond, - loc_if, - loc_else], - keywords=[], - starargs=None, - kwargs=None) - - if (isinstance(cond, ast.UnaryOp) and - isinstance(cond.op, ast.Not)): - ## if not cond -> switch exprCond - dst.args[1:] = dst.args[1:][::-1] - dst.args[0] = cond.operand - - IRDst = ast.Attribute(value=ast.Name(id='ir', - ctx=ast.Load()), - attr='IRDst', ctx=ast.Load()) - loc_db = ast.Attribute(value=ast.Name(id='ir', - ctx=ast.Load()), - attr='loc_db', ctx=ast.Load()) - blocks[-1][-1].append(ast.Call(func=ast.Name(id='ExprAssign', - ctx=ast.Load()), - args=[IRDst, dst], - keywords=[], - starargs=None, - kwargs=None)) - - # Create the new blocks - elements = [(statement.body, 'loc_if')] - if statement.orelse: - elements.append((statement.orelse, 'loc_else')) - for content, loc_name in elements: - sub_blocks, sub_body = self._parse_body(content, - argument_names) - if len(sub_blocks) > 1: - raise RuntimeError("Imbricated if unimplemented") - - ## Close the last block - jmp_end = ast.Call(func=ast.Name(id='ExprAssign', - ctx=ast.Load()), - args=[IRDst, loc_end], - keywords=[], - starargs=None, - kwargs=None) - sub_blocks[-1][-1].append(jmp_end) - - - instr = ast.Name(id='instr', ctx=ast.Load()) - effects = ast.List(elts=sub_blocks[-1][-1], - ctx=ast.Load()) - assignblk = ast.Call(func=ast.Name(id='AssignBlock', - ctx=ast.Load()), - args=[effects, instr], - keywords=[], - starargs=None, - kwargs=None) - - - ## Replace the block with a call to 'IRBlock' - loc_if_name = ast.Name(id=loc_name, ctx=ast.Load()) - - assignblks = ast.List(elts=[assignblk], - ctx=ast.Load()) - - sub_blocks[-1] = ast.Call(func=ast.Name(id='IRBlock', - ctx=ast.Load()), - args=[ - loc_db, - loc_if_name, - assignblks - ], - keywords=[], - starargs=None, - kwargs=None) - blocks += sub_blocks - real_body += sub_body - - # Prepare a new block for following statement - blocks.append([[]]) - - else: - # TODO: real var, +=, /=, -=, <<=, >>=, if/else, ... - raise RuntimeError("Unimplemented %s" % statement) - - return blocks, real_body - - def parse(self, func): - """Function decorator, returning a correct method from a pseudo-Python - one""" - - # Get the function AST - parsed = ast.parse(inspect.getsource(func)) - fc_ast = parsed.body[0] - argument_names = [get_arg_name(name) for name in fc_ast.args.args] - - # Init local cache - self._local_ctx = {} - - # Translate (blocks[0][0] is the current instr) - blocks, body = self._parse_body(fc_ast.body, argument_names) - - # Build the new function - fc_ast.args.args[0:0] = [ - gen_arg('ir', ast.Param()), - gen_arg('instr', ast.Param()) - ] - cur_instr = blocks[0][0] - if len(blocks[-1][0]) == 0: - ## Last block can be empty - blocks.pop() - other_blocks = blocks[1:] - body.append(ast.Return(value=ast.Tuple(elts=[ast.List(elts=cur_instr, - ctx=ast.Load()), - ast.List(elts=other_blocks, - ctx=ast.Load())], - ctx=ast.Load()))) - - ret = ast.parse('') - ret.body = [ast.FunctionDef(name=fc_ast.name, - args=fc_ast.args, - body=body, - decorator_list=[])] - - # To display the generated function, use codegen.to_source - # codegen: https://github.com/andreif/codegen - - # Compile according to the context - fixed = ast.fix_missing_locations(ret) - codeobj = compile(fixed, '<string>', 'exec') - ctx = self._ctx.copy() - eval(codeobj, ctx) - - # Get the function back - self._functions[fc_ast.name] = ctx[fc_ast.name] - return ctx[fc_ast.name] diff --git a/miasm/core/types.py b/miasm/core/types.py deleted file mode 100644 index 4f99627d..00000000 --- a/miasm/core/types.py +++ /dev/null @@ -1,1693 +0,0 @@ -"""This module provides classes to manipulate pure C types as well as their -representation in memory. A typical usecase is to use this module to -easily manipylate structures backed by a VmMngr object (a miasm sandbox virtual -memory): - - class ListNode(MemStruct): - fields = [ - ("next", Ptr("<I", Self())), - ("data", Ptr("<I", Void())), - ] - - class LinkedList(MemStruct): - fields = [ - ("head", Ptr("<I", ListNode)), - ("tail", Ptr("<I", ListNode)), - ("size", Num("<I")), - ] - - link = LinkedList(vm, addr1) - link.memset() - node = ListNode(vm, addr2) - node.memset() - link.head = node.get_addr() - link.tail = node.get_addr() - link.size += 1 - assert link.head.deref == node - data = Num("<I").lval(vm, addr3) - data.val = 5 - node.data = data.get_addr() - # see examples/jitter/types.py for more info - - -It provides two families of classes, Type-s (Num, Ptr, Str...) and their -associated MemType-s. A Type subclass instance represents a fully defined C -type. A MemType subclass instance represents a C LValue (or variable): it is -a type attached to the memory. Available types are: - - - Num: for number (float or int) handling - - Ptr: a pointer to another Type - - Struct: equivalent to a C struct definition - - Union: similar to union in C, list of Types at the same offset in a - structure; the union has the size of the biggest Type (~ Struct with all - the fields at offset 0) - - Array: an array of items of the same type; can have a fixed size or - not (e.g. char[3] vs char* used as an array in C) - - BitField: similar to C bitfields, a list of - [(<field_name>, <number_of_bits>),]; creates fields that correspond to - certain bits of the field; analogous to a Union of Bits (see Bits below) - - Str: a character string, with an encoding; not directly mapped to a C - type, it is a higher level notion provided for ease of use - - Void: analogous to C void, can be a placeholder in void*-style cases. - - Self: special marker to reference a Struct inside itself (FIXME: to - remove?) - -And some less common types: - - - Bits: mask only some bits of a Num - - RawStruct: abstraction over a simple struct pack/unpack (no mapping to a - standard C type) - -For each type, the `.lval` property returns a MemType subclass that -allows to access the field in memory. - - -The easiest way to use the API to declare and manipulate new structures is to -subclass MemStruct and define a list of (<field_name>, <field_definition>): - - class MyStruct(MemStruct): - fields = [ - # Scalar field: just struct.pack field with one value - ("num", Num("I")), - ("flags", Num("B")), - # Ptr fields contain two fields: "val", for the numerical value, - # and "deref" to get the pointed object - ("other", Ptr("I", OtherStruct)), - # Ptr to a variable length String - ("s", Ptr("I", Str())), - ("i", Ptr("I", Num("I"))), - ] - -And access the fields: - - mstruct = MyStruct(jitter.vm, addr) - mstruct.num = 3 - assert mstruct.num == 3 - mstruct.other.val = addr2 - # Also works: - mstruct.other = addr2 - mstruct.other.deref = OtherStruct(jitter.vm, addr) - -MemUnion and MemBitField can also be subclassed, the `fields` field being -in the format expected by, respectively, Union and BitField. - -The `addr` argument can be omitted if an allocator is set, in which case the -structure will be automatically allocated in memory: - - my_heap = miasm.os_dep.common.heap() - # the allocator is a func(VmMngr) -> integer_address - set_allocator(my_heap) - -Note that some structures (e.g. MemStr or MemArray) do not have a static -size and cannot be allocated automatically. -""" - -from builtins import range, zip -from builtins import int as int_types -import itertools -import logging -import struct -from future.utils import PY3 -from future.utils import viewitems, with_metaclass - -log = logging.getLogger(__name__) -console_handler = logging.StreamHandler() -console_handler.setFormatter(logging.Formatter("[%(levelname)-8s]: %(message)s")) -log.addHandler(console_handler) -log.setLevel(logging.WARN) - -# Cache for dynamically generated MemTypes -DYN_MEM_STRUCT_CACHE = {} - -def set_allocator(alloc_func): - """Shorthand to set the default allocator of MemType. See - MemType.set_allocator doc for more information. - """ - MemType.set_allocator(alloc_func) - - -# Helpers - -def to_type(obj): - """If possible, return the Type associated with @obj, otherwise raises - a ValueError. - - Works with a Type instance (returns obj) or a MemType subclass or instance - (returns obj.get_type()). - """ - # obj is a python type - if isinstance(obj, type): - if issubclass(obj, MemType): - if obj.get_type() is None: - raise ValueError("%r has no static type; use a subclasses " - "with a non null _type or use a " - "Type instance" % obj) - return obj.get_type() - # obj is not not a type - else: - if isinstance(obj, Type): - return obj - elif isinstance(obj, MemType): - return obj.get_type() - raise ValueError("%r is not a Type or a MemType" % obj) - -def indent(s, size=4): - """Indent a string with @size spaces""" - return ' '*size + ('\n' + ' '*size).join(s.split('\n')) - - -# String generic getter/setter/len-er -# TODO: make miasm.os_dep.common and jitter ones use these ones - -def get_str(vm, addr, enc, max_char=None, end=u'\x00'): - """Get a @end (by default '\\x00') terminated @enc encoded string from a - VmMngr. - - For example: - - get_str(vm, addr, "ascii") will read "foo\\x00" in memory and - return u"foo" - - get_str(vm, addr, "utf-16le") will read "f\\x00o\\x00o\\x00\\x00\\x00" - in memory and return u"foo" as well. - - Setting @max_char=<n> and @end='' allows to read non null terminated strings - from memory. - - @vm: VmMngr instance - @addr: the address at which to read the string - @enc: the encoding of the string to read. - @max_char: max number of bytes to get in memory - @end: the unencoded ending sequence of the string, by default "\\x00". - Unencoded here means that the actual ending sequence that this function - will look for is end.encode(enc), not directly @end. - """ - s = [] - end_char= end.encode(enc) - step = len(end_char) - i = 0 - while max_char is None or i < max_char: - c = vm.get_mem(addr + i, step) - if c == end_char: - break - s.append(c) - i += step - return b''.join(s).decode(enc) - -def raw_str(s, enc, end=u'\x00'): - """Returns a string representing @s as an @end (by default \\x00) - terminated @enc encoded string. - - @s: the unicode str to serialize - @enc: the encoding to apply to @s and @end before serialization. - @end: the ending string/character to append to the string _before encoding_ - and serialization (by default '\\x00') - """ - return (s + end).encode(enc) - -def set_str(vm, addr, s, enc, end=u'\x00'): - """Encode a string to an @end (by default \\x00) terminated @enc encoded - string and set it in a VmMngr memory. - - @vm: VmMngr instance - @addr: start address to serialize the string to - @s: the unicode str to serialize - @enc: the encoding to apply to @s and @end before serialization. - @end: the ending string/character to append to the string _before encoding_ - and serialization (by default '\\x00') - """ - s = raw_str(s, enc, end=end) - vm.set_mem(addr, s) - -def raw_len(py_unic_str, enc, end=u'\x00'): - """Returns the length in bytes of @py_unic_str in memory (once @end has been - added and the full str has been encoded). It returns exactly the room - necessary to call set_str with similar arguments. - - @py_unic_str: the unicode str to work with - @enc: the encoding to encode @py_unic_str to - @end: the ending string/character to append to the string _before encoding_ - (by default \\x00) - """ - return len(raw_str(py_unic_str, enc)) - -def enc_triplet(enc, max_char=None, end=u'\x00'): - """Returns a triplet of functions (get_str_enc, set_str_enc, raw_len_enc) - for a given encoding (as needed by Str to add an encoding). The prototypes - are: - - - get_str_end: same as get_str without the @enc argument - - set_str_end: same as set_str without the @enc argument - - raw_len_enc: same as raw_len without the @enc argument - """ - return ( - lambda vm, addr, max_char=max_char, end=end: \ - get_str(vm, addr, enc, max_char=max_char, end=end), - lambda vm, addr, s, end=end: set_str(vm, addr, s, enc, end=end), - lambda s, end=end: raw_len(s, enc, end=end), - ) - - -# Type classes - -class Type(object): - """Base class to provide methods to describe a type, as well as how to set - and get fields from virtual mem. - - Each Type subclass is linked to a MemType subclass (e.g. Struct to - MemStruct, Ptr to MemPtr, etc.). - - When nothing is specified, MemValue is used to access the type in memory. - MemValue instances have one `.val` field, setting and getting it call - the set and get of the Type. - - Subclasses can either override _pack and _unpack, or get and set if data - serialization requires more work (see Struct implementation for an example). - - TODO: move any trace of vm and addr out of these classes? - """ - - _self_type = None - - def _pack(self, val): - """Serializes the python value @val to a raw str""" - raise NotImplementedError() - - def _unpack(self, raw_str): - """Deserializes a raw str to an object representing the python value - of this field. - """ - raise NotImplementedError() - - def set(self, vm, addr, val): - """Set a VmMngr memory from a value. - - @vm: VmMngr instance - @addr: the start address in memory to set - @val: the python value to serialize in @vm at @addr - """ - raw = self._pack(val) - vm.set_mem(addr, raw) - - def get(self, vm, addr): - """Get the python value of a field from a VmMngr memory at @addr.""" - raw = vm.get_mem(addr, self.size) - return self._unpack(raw) - - @property - def lval(self): - """Returns a class with a (vm, addr) constructor that allows to - interact with this type in memory. - - In compilation terms, it returns a class allowing to instantiate an - lvalue of this type. - - @return: a MemType subclass. - """ - if self in DYN_MEM_STRUCT_CACHE: - return DYN_MEM_STRUCT_CACHE[self] - pinned_type = self._build_pinned_type() - DYN_MEM_STRUCT_CACHE[self] = pinned_type - return pinned_type - - def _build_pinned_type(self): - """Builds the MemType subclass allowing to interact with this type. - - Called by self.lval when it is not in cache. - """ - pinned_base_class = self._get_pinned_base_class() - pinned_type = type( - "Mem%r" % self, - (pinned_base_class,), - {'_type': self} - ) - return pinned_type - - def _get_pinned_base_class(self): - """Return the MemType subclass that maps this type in memory""" - return MemValue - - def _get_self_type(self): - """Used for the Self trick.""" - return self._self_type - - def _set_self_type(self, self_type): - """If this field refers to MemSelf/Self, replace it with @self_type - (a Type instance) when using it. Generally not used outside this - module. - """ - self._self_type = self_type - - @property - def size(self): - """Return the size in bytes of the serialized version of this field""" - raise NotImplementedError() - - def __len__(self): - return self.size - - def __neq__(self, other): - return not self == other - - def __eq__(self, other): - raise NotImplementedError("Abstract method") - - def __ne__(self, other): - return not self == other - - -class RawStruct(Type): - """Dumb struct.pack/unpack field. Mainly used to factorize code. - - Value is a tuple corresponding to the struct @fmt passed to the constructor. - """ - - def __init__(self, fmt): - self._fmt = fmt - - def _pack(self, fields): - return struct.pack(self._fmt, *fields) - - def _unpack(self, raw_str): - return struct.unpack(self._fmt, raw_str) - - @property - def size(self): - return struct.calcsize(self._fmt) - - def __repr__(self): - return "%s(%s)" % (self.__class__.__name__, self._fmt) - - def __eq__(self, other): - return self.__class__ == other.__class__ and self._fmt == other._fmt - - def __ne__(self, other): - return not self == other - - def __hash__(self): - return hash((self.__class__, self._fmt)) - - -class Num(RawStruct): - """Represents a number (integer or float). The number is encoded with - a struct-style format which must represent only one value. - - TODO: use u32, i16, etc. for format. - """ - - def _pack(self, number): - return super(Num, self)._pack([number]) - - def _unpack(self, raw_str): - upck = super(Num, self)._unpack(raw_str) - if len(upck) != 1: - raise ValueError("Num format string unpacks to multiple values, " - "should be 1") - return upck[0] - - -class Ptr(Num): - """Special case of number of which value indicates the address of a - MemType. - - Mapped to MemPtr (see its doc for more info): - - assert isinstance(mystruct.ptr, MemPtr) - mystruct.ptr = 0x4000 # Assign the Ptr numeric value - mystruct.ptr.val = 0x4000 # Also assigns the Ptr numeric value - assert isinstance(mystruct.ptr.val, int) # Get the Ptr numeric value - mystruct.ptr.deref # Get the pointed MemType - mystruct.ptr.deref = other # Set the pointed MemType - """ - - def __init__(self, fmt, dst_type, *type_args, **type_kwargs): - """ - @fmt: (str) Num compatible format that will be the Ptr representation - in memory - @dst_type: (MemType or Type) the Type this Ptr points to. - If a Type is given, it is transformed into a MemType with - TheType.lval. - *type_args, **type_kwargs: arguments to pass to the the pointed - MemType when instantiating it (e.g. for MemStr encoding or - MemArray field_type). - """ - if (not isinstance(dst_type, Type) and - not (isinstance(dst_type, type) and - issubclass(dst_type, MemType)) and - not dst_type == MemSelf): - raise ValueError("dst_type of Ptr must be a MemType type, a " - "Type instance, the MemSelf marker or a class " - "name.") - super(Ptr, self).__init__(fmt) - if isinstance(dst_type, Type): - # Patch the field to propagate the MemSelf replacement - dst_type._get_self_type = lambda: self._get_self_type() - # dst_type cannot be patched here, since _get_self_type of the outer - # class has not yet been set. Patching dst_type involves calling - # dst_type.lval, which will only return a type that does not point - # on MemSelf but on the right class only when _get_self_type of the - # outer class has been replaced by _MetaMemStruct. - # In short, dst_type = dst_type.lval is not valid here, it is done - # lazily in _fix_dst_type - self._dst_type = dst_type - self._type_args = type_args - self._type_kwargs = type_kwargs - - def _fix_dst_type(self): - if self._dst_type in [MemSelf, SELF_TYPE_INSTANCE]: - if self._get_self_type() is not None: - self._dst_type = self._get_self_type() - else: - raise ValueError("Unsupported usecase for (Mem)Self, sorry") - self._dst_type = to_type(self._dst_type) - - @property - def dst_type(self): - """Return the type (MemType subtype) this Ptr points to.""" - self._fix_dst_type() - return self._dst_type - - def set(self, vm, addr, val): - """A Ptr field can be set with a MemPtr or an int""" - if isinstance(val, MemType) and isinstance(val.get_type(), Ptr): - self.set_val(vm, addr, val.val) - else: - super(Ptr, self).set(vm, addr, val) - - def get(self, vm, addr): - return self.lval(vm, addr) - - def get_val(self, vm, addr): - """Get the numeric value of a Ptr""" - return super(Ptr, self).get(vm, addr) - - def set_val(self, vm, addr, val): - """Set the numeric value of a Ptr""" - return super(Ptr, self).set(vm, addr, val) - - def deref_get(self, vm, addr): - """Deserializes the data in @vm (VmMngr) at @addr to self.dst_type. - Equivalent to a pointer dereference rvalue in C. - """ - dst_addr = self.get_val(vm, addr) - return self.dst_type.lval(vm, dst_addr, - *self._type_args, **self._type_kwargs) - - def deref_set(self, vm, addr, val): - """Serializes the @val MemType subclass instance in @vm (VmMngr) at - @addr. Equivalent to a pointer dereference assignment in C. - """ - # Sanity check - if self.dst_type != val.get_type(): - log.warning("Original type was %s, overridden by value of type %s", - self._dst_type.__name__, val.__class__.__name__) - - # Actual job - dst_addr = self.get_val(vm, addr) - vm.set_mem(dst_addr, bytes(val)) - - def _get_pinned_base_class(self): - return MemPtr - - def __repr__(self): - return "%s(%r)" % (self.__class__.__name__, self.dst_type) - - def __eq__(self, other): - return super(Ptr, self).__eq__(other) and \ - self.dst_type == other.dst_type and \ - self._type_args == other._type_args and \ - self._type_kwargs == other._type_kwargs - - def __ne__(self, other): - return not self == other - - def __hash__(self): - return hash((super(Ptr, self).__hash__(), self.dst_type, - self._type_args)) - - -class Struct(Type): - """Equivalent to a C struct type. Composed of a name, and a - (<field_name (str)>, <Type_subclass_instance>) list describing the fields - of the struct. - - Mapped to MemStruct. - - NOTE: The `.lval` property of Struct creates classes on the fly. If an - equivalent structure is created by subclassing MemStruct, an exception - is raised to prevent creating multiple classes designating the same type. - - Example: - s = Struct("Toto", [("f1", Num("I")), ("f2", Num("I"))]) - - Toto1 = s.lval - - # This raises an exception, because it describes the same structure as - # Toto1 - class Toto(MemStruct): - fields = [("f1", Num("I")), ("f2", Num("I"))] - - Anonymous Struct, Union or BitField can be used if their field name - evaluates to False ("" or None). Such anonymous Struct field will generate - fields to the parent Struct, e.g.: - bla = Struct("Bla", [ - ("a", Num("B")), - ("", Union([("b1", Num("B")), ("b2", Num("H"))])), - ("", Struct("", [("c1", Num("B")), ("c2", Num("B"))])), - ] - Will have a b1, b2 and c1, c2 field directly accessible. The anonymous - fields are renamed to "__anon_<num>", with <num> an incremented number. - - In such case, bla.fields will not contain b1, b2, c1 and c2 (only the 3 - actual fields, with the anonymous ones renamed), but bla.all_fields will - return the 3 fields + b1, b2, c1 and c2 (and an information telling if it - has been generated from an anonymous Struct/Union). - - bla.get_field(vm, addr, "b1") will work. - """ - - def __init__(self, name, fields): - self.name = name - # generates self._fields and self._fields_desc - self._gen_fields(fields) - - def _gen_fields(self, fields): - """Precompute useful metadata on self.fields.""" - self._fields_desc = {} - offset = 0 - - # Build a proper (name, Field()) list, handling cases where the user - # supplies a MemType subclass instead of a Type instance - real_fields = [] - uniq_count = 0 - for fname, field in fields: - field = to_type(field) - - # For reflexion - field._set_self_type(self) - - # Anonymous Struct/Union - if not fname and isinstance(field, Struct): - # Generate field information - updated_fields = { - name: { - # Same field type than the anon field subfield - 'field': fd['field'], - # But the current offset is added - 'offset': fd['offset'] + offset, - } - for name, fd in viewitems(field._fields_desc) - } - - # Add the newly generated fields from the anon field - self._fields_desc.update(updated_fields) - real_fields += [(name, fld, True) - for name, fld in field.fields] - - # Rename the anonymous field - fname = '__anon_%x' % uniq_count - uniq_count += 1 - - self._fields_desc[fname] = {"field": field, "offset": offset} - real_fields.append((fname, field, False)) - offset = self._next_offset(field, offset) - - # fields is immutable - self._fields = tuple(real_fields) - - def _next_offset(self, field, orig_offset): - return orig_offset + field.size - - @property - def fields(self): - """Returns a sequence of (name, field) describing the fields of this - Struct, in order of offset. - - Fields generated from anonymous Unions or Structs are excluded from - this sequence. - """ - return tuple((name, field) for name, field, anon in self._fields - if not anon) - - @property - def all_fields(self): - """Returns a sequence of (<name>, <field (Type instance)>, <is_anon>), - where is_anon is True when a field is generated from an anonymous - Struct or Union, and False for the fields that have been provided as is. - """ - return self._fields - - def set(self, vm, addr, val): - raw = bytes(val) - vm.set_mem(addr, raw) - - def get(self, vm, addr): - return self.lval(vm, addr) - - def get_field(self, vm, addr, name): - """Get a field value by @name and base structure @addr in @vm VmMngr.""" - if name not in self._fields_desc: - raise ValueError("'%s' type has no field '%s'" % (self, name)) - field = self.get_field_type(name) - offset = self.get_offset(name) - return field.get(vm, addr + offset) - - def set_field(self, vm, addr, name, val): - """Set a field value by @name and base structure @addr in @vm VmMngr. - @val is the python value corresponding to this field type. - """ - if name not in self._fields_desc: - raise AttributeError("'%s' object has no attribute '%s'" - % (self.__class__.__name__, name)) - field = self.get_field_type(name) - offset = self.get_offset(name) - field.set(vm, addr + offset, val) - - @property - def size(self): - return sum(field.size for _, field in self.fields) - - def get_offset(self, field_name): - """ - @field_name: (str, optional) the name of the field to get the - offset of - """ - if field_name not in self._fields_desc: - raise ValueError("This structure has no %s field" % field_name) - return self._fields_desc[field_name]['offset'] - - def get_field_type(self, name): - """Return the Type subclass instance describing field @name.""" - return self._fields_desc[name]['field'] - - def _get_pinned_base_class(self): - return MemStruct - - def __repr__(self): - return "struct %s" % self.name - - def __eq__(self, other): - return self.__class__ == other.__class__ and \ - self.fields == other.fields and \ - self.name == other.name - - def __ne__(self, other): - return not self == other - - def __hash__(self): - # Only hash name, not fields, because if a field is a Ptr to this - # Struct type, an infinite loop occurs - return hash((self.__class__, self.name)) - - -class Union(Struct): - """Represents a C union. - - Allows to put multiple fields at the same offset in a MemStruct, - similar to unions in C. The Union will have the size of the largest of its - fields. - - Mapped to MemUnion. - - Example: - - class Example(MemStruct): - fields = [("uni", Union([ - ("f1", Num("<B")), - ("f2", Num("<H")) - ]) - )] - - ex = Example(vm, addr) - ex.uni.f2 = 0x1234 - assert ex.uni.f1 == 0x34 - """ - - def __init__(self, field_list): - """@field_list: a [(name, field)] list, see the class doc""" - super(Union, self).__init__("union", field_list) - - @property - def size(self): - return max(field.size for _, field in self.fields) - - def _next_offset(self, field, orig_offset): - return orig_offset - - def _get_pinned_base_class(self): - return MemUnion - - def __repr__(self): - fields_repr = ', '.join("%s: %r" % (name, field) - for name, field in self.fields) - return "%s(%s)" % (self.__class__.__name__, fields_repr) - - -class Array(Type): - """An array (contiguous sequence) of a Type subclass elements. - - Can be sized (similar to something like the char[10] type in C) or unsized - if no @array_len is given to the constructor (similar to char* used as an - array). - - Mapped to MemArray or MemSizedArray, depending on if the Array is - sized or not. - - Getting an array field actually returns a MemSizedArray. Setting it is - possible with either a list or a MemSizedArray instance. Examples of - syntax: - - class Example(MemStruct): - fields = [("array", Array(Num("B"), 4))] - - mystruct = Example(vm, addr) - mystruct.array[3] = 27 - mystruct.array = [1, 4, 8, 9] - mystruct.array = MemSizedArray(vm, addr2, Num("B"), 4) - """ - - def __init__(self, field_type, array_len=None): - # Handle both Type instance and MemType subclasses - self.field_type = to_type(field_type) - self.array_len = array_len - - def _set_self_type(self, self_type): - super(Array, self)._set_self_type(self_type) - self.field_type._set_self_type(self_type) - - def set(self, vm, addr, val): - # MemSizedArray assignment - if isinstance(val, MemSizedArray): - if val.array_len != self.array_len or len(val) != self.size: - raise ValueError("Size mismatch in MemSizedArray assignment") - raw = bytes(val) - vm.set_mem(addr, raw) - - # list assignment - elif isinstance(val, list): - if len(val) != self.array_len: - raise ValueError("Size mismatch in MemSizedArray assignment ") - offset = 0 - for elt in val: - self.field_type.set(vm, addr + offset, elt) - offset += self.field_type.size - - else: - raise RuntimeError( - "Assignment only implemented for list and MemSizedArray") - - def get(self, vm, addr): - return self.lval(vm, addr) - - @property - def size(self): - if self.is_sized(): - return self.get_offset(self.array_len) - else: - raise ValueError("%s is unsized, use an array with a fixed " - "array_len instead." % self) - - def get_offset(self, idx): - """Returns the offset of the item at index @idx.""" - return self.field_type.size * idx - - def get_item(self, vm, addr, idx): - """Get the item(s) at index @idx. - - @idx: int, long or slice - """ - if isinstance(idx, slice): - res = [] - idx = self._normalize_slice(idx) - for i in range(idx.start, idx.stop, idx.step): - res.append(self.field_type.get(vm, addr + self.get_offset(i))) - return res - else: - idx = self._normalize_idx(idx) - return self.field_type.get(vm, addr + self.get_offset(idx)) - - def set_item(self, vm, addr, idx, item): - """Sets one or multiple items in this array (@idx can be an int, long - or slice). - """ - if isinstance(idx, slice): - idx = self._normalize_slice(idx) - if len(item) != len(range(idx.start, idx.stop, idx.step)): - raise ValueError("Mismatched lengths in slice assignment") - for i, val in zip(range(idx.start, idx.stop, idx.step), - item): - self.field_type.set(vm, addr + self.get_offset(i), val) - else: - idx = self._normalize_idx(idx) - self.field_type.set(vm, addr + self.get_offset(idx), item) - - def is_sized(self): - """True if this is a sized array (non None self.array_len), False - otherwise. - """ - return self.array_len is not None - - def _normalize_idx(self, idx): - # Noop for this type - if self.is_sized(): - if idx < 0: - idx = self.array_len + idx - self._check_bounds(idx) - return idx - - def _normalize_slice(self, slice_): - start = slice_.start if slice_.start is not None else 0 - stop = slice_.stop if slice_.stop is not None else self.get_size() - step = slice_.step if slice_.step is not None else 1 - start = self._normalize_idx(start) - stop = self._normalize_idx(stop) - return slice(start, stop, step) - - def _check_bounds(self, idx): - if not isinstance(idx, int_types): - raise ValueError("index must be an int or a long") - if idx < 0 or (self.is_sized() and idx >= self.size): - raise IndexError("Index %s out of bounds" % idx) - - def _get_pinned_base_class(self): - if self.is_sized(): - return MemSizedArray - else: - return MemArray - - def __repr__(self): - return "[%r; %s]" % (self.field_type, self.array_len or "unsized") - - def __eq__(self, other): - return self.__class__ == other.__class__ and \ - self.field_type == other.field_type and \ - self.array_len == other.array_len - - def __ne__(self, other): - return not self == other - - def __hash__(self): - return hash((self.__class__, self.field_type, self.array_len)) - - -class Bits(Type): - """Helper class for BitField, not very useful on its own. Represents some - bits of a Num. - - The @backing_num is used to know how to serialize/deserialize data in vm, - but getting/setting this fields only assign bits from @bit_offset to - @bit_offset + @bits. Masking and shifting is handled by the class, the aim - is to provide a transparent way to set and get some bits of a num. - """ - - def __init__(self, backing_num, bits, bit_offset): - if not isinstance(backing_num, Num): - raise ValueError("backing_num should be a Num instance") - self._num = backing_num - self._bits = bits - self._bit_offset = bit_offset - - def set(self, vm, addr, val): - val_mask = (1 << self._bits) - 1 - val_shifted = (val & val_mask) << self._bit_offset - num_size = self._num.size * 8 - - full_num_mask = (1 << num_size) - 1 - num_mask = (~(val_mask << self._bit_offset)) & full_num_mask - - num_val = self._num.get(vm, addr) - res_val = (num_val & num_mask) | val_shifted - self._num.set(vm, addr, res_val) - - def get(self, vm, addr): - val_mask = (1 << self._bits) - 1 - num_val = self._num.get(vm, addr) - res_val = (num_val >> self._bit_offset) & val_mask - return res_val - - @property - def size(self): - return self._num.size - - @property - def bit_size(self): - """Number of bits read/written by this class""" - return self._bits - - @property - def bit_offset(self): - """Offset in bits (beginning at 0, the LSB) from which to read/write - bits. - """ - return self._bit_offset - - def __repr__(self): - return "%s%r(%d:%d)" % (self.__class__.__name__, self._num, - self._bit_offset, self._bit_offset + self._bits) - - def __eq__(self, other): - return self.__class__ == other.__class__ and \ - self._num == other._num and self._bits == other._bits and \ - self._bit_offset == other._bit_offset - - def __ne__(self, other): - return not self == other - - def __hash__(self): - return hash((self.__class__, self._num, self._bits, self._bit_offset)) - - -class BitField(Union): - """A C-like bitfield. - - Constructed with a list [(<field_name>, <number_of_bits>)] and a - @backing_num. The @backing_num is a Num instance that determines the total - size of the bitfield and the way the bits are serialized/deserialized (big - endian int, little endian short...). Can be seen (and implemented) as a - Union of Bits fields. - - Mapped to MemBitField. - - Creates fields that allow to access the bitfield fields easily. Example: - - class Example(MemStruct): - fields = [("bf", BitField(Num("B"), [ - ("f1", 2), - ("f2", 4), - ("f3", 1) - ]) - )] - - ex = Example(vm, addr) - ex.memset() - ex.f2 = 2 - ex.f1 = 5 # 5 does not fit on two bits, it will be binarily truncated - assert ex.f1 == 3 - assert ex.f2 == 2 - assert ex.f3 == 0 # previously memset() - assert ex.bf == 3 + 2 << 2 - """ - - def __init__(self, backing_num, bit_list): - """@backing num: Num instance, @bit_list: [(name, n_bits)]""" - self._num = backing_num - fields = [] - offset = 0 - for name, bits in bit_list: - fields.append((name, Bits(self._num, bits, offset))) - offset += bits - if offset > self._num.size * 8: - raise ValueError("sum of bit lengths is > to the backing num size") - super(BitField, self).__init__(fields) - - def set(self, vm, addr, val): - self._num.set(vm, addr, val) - - def _get_pinned_base_class(self): - return MemBitField - - def __eq__(self, other): - return self.__class__ == other.__class__ and \ - self._num == other._num and super(BitField, self).__eq__(other) - - def __ne__(self, other): - return not self == other - - def __hash__(self): - return hash((super(BitField, self).__hash__(), self._num)) - - def __repr__(self): - fields_repr = ', '.join("%s: %r" % (name, field.bit_size) - for name, field in self.fields) - return "%s(%s)" % (self.__class__.__name__, fields_repr) - - -class Str(Type): - """A string type that handles encoding. This type is unsized (no static - size). - - The @encoding is passed to the constructor, and is one of the keys of - Str.encodings, currently: - - ascii - - latin1 - - ansi (= latin1) - - utf8 (= utf-8le) - - utf16 (= utf-16le, Windows UCS-2 compatible) - New encodings can be added with Str.add_encoding. - If an unknown encoding is passed to the constructor, Str will try to add it - to the available ones with Str.add_encoding. - - Mapped to MemStr. - """ - - # Dict of {name: (getter, setter, raw_len)} - # Where: - # - getter(vm, addr) -> unicode - # - setter(vm, addr, unicode) - # - raw_len(unicode_str) -> int (length of the str value one encoded in - # memory) - # See enc_triplet() - # - # NOTE: this appears like it could be implemented only with - # (getter, raw_str), but this would cause trouble for length-prefixed str - # encoding (Pascal-style strings). - encodings = { - "ascii": enc_triplet("ascii"), - "latin1": enc_triplet("latin1"), - "ansi": enc_triplet("latin1"), - "utf8": enc_triplet("utf8"), - "utf16": enc_triplet("utf-16le"), - } - - def __init__(self, encoding="ansi"): - if encoding not in self.encodings: - self.add_encoding(encoding) - self._enc = encoding - - @classmethod - def add_encoding(cls, enc_name, str_enc=None, getter=None, setter=None, - raw_len=None): - """Add an available Str encoding. - - @enc_name: the name that will be used to designate this encoding in the - Str constructor - @str_end: (optional) the actual str encoding name if it differs from - @enc_name - @getter: (optional) func(vm, addr) -> unicode, to force usage of this - function to retrieve the str from memory - @setter: (optional) func(vm, addr, unicode), to force usage of this - function to set the str in memory - @raw_len: (optional) func(unicode_str) -> int (length of the str value - one encoded in memory), to force usage of this function to compute - the length of this string once in memory - """ - default = enc_triplet(str_enc or enc_name) - actual = ( - getter or default[0], - setter or default[1], - raw_len or default[2], - ) - cls.encodings[enc_name] = actual - - def get(self, vm, addr): - """Set the string value in memory""" - get_str = self.encodings[self.enc][0] - return get_str(vm, addr) - - def set(self, vm, addr, s): - """Get the string value from memory""" - set_str = self.encodings[self.enc][1] - set_str(vm, addr, s) - - @property - def size(self): - """This type is unsized.""" - raise ValueError("Str is unsized") - - def value_size(self, py_str): - """Returns the in-memory size of a @py_str for this Str type (handles - encoding, i.e. will not return the same size for "utf16" and "ansi"). - """ - raw_len = self.encodings[self.enc][2] - return raw_len(py_str) - - @property - def enc(self): - """This Str's encoding name (as a str).""" - return self._enc - - def _get_pinned_base_class(self): - return MemStr - - def __repr__(self): - return "%s(%s)" % (self.__class__.__name__, self.enc) - - def __eq__(self, other): - return self.__class__ == other.__class__ and self._enc == other._enc - - def __ne__(self, other): - return not self == other - - def __hash__(self): - return hash((self.__class__, self._enc)) - - -class Void(Type): - """Represents the C void type. - - Mapped to MemVoid. - """ - - def _build_pinned_type(self): - return MemVoid - - def __eq__(self, other): - return self.__class__ == other.__class__ - - def __ne__(self, other): - return not self == other - - def __hash__(self): - return hash(self.__class__) - - def __repr__(self): - return self.__class__.__name__ - - -class Self(Void): - """Special marker to reference a type inside itself. - - Mapped to MemSelf. - - Example: - class ListNode(MemStruct): - fields = [ - ("next", Ptr("<I", Self())), - ("data", Ptr("<I", Void())), - ] - """ - - def _build_pinned_type(self): - return MemSelf - -# To avoid reinstantiation when testing equality -SELF_TYPE_INSTANCE = Self() -VOID_TYPE_INSTANCE = Void() - - -# MemType classes - -class _MetaMemType(type): - def __repr__(cls): - return cls.__name__ - - -class _MetaMemStruct(_MetaMemType): - """MemStruct metaclass. Triggers the magic that generates the class - fields from the cls.fields list. - - Just calls MemStruct.gen_fields() if the fields class attribute has been - set, the actual implementation can seen be there. - """ - - def __init__(cls, name, bases, dct): - super(_MetaMemStruct, cls).__init__(name, bases, dct) - if cls.fields is not None: - cls.fields = tuple(cls.fields) - # Am I able to generate fields? (if not, let the user do it manually - # later) - if cls.get_type() is not None or cls.fields is not None: - cls.gen_fields() - - -class MemType(with_metaclass(_MetaMemType, object)): - """Base class for classes that allow to map python objects to C types in - virtual memory. Represents an lvalue of a given type. - - Globally, MemTypes are not meant to be used directly: specialized - subclasses are generated by Type(...).lval and should be used instead. - The main exception is MemStruct, which you may want to subclass yourself - for syntactic ease. - """ - - # allocator is a function(vm, size) -> allocated_address - allocator = None - - _type = None - - def __init__(self, vm, addr=None, type_=None): - self._vm = vm - if addr is None: - self._addr = self.alloc(vm, self.get_size()) - else: - self._addr = addr - if type_ is not None: - self._type = type_ - if self._type is None: - raise ValueError("Subclass MemType and define cls._type or pass " - "a type to the constructor") - - @classmethod - def alloc(cls, vm, size): - """Returns an allocated page of size @size if cls.allocator is set. - Raises ValueError otherwise. - """ - if cls.allocator is None: - raise ValueError("Cannot provide None address to MemType() if" - "%s.set_allocator has not been called." - % __name__) - return cls.allocator(vm, size) - - @classmethod - def set_allocator(cls, alloc_func): - """Set an allocator for this class; allows to instantiate statically - sized MemTypes (i.e. sizeof() is implemented) without specifying the - address (the object is allocated by @alloc_func in the vm). - - You may call set_allocator on specific MemType classes if you want - to use a different allocator. - - @alloc_func: func(VmMngr) -> integer_address - """ - cls.allocator = alloc_func - - def get_addr(self, field=None): - """Return the address of this MemType or one of its fields. - - @field: (str, optional) used by subclasses to specify the name or index - of the field to get the address of - """ - if field is not None: - raise NotImplementedError("Getting a field's address is not " - "implemented for this class.") - return self._addr - - @classmethod - def get_type(cls): - """Returns the Type subclass instance representing the C type of this - MemType. - """ - return cls._type - - @classmethod - def sizeof(cls): - """Return the static size of this type. By default, it is the size - of the underlying Type. - """ - return cls._type.size - - def get_size(self): - """Return the dynamic size of this structure (e.g. the size of an - instance). Defaults to sizeof for this base class. - - For example, MemStr defines get_size but not sizeof, as an instance - has a fixed size (at least its value has), but all the instance do not - have the same size. - """ - return self.sizeof() - - def memset(self, byte=b'\x00'): - """Fill the memory space of this MemType with @byte ('\x00' by - default). The size is retrieved with self.get_size() (dynamic size). - """ - # TODO: multibyte patterns - if not isinstance(byte, bytes) or len(byte) != 1: - raise ValueError("byte must be a 1-lengthed str") - self._vm.set_mem(self.get_addr(), byte * self.get_size()) - - def cast(self, other_type): - """Cast this MemType to another MemType (same address, same vm, - but different type). Return the casted MemType. - - @other_type: either a Type instance (other_type.lval is used) or a - MemType subclass - """ - if isinstance(other_type, Type): - other_type = other_type.lval - return other_type(self._vm, self.get_addr()) - - def cast_field(self, field, other_type, *type_args, **type_kwargs): - """ABSTRACT: Same as cast, but the address of the returned MemType - is the address at which @field is in the current MemType. - - @field: field specification, for example its name for a struct, or an - index in an array. See the subclass doc. - @other_type: either a Type instance (other_type.lval is used) or a - MemType subclass - """ - raise NotImplementedError("Abstract") - - def raw(self): - """Raw binary (str) representation of the MemType as it is in - memory. - """ - return self._vm.get_mem(self.get_addr(), self.get_size()) - - def __len__(self): - return self.get_size() - - def __str__(self): - if PY3: - return repr(self) - return self.__bytes__() - - def __bytes__(self): - return self.raw() - - def __repr__(self): - return "Mem%r" % self._type - - def __eq__(self, other): - return self.__class__ == other.__class__ and \ - self.get_type() == other.get_type() and \ - bytes(self) == bytes(other) - - def __ne__(self, other): - return not self == other - - -class MemValue(MemType): - """Simple MemType that gets and sets the Type through the `.val` - attribute. - """ - - @property - def val(self): - return self._type.get(self._vm, self._addr) - - @val.setter - def val(self, value): - self._type.set(self._vm, self._addr, value) - - def __repr__(self): - return "%r: %r" % (self.__class__, self.val) - - -class MemStruct(with_metaclass(_MetaMemStruct, MemType)): - """Base class to easily implement VmMngr backed C-like structures in miasm. - Represents a structure in virtual memory. - - The mechanism is the following: - - set a "fields" class field to be a list of - (<field_name (str)>, <Type_subclass_instance>) - - instances of this class will have properties to interact with these - fields. - - Example: - class MyStruct(MemStruct): - fields = [ - # Scalar field: just struct.pack field with one value - ("num", Num("I")), - ("flags", Num("B")), - # Ptr fields contain two fields: "val", for the numerical value, - # and "deref" to get the pointed object - ("other", Ptr("I", OtherStruct)), - # Ptr to a variable length String - ("s", Ptr("I", Str())), - ("i", Ptr("I", Num("I"))), - ] - - mstruct = MyStruct(vm, addr) - - # Field assignment modifies virtual memory - mstruct.num = 3 - assert mstruct.num == 3 - memval = struct.unpack("I", vm.get_mem(mstruct.get_addr(), - 4))[0] - assert memval == mstruct.num - - # Memset sets the whole structure - mstruct.memset() - assert mstruct.num == 0 - mstruct.memset('\x11') - assert mstruct.num == 0x11111111 - - other = OtherStruct(vm, addr2) - mstruct.other = other.get_addr() - assert mstruct.other.val == other.get_addr() - assert mstruct.other.deref == other - assert mstruct.other.deref.foo == 0x1234 - - Note that: - MyStruct = Struct("MyStruct", <same fields>).lval - is equivalent to the previous MyStruct declaration. - - See the various Type-s doc for more information. See MemStruct.gen_fields - doc for more information on how to handle recursive types and cyclic - dependencies. - """ - fields = None - - def get_addr(self, field_name=None): - """ - @field_name: (str, optional) the name of the field to get the - address of - """ - if field_name is not None: - offset = self._type.get_offset(field_name) - else: - offset = 0 - return self._addr + offset - - @classmethod - def get_offset(cls, field_name): - """Shorthand for self.get_type().get_offset(field_name).""" - return cls.get_type().get_offset(field_name) - - def get_field(self, name): - """Get a field value by name. - - useless most of the time since fields are accessible via self.<name>. - """ - return self._type.get_field(self._vm, self.get_addr(), name) - - def set_field(self, name, val): - """Set a field value by name. @val is the python value corresponding to - this field type. - - useless most of the time since fields are accessible via self.<name>. - """ - return self._type.set_field(self._vm, self.get_addr(), name, val) - - def cast_field(self, field, other_type): - """In this implementation, @field is a field name""" - if isinstance(other_type, Type): - other_type = other_type.lval - return other_type(self._vm, self.get_addr(field)) - - # Field generation method, voluntarily public to be able to gen fields - # after class definition - @classmethod - def gen_fields(cls, fields=None): - """Generate the fields of this class (so that they can be accessed with - self.<field_name>) from a @fields list, as described in the class doc. - - Useful in case of a type cyclic dependency. For example, the following - is not possible in python: - - class A(MemStruct): - fields = [("b", Ptr("I", B))] - - class B(MemStruct): - fields = [("a", Ptr("I", A))] - - With gen_fields, the following is the legal equivalent: - - class A(MemStruct): - pass - - class B(MemStruct): - fields = [("a", Ptr("I", A))] - - A.gen_fields([("b", Ptr("I", B))]) - """ - if fields is not None: - if cls.fields is not None: - raise ValueError("Cannot regen fields of a class. Setting " - "cls.fields at class definition and calling " - "gen_fields are mutually exclusive.") - cls.fields = fields - - if cls._type is None: - if cls.fields is None: - raise ValueError("Cannot create a MemStruct subclass without" - " a cls._type or a cls.fields") - cls._type = cls._gen_type(cls.fields) - - if cls._type in DYN_MEM_STRUCT_CACHE: - # FIXME: Maybe a warning would be better? - raise RuntimeError("Another MemType has the same type as this " - "one. Use it instead.") - - # Register this class so that another one will not be created when - # calling cls._type.lval - DYN_MEM_STRUCT_CACHE[cls._type] = cls - - cls._gen_attributes() - - @classmethod - def _gen_attributes(cls): - # Generate self.<name> getter and setters - for name, _, _ in cls._type.all_fields: - setattr(cls, name, property( - lambda self, name=name: self.get_field(name), - lambda self, val, name=name: self.set_field(name, val) - )) - - @classmethod - def _gen_type(cls, fields): - return Struct(cls.__name__, fields) - - def __repr__(self): - out = [] - for name, field in self._type.fields: - val_repr = repr(self.get_field(name)) - if '\n' in val_repr: - val_repr = '\n' + indent(val_repr, 4) - out.append("%s: %r = %s" % (name, field, val_repr)) - return '%r:\n' % self.__class__ + indent('\n'.join(out), 2) - - -class MemUnion(MemStruct): - """Same as MemStruct but all fields have a 0 offset in the struct.""" - @classmethod - def _gen_type(cls, fields): - return Union(fields) - - -class MemBitField(MemUnion): - """MemUnion of Bits(...) fields.""" - @classmethod - def _gen_type(cls, fields): - return BitField(fields) - - -class MemSelf(MemStruct): - """Special Marker class for reference to current class in a Ptr or Array - (mostly Array of Ptr). See Self doc. - """ - def __repr__(self): - return self.__class__.__name__ - - -class MemVoid(MemType): - """Placeholder for e.g. Ptr to an undetermined type. Useful mostly when - casted to another type. Allows to implement C's "void*" pattern. - """ - _type = Void() - - def __repr__(self): - return self.__class__.__name__ - - -class MemPtr(MemValue): - """Mem version of a Ptr, provides two properties: - - val, to set and get the numeric value of the Ptr - - deref, to set and get the pointed type - """ - @property - def val(self): - return self._type.get_val(self._vm, self._addr) - - @val.setter - def val(self, value): - return self._type.set_val(self._vm, self._addr, value) - - @property - def deref(self): - return self._type.deref_get(self._vm, self._addr) - - @deref.setter - def deref(self, val): - return self._type.deref_set(self._vm, self._addr, val) - - def __repr__(self): - return "*%s" % hex(self.val) - - -class MemStr(MemValue): - """Implements a string representation in memory. - - The string value can be got or set (with python str/unicode) through the - self.val attribute. String encoding/decoding is handled by the class, - - This type is dynamically sized only (get_size is implemented, not sizeof). - """ - - def get_size(self): - """This get_size implementation is quite unsafe: it reads the string - underneath to determine the size, it may therefore read a lot of memory - and provoke mem faults (analogous to strlen). - """ - val = self.val - return self.get_type().value_size(val) - - @classmethod - def from_str(cls, vm, py_str): - """Allocates a MemStr with the global allocator with value py_str. - Raises a ValueError if allocator is not set. - """ - size = cls._type.value_size(py_str) - addr = cls.alloc(vm, size) - memstr = cls(vm, addr) - memstr.val = py_str - return memstr - - def raw(self): - raw = self._vm.get_mem(self.get_addr(), self.get_size()) - return raw - - def __repr__(self): - return "%r: %r" % (self.__class__, self.val) - - -class MemArray(MemType): - """An unsized array of type @field_type (a Type subclass instance). - This class has no static or dynamic size. - - It can be indexed for setting and getting elements, example: - - array = Array(Num("I")).lval(vm, addr)) - array[2] = 5 - array[4:8] = [0, 1, 2, 3] - print array[20] - """ - - @property - def field_type(self): - """Return the Type subclass instance that represents the type of - this MemArray items. - """ - return self.get_type().field_type - - def get_addr(self, idx=0): - return self._addr + self.get_type().get_offset(idx) - - @classmethod - def get_offset(cls, idx): - """Shorthand for self.get_type().get_offset(idx).""" - return cls.get_type().get_offset(idx) - - def __getitem__(self, idx): - return self.get_type().get_item(self._vm, self._addr, idx) - - def __setitem__(self, idx, item): - self.get_type().set_item(self._vm, self._addr, idx, item) - - def raw(self): - raise ValueError("%s is unsized, which prevents from getting its full " - "raw representation. Use MemSizedArray instead." % - self.__class__) - - def __repr__(self): - return "[%r, ...] [%r]" % (self[0], self.field_type) - - -class MemSizedArray(MemArray): - """A fixed size MemArray. - - This type is dynamically sized. Generate a fixed @field_type and @array_len - array which has a static size by using Array(type, size).lval. - """ - - @property - def array_len(self): - """The length, in number of elements, of this array.""" - return self.get_type().array_len - - def get_size(self): - return self.get_type().size - - def __iter__(self): - for i in range(self.get_type().array_len): - yield self[i] - - def raw(self): - return self._vm.get_mem(self.get_addr(), self.get_size()) - - def __repr__(self): - item_reprs = [repr(item) for item in self] - if self.array_len > 0 and '\n' in item_reprs[0]: - items = '\n' + indent(',\n'.join(item_reprs), 2) + '\n' - else: - items = ', '.join(item_reprs) - return "[%s] [%r; %s]" % (items, self.field_type, self.array_len) - diff --git a/miasm/core/utils.py b/miasm/core/utils.py deleted file mode 100644 index 291c5f4d..00000000 --- a/miasm/core/utils.py +++ /dev/null @@ -1,292 +0,0 @@ -from __future__ import print_function -import re -import sys -from builtins import range -import struct -import inspect - -try: - from collections.abc import MutableMapping as DictMixin -except ImportError: - from collections import MutableMapping as DictMixin - -from operator import itemgetter -import codecs - -from future.utils import viewitems - -import collections - -COLOR_INT = "azure4" -COLOR_ID = "forestgreen"#"chartreuse3" -COLOR_MEM = "deeppink4" -COLOR_OP_FUNC = "blue1" -COLOR_LOC = "darkslateblue" -COLOR_OP = "black" - -COLOR_MNEMO = "blue1" - -ESCAPE_CHARS = re.compile('[' + re.escape('{}[]') + '&|<>' + ']') - - - -def set_html_text_color(text, color): - return '<font color="%s">%s</font>' % (color, text) - - -def _fix_chars(token): - return "&#%04d;" % ord(token.group()) - - -def fix_html_chars(text): - return ESCAPE_CHARS.sub(_fix_chars, str(text)) - -BRACKET_O = fix_html_chars('[') -BRACKET_C = fix_html_chars(']') - -upck8 = lambda x: struct.unpack('B', x)[0] -upck16 = lambda x: struct.unpack('H', x)[0] -upck32 = lambda x: struct.unpack('I', x)[0] -upck64 = lambda x: struct.unpack('Q', x)[0] -pck8 = lambda x: struct.pack('B', x) -pck16 = lambda x: struct.pack('H', x) -pck32 = lambda x: struct.pack('I', x) -pck64 = lambda x: struct.pack('Q', x) - -# Little endian -upck8le = lambda x: struct.unpack('<B', x)[0] -upck16le = lambda x: struct.unpack('<H', x)[0] -upck32le = lambda x: struct.unpack('<I', x)[0] -upck64le = lambda x: struct.unpack('<Q', x)[0] -pck8le = lambda x: struct.pack('<B', x) -pck16le = lambda x: struct.pack('<H', x) -pck32le = lambda x: struct.pack('<I', x) -pck64le = lambda x: struct.pack('<Q', x) - -# Big endian -upck8be = lambda x: struct.unpack('>B', x)[0] -upck16be = lambda x: struct.unpack('>H', x)[0] -upck32be = lambda x: struct.unpack('>I', x)[0] -upck64be = lambda x: struct.unpack('>Q', x)[0] -pck8be = lambda x: struct.pack('>B', x) -pck16be = lambda x: struct.pack('>H', x) -pck32be = lambda x: struct.pack('>I', x) -pck64be = lambda x: struct.pack('>Q', x) - - -LITTLE_ENDIAN = 1 -BIG_ENDIAN = 2 - - -pck = {8: pck8, - 16: pck16, - 32: pck32, - 64: pck64} - - -def get_caller_name(caller_num=0): - """Get the nth caller's name - @caller_num: 0 = the caller of get_caller_name, 1 = next parent, ...""" - pystk = inspect.stack() - if len(pystk) > 1 + caller_num: - return pystk[1 + caller_num][3] - else: - return "Bad caller num" - - -def whoami(): - """Returns the caller's name""" - return get_caller_name(1) - - -class Disasm_Exception(Exception): - pass - - -def printable(string): - if isinstance(string, bytes): - return "".join( - c.decode() if b" " <= c < b"~" else "." - for c in (string[i:i+1] for i in range(len(string))) - ) - return string - - -def force_bytes(value): - if isinstance(value, bytes): - return value - if not isinstance(value, str): - return value - out = [] - for c in value: - c = ord(c) - assert c < 0x100 - out.append(c) - return bytes(out) - - -def force_str(value): - if isinstance(value, str): - return value - elif isinstance(value, bytes): - out = "" - for i in range(len(value)): - # For Python2/Python3 compatibility - c = ord(value[i:i+1]) - out += chr(c) - value = out - else: - raise ValueError("Unsupported type") - return value - - -def iterbytes(string): - for i in range(len(string)): - yield string[i:i+1] - - -def int_to_byte(value): - return struct.pack('B', value) - -def cmp_elts(elt1, elt2): - return (elt1 > elt2) - (elt1 < elt2) - - -_DECODE_HEX = codecs.getdecoder("hex_codec") -_ENCODE_HEX = codecs.getencoder("hex_codec") - -def decode_hex(value): - return _DECODE_HEX(value)[0] - -def encode_hex(value): - return _ENCODE_HEX(value)[0] - -def size2mask(size): - """Return the bit mask of size @size""" - return (1 << size) - 1 - -def hexdump(src, length=16): - lines = [] - for c in range(0, len(src), length): - chars = src[c:c + length] - hexa = ' '.join("%02x" % ord(x) for x in iterbytes(chars)) - printable = ''.join( - x.decode() if 32 <= ord(x) <= 126 else '.' for x in iterbytes(chars) - ) - lines.append("%04x %-*s %s\n" % (c, length * 3, hexa, printable)) - print(''.join(lines)) - - -# stackoverflow.com/questions/2912231 -class keydefaultdict(collections.defaultdict): - - def __missing__(self, key): - if self.default_factory is None: - raise KeyError(key) - value = self[key] = self.default_factory(key) - return value - - -class BoundedDict(DictMixin): - - """Limited in size dictionary. - - To reduce combinatory cost, once an upper limit @max_size is reached, - @max_size - @min_size elements are suppressed. - The targeted elements are the less accessed. - - One can define a callback called when an element is removed - """ - - def __init__(self, max_size, min_size=None, initialdata=None, - delete_cb=None): - """Create a BoundedDict - @max_size: maximum size of the dictionary - @min_size: (optional) number of most used element to keep when resizing - @initialdata: (optional) dict instance with initial data - @delete_cb: (optional) callback called when an element is removed - """ - self._data = initialdata.copy() if initialdata else {} - self._min_size = min_size if min_size else max_size // 3 - self._max_size = max_size - self._size = len(self._data) - # Do not use collections.Counter as it is quite slow - self._counter = {k: 1 for k in self._data} - self._delete_cb = delete_cb - - def __setitem__(self, asked_key, value): - if asked_key not in self._data: - # Update internal size and use's counter - self._size += 1 - - # Bound can only be reached on a new element - if (self._size >= self._max_size): - most_common = sorted( - viewitems(self._counter), - key=itemgetter(1), - reverse=True - ) - - # Handle callback - if self._delete_cb is not None: - for key, _ in most_common[self._min_size - 1:]: - self._delete_cb(key) - - # Keep only the most @_min_size used - self._data = {key: self._data[key] - for key, _ in most_common[:self._min_size - 1]} - self._size = self._min_size - - # Reset use's counter - self._counter = {k: 1 for k in self._data} - - # Avoid rechecking in dict: set to 1 here, add 1 otherwise - self._counter[asked_key] = 1 - else: - self._counter[asked_key] += 1 - - self._data[asked_key] = value - - def __contains__(self, key): - # Do not call has_key to avoid adding function call overhead - return key in self._data - - def has_key(self, key): - return key in self._data - - def keys(self): - "Return the list of dict's keys" - return list(self._data) - - @property - def data(self): - "Return the current instance as a dictionary" - return self._data - - def __getitem__(self, key): - # Retrieve data first to raise the proper exception on error - data = self._data[key] - # Should never raise, since the key is in self._data - self._counter[key] += 1 - return data - - def __delitem__(self, key): - if self._delete_cb is not None: - self._delete_cb(key) - del self._data[key] - self._size -= 1 - del self._counter[key] - - def __del__(self): - """Ensure the callback is called when last reference is lost""" - if self._delete_cb: - for key in self._data: - self._delete_cb(key) - - - def __len__(self): - return len(self._data) - - def __iter__(self): - return iter(self._data) - |