diff options
Diffstat (limited to 'miasm2/expression')
| -rw-r--r-- | miasm2/expression/__init__.py | 18 | ||||
| -rw-r--r-- | miasm2/expression/expression.py | 2035 | ||||
| -rw-r--r-- | miasm2/expression/expression_helper.py | 628 | ||||
| -rw-r--r-- | miasm2/expression/expression_reduce.py | 280 | ||||
| -rw-r--r-- | miasm2/expression/modint.py | 259 | ||||
| -rw-r--r-- | miasm2/expression/parser.py | 84 | ||||
| -rw-r--r-- | miasm2/expression/simplifications.py | 207 | ||||
| -rw-r--r-- | miasm2/expression/simplifications_common.py | 1556 | ||||
| -rw-r--r-- | miasm2/expression/simplifications_cond.py | 178 | ||||
| -rw-r--r-- | miasm2/expression/simplifications_explicit.py | 159 | ||||
| -rw-r--r-- | miasm2/expression/smt2_helper.py | 296 |
11 files changed, 0 insertions, 5700 deletions
diff --git a/miasm2/expression/__init__.py b/miasm2/expression/__init__.py deleted file mode 100644 index 67f567f7..00000000 --- a/miasm2/expression/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# -# Copyright (C) 2011 EADS France, Fabrice Desclaux <fabrice.desclaux@eads.net> -# -# This program is free software; you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation; either version 2 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License along -# with this program; if not, write to the Free Software Foundation, Inc., -# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -# -"Intermediate language implementation" diff --git a/miasm2/expression/expression.py b/miasm2/expression/expression.py deleted file mode 100644 index 03febbfd..00000000 --- a/miasm2/expression/expression.py +++ /dev/null @@ -1,2035 +0,0 @@ -# -# Copyright (C) 2011 EADS France, Fabrice Desclaux <fabrice.desclaux@eads.net> -# -# This program is free software; you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation; either version 2 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License along -# with this program; if not, write to the Free Software Foundation, Inc., -# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -# -# These module implements Miasm IR components and basic operations related. -# IR components are : -# - ExprInt -# - ExprId -# - ExprLoc -# - ExprAssign -# - ExprCond -# - ExprMem -# - ExprOp -# - ExprSlice -# - ExprCompose -# - - -from builtins import zip -from builtins import range -import warnings -import itertools -from builtins import int as int_types -from functools import cmp_to_key, total_ordering -from future.utils import viewitems - -from miasm2.core.utils import force_bytes, cmp_elts -from miasm2.expression.modint import mod_size2uint, is_modint, size2mask, \ - define_uint -from miasm2.core.graph import DiGraph -from functools import reduce - -# Define tokens -TOK_INF = "<" -TOK_INF_SIGNED = TOK_INF + "s" -TOK_INF_UNSIGNED = TOK_INF + "u" -TOK_INF_EQUAL = "<=" -TOK_INF_EQUAL_SIGNED = TOK_INF_EQUAL + "s" -TOK_INF_EQUAL_UNSIGNED = TOK_INF_EQUAL + "u" -TOK_EQUAL = "==" -TOK_POS = "pos" -TOK_POS_STRICT = "Spos" - -# Hashing constants -EXPRINT = 1 -EXPRID = 2 -EXPRLOC = 3 -EXPRASSIGN = 4 -EXPRCOND = 5 -EXPRMEM = 6 -EXPROP = 7 -EXPRSLICE = 8 -EXPRCOMPOSE = 9 - - -priorities_list = [ - [ '+' ], - [ '*', '/', '%' ], - [ '**' ], - [ '-' ], # Unary '-', associativity with + not handled -] - -# dictionary from 'op' to priority, derived from above -priorities = dict((op, prio) - for prio, l in enumerate(priorities_list) - for op in l) -PRIORITY_MAX = len(priorities_list) - 1 - -def should_parenthesize_child(child, parent): - if (isinstance(child, ExprId) or isinstance(child, ExprInt) or - isinstance(child, ExprCompose) or isinstance(child, ExprMem) or - isinstance(child, ExprSlice)): - return False - elif isinstance(child, ExprOp) and not child.is_infix(): - return False - elif (isinstance(child, ExprCond) or isinstance(parent, ExprSlice)): - return True - elif (isinstance(child, ExprOp) and isinstance(parent, ExprOp)): - pri_child = priorities.get(child.op, -1) - pri_parent = priorities.get(parent.op, PRIORITY_MAX + 1) - return pri_child < pri_parent - else: - return True - -def str_protected_child(child, parent): - return ("(%s)" % child) if should_parenthesize_child(child, parent) else str(child) - -def visit_chk(visitor): - "Function decorator launching callback on Expression visit" - def wrapped(expr, callback, test_visit=lambda x: True): - if (test_visit is not None) and (not test_visit(expr)): - return expr - expr_new = visitor(expr, callback, test_visit) - if expr_new is None: - return None - expr_new2 = callback(expr_new) - return expr_new2 - return wrapped - - -# Expression display - - -class DiGraphExpr(DiGraph): - - """Enhanced graph for Expression display - Expression are displayed as a tree with node and edge labeled - with only relevant information""" - - def node2str(self, node): - if isinstance(node, ExprOp): - return node.op - elif isinstance(node, ExprId): - return node.name - elif isinstance(node, ExprLoc): - return "%s" % node.loc_key - elif isinstance(node, ExprMem): - return "@%d" % node.size - elif isinstance(node, ExprCompose): - return "{ %d }" % node.size - elif isinstance(node, ExprCond): - return "? %d" % node.size - elif isinstance(node, ExprSlice): - return "[%d:%d]" % (node.start, node.stop) - return str(node) - - def edge2str(self, nfrom, nto): - if isinstance(nfrom, ExprCompose): - for i in nfrom.args: - if i[0] == nto: - return "[%s, %s]" % (i[1], i[2]) - elif isinstance(nfrom, ExprCond): - if nfrom.cond == nto: - return "?" - elif nfrom.src1 == nto: - return "True" - elif nfrom.src2 == nto: - return "False" - - return "" - - -@total_ordering -class LocKey(object): - def __init__(self, key): - self._key = key - - key = property(lambda self: self._key) - - def __hash__(self): - return hash(self._key) - - def __eq__(self, other): - if self is other: - return True - if self.__class__ is not other.__class__: - return False - return self.key == other.key - - def __ne__(self, other): - # required Python 2.7.14 - return not self == other - - def __lt__(self, other): - return self.key < other.key - - def __repr__(self): - return "<%s %d>" % (self.__class__.__name__, self._key) - - def __str__(self): - return "loc_key_%d" % self.key - -# IR definitions - -class Expr(object): - - "Parent class for Miasm Expressions" - - __slots__ = ["_hash", "_repr", "_size"] - - args2expr = {} - canon_exprs = set() - use_singleton = True - - def set_size(self, _): - raise ValueError('size is not mutable') - - def __init__(self, size): - """Instantiate an Expr with size @size - @size: int - """ - # Common attribute - self._size = size - - # Lazy cache needs - self._hash = None - self._repr = None - - size = property(lambda self: self._size) - - @staticmethod - def get_object(expr_cls, args): - if not expr_cls.use_singleton: - return object.__new__(expr_cls) - - expr = Expr.args2expr.get((expr_cls, args)) - if expr is None: - expr = object.__new__(expr_cls) - Expr.args2expr[(expr_cls, args)] = expr - return expr - - def get_is_canon(self): - return self in Expr.canon_exprs - - def set_is_canon(self, value): - assert value is True - Expr.canon_exprs.add(self) - - is_canon = property(get_is_canon, set_is_canon) - - # Common operations - - def __str__(self): - raise NotImplementedError("Abstract Method") - - def __getitem__(self, i): - if not isinstance(i, slice): - raise TypeError("Expression: Bad slice: %s" % i) - start, stop, step = i.indices(self.size) - if step != 1: - raise ValueError("Expression: Bad slice: %s" % i) - return ExprSlice(self, start, stop) - - def get_size(self): - raise DeprecationWarning("use X.size instead of X.get_size()") - - def is_function_call(self): - """Returns true if the considered Expr is a function call - """ - return False - - def __repr__(self): - if self._repr is None: - self._repr = self._exprrepr() - return self._repr - - def __hash__(self): - if self._hash is None: - self._hash = self._exprhash() - return self._hash - - def __eq__(self, other): - if self is other: - return True - elif self.use_singleton: - # In case of Singleton, pointer comparison is sufficient - # Avoid computation of hash and repr - return False - - if self.__class__ is not other.__class__: - return False - if hash(self) != hash(other): - return False - return repr(self) == repr(other) - - def __ne__(self, other): - return not self.__eq__(other) - - def __add__(self, other): - return ExprOp('+', self, other) - - def __sub__(self, other): - return ExprOp('+', self, ExprOp('-', other)) - - def __div__(self, other): - return ExprOp('/', self, other) - - def __floordiv__(self, other): - return self.__div__(other) - - def __mod__(self, other): - return ExprOp('%', self, other) - - def __mul__(self, other): - return ExprOp('*', self, other) - - def __lshift__(self, other): - return ExprOp('<<', self, other) - - def __rshift__(self, other): - return ExprOp('>>', self, other) - - def __xor__(self, other): - return ExprOp('^', self, other) - - def __or__(self, other): - return ExprOp('|', self, other) - - def __and__(self, other): - return ExprOp('&', self, other) - - def __neg__(self): - return ExprOp('-', self) - - def __pow__(self, other): - return ExprOp("**", self, other) - - def __invert__(self): - return ExprOp('^', self, self.mask) - - def copy(self): - "Deep copy of the expression" - return self.visit(lambda x: x) - - def __deepcopy__(self, _): - return self.copy() - - def replace_expr(self, dct): - """Find and replace sub expression using dct - @dct: dictionary associating replaced Expr to its new Expr value - """ - return self.visit(lambda expr: dct.get(expr, expr)) - - def canonize(self): - "Canonize the Expression" - - def must_canon(expr): - return not expr.is_canon - - def canonize_visitor(expr): - if expr.is_canon: - return expr - if isinstance(expr, ExprOp): - if expr.is_associative(): - # ((a+b) + c) => (a + b + c) - args = [] - for arg in expr.args: - if isinstance(arg, ExprOp) and expr.op == arg.op: - args += arg.args - else: - args.append(arg) - args = canonize_expr_list(args) - new_e = ExprOp(expr.op, *args) - else: - new_e = expr - else: - new_e = expr - new_e.is_canon = True - return new_e - - return self.visit(canonize_visitor, must_canon) - - def msb(self): - "Return the Most Significant Bit" - return self[self.size - 1:self.size] - - def zeroExtend(self, size): - """Zero extend to size - @size: int - """ - assert self.size <= size - if self.size == size: - return self - return ExprOp('zeroExt_%d' % size, self) - - def signExtend(self, size): - """Sign extend to size - @size: int - """ - assert self.size <= size - if self.size == size: - return self - return ExprOp('signExt_%d' % size, self) - - def graph_recursive(self, graph): - """Recursive method used by graph - @graph: miasm2.core.graph.DiGraph instance - Update @graph instance to include sons - This is an Abstract method""" - - raise ValueError("Abstract method") - - def graph(self): - """Return a DiGraph instance standing for Expr tree - Instance's display functions have been override for better visibility - Wrapper on graph_recursive""" - - # Create recursively the graph - graph = DiGraphExpr() - self.graph_recursive(graph) - - return graph - - def set_mask(self, value): - raise ValueError('mask is not mutable') - - mask = property(lambda self: ExprInt(-1, self.size)) - - def is_int(self, value=None): - return False - - def is_id(self, name=None): - return False - - def is_loc(self, label=None): - return False - - def is_aff(self): - return False - - def is_cond(self): - return False - - def is_mem(self): - return False - - def is_op(self, op=None): - return False - - def is_slice(self, start=None, stop=None): - return False - - def is_compose(self): - return False - - def is_op_segm(self): - """Returns True if is ExprOp and op == 'segm'""" - return False - - def is_mem_segm(self): - """Returns True if is ExprMem and ptr is_op_segm""" - return False - -class ExprInt(Expr): - - """An ExprInt represent a constant in Miasm IR. - - Some use cases: - - Constant 0x42 - - Constant -0x30 - - Constant 0x12345678 on 32bits - """ - - __slots__ = Expr.__slots__ + ["_arg"] - - - def __init__(self, arg, size): - """Create an ExprInt from a modint or num/size - @arg: 'intable' number - @size: int size""" - super(ExprInt, self).__init__(size) - # Work for ._arg is done in __new__ - - arg = property(lambda self: self._arg) - - def __reduce__(self): - state = int(self._arg), self._size - return self.__class__, state - - def __new__(cls, arg, size): - """Create an ExprInt from a modint or num/size - @arg: 'intable' number - @size: int size""" - - if is_modint(arg): - assert size == arg.size - # Avoid a common blunder - assert not isinstance(arg, ExprInt) - - # Ensure arg is always a moduint - arg = int(arg) - if size not in mod_size2uint: - define_uint(size) - arg = mod_size2uint[size](arg) - - # Get the Singleton instance - expr = Expr.get_object(cls, (arg, size)) - - # Save parameters (__init__ is called with parameters unchanged) - expr._arg = arg - return expr - - def _get_int(self): - "Return self integer representation" - return int(self._arg & size2mask(self._size)) - - def __str__(self): - if self._arg < 0: - return str("-0x%X" % (- self._get_int())) - else: - return str("0x%X" % self._get_int()) - - def get_r(self, mem_read=False, cst_read=False): - if cst_read: - return set([self]) - else: - return set() - - def get_w(self): - return set() - - def _exprhash(self): - return hash((EXPRINT, self._arg, self._size)) - - def _exprrepr(self): - return "%s(0x%X, %d)" % (self.__class__.__name__, self._get_int(), - self._size) - - def __contains__(self, expr): - return self == expr - - @visit_chk - def visit(self, callback, test_visit=None): - return self - - def copy(self): - return ExprInt(self._arg, self._size) - - def depth(self): - return 1 - - def graph_recursive(self, graph): - graph.add_node(self) - - def __int__(self): - return int(self.arg) - - def __long__(self): - return int(self.arg) - - def is_int(self, value=None): - if value is not None and self._arg != value: - return False - return True - - -class ExprId(Expr): - - """An ExprId represent an identifier in Miasm IR. - - Some use cases: - - EAX register - - 'start' offset - - variable v1 - """ - - __slots__ = Expr.__slots__ + ["_name"] - - def __init__(self, name, size=None): - """Create an identifier - @name: str, identifier's name - @size: int, identifier's size - """ - if size is None: - warnings.warn('DEPRECATION WARNING: size is a mandatory argument: use ExprId(name, SIZE)') - size = 32 - assert isinstance(name, (str, bytes)) - super(ExprId, self).__init__(size) - self._name = name - - name = property(lambda self: self._name) - - def __reduce__(self): - state = self._name, self._size - return self.__class__, state - - def __new__(cls, name, size=None): - if size is None: - warnings.warn('DEPRECATION WARNING: size is a mandatory argument: use ExprId(name, SIZE)') - size = 32 - return Expr.get_object(cls, (name, size)) - - def __str__(self): - return str(self._name) - - def get_r(self, mem_read=False, cst_read=False): - return set([self]) - - def get_w(self): - return set([self]) - - def _exprhash(self): - return hash((EXPRID, self._name, self._size)) - - def _exprrepr(self): - return "%s(%r, %d)" % (self.__class__.__name__, self._name, self._size) - - def __contains__(self, expr): - return self == expr - - @visit_chk - def visit(self, callback, test_visit=None): - return self - - def copy(self): - return ExprId(self._name, self._size) - - def depth(self): - return 1 - - def graph_recursive(self, graph): - graph.add_node(self) - - def is_id(self, name=None): - if name is not None and self._name != name: - return False - return True - - -class ExprLoc(Expr): - - """An ExprLoc represent a Label in Miasm IR. - """ - - __slots__ = Expr.__slots__ + ["_loc_key"] - - def __init__(self, loc_key, size): - """Create an identifier - @loc_key: int, label loc_key - @size: int, identifier's size - """ - assert isinstance(loc_key, LocKey) - super(ExprLoc, self).__init__(size) - self._loc_key = loc_key - - loc_key= property(lambda self: self._loc_key) - - def __reduce__(self): - state = self._loc_key, self._size - return self.__class__, state - - def __new__(cls, loc_key, size): - return Expr.get_object(cls, (loc_key, size)) - - def __str__(self): - return str(self._loc_key) - - def get_r(self, mem_read=False, cst_read=False): - return set() - - def get_w(self): - return set() - - def _exprhash(self): - return hash((EXPRLOC, self._loc_key, self._size)) - - def _exprrepr(self): - return "%s(%r, %d)" % (self.__class__.__name__, self._loc_key, self._size) - - def __contains__(self, expr): - return self == expr - - @visit_chk - def visit(self, callback, test_visit=None): - return self - - def copy(self): - return ExprLoc(self._loc_key, self._size) - - def depth(self): - return 1 - - def graph_recursive(self, graph): - graph.add_node(self) - - def is_loc(self, loc_key=None): - if loc_key is not None and self._loc_key != loc_key: - return False - return True - - -class ExprAssign(Expr): - - """An ExprAssign represent an assignment from an Expression to another one. - - Some use cases: - - var1 <- 2 - """ - - __slots__ = Expr.__slots__ + ["_dst", "_src"] - - def __init__(self, dst, src): - """Create an ExprAssign for dst <- src - @dst: Expr, assignment destination - @src: Expr, assignment source - """ - # dst & src must be Expr - assert isinstance(dst, Expr) - assert isinstance(src, Expr) - - if dst.size != src.size: - raise ValueError( - "sanitycheck: ExprAssign args must have same size! %s" % - ([(str(arg), arg.size) for arg in [dst, src]])) - - super(ExprAssign, self).__init__(self.dst.size) - - dst = property(lambda self: self._dst) - src = property(lambda self: self._src) - - - def __reduce__(self): - state = self._dst, self._src - return self.__class__, state - - def __new__(cls, dst, src): - if dst.is_slice() and dst.arg.size == src.size: - new_dst, new_src = dst.arg, src - elif dst.is_slice(): - # Complete the source with missing slice parts - new_dst = dst.arg - rest = [(ExprSlice(dst.arg, r[0], r[1]), r[0], r[1]) - for r in dst.slice_rest()] - all_a = [(src, dst.start, dst.stop)] + rest - all_a.sort(key=lambda x: x[1]) - args = [expr for (expr, _, _) in all_a] - new_src = ExprCompose(*args) - else: - new_dst, new_src = dst, src - expr = Expr.get_object(cls, (new_dst, new_src)) - expr._dst, expr._src = new_dst, new_src - return expr - - def __str__(self): - return "%s = %s" % (str(self._dst), str(self._src)) - - def get_r(self, mem_read=False, cst_read=False): - elements = self._src.get_r(mem_read, cst_read) - if isinstance(self._dst, ExprMem) and mem_read: - elements.update(self._dst.ptr.get_r(mem_read, cst_read)) - return elements - - def get_w(self): - if isinstance(self._dst, ExprMem): - return set([self._dst]) # [memreg] - else: - return self._dst.get_w() - - def _exprhash(self): - return hash((EXPRASSIGN, hash(self._dst), hash(self._src))) - - def _exprrepr(self): - return "%s(%r, %r)" % (self.__class__.__name__, self._dst, self._src) - - def __contains__(self, expr): - return (self == expr or - self._src.__contains__(expr) or - self._dst.__contains__(expr)) - - @visit_chk - def visit(self, callback, test_visit=None): - dst, src = self._dst.visit(callback, test_visit), self._src.visit(callback, test_visit) - if dst == self._dst and src == self._src: - return self - else: - return ExprAssign(dst, src) - - def copy(self): - return ExprAssign(self._dst.copy(), self._src.copy()) - - def depth(self): - return max(self._src.depth(), self._dst.depth()) + 1 - - def graph_recursive(self, graph): - graph.add_node(self) - for arg in [self._src, self._dst]: - arg.graph_recursive(graph) - graph.add_uniq_edge(self, arg) - - def is_aff(self): - return True - - -class ExprAff(ExprAssign): - """ - DEPRECATED class. - Use ExprAssign instead of ExprAff - """ - - def __init__(self, dst, src): - warnings.warn('DEPRECATION WARNING: use ExprAssign instead of ExprAff') - super(ExprAff, self).__init__(dst, src) - - -class ExprCond(Expr): - - """An ExprCond stand for a condition on an Expr - - Use cases: - - var1 < var2 - - min(var1, var2) - - if (cond) then ... else ... - """ - - __slots__ = Expr.__slots__ + ["_cond", "_src1", "_src2"] - - def __init__(self, cond, src1, src2): - """Create an ExprCond - @cond: Expr, condition - @src1: Expr, value if condition is evaled to not zero - @src2: Expr, value if condition is evaled zero - """ - - # cond, src1, src2 must be Expr - assert isinstance(cond, Expr) - assert isinstance(src1, Expr) - assert isinstance(src2, Expr) - - self._cond, self._src1, self._src2 = cond, src1, src2 - assert src1.size == src2.size - super(ExprCond, self).__init__(self.src1.size) - - cond = property(lambda self: self._cond) - src1 = property(lambda self: self._src1) - src2 = property(lambda self: self._src2) - - def __reduce__(self): - state = self._cond, self._src1, self._src2 - return self.__class__, state - - def __new__(cls, cond, src1, src2): - return Expr.get_object(cls, (cond, src1, src2)) - - def __str__(self): - return "%s?(%s,%s)" % (str_protected_child(self._cond, self), str(self._src1), str(self._src2)) - - def get_r(self, mem_read=False, cst_read=False): - out_src1 = self.src1.get_r(mem_read, cst_read) - out_src2 = self.src2.get_r(mem_read, cst_read) - return self.cond.get_r(mem_read, - cst_read).union(out_src1).union(out_src2) - - def get_w(self): - return set() - - def _exprhash(self): - return hash((EXPRCOND, hash(self.cond), - hash(self._src1), hash(self._src2))) - - def _exprrepr(self): - return "%s(%r, %r, %r)" % (self.__class__.__name__, - self._cond, self._src1, self._src2) - - def __contains__(self, expr): - return (self == expr or - self.cond.__contains__(expr) or - self.src1.__contains__(expr) or - self.src2.__contains__(expr)) - - @visit_chk - def visit(self, callback, test_visit=None): - cond = self._cond.visit(callback, test_visit) - src1 = self._src1.visit(callback, test_visit) - src2 = self._src2.visit(callback, test_visit) - if cond == self._cond and src1 == self._src1 and src2 == self._src2: - return self - return ExprCond(cond, src1, src2) - - def copy(self): - return ExprCond(self._cond.copy(), - self._src1.copy(), - self._src2.copy()) - - def depth(self): - return max(self._cond.depth(), - self._src1.depth(), - self._src2.depth()) + 1 - - def graph_recursive(self, graph): - graph.add_node(self) - for arg in [self._cond, self._src1, self._src2]: - arg.graph_recursive(graph) - graph.add_uniq_edge(self, arg) - - def is_cond(self): - return True - - -class ExprMem(Expr): - - """An ExprMem stand for a memory access - - Use cases: - - Memory read - - Memory write - """ - - __slots__ = Expr.__slots__ + ["_ptr"] - - def __init__(self, ptr, size=None): - """Create an ExprMem - @ptr: Expr, memory access address - @size: int, memory access size - """ - if size is None: - warnings.warn('DEPRECATION WARNING: size is a mandatory argument: use ExprMem(ptr, SIZE)') - size = 32 - - # ptr must be Expr - assert isinstance(ptr, Expr) - assert isinstance(size, int_types) - - if not isinstance(ptr, Expr): - raise ValueError( - 'ExprMem: ptr must be an Expr (not %s)' % type(ptr)) - - super(ExprMem, self).__init__(size) - self._ptr = ptr - - def get_arg(self): - warnings.warn('DEPRECATION WARNING: use exprmem.ptr instead of exprmem.arg') - return self.ptr - - def set_arg(self, value): - warnings.warn('DEPRECATION WARNING: use exprmem.ptr instead of exprmem.arg') - self.ptr = value - - ptr = property(lambda self: self._ptr) - arg = property(get_arg, set_arg) - - def __reduce__(self): - state = self._ptr, self._size - return self.__class__, state - - def __new__(cls, ptr, size=None): - if size is None: - warnings.warn('DEPRECATION WARNING: size is a mandatory argument: use ExprMem(ptr, SIZE)') - size = 32 - - return Expr.get_object(cls, (ptr, size)) - - def __str__(self): - return "@%d[%s]" % (self.size, str(self.ptr)) - - def get_r(self, mem_read=False, cst_read=False): - if mem_read: - return set(self._ptr.get_r(mem_read, cst_read).union(set([self]))) - else: - return set([self]) - - def get_w(self): - return set([self]) # [memreg] - - def _exprhash(self): - return hash((EXPRMEM, hash(self._ptr), self._size)) - - def _exprrepr(self): - return "%s(%r, %r)" % (self.__class__.__name__, - self._ptr, self._size) - - def __contains__(self, expr): - return self == expr or self._ptr.__contains__(expr) - - @visit_chk - def visit(self, callback, test_visit=None): - ptr = self._ptr.visit(callback, test_visit) - if ptr == self._ptr: - return self - return ExprMem(ptr, self.size) - - def copy(self): - ptr = self.ptr.copy() - return ExprMem(ptr, size=self.size) - - def is_mem_segm(self): - """Returns True if is ExprMem and ptr is_op_segm""" - return self._ptr.is_op_segm() - - def depth(self): - return self._ptr.depth() + 1 - - def graph_recursive(self, graph): - graph.add_node(self) - self._ptr.graph_recursive(graph) - graph.add_uniq_edge(self, self._ptr) - - def is_mem(self): - return True - - -class ExprOp(Expr): - - """An ExprOp stand for an operation between Expr - - Use cases: - - var1 XOR var2 - - var1 + var2 + var3 - - parity bit(var1) - """ - - __slots__ = Expr.__slots__ + ["_op", "_args"] - - def __init__(self, op, *args): - """Create an ExprOp - @op: str, operation - @*args: Expr, operand list - """ - - # args must be Expr - assert all(isinstance(arg, Expr) for arg in args) - - sizes = set([arg.size for arg in args]) - - if len(sizes) != 1: - # Special cases : operande sizes can differ - if op not in [ - "segm", - "FLAG_EQ_ADDWC", "FLAG_EQ_SUBWC", - "FLAG_SIGN_ADDWC", "FLAG_SIGN_SUBWC", - "FLAG_ADDWC_CF", "FLAG_ADDWC_OF", - "FLAG_SUBWC_CF", "FLAG_SUBWC_OF", - - ]: - raise ValueError( - "sanitycheck: ExprOp args must have same size! %s" % - ([(str(arg), arg.size) for arg in args])) - - if not isinstance(op, str): - raise ValueError("ExprOp: 'op' argument must be a string") - - assert isinstance(args, tuple) - self._op, self._args = op, args - - # Set size for special cases - if self._op in [ - TOK_EQUAL, 'parity', 'fcom_c0', 'fcom_c1', 'fcom_c2', 'fcom_c3', - 'fxam_c0', 'fxam_c1', 'fxam_c2', 'fxam_c3', - "access_segment_ok", "load_segment_limit_ok", "bcdadd_cf", - "ucomiss_zf", "ucomiss_pf", "ucomiss_cf", - "ucomisd_zf", "ucomisd_pf", "ucomisd_cf"]: - size = 1 - elif self._op in [TOK_INF, TOK_INF_SIGNED, - TOK_INF_UNSIGNED, TOK_INF_EQUAL, - TOK_INF_EQUAL_SIGNED, TOK_INF_EQUAL_UNSIGNED, - TOK_EQUAL, TOK_POS, - TOK_POS_STRICT, - ]: - size = 1 - elif self._op.startswith("fp_to_sint"): - size = int(self._op[len("fp_to_sint"):]) - elif self._op.startswith("fpconvert_fp"): - size = int(self._op[len("fpconvert_fp"):]) - elif self._op in [ - "FLAG_ADD_CF", "FLAG_SUB_CF", - "FLAG_ADD_OF", "FLAG_SUB_OF", - "FLAG_EQ", "FLAG_EQ_CMP", - "FLAG_SIGN_SUB", "FLAG_SIGN_ADD", - "FLAG_EQ_AND", - "FLAG_EQ_ADDWC", "FLAG_EQ_SUBWC", - "FLAG_SIGN_ADDWC", "FLAG_SIGN_SUBWC", - "FLAG_ADDWC_CF", "FLAG_ADDWC_OF", - "FLAG_SUBWC_CF", "FLAG_SUBWC_OF", - ]: - size = 1 - - elif self._op.startswith('signExt_'): - size = int(self._op[8:]) - elif self._op.startswith('zeroExt_'): - size = int(self._op[8:]) - elif self._op in ['segm']: - size = self._args[1].size - else: - if None in sizes: - size = None - else: - # All arguments have the same size - size = list(sizes)[0] - - super(ExprOp, self).__init__(size) - - op = property(lambda self: self._op) - args = property(lambda self: self._args) - - def __reduce__(self): - state = tuple([self._op] + list(self._args)) - return self.__class__, state - - def __new__(cls, op, *args): - return Expr.get_object(cls, (op, args)) - - def __str__(self): - if self._op == '-': # Unary minus - return '-' + str_protected_child(self._args[0], self) - if self.is_associative() or self.is_infix(): - return (' ' + self._op + ' ').join([str_protected_child(arg, self) - for arg in self._args]) - return (self._op + '(' + - ', '.join([str(arg) for arg in self._args]) + ')') - - def get_r(self, mem_read=False, cst_read=False): - return reduce(lambda elements, arg: - elements.union(arg.get_r(mem_read, cst_read)), self._args, set()) - - def get_w(self): - raise ValueError('op cannot be written!', self) - - def _exprhash(self): - h_hargs = [hash(arg) for arg in self._args] - return hash((EXPROP, self._op, tuple(h_hargs))) - - def _exprrepr(self): - return "%s(%r, %s)" % (self.__class__.__name__, self._op, - ', '.join(repr(arg) for arg in self._args)) - - def __contains__(self, expr): - if self == expr: - return True - for arg in self._args: - if arg.__contains__(expr): - return True - return False - - def is_function_call(self): - return self._op.startswith('call') - - def is_infix(self): - return self._op in [ - '-', '+', '*', '^', '&', '|', '>>', '<<', - 'a>>', '>>>', '<<<', '/', '%', '**', - TOK_INF_UNSIGNED, - TOK_INF_SIGNED, - TOK_INF_EQUAL_UNSIGNED, - TOK_INF_EQUAL_SIGNED, - TOK_EQUAL - ] - - def is_associative(self): - "Return True iff current operation is associative" - return (self._op in ['+', '*', '^', '&', '|']) - - def is_commutative(self): - "Return True iff current operation is commutative" - return (self._op in ['+', '*', '^', '&', '|']) - - @visit_chk - def visit(self, callback, test_visit=None): - args = [arg.visit(callback, test_visit) for arg in self._args] - modified = any([arg[0] != arg[1] for arg in zip(self._args, args)]) - if modified: - return ExprOp(self._op, *args) - return self - - def copy(self): - args = [arg.copy() for arg in self._args] - return ExprOp(self._op, *args) - - def depth(self): - depth = [arg.depth() for arg in self._args] - return max(depth) + 1 - - def graph_recursive(self, graph): - graph.add_node(self) - for arg in self._args: - arg.graph_recursive(graph) - graph.add_uniq_edge(self, arg) - - def is_op(self, op=None): - if op is None: - return True - return self.op == op - - def is_op_segm(self): - """Returns True if is ExprOp and op == 'segm'""" - return self.is_op('segm') - -class ExprSlice(Expr): - - __slots__ = Expr.__slots__ + ["_arg", "_start", "_stop"] - - def __init__(self, arg, start, stop): - - # arg must be Expr - assert isinstance(arg, Expr) - assert isinstance(start, int_types) - assert isinstance(stop, int_types) - assert start < stop - - self._arg, self._start, self._stop = arg, start, stop - super(ExprSlice, self).__init__(self._stop - self._start) - - arg = property(lambda self: self._arg) - start = property(lambda self: self._start) - stop = property(lambda self: self._stop) - - def __reduce__(self): - state = self._arg, self._start, self._stop - return self.__class__, state - - def __new__(cls, arg, start, stop): - return Expr.get_object(cls, (arg, start, stop)) - - def __str__(self): - return "%s[%d:%d]" % (str_protected_child(self._arg, self), self._start, self._stop) - - def get_r(self, mem_read=False, cst_read=False): - return self._arg.get_r(mem_read, cst_read) - - def get_w(self): - return self._arg.get_w() - - def _exprhash(self): - return hash((EXPRSLICE, hash(self._arg), self._start, self._stop)) - - def _exprrepr(self): - return "%s(%r, %d, %d)" % (self.__class__.__name__, self._arg, - self._start, self._stop) - - def __contains__(self, expr): - if self == expr: - return True - return self._arg.__contains__(expr) - - @visit_chk - def visit(self, callback, test_visit=None): - arg = self._arg.visit(callback, test_visit) - if arg == self._arg: - return self - return ExprSlice(arg, self._start, self._stop) - - def copy(self): - return ExprSlice(self._arg.copy(), self._start, self._stop) - - def depth(self): - return self._arg.depth() + 1 - - def slice_rest(self): - "Return the completion of the current slice" - size = self._arg.size - if self._start >= size or self._stop > size: - raise ValueError('bad slice rest %s %s %s' % - (size, self._start, self._stop)) - - if self._start == self._stop: - return [(0, size)] - - rest = [] - if self._start != 0: - rest.append((0, self._start)) - if self._stop < size: - rest.append((self._stop, size)) - - return rest - - def graph_recursive(self, graph): - graph.add_node(self) - self._arg.graph_recursive(graph) - graph.add_uniq_edge(self, self._arg) - - def is_slice(self, start=None, stop=None): - if start is not None and self._start != start: - return False - if stop is not None and self._stop != stop: - return False - return True - - -class ExprCompose(Expr): - - """ - Compose is like a hambuger. It concatenate Expressions - """ - - __slots__ = Expr.__slots__ + ["_args"] - - def __init__(self, *args): - """Create an ExprCompose - The ExprCompose is contiguous and starts at 0 - @args: [Expr, Expr, ...] - DEPRECATED: - @args: [(Expr, int, int), (Expr, int, int), ...] - """ - - # args must be Expr - assert all(isinstance(arg, Expr) for arg in args) - - assert isinstance(args, tuple) - self._args = args - super(ExprCompose, self).__init__(sum(arg.size for arg in args)) - - args = property(lambda self: self._args) - - def __reduce__(self): - state = self._args - return self.__class__, state - - def __new__(cls, *args): - return Expr.get_object(cls, args) - - def __str__(self): - return '{' + ', '.join(["%s %s %s" % (arg, idx, idx + arg.size) for idx, arg in self.iter_args()]) + '}' - - def get_r(self, mem_read=False, cst_read=False): - return reduce(lambda elements, arg: - elements.union(arg.get_r(mem_read, cst_read)), self._args, set()) - - def get_w(self): - return reduce(lambda elements, arg: - elements.union(arg.get_w()), self._args, set()) - - def _exprhash(self): - h_args = [EXPRCOMPOSE] + [hash(arg) for arg in self._args] - return hash(tuple(h_args)) - - def _exprrepr(self): - return "%s%r" % (self.__class__.__name__, self._args) - - def __contains__(self, expr): - if self == expr: - return True - for arg in self._args: - if arg == expr: - return True - if arg.__contains__(expr): - return True - return False - - @visit_chk - def visit(self, callback, test_visit=None): - args = [arg.visit(callback, test_visit) for arg in self._args] - modified = any([arg != arg_new for arg, arg_new in zip(self._args, args)]) - if modified: - return ExprCompose(*args) - return self - - def copy(self): - args = [arg.copy() for arg in self._args] - return ExprCompose(*args) - - def depth(self): - depth = [arg.depth() for arg in self._args] - return max(depth) + 1 - - def graph_recursive(self, graph): - graph.add_node(self) - for arg in self.args: - arg.graph_recursive(graph) - graph.add_uniq_edge(self, arg) - - def iter_args(self): - index = 0 - for arg in self._args: - yield index, arg - index += arg.size - - def is_compose(self): - return True - -# Expression order for comparison -EXPR_ORDER_DICT = { - ExprId: 1, - ExprLoc: 2, - ExprCond: 3, - ExprMem: 4, - ExprOp: 5, - ExprSlice: 6, - ExprCompose: 7, - ExprInt: 8, -} - - -def compare_exprs_compose(expr1, expr2): - # Sort by start bit address, then expr, then stop bit address - ret = cmp_elts(expr1[1], expr2[1]) - if ret: - return ret - ret = compare_exprs(expr1[0], expr2[0]) - if ret: - return ret - ret = cmp_elts(expr1[2], expr2[2]) - return ret - - -def compare_expr_list_compose(l1_e, l2_e): - # Sort by list elements in incremental order, then by list size - for i in range(min(len(l1_e), len(l2_e))): - ret = compare_exprs(l1_e[i], l2_e[i]) - if ret: - return ret - return cmp_elts(len(l1_e), len(l2_e)) - - -def compare_expr_list(l1_e, l2_e): - # Sort by list elements in incremental order, then by list size - for i in range(min(len(l1_e), len(l2_e))): - ret = compare_exprs(l1_e[i], l2_e[i]) - if ret: - return ret - return cmp_elts(len(l1_e), len(l2_e)) - - -def compare_exprs(expr1, expr2): - """Compare 2 expressions for canonization - @expr1: Expr - @expr2: Expr - 0 => == - 1 => expr1 > expr2 - -1 => expr1 < expr2 - """ - cls1 = expr1.__class__ - cls2 = expr2.__class__ - if cls1 != cls2: - return cmp_elts(EXPR_ORDER_DICT[cls1], EXPR_ORDER_DICT[cls2]) - if expr1 == expr2: - return 0 - if cls1 == ExprInt: - ret = cmp_elts(expr1.size, expr2.size) - if ret != 0: - return ret - return cmp_elts(expr1.arg, expr2.arg) - elif cls1 == ExprId: - name1 = force_bytes(expr1.name) - name2 = force_bytes(expr2.name) - ret = cmp_elts(name1, name2) - if ret: - return ret - return cmp_elts(expr1.size, expr2.size) - elif cls1 == ExprLoc: - ret = cmp_elts(expr1.loc_key, expr2.loc_key) - if ret: - return ret - return cmp_elts(expr1.size, expr2.size) - elif cls1 == ExprAssign: - raise NotImplementedError( - "Comparison from an ExprAssign not yet implemented" - ) - elif cls2 == ExprCond: - ret = compare_exprs(expr1.cond, expr2.cond) - if ret: - return ret - ret = compare_exprs(expr1.src1, expr2.src1) - if ret: - return ret - ret = compare_exprs(expr1.src2, expr2.src2) - return ret - elif cls1 == ExprMem: - ret = compare_exprs(expr1.ptr, expr2.ptr) - if ret: - return ret - return cmp_elts(expr1.size, expr2.size) - elif cls1 == ExprOp: - if expr1.op != expr2.op: - return cmp_elts(expr1.op, expr2.op) - return compare_expr_list(expr1.args, expr2.args) - elif cls1 == ExprSlice: - ret = compare_exprs(expr1.arg, expr2.arg) - if ret: - return ret - ret = cmp_elts(expr1.start, expr2.start) - if ret: - return ret - ret = cmp_elts(expr1.stop, expr2.stop) - return ret - elif cls1 == ExprCompose: - return compare_expr_list_compose(expr1.args, expr2.args) - raise NotImplementedError( - "Comparison between %r %r not implemented" % (expr1, expr2) - ) - - -def canonize_expr_list(expr_list): - return sorted(expr_list, key=cmp_to_key(compare_exprs)) - - -def canonize_expr_list_compose(expr_list): - return sorted(expr_list, key=cmp_to_key(compare_exprs_compose)) - -# Generate ExprInt with common size - - -def ExprInt1(i): - warnings.warn('DEPRECATION WARNING: use ExprInt(i, 1) instead of '\ - 'ExprInt1(i))') - return ExprInt(i, 1) - - -def ExprInt8(i): - warnings.warn('DEPRECATION WARNING: use ExprInt(i, 8) instead of '\ - 'ExprInt8(i))') - return ExprInt(i, 8) - - -def ExprInt16(i): - warnings.warn('DEPRECATION WARNING: use ExprInt(i, 16) instead of '\ - 'ExprInt16(i))') - return ExprInt(i, 16) - - -def ExprInt32(i): - warnings.warn('DEPRECATION WARNING: use ExprInt(i, 32) instead of '\ - 'ExprInt32(i))') - return ExprInt(i, 32) - - -def ExprInt64(i): - warnings.warn('DEPRECATION WARNING: use ExprInt(i, 64) instead of '\ - 'ExprInt64(i))') - return ExprInt(i, 64) - - -def ExprInt_from(expr, i): - "Generate ExprInt with size equal to expression" - warnings.warn('DEPRECATION WARNING: use ExprInt(i, expr.size) instead of'\ - 'ExprInt_from(expr, i))') - return ExprInt(i, expr.size) - - -def get_expr_ids_visit(expr, ids): - """Visitor to retrieve ExprId in @expr - @expr: Expr""" - if expr.is_id(): - ids.add(expr) - return expr - - -def get_expr_locs_visit(expr, locs): - """Visitor to retrieve ExprLoc in @expr - @expr: Expr""" - if expr.is_loc(): - locs.add(expr) - return expr - - -def get_expr_ids(expr): - """Retrieve ExprId in @expr - @expr: Expr""" - ids = set() - expr.visit(lambda x: get_expr_ids_visit(x, ids)) - return ids - - -def get_expr_locs(expr): - """Retrieve ExprLoc in @expr - @expr: Expr""" - locs = set() - expr.visit(lambda x: get_expr_locs_visit(x, locs)) - return locs - - -def test_set(expr, pattern, tks, result): - """Test if v can correspond to e. If so, update the context in result. - Otherwise, return False - @expr : Expr to match - @pattern : pattern Expr - @tks : list of ExprId, available jokers - @result : dictionary of ExprId -> Expr, current context - """ - - if not pattern in tks: - return expr == pattern - if pattern in result and result[pattern] != expr: - return False - result[pattern] = expr - return result - - -def match_expr(expr, pattern, tks, result=None): - """Try to match the @pattern expression with the pattern @expr with @tks jokers. - Result is output dictionary with matching joker values. - @expr : Expr pattern - @pattern : Targeted Expr to match - @tks : list of ExprId, available jokers - @result : dictionary of ExprId -> Expr, output matching context - """ - - if result is None: - result = {} - - if pattern in tks: - # pattern is a Joker - return test_set(expr, pattern, tks, result) - - if expr.is_int(): - return test_set(expr, pattern, tks, result) - - elif expr.is_id(): - return test_set(expr, pattern, tks, result) - - elif expr.is_loc(): - return test_set(expr, pattern, tks, result) - - elif expr.is_op(): - - # expr need to be the same operation than pattern - if not pattern.is_op(): - return False - if expr.op != pattern.op: - return False - if len(expr.args) != len(pattern.args): - return False - - # Perform permutation only if the current operation is commutative - if expr.is_commutative(): - permutations = itertools.permutations(expr.args) - else: - permutations = [expr.args] - - # For each permutations of arguments - for permut in permutations: - good = True - # We need to use a copy of result to not override it - myresult = dict(result) - for sub_expr, sub_pattern in zip(permut, pattern.args): - ret = match_expr(sub_expr, sub_pattern, tks, myresult) - # If the current permutation do not match EVERY terms - if ret is False: - good = False - break - if good is True: - # We found a possibility - for joker, value in viewitems(myresult): - # Updating result in place (to keep pointer in recursion) - result[joker] = value - return result - return False - - # Recursive tests - - elif expr.is_mem(): - if not pattern.is_mem(): - return False - if expr.size != pattern.size: - return False - return match_expr(expr.ptr, pattern.ptr, tks, result) - - elif expr.is_slice(): - if not pattern.is_slice(): - return False - if expr.start != pattern.start or expr.stop != pattern.stop: - return False - return match_expr(expr.arg, pattern.arg, tks, result) - - elif expr.is_cond(): - if not pattern.is_cond(): - return False - if match_expr(expr.cond, pattern.cond, tks, result) is False: - return False - if match_expr(expr.src1, pattern.src1, tks, result) is False: - return False - if match_expr(expr.src2, pattern.src2, tks, result) is False: - return False - return result - - elif expr.is_compose(): - if not pattern.is_compose(): - return False - for sub_expr, sub_pattern in zip(expr.args, pattern.args): - if match_expr(sub_expr, sub_pattern, tks, result) is False: - return False - return result - - elif expr.is_aff(): - if not pattern.is_aff(): - return False - if match_expr(expr.src, pattern.src, tks, result) is False: - return False - if match_expr(expr.dst, pattern.dst, tks, result) is False: - return False - return result - - else: - raise NotImplementedError("match_expr: Unknown type: %s" % type(expr)) - - -def MatchExpr(expr, pattern, tks, result=None): - warnings.warn('DEPRECATION WARNING: use match_expr instead of MatchExpr') - return match_expr(expr, pattern, tks, result) - - -def get_rw(exprs): - o_r = set() - o_w = set() - for expr in exprs: - o_r.update(expr.get_r(mem_read=True)) - for expr in exprs: - o_w.update(expr.get_w()) - return o_r, o_w - - -def get_list_rw(exprs, mem_read=False, cst_read=True): - """Return list of read/write reg/cst/mem for each @exprs - @exprs: list of expressions - @mem_read: walk though memory accesses - @cst_read: retrieve constants - """ - list_rw = [] - # cst_num = 0 - for expr in exprs: - o_r = set() - o_w = set() - # get r/w - o_r.update(expr.get_r(mem_read=mem_read, cst_read=cst_read)) - if isinstance(expr.dst, ExprMem): - o_r.update(expr.dst.arg.get_r(mem_read=mem_read, cst_read=cst_read)) - o_w.update(expr.get_w()) - # each cst is indexed - o_r_rw = set() - for read in o_r: - o_r_rw.add(read) - o_r = o_r_rw - list_rw.append((o_r, o_w)) - - return list_rw - - -def get_expr_ops(expr): - """Retrieve operators of an @expr - @expr: Expr""" - def visit_getops(expr, out=None): - if out is None: - out = set() - if isinstance(expr, ExprOp): - out.add(expr.op) - return expr - ops = set() - expr.visit(lambda x: visit_getops(x, ops)) - return ops - - -def get_expr_mem(expr): - """Retrieve memory accesses of an @expr - @expr: Expr""" - def visit_getmem(expr, out=None): - if out is None: - out = set() - if isinstance(expr, ExprMem): - out.add(expr) - return expr - ops = set() - expr.visit(lambda x: visit_getmem(x, ops)) - return ops - - -def _expr_compute_cf(op1, op2): - """ - Get carry flag of @op1 - @op2 - Ref: x86 cf flag - @op1: Expression - @op2: Expression - """ - res = op1 - op2 - cf = (((op1 ^ op2) ^ res) ^ ((op1 ^ res) & (op1 ^ op2))).msb() - return cf - -def _expr_compute_of(op1, op2): - """ - Get overflow flag of @op1 - @op2 - Ref: x86 of flag - @op1: Expression - @op2: Expression - """ - res = op1 - op2 - of = (((op1 ^ res) & (op1 ^ op2))).msb() - return of - -def _expr_compute_zf(op1, op2): - """ - Get zero flag of @op1 - @op2 - @op1: Expression - @op2: Expression - """ - res = op1 - op2 - zf = ExprCond(res, - ExprInt(0, 1), - ExprInt(1, 1)) - return zf - - -def _expr_compute_nf(op1, op2): - """ - Get negative (or sign) flag of @op1 - @op2 - @op1: Expression - @op2: Expression - """ - res = op1 - op2 - nf = res.msb() - return nf - - -def expr_is_equal(op1, op2): - """ - if op1 == op2: - Return ExprInt(1, 1) - else: - Return ExprInt(0, 1) - """ - - zf = _expr_compute_zf(op1, op2) - return zf - - -def expr_is_not_equal(op1, op2): - """ - if op1 != op2: - Return ExprInt(1, 1) - else: - Return ExprInt(0, 1) - """ - - zf = _expr_compute_zf(op1, op2) - return ~zf - - -def expr_is_unsigned_greater(op1, op2): - """ - UNSIGNED cmp - if op1 > op2: - Return ExprInt(1, 1) - else: - Return ExprInt(0, 1) - """ - - cf = _expr_compute_cf(op1, op2) - zf = _expr_compute_zf(op1, op2) - return ~(cf | zf) - - -def expr_is_unsigned_greater_or_equal(op1, op2): - """ - Unsigned cmp - if op1 >= op2: - Return ExprInt(1, 1) - else: - Return ExprInt(0, 1) - """ - - cf = _expr_compute_cf(op1, op2) - return ~cf - - -def expr_is_unsigned_lower(op1, op2): - """ - Unsigned cmp - if op1 < op2: - Return ExprInt(1, 1) - else: - Return ExprInt(0, 1) - """ - - cf = _expr_compute_cf(op1, op2) - return cf - - -def expr_is_unsigned_lower_or_equal(op1, op2): - """ - Unsigned cmp - if op1 <= op2: - Return ExprInt(1, 1) - else: - Return ExprInt(0, 1) - """ - - cf = _expr_compute_cf(op1, op2) - zf = _expr_compute_zf(op1, op2) - return cf | zf - - -def expr_is_signed_greater(op1, op2): - """ - Signed cmp - if op1 > op2: - Return ExprInt(1, 1) - else: - Return ExprInt(0, 1) - """ - - nf = _expr_compute_nf(op1, op2) - of = _expr_compute_of(op1, op2) - zf = _expr_compute_zf(op1, op2) - return ~(zf | (nf ^ of)) - - -def expr_is_signed_greater_or_equal(op1, op2): - """ - Signed cmp - if op1 > op2: - Return ExprInt(1, 1) - else: - Return ExprInt(0, 1) - """ - - nf = _expr_compute_nf(op1, op2) - of = _expr_compute_of(op1, op2) - return ~(nf ^ of) - - -def expr_is_signed_lower(op1, op2): - """ - Signed cmp - if op1 < op2: - Return ExprInt(1, 1) - else: - Return ExprInt(0, 1) - """ - - nf = _expr_compute_nf(op1, op2) - of = _expr_compute_of(op1, op2) - return nf ^ of - - -def expr_is_signed_lower_or_equal(op1, op2): - """ - Signed cmp - if op1 <= op2: - Return ExprInt(1, 1) - else: - Return ExprInt(0, 1) - """ - - nf = _expr_compute_nf(op1, op2) - of = _expr_compute_of(op1, op2) - zf = _expr_compute_zf(op1, op2) - return zf | (nf ^ of) - -# sign bit | exponent | significand -size_to_IEEE754_info = { - 16: { - "exponent": 5, - "significand": 10, - }, - 32: { - "exponent": 8, - "significand": 23, - }, - 64: { - "exponent": 11, - "significand": 52, - }, -} - -def expr_is_NaN(expr): - """Return 1 or 0 on 1 bit if expr represent a NaN value according to IEEE754 - """ - info = size_to_IEEE754_info[expr.size] - exponent = expr[info["significand"]: info["significand"] + info["exponent"]] - - # exponent is full of 1s and significand is not NULL - return ExprCond(exponent - ExprInt(-1, exponent.size), - ExprInt(0, 1), - ExprCond(expr[:info["significand"]], ExprInt(1, 1), - ExprInt(0, 1))) - - -def expr_is_infinite(expr): - """Return 1 or 0 on 1 bit if expr represent an infinite value according to - IEEE754 - """ - info = size_to_IEEE754_info[expr.size] - exponent = expr[info["significand"]: info["significand"] + info["exponent"]] - - # exponent is full of 1s and significand is NULL - return ExprCond(exponent - ExprInt(-1, exponent.size), - ExprInt(0, 1), - ExprCond(expr[:info["significand"]], ExprInt(0, 1), - ExprInt(1, 1))) - - -def expr_is_IEEE754_zero(expr): - """Return 1 or 0 on 1 bit if expr represent a zero value according to - IEEE754 - """ - # Sign is the msb - expr_no_sign = expr[:expr.size - 1] - return ExprCond(expr_no_sign, ExprInt(0, 1), ExprInt(1, 1)) - - -def expr_is_IEEE754_denormal(expr): - """Return 1 or 0 on 1 bit if expr represent a denormalized value according - to IEEE754 - """ - info = size_to_IEEE754_info[expr.size] - exponent = expr[info["significand"]: info["significand"] + info["exponent"]] - # exponent is full of 0s - return ExprCond(exponent, ExprInt(0, 1), ExprInt(1, 1)) - - -def expr_is_qNaN(expr): - """Return 1 or 0 on 1 bit if expr represent a qNaN (quiet) value according to - IEEE754 - """ - info = size_to_IEEE754_info[expr.size] - significand_top = expr[info["significand"]: info["significand"] + 1] - return expr_is_NaN(expr) & significand_top - - -def expr_is_sNaN(expr): - """Return 1 or 0 on 1 bit if expr represent a sNaN (signalling) value according - to IEEE754 - """ - info = size_to_IEEE754_info[expr.size] - significand_top = expr[info["significand"]: info["significand"] + 1] - return expr_is_NaN(expr) & ~significand_top - - -def expr_is_float_lower(op1, op2): - """Return 1 on 1 bit if @op1 < @op2, 0 otherwise. - /!\ Assume @op1 and @op2 are not NaN - Comparison is the floating point one, defined in IEEE754 - """ - sign1, sign2 = op1.msb(), op2.msb() - magn1, magn2 = op1[:-1], op2[:-1] - return ExprCond(sign1 ^ sign2, - # Sign different, only the sign matters - sign1, # sign1 ? op1 < op2 : op1 >= op2 - # Sign equals, the result is inversed for negatives - sign1 ^ (expr_is_unsigned_lower(magn1, magn2))) - - -def expr_is_float_equal(op1, op2): - """Return 1 on 1 bit if @op1 == @op2, 0 otherwise. - /!\ Assume @op1 and @op2 are not NaN - Comparison is the floating point one, defined in IEEE754 - """ - sign1, sign2 = op1.msb(), op2.msb() - magn1, magn2 = op1[:-1], op2[:-1] - return ExprCond(magn1 ^ magn2, - ExprInt(0, 1), - ExprCond(magn1, - # magn1 == magn2, are the signal equals? - ~(sign1 ^ sign2), - # Special case: -0.0 == +0.0 - ExprInt(1, 1)) - ) diff --git a/miasm2/expression/expression_helper.py b/miasm2/expression/expression_helper.py deleted file mode 100644 index a50e0d5b..00000000 --- a/miasm2/expression/expression_helper.py +++ /dev/null @@ -1,628 +0,0 @@ -# -# Copyright (C) 2011 EADS France, Fabrice Desclaux <fabrice.desclaux@eads.net> -# -# This program is free software; you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation; either version 2 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License along -# with this program; if not, write to the Free Software Foundation, Inc., -# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -# - -# Expressions manipulation functions -from builtins import range -import itertools -import collections -import random -import string -import warnings - -from future.utils import viewitems, viewvalues - -import miasm2.expression.expression as m2_expr - - -def parity(a): - tmp = (a) & 0xFF - cpt = 1 - while tmp != 0: - cpt ^= tmp & 1 - tmp >>= 1 - return cpt - - -def merge_sliceto_slice(expr): - """ - Apply basic factorisation on ExprCompose sub components - @expr: ExprCompose - """ - - out_args = [] - last_index = 0 - for index, arg in expr.iter_args(): - # Init - if len(out_args) == 0: - out_args.append(arg) - continue - - last_value = out_args[-1] - # Consecutive - - if last_index + last_value.size == index: - # Merge consecutive integers - if (isinstance(arg, m2_expr.ExprInt) and - isinstance(last_value, m2_expr.ExprInt)): - new_size = last_value.size + arg.size - value = int(arg) << last_value.size - value |= int(last_value) - out_args[-1] = m2_expr.ExprInt(value, size=new_size) - continue - - # Merge consecuvite slice - elif (isinstance(arg, m2_expr.ExprSlice) and - isinstance(last_value, m2_expr.ExprSlice)): - value = arg.arg - if (last_value.arg == value and - last_value.stop == arg.start): - out_args[-1] = value[last_value.start:arg.stop] - continue - - # Unmergeable - last_index = index - out_args.append(arg) - - return out_args - - -op_propag_cst = ['+', '*', '^', '&', '|', '>>', - '<<', "a>>", ">>>", "<<<", - "/", "%", 'sdiv', 'smod', 'umod', 'udiv','**'] - - -def is_pure_int(e): - """ - return True if expr is only composed with integers - /!\ ExprCond returns True is src1 and src2 are integers - """ - def modify_cond(e): - if isinstance(e, m2_expr.ExprCond): - return e.src1 | e.src2 - return e - - def find_int(e, s): - if isinstance(e, m2_expr.ExprId) or isinstance(e, m2_expr.ExprMem): - s.add(e) - return e - s = set() - new_e = e.visit(modify_cond) - new_e.visit(lambda x: find_int(x, s)) - if s: - return False - return True - - -def is_int_or_cond_src_int(e): - if isinstance(e, m2_expr.ExprInt): - return True - if isinstance(e, m2_expr.ExprCond): - return (isinstance(e.src1, m2_expr.ExprInt) and - isinstance(e.src2, m2_expr.ExprInt)) - return False - - -def fast_unify(seq, idfun=None): - # order preserving unifying list function - if idfun is None: - idfun = lambda x: x - seen = {} - result = [] - for item in seq: - marker = idfun(item) - - if marker in seen: - continue - seen[marker] = 1 - result.append(item) - return result - -def get_missing_interval(all_intervals, i_min=0, i_max=32): - """Return a list of missing interval in all_interval - @all_interval: list of (int, int) - @i_min: int, minimal missing interval bound - @i_max: int, maximal missing interval bound""" - - my_intervals = all_intervals[:] - my_intervals.sort() - my_intervals.append((i_max, i_max)) - - missing_i = [] - last_pos = i_min - for start, stop in my_intervals: - if last_pos != start: - missing_i.append((last_pos, start)) - last_pos = stop - return missing_i - - -class Variables_Identifier(object): - """Identify variables in an expression. - Returns: - - variables with their corresponding values - - original expression with variables translated - """ - - def __init__(self, expr, var_prefix="v"): - """Set the expression @expr to handle and launch variable identification - process - @expr: Expr instance - @var_prefix: (optional) prefix of the variable name, default is 'v'""" - - # Init - self.var_indice = itertools.count() - self.var_asked = set() - self._vars = {} # VarID -> Expr - self.var_prefix = var_prefix - - # Launch recurrence - self.find_variables_rec(expr) - - # Compute inter-variable dependencies - has_change = True - while has_change: - has_change = False - for var_id, var_value in list(viewitems(self._vars)): - cur = var_value - - # Do not replace with itself - to_replace = { - v_val:v_id - for v_id, v_val in viewitems(self._vars) - if v_id != var_id - } - var_value = var_value.replace_expr(to_replace) - - if cur != var_value: - # Force @self._vars update - has_change = True - self._vars[var_id] = var_value - break - - # Replace in the original equation - self._equation = expr.replace_expr( - { - v_val: v_id for v_id, v_val - in viewitems(self._vars) - } - ) - - # Compute variables dependencies - self._vars_ordered = collections.OrderedDict() - todo = set(self._vars) - needs = {} - - ## Build initial needs - for var_id, var_expr in viewitems(self._vars): - ### Handle corner cases while using Variable Identifier on an - ### already computed equation - needs[var_id] = [ - var_name - for var_name in var_expr.get_r(mem_read=True) - if self.is_var_identifier(var_name) and \ - var_name in todo and \ - var_name != var_id - ] - - ## Build order list - while todo: - done = set() - for var_id in todo: - all_met = True - for need in needs[var_id]: - if need not in self._vars_ordered: - # A dependency is not met - all_met = False - break - if not all_met: - continue - - # All dependencies are already met, add current - self._vars_ordered[var_id] = self._vars[var_id] - done.add(var_id) - - # Update the todo list - for element_done in done: - todo.remove(element_done) - - def is_var_identifier(self, expr): - "Return True iff @expr is a variable identifier" - if not isinstance(expr, m2_expr.ExprId): - return False - return expr in self._vars - - def find_variables_rec(self, expr): - """Recursive method called by find_variable to expand @expr. - Set @var_names and @var_values. - This implementation is faster than an expression visitor because - we do not rebuild each expression. - """ - - if (expr in self.var_asked): - # Expr has already been asked - if expr not in viewvalues(self._vars): - # Create var - identifier = m2_expr.ExprId( - "%s%s" % ( - self.var_prefix, - next(self.var_indice) - ), - size = expr.size - ) - self._vars[identifier] = expr - - # Recursion stop case - return - else: - # First time for @expr - self.var_asked.add(expr) - - if isinstance(expr, m2_expr.ExprOp): - for a in expr.args: - self.find_variables_rec(a) - - elif isinstance(expr, m2_expr.ExprInt): - pass - - elif isinstance(expr, m2_expr.ExprId): - pass - - elif isinstance(expr, m2_expr.ExprLoc): - pass - - elif isinstance(expr, m2_expr.ExprMem): - self.find_variables_rec(expr.ptr) - - elif isinstance(expr, m2_expr.ExprCompose): - for arg in expr.args: - self.find_variables_rec(arg) - - elif isinstance(expr, m2_expr.ExprSlice): - self.find_variables_rec(expr.arg) - - elif isinstance(expr, m2_expr.ExprCond): - self.find_variables_rec(expr.cond) - self.find_variables_rec(expr.src1) - self.find_variables_rec(expr.src2) - - else: - raise NotImplementedError("Type not handled: %s" % expr) - - @property - def vars(self): - return self._vars_ordered - - @property - def equation(self): - return self._equation - - def __str__(self): - "Display variables and final equation" - out = "" - for var_id, var_expr in viewitems(self.vars): - out += "%s = %s\n" % (var_id, var_expr) - out += "Final: %s" % self.equation - return out - - -class ExprRandom(object): - """Return an expression randomly generated""" - - # Identifiers length - identifier_len = 5 - # Identifiers' name charset - identifier_charset = string.ascii_letters - # Number max value - number_max = 0xFFFFFFFF - # Available operations - operations_by_args_number = {1: ["-"], - 2: ["<<", "<<<", ">>", ">>>"], - "2+": ["+", "*", "&", "|", "^"], - } - # Maximum number of argument for operations - operations_max_args_number = 5 - # If set, output expression is a perfect tree - perfect_tree = True - # Max argument size in slice, relative to slice size - slice_add_size = 10 - # Maximum number of layer in compose - compose_max_layer = 5 - # Maximum size of memory address in bits - memory_max_address_size = 32 - # Re-use already generated elements to mimic a more realistic behavior - reuse_element = True - generated_elements = {} # (depth, size) -> [Expr] - - @classmethod - def identifier(cls, size=32): - """Return a random identifier - @size: (optional) identifier size - """ - return m2_expr.ExprId("".join([random.choice(cls.identifier_charset) - for _ in range(cls.identifier_len)]), - size=size) - - @classmethod - def number(cls, size=32): - """Return a random number - @size: (optional) number max bits - """ - num = random.randint(0, cls.number_max % (2**size)) - return m2_expr.ExprInt(num, size) - - @classmethod - def atomic(cls, size=32): - """Return an atomic Expression - @size: (optional) Expr size - """ - available_funcs = [cls.identifier, cls.number] - return random.choice(available_funcs)(size=size) - - @classmethod - def operation(cls, size=32, depth=1): - """Return an ExprOp - @size: (optional) Operation size - @depth: (optional) Expression depth - """ - operand_type = random.choice(list(cls.operations_by_args_number)) - if isinstance(operand_type, str) and "+" in operand_type: - number_args = random.randint( - int(operand_type[:-1]), - cls.operations_max_args_number - ) - else: - number_args = operand_type - - args = [cls._gen(size=size, depth=depth - 1) - for _ in range(number_args)] - operand = random.choice(cls.operations_by_args_number[operand_type]) - return m2_expr.ExprOp(operand, - *args) - - @classmethod - def slice(cls, size=32, depth=1): - """Return an ExprSlice - @size: (optional) Operation size - @depth: (optional) Expression depth - """ - start = random.randint(0, size) - stop = start + size - return cls._gen(size=random.randint(stop, stop + cls.slice_add_size), - depth=depth - 1)[start:stop] - - @classmethod - def compose(cls, size=32, depth=1): - """Return an ExprCompose - @size: (optional) Operation size - @depth: (optional) Expression depth - """ - # First layer - upper_bound = random.randint(1, size) - args = [cls._gen(size=upper_bound, depth=depth - 1)] - - # Next layers - while (upper_bound < size): - if len(args) == (cls.compose_max_layer - 1): - # We reach the maximum size - new_upper_bound = size - else: - new_upper_bound = random.randint(upper_bound + 1, size) - - args.append(cls._gen(size=new_upper_bound - upper_bound)) - upper_bound = new_upper_bound - return m2_expr.ExprCompose(*args) - - @classmethod - def memory(cls, size=32, depth=1): - """Return an ExprMem - @size: (optional) Operation size - @depth: (optional) Expression depth - """ - - address_size = random.randint(1, cls.memory_max_address_size) - return m2_expr.ExprMem(cls._gen(size=address_size, - depth=depth - 1), - size=size) - - @classmethod - def _gen(cls, size=32, depth=1): - """Internal function for generating sub-expression according to options - @size: (optional) Operation size - @depth: (optional) Expression depth - /!\ @generated_elements is left modified - """ - # Perfect tree handling - if not cls.perfect_tree: - depth = random.randint(max(0, depth - 2), depth) - - # Element re-use - if cls.reuse_element and random.choice([True, False]) and \ - (depth, size) in cls.generated_elements: - return random.choice(cls.generated_elements[(depth, size)]) - - # Recursion stop - if depth == 0: - return cls.atomic(size=size) - - # Build a more complex expression - available_funcs = [cls.operation, cls.slice, cls.compose, cls.memory] - gen = random.choice(available_funcs)(size=size, depth=depth) - - # Save it - new_value = cls.generated_elements.get((depth, size), []) + [gen] - cls.generated_elements[(depth, size)] = new_value - return gen - - @classmethod - def get(cls, size=32, depth=1, clean=True): - """Return a randomly generated expression - @size: (optional) Operation size - @depth: (optional) Expression depth - @clean: (optional) Clean expression cache between two calls - """ - # Init state - if clean: - cls.generated_elements = {} - - # Get an element - got = cls._gen(size=size, depth=depth) - - # Clear state - if clean: - cls.generated_elements = {} - - return got - -def expr_cmpu(arg1, arg2): - """ - Returns a one bit long Expression: - * 1 if @arg1 is strictly greater than @arg2 (unsigned) - * 0 otherwise. - """ - warnings.warn('DEPRECATION WARNING: use "expr_is_unsigned_greater" instead"') - return m2_expr.expr_is_unsigned_greater(arg1, arg2) - -def expr_cmps(arg1, arg2): - """ - Returns a one bit long Expression: - * 1 if @arg1 is strictly greater than @arg2 (signed) - * 0 otherwise. - """ - warnings.warn('DEPRECATION WARNING: use "expr_is_signed_greater" instead"') - return m2_expr.expr_is_signed_greater(arg1, arg2) - - -class CondConstraint(object): - - """Stand for a constraint on an Expr""" - - # str of the associated operator - operator = "" - - def __init__(self, expr): - self.expr = expr - - def __repr__(self): - return "<%s %s 0>" % (self.expr, self.operator) - - def to_constraint(self): - """Transform itself into a constraint using Expr""" - raise NotImplementedError("Abstract method") - - -class CondConstraintZero(CondConstraint): - - """Stand for a constraint like 'A == 0'""" - operator = m2_expr.TOK_EQUAL - - def to_constraint(self): - return m2_expr.ExprAssign(self.expr, m2_expr.ExprInt(0, self.expr.size)) - - -class CondConstraintNotZero(CondConstraint): - - """Stand for a constraint like 'A != 0'""" - operator = "!=" - - def to_constraint(self): - cst1, cst2 = m2_expr.ExprInt(0, 1), m2_expr.ExprInt(1, 1) - return m2_expr.ExprAssign(cst1, m2_expr.ExprCond(self.expr, cst1, cst2)) - - -ConstrainedValue = collections.namedtuple("ConstrainedValue", - ["constraints", "value"]) - - -class ConstrainedValues(set): - - """Set of ConstrainedValue""" - - def __str__(self): - out = [] - for sol in self: - out.append("%s with constraints:" % sol.value) - for constraint in sol.constraints: - out.append("\t%s" % constraint) - return "\n".join(out) - - -def possible_values(expr): - """Return possible values for expression @expr, associated with their - condition constraint as a ConstrainedValues instance - @expr: Expr instance - """ - - consvals = ConstrainedValues() - - # Terminal expression - if (isinstance(expr, m2_expr.ExprInt) or - isinstance(expr, m2_expr.ExprId) or - isinstance(expr, m2_expr.ExprLoc)): - consvals.add(ConstrainedValue(frozenset(), expr)) - # Unary expression - elif isinstance(expr, m2_expr.ExprSlice): - consvals.update(ConstrainedValue(consval.constraints, - consval.value[expr.start:expr.stop]) - for consval in possible_values(expr.arg)) - elif isinstance(expr, m2_expr.ExprMem): - consvals.update(ConstrainedValue(consval.constraints, - m2_expr.ExprMem(consval.value, - expr.size)) - for consval in possible_values(expr.ptr)) - elif isinstance(expr, m2_expr.ExprAssign): - consvals.update(possible_values(expr.src)) - # Special case: constraint insertion - elif isinstance(expr, m2_expr.ExprCond): - src1cond = CondConstraintNotZero(expr.cond) - src2cond = CondConstraintZero(expr.cond) - consvals.update(ConstrainedValue(consval.constraints.union([src1cond]), - consval.value) - for consval in possible_values(expr.src1)) - consvals.update(ConstrainedValue(consval.constraints.union([src2cond]), - consval.value) - for consval in possible_values(expr.src2)) - # N-ary expression - elif isinstance(expr, m2_expr.ExprOp): - # For details, see ExprCompose - consvals_args = [possible_values(arg) for arg in expr.args] - for consvals_possibility in itertools.product(*consvals_args): - args_value = [consval.value for consval in consvals_possibility] - args_constraint = itertools.chain(*[consval.constraints - for consval in consvals_possibility]) - consvals.add(ConstrainedValue(frozenset(args_constraint), - m2_expr.ExprOp(expr.op, *args_value))) - elif isinstance(expr, m2_expr.ExprCompose): - # Generate each possibility for sub-argument, associated with the start - # and stop bit - consvals_args = [ - list(possible_values(arg)) - for arg in expr.args - ] - for consvals_possibility in itertools.product(*consvals_args): - # Merge constraint of each sub-element - args_constraint = itertools.chain(*[consval.constraints - for consval in consvals_possibility]) - # Gen the corresponding constraints / ExprCompose - args = [consval.value for consval in consvals_possibility] - consvals.add( - ConstrainedValue(frozenset(args_constraint), - m2_expr.ExprCompose(*args))) - else: - raise RuntimeError("Unsupported type for expr: %s" % type(expr)) - - return consvals diff --git a/miasm2/expression/expression_reduce.py b/miasm2/expression/expression_reduce.py deleted file mode 100644 index adad552e..00000000 --- a/miasm2/expression/expression_reduce.py +++ /dev/null @@ -1,280 +0,0 @@ -""" -Expression reducer: -Apply reduction rules to an Expression ast -""" - -import logging -from miasm2.expression.expression import ExprInt, ExprId, ExprLoc, ExprOp, \ - ExprSlice, ExprCompose, ExprMem, ExprCond - -log_reduce = logging.getLogger("expr_reduce") -console_handler = logging.StreamHandler() -console_handler.setFormatter(logging.Formatter("%(levelname)-5s: %(message)s")) -log_reduce.addHandler(console_handler) -log_reduce.setLevel(logging.WARNING) - - - -class ExprNode(object): - """Clone of Expression object with additional information""" - - def __init__(self, expr): - self.expr = expr - - -class ExprNodeInt(ExprNode): - def __init__(self, expr): - assert expr.is_int() - super(ExprNodeInt, self).__init__(expr) - self.arg = None - - def __repr__(self): - if self.info is not None: - out = repr(self.info) - else: - out = str(self.expr) - return out - - -class ExprNodeId(ExprNode): - def __init__(self, expr): - assert expr.is_id() - super(ExprNodeId, self).__init__(expr) - self.arg = None - - def __repr__(self): - if self.info is not None: - out = repr(self.info) - else: - out = str(self.expr) - return out - - -class ExprNodeLoc(ExprNode): - def __init__(self, expr): - assert expr.is_loc() - super(ExprNodeLoc, self).__init__(expr) - self.arg = None - - def __repr__(self): - if self.info is not None: - out = repr(self.info) - else: - out = str(self.expr) - return out - - -class ExprNodeMem(ExprNode): - def __init__(self, expr): - assert expr.is_mem() - super(ExprNodeMem, self).__init__(expr) - self.ptr = None - - def __repr__(self): - if self.info is not None: - out = repr(self.info) - else: - out = "@%d[%r]" % (self.expr.size, self.ptr) - return out - - -class ExprNodeOp(ExprNode): - def __init__(self, expr): - assert expr.is_op() - super(ExprNodeOp, self).__init__(expr) - self.args = None - - def __repr__(self): - if self.info is not None: - out = repr(self.info) - else: - if len(self.args) == 1: - out = "(%s(%r))" % (self.expr.op, self.args[0]) - else: - out = "(%s)" % self.expr.op.join(repr(arg) for arg in self.args) - return out - - -class ExprNodeSlice(ExprNode): - def __init__(self, expr): - assert expr.is_slice() - super(ExprNodeSlice, self).__init__(expr) - self.arg = None - - def __repr__(self): - if self.info is not None: - out = repr(self.info) - else: - out = "%r[%d:%d]" % (self.arg, self.expr.start, self.expr.stop) - return out - - -class ExprNodeCompose(ExprNode): - def __init__(self, expr): - assert expr.is_compose() - super(ExprNodeCompose, self).__init__(expr) - self.args = None - - def __repr__(self): - if self.info is not None: - out = repr(self.info) - else: - out = "{%s}" % ', '.join(repr(arg) for arg in self.args) - return out - - -class ExprNodeCond(ExprNode): - def __init__(self, expr): - assert expr.is_cond() - super(ExprNodeCond, self).__init__(expr) - self.cond = None - self.src1 = None - self.src2 = None - - def __repr__(self): - if self.info is not None: - out = repr(self.info) - else: - out = "(%r?%r:%r)" % (self.cond, self.src1, self.src2) - return out - - -class ExprReducer(object): - """Apply reduction rules to an expr - - reduction_rules: list of ordered reduction rules - - List of function representing reduction rules - Function API: - reduction_xxx(self, node, lvl=0) - with: - * node: the ExprNode to qualify - * lvl: [optional] the recursion level - Returns: - * None if the reduction rule is not applied - * the resulting information to store in the ExprNode.info - - allow_none_result: allow missing reduction rules - """ - - reduction_rules = [] - allow_none_result = False - - def expr2node(self, expr): - """Build ExprNode mirror of @expr - - @expr: Expression to analyze - """ - - if isinstance(expr, ExprId): - node = ExprNodeId(expr) - elif isinstance(expr, ExprLoc): - node = ExprNodeLoc(expr) - elif isinstance(expr, ExprInt): - node = ExprNodeInt(expr) - elif isinstance(expr, ExprMem): - son = self.expr2node(expr.ptr) - node = ExprNodeMem(expr) - node.ptr = son - elif isinstance(expr, ExprSlice): - son = self.expr2node(expr.arg) - node = ExprNodeSlice(expr) - node.arg = son - elif isinstance(expr, ExprOp): - sons = [self.expr2node(arg) for arg in expr.args] - node = ExprNodeOp(expr) - node.args = sons - elif isinstance(expr, ExprCompose): - sons = [self.expr2node(arg) for arg in expr.args] - node = ExprNodeCompose(expr) - node.args = sons - elif isinstance(expr, ExprCond): - node = ExprNodeCond(expr) - node.cond = self.expr2node(expr.cond) - node.src1 = self.expr2node(expr.src1) - node.src2 = self.expr2node(expr.src2) - else: - raise TypeError("Unknown Expr Type %r", type(expr)) - return node - - def reduce(self, expr, **kwargs): - """Returns an ExprNode tree mirroring @expr tree. The ExprNode is - computed by applying reduction rules to the expression @expr - - @expr: an Expression - """ - - node = self.expr2node(expr) - return self.categorize(node, lvl=0, **kwargs) - - def categorize(self, node, lvl=0, **kwargs): - """Recursively apply rules to @node - - @node: ExprNode to analyze - @lvl: actual recursion level - """ - - expr = node.expr - log_reduce.debug("\t" * lvl + "Reduce...: %s", node.expr) - if isinstance(expr, ExprId): - node = ExprNodeId(expr) - elif isinstance(expr, ExprInt): - node = ExprNodeInt(expr) - elif isinstance(expr, ExprLoc): - node = ExprNodeLoc(expr) - elif isinstance(expr, ExprMem): - ptr = self.categorize(node.ptr, lvl=lvl + 1, **kwargs) - node = ExprNodeMem(ExprMem(ptr.expr, expr.size)) - node.ptr = ptr - elif isinstance(expr, ExprSlice): - arg = self.categorize(node.arg, lvl=lvl + 1, **kwargs) - node = ExprNodeSlice(ExprSlice(arg.expr, expr.start, expr.stop)) - node.arg = arg - elif isinstance(expr, ExprOp): - new_args = [] - for arg in node.args: - new_a = self.categorize(arg, lvl=lvl + 1, **kwargs) - assert new_a.expr.size == arg.expr.size - new_args.append(new_a) - node = ExprNodeOp(ExprOp(expr.op, *[x.expr for x in new_args])) - node.args = new_args - expr = node.expr - elif isinstance(expr, ExprCompose): - new_args = [] - new_expr_args = [] - for arg in node.args: - arg = self.categorize(arg, lvl=lvl + 1, **kwargs) - new_args.append(arg) - new_expr_args.append(arg.expr) - new_expr = ExprCompose(*new_expr_args) - node = ExprNodeCompose(new_expr) - node.args = new_args - elif isinstance(expr, ExprCond): - cond = self.categorize(node.cond, lvl=lvl + 1, **kwargs) - src1 = self.categorize(node.src1, lvl=lvl + 1, **kwargs) - src2 = self.categorize(node.src2, lvl=lvl + 1, **kwargs) - node = ExprNodeCond(ExprCond(cond.expr, src1.expr, src2.expr)) - node.cond, node.src1, node.src2 = cond, src1, src2 - else: - raise TypeError("Unknown Expr Type %r", type(expr)) - - node.info = self.apply_rules(node, lvl=lvl, **kwargs) - log_reduce.debug("\t" * lvl + "Reduce result: %s %r", - node.expr, node.info) - return node - - def apply_rules(self, node, lvl=0, **kwargs): - """Find and apply reduction rules to @node - - @node: ExprNode to analyse - @lvl: actuel recursion level - """ - - for rule in self.reduction_rules: - ret = rule(self, node, lvl=lvl, **kwargs) - - if ret is not None: - log_reduce.debug("\t" * lvl + "Rule found: %r", rule) - return ret - if not self.allow_none_result: - raise RuntimeError('Missing reduction rule for %r' % node.expr) diff --git a/miasm2/expression/modint.py b/miasm2/expression/modint.py deleted file mode 100644 index 22d17b9b..00000000 --- a/miasm2/expression/modint.py +++ /dev/null @@ -1,259 +0,0 @@ -#-*- coding:utf-8 -*- - -from builtins import range -from functools import total_ordering - -@total_ordering -class moduint(object): - - def __init__(self, arg): - self.arg = int(arg) % self.__class__.limit - assert(self.arg >= 0 and self.arg < self.__class__.limit) - - def __repr__(self): - return self.__class__.__name__ + '(' + hex(self.arg) + ')' - - def __hash__(self): - return hash(self.arg) - - @classmethod - def maxcast(cls, c2): - c2 = c2.__class__ - if cls.size > c2.size: - return cls - else: - return c2 - - def __eq__(self, y): - if isinstance(y, moduint): - return self.arg == y.arg - return self.arg == y - - def __ne__(self, y): - # required Python 2.7.14 - return not self == y - - def __lt__(self, y): - if isinstance(y, moduint): - return self.arg < y.arg - return self.arg < y - - def __add__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(self.arg + y.arg) - else: - return self.__class__(self.arg + y) - - def __and__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(self.arg & y.arg) - else: - return self.__class__(self.arg & y) - - def __div__(self, y): - # Python: 8 / -7 == -2 (C-like: -1) - # int(float) trick cannot be used, due to information loss - den = int(y) - num = int(self) - result_sign = 1 if (den * num) >= 0 else -1 - cls = self.__class__ - if isinstance(y, moduint): - cls = self.maxcast(y) - return (abs(num) // abs(den)) * result_sign - - def __floordiv__(self, y): - return self.__div__(y) - - def __int__(self): - return int(self.arg) - - def __long__(self): - return int(self.arg) - - def __index__(self): - return int(self.arg) - - def __invert__(self): - return self.__class__(~self.arg) - - def __lshift__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(self.arg << y.arg) - else: - return self.__class__(self.arg << y) - - def __mod__(self, y): - # See __div__ for implementation choice - cls = self.__class__ - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(self.arg - y * (self // y)) - - def __mul__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(self.arg * y.arg) - else: - return self.__class__(self.arg * y) - - def __neg__(self): - return self.__class__(-self.arg) - - def __or__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(self.arg | y.arg) - else: - return self.__class__(self.arg | y) - - def __radd__(self, y): - return self.__add__(y) - - def __rand__(self, y): - return self.__and__(y) - - def __rdiv__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(y.arg // self.arg) - else: - return self.__class__(y // self.arg) - - def __rfloordiv__(self, y): - return self.__rdiv__(y) - - def __rlshift__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(y.arg << self.arg) - else: - return self.__class__(y << self.arg) - - def __rmod__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(y.arg % self.arg) - else: - return self.__class__(y % self.arg) - - def __rmul__(self, y): - return self.__mul__(y) - - def __ror__(self, y): - return self.__or__(y) - - def __rrshift__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(y.arg >> self.arg) - else: - return self.__class__(y >> self.arg) - - def __rshift__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(self.arg >> y.arg) - else: - return self.__class__(self.arg >> y) - - def __rsub__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(y.arg - self.arg) - else: - return self.__class__(y - self.arg) - - def __rxor__(self, y): - return self.__xor__(y) - - def __sub__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(self.arg - y.arg) - else: - return self.__class__(self.arg - y) - - def __xor__(self, y): - if isinstance(y, moduint): - cls = self.maxcast(y) - return cls(self.arg ^ y.arg) - else: - return self.__class__(self.arg ^ y) - - def __hex__(self): - return hex(self.arg) - - def __abs__(self): - return abs(self.arg) - - def __rpow__(self, v): - return v ** self.arg - - def __pow__(self, v): - return self.__class__(self.arg ** v) - - -class modint(moduint): - - def __init__(self, arg): - if isinstance(arg, moduint): - arg = arg.arg - a = arg % self.__class__.limit - if a >= self.__class__.limit // 2: - a -= self.__class__.limit - self.arg = a - assert( - self.arg >= -self.__class__.limit // 2 and - self.arg < self.__class__.limit - ) - - -def is_modint(a): - return isinstance(a, moduint) - - -def size2mask(size): - return (1 << size) - 1 - -mod_size2uint = {} -mod_size2int = {} - -mod_uint2size = {} -mod_int2size = {} - -def define_int(size): - """Build the 'modint' instance corresponding to size @size""" - global mod_size2int, mod_int2size - - name = 'int%d' % size - cls = type(name, (modint,), {"size": size, "limit": 1 << size}) - globals()[name] = cls - mod_size2int[size] = cls - mod_int2size[cls] = size - return cls - -def define_uint(size): - """Build the 'moduint' instance corresponding to size @size""" - global mod_size2uint, mod_uint2size - - name = 'uint%d' % size - cls = type(name, (moduint,), {"size": size, "limit": 1 << size}) - globals()[name] = cls - mod_size2uint[size] = cls - mod_uint2size[cls] = size - return cls - -def define_common_int(): - "Define common int" - common_int = range(1, 257) - - for i in common_int: - define_int(i) - - for i in common_int: - define_uint(i) - -define_common_int() diff --git a/miasm2/expression/parser.py b/miasm2/expression/parser.py deleted file mode 100644 index 71efc849..00000000 --- a/miasm2/expression/parser.py +++ /dev/null @@ -1,84 +0,0 @@ -import pyparsing -from miasm2.expression.expression import ExprInt, ExprId, ExprLoc, ExprSlice, \ - ExprMem, ExprCond, ExprCompose, ExprOp, ExprAssign, LocKey - -integer = pyparsing.Word(pyparsing.nums).setParseAction(lambda t: - int(t[0])) -hex_word = pyparsing.Literal('0x') + pyparsing.Word(pyparsing.hexnums) -hex_int = pyparsing.Combine(hex_word).setParseAction(lambda t: - int(t[0], 16)) - -str_int_pos = (hex_int | integer) -str_int_neg = (pyparsing.Suppress('-') + \ - (hex_int | integer)).setParseAction(lambda t: -t[0]) - -str_int = str_int_pos | str_int_neg - -STR_EXPRINT = pyparsing.Suppress("ExprInt") -STR_EXPRID = pyparsing.Suppress("ExprId") -STR_EXPRLOC = pyparsing.Suppress("ExprLoc") -STR_EXPRSLICE = pyparsing.Suppress("ExprSlice") -STR_EXPRMEM = pyparsing.Suppress("ExprMem") -STR_EXPRCOND = pyparsing.Suppress("ExprCond") -STR_EXPRCOMPOSE = pyparsing.Suppress("ExprCompose") -STR_EXPROP = pyparsing.Suppress("ExprOp") -STR_EXPRASSIGN = pyparsing.Suppress("ExprAssign") - -LOCKEY = pyparsing.Suppress("LocKey") - -STR_COMMA = pyparsing.Suppress(",") -LPARENTHESIS = pyparsing.Suppress("(") -RPARENTHESIS = pyparsing.Suppress(")") - - -T_INF = pyparsing.Suppress("<") -T_SUP = pyparsing.Suppress(">") - - -string_quote = pyparsing.QuotedString(quoteChar="'", escChar='\\', escQuote='\\') -string_dquote = pyparsing.QuotedString(quoteChar='"', escChar='\\', escQuote='\\') - - -string = string_quote | string_dquote - -expr = pyparsing.Forward() - -expr_int = STR_EXPRINT + LPARENTHESIS + str_int + STR_COMMA + str_int + RPARENTHESIS -expr_id = STR_EXPRID + LPARENTHESIS + string + STR_COMMA + str_int + RPARENTHESIS -expr_loc = STR_EXPRLOC + LPARENTHESIS + T_INF + LOCKEY + str_int + T_SUP + STR_COMMA + str_int + RPARENTHESIS -expr_slice = STR_EXPRSLICE + LPARENTHESIS + expr + STR_COMMA + str_int + STR_COMMA + str_int + RPARENTHESIS -expr_mem = STR_EXPRMEM + LPARENTHESIS + expr + STR_COMMA + str_int + RPARENTHESIS -expr_cond = STR_EXPRCOND + LPARENTHESIS + expr + STR_COMMA + expr + STR_COMMA + expr + RPARENTHESIS -expr_compose = STR_EXPRCOMPOSE + LPARENTHESIS + pyparsing.delimitedList(expr, delim=',') + RPARENTHESIS -expr_op = STR_EXPROP + LPARENTHESIS + string + STR_COMMA + pyparsing.delimitedList(expr, delim=',') + RPARENTHESIS -expr_aff = STR_EXPRASSIGN + LPARENTHESIS + expr + STR_COMMA + expr + RPARENTHESIS - -expr << (expr_int | expr_id | expr_loc | expr_slice | expr_mem | expr_cond | \ - expr_compose | expr_op | expr_aff) - -def parse_loc_key(t): - assert len(t) == 2 - loc_key, size = LocKey(t[0]), t[1] - return ExprLoc(loc_key, size) - -expr_int.setParseAction(lambda t: ExprInt(*t)) -expr_id.setParseAction(lambda t: ExprId(*t)) -expr_loc.setParseAction(parse_loc_key) -expr_slice.setParseAction(lambda t: ExprSlice(*t)) -expr_mem.setParseAction(lambda t: ExprMem(*t)) -expr_cond.setParseAction(lambda t: ExprCond(*t)) -expr_compose.setParseAction(lambda t: ExprCompose(*t)) -expr_op.setParseAction(lambda t: ExprOp(*t)) -expr_aff.setParseAction(lambda t: ExprAssign(*t)) - - -def str_to_expr(str_in): - """Parse the @str_in and return the corresponoding Expression - @str_in: repr string of an Expression""" - - try: - value = expr.parseString(str_in) - except: - raise RuntimeError("Cannot parse expression %s" % str_in) - assert len(value) == 1 - return value[0] diff --git a/miasm2/expression/simplifications.py b/miasm2/expression/simplifications.py deleted file mode 100644 index 331018ae..00000000 --- a/miasm2/expression/simplifications.py +++ /dev/null @@ -1,207 +0,0 @@ -# # -# Simplification methods library # -# # - -import logging - -from future.utils import viewitems - -from miasm2.expression import simplifications_common -from miasm2.expression import simplifications_cond -from miasm2.expression import simplifications_explicit -from miasm2.expression.expression_helper import fast_unify -import miasm2.expression.expression as m2_expr - -# Expression Simplifier -# --------------------- - -log_exprsimp = logging.getLogger("exprsimp") -console_handler = logging.StreamHandler() -console_handler.setFormatter(logging.Formatter("%(levelname)-5s: %(message)s")) -log_exprsimp.addHandler(console_handler) -log_exprsimp.setLevel(logging.WARNING) - - -class ExpressionSimplifier(object): - - """Wrapper on expression simplification passes. - - Instance handle passes lists. - - Available passes lists are: - - commons: common passes such as constant folding - - heavy : rare passes (for instance, in case of obfuscation) - """ - - # Common passes - PASS_COMMONS = { - m2_expr.ExprOp: [ - simplifications_common.simp_cst_propagation, - simplifications_common.simp_cond_op_int, - simplifications_common.simp_cond_factor, - simplifications_common.simp_add_multiple, - # CC op - simplifications_common.simp_cc_conds, - simplifications_common.simp_subwc_cf, - simplifications_common.simp_subwc_of, - simplifications_common.simp_sign_subwc_cf, - simplifications_common.simp_double_zeroext, - simplifications_common.simp_double_signext, - simplifications_common.simp_zeroext_eq_cst, - simplifications_common.simp_ext_eq_ext, - - simplifications_common.simp_cmp_int, - simplifications_common.simp_sign_inf_zeroext, - simplifications_common.simp_cmp_int_int, - simplifications_common.simp_ext_cst, - simplifications_common.simp_zeroext_and_cst_eq_cst, - simplifications_common.simp_test_signext_inf, - simplifications_common.simp_test_zeroext_inf, - simplifications_common.simp_cond_inf_eq_unsigned_zero, - - ], - - m2_expr.ExprSlice: [ - simplifications_common.simp_slice, - simplifications_common.simp_slice_of_ext, - simplifications_common.simp_slice_of_op_ext, - ], - m2_expr.ExprCompose: [simplifications_common.simp_compose], - m2_expr.ExprCond: [ - simplifications_common.simp_cond, - simplifications_common.simp_cond_zeroext, - simplifications_common.simp_cond_add, - # CC op - simplifications_common.simp_cond_flag, - simplifications_common.simp_cmp_int_arg, - - simplifications_common.simp_cond_eq_zero, - simplifications_common.simp_x_and_cst_eq_cst, - simplifications_common.simp_cond_logic_ext, - simplifications_common.simp_cond_sign_bit, - simplifications_common.simp_cond_eq_1_0, - ], - m2_expr.ExprMem: [simplifications_common.simp_mem], - - } - - - # Heavy passes - PASS_HEAVY = {} - - # Cond passes - PASS_COND = { - m2_expr.ExprSlice: [ - simplifications_cond.expr_simp_inf_signed, - simplifications_cond.expr_simp_inf_unsigned_inversed - ], - m2_expr.ExprOp: [ - simplifications_cond.expr_simp_inverse, - ], - m2_expr.ExprCond: [ - simplifications_cond.expr_simp_equal - ] - } - - - # Available passes lists are: - # - highlevel: transform high level operators to explicit computations - PASS_HIGH_TO_EXPLICIT = { - m2_expr.ExprOp: [ - simplifications_explicit.simp_flags, - simplifications_explicit.simp_ext, - ], - } - - - def __init__(self): - self.expr_simp_cb = {} - self.simplified_exprs = set() - - def enable_passes(self, passes): - """Add passes from @passes - @passes: dict(Expr class : list(callback)) - - Callback signature: Expr callback(ExpressionSimplifier, Expr) - """ - - # Clear cache of simplifiied expressions when adding a new pass - self.simplified_exprs.clear() - - for k, v in viewitems(passes): - self.expr_simp_cb[k] = fast_unify(self.expr_simp_cb.get(k, []) + v) - - def apply_simp(self, expression): - """Apply enabled simplifications on expression - @expression: Expr instance - Return an Expr instance""" - - cls = expression.__class__ - debug_level = log_exprsimp.level >= logging.DEBUG - for simp_func in self.expr_simp_cb.get(cls, []): - # Apply simplifications - before = expression - expression = simp_func(self, expression) - after = expression - - if debug_level and before != after: - log_exprsimp.debug("[%s] %s => %s", simp_func, before, after) - - # If class changes, stop to prevent wrong simplifications - if expression.__class__ is not cls: - break - - return expression - - def expr_simp(self, expression): - """Apply enabled simplifications on expression and find a stable state - @expression: Expr instance - Return an Expr instance""" - - if expression in self.simplified_exprs: - return expression - - # Find a stable state - while True: - # Canonize and simplify - e_new = self.apply_simp(expression.canonize()) - if e_new == expression: - break - - # Launch recursivity - expression = self.expr_simp_wrapper(e_new) - self.simplified_exprs.add(expression) - # Mark expression as simplified - self.simplified_exprs.add(e_new) - - return e_new - - def expr_simp_wrapper(self, expression, callback=None): - """Apply enabled simplifications on expression - @expression: Expr instance - @manual_callback: If set, call this function instead of normal one - Return an Expr instance""" - - if expression in self.simplified_exprs: - return expression - - if callback is None: - callback = self.expr_simp - - return expression.visit(callback, lambda e: e not in self.simplified_exprs) - - def __call__(self, expression, callback=None): - "Wrapper on expr_simp_wrapper" - return self.expr_simp_wrapper(expression, callback) - - -# Public ExprSimplificationPass instance with commons passes -expr_simp = ExpressionSimplifier() -expr_simp.enable_passes(ExpressionSimplifier.PASS_COMMONS) - -expr_simp_high_to_explicit = ExpressionSimplifier() -expr_simp_high_to_explicit.enable_passes(ExpressionSimplifier.PASS_HIGH_TO_EXPLICIT) - -expr_simp_explicit = ExpressionSimplifier() -expr_simp_explicit.enable_passes(ExpressionSimplifier.PASS_COMMONS) -expr_simp_explicit.enable_passes(ExpressionSimplifier.PASS_HIGH_TO_EXPLICIT) diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py deleted file mode 100644 index ddcfc668..00000000 --- a/miasm2/expression/simplifications_common.py +++ /dev/null @@ -1,1556 +0,0 @@ -# ----------------------------- # -# Common simplifications passes # -# ----------------------------- # - -from future.utils import viewitems - -from miasm2.expression.modint import mod_size2int, mod_size2uint -from miasm2.expression.expression import ExprInt, ExprSlice, ExprMem, \ - ExprCond, ExprOp, ExprCompose, TOK_INF_SIGNED, TOK_INF_UNSIGNED, \ - TOK_INF_EQUAL_SIGNED, TOK_INF_EQUAL_UNSIGNED, TOK_EQUAL -from miasm2.expression.expression_helper import parity, op_propag_cst, \ - merge_sliceto_slice - - -def simp_cst_propagation(e_s, expr): - """This passe includes: - - Constant folding - - Common logical identities - - Common binary identities - """ - - # merge associatif op - args = list(expr.args) - op_name = expr.op - # simpl integer manip - # int OP int => int - # TODO: <<< >>> << >> are architecture dependent - if op_name in op_propag_cst: - while (len(args) >= 2 and - args[-1].is_int() and - args[-2].is_int()): - int2 = args.pop() - int1 = args.pop() - if op_name == '+': - out = int1.arg + int2.arg - elif op_name == '*': - out = int1.arg * int2.arg - elif op_name == '**': - out =int1.arg ** int2.arg - elif op_name == '^': - out = int1.arg ^ int2.arg - elif op_name == '&': - out = int1.arg & int2.arg - elif op_name == '|': - out = int1.arg | int2.arg - elif op_name == '>>': - if int(int2) > int1.size: - out = 0 - else: - out = int1.arg >> int2.arg - elif op_name == '<<': - if int(int2) > int1.size: - out = 0 - else: - out = int1.arg << int2.arg - elif op_name == 'a>>': - tmp1 = mod_size2int[int1.arg.size](int1.arg) - tmp2 = mod_size2uint[int2.arg.size](int2.arg) - if tmp2 > int1.size: - is_signed = int(int1) & (1 << (int1.size - 1)) - if is_signed: - out = -1 - else: - out = 0 - else: - out = mod_size2uint[int1.arg.size](tmp1 >> tmp2) - elif op_name == '>>>': - shifter = int2.arg % int2.size - out = (int1.arg >> shifter) | (int1.arg << (int2.size - shifter)) - elif op_name == '<<<': - shifter = int2.arg % int2.size - out = (int1.arg << shifter) | (int1.arg >> (int2.size - shifter)) - elif op_name == '/': - out = int1.arg // int2.arg - elif op_name == '%': - out = int1.arg % int2.arg - elif op_name == 'sdiv': - assert int2.arg.arg - tmp1 = mod_size2int[int1.arg.size](int1.arg) - tmp2 = mod_size2int[int2.arg.size](int2.arg) - out = mod_size2uint[int1.arg.size](tmp1 // tmp2) - elif op_name == 'smod': - assert int2.arg.arg - tmp1 = mod_size2int[int1.arg.size](int1.arg) - tmp2 = mod_size2int[int2.arg.size](int2.arg) - out = mod_size2uint[int1.arg.size](tmp1 % tmp2) - elif op_name == 'umod': - assert int2.arg.arg - tmp1 = mod_size2uint[int1.arg.size](int1.arg) - tmp2 = mod_size2uint[int2.arg.size](int2.arg) - out = mod_size2uint[int1.arg.size](tmp1 % tmp2) - elif op_name == 'udiv': - assert int2.arg.arg - tmp1 = mod_size2uint[int1.arg.size](int1.arg) - tmp2 = mod_size2uint[int2.arg.size](int2.arg) - out = mod_size2uint[int1.arg.size](tmp1 // tmp2) - - - - args.append(ExprInt(out, int1.size)) - - # cnttrailzeros(int) => int - if op_name == "cnttrailzeros" and args[0].is_int(): - i = 0 - while args[0].arg & (1 << i) == 0 and i < args[0].size: - i += 1 - return ExprInt(i, args[0].size) - - # cntleadzeros(int) => int - if op_name == "cntleadzeros" and args[0].is_int(): - if args[0].arg == 0: - return ExprInt(args[0].size, args[0].size) - i = args[0].size - 1 - while args[0].arg & (1 << i) == 0: - i -= 1 - return ExprInt(expr.size - (i + 1), args[0].size) - - # -(-(A)) => A - if (op_name == '-' and len(args) == 1 and args[0].is_op('-') and - len(args[0].args) == 1): - return args[0].args[0] - - # -(int) => -int - if op_name == '-' and len(args) == 1 and args[0].is_int(): - return ExprInt(-int(args[0]), expr.size) - # A op 0 =>A - if op_name in ['+', '|', "^", "<<", ">>", "<<<", ">>>"] and len(args) > 1: - if args[-1].is_int(0): - args.pop() - # A - 0 =>A - if op_name == '-' and len(args) > 1 and args[-1].is_int(0): - assert len(args) == 2 # Op '-' with more than 2 args: SantityCheckError - return args[0] - - # A * 1 =>A - if op_name == "*" and len(args) > 1 and args[-1].is_int(1): - args.pop() - - # for cannon form - # A * -1 => - A - if op_name == "*" and len(args) > 1 and args[-1] == args[-1].mask: - args.pop() - args[-1] = - args[-1] - - # op A => A - if op_name in ['+', '*', '^', '&', '|', '>>', '<<', - 'a>>', '<<<', '>>>', 'sdiv', 'smod', 'umod', 'udiv'] and len(args) == 1: - return args[0] - - # A-B => A + (-B) - if op_name == '-' and len(args) > 1: - if len(args) > 2: - raise ValueError( - 'sanity check fail on expr -: should have one or 2 args ' + - '%r %s' % (expr, expr) - ) - return ExprOp('+', args[0], -args[1]) - - # A op 0 => 0 - if op_name in ['&', "*"] and args[-1].is_int(0): - return ExprInt(0, expr.size) - - # - (A + B +...) => -A + -B + -C - if op_name == '-' and len(args) == 1 and args[0].is_op('+'): - args = [-a for a in args[0].args] - return ExprOp('+', *args) - - # -(a?int1:int2) => (a?-int1:-int2) - if (op_name == '-' and len(args) == 1 and - args[0].is_cond() and - args[0].src1.is_int() and args[0].src2.is_int()): - int1 = args[0].src1 - int2 = args[0].src2 - int1 = ExprInt(-int1.arg, int1.size) - int2 = ExprInt(-int2.arg, int2.size) - return ExprCond(args[0].cond, int1, int2) - - i = 0 - while i < len(args) - 1: - j = i + 1 - while j < len(args): - # A ^ A => 0 - if op_name == '^' and args[i] == args[j]: - args[i] = ExprInt(0, args[i].size) - del args[j] - continue - # A + (- A) => 0 - if op_name == '+' and args[j].is_op("-"): - if len(args[j].args) == 1 and args[i] == args[j].args[0]: - args[i] = ExprInt(0, args[i].size) - del args[j] - continue - # (- A) + A => 0 - if op_name == '+' and args[i].is_op("-"): - if len(args[i].args) == 1 and args[j] == args[i].args[0]: - args[i] = ExprInt(0, args[i].size) - del args[j] - continue - # A | A => A - if op_name == '|' and args[i] == args[j]: - del args[j] - continue - # A & A => A - if op_name == '&' and args[i] == args[j]: - del args[j] - continue - j += 1 - i += 1 - - if op_name in ['|', '&', '%', '/', '**'] and len(args) == 1: - return args[0] - - # A <<< A.size => A - if (op_name in ['<<<', '>>>'] and - args[1].is_int() and - args[1].arg == args[0].size): - return args[0] - - # (A <<< X) <<< Y => A <<< (X+Y) (or <<< >>>) if X + Y does not overflow - if (op_name in ['<<<', '>>>'] and - args[0].is_op() and - args[0].op in ['<<<', '>>>']): - A = args[0].args[0] - X = args[0].args[1] - Y = args[1] - if op_name != args[0].op and e_s(X - Y) == ExprInt(0, X.size): - return args[0].args[0] - elif X.is_int() and Y.is_int(): - new_X = int(X) % expr.size - new_Y = int(Y) % expr.size - if op_name == args[0].op: - rot = (new_X + new_Y) % expr.size - op = op_name - else: - rot = new_Y - new_X - op = op_name - if rot < 0: - rot = - rot - op = {">>>": "<<<", "<<<": ">>>"}[op_name] - args = [A, ExprInt(rot, expr.size)] - op_name = op - - else: - # Do not consider this case, too tricky (overflow on addition / - # subtraction) - pass - - # A >> X >> Y => A >> (X+Y) if X + Y does not overflow - # To be sure, only consider the simplification when X.msb and Y.msb are 0 - if (op_name in ['<<', '>>'] and - args[0].is_op(op_name)): - X = args[0].args[1] - Y = args[1] - if (e_s(X.msb()) == ExprInt(0, 1) and - e_s(Y.msb()) == ExprInt(0, 1)): - args = [args[0].args[0], X + Y] - - # ((var >> int1) << int1) => var & mask - # ((var << int1) >> int1) => var & mask - if (op_name in ['<<', '>>'] and - args[0].is_op() and - args[0].op in ['<<', '>>'] and - op_name != args[0]): - var = args[0].args[0] - int1 = args[0].args[1] - int2 = args[1] - if int1 == int2 and int1.is_int() and int(int1) < expr.size: - if op_name == '>>': - mask = ExprInt((1 << (expr.size - int(int1))) - 1, expr.size) - else: - mask = ExprInt( - ((1 << int(int1)) - 1) ^ ((1 << expr.size) - 1), - expr.size - ) - ret = var & mask - return ret - - # ((A & A.mask) - if op_name == "&" and args[-1] == expr.mask: - return ExprOp('&', *args[:-1]) - - # ((A | A.mask) - if op_name == "|" and args[-1] == expr.mask: - return args[-1] - - # ! (!X + int) => X - int - # TODO - - # ((A & mask) >> shift) with mask < 2**shift => 0 - if op_name == ">>" and args[1].is_int() and args[0].is_op("&"): - if (args[0].args[1].is_int() and - 2 ** args[1].arg > args[0].args[1].arg): - return ExprInt(0, args[0].size) - - # parity(int) => int - if op_name == 'parity' and args[0].is_int(): - return ExprInt(parity(int(args[0])), 1) - - # (-a) * b * (-c) * (-d) => (-a) * b * c * d - if op_name == "*" and len(args) > 1: - new_args = [] - counter = 0 - for arg in args: - if arg.is_op('-') and len(arg.args) == 1: - new_args.append(arg.args[0]) - counter += 1 - else: - new_args.append(arg) - if counter % 2: - return -ExprOp(op_name, *new_args) - args = new_args - - # -(a * b * int) => a * b * (-int) - if op_name == "-" and args[0].is_op('*') and args[0].args[-1].is_int(): - args = args[0].args - return ExprOp('*', *(list(args[:-1]) + [ExprInt(-int(args[-1]), expr.size)])) - - - # A << int with A ExprCompose => move index - if (op_name == "<<" and args[0].is_compose() and - args[1].is_int() and int(args[1]) != 0): - final_size = args[0].size - shift = int(args[1]) - new_args = [] - # shift indexes - for index, arg in args[0].iter_args(): - new_args.append((arg, index+shift, index+shift+arg.size)) - # filter out expression - filter_args = [] - min_index = final_size - for tmp, start, stop in new_args: - if start >= final_size: - continue - if stop > final_size: - tmp = tmp[:tmp.size - (stop - final_size)] - filter_args.append(tmp) - min_index = min(start, min_index) - # create entry 0 - assert min_index != 0 - tmp = ExprInt(0, min_index) - args = [tmp] + filter_args - return ExprCompose(*args) - - # A >> int with A ExprCompose => move index - if op_name == ">>" and args[0].is_compose() and args[1].is_int(): - final_size = args[0].size - shift = int(args[1]) - new_args = [] - # shift indexes - for index, arg in args[0].iter_args(): - new_args.append((arg, index-shift, index+arg.size-shift)) - # filter out expression - filter_args = [] - max_index = 0 - for tmp, start, stop in new_args: - if stop <= 0: - continue - if start < 0: - tmp = tmp[-start:] - filter_args.append(tmp) - max_index = max(stop, max_index) - # create entry 0 - tmp = ExprInt(0, final_size - max_index) - args = filter_args + [tmp] - return ExprCompose(*args) - - - # Compose(a) OP Compose(b) with a/b same bounds => Compose(a OP b) - if op_name in ['|', '&', '^'] and all([arg.is_compose() for arg in args]): - bounds = set() - for arg in args: - bound = tuple([tmp.size for tmp in arg.args]) - bounds.add(bound) - if len(bounds) == 1: - new_args = [[tmp] for tmp in args[0].args] - for sub_arg in args[1:]: - for i, tmp in enumerate(sub_arg.args): - new_args[i].append(tmp) - args = [] - for i, arg in enumerate(new_args): - args.append(ExprOp(op_name, *arg)) - return ExprCompose(*args) - - return ExprOp(op_name, *args) - - -def simp_cond_op_int(_, expr): - "Extract conditions from operations" - - - # x?a:b + x?c:d + e => x?(a+c+e:b+d+e) - if not expr.op in ["+", "|", "^", "&", "*", '<<', '>>', 'a>>']: - return expr - if len(expr.args) < 2: - return expr - conds = set() - for arg in expr.args: - if arg.is_cond(): - conds.add(arg) - if len(conds) != 1: - return expr - cond = list(conds).pop() - - args1, args2 = [], [] - for arg in expr.args: - if arg.is_cond(): - args1.append(arg.src1) - args2.append(arg.src2) - else: - args1.append(arg) - args2.append(arg) - - return ExprCond(cond.cond, - ExprOp(expr.op, *args1), - ExprOp(expr.op, *args2)) - - -def simp_cond_factor(e_s, expr): - "Merge similar conditions" - if not expr.op in ["+", "|", "^", "&", "*", '<<', '>>', 'a>>']: - return expr - if len(expr.args) < 2: - return expr - - if expr.op in ['>>', '<<', 'a>>']: - assert len(expr.args) == 2 - - # Note: the following code is correct for non-commutative operation only if - # there is 2 arguments. Otherwise, the order is not conserved - - # Regroup sub-expression by similar conditions - conds = {} - not_conds = [] - multi_cond = False - for arg in expr.args: - if not arg.is_cond(): - not_conds.append(arg) - continue - cond = arg.cond - if not cond in conds: - conds[cond] = [] - else: - multi_cond = True - conds[cond].append(arg) - if not multi_cond: - return expr - - # Rebuild the new expression - c_out = not_conds - for cond, vals in viewitems(conds): - new_src1 = [x.src1 for x in vals] - new_src2 = [x.src2 for x in vals] - src1 = e_s.expr_simp_wrapper(ExprOp(expr.op, *new_src1)) - src2 = e_s.expr_simp_wrapper(ExprOp(expr.op, *new_src2)) - c_out.append(ExprCond(cond, src1, src2)) - - if len(c_out) == 1: - new_e = c_out[0] - else: - new_e = ExprOp(expr.op, *c_out) - return new_e - - -def simp_slice(e_s, expr): - "Slice optimization" - - # slice(A, 0, a.size) => A - if expr.start == 0 and expr.stop == expr.arg.size: - return expr.arg - # Slice(int) => int - if expr.arg.is_int(): - total_bit = expr.stop - expr.start - mask = (1 << (expr.stop - expr.start)) - 1 - return ExprInt(int((expr.arg.arg >> expr.start) & mask), total_bit) - # Slice(Slice(A, x), y) => Slice(A, z) - if expr.arg.is_slice(): - if expr.stop - expr.start > expr.arg.stop - expr.arg.start: - raise ValueError('slice in slice: getting more val', str(expr)) - - return ExprSlice(expr.arg.arg, expr.start + expr.arg.start, - expr.start + expr.arg.start + (expr.stop - expr.start)) - if expr.arg.is_compose(): - # Slice(Compose(A), x) => Slice(A, y) - for index, arg in expr.arg.iter_args(): - if index <= expr.start and index+arg.size >= expr.stop: - return arg[expr.start - index:expr.stop - index] - # Slice(Compose(A, B, C), x) => Compose(A, B, C) with truncated A/B/C - out = [] - for index, arg in expr.arg.iter_args(): - # arg is before slice start - if expr.start >= index + arg.size: - continue - # arg is after slice stop - elif expr.stop <= index: - continue - # arg is fully included in slice - elif expr.start <= index and index + arg.size <= expr.stop: - out.append(arg) - continue - # arg is truncated at start - if expr.start > index: - slice_start = expr.start - index - else: - # arg is not truncated at start - slice_start = 0 - # a is truncated at stop - if expr.stop < index + arg.size: - slice_stop = arg.size + expr.stop - (index + arg.size) - slice_start - else: - slice_stop = arg.size - out.append(arg[slice_start:slice_stop]) - - return ExprCompose(*out) - - # ExprMem(x, size)[:A] => ExprMem(x, a) - # XXXX todo hum, is it safe? - if (expr.arg.is_mem() and - expr.start == 0 and - expr.arg.size > expr.stop and expr.stop % 8 == 0): - return ExprMem(expr.arg.ptr, size=expr.stop) - # distributivity of slice and & - # (a & int)[x:y] => 0 if int[x:y] == 0 - if expr.arg.is_op("&") and expr.arg.args[-1].is_int(): - tmp = e_s.expr_simp_wrapper(expr.arg.args[-1][expr.start:expr.stop]) - if tmp.is_int(0): - return tmp - # distributivity of slice and exprcond - # (a?int1:int2)[x:y] => (a?int1[x:y]:int2[x:y]) - # (a?compose1:compose2)[x:y] => (a?compose1[x:y]:compose2[x:y]) - if (expr.arg.is_cond() and - (expr.arg.src1.is_int() or expr.arg.src1.is_compose()) and - (expr.arg.src2.is_int() or expr.arg.src2.is_compose())): - src1 = expr.arg.src1[expr.start:expr.stop] - src2 = expr.arg.src2[expr.start:expr.stop] - return ExprCond(expr.arg.cond, src1, src2) - - # (a * int)[0:y] => (a[0:y] * int[0:y]) - if expr.start == 0 and expr.arg.is_op("*") and expr.arg.args[-1].is_int(): - args = [e_s.expr_simp_wrapper(a[expr.start:expr.stop]) for a in expr.arg.args] - return ExprOp(expr.arg.op, *args) - - # (a >> int)[x:y] => a[x+int:y+int] with int+y <= a.size - # (a << int)[x:y] => a[x-int:y-int] with x-int >= 0 - if (expr.arg.is_op() and expr.arg.op in [">>", "<<"] and - expr.arg.args[1].is_int()): - arg, shift = expr.arg.args - shift = int(shift) - if expr.arg.op == ">>": - if shift + expr.stop <= arg.size: - return arg[expr.start + shift:expr.stop + shift] - elif expr.arg.op == "<<": - if expr.start - shift >= 0: - return arg[expr.start - shift:expr.stop - shift] - else: - raise ValueError('Bad case') - - return expr - - -def simp_compose(e_s, expr): - "Commons simplification on ExprCompose" - args = merge_sliceto_slice(expr) - out = [] - # compose of compose - for arg in args: - if arg.is_compose(): - out += arg.args - else: - out.append(arg) - args = out - # Compose(a) with a.size = compose.size => a - if len(args) == 1 and args[0].size == expr.size: - return args[0] - - # {(X[z:], 0, X.size-z), (0, X.size-z, X.size)} => (X >> z) - if len(args) == 2 and args[1].is_int(0): - if (args[0].is_slice() and - args[0].stop == args[0].arg.size and - args[0].size + args[1].size == args[0].arg.size): - new_expr = args[0].arg >> ExprInt(args[0].start, args[0].arg.size) - return new_expr - - # {@X[base + i] 0 X, @Y[base + i + X] X (X + Y)} => @(X+Y)[base + i] - for i, arg in enumerate(args[:-1]): - nxt = args[i + 1] - if arg.is_mem() and nxt.is_mem(): - gap = e_s(nxt.ptr - arg.ptr) - if gap.is_int() and arg.size % 8 == 0 and int(gap) == arg.size // 8: - args = args[:i] + [ExprMem(arg.ptr, - arg.size + nxt.size)] + args[i + 2:] - return ExprCompose(*args) - - # {a, x?b:d, x?c:e, f} => x?{a, b, c, f}:{a, d, e, f} - conds = set(arg.cond for arg in expr.args if arg.is_cond()) - if len(conds) == 1: - cond = list(conds)[0] - args1, args2 = [], [] - for arg in expr.args: - if arg.is_cond(): - args1.append(arg.src1) - args2.append(arg.src2) - else: - args1.append(arg) - args2.append(arg) - arg1 = e_s(ExprCompose(*args1)) - arg2 = e_s(ExprCompose(*args2)) - return ExprCond(cond, arg1, arg2) - return ExprCompose(*args) - - -def simp_cond(_, expr): - """ - Common simplifications on ExprCond. - Eval exprcond src1/src2 with satifiable/unsatisfiable condition propagation - """ - if (not expr.cond.is_int()) and expr.cond.size == 1: - src1 = expr.src1.replace_expr({expr.cond: ExprInt(1, 1)}) - src2 = expr.src2.replace_expr({expr.cond: ExprInt(0, 1)}) - if src1 != expr.src1 or src2 != expr.src2: - return ExprCond(expr.cond, src1, src2) - - # -A ? B:C => A ? B:C - if expr.cond.is_op('-') and len(expr.cond.args) == 1: - expr = ExprCond(expr.cond.args[0], expr.src1, expr.src2) - # a?x:x - elif expr.src1 == expr.src2: - expr = expr.src1 - # int ? A:B => A or B - elif expr.cond.is_int(): - if expr.cond.arg == 0: - expr = expr.src2 - else: - expr = expr.src1 - # a?(a?b:c):x => a?b:x - elif expr.src1.is_cond() and expr.cond == expr.src1.cond: - expr = ExprCond(expr.cond, expr.src1.src1, expr.src2) - # a?x:(a?b:c) => a?x:c - elif expr.src2.is_cond() and expr.cond == expr.src2.cond: - expr = ExprCond(expr.cond, expr.src1, expr.src2.src2) - # a|int ? b:c => b with int != 0 - elif (expr.cond.is_op('|') and - expr.cond.args[1].is_int() and - expr.cond.args[1].arg != 0): - return expr.src1 - - # (C?int1:int2)?(A:B) => - elif (expr.cond.is_cond() and - expr.cond.src1.is_int() and - expr.cond.src2.is_int()): - int1 = expr.cond.src1.arg.arg - int2 = expr.cond.src2.arg.arg - if int1 and int2: - expr = expr.src1 - elif int1 == 0 and int2 == 0: - expr = expr.src2 - elif int1 == 0 and int2: - expr = ExprCond(expr.cond.cond, expr.src2, expr.src1) - elif int1 and int2 == 0: - expr = ExprCond(expr.cond.cond, expr.src1, expr.src2) - - elif expr.cond.is_compose(): - # {0, X, 0}?(A:B) => X?(A:B) - args = [arg for arg in expr.cond.args if not arg.is_int(0)] - if len(args) == 1: - arg = args.pop() - return ExprCond(arg, expr.src1, expr.src2) - elif len(args) < len(expr.cond.args): - return ExprCond(ExprCompose(*args), expr.src1, expr.src2) - return expr - - -def simp_mem(_, expr): - """ - Common simplifications on ExprMem: - @32[x?a:b] => x?@32[a]:@32[b] - """ - if expr.ptr.is_cond(): - cond = expr.ptr - ret = ExprCond(cond.cond, - ExprMem(cond.src1, expr.size), - ExprMem(cond.src2, expr.size)) - return ret - return expr - - - - -def test_cc_eq_args(expr, *sons_op): - """ - Return True if expression's arguments match the list in sons_op, and their - sub arguments are identical. Ex: - CC_S<=( - FLAG_SIGN_SUB(A, B), - FLAG_SUB_OF(A, B), - FLAG_EQ_CMP(A, B) - ) - """ - if not expr.is_op(): - return False - if len(expr.args) != len(sons_op): - return False - all_args = set() - for i, arg in enumerate(expr.args): - if not arg.is_op(sons_op[i]): - return False - all_args.add(arg.args) - return len(all_args) == 1 - - -def simp_cc_conds(_, expr): - """ - High level simplifications. Example: - CC_U<(FLAG_SUB_CF(A, B) => A <u B - """ - if (expr.is_op("CC_U>=") and - test_cc_eq_args( - expr, - "FLAG_SUB_CF" - )): - expr = ExprCond( - ExprOp(TOK_INF_UNSIGNED, *expr.args[0].args), - ExprInt(0, expr.size), - ExprInt(1, expr.size)) - - elif (expr.is_op("CC_U<") and - test_cc_eq_args( - expr, - "FLAG_SUB_CF" - )): - expr = ExprOp(TOK_INF_UNSIGNED, *expr.args[0].args) - - elif (expr.is_op("CC_NEG") and - test_cc_eq_args( - expr, - "FLAG_SIGN_SUB" - )): - expr = ExprOp(TOK_INF_SIGNED, *expr.args[0].args) - - elif (expr.is_op("CC_POS") and - test_cc_eq_args( - expr, - "FLAG_SIGN_SUB" - )): - expr = ExprCond( - ExprOp(TOK_INF_SIGNED, *expr.args[0].args), - ExprInt(0, expr.size), - ExprInt(1, expr.size) - ) - - elif (expr.is_op("CC_EQ") and - test_cc_eq_args( - expr, - "FLAG_EQ" - )): - arg = expr.args[0].args[0] - expr = ExprOp(TOK_EQUAL, arg, ExprInt(0, arg.size)) - - elif (expr.is_op("CC_NE") and - test_cc_eq_args( - expr, - "FLAG_EQ" - )): - arg = expr.args[0].args[0] - expr = ExprCond( - ExprOp(TOK_EQUAL,arg, ExprInt(0, arg.size)), - ExprInt(0, expr.size), - ExprInt(1, expr.size) - ) - elif (expr.is_op("CC_NE") and - test_cc_eq_args( - expr, - "FLAG_EQ_CMP" - )): - expr = ExprCond( - ExprOp(TOK_EQUAL, *expr.args[0].args), - ExprInt(0, expr.size), - ExprInt(1, expr.size) - ) - - elif (expr.is_op("CC_EQ") and - test_cc_eq_args( - expr, - "FLAG_EQ_CMP" - )): - expr = ExprOp(TOK_EQUAL, *expr.args[0].args) - - elif (expr.is_op("CC_NE") and - test_cc_eq_args( - expr, - "FLAG_EQ_AND" - )): - expr = ExprOp("&", *expr.args[0].args) - - elif (expr.is_op("CC_EQ") and - test_cc_eq_args( - expr, - "FLAG_EQ_AND" - )): - expr = ExprCond( - ExprOp("&", *expr.args[0].args), - ExprInt(0, expr.size), - ExprInt(1, expr.size) - ) - - elif (expr.is_op("CC_S>") and - test_cc_eq_args( - expr, - "FLAG_SIGN_SUB", - "FLAG_SUB_OF", - "FLAG_EQ_CMP", - )): - expr = ExprCond( - ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args), - ExprInt(0, expr.size), - ExprInt(1, expr.size) - ) - - elif (expr.is_op("CC_S>") and - len(expr.args) == 3 and - expr.args[0].is_op("FLAG_SIGN_SUB") and - expr.args[2].is_op("FLAG_EQ_CMP") and - expr.args[0].args == expr.args[2].args and - expr.args[1].is_int(0)): - expr = ExprCond( - ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args), - ExprInt(0, expr.size), - ExprInt(1, expr.size) - ) - - - - elif (expr.is_op("CC_S>=") and - test_cc_eq_args( - expr, - "FLAG_SIGN_SUB", - "FLAG_SUB_OF" - )): - expr = ExprCond( - ExprOp(TOK_INF_SIGNED, *expr.args[0].args), - ExprInt(0, expr.size), - ExprInt(1, expr.size) - ) - - elif (expr.is_op("CC_S<") and - test_cc_eq_args( - expr, - "FLAG_SIGN_SUB", - "FLAG_SUB_OF" - )): - expr = ExprOp(TOK_INF_SIGNED, *expr.args[0].args) - - elif (expr.is_op("CC_S<=") and - test_cc_eq_args( - expr, - "FLAG_SIGN_SUB", - "FLAG_SUB_OF", - "FLAG_EQ_CMP", - )): - expr = ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args) - - elif (expr.is_op("CC_S<=") and - len(expr.args) == 3 and - expr.args[0].is_op("FLAG_SIGN_SUB") and - expr.args[2].is_op("FLAG_EQ_CMP") and - expr.args[0].args == expr.args[2].args and - expr.args[1].is_int(0)): - expr = ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args) - - elif (expr.is_op("CC_U<=") and - test_cc_eq_args( - expr, - "FLAG_SUB_CF", - "FLAG_EQ_CMP", - )): - expr = ExprOp(TOK_INF_EQUAL_UNSIGNED, *expr.args[0].args) - - elif (expr.is_op("CC_U>") and - test_cc_eq_args( - expr, - "FLAG_SUB_CF", - "FLAG_EQ_CMP", - )): - expr = ExprCond( - ExprOp(TOK_INF_EQUAL_UNSIGNED, *expr.args[0].args), - ExprInt(0, expr.size), - ExprInt(1, expr.size) - ) - - elif (expr.is_op("CC_S<") and - test_cc_eq_args( - expr, - "FLAG_SIGN_ADD", - "FLAG_ADD_OF" - )): - arg0, arg1 = expr.args[0].args - expr = ExprOp(TOK_INF_SIGNED, arg0, -arg1) - - return expr - - - -def simp_cond_flag(_, expr): - """FLAG_EQ_CMP(X, Y)?A:B => (X == Y)?A:B""" - cond = expr.cond - if cond.is_op("FLAG_EQ_CMP"): - return ExprCond(ExprOp(TOK_EQUAL, *cond.args), expr.src1, expr.src2) - return expr - - -def simp_cmp_int(expr_simp, expr): - """ - ({X, 0} == int) => X == int[:] - X + int1 == int2 => X == int2-int1 - X ^ int1 == int2 => X == int1^int2 - """ - if (expr.is_op(TOK_EQUAL) and - expr.args[1].is_int() and - expr.args[0].is_compose() and - len(expr.args[0].args) == 2 and - expr.args[0].args[1].is_int(0)): - # ({X, 0} == int) => X == int[:] - src = expr.args[0].args[0] - int_val = int(expr.args[1]) - new_int = ExprInt(int_val, src.size) - expr = expr_simp( - ExprOp(TOK_EQUAL, src, new_int) - ) - elif not expr.is_op(TOK_EQUAL): - return expr - assert len(expr.args) == 2 - - left, right = expr.args - if left.is_int() and not right.is_int(): - left, right = right, left - if not right.is_int(): - return expr - if not (left.is_op() and left.op in ['+', '^']): - return expr - if not left.args[-1].is_int(): - return expr - # X + int1 == int2 => X == int2-int1 - # WARNING: - # X - 0x10 <=u 0x20 gives X in [0x10 0x30] - # which is not equivalet to A <=u 0x10 - - left_orig = left - left, last_int = left.args[:-1], left.args[-1] - - if len(left) == 1: - left = left[0] - else: - left = ExprOp(left.op, *left) - - if left_orig.op == "+": - new_int = expr_simp(right - last_int) - elif left_orig.op == '^': - new_int = expr_simp(right ^ last_int) - else: - raise RuntimeError("Unsupported operator") - - expr = expr_simp( - ExprOp(TOK_EQUAL, left, new_int), - ) - return expr - - - -def simp_cmp_int_arg(_, expr): - """ - (0x10 <= R0) ? A:B - => - (R0 < 0x10) ? B:A - """ - cond = expr.cond - if not cond.is_op(): - return expr - op = cond.op - if op not in [ - TOK_EQUAL, - TOK_INF_SIGNED, - TOK_INF_EQUAL_SIGNED, - TOK_INF_UNSIGNED, - TOK_INF_EQUAL_UNSIGNED - ]: - return expr - arg1, arg2 = cond.args - if arg2.is_int(): - return expr - if not arg1.is_int(): - return expr - src1, src2 = expr.src1, expr.src2 - if op == TOK_EQUAL: - return ExprCond(ExprOp(TOK_EQUAL, arg2, arg1), src1, src2) - - arg1, arg2 = arg2, arg1 - src1, src2 = src2, src1 - if op == TOK_INF_SIGNED: - op = TOK_INF_EQUAL_SIGNED - elif op == TOK_INF_EQUAL_SIGNED: - op = TOK_INF_SIGNED - elif op == TOK_INF_UNSIGNED: - op = TOK_INF_EQUAL_UNSIGNED - elif op == TOK_INF_EQUAL_UNSIGNED: - op = TOK_INF_UNSIGNED - return ExprCond(ExprOp(op, arg1, arg2), src1, src2) - - -def simp_subwc_cf(_, expr): - """SUBWC_CF(A, B, SUB_CF(C, D)) => SUB_CF({A, C}, {B, D})""" - if not expr.is_op('FLAG_SUBWC_CF'): - return expr - op3 = expr.args[2] - if not op3.is_op("FLAG_SUB_CF"): - return expr - - op1 = ExprCompose(expr.args[0], op3.args[0]) - op2 = ExprCompose(expr.args[1], op3.args[1]) - - return ExprOp("FLAG_SUB_CF", op1, op2) - - -def simp_subwc_of(_, expr): - """SUBWC_OF(A, B, SUB_CF(C, D)) => SUB_OF({A, C}, {B, D})""" - if not expr.is_op('FLAG_SUBWC_OF'): - return expr - op3 = expr.args[2] - if not op3.is_op("FLAG_SUB_CF"): - return expr - - op1 = ExprCompose(expr.args[0], op3.args[0]) - op2 = ExprCompose(expr.args[1], op3.args[1]) - - return ExprOp("FLAG_SUB_OF", op1, op2) - - -def simp_sign_subwc_cf(_, expr): - """SIGN_SUBWC(A, B, SUB_CF(C, D)) => SIGN_SUB({A, C}, {B, D})""" - if not expr.is_op('FLAG_SIGN_SUBWC'): - return expr - op3 = expr.args[2] - if not op3.is_op("FLAG_SUB_CF"): - return expr - - op1 = ExprCompose(expr.args[0], op3.args[0]) - op2 = ExprCompose(expr.args[1], op3.args[1]) - - return ExprOp("FLAG_SIGN_SUB", op1, op2) - -def simp_double_zeroext(_, expr): - """A.zeroExt(X).zeroExt(Y) => A.zeroExt(Y)""" - if not (expr.is_op() and expr.op.startswith("zeroExt")): - return expr - arg1 = expr.args[0] - if not (arg1.is_op() and arg1.op.startswith("zeroExt")): - return expr - arg2 = arg1.args[0] - return ExprOp(expr.op, arg2) - -def simp_double_signext(_, expr): - """A.signExt(X).signExt(Y) => A.signExt(Y)""" - if not (expr.is_op() and expr.op.startswith("signExt")): - return expr - arg1 = expr.args[0] - if not (arg1.is_op() and arg1.op.startswith("signExt")): - return expr - arg2 = arg1.args[0] - return ExprOp(expr.op, arg2) - -def simp_zeroext_eq_cst(_, expr): - """A.zeroExt(X) == int => A == int[:A.size]""" - if not expr.is_op(TOK_EQUAL): - return expr - arg1, arg2 = expr.args - if not arg2.is_int(): - return expr - if not (arg1.is_op() and arg1.op.startswith("zeroExt")): - return expr - src = arg1.args[0] - if int(arg2) > (1 << src.size): - # Always false - return ExprInt(0, expr.size) - return ExprOp(TOK_EQUAL, src, ExprInt(int(arg2), src.size)) - -def simp_cond_zeroext(_, expr): - """ - X.zeroExt()?(A:B) => X ? A:B - X.signExt()?(A:B) => X ? A:B - """ - if not ( - expr.cond.is_op() and - ( - expr.cond.op.startswith("zeroExt") or - expr.cond.op.startswith("signExt") - ) - ): - return expr - - ret = ExprCond(expr.cond.args[0], expr.src1, expr.src2) - return ret - -def simp_ext_eq_ext(_, expr): - """ - A.zeroExt(X) == B.zeroExt(X) => A == B - A.signExt(X) == B.signExt(X) => A == B - """ - if not expr.is_op(TOK_EQUAL): - return expr - arg1, arg2 = expr.args - if (not ((arg1.is_op() and arg1.op.startswith("zeroExt") and - arg2.is_op() and arg2.op.startswith("zeroExt")) or - (arg1.is_op() and arg1.op.startswith("signExt") and - arg2.is_op() and arg2.op.startswith("signExt")))): - return expr - if arg1.args[0].size != arg2.args[0].size: - return expr - return ExprOp(TOK_EQUAL, arg1.args[0], arg2.args[0]) - -def simp_cond_eq_zero(_, expr): - """(X == 0)?(A:B) => X?(B:A)""" - cond = expr.cond - if not cond.is_op(TOK_EQUAL): - return expr - arg1, arg2 = cond.args - if not arg2.is_int(0): - return expr - new_expr = ExprCond(arg1, expr.src2, expr.src1) - return new_expr - -def simp_sign_inf_zeroext(expr_s, expr): - """ - /!\ Ensure before: X.zeroExt(X.size) => X - - X.zeroExt() <s 0 => 0 - X.zeroExt() <=s 0 => X == 0 - - X.zeroExt() <s cst => X.zeroExt() <u cst (cst positive) - X.zeroExt() <=s cst => X.zeroExt() <=u cst (cst positive) - - X.zeroExt() <s cst => 0 (cst negative) - X.zeroExt() <=s cst => 0 (cst negative) - - """ - if not (expr.is_op(TOK_INF_SIGNED) or expr.is_op(TOK_INF_EQUAL_SIGNED)): - return expr - arg1, arg2 = expr.args - if not arg2.is_int(): - return expr - if not (arg1.is_op() and arg1.op.startswith("zeroExt")): - return expr - src = arg1.args[0] - assert src.size < arg1.size - - # If cst is zero - if arg2.is_int(0): - if expr.is_op(TOK_INF_SIGNED): - # X.zeroExt() <s 0 => 0 - return ExprInt(0, expr.size) - else: - # X.zeroExt() <=s 0 => X == 0 - return ExprOp(TOK_EQUAL, src, ExprInt(0, src.size)) - - # cst is not zero - cst = int(arg2) - if cst & (1 << (arg2.size - 1)): - # cst is negative - return ExprInt(0, expr.size) - # cst is positive - if expr.is_op(TOK_INF_SIGNED): - # X.zeroExt() <s cst => X.zeroExt() <u cst (cst positive) - return ExprOp(TOK_INF_UNSIGNED, src, expr_s(arg2[:src.size])) - # X.zeroExt() <=s cst => X.zeroExt() <=u cst (cst positive) - return ExprOp(TOK_INF_EQUAL_UNSIGNED, src, expr_s(arg2[:src.size])) - - -def simp_zeroext_and_cst_eq_cst(expr_s, expr): - """ - A.zeroExt(X) & ... & int == int => A & ... & int[:A.size] == int[:A.size] - """ - if not expr.is_op(TOK_EQUAL): - return expr - arg1, arg2 = expr.args - if not arg2.is_int(): - return expr - if not arg1.is_op('&'): - return expr - is_ok = True - sizes = set() - for arg in arg1.args: - if arg.is_int(): - continue - if (arg.is_op() and - arg.op.startswith("zeroExt")): - sizes.add(arg.args[0].size) - continue - is_ok = False - break - if not is_ok: - return expr - if len(sizes) != 1: - return expr - size = list(sizes)[0] - if int(arg2) > ((1 << size) - 1): - return expr - args = [expr_s(arg[:size]) for arg in arg1.args] - left = ExprOp('&', *args) - right = expr_s(arg2[:size]) - ret = ExprOp(TOK_EQUAL, left, right) - return ret - - -def test_one_bit_set(arg): - """ - Return True if arg has form 1 << X - """ - return arg != 0 and ((arg & (arg - 1)) == 0) - -def simp_x_and_cst_eq_cst(_, expr): - """ - (x & ... & onebitmask == onebitmask) ? A:B => (x & ... & onebitmask) ? A:B - """ - cond = expr.cond - if not cond.is_op(TOK_EQUAL): - return expr - arg1, mask2 = cond.args - if not mask2.is_int(): - return expr - if not test_one_bit_set(int(mask2)): - return expr - if not arg1.is_op('&'): - return expr - mask1 = arg1.args[-1] - if mask1 != mask2: - return expr - cond = ExprOp('&', *arg1.args) - return ExprCond(cond, expr.src1, expr.src2) - -def simp_cmp_int_int(_, expr): - """ - IntA <s IntB => int - IntA <u IntB => int - IntA <=s IntB => int - IntA <=u IntB => int - IntA == IntB => int - """ - if expr.op not in [ - TOK_EQUAL, - TOK_INF_SIGNED, TOK_INF_UNSIGNED, - TOK_INF_EQUAL_SIGNED, TOK_INF_EQUAL_UNSIGNED, - ]: - return expr - if not all(arg.is_int() for arg in expr.args): - return expr - int_a, int_b = expr.args - if expr.is_op(TOK_EQUAL): - if int_a == int_b: - return ExprInt(1, 1) - return ExprInt(0, expr.size) - - if expr.op in [TOK_INF_SIGNED, TOK_INF_EQUAL_SIGNED]: - int_a = int(mod_size2int[int_a.size](int(int_a))) - int_b = int(mod_size2int[int_b.size](int(int_b))) - else: - int_a = int(mod_size2uint[int_a.size](int(int_a))) - int_b = int(mod_size2uint[int_b.size](int(int_b))) - - if expr.op in [TOK_INF_SIGNED, TOK_INF_UNSIGNED]: - ret = int_a < int_b - else: - ret = int_a <= int_b - - if ret: - ret = 1 - else: - ret = 0 - return ExprInt(ret, 1) - - -def simp_ext_cst(_, expr): - """ - Int.zeroExt(X) => Int - Int.signExt(X) => Int - """ - if not (expr.op.startswith("zeroExt") or expr.op.startswith("signExt")): - return expr - arg = expr.args[0] - if not arg.is_int(): - return expr - if expr.op.startswith("zeroExt"): - ret = int(arg) - else: - ret = int(mod_size2int[arg.size](int(arg))) - ret = ExprInt(ret, expr.size) - return ret - - -def simp_slice_of_ext(_, expr): - """ - C.zeroExt(X)[A:B] => 0 if A >= size(C) - C.zeroExt(X)[A:B] => C[A:B] if B <= size(C) - A.zeroExt(X)[0:Y] => A.zeroExt(Y) - """ - if not expr.arg.is_op(): - return expr - if not expr.arg.op.startswith("zeroExt"): - return expr - arg = expr.arg.args[0] - - if expr.start >= arg.size: - # C.zeroExt(X)[A:B] => 0 if A >= size(C) - return ExprInt(0, expr.size) - if expr.stop <= arg.size: - # C.zeroExt(X)[A:B] => C[A:B] if B <= size(C) - return arg[expr.start:expr.stop] - if expr.start == 0: - # A.zeroExt(X)[0:Y] => A.zeroExt(Y) - return arg.zeroExtend(expr.stop) - return expr - -def simp_slice_of_op_ext(expr_s, expr): - """ - (X.zeroExt() + {Z, } + ... + Int)[0:8] => X + ... + int[:] - (X.zeroExt() | ... | Int)[0:8] => X | ... | int[:] - ... - """ - if expr.start != 0: - return expr - src = expr.arg - if not src.is_op(): - return expr - if src.op not in ['+', '|', '^', '&']: - return expr - is_ok = True - for arg in src.args: - if arg.is_int(): - continue - if (arg.is_op() and - arg.op.startswith("zeroExt") and - arg.args[0].size == expr.stop): - continue - if arg.is_compose(): - continue - is_ok = False - break - if not is_ok: - return expr - args = [expr_s(arg[:expr.stop]) for arg in src.args] - return ExprOp(src.op, *args) - - -def simp_cond_logic_ext(expr_s, expr): - """(X.zeroExt() + ... + Int) ? A:B => X + ... + int[:] ? A:B""" - cond = expr.cond - if not cond.is_op(): - return expr - if cond.op not in ["&", "^", "|"]: - return expr - is_ok = True - sizes = set() - for arg in cond.args: - if arg.is_int(): - continue - if (arg.is_op() and - arg.op.startswith("zeroExt")): - sizes.add(arg.args[0].size) - continue - is_ok = False - break - if not is_ok: - return expr - if len(sizes) != 1: - return expr - size = list(sizes)[0] - args = [expr_s(arg[:size]) for arg in cond.args] - cond = ExprOp(cond.op, *args) - return ExprCond(cond, expr.src1, expr.src2) - - -def simp_cond_sign_bit(_, expr): - """(a & .. & 0x80000000) ? A:B => (a & ...) <s 0 ? A:B""" - cond = expr.cond - if not cond.is_op('&'): - return expr - last = cond.args[-1] - if not last.is_int(1 << (last.size - 1)): - return expr - zero = ExprInt(0, expr.cond.size) - if len(cond.args) == 2: - args = [cond.args[0], zero] - else: - args = [ExprOp('&', *list(cond.args[:-1])), zero] - cond = ExprOp(TOK_INF_SIGNED, *args) - return ExprCond(cond, expr.src1, expr.src2) - - -def simp_cond_add(expr_s, expr): - """ - (a+b)?X:Y => (a == b)?Y:X - (a^b)?X:Y => (a == b)?Y:X - """ - cond = expr.cond - if not cond.is_op(): - return expr - if cond.op not in ['+', '^']: - return expr - if len(cond.args) != 2: - return expr - arg1, arg2 = cond.args - if cond.is_op('+'): - new_cond = ExprOp('==', arg1, expr_s(-arg2)) - elif cond.is_op('^'): - new_cond = ExprOp('==', arg1, arg2) - else: - raise ValueError('Bad case') - return ExprCond(new_cond, expr.src2, expr.src1) - - -def simp_cond_eq_1_0(expr_s, expr): - """ - (a == b)?ExprInt(1, 1):ExprInt(0, 1) => a == b - (a <s b)?ExprInt(1, 1):ExprInt(0, 1) => a == b - ... - """ - cond = expr.cond - if not cond.is_op(): - return expr - if cond.op not in [ - TOK_EQUAL, - TOK_INF_SIGNED, TOK_INF_EQUAL_SIGNED, - TOK_INF_UNSIGNED, TOK_INF_EQUAL_UNSIGNED - ]: - return expr - if expr.src1 != ExprInt(1, 1) or expr.src2 != ExprInt(0, 1): - return expr - return cond - - -def simp_cond_inf_eq_unsigned_zero(expr_s, expr): - """ - (a <=u 0) => a == 0 - """ - if not expr.is_op(TOK_INF_EQUAL_UNSIGNED): - return expr - if not expr.args[1].is_int(0): - return expr - return ExprOp(TOK_EQUAL, expr.args[0], expr.args[1]) - - -def simp_test_signext_inf(expr_s, expr): - """A.signExt() <s int => A <s int[:]""" - if not (expr.is_op(TOK_INF_SIGNED) or expr.is_op(TOK_INF_EQUAL_SIGNED)): - return expr - arg, cst = expr.args - if not (arg.is_op() and arg.op.startswith("signExt")): - return expr - if not cst.is_int(): - return expr - base = arg.args[0] - tmp = int(mod_size2int[cst.size](int(cst))) - if -(1 << (base.size - 1)) <= tmp < (1 << (base.size - 1)): - # Can trunc integer - return ExprOp(expr.op, base, expr_s(cst[:base.size])) - if (tmp >= (1 << (base.size - 1)) or - tmp < -(1 << (base.size - 1)) ): - return ExprInt(1, 1) - return expr - - -def simp_test_zeroext_inf(expr_s, expr): - """A.zeroExt() <u int => A <u int[:]""" - if not (expr.is_op(TOK_INF_UNSIGNED) or expr.is_op(TOK_INF_EQUAL_UNSIGNED)): - return expr - arg, cst = expr.args - if not (arg.is_op() and arg.op.startswith("zeroExt")): - return expr - if not cst.is_int(): - return expr - base = arg.args[0] - tmp = int(mod_size2uint[cst.size](int(cst))) - if 0 <= tmp < (1 << base.size): - # Can trunc integer - return ExprOp(expr.op, base, expr_s(cst[:base.size])) - if tmp >= (1 << base.size): - return ExprInt(1, 1) - return expr - - -def simp_add_multiple(_, expr): - """ - X + X => 2 * X - X + X * int1 => X * (1 + int1) - X * int1 + (- X) => X * (int1 - 1) - X + (X << int1) => X * (1 + 2 ** int1) - Correct even if addition overflow/underflow - """ - if not expr.is_op('+'): - return expr - - # Extract each argument and its counter - operands = {} - for arg in expr.args: - if arg.is_op('*') and arg.args[1].is_int(): - base_expr, factor = arg.args - operands[base_expr] = operands.get(base_expr, 0) + int(factor) - elif arg.is_op('<<') and arg.args[1].is_int(): - base_expr, factor = arg.args - operands[base_expr] = operands.get(base_expr, 0) + 2 ** int(factor) - elif arg.is_op("-"): - arg = arg.args[0] - if arg.is_op('<<') and arg.args[1].is_int(): - base_expr, factor = arg.args - operands[base_expr] = operands.get(base_expr, 0) - (2 ** int(factor)) - else: - operands[arg] = operands.get(arg, 0) - 1 - else: - operands[arg] = operands.get(arg, 0) + 1 - out = [] - - # Best effort to factor common args: - # (a + b) * 3 + a + b => (a + b) * 4 - # Does not factor: - # (a + b) * 3 + 2 * a + b => (a + b) * 4 + a - modified = True - while modified: - modified = False - for arg, count in list(viewitems(operands)): - if not arg.is_op('+'): - continue - components = arg.args - if not all(component in operands for component in components): - continue - counters = set(operands[component] for component in components) - if len(counters) != 1: - continue - counter = counters.pop() - for component in components: - del operands[component] - operands[arg] += counter - modified = True - break - - for arg, count in viewitems(operands): - if count == 0: - continue - if count == 1: - out.append(arg) - continue - out.append(arg * ExprInt(count, expr.size)) - - if len(out) == len(expr.args): - # No reductions - return expr - if not out: - return ExprInt(0, expr.size) - if len(out) == 1: - return out[0] - return ExprOp('+', *out) diff --git a/miasm2/expression/simplifications_cond.py b/miasm2/expression/simplifications_cond.py deleted file mode 100644 index f1c224b7..00000000 --- a/miasm2/expression/simplifications_cond.py +++ /dev/null @@ -1,178 +0,0 @@ -################################################################################ -# -# By choice, Miasm2 does not handle comparison as a single operation, but with -# operations corresponding to comparison computation. -# One may want to detect those comparison; this library is designed to add them -# in Miasm2 engine thanks to : -# - Conditions computation in ExprOp -# - Simplifications to catch known condition forms -# -# Conditions currently supported : -# <u, <s, == -# -# Authors : Fabrice DESCLAUX (CEA/DAM), Camille MOUGEY (CEA/DAM) -# -################################################################################ - -import miasm2.expression.expression as m2_expr - - -# Jokers for expression matching - -jok1 = m2_expr.ExprId("jok1", 32) -jok2 = m2_expr.ExprId("jok2", 32) -jok3 = m2_expr.ExprId("jok3", 32) -jok_small = m2_expr.ExprId("jok_small", 1) - - -# Constructors - -def __ExprOp_cond(op, arg1, arg2): - "Return an ExprOp standing for arg1 op arg2 with size to 1" - ec = m2_expr.ExprOp(op, arg1, arg2) - return ec - - -def ExprOp_inf_signed(arg1, arg2): - "Return an ExprOp standing for arg1 <s arg2" - return __ExprOp_cond(m2_expr.TOK_INF_SIGNED, arg1, arg2) - - -def ExprOp_inf_unsigned(arg1, arg2): - "Return an ExprOp standing for arg1 <s arg2" - return __ExprOp_cond(m2_expr.TOK_INF_UNSIGNED, arg1, arg2) - -def ExprOp_equal(arg1, arg2): - "Return an ExprOp standing for arg1 == arg2" - return __ExprOp_cond(m2_expr.TOK_EQUAL, arg1, arg2) - - -# Catching conditions forms - -def __check_msb(e): - """If @e stand for the most significant bit of its arg, return the arg; - False otherwise""" - - if not isinstance(e, m2_expr.ExprSlice): - return False - - arg = e.arg - if e.start != (arg.size - 1) or e.stop != arg.size: - return False - - return arg - -def __match_expr_wrap(e, to_match, jok_list): - "Wrapper around match_expr to canonize pattern" - - to_match = to_match.canonize() - - r = m2_expr.match_expr(e, to_match, jok_list) - if r is False: - return False - - if r == {}: - return False - - return r - -def expr_simp_inf_signed(expr_simp, e): - "((x - y) ^ ((x ^ y) & ((x - y) ^ x))) [31:32] == x <s y" - - arg = __check_msb(e) - if arg is False: - return e - # We want jok3 = jok1 - jok2 - to_match = jok3 ^ ((jok1 ^ jok2) & (jok3 ^ jok1)) - r = __match_expr_wrap(arg, - to_match, - [jok1, jok2, jok3]) - - if r is False: - return e - - new_j3 = expr_simp(r[jok3]) - sub = expr_simp(r[jok1] - r[jok2]) - - if new_j3 == sub: - return ExprOp_inf_signed(r[jok1], r[jok2]) - else: - return e - -def expr_simp_inf_unsigned_inversed(expr_simp, e): - "((x - y) ^ ((x ^ y) & ((x - y) ^ x))) ^ x ^ y [31:32] == x <u y" - - arg = __check_msb(e) - if arg is False: - return e - - # We want jok3 = jok1 - jok2 - to_match = jok3 ^ ((jok1 ^ jok2) & (jok3 ^ jok1)) ^ jok1 ^ jok2 - r = __match_expr_wrap(arg, - to_match, - [jok1, jok2, jok3]) - - if r is False: - return e - - new_j3 = expr_simp(r[jok3]) - sub = expr_simp(r[jok1] - r[jok2]) - - if new_j3 == sub: - return ExprOp_inf_unsigned(r[jok1], r[jok2]) - else: - return e - -def expr_simp_inverse(expr_simp, e): - """(x <u y) ^ ((x ^ y) [31:32]) == x <s y, - (x <s y) ^ ((x ^ y) [31:32]) == x <u y""" - - to_match = (ExprOp_inf_unsigned(jok1, jok2) ^ jok_small) - r = __match_expr_wrap(e, - to_match, - [jok1, jok2, jok_small]) - - # Check for 2 symmetric cases - if r is False: - to_match = (ExprOp_inf_signed(jok1, jok2) ^ jok_small) - r = __match_expr_wrap(e, - to_match, - [jok1, jok2, jok_small]) - - if r is False: - return e - cur_sig = m2_expr.TOK_INF_SIGNED - else: - cur_sig = m2_expr.TOK_INF_UNSIGNED - - - arg = __check_msb(r[jok_small]) - if arg is False: - return e - - if not isinstance(arg, m2_expr.ExprOp) or arg.op != "^": - return e - - op_args = arg.args - if len(op_args) != 2: - return e - - if r[jok1] not in op_args or r[jok2] not in op_args: - return e - - if cur_sig == m2_expr.TOK_INF_UNSIGNED: - return ExprOp_inf_signed(r[jok1], r[jok2]) - else: - return ExprOp_inf_unsigned(r[jok1], r[jok2]) - -def expr_simp_equal(expr_simp, e): - """(x - y)?(0:1) == (x == y)""" - - to_match = m2_expr.ExprCond(jok1 + jok2, m2_expr.ExprInt(0, 1), m2_expr.ExprInt(1, 1)) - r = __match_expr_wrap(e, - to_match, - [jok1, jok2]) - if r is False: - return e - - return ExprOp_equal(r[jok1], expr_simp(-r[jok2])) diff --git a/miasm2/expression/simplifications_explicit.py b/miasm2/expression/simplifications_explicit.py deleted file mode 100644 index 00892201..00000000 --- a/miasm2/expression/simplifications_explicit.py +++ /dev/null @@ -1,159 +0,0 @@ -from miasm2.expression.modint import size2mask -from miasm2.expression.expression import ExprInt, ExprCond, ExprCompose, \ - TOK_EQUAL - - -def simp_ext(_, expr): - if expr.op.startswith('zeroExt_'): - arg = expr.args[0] - if expr.size == arg.size: - return arg - return ExprCompose(arg, ExprInt(0, expr.size - arg.size)) - - if expr.op.startswith("signExt_"): - arg = expr.args[0] - add_size = expr.size - arg.size - new_expr = ExprCompose( - arg, - ExprCond( - arg.msb(), - ExprInt(size2mask(add_size), add_size), - ExprInt(0, add_size) - ) - ) - return new_expr - return expr - - -def simp_flags(_, expr): - args = expr.args - - if expr.is_op("FLAG_EQ"): - return ExprCond(args[0], ExprInt(0, 1), ExprInt(1, 1)) - - elif expr.is_op("FLAG_EQ_AND"): - op1, op2 = args - return ExprCond(op1 & op2, ExprInt(0, 1), ExprInt(1, 1)) - - elif expr.is_op("FLAG_SIGN_SUB"): - return (args[0] - args[1]).msb() - - elif expr.is_op("FLAG_EQ_CMP"): - return ExprCond( - args[0] - args[1], - ExprInt(0, 1), - ExprInt(1, 1), - ) - - elif expr.is_op("FLAG_ADD_CF"): - op1, op2 = args - res = op1 + op2 - return (((op1 ^ op2) ^ res) ^ ((op1 ^ res) & (~(op1 ^ op2)))).msb() - - elif expr.is_op("FLAG_SUB_CF"): - op1, op2 = args - res = op1 - op2 - return (((op1 ^ op2) ^ res) ^ ((op1 ^ res) & (op1 ^ op2))).msb() - - elif expr.is_op("FLAG_ADD_OF"): - op1, op2 = args - res = op1 + op2 - return (((op1 ^ res) & (~(op1 ^ op2)))).msb() - - elif expr.is_op("FLAG_SUB_OF"): - op1, op2 = args - res = op1 - op2 - return (((op1 ^ res) & (op1 ^ op2))).msb() - - elif expr.is_op("FLAG_EQ_ADDWC"): - op1, op2, op3 = args - return ExprCond( - op1 + op2 + op3.zeroExtend(op1.size), - ExprInt(0, 1), - ExprInt(1, 1), - ) - - elif expr.is_op("FLAG_ADDWC_OF"): - op1, op2, op3 = args - res = op1 + op2 + op3.zeroExtend(op1.size) - return (((op1 ^ res) & (~(op1 ^ op2)))).msb() - - elif expr.is_op("FLAG_SUBWC_OF"): - op1, op2, op3 = args - res = op1 - (op2 + op3.zeroExtend(op1.size)) - return (((op1 ^ res) & (op1 ^ op2))).msb() - - elif expr.is_op("FLAG_ADDWC_CF"): - op1, op2, op3 = args - res = op1 + op2 + op3.zeroExtend(op1.size) - return (((op1 ^ op2) ^ res) ^ ((op1 ^ res) & (~(op1 ^ op2)))).msb() - - elif expr.is_op("FLAG_SUBWC_CF"): - op1, op2, op3 = args - res = op1 - (op2 + op3.zeroExtend(op1.size)) - return (((op1 ^ op2) ^ res) ^ ((op1 ^ res) & (op1 ^ op2))).msb() - - elif expr.is_op("FLAG_SIGN_ADDWC"): - op1, op2, op3 = args - return (op1 + op2 + op3.zeroExtend(op1.size)).msb() - - elif expr.is_op("FLAG_SIGN_SUBWC"): - op1, op2, op3 = args - return (op1 - (op2 + op3.zeroExtend(op1.size))).msb() - - - elif expr.is_op("FLAG_EQ_SUBWC"): - op1, op2, op3 = args - res = op1 - (op2 + op3.zeroExtend(op1.size)) - return ExprCond(res, ExprInt(0, 1), ExprInt(1, 1)) - - elif expr.is_op("CC_U<="): - op_cf, op_zf = args - return op_cf | op_zf - - elif expr.is_op("CC_U>="): - op_cf, = args - return ~op_cf - - elif expr.is_op("CC_S<"): - op_nf, op_of = args - return op_nf ^ op_of - - elif expr.is_op("CC_S>"): - op_nf, op_of, op_zf = args - return ~(op_zf | (op_nf ^ op_of)) - - elif expr.is_op("CC_S<="): - op_nf, op_of, op_zf = args - return op_zf | (op_nf ^ op_of) - - elif expr.is_op("CC_S>="): - op_nf, op_of = args - return ~(op_nf ^ op_of) - - elif expr.is_op("CC_U>"): - op_cf, op_zf = args - return ~(op_cf | op_zf) - - elif expr.is_op("CC_U<"): - op_cf, = args - return op_cf - - elif expr.is_op("CC_NEG"): - op_nf, = args - return op_nf - - elif expr.is_op("CC_EQ"): - op_zf, = args - return op_zf - - elif expr.is_op("CC_NE"): - op_zf, = args - return ~op_zf - - elif expr.is_op("CC_POS"): - op_nf, = args - return ~op_nf - - return expr - diff --git a/miasm2/expression/smt2_helper.py b/miasm2/expression/smt2_helper.py deleted file mode 100644 index 53d323e8..00000000 --- a/miasm2/expression/smt2_helper.py +++ /dev/null @@ -1,296 +0,0 @@ -# Helper functions for the generation of SMT2 expressions -# The SMT2 expressions will be returned as a string. -# The expressions are divided as follows -# -# - generic SMT2 operations -# - definitions of SMT2 structures -# - bit vector operations -# - array operations - -# generic SMT2 operations - -def smt2_eq(a, b): - """ - Assignment: a = b - """ - return "(= {} {})".format(a, b) - - -def smt2_implies(a, b): - """ - Implication: a => b - """ - return "(=> {} {})".format(a, b) - - -def smt2_and(*args): - """ - Conjunction: a and b and c ... - """ - # transform args into strings - args = [str(arg) for arg in args] - return "(and {})".format(' '.join(args)) - - -def smt2_or(*args): - """ - Disjunction: a or b or c ... - """ - # transform args into strings - args = [str(arg) for arg in args] - return "(or {})".format(' '.join(args)) - - -def smt2_ite(cond, a, b): - """ - If-then-else: cond ? a : b - """ - return "(ite {} {} {})".format(cond, a, b) - - -def smt2_distinct(*args): - """ - Distinction: a != b != c != ... - """ - # transform args into strings - args = [str(arg) for arg in args] - return "(distinct {})".format(' '.join(args)) - - -def smt2_assert(expr): - """ - Assertion that @expr holds - """ - return "(assert {})".format(expr) - - -# definitions - -def declare_bv(bv, size): - """ - Declares an bit vector @bv of size @size - """ - return "(declare-fun {} () {})".format(bv, bit_vec(size)) - - -def declare_array(a, bv1, bv2): - """ - Declares an SMT2 array represented as a map - from a bit vector to another bit vector. - :param a: array name - :param bv1: SMT2 bit vector - :param bv2: SMT2 bit vector - """ - return "(declare-fun {} () (Array {} {}))".format(a, bv1, bv2) - - -def bit_vec_val(v, size): - """ - Declares a bit vector value - :param v: int, value of the bit vector - :param size: size of the bit vector - """ - return "(_ bv{} {})".format(v, size) - - -def bit_vec(size): - """ - Returns a bit vector of size @size - """ - return "(_ BitVec {})".format(size) - - -# bit vector operations - -def bvadd(a, b): - """ - Addition: a + b - """ - return "(bvadd {} {})".format(a, b) - - -def bvsub(a, b): - """ - Subtraction: a - b - """ - return "(bvsub {} {})".format(a, b) - - -def bvmul(a, b): - """ - Multiplication: a * b - """ - return "(bvmul {} {})".format(a, b) - - -def bvand(a, b): - """ - Bitwise AND: a & b - """ - return "(bvand {} {})".format(a, b) - - -def bvor(a, b): - """ - Bitwise OR: a | b - """ - return "(bvor {} {})".format(a, b) - - -def bvxor(a, b): - """ - Bitwise XOR: a ^ b - """ - return "(bvxor {} {})".format(a, b) - - -def bvneg(bv): - """ - Unary minus: - bv - """ - return "(bvneg {})".format(bv) - - -def bvsdiv(a, b): - """ - Signed division: a / b - """ - return "(bvsdiv {} {})".format(a, b) - - -def bvudiv(a, b): - """ - Unsigned division: a / b - """ - return "(bvudiv {} {})".format(a, b) - - -def bvsmod(a, b): - """ - Signed modulo: a mod b - """ - return "(bvsmod {} {})".format(a, b) - - -def bvurem(a, b): - """ - Unsigned modulo: a mod b - """ - return "(bvurem {} {})".format(a, b) - - -def bvshl(a, b): - """ - Shift left: a << b - """ - return "(bvshl {} {})".format(a, b) - - -def bvlshr(a, b): - """ - Logical shift right: a >> b - """ - return "(bvlshr {} {})".format(a, b) - - -def bvashr(a, b): - """ - Arithmetic shift right: a a>> b - """ - return "(bvashr {} {})".format(a, b) - - -def bv_rotate_left(a, b, size): - """ - Rotates bits of a to the left b times: a <<< b - - Since ((_ rotate_left b) a) does not support - symbolic values for b, the implementation is - based on a C implementation. - - Therefore, the rotation will be computed as - a << (b & (size - 1))) | (a >> (size - (b & (size - 1)))) - - :param a: bit vector - :param b: bit vector - :param size: size of a - """ - - # define constant - s = bit_vec_val(size, size) - - # shift = b & (size - 1) - shift = bvand(b, bvsub(s, bit_vec_val(1, size))) - - # (a << shift) | (a >> size - shift) - rotate = bvor(bvshl(a, shift), - bvlshr(a, bvsub(s, shift))) - - return rotate - - -def bv_rotate_right(a, b, size): - """ - Rotates bits of a to the right b times: a >>> b - - Since ((_ rotate_right b) a) does not support - symbolic values for b, the implementation is - based on a C implementation. - - Therefore, the rotation will be computed as - a >> (b & (size - 1))) | (a << (size - (b & (size - 1)))) - - :param a: bit vector - :param b: bit vector - :param size: size of a - """ - - # define constant - s = bit_vec_val(size, size) - - # shift = b & (size - 1) - shift = bvand(b, bvsub(s, bit_vec_val(1, size))) - - # (a >> shift) | (a << size - shift) - rotate = bvor(bvlshr(a, shift), - bvshl(a, bvsub(s, shift))) - - return rotate - - -def bv_extract(high, low, bv): - """ - Extracts bits from a bit vector - :param high: end bit - :param low: start bit - :param bv: bit vector - """ - return "((_ extract {} {}) {})".format(high, low, bv) - - -def bv_concat(a, b): - """ - Concatenation of two SMT2 expressions - """ - return "(concat {} {})".format(a, b) - - -# array operations - -def array_select(array, index): - """ - Reads from an SMT2 array at index @index - :param array: SMT2 array - :param index: SMT2 expression, index of the array - """ - return "(select {} {})".format(array, index) - - -def array_store(array, index, value): - """ - Writes an value into an SMT2 array at address @index - :param array: SMT array - :param index: SMT2 expression, index of the array - :param value: SMT2 expression, value to write - """ - return "(store {} {} {})".format(array, index, value) |