diff options
Diffstat (limited to 'src/miasm/expression')
| -rw-r--r-- | src/miasm/expression/__init__.py | 18 | ||||
| -rw-r--r-- | src/miasm/expression/expression.py | 2175 | ||||
| -rw-r--r-- | src/miasm/expression/expression_helper.py | 628 | ||||
| -rw-r--r-- | src/miasm/expression/expression_reduce.py | 280 | ||||
| -rw-r--r-- | src/miasm/expression/parser.py | 84 | ||||
| -rw-r--r-- | src/miasm/expression/simplifications.py | 201 | ||||
| -rw-r--r-- | src/miasm/expression/simplifications_common.py | 1868 | ||||
| -rw-r--r-- | src/miasm/expression/simplifications_cond.py | 178 | ||||
| -rw-r--r-- | src/miasm/expression/simplifications_explicit.py | 159 | ||||
| -rw-r--r-- | src/miasm/expression/smt2_helper.py | 296 |
10 files changed, 5887 insertions, 0 deletions
diff --git a/src/miasm/expression/__init__.py b/src/miasm/expression/__init__.py new file mode 100644 index 00000000..67f567f7 --- /dev/null +++ b/src/miasm/expression/__init__.py @@ -0,0 +1,18 @@ +# +# Copyright (C) 2011 EADS France, Fabrice Desclaux <fabrice.desclaux@eads.net> +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with this program; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +"Intermediate language implementation" diff --git a/src/miasm/expression/expression.py b/src/miasm/expression/expression.py new file mode 100644 index 00000000..4b0bbe6b --- /dev/null +++ b/src/miasm/expression/expression.py @@ -0,0 +1,2175 @@ +# +# Copyright (C) 2011 EADS France, Fabrice Desclaux <fabrice.desclaux@eads.net> +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with this program; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# These module implements Miasm IR components and basic operations related. +# IR components are : +# - ExprInt +# - ExprId +# - ExprLoc +# - ExprAssign +# - ExprCond +# - ExprMem +# - ExprOp +# - ExprSlice +# - ExprCompose +# + + +from builtins import zip +from builtins import range +import warnings +import itertools +from builtins import int as int_types +from functools import cmp_to_key, total_ordering +from future.utils import viewitems + +from miasm.core.utils import force_bytes, cmp_elts +from miasm.core.graph import DiGraph +from functools import reduce + +# Define tokens +TOK_INF = "<" +TOK_INF_SIGNED = TOK_INF + "s" +TOK_INF_UNSIGNED = TOK_INF + "u" +TOK_INF_EQUAL = "<=" +TOK_INF_EQUAL_SIGNED = TOK_INF_EQUAL + "s" +TOK_INF_EQUAL_UNSIGNED = TOK_INF_EQUAL + "u" +TOK_EQUAL = "==" +TOK_POS = "pos" +TOK_POS_STRICT = "Spos" + +# Hashing constants +EXPRINT = 1 +EXPRID = 2 +EXPRLOC = 3 +EXPRASSIGN = 4 +EXPRCOND = 5 +EXPRMEM = 6 +EXPROP = 7 +EXPRSLICE = 8 +EXPRCOMPOSE = 9 + + +priorities_list = [ + [ '+' ], + [ '*', '/', '%' ], + [ '**' ], + [ '-' ], # Unary '-', associativity with + not handled +] + +# dictionary from 'op' to priority, derived from above +priorities = dict((op, prio) + for prio, l in enumerate(priorities_list) + for op in l) +PRIORITY_MAX = len(priorities_list) - 1 + +def should_parenthesize_child(child, parent): + if (isinstance(child, ExprId) or isinstance(child, ExprInt) or + isinstance(child, ExprCompose) or isinstance(child, ExprMem) or + isinstance(child, ExprSlice)): + return False + elif isinstance(child, ExprOp) and not child.is_infix(): + return False + elif (isinstance(child, ExprCond) or isinstance(parent, ExprSlice)): + return True + elif (isinstance(child, ExprOp) and isinstance(parent, ExprOp)): + pri_child = priorities.get(child.op, -1) + pri_parent = priorities.get(parent.op, PRIORITY_MAX + 1) + return pri_child < pri_parent + else: + return True + +def str_protected_child(child, parent): + return ("(%s)" % child) if should_parenthesize_child(child, parent) else str(child) + + +# Expression display + + +class DiGraphExpr(DiGraph): + + """Enhanced graph for Expression display + Expression are displayed as a tree with node and edge labeled + with only relevant information""" + + def node2str(self, node): + if isinstance(node, ExprOp): + return node.op + elif isinstance(node, ExprId): + return node.name + elif isinstance(node, ExprLoc): + return "%s" % node.loc_key + elif isinstance(node, ExprMem): + return "@%d" % node.size + elif isinstance(node, ExprCompose): + return "{ %d }" % node.size + elif isinstance(node, ExprCond): + return "? %d" % node.size + elif isinstance(node, ExprSlice): + return "[%d:%d]" % (node.start, node.stop) + return str(node) + + def edge2str(self, nfrom, nto): + if isinstance(nfrom, ExprCompose): + for i in nfrom.args: + if i[0] == nto: + return "[%s, %s]" % (i[1], i[2]) + elif isinstance(nfrom, ExprCond): + if nfrom.cond == nto: + return "?" + elif nfrom.src1 == nto: + return "True" + elif nfrom.src2 == nto: + return "False" + + return "" + +def is_expr(expr): + return isinstance( + expr, + ( + ExprInt, ExprId, ExprMem, + ExprSlice, ExprCompose, ExprCond, + ExprLoc, ExprOp + ) + ) + +def is_associative(expr): + "Return True iff current operation is associative" + return (expr.op in ['+', '*', '^', '&', '|']) + +def is_commutative(expr): + "Return True iff current operation is commutative" + return (expr.op in ['+', '*', '^', '&', '|']) + +def canonize_to_exprloc(locdb, expr): + """ + If expr is ExprInt, return ExprLoc with corresponding loc_key + Else, return expr + + @expr: Expr instance + """ + if expr.is_int(): + loc_key = locdb.get_or_create_offset_location(int(expr)) + ret = ExprLoc(loc_key, expr.size) + return ret + return expr + +def is_function_call(expr): + """Returns true if the considered Expr is a function call + """ + return expr.is_op() and expr.op.startswith('call') + +@total_ordering +class LocKey(object): + def __init__(self, key): + self._key = key + + key = property(lambda self: self._key) + + def __hash__(self): + return hash(self._key) + + def __eq__(self, other): + if self is other: + return True + if self.__class__ is not other.__class__: + return False + return self.key == other.key + + def __ne__(self, other): + # required Python 2.7.14 + return not self == other + + def __lt__(self, other): + return self.key < other.key + + def __repr__(self): + return "<%s %d>" % (self.__class__.__name__, self._key) + + def __str__(self): + return "loc_key_%d" % self.key + + +class ExprWalkBase(object): + """ + Walk through sub-expressions, call @callback on them. + If @callback returns a non None value, stop walk and return this value + """ + + def __init__(self, callback): + self.callback = callback + + def visit(self, expr, *args, **kwargs): + if expr.is_int() or expr.is_id() or expr.is_loc(): + pass + elif expr.is_assign(): + ret = self.visit(expr.dst, *args, **kwargs) + if ret: + return ret + src = self.visit(expr.src, *args, **kwargs) + if ret: + return ret + elif expr.is_cond(): + ret = self.visit(expr.cond, *args, **kwargs) + if ret: + return ret + ret = self.visit(expr.src1, *args, **kwargs) + if ret: + return ret + ret = self.visit(expr.src2, *args, **kwargs) + if ret: + return ret + elif expr.is_mem(): + ret = self.visit(expr.ptr, *args, **kwargs) + if ret: + return ret + elif expr.is_slice(): + ret = self.visit(expr.arg, *args, **kwargs) + if ret: + return ret + elif expr.is_op(): + for arg in expr.args: + ret = self.visit(arg, *args, **kwargs) + if ret: + return ret + elif expr.is_compose(): + for arg in expr.args: + ret = self.visit(arg, *args, **kwargs) + if ret: + return ret + else: + raise TypeError("Visitor can only take Expr") + + ret = self.callback(expr, *args, **kwargs) + return ret + + +class ExprWalk(ExprWalkBase): + """ + Walk through sub-expressions, call @callback on them. + If @callback returns a non None value, stop walk and return this value + Use cache mechanism. + """ + def __init__(self, callback): + self.cache = set() + self.callback = callback + + def visit(self, expr, *args, **kwargs): + if expr in self.cache: + return None + ret = super(ExprWalk, self).visit(expr, *args, **kwargs) + if ret: + return ret + self.cache.add(expr) + return None + + +class ExprGetR(ExprWalkBase): + """ + Return ExprId/ExprMem used by a given expression + """ + def __init__(self, mem_read=False, cst_read=False): + super(ExprGetR, self).__init__(lambda x:None) + self.mem_read = mem_read + self.cst_read = cst_read + self.elements = set() + self.cache = dict() + + def get_r_leaves(self, expr): + if (expr.is_int() or expr.is_loc()) and self.cst_read: + self.elements.add(expr) + elif expr.is_mem(): + self.elements.add(expr) + elif expr.is_id(): + self.elements.add(expr) + + def visit(self, expr, *args, **kwargs): + cache_key = (expr, self.mem_read, self.cst_read) + if cache_key in self.cache: + return self.cache[cache_key] + ret = self.visit_inner(expr, *args, **kwargs) + self.cache[cache_key] = ret + return ret + + def visit_inner(self, expr, *args, **kwargs): + self.get_r_leaves(expr) + if expr.is_mem() and not self.mem_read: + # Don't visit memory sons + return None + + if expr.is_assign(): + if expr.dst.is_mem() and self.mem_read: + ret = super(ExprGetR, self).visit(expr.dst, *args, **kwargs) + if expr.src.is_mem(): + self.elements.add(expr.src) + self.get_r_leaves(expr.src) + if expr.src.is_mem() and not self.mem_read: + return None + ret = super(ExprGetR, self).visit(expr.src, *args, **kwargs) + return ret + ret = super(ExprGetR, self).visit(expr, *args, **kwargs) + return ret + + +class ExprVisitorBase(object): + """ + Rebuild expression by visiting sub-expressions + """ + def visit(self, expr, *args, **kwargs): + if expr.is_int() or expr.is_id() or expr.is_loc(): + ret = expr + elif expr.is_assign(): + dst = self.visit(expr.dst, *args, **kwargs) + src = self.visit(expr.src, *args, **kwargs) + ret = ExprAssign(dst, src) + elif expr.is_cond(): + cond = self.visit(expr.cond, *args, **kwargs) + src1 = self.visit(expr.src1, *args, **kwargs) + src2 = self.visit(expr.src2, *args, **kwargs) + ret = ExprCond(cond, src1, src2) + elif expr.is_mem(): + ptr = self.visit(expr.ptr, *args, **kwargs) + ret = ExprMem(ptr, expr.size) + elif expr.is_slice(): + arg = self.visit(expr.arg, *args, **kwargs) + ret = ExprSlice(arg, expr.start, expr.stop) + elif expr.is_op(): + args = [self.visit(arg, *args, **kwargs) for arg in expr.args] + ret = ExprOp(expr.op, *args) + elif expr.is_compose(): + args = [self.visit(arg, *args, **kwargs) for arg in expr.args] + ret = ExprCompose(*args) + else: + raise TypeError("Visitor can only take Expr") + return ret + + +class ExprVisitorCallbackTopToBottom(ExprVisitorBase): + """ + Rebuild expression by visiting sub-expressions + Call @callback on each sub-expression + if @callback return non None value, replace current node with this value + Else, continue visit of sub-expressions + """ + def __init__(self, callback): + super(ExprVisitorCallbackTopToBottom, self).__init__() + self.cache = dict() + self.callback = callback + + def visit(self, expr, *args, **kwargs): + if expr in self.cache: + return self.cache[expr] + ret = self.visit_inner(expr, *args, **kwargs) + self.cache[expr] = ret + return ret + + def visit_inner(self, expr, *args, **kwargs): + ret = self.callback(expr) + if ret: + return ret + ret = super(ExprVisitorCallbackTopToBottom, self).visit(expr, *args, **kwargs) + return ret + + +class ExprVisitorCallbackBottomToTop(ExprVisitorBase): + """ + Rebuild expression by visiting sub-expressions + Call @callback from leaves to root expressions + """ + def __init__(self, callback): + super(ExprVisitorCallbackBottomToTop, self).__init__() + self.cache = dict() + self.callback = callback + + def visit(self, expr, *args, **kwargs): + if expr in self.cache: + return self.cache[expr] + ret = self.visit_inner(expr, *args, **kwargs) + self.cache[expr] = ret + return ret + + def visit_inner(self, expr, *args, **kwargs): + ret = super(ExprVisitorCallbackBottomToTop, self).visit(expr, *args, **kwargs) + ret = self.callback(ret) + return ret + + +class ExprVisitorCanonize(ExprVisitorCallbackBottomToTop): + def __init__(self): + super(ExprVisitorCanonize, self).__init__(self.canonize) + + def canonize(self, expr): + if not expr.is_op(): + return expr + if not expr.is_associative(): + return expr + + # ((a+b) + c) => (a + b + c) + args = [] + for arg in expr.args: + if isinstance(arg, ExprOp) and expr.op == arg.op: + args += arg.args + else: + args.append(arg) + args = canonize_expr_list(args) + new_expr = ExprOp(expr.op, *args) + return new_expr + + +class ExprVisitorContains(ExprWalkBase): + """ + Visitor to test if a needle is in an Expression + Cache results + """ + def __init__(self): + self.cache = set() + super(ExprVisitorContains, self).__init__(self.eq_expr) + + def eq_expr(self, expr, needle, *args, **kwargs): + if expr == needle: + return True + return None + + def visit(self, expr, needle, *args, **kwargs): + if (expr, needle) in self.cache: + return None + ret = super(ExprVisitorContains, self).visit(expr, needle, *args, **kwargs) + if ret: + return ret + self.cache.add((expr, needle)) + return None + + + def contains(self, expr, needle): + return self.visit(expr, needle) + +contains_visitor = ExprVisitorContains() +canonize_visitor = ExprVisitorCanonize() + +# IR definitions + +class Expr(object): + + "Parent class for Miasm Expressions" + + __slots__ = ["_hash", "_repr", "_size"] + + args2expr = {} + canon_exprs = set() + use_singleton = True + + def set_size(self, _): + raise ValueError('size is not mutable') + + def __init__(self, size): + """Instantiate an Expr with size @size + @size: int + """ + # Common attribute + self._size = size + + # Lazy cache needs + self._hash = None + self._repr = None + + size = property(lambda self: self._size) + + @staticmethod + def get_object(expr_cls, args): + if not expr_cls.use_singleton: + return object.__new__(expr_cls) + + expr = Expr.args2expr.get((expr_cls, args)) + if expr is None: + expr = object.__new__(expr_cls) + Expr.args2expr[(expr_cls, args)] = expr + return expr + + def get_is_canon(self): + return self in Expr.canon_exprs + + def set_is_canon(self, value): + assert value is True + Expr.canon_exprs.add(self) + + is_canon = property(get_is_canon, set_is_canon) + + # Common operations + + def __str__(self): + raise NotImplementedError("Abstract Method") + + def __getitem__(self, i): + if not isinstance(i, slice): + raise TypeError("Expression: Bad slice: %s" % i) + start, stop, step = i.indices(self.size) + if step != 1: + raise ValueError("Expression: Bad slice: %s" % i) + return ExprSlice(self, start, stop) + + def get_size(self): + raise DeprecationWarning("use X.size instead of X.get_size()") + + def is_function_call(self): + """Returns true if the considered Expr is a function call + """ + return False + + def __repr__(self): + if self._repr is None: + self._repr = self._exprrepr() + return self._repr + + def __hash__(self): + if self._hash is None: + self._hash = self._exprhash() + return self._hash + + def __eq__(self, other): + if self is other: + return True + elif self.use_singleton: + # In case of Singleton, pointer comparison is sufficient + # Avoid computation of hash and repr + return False + + if self.__class__ is not other.__class__: + return False + if hash(self) != hash(other): + return False + return repr(self) == repr(other) + + def __ne__(self, other): + return not self.__eq__(other) + + def __lt__(self, other): + weight1 = EXPR_ORDER_DICT[self.__class__] + weight2 = EXPR_ORDER_DICT[other.__class__] + return weight1 < weight2 + + def __add__(self, other): + return ExprOp('+', self, other) + + def __sub__(self, other): + return ExprOp('+', self, ExprOp('-', other)) + + def __truediv__(self, other): + return ExprOp('/', self, other) + + def __floordiv__(self, other): + return self.__truediv__(other) + + def __mod__(self, other): + return ExprOp('%', self, other) + + def __mul__(self, other): + return ExprOp('*', self, other) + + def __lshift__(self, other): + return ExprOp('<<', self, other) + + def __rshift__(self, other): + return ExprOp('>>', self, other) + + def __xor__(self, other): + return ExprOp('^', self, other) + + def __or__(self, other): + return ExprOp('|', self, other) + + def __and__(self, other): + return ExprOp('&', self, other) + + def __neg__(self): + return ExprOp('-', self) + + def __pow__(self, other): + return ExprOp("**", self, other) + + def __invert__(self): + return ExprOp('^', self, self.mask) + + def copy(self): + "Deep copy of the expression" + return self.visit(lambda x: x) + + def __deepcopy__(self, _): + return self.copy() + + def replace_expr(self, dct): + """Find and replace sub expression using dct + @dct: dictionary associating replaced Expr to its new Expr value + """ + def replace(expr): + if expr in dct: + return dct[expr] + return None + visitor = ExprVisitorCallbackTopToBottom(lambda expr:replace(expr)) + return visitor.visit(self) + + def canonize(self): + "Canonize the Expression" + return canonize_visitor.visit(self) + + def msb(self): + "Return the Most Significant Bit" + return self[self.size - 1:self.size] + + def zeroExtend(self, size): + """Zero extend to size + @size: int + """ + assert self.size <= size + if self.size == size: + return self + return ExprOp('zeroExt_%d' % size, self) + + def signExtend(self, size): + """Sign extend to size + @size: int + """ + assert self.size <= size + if self.size == size: + return self + return ExprOp('signExt_%d' % size, self) + + def graph_recursive(self, graph): + """Recursive method used by graph + @graph: miasm.core.graph.DiGraph instance + Update @graph instance to include sons + This is an Abstract method""" + + raise ValueError("Abstract method") + + def graph(self): + """Return a DiGraph instance standing for Expr tree + Instance's display functions have been override for better visibility + Wrapper on graph_recursive""" + + # Create recursively the graph + graph = DiGraphExpr() + self.graph_recursive(graph) + + return graph + + def set_mask(self, value): + raise ValueError('mask is not mutable') + + mask = property(lambda self: ExprInt(-1, self.size)) + + def is_int(self, value=None): + return False + + def is_id(self, name=None): + return False + + def is_loc(self, label=None): + return False + + def is_aff(self): + warnings.warn('DEPRECATION WARNING: use is_assign()') + return False + + def is_assign(self): + return False + + def is_cond(self): + return False + + def is_mem(self): + return False + + def is_op(self, op=None): + return False + + def is_slice(self, start=None, stop=None): + return False + + def is_compose(self): + return False + + def is_op_segm(self): + """Returns True if is ExprOp and op == 'segm'""" + warnings.warn('DEPRECATION WARNING: use is_op_segm(expr)') + raise RuntimeError("Moved api") + + def is_mem_segm(self): + """Returns True if is ExprMem and ptr is_op_segm""" + warnings.warn('DEPRECATION WARNING: use is_mem_segm(expr)') + raise RuntimeError("Moved api") + + def __contains__(self, expr): + ret = contains_visitor.contains(self, expr) + return ret + + def visit(self, callback): + """ + Apply callback to all sub expression of @self + This function keeps a cache to avoid rerunning @callback on common sub + expressions. + + @callback: fn(Expr) -> Expr + """ + visitor = ExprVisitorCallbackBottomToTop(callback) + return visitor.visit(self) + + def get_r(self, mem_read=False, cst_read=False): + visitor = ExprGetR(mem_read, cst_read) + visitor.visit(self) + return visitor.elements + + + def get_w(self, mem_read=False, cst_read=False): + if self.is_assign(): + return set([self.dst]) + return set() + +class ExprInt(Expr): + + """An ExprInt represent a constant in Miasm IR. + + Some use cases: + - Constant 0x42 + - Constant -0x30 + - Constant 0x12345678 on 32bits + """ + + __slots__ = Expr.__slots__ + ["_arg"] + + + def __init__(self, arg, size): + """Create an ExprInt from num/size + @arg: int/long number + @size: int size""" + super(ExprInt, self).__init__(size) + # Work for ._arg is done in __new__ + + arg = property(lambda self: self._arg) + + def __reduce__(self): + state = int(self._arg), self._size + return self.__class__, state + + def __new__(cls, arg, size): + """Create an ExprInt from num/size + @arg: int/long number + @size: int size""" + + assert isinstance(arg, int_types) + arg = arg & ((1 << size) - 1) + # Get the Singleton instance + expr = Expr.get_object(cls, (arg, size)) + + # Save parameters (__init__ is called with parameters unchanged) + expr._arg = arg + return expr + + def __str__(self): + return str("0x%X" % self.arg) + + def get_w(self): + return set() + + def _exprhash(self): + return hash((EXPRINT, self._arg, self._size)) + + def _exprrepr(self): + return "%s(0x%X, %d)" % (self.__class__.__name__, self.arg, + self._size) + + def copy(self): + return ExprInt(self._arg, self._size) + + def depth(self): + return 1 + + def graph_recursive(self, graph): + graph.add_node(self) + + def __int__(self): + return int(self.arg) + + def __long__(self): + return int(self.arg) + + def is_int(self, value=None): + if value is not None and self._arg != value: + return False + return True + + +class ExprId(Expr): + + """An ExprId represent an identifier in Miasm IR. + + Some use cases: + - EAX register + - 'start' offset + - variable v1 + """ + + __slots__ = Expr.__slots__ + ["_name"] + + def __init__(self, name, size=None): + """Create an identifier + @name: str, identifier's name + @size: int, identifier's size + """ + if size is None: + warnings.warn('DEPRECATION WARNING: size is a mandatory argument: use ExprId(name, SIZE)') + size = 32 + assert isinstance(name, (str, bytes)) + super(ExprId, self).__init__(size) + self._name = name + + name = property(lambda self: self._name) + + def __reduce__(self): + state = self._name, self._size + return self.__class__, state + + def __new__(cls, name, size=None): + if size is None: + warnings.warn('DEPRECATION WARNING: size is a mandatory argument: use ExprId(name, SIZE)') + size = 32 + return Expr.get_object(cls, (name, size)) + + def __str__(self): + return str(self._name) + + def get_w(self): + return set([self]) + + def _exprhash(self): + return hash((EXPRID, self._name, self._size)) + + def _exprrepr(self): + return "%s(%r, %d)" % (self.__class__.__name__, self._name, self._size) + + def copy(self): + return ExprId(self._name, self._size) + + def depth(self): + return 1 + + def graph_recursive(self, graph): + graph.add_node(self) + + def is_id(self, name=None): + if name is not None and self._name != name: + return False + return True + + +class ExprLoc(Expr): + + """An ExprLoc represent a Label in Miasm IR. + """ + + __slots__ = Expr.__slots__ + ["_loc_key"] + + def __init__(self, loc_key, size): + """Create an identifier + @loc_key: int, label loc_key + @size: int, identifier's size + """ + assert isinstance(loc_key, LocKey) + super(ExprLoc, self).__init__(size) + self._loc_key = loc_key + + loc_key= property(lambda self: self._loc_key) + + def __reduce__(self): + state = self._loc_key, self._size + return self.__class__, state + + def __new__(cls, loc_key, size): + return Expr.get_object(cls, (loc_key, size)) + + def __str__(self): + return str(self._loc_key) + + def get_w(self): + return set() + + def _exprhash(self): + return hash((EXPRLOC, self._loc_key, self._size)) + + def _exprrepr(self): + return "%s(%r, %d)" % (self.__class__.__name__, self._loc_key, self._size) + + def copy(self): + return ExprLoc(self._loc_key, self._size) + + def depth(self): + return 1 + + def graph_recursive(self, graph): + graph.add_node(self) + + def is_loc(self, loc_key=None): + if loc_key is not None and self._loc_key != loc_key: + return False + return True + + +class ExprAssign(Expr): + + """An ExprAssign represent an assignment from an Expression to another one. + + Some use cases: + - var1 <- 2 + """ + + __slots__ = Expr.__slots__ + ["_dst", "_src"] + + def __init__(self, dst, src): + """Create an ExprAssign for dst <- src + @dst: Expr, assignment destination + @src: Expr, assignment source + """ + # dst & src must be Expr + assert isinstance(dst, Expr) + assert isinstance(src, Expr) + + if dst.size != src.size: + raise ValueError( + "sanitycheck: ExprAssign args must have same size! %s" % + ([(str(arg), arg.size) for arg in [dst, src]])) + + super(ExprAssign, self).__init__(self.dst.size) + + dst = property(lambda self: self._dst) + src = property(lambda self: self._src) + + + def __reduce__(self): + state = self._dst, self._src + return self.__class__, state + + def __new__(cls, dst, src): + if dst.is_slice() and dst.arg.size == src.size: + new_dst, new_src = dst.arg, src + elif dst.is_slice(): + # Complete the source with missing slice parts + new_dst = dst.arg + rest = [(ExprSlice(dst.arg, r[0], r[1]), r[0], r[1]) + for r in dst.slice_rest()] + all_a = [(src, dst.start, dst.stop)] + rest + all_a.sort(key=lambda x: x[1]) + args = [expr for (expr, _, _) in all_a] + new_src = ExprCompose(*args) + else: + new_dst, new_src = dst, src + expr = Expr.get_object(cls, (new_dst, new_src)) + expr._dst, expr._src = new_dst, new_src + return expr + + def __str__(self): + return "%s = %s" % (str(self._dst), str(self._src)) + + def get_w(self): + if isinstance(self._dst, ExprMem): + return set([self._dst]) # [memreg] + else: + return self._dst.get_w() + + def _exprhash(self): + return hash((EXPRASSIGN, hash(self._dst), hash(self._src))) + + def _exprrepr(self): + return "%s(%r, %r)" % (self.__class__.__name__, self._dst, self._src) + + def copy(self): + return ExprAssign(self._dst.copy(), self._src.copy()) + + def depth(self): + return max(self._src.depth(), self._dst.depth()) + 1 + + def graph_recursive(self, graph): + graph.add_node(self) + for arg in [self._src, self._dst]: + arg.graph_recursive(graph) + graph.add_uniq_edge(self, arg) + + + def is_aff(self): + warnings.warn('DEPRECATION WARNING: use is_assign()') + return True + + def is_assign(self): + return True + + +class ExprAff(ExprAssign): + """ + DEPRECATED class. + Use ExprAssign instead of ExprAff + """ + + def __init__(self, dst, src): + warnings.warn('DEPRECATION WARNING: use ExprAssign instead of ExprAff') + super(ExprAff, self).__init__(dst, src) + + +class ExprCond(Expr): + + """An ExprCond stand for a condition on an Expr + + Use cases: + - var1 < var2 + - min(var1, var2) + - if (cond) then ... else ... + """ + + __slots__ = Expr.__slots__ + ["_cond", "_src1", "_src2"] + + def __init__(self, cond, src1, src2): + """Create an ExprCond + @cond: Expr, condition + @src1: Expr, value if condition is evaled to not zero + @src2: Expr, value if condition is evaled zero + """ + + # cond, src1, src2 must be Expr + assert isinstance(cond, Expr) + assert isinstance(src1, Expr) + assert isinstance(src2, Expr) + + self._cond, self._src1, self._src2 = cond, src1, src2 + assert src1.size == src2.size + super(ExprCond, self).__init__(self.src1.size) + + cond = property(lambda self: self._cond) + src1 = property(lambda self: self._src1) + src2 = property(lambda self: self._src2) + + def __reduce__(self): + state = self._cond, self._src1, self._src2 + return self.__class__, state + + def __new__(cls, cond, src1, src2): + return Expr.get_object(cls, (cond, src1, src2)) + + def __str__(self): + return "%s?(%s,%s)" % (str_protected_child(self._cond, self), str(self._src1), str(self._src2)) + + def get_w(self): + return set() + + def _exprhash(self): + return hash((EXPRCOND, hash(self.cond), + hash(self._src1), hash(self._src2))) + + def _exprrepr(self): + return "%s(%r, %r, %r)" % (self.__class__.__name__, + self._cond, self._src1, self._src2) + + def copy(self): + return ExprCond(self._cond.copy(), + self._src1.copy(), + self._src2.copy()) + + def depth(self): + return max(self._cond.depth(), + self._src1.depth(), + self._src2.depth()) + 1 + + def graph_recursive(self, graph): + graph.add_node(self) + for arg in [self._cond, self._src1, self._src2]: + arg.graph_recursive(graph) + graph.add_uniq_edge(self, arg) + + def is_cond(self): + return True + + +class ExprMem(Expr): + + """An ExprMem stand for a memory access + + Use cases: + - Memory read + - Memory write + """ + + __slots__ = Expr.__slots__ + ["_ptr"] + + def __init__(self, ptr, size=None): + """Create an ExprMem + @ptr: Expr, memory access address + @size: int, memory access size + """ + if size is None: + warnings.warn('DEPRECATION WARNING: size is a mandatory argument: use ExprMem(ptr, SIZE)') + size = 32 + + # ptr must be Expr + assert isinstance(ptr, Expr) + assert isinstance(size, int_types) + + if not isinstance(ptr, Expr): + raise ValueError( + 'ExprMem: ptr must be an Expr (not %s)' % type(ptr)) + + super(ExprMem, self).__init__(size) + self._ptr = ptr + + def get_arg(self): + warnings.warn('DEPRECATION WARNING: use exprmem.ptr instead of exprmem.arg') + return self.ptr + + def set_arg(self, value): + warnings.warn('DEPRECATION WARNING: use exprmem.ptr instead of exprmem.arg') + self.ptr = value + + ptr = property(lambda self: self._ptr) + arg = property(get_arg, set_arg) + + def __reduce__(self): + state = self._ptr, self._size + return self.__class__, state + + def __new__(cls, ptr, size=None): + if size is None: + warnings.warn('DEPRECATION WARNING: size is a mandatory argument: use ExprMem(ptr, SIZE)') + size = 32 + + return Expr.get_object(cls, (ptr, size)) + + def __str__(self): + return "@%d[%s]" % (self.size, str(self.ptr)) + + def get_w(self): + return set([self]) # [memreg] + + def _exprhash(self): + return hash((EXPRMEM, hash(self._ptr), self._size)) + + def _exprrepr(self): + return "%s(%r, %r)" % (self.__class__.__name__, + self._ptr, self._size) + + def copy(self): + ptr = self.ptr.copy() + return ExprMem(ptr, size=self.size) + + def is_mem_segm(self): + """Returns True if is ExprMem and ptr is_op_segm""" + warnings.warn('DEPRECATION WARNING: use is_mem_segm(expr)') + raise RuntimeError("Moved api") + + def depth(self): + return self._ptr.depth() + 1 + + def graph_recursive(self, graph): + graph.add_node(self) + self._ptr.graph_recursive(graph) + graph.add_uniq_edge(self, self._ptr) + + def is_mem(self): + return True + + +class ExprOp(Expr): + + """An ExprOp stand for an operation between Expr + + Use cases: + - var1 XOR var2 + - var1 + var2 + var3 + - parity bit(var1) + """ + + __slots__ = Expr.__slots__ + ["_op", "_args"] + + def __init__(self, op, *args): + """Create an ExprOp + @op: str, operation + @*args: Expr, operand list + """ + + # args must be Expr + assert all(isinstance(arg, Expr) for arg in args) + + sizes = set([arg.size for arg in args]) + + if len(sizes) != 1: + # Special cases : operande sizes can differ + if op not in [ + "segm", + "FLAG_EQ_ADDWC", "FLAG_EQ_SUBWC", + "FLAG_SIGN_ADDWC", "FLAG_SIGN_SUBWC", + "FLAG_ADDWC_CF", "FLAG_ADDWC_OF", + "FLAG_SUBWC_CF", "FLAG_SUBWC_OF", + + ]: + raise ValueError( + "sanitycheck: ExprOp args must have same size! %s" % + ([(str(arg), arg.size) for arg in args])) + + if not isinstance(op, str): + raise ValueError("ExprOp: 'op' argument must be a string") + + assert isinstance(args, tuple) + self._op, self._args = op, args + + # Set size for special cases + if self._op in [ + TOK_EQUAL, 'parity', 'fcom_c0', 'fcom_c1', 'fcom_c2', 'fcom_c3', + 'fxam_c0', 'fxam_c1', 'fxam_c2', 'fxam_c3', + "access_segment_ok", "load_segment_limit_ok", "bcdadd_cf", + "ucomiss_zf", "ucomiss_pf", "ucomiss_cf", + "ucomisd_zf", "ucomisd_pf", "ucomisd_cf"]: + size = 1 + elif self._op in [TOK_INF, TOK_INF_SIGNED, + TOK_INF_UNSIGNED, TOK_INF_EQUAL, + TOK_INF_EQUAL_SIGNED, TOK_INF_EQUAL_UNSIGNED, + TOK_EQUAL, TOK_POS, + TOK_POS_STRICT, + ]: + size = 1 + elif self._op.startswith("fp_to_sint"): + size = int(self._op[len("fp_to_sint"):]) + elif self._op.startswith("fpconvert_fp"): + size = int(self._op[len("fpconvert_fp"):]) + elif self._op in [ + "FLAG_ADD_CF", "FLAG_SUB_CF", + "FLAG_ADD_OF", "FLAG_SUB_OF", + "FLAG_EQ", "FLAG_EQ_CMP", + "FLAG_SIGN_SUB", "FLAG_SIGN_ADD", + "FLAG_EQ_AND", + "FLAG_EQ_ADDWC", "FLAG_EQ_SUBWC", + "FLAG_SIGN_ADDWC", "FLAG_SIGN_SUBWC", + "FLAG_ADDWC_CF", "FLAG_ADDWC_OF", + "FLAG_SUBWC_CF", "FLAG_SUBWC_OF", + ]: + size = 1 + + elif self._op.startswith('signExt_'): + size = int(self._op[8:]) + elif self._op.startswith('zeroExt_'): + size = int(self._op[8:]) + elif self._op in ['segm']: + size = self._args[1].size + else: + if None in sizes: + size = None + else: + # All arguments have the same size + size = list(sizes)[0] + + super(ExprOp, self).__init__(size) + + op = property(lambda self: self._op) + args = property(lambda self: self._args) + + def __reduce__(self): + state = tuple([self._op] + list(self._args)) + return self.__class__, state + + def __new__(cls, op, *args): + return Expr.get_object(cls, (op, args)) + + def __str__(self): + if self._op == '-': # Unary minus + return '-' + str_protected_child(self._args[0], self) + if self.is_associative() or self.is_infix(): + return (' ' + self._op + ' ').join([str_protected_child(arg, self) + for arg in self._args]) + return (self._op + '(' + + ', '.join([str(arg) for arg in self._args]) + ')') + + def get_w(self): + raise ValueError('op cannot be written!', self) + + def _exprhash(self): + h_hargs = [hash(arg) for arg in self._args] + return hash((EXPROP, self._op, tuple(h_hargs))) + + def _exprrepr(self): + return "%s(%r, %s)" % (self.__class__.__name__, self._op, + ', '.join(repr(arg) for arg in self._args)) + + def is_function_call(self): + return self._op.startswith('call') + + def is_infix(self): + return self._op in [ + '-', '+', '*', '^', '&', '|', '>>', '<<', + 'a>>', '>>>', '<<<', '/', '%', '**', + TOK_INF_UNSIGNED, + TOK_INF_SIGNED, + TOK_INF_EQUAL_UNSIGNED, + TOK_INF_EQUAL_SIGNED, + TOK_EQUAL + ] + + def is_associative(self): + "Return True iff current operation is associative" + return (self._op in ['+', '*', '^', '&', '|']) + + def is_commutative(self): + "Return True iff current operation is commutative" + return (self._op in ['+', '*', '^', '&', '|']) + + def copy(self): + args = [arg.copy() for arg in self._args] + return ExprOp(self._op, *args) + + def depth(self): + depth = [arg.depth() for arg in self._args] + return max(depth) + 1 + + def graph_recursive(self, graph): + graph.add_node(self) + for arg in self._args: + arg.graph_recursive(graph) + graph.add_uniq_edge(self, arg) + + def is_op(self, op=None): + if op is None: + return True + return self.op == op + + def is_op_segm(self): + """Returns True if is ExprOp and op == 'segm'""" + warnings.warn('DEPRECATION WARNING: use is_op_segm(expr)') + raise RuntimeError("Moved api") + +class ExprSlice(Expr): + + __slots__ = Expr.__slots__ + ["_arg", "_start", "_stop"] + + def __init__(self, arg, start, stop): + + # arg must be Expr + assert isinstance(arg, Expr) + assert isinstance(start, int_types) + assert isinstance(stop, int_types) + assert start < stop + + self._arg, self._start, self._stop = arg, start, stop + super(ExprSlice, self).__init__(self._stop - self._start) + + arg = property(lambda self: self._arg) + start = property(lambda self: self._start) + stop = property(lambda self: self._stop) + + def __reduce__(self): + state = self._arg, self._start, self._stop + return self.__class__, state + + def __new__(cls, arg, start, stop): + return Expr.get_object(cls, (arg, start, stop)) + + def __str__(self): + return "%s[%d:%d]" % (str_protected_child(self._arg, self), self._start, self._stop) + + def get_w(self): + return self._arg.get_w() + + def _exprhash(self): + return hash((EXPRSLICE, hash(self._arg), self._start, self._stop)) + + def _exprrepr(self): + return "%s(%r, %d, %d)" % (self.__class__.__name__, self._arg, + self._start, self._stop) + + def copy(self): + return ExprSlice(self._arg.copy(), self._start, self._stop) + + def depth(self): + return self._arg.depth() + 1 + + def slice_rest(self): + "Return the completion of the current slice" + size = self._arg.size + if self._start >= size or self._stop > size: + raise ValueError('bad slice rest %s %s %s' % + (size, self._start, self._stop)) + + if self._start == self._stop: + return [(0, size)] + + rest = [] + if self._start != 0: + rest.append((0, self._start)) + if self._stop < size: + rest.append((self._stop, size)) + + return rest + + def graph_recursive(self, graph): + graph.add_node(self) + self._arg.graph_recursive(graph) + graph.add_uniq_edge(self, self._arg) + + def is_slice(self, start=None, stop=None): + if start is not None and self._start != start: + return False + if stop is not None and self._stop != stop: + return False + return True + + +class ExprCompose(Expr): + + """ + Compose is like a hamburger. It concatenate Expressions + """ + + __slots__ = Expr.__slots__ + ["_args"] + + def __init__(self, *args): + """Create an ExprCompose + The ExprCompose is contiguous and starts at 0 + @args: [Expr, Expr, ...] + DEPRECATED: + @args: [(Expr, int, int), (Expr, int, int), ...] + """ + + # args must be Expr + assert all(isinstance(arg, Expr) for arg in args) + + assert isinstance(args, tuple) + self._args = args + super(ExprCompose, self).__init__(sum(arg.size for arg in args)) + + args = property(lambda self: self._args) + + def __reduce__(self): + state = self._args + return self.__class__, state + + def __new__(cls, *args): + return Expr.get_object(cls, args) + + def __str__(self): + return '{' + ', '.join(["%s %s %s" % (arg, idx, idx + arg.size) for idx, arg in self.iter_args()]) + '}' + + def get_w(self): + return reduce(lambda elements, arg: + elements.union(arg.get_w()), self._args, set()) + + def _exprhash(self): + h_args = [EXPRCOMPOSE] + [hash(arg) for arg in self._args] + return hash(tuple(h_args)) + + def _exprrepr(self): + return "%s%r" % (self.__class__.__name__, self._args) + + def copy(self): + args = [arg.copy() for arg in self._args] + return ExprCompose(*args) + + def depth(self): + depth = [arg.depth() for arg in self._args] + return max(depth) + 1 + + def graph_recursive(self, graph): + graph.add_node(self) + for arg in self.args: + arg.graph_recursive(graph) + graph.add_uniq_edge(self, arg) + + def iter_args(self): + index = 0 + for arg in self._args: + yield index, arg + index += arg.size + + def is_compose(self): + return True + +# Expression order for comparison +EXPR_ORDER_DICT = { + ExprId: 1, + ExprLoc: 2, + ExprCond: 3, + ExprMem: 4, + ExprOp: 5, + ExprSlice: 6, + ExprCompose: 7, + ExprInt: 8, +} + + +def compare_exprs_compose(expr1, expr2): + # Sort by start bit address, then expr, then stop bit address + ret = cmp_elts(expr1[1], expr2[1]) + if ret: + return ret + ret = compare_exprs(expr1[0], expr2[0]) + if ret: + return ret + ret = cmp_elts(expr1[2], expr2[2]) + return ret + + +def compare_expr_list_compose(l1_e, l2_e): + # Sort by list elements in incremental order, then by list size + for i in range(min(len(l1_e), len(l2_e))): + ret = compare_exprs(l1_e[i], l2_e[i]) + if ret: + return ret + return cmp_elts(len(l1_e), len(l2_e)) + + +def compare_expr_list(l1_e, l2_e): + # Sort by list elements in incremental order, then by list size + for i in range(min(len(l1_e), len(l2_e))): + ret = compare_exprs(l1_e[i], l2_e[i]) + if ret: + return ret + return cmp_elts(len(l1_e), len(l2_e)) + + +def compare_exprs(expr1, expr2): + """Compare 2 expressions for canonization + @expr1: Expr + @expr2: Expr + 0 => == + 1 => expr1 > expr2 + -1 => expr1 < expr2 + """ + cls1 = expr1.__class__ + cls2 = expr2.__class__ + if cls1 != cls2: + return cmp_elts(EXPR_ORDER_DICT[cls1], EXPR_ORDER_DICT[cls2]) + if expr1 == expr2: + return 0 + if cls1 == ExprInt: + ret = cmp_elts(expr1.size, expr2.size) + if ret != 0: + return ret + return cmp_elts(expr1.arg, expr2.arg) + elif cls1 == ExprId: + name1 = force_bytes(expr1.name) + name2 = force_bytes(expr2.name) + ret = cmp_elts(name1, name2) + if ret: + return ret + return cmp_elts(expr1.size, expr2.size) + elif cls1 == ExprLoc: + ret = cmp_elts(expr1.loc_key, expr2.loc_key) + if ret: + return ret + return cmp_elts(expr1.size, expr2.size) + elif cls1 == ExprAssign: + raise NotImplementedError( + "Comparison from an ExprAssign not yet implemented" + ) + elif cls2 == ExprCond: + ret = compare_exprs(expr1.cond, expr2.cond) + if ret: + return ret + ret = compare_exprs(expr1.src1, expr2.src1) + if ret: + return ret + ret = compare_exprs(expr1.src2, expr2.src2) + return ret + elif cls1 == ExprMem: + ret = compare_exprs(expr1.ptr, expr2.ptr) + if ret: + return ret + return cmp_elts(expr1.size, expr2.size) + elif cls1 == ExprOp: + if expr1.op != expr2.op: + return cmp_elts(expr1.op, expr2.op) + return compare_expr_list(expr1.args, expr2.args) + elif cls1 == ExprSlice: + ret = compare_exprs(expr1.arg, expr2.arg) + if ret: + return ret + ret = cmp_elts(expr1.start, expr2.start) + if ret: + return ret + ret = cmp_elts(expr1.stop, expr2.stop) + return ret + elif cls1 == ExprCompose: + return compare_expr_list_compose(expr1.args, expr2.args) + raise NotImplementedError( + "Comparison between %r %r not implemented" % (expr1, expr2) + ) + + +def canonize_expr_list(expr_list): + return sorted(expr_list, key=cmp_to_key(compare_exprs)) + + +def canonize_expr_list_compose(expr_list): + return sorted(expr_list, key=cmp_to_key(compare_exprs_compose)) + +# Generate ExprInt with common size + + +def ExprInt1(i): + warnings.warn('DEPRECATION WARNING: use ExprInt(i, 1) instead of '\ + 'ExprInt1(i))') + return ExprInt(i, 1) + + +def ExprInt8(i): + warnings.warn('DEPRECATION WARNING: use ExprInt(i, 8) instead of '\ + 'ExprInt8(i))') + return ExprInt(i, 8) + + +def ExprInt16(i): + warnings.warn('DEPRECATION WARNING: use ExprInt(i, 16) instead of '\ + 'ExprInt16(i))') + return ExprInt(i, 16) + + +def ExprInt32(i): + warnings.warn('DEPRECATION WARNING: use ExprInt(i, 32) instead of '\ + 'ExprInt32(i))') + return ExprInt(i, 32) + + +def ExprInt64(i): + warnings.warn('DEPRECATION WARNING: use ExprInt(i, 64) instead of '\ + 'ExprInt64(i))') + return ExprInt(i, 64) + + +def ExprInt_from(expr, i): + "Generate ExprInt with size equal to expression" + warnings.warn('DEPRECATION WARNING: use ExprInt(i, expr.size) instead of'\ + 'ExprInt_from(expr, i))') + return ExprInt(i, expr.size) + + +def get_expr_ids_visit(expr, ids): + """Visitor to retrieve ExprId in @expr + @expr: Expr""" + if expr.is_id(): + ids.add(expr) + return expr + + +def get_expr_locs_visit(expr, locs): + """Visitor to retrieve ExprLoc in @expr + @expr: Expr""" + if expr.is_loc(): + locs.add(expr) + return expr + + +def get_expr_ids(expr): + """Retrieve ExprId in @expr + @expr: Expr""" + ids = set() + expr.visit(lambda x: get_expr_ids_visit(x, ids)) + return ids + + +def get_expr_locs(expr): + """Retrieve ExprLoc in @expr + @expr: Expr""" + locs = set() + expr.visit(lambda x: get_expr_locs_visit(x, locs)) + return locs + + +def test_set(expr, pattern, tks, result): + """Test if v can correspond to e. If so, update the context in result. + Otherwise, return False + @expr : Expr to match + @pattern : pattern Expr + @tks : list of ExprId, available jokers + @result : dictionary of ExprId -> Expr, current context + """ + + if not pattern in tks: + return expr == pattern + if pattern in result and result[pattern] != expr: + return False + result[pattern] = expr + return result + + +def match_expr(expr, pattern, tks, result=None): + """Try to match the @pattern expression with the pattern @expr with @tks jokers. + Result is output dictionary with matching joker values. + @expr : Expr pattern + @pattern : Targeted Expr to match + @tks : list of ExprId, available jokers + @result : dictionary of ExprId -> Expr, output matching context + """ + + if result is None: + result = {} + + if pattern in tks: + # pattern is a Joker + return test_set(expr, pattern, tks, result) + + if expr.is_int(): + return test_set(expr, pattern, tks, result) + + elif expr.is_id(): + return test_set(expr, pattern, tks, result) + + elif expr.is_loc(): + return test_set(expr, pattern, tks, result) + + elif expr.is_op(): + + # expr need to be the same operation than pattern + if not pattern.is_op(): + return False + if expr.op != pattern.op: + return False + if len(expr.args) != len(pattern.args): + return False + + # Perform permutation only if the current operation is commutative + if expr.is_commutative(): + permutations = itertools.permutations(expr.args) + else: + permutations = [expr.args] + + # For each permutations of arguments + for permut in permutations: + good = True + # We need to use a copy of result to not override it + myresult = dict(result) + for sub_expr, sub_pattern in zip(permut, pattern.args): + ret = match_expr(sub_expr, sub_pattern, tks, myresult) + # If the current permutation do not match EVERY terms + if ret is False: + good = False + break + if good is True: + # We found a possibility + for joker, value in viewitems(myresult): + # Updating result in place (to keep pointer in recursion) + result[joker] = value + return result + return False + + # Recursive tests + + elif expr.is_mem(): + if not pattern.is_mem(): + return False + if expr.size != pattern.size: + return False + return match_expr(expr.ptr, pattern.ptr, tks, result) + + elif expr.is_slice(): + if not pattern.is_slice(): + return False + if expr.start != pattern.start or expr.stop != pattern.stop: + return False + return match_expr(expr.arg, pattern.arg, tks, result) + + elif expr.is_cond(): + if not pattern.is_cond(): + return False + if match_expr(expr.cond, pattern.cond, tks, result) is False: + return False + if match_expr(expr.src1, pattern.src1, tks, result) is False: + return False + if match_expr(expr.src2, pattern.src2, tks, result) is False: + return False + return result + + elif expr.is_compose(): + if not pattern.is_compose(): + return False + for sub_expr, sub_pattern in zip(expr.args, pattern.args): + if match_expr(sub_expr, sub_pattern, tks, result) is False: + return False + return result + + elif expr.is_assign(): + if not pattern.is_assign(): + return False + if match_expr(expr.src, pattern.src, tks, result) is False: + return False + if match_expr(expr.dst, pattern.dst, tks, result) is False: + return False + return result + + else: + raise NotImplementedError("match_expr: Unknown type: %s" % type(expr)) + + +def MatchExpr(expr, pattern, tks, result=None): + warnings.warn('DEPRECATION WARNING: use match_expr instead of MatchExpr') + return match_expr(expr, pattern, tks, result) + + +def get_rw(exprs): + o_r = set() + o_w = set() + for expr in exprs: + o_r.update(expr.get_r(mem_read=True)) + for expr in exprs: + o_w.update(expr.get_w()) + return o_r, o_w + + +def get_list_rw(exprs, mem_read=False, cst_read=True): + """Return list of read/write reg/cst/mem for each @exprs + @exprs: list of expressions + @mem_read: walk though memory accesses + @cst_read: retrieve constants + """ + list_rw = [] + # cst_num = 0 + for expr in exprs: + o_r = set() + o_w = set() + # get r/w + o_r.update(expr.get_r(mem_read=mem_read, cst_read=cst_read)) + if isinstance(expr.dst, ExprMem): + o_r.update(expr.dst.arg.get_r(mem_read=mem_read, cst_read=cst_read)) + o_w.update(expr.get_w()) + # each cst is indexed + o_r_rw = set() + for read in o_r: + o_r_rw.add(read) + o_r = o_r_rw + list_rw.append((o_r, o_w)) + + return list_rw + + +def get_expr_ops(expr): + """Retrieve operators of an @expr + @expr: Expr""" + def visit_getops(expr, out=None): + if out is None: + out = set() + if isinstance(expr, ExprOp): + out.add(expr.op) + return expr + ops = set() + expr.visit(lambda x: visit_getops(x, ops)) + return ops + + +def get_expr_mem(expr): + """Retrieve memory accesses of an @expr + @expr: Expr""" + def visit_getmem(expr, out=None): + if out is None: + out = set() + if isinstance(expr, ExprMem): + out.add(expr) + return expr + ops = set() + expr.visit(lambda x: visit_getmem(x, ops)) + return ops + + +def _expr_compute_cf(op1, op2): + """ + Get carry flag of @op1 - @op2 + Ref: x86 cf flag + @op1: Expression + @op2: Expression + """ + res = op1 - op2 + cf = (((op1 ^ op2) ^ res) ^ ((op1 ^ res) & (op1 ^ op2))).msb() + return cf + +def _expr_compute_of(op1, op2): + """ + Get overflow flag of @op1 - @op2 + Ref: x86 of flag + @op1: Expression + @op2: Expression + """ + res = op1 - op2 + of = (((op1 ^ res) & (op1 ^ op2))).msb() + return of + +def _expr_compute_zf(op1, op2): + """ + Get zero flag of @op1 - @op2 + @op1: Expression + @op2: Expression + """ + res = op1 - op2 + zf = ExprCond(res, + ExprInt(0, 1), + ExprInt(1, 1)) + return zf + + +def _expr_compute_nf(op1, op2): + """ + Get negative (or sign) flag of @op1 - @op2 + @op1: Expression + @op2: Expression + """ + res = op1 - op2 + nf = res.msb() + return nf + + +def expr_is_equal(op1, op2): + """ + if op1 == op2: + Return ExprInt(1, 1) + else: + Return ExprInt(0, 1) + """ + + zf = _expr_compute_zf(op1, op2) + return zf + + +def expr_is_not_equal(op1, op2): + """ + if op1 != op2: + Return ExprInt(1, 1) + else: + Return ExprInt(0, 1) + """ + + zf = _expr_compute_zf(op1, op2) + return ~zf + + +def expr_is_unsigned_greater(op1, op2): + """ + UNSIGNED cmp + if op1 > op2: + Return ExprInt(1, 1) + else: + Return ExprInt(0, 1) + """ + + cf = _expr_compute_cf(op1, op2) + zf = _expr_compute_zf(op1, op2) + return ~(cf | zf) + + +def expr_is_unsigned_greater_or_equal(op1, op2): + """ + Unsigned cmp + if op1 >= op2: + Return ExprInt(1, 1) + else: + Return ExprInt(0, 1) + """ + + cf = _expr_compute_cf(op1, op2) + return ~cf + + +def expr_is_unsigned_lower(op1, op2): + """ + Unsigned cmp + if op1 < op2: + Return ExprInt(1, 1) + else: + Return ExprInt(0, 1) + """ + + cf = _expr_compute_cf(op1, op2) + return cf + + +def expr_is_unsigned_lower_or_equal(op1, op2): + """ + Unsigned cmp + if op1 <= op2: + Return ExprInt(1, 1) + else: + Return ExprInt(0, 1) + """ + + cf = _expr_compute_cf(op1, op2) + zf = _expr_compute_zf(op1, op2) + return cf | zf + + +def expr_is_signed_greater(op1, op2): + """ + Signed cmp + if op1 > op2: + Return ExprInt(1, 1) + else: + Return ExprInt(0, 1) + """ + + nf = _expr_compute_nf(op1, op2) + of = _expr_compute_of(op1, op2) + zf = _expr_compute_zf(op1, op2) + return ~(zf | (nf ^ of)) + + +def expr_is_signed_greater_or_equal(op1, op2): + """ + Signed cmp + if op1 > op2: + Return ExprInt(1, 1) + else: + Return ExprInt(0, 1) + """ + + nf = _expr_compute_nf(op1, op2) + of = _expr_compute_of(op1, op2) + return ~(nf ^ of) + + +def expr_is_signed_lower(op1, op2): + """ + Signed cmp + if op1 < op2: + Return ExprInt(1, 1) + else: + Return ExprInt(0, 1) + """ + + nf = _expr_compute_nf(op1, op2) + of = _expr_compute_of(op1, op2) + return nf ^ of + + +def expr_is_signed_lower_or_equal(op1, op2): + """ + Signed cmp + if op1 <= op2: + Return ExprInt(1, 1) + else: + Return ExprInt(0, 1) + """ + + nf = _expr_compute_nf(op1, op2) + of = _expr_compute_of(op1, op2) + zf = _expr_compute_zf(op1, op2) + return zf | (nf ^ of) + +# sign bit | exponent | significand +size_to_IEEE754_info = { + 16: { + "exponent": 5, + "significand": 10, + }, + 32: { + "exponent": 8, + "significand": 23, + }, + 64: { + "exponent": 11, + "significand": 52, + }, +} + +def expr_is_NaN(expr): + """Return 1 or 0 on 1 bit if expr represent a NaN value according to IEEE754 + """ + info = size_to_IEEE754_info[expr.size] + exponent = expr[info["significand"]: info["significand"] + info["exponent"]] + + # exponent is full of 1s and significand is not NULL + return ExprCond(exponent - ExprInt(-1, exponent.size), + ExprInt(0, 1), + ExprCond(expr[:info["significand"]], ExprInt(1, 1), + ExprInt(0, 1))) + + +def expr_is_infinite(expr): + """Return 1 or 0 on 1 bit if expr represent an infinite value according to + IEEE754 + """ + info = size_to_IEEE754_info[expr.size] + exponent = expr[info["significand"]: info["significand"] + info["exponent"]] + + # exponent is full of 1s and significand is NULL + return ExprCond(exponent - ExprInt(-1, exponent.size), + ExprInt(0, 1), + ExprCond(expr[:info["significand"]], ExprInt(0, 1), + ExprInt(1, 1))) + + +def expr_is_IEEE754_zero(expr): + """Return 1 or 0 on 1 bit if expr represent a zero value according to + IEEE754 + """ + # Sign is the msb + expr_no_sign = expr[:expr.size - 1] + return ExprCond(expr_no_sign, ExprInt(0, 1), ExprInt(1, 1)) + + +def expr_is_IEEE754_denormal(expr): + """Return 1 or 0 on 1 bit if expr represent a denormalized value according + to IEEE754 + """ + info = size_to_IEEE754_info[expr.size] + exponent = expr[info["significand"]: info["significand"] + info["exponent"]] + # exponent is full of 0s + return ExprCond(exponent, ExprInt(0, 1), ExprInt(1, 1)) + + +def expr_is_qNaN(expr): + """Return 1 or 0 on 1 bit if expr represent a qNaN (quiet) value according to + IEEE754 + """ + info = size_to_IEEE754_info[expr.size] + significand_top = expr[info["significand"]: info["significand"] + 1] + return expr_is_NaN(expr) & significand_top + + +def expr_is_sNaN(expr): + """Return 1 or 0 on 1 bit if expr represent a sNaN (signalling) value according + to IEEE754 + """ + info = size_to_IEEE754_info[expr.size] + significand_top = expr[info["significand"]: info["significand"] + 1] + return expr_is_NaN(expr) & ~significand_top + + +def expr_is_float_lower(op1, op2): + """Return 1 on 1 bit if @op1 < @op2, 0 otherwise. + [!] Assume @op1 and @op2 are not NaN + Comparison is the floating point one, defined in IEEE754 + """ + sign1, sign2 = op1.msb(), op2.msb() + magn1, magn2 = op1[:-1], op2[:-1] + return ExprCond(sign1 ^ sign2, + # Sign different, only the sign matters + sign1, # sign1 ? op1 < op2 : op1 >= op2 + # Sign equals, the result is inversed for negatives + sign1 ^ (expr_is_unsigned_lower(magn1, magn2))) + + +def expr_is_float_equal(op1, op2): + """Return 1 on 1 bit if @op1 == @op2, 0 otherwise. + [!] Assume @op1 and @op2 are not NaN + Comparison is the floating point one, defined in IEEE754 + """ + sign1, sign2 = op1.msb(), op2.msb() + magn1, magn2 = op1[:-1], op2[:-1] + return ExprCond(magn1 ^ magn2, + ExprInt(0, 1), + ExprCond(magn1, + # magn1 == magn2, are the signal equals? + ~(sign1 ^ sign2), + # Special case: -0.0 == +0.0 + ExprInt(1, 1)) + ) diff --git a/src/miasm/expression/expression_helper.py b/src/miasm/expression/expression_helper.py new file mode 100644 index 00000000..81fc5c90 --- /dev/null +++ b/src/miasm/expression/expression_helper.py @@ -0,0 +1,628 @@ +# +# Copyright (C) 2011 EADS France, Fabrice Desclaux <fabrice.desclaux@eads.net> +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with this program; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# + +# Expressions manipulation functions +from builtins import range +import itertools +import collections +import random +import string +import warnings + +from future.utils import viewitems, viewvalues + +import miasm.expression.expression as m2_expr + + +def parity(a): + tmp = (a) & 0xFF + cpt = 1 + while tmp != 0: + cpt ^= tmp & 1 + tmp >>= 1 + return cpt + + +def merge_sliceto_slice(expr): + """ + Apply basic factorisation on ExprCompose sub components + @expr: ExprCompose + """ + + out_args = [] + last_index = 0 + for index, arg in expr.iter_args(): + # Init + if len(out_args) == 0: + out_args.append(arg) + continue + + last_value = out_args[-1] + # Consecutive + + if last_index + last_value.size == index: + # Merge consecutive integers + if (isinstance(arg, m2_expr.ExprInt) and + isinstance(last_value, m2_expr.ExprInt)): + new_size = last_value.size + arg.size + value = int(arg) << last_value.size + value |= int(last_value) + out_args[-1] = m2_expr.ExprInt(value, size=new_size) + continue + + # Merge consecuvite slice + elif (isinstance(arg, m2_expr.ExprSlice) and + isinstance(last_value, m2_expr.ExprSlice)): + value = arg.arg + if (last_value.arg == value and + last_value.stop == arg.start): + out_args[-1] = value[last_value.start:arg.stop] + continue + + # Unmergeable + last_index = index + out_args.append(arg) + + return out_args + + +op_propag_cst = ['+', '*', '^', '&', '|', '>>', + '<<', "a>>", ">>>", "<<<", + "/", "%", 'sdiv', 'smod', 'umod', 'udiv','**'] + + +def is_pure_int(e): + """ + return True if expr is only composed with integers + [!] ExprCond returns True if src1 and src2 are integers + """ + def modify_cond(e): + if isinstance(e, m2_expr.ExprCond): + return e.src1 | e.src2 + return e + + def find_int(e, s): + if isinstance(e, m2_expr.ExprId) or isinstance(e, m2_expr.ExprMem): + s.add(e) + return e + s = set() + new_e = e.visit(modify_cond) + new_e.visit(lambda x: find_int(x, s)) + if s: + return False + return True + + +def is_int_or_cond_src_int(e): + if isinstance(e, m2_expr.ExprInt): + return True + if isinstance(e, m2_expr.ExprCond): + return (isinstance(e.src1, m2_expr.ExprInt) and + isinstance(e.src2, m2_expr.ExprInt)) + return False + + +def fast_unify(seq, idfun=None): + # order preserving unifying list function + if idfun is None: + idfun = lambda x: x + seen = {} + result = [] + for item in seq: + marker = idfun(item) + + if marker in seen: + continue + seen[marker] = 1 + result.append(item) + return result + +def get_missing_interval(all_intervals, i_min=0, i_max=32): + """Return a list of missing interval in all_interval + @all_interval: list of (int, int) + @i_min: int, minimal missing interval bound + @i_max: int, maximal missing interval bound""" + + my_intervals = all_intervals[:] + my_intervals.sort() + my_intervals.append((i_max, i_max)) + + missing_i = [] + last_pos = i_min + for start, stop in my_intervals: + if last_pos != start: + missing_i.append((last_pos, start)) + last_pos = stop + return missing_i + + +class Variables_Identifier(object): + """Identify variables in an expression. + Returns: + - variables with their corresponding values + - original expression with variables translated + """ + + def __init__(self, expr, var_prefix="v"): + """Set the expression @expr to handle and launch variable identification + process + @expr: Expr instance + @var_prefix: (optional) prefix of the variable name, default is 'v'""" + + # Init + self.var_indice = itertools.count() + self.var_asked = set() + self._vars = {} # VarID -> Expr + self.var_prefix = var_prefix + + # Launch recurrence + self.find_variables_rec(expr) + + # Compute inter-variable dependencies + has_change = True + while has_change: + has_change = False + for var_id, var_value in list(viewitems(self._vars)): + cur = var_value + + # Do not replace with itself + to_replace = { + v_val:v_id + for v_id, v_val in viewitems(self._vars) + if v_id != var_id + } + var_value = var_value.replace_expr(to_replace) + + if cur != var_value: + # Force @self._vars update + has_change = True + self._vars[var_id] = var_value + break + + # Replace in the original equation + self._equation = expr.replace_expr( + { + v_val: v_id for v_id, v_val + in viewitems(self._vars) + } + ) + + # Compute variables dependencies + self._vars_ordered = collections.OrderedDict() + todo = set(self._vars) + needs = {} + + ## Build initial needs + for var_id, var_expr in viewitems(self._vars): + ### Handle corner cases while using Variable Identifier on an + ### already computed equation + needs[var_id] = [ + var_name + for var_name in var_expr.get_r(mem_read=True) + if self.is_var_identifier(var_name) and \ + var_name in todo and \ + var_name != var_id + ] + + ## Build order list + while todo: + done = set() + for var_id in todo: + all_met = True + for need in needs[var_id]: + if need not in self._vars_ordered: + # A dependency is not met + all_met = False + break + if not all_met: + continue + + # All dependencies are already met, add current + self._vars_ordered[var_id] = self._vars[var_id] + done.add(var_id) + + # Update the todo list + for element_done in done: + todo.remove(element_done) + + def is_var_identifier(self, expr): + "Return True iff @expr is a variable identifier" + if not isinstance(expr, m2_expr.ExprId): + return False + return expr in self._vars + + def find_variables_rec(self, expr): + """Recursive method called by find_variable to expand @expr. + Set @var_names and @var_values. + This implementation is faster than an expression visitor because + we do not rebuild each expression. + """ + + if (expr in self.var_asked): + # Expr has already been asked + if expr not in viewvalues(self._vars): + # Create var + identifier = m2_expr.ExprId( + "%s%s" % ( + self.var_prefix, + next(self.var_indice) + ), + size = expr.size + ) + self._vars[identifier] = expr + + # Recursion stop case + return + else: + # First time for @expr + self.var_asked.add(expr) + + if isinstance(expr, m2_expr.ExprOp): + for a in expr.args: + self.find_variables_rec(a) + + elif isinstance(expr, m2_expr.ExprInt): + pass + + elif isinstance(expr, m2_expr.ExprId): + pass + + elif isinstance(expr, m2_expr.ExprLoc): + pass + + elif isinstance(expr, m2_expr.ExprMem): + self.find_variables_rec(expr.ptr) + + elif isinstance(expr, m2_expr.ExprCompose): + for arg in expr.args: + self.find_variables_rec(arg) + + elif isinstance(expr, m2_expr.ExprSlice): + self.find_variables_rec(expr.arg) + + elif isinstance(expr, m2_expr.ExprCond): + self.find_variables_rec(expr.cond) + self.find_variables_rec(expr.src1) + self.find_variables_rec(expr.src2) + + else: + raise NotImplementedError("Type not handled: %s" % expr) + + @property + def vars(self): + return self._vars_ordered + + @property + def equation(self): + return self._equation + + def __str__(self): + "Display variables and final equation" + out = "" + for var_id, var_expr in viewitems(self.vars): + out += "%s = %s\n" % (var_id, var_expr) + out += "Final: %s" % self.equation + return out + + +class ExprRandom(object): + """Return an expression randomly generated""" + + # Identifiers length + identifier_len = 5 + # Identifiers' name charset + identifier_charset = string.ascii_letters + # Number max value + number_max = 0xFFFFFFFF + # Available operations + operations_by_args_number = {1: ["-"], + 2: ["<<", "<<<", ">>", ">>>"], + "2+": ["+", "*", "&", "|", "^"], + } + # Maximum number of argument for operations + operations_max_args_number = 5 + # If set, output expression is a perfect tree + perfect_tree = True + # Max argument size in slice, relative to slice size + slice_add_size = 10 + # Maximum number of layer in compose + compose_max_layer = 5 + # Maximum size of memory address in bits + memory_max_address_size = 32 + # Reuse already generated elements to mimic a more realistic behavior + reuse_element = True + generated_elements = {} # (depth, size) -> [Expr] + + @classmethod + def identifier(cls, size=32): + """Return a random identifier + @size: (optional) identifier size + """ + return m2_expr.ExprId("".join([random.choice(cls.identifier_charset) + for _ in range(cls.identifier_len)]), + size=size) + + @classmethod + def number(cls, size=32): + """Return a random number + @size: (optional) number max bits + """ + num = random.randint(0, cls.number_max % (2**size)) + return m2_expr.ExprInt(num, size) + + @classmethod + def atomic(cls, size=32): + """Return an atomic Expression + @size: (optional) Expr size + """ + available_funcs = [cls.identifier, cls.number] + return random.choice(available_funcs)(size=size) + + @classmethod + def operation(cls, size=32, depth=1): + """Return an ExprOp + @size: (optional) Operation size + @depth: (optional) Expression depth + """ + operand_type = random.choice(list(cls.operations_by_args_number)) + if isinstance(operand_type, str) and "+" in operand_type: + number_args = random.randint( + int(operand_type[:-1]), + cls.operations_max_args_number + ) + else: + number_args = operand_type + + args = [cls._gen(size=size, depth=depth - 1) + for _ in range(number_args)] + operand = random.choice(cls.operations_by_args_number[operand_type]) + return m2_expr.ExprOp(operand, + *args) + + @classmethod + def slice(cls, size=32, depth=1): + """Return an ExprSlice + @size: (optional) Operation size + @depth: (optional) Expression depth + """ + start = random.randint(0, size) + stop = start + size + return cls._gen(size=random.randint(stop, stop + cls.slice_add_size), + depth=depth - 1)[start:stop] + + @classmethod + def compose(cls, size=32, depth=1): + """Return an ExprCompose + @size: (optional) Operation size + @depth: (optional) Expression depth + """ + # First layer + upper_bound = random.randint(1, size) + args = [cls._gen(size=upper_bound, depth=depth - 1)] + + # Next layers + while (upper_bound < size): + if len(args) == (cls.compose_max_layer - 1): + # We reach the maximum size + new_upper_bound = size + else: + new_upper_bound = random.randint(upper_bound + 1, size) + + args.append(cls._gen(size=new_upper_bound - upper_bound)) + upper_bound = new_upper_bound + return m2_expr.ExprCompose(*args) + + @classmethod + def memory(cls, size=32, depth=1): + """Return an ExprMem + @size: (optional) Operation size + @depth: (optional) Expression depth + """ + + address_size = random.randint(1, cls.memory_max_address_size) + return m2_expr.ExprMem(cls._gen(size=address_size, + depth=depth - 1), + size=size) + + @classmethod + def _gen(cls, size=32, depth=1): + """Internal function for generating sub-expression according to options + @size: (optional) Operation size + @depth: (optional) Expression depth + [!] @generated_elements is left modified + """ + # Perfect tree handling + if not cls.perfect_tree: + depth = random.randint(max(0, depth - 2), depth) + + # Element reuse + if cls.reuse_element and random.choice([True, False]) and \ + (depth, size) in cls.generated_elements: + return random.choice(cls.generated_elements[(depth, size)]) + + # Recursion stop + if depth == 0: + return cls.atomic(size=size) + + # Build a more complex expression + available_funcs = [cls.operation, cls.slice, cls.compose, cls.memory] + gen = random.choice(available_funcs)(size=size, depth=depth) + + # Save it + new_value = cls.generated_elements.get((depth, size), []) + [gen] + cls.generated_elements[(depth, size)] = new_value + return gen + + @classmethod + def get(cls, size=32, depth=1, clean=True): + """Return a randomly generated expression + @size: (optional) Operation size + @depth: (optional) Expression depth + @clean: (optional) Clean expression cache between two calls + """ + # Init state + if clean: + cls.generated_elements = {} + + # Get an element + got = cls._gen(size=size, depth=depth) + + # Clear state + if clean: + cls.generated_elements = {} + + return got + +def expr_cmpu(arg1, arg2): + """ + Returns a one bit long Expression: + * 1 if @arg1 is strictly greater than @arg2 (unsigned) + * 0 otherwise. + """ + warnings.warn('DEPRECATION WARNING: use "expr_is_unsigned_greater" instead"') + return m2_expr.expr_is_unsigned_greater(arg1, arg2) + +def expr_cmps(arg1, arg2): + """ + Returns a one bit long Expression: + * 1 if @arg1 is strictly greater than @arg2 (signed) + * 0 otherwise. + """ + warnings.warn('DEPRECATION WARNING: use "expr_is_signed_greater" instead"') + return m2_expr.expr_is_signed_greater(arg1, arg2) + + +class CondConstraint(object): + + """Stand for a constraint on an Expr""" + + # str of the associated operator + operator = "" + + def __init__(self, expr): + self.expr = expr + + def __repr__(self): + return "<%s %s 0>" % (self.expr, self.operator) + + def to_constraint(self): + """Transform itself into a constraint using Expr""" + raise NotImplementedError("Abstract method") + + +class CondConstraintZero(CondConstraint): + + """Stand for a constraint like 'A == 0'""" + operator = m2_expr.TOK_EQUAL + + def to_constraint(self): + return m2_expr.ExprAssign(self.expr, m2_expr.ExprInt(0, self.expr.size)) + + +class CondConstraintNotZero(CondConstraint): + + """Stand for a constraint like 'A != 0'""" + operator = "!=" + + def to_constraint(self): + cst1, cst2 = m2_expr.ExprInt(0, 1), m2_expr.ExprInt(1, 1) + return m2_expr.ExprAssign(cst1, m2_expr.ExprCond(self.expr, cst1, cst2)) + + +ConstrainedValue = collections.namedtuple("ConstrainedValue", + ["constraints", "value"]) + + +class ConstrainedValues(set): + + """Set of ConstrainedValue""" + + def __str__(self): + out = [] + for sol in self: + out.append("%s with constraints:" % sol.value) + for constraint in sol.constraints: + out.append("\t%s" % constraint) + return "\n".join(out) + + +def possible_values(expr): + """Return possible values for expression @expr, associated with their + condition constraint as a ConstrainedValues instance + @expr: Expr instance + """ + + consvals = ConstrainedValues() + + # Terminal expression + if (isinstance(expr, m2_expr.ExprInt) or + isinstance(expr, m2_expr.ExprId) or + isinstance(expr, m2_expr.ExprLoc)): + consvals.add(ConstrainedValue(frozenset(), expr)) + # Unary expression + elif isinstance(expr, m2_expr.ExprSlice): + consvals.update(ConstrainedValue(consval.constraints, + consval.value[expr.start:expr.stop]) + for consval in possible_values(expr.arg)) + elif isinstance(expr, m2_expr.ExprMem): + consvals.update(ConstrainedValue(consval.constraints, + m2_expr.ExprMem(consval.value, + expr.size)) + for consval in possible_values(expr.ptr)) + elif isinstance(expr, m2_expr.ExprAssign): + consvals.update(possible_values(expr.src)) + # Special case: constraint insertion + elif isinstance(expr, m2_expr.ExprCond): + src1cond = CondConstraintNotZero(expr.cond) + src2cond = CondConstraintZero(expr.cond) + consvals.update(ConstrainedValue(consval.constraints.union([src1cond]), + consval.value) + for consval in possible_values(expr.src1)) + consvals.update(ConstrainedValue(consval.constraints.union([src2cond]), + consval.value) + for consval in possible_values(expr.src2)) + # N-ary expression + elif isinstance(expr, m2_expr.ExprOp): + # For details, see ExprCompose + consvals_args = [possible_values(arg) for arg in expr.args] + for consvals_possibility in itertools.product(*consvals_args): + args_value = [consval.value for consval in consvals_possibility] + args_constraint = itertools.chain(*[consval.constraints + for consval in consvals_possibility]) + consvals.add(ConstrainedValue(frozenset(args_constraint), + m2_expr.ExprOp(expr.op, *args_value))) + elif isinstance(expr, m2_expr.ExprCompose): + # Generate each possibility for sub-argument, associated with the start + # and stop bit + consvals_args = [ + list(possible_values(arg)) + for arg in expr.args + ] + for consvals_possibility in itertools.product(*consvals_args): + # Merge constraint of each sub-element + args_constraint = itertools.chain(*[consval.constraints + for consval in consvals_possibility]) + # Gen the corresponding constraints / ExprCompose + args = [consval.value for consval in consvals_possibility] + consvals.add( + ConstrainedValue(frozenset(args_constraint), + m2_expr.ExprCompose(*args))) + else: + raise RuntimeError("Unsupported type for expr: %s" % type(expr)) + + return consvals diff --git a/src/miasm/expression/expression_reduce.py b/src/miasm/expression/expression_reduce.py new file mode 100644 index 00000000..41891a09 --- /dev/null +++ b/src/miasm/expression/expression_reduce.py @@ -0,0 +1,280 @@ +""" +Expression reducer: +Apply reduction rules to an Expression ast +""" + +import logging +from miasm.expression.expression import ExprInt, ExprId, ExprLoc, ExprOp, \ + ExprSlice, ExprCompose, ExprMem, ExprCond + +log_reduce = logging.getLogger("expr_reduce") +console_handler = logging.StreamHandler() +console_handler.setFormatter(logging.Formatter("[%(levelname)-8s]: %(message)s")) +log_reduce.addHandler(console_handler) +log_reduce.setLevel(logging.WARNING) + + + +class ExprNode(object): + """Clone of Expression object with additional information""" + + def __init__(self, expr): + self.expr = expr + + +class ExprNodeInt(ExprNode): + def __init__(self, expr): + assert expr.is_int() + super(ExprNodeInt, self).__init__(expr) + self.arg = None + + def __repr__(self): + if self.info is not None: + out = repr(self.info) + else: + out = str(self.expr) + return out + + +class ExprNodeId(ExprNode): + def __init__(self, expr): + assert expr.is_id() + super(ExprNodeId, self).__init__(expr) + self.arg = None + + def __repr__(self): + if self.info is not None: + out = repr(self.info) + else: + out = str(self.expr) + return out + + +class ExprNodeLoc(ExprNode): + def __init__(self, expr): + assert expr.is_loc() + super(ExprNodeLoc, self).__init__(expr) + self.arg = None + + def __repr__(self): + if self.info is not None: + out = repr(self.info) + else: + out = str(self.expr) + return out + + +class ExprNodeMem(ExprNode): + def __init__(self, expr): + assert expr.is_mem() + super(ExprNodeMem, self).__init__(expr) + self.ptr = None + + def __repr__(self): + if self.info is not None: + out = repr(self.info) + else: + out = "@%d[%r]" % (self.expr.size, self.ptr) + return out + + +class ExprNodeOp(ExprNode): + def __init__(self, expr): + assert expr.is_op() + super(ExprNodeOp, self).__init__(expr) + self.args = None + + def __repr__(self): + if self.info is not None: + out = repr(self.info) + else: + if len(self.args) == 1: + out = "(%s(%r))" % (self.expr.op, self.args[0]) + else: + out = "(%s)" % self.expr.op.join(repr(arg) for arg in self.args) + return out + + +class ExprNodeSlice(ExprNode): + def __init__(self, expr): + assert expr.is_slice() + super(ExprNodeSlice, self).__init__(expr) + self.arg = None + + def __repr__(self): + if self.info is not None: + out = repr(self.info) + else: + out = "%r[%d:%d]" % (self.arg, self.expr.start, self.expr.stop) + return out + + +class ExprNodeCompose(ExprNode): + def __init__(self, expr): + assert expr.is_compose() + super(ExprNodeCompose, self).__init__(expr) + self.args = None + + def __repr__(self): + if self.info is not None: + out = repr(self.info) + else: + out = "{%s}" % ', '.join(repr(arg) for arg in self.args) + return out + + +class ExprNodeCond(ExprNode): + def __init__(self, expr): + assert expr.is_cond() + super(ExprNodeCond, self).__init__(expr) + self.cond = None + self.src1 = None + self.src2 = None + + def __repr__(self): + if self.info is not None: + out = repr(self.info) + else: + out = "(%r?%r:%r)" % (self.cond, self.src1, self.src2) + return out + + +class ExprReducer(object): + """Apply reduction rules to an expr + + reduction_rules: list of ordered reduction rules + + List of function representing reduction rules + Function API: + reduction_xxx(self, node, lvl=0) + with: + * node: the ExprNode to qualify + * lvl: [optional] the recursion level + Returns: + * None if the reduction rule is not applied + * the resulting information to store in the ExprNode.info + + allow_none_result: allow missing reduction rules + """ + + reduction_rules = [] + allow_none_result = False + + def expr2node(self, expr): + """Build ExprNode mirror of @expr + + @expr: Expression to analyze + """ + + if isinstance(expr, ExprId): + node = ExprNodeId(expr) + elif isinstance(expr, ExprLoc): + node = ExprNodeLoc(expr) + elif isinstance(expr, ExprInt): + node = ExprNodeInt(expr) + elif isinstance(expr, ExprMem): + son = self.expr2node(expr.ptr) + node = ExprNodeMem(expr) + node.ptr = son + elif isinstance(expr, ExprSlice): + son = self.expr2node(expr.arg) + node = ExprNodeSlice(expr) + node.arg = son + elif isinstance(expr, ExprOp): + sons = [self.expr2node(arg) for arg in expr.args] + node = ExprNodeOp(expr) + node.args = sons + elif isinstance(expr, ExprCompose): + sons = [self.expr2node(arg) for arg in expr.args] + node = ExprNodeCompose(expr) + node.args = sons + elif isinstance(expr, ExprCond): + node = ExprNodeCond(expr) + node.cond = self.expr2node(expr.cond) + node.src1 = self.expr2node(expr.src1) + node.src2 = self.expr2node(expr.src2) + else: + raise TypeError("Unknown Expr Type %r", type(expr)) + return node + + def reduce(self, expr, **kwargs): + """Returns an ExprNode tree mirroring @expr tree. The ExprNode is + computed by applying reduction rules to the expression @expr + + @expr: an Expression + """ + + node = self.expr2node(expr) + return self.categorize(node, lvl=0, **kwargs) + + def categorize(self, node, lvl=0, **kwargs): + """Recursively apply rules to @node + + @node: ExprNode to analyze + @lvl: actual recursion level + """ + + expr = node.expr + log_reduce.debug("\t" * lvl + "Reduce...: %s", node.expr) + if isinstance(expr, ExprId): + node = ExprNodeId(expr) + elif isinstance(expr, ExprInt): + node = ExprNodeInt(expr) + elif isinstance(expr, ExprLoc): + node = ExprNodeLoc(expr) + elif isinstance(expr, ExprMem): + ptr = self.categorize(node.ptr, lvl=lvl + 1, **kwargs) + node = ExprNodeMem(ExprMem(ptr.expr, expr.size)) + node.ptr = ptr + elif isinstance(expr, ExprSlice): + arg = self.categorize(node.arg, lvl=lvl + 1, **kwargs) + node = ExprNodeSlice(ExprSlice(arg.expr, expr.start, expr.stop)) + node.arg = arg + elif isinstance(expr, ExprOp): + new_args = [] + for arg in node.args: + new_a = self.categorize(arg, lvl=lvl + 1, **kwargs) + assert new_a.expr.size == arg.expr.size + new_args.append(new_a) + node = ExprNodeOp(ExprOp(expr.op, *[x.expr for x in new_args])) + node.args = new_args + expr = node.expr + elif isinstance(expr, ExprCompose): + new_args = [] + new_expr_args = [] + for arg in node.args: + arg = self.categorize(arg, lvl=lvl + 1, **kwargs) + new_args.append(arg) + new_expr_args.append(arg.expr) + new_expr = ExprCompose(*new_expr_args) + node = ExprNodeCompose(new_expr) + node.args = new_args + elif isinstance(expr, ExprCond): + cond = self.categorize(node.cond, lvl=lvl + 1, **kwargs) + src1 = self.categorize(node.src1, lvl=lvl + 1, **kwargs) + src2 = self.categorize(node.src2, lvl=lvl + 1, **kwargs) + node = ExprNodeCond(ExprCond(cond.expr, src1.expr, src2.expr)) + node.cond, node.src1, node.src2 = cond, src1, src2 + else: + raise TypeError("Unknown Expr Type %r", type(expr)) + + node.info = self.apply_rules(node, lvl=lvl, **kwargs) + log_reduce.debug("\t" * lvl + "Reduce result: %s %r", + node.expr, node.info) + return node + + def apply_rules(self, node, lvl=0, **kwargs): + """Find and apply reduction rules to @node + + @node: ExprNode to analyse + @lvl: actuel recursion level + """ + + for rule in self.reduction_rules: + ret = rule(self, node, lvl=lvl, **kwargs) + + if ret is not None: + log_reduce.debug("\t" * lvl + "Rule found: %r", rule) + return ret + if not self.allow_none_result: + raise RuntimeError('Missing reduction rule for %r' % node.expr) diff --git a/src/miasm/expression/parser.py b/src/miasm/expression/parser.py new file mode 100644 index 00000000..66b6fc6d --- /dev/null +++ b/src/miasm/expression/parser.py @@ -0,0 +1,84 @@ +import pyparsing +from miasm.expression.expression import ExprInt, ExprId, ExprLoc, ExprSlice, \ + ExprMem, ExprCond, ExprCompose, ExprOp, ExprAssign, LocKey + +integer = pyparsing.Word(pyparsing.nums).setParseAction(lambda t: + int(t[0])) +hex_word = pyparsing.Literal('0x') + pyparsing.Word(pyparsing.hexnums) +hex_int = pyparsing.Combine(hex_word).setParseAction(lambda t: + int(t[0], 16)) + +str_int_pos = (hex_int | integer) +str_int_neg = (pyparsing.Suppress('-') + \ + (hex_int | integer)).setParseAction(lambda t: -t[0]) + +str_int = str_int_pos | str_int_neg + +STR_EXPRINT = pyparsing.Suppress("ExprInt") +STR_EXPRID = pyparsing.Suppress("ExprId") +STR_EXPRLOC = pyparsing.Suppress("ExprLoc") +STR_EXPRSLICE = pyparsing.Suppress("ExprSlice") +STR_EXPRMEM = pyparsing.Suppress("ExprMem") +STR_EXPRCOND = pyparsing.Suppress("ExprCond") +STR_EXPRCOMPOSE = pyparsing.Suppress("ExprCompose") +STR_EXPROP = pyparsing.Suppress("ExprOp") +STR_EXPRASSIGN = pyparsing.Suppress("ExprAssign") + +LOCKEY = pyparsing.Suppress("LocKey") + +STR_COMMA = pyparsing.Suppress(",") +LPARENTHESIS = pyparsing.Suppress("(") +RPARENTHESIS = pyparsing.Suppress(")") + + +T_INF = pyparsing.Suppress("<") +T_SUP = pyparsing.Suppress(">") + + +string_quote = pyparsing.QuotedString(quoteChar="'", escChar='\\', escQuote='\\') +string_dquote = pyparsing.QuotedString(quoteChar='"', escChar='\\', escQuote='\\') + + +string = string_quote | string_dquote + +expr = pyparsing.Forward() + +expr_int = STR_EXPRINT + LPARENTHESIS + str_int + STR_COMMA + str_int + RPARENTHESIS +expr_id = STR_EXPRID + LPARENTHESIS + string + STR_COMMA + str_int + RPARENTHESIS +expr_loc = STR_EXPRLOC + LPARENTHESIS + T_INF + LOCKEY + str_int + T_SUP + STR_COMMA + str_int + RPARENTHESIS +expr_slice = STR_EXPRSLICE + LPARENTHESIS + expr + STR_COMMA + str_int + STR_COMMA + str_int + RPARENTHESIS +expr_mem = STR_EXPRMEM + LPARENTHESIS + expr + STR_COMMA + str_int + RPARENTHESIS +expr_cond = STR_EXPRCOND + LPARENTHESIS + expr + STR_COMMA + expr + STR_COMMA + expr + RPARENTHESIS +expr_compose = STR_EXPRCOMPOSE + LPARENTHESIS + pyparsing.delimitedList(expr, delim=',') + RPARENTHESIS +expr_op = STR_EXPROP + LPARENTHESIS + string + STR_COMMA + pyparsing.delimitedList(expr, delim=',') + RPARENTHESIS +expr_aff = STR_EXPRASSIGN + LPARENTHESIS + expr + STR_COMMA + expr + RPARENTHESIS + +expr << (expr_int | expr_id | expr_loc | expr_slice | expr_mem | expr_cond | \ + expr_compose | expr_op | expr_aff) + +def parse_loc_key(t): + assert len(t) == 2 + loc_key, size = LocKey(t[0]), t[1] + return ExprLoc(loc_key, size) + +expr_int.setParseAction(lambda t: ExprInt(*t)) +expr_id.setParseAction(lambda t: ExprId(*t)) +expr_loc.setParseAction(parse_loc_key) +expr_slice.setParseAction(lambda t: ExprSlice(*t)) +expr_mem.setParseAction(lambda t: ExprMem(*t)) +expr_cond.setParseAction(lambda t: ExprCond(*t)) +expr_compose.setParseAction(lambda t: ExprCompose(*t)) +expr_op.setParseAction(lambda t: ExprOp(*t)) +expr_aff.setParseAction(lambda t: ExprAssign(*t)) + + +def str_to_expr(str_in): + """Parse the @str_in and return the corresponding Expression + @str_in: repr string of an Expression""" + + try: + value = expr.parseString(str_in) + except: + raise RuntimeError("Cannot parse expression %s" % str_in) + assert len(value) == 1 + return value[0] diff --git a/src/miasm/expression/simplifications.py b/src/miasm/expression/simplifications.py new file mode 100644 index 00000000..88e59983 --- /dev/null +++ b/src/miasm/expression/simplifications.py @@ -0,0 +1,201 @@ +# # +# Simplification methods library # +# # + +import logging + +from future.utils import viewitems + +from miasm.expression import simplifications_common +from miasm.expression import simplifications_cond +from miasm.expression import simplifications_explicit +from miasm.expression.expression_helper import fast_unify +import miasm.expression.expression as m2_expr +from miasm.expression.expression import ExprVisitorCallbackBottomToTop + +# Expression Simplifier +# --------------------- + +log_exprsimp = logging.getLogger("exprsimp") +console_handler = logging.StreamHandler() +console_handler.setFormatter(logging.Formatter("[%(levelname)-8s]: %(message)s")) +log_exprsimp.addHandler(console_handler) +log_exprsimp.setLevel(logging.WARNING) + + +class ExpressionSimplifier(ExprVisitorCallbackBottomToTop): + + """Wrapper on expression simplification passes. + + Instance handle passes lists. + + Available passes lists are: + - commons: common passes such as constant folding + - heavy : rare passes (for instance, in case of obfuscation) + """ + + # Common passes + PASS_COMMONS = { + m2_expr.ExprOp: [ + simplifications_common.simp_cst_propagation, + simplifications_common.simp_cond_op_int, + simplifications_common.simp_cond_factor, + simplifications_common.simp_add_multiple, + # CC op + simplifications_common.simp_cc_conds, + simplifications_common.simp_subwc_cf, + simplifications_common.simp_subwc_of, + simplifications_common.simp_sign_subwc_cf, + simplifications_common.simp_double_zeroext, + simplifications_common.simp_double_signext, + simplifications_common.simp_zeroext_eq_cst, + simplifications_common.simp_ext_eq_ext, + simplifications_common.simp_ext_cond_int, + simplifications_common.simp_sub_cf_zero, + + simplifications_common.simp_cmp_int, + simplifications_common.simp_cmp_bijective_op, + simplifications_common.simp_sign_inf_zeroext, + simplifications_common.simp_cmp_int_int, + simplifications_common.simp_ext_cst, + simplifications_common.simp_zeroext_and_cst_eq_cst, + simplifications_common.simp_test_signext_inf, + simplifications_common.simp_test_zeroext_inf, + simplifications_common.simp_cond_inf_eq_unsigned_zero, + simplifications_common.simp_compose_and_mask, + simplifications_common.simp_bcdadd_cf, + simplifications_common.simp_bcdadd, + simplifications_common.simp_smod_sext, + simplifications_common.simp_flag_cst, + ], + + m2_expr.ExprSlice: [ + simplifications_common.simp_slice, + simplifications_common.simp_slice_of_ext, + simplifications_common.simp_slice_of_sext, + simplifications_common.simp_slice_of_op_ext, + ], + m2_expr.ExprCompose: [simplifications_common.simp_compose], + m2_expr.ExprCond: [ + simplifications_common.simp_cond, + simplifications_common.simp_cond_zeroext, + simplifications_common.simp_cond_add, + # CC op + simplifications_common.simp_cond_flag, + simplifications_common.simp_cmp_int_arg, + + simplifications_common.simp_cond_eq_zero, + simplifications_common.simp_x_and_cst_eq_cst, + simplifications_common.simp_cond_logic_ext, + simplifications_common.simp_cond_sign_bit, + simplifications_common.simp_cond_eq_1_0, + simplifications_common.simp_cond_cc_flag, + simplifications_common.simp_cond_sub_cf, + ], + m2_expr.ExprMem: [simplifications_common.simp_mem], + + } + + + # Heavy passes + PASS_HEAVY = {} + + # Cond passes + PASS_COND = { + m2_expr.ExprSlice: [ + simplifications_cond.expr_simp_inf_signed, + simplifications_cond.expr_simp_inf_unsigned_inversed + ], + m2_expr.ExprOp: [ + simplifications_cond.expr_simp_inverse, + ], + m2_expr.ExprCond: [ + simplifications_cond.expr_simp_equal + ] + } + + + # Available passes lists are: + # - highlevel: transform high level operators to explicit computations + PASS_HIGH_TO_EXPLICIT = { + m2_expr.ExprOp: [ + simplifications_explicit.simp_flags, + simplifications_explicit.simp_ext, + ], + } + + + def __init__(self): + super(ExpressionSimplifier, self).__init__(self.expr_simp_inner) + self.expr_simp_cb = {} + + def enable_passes(self, passes): + """Add passes from @passes + @passes: dict(Expr class : list(callback)) + + Callback signature: Expr callback(ExpressionSimplifier, Expr) + """ + + # Clear cache of simplifiied expressions when adding a new pass + self.cache.clear() + + for k, v in viewitems(passes): + self.expr_simp_cb[k] = fast_unify(self.expr_simp_cb.get(k, []) + v) + + def apply_simp(self, expression): + """Apply enabled simplifications on expression + @expression: Expr instance + Return an Expr instance""" + + cls = expression.__class__ + debug_level = log_exprsimp.level >= logging.DEBUG + for simp_func in self.expr_simp_cb.get(cls, []): + # Apply simplifications + before = expression + expression = simp_func(self, expression) + after = expression + + if debug_level and before != after: + log_exprsimp.debug("[%s] %s => %s", simp_func, before, after) + + # If class changes, stop to prevent wrong simplifications + if expression.__class__ is not cls: + break + + return expression + + def expr_simp_inner(self, expression): + """Apply enabled simplifications on expression and find a stable state + @expression: Expr instance + Return an Expr instance""" + + # Find a stable state + while True: + # Canonize and simplify + new_expr = self.apply_simp(expression.canonize()) + if new_expr == expression: + return new_expr + # Run recursively simplification on fresh new expression + new_expr = self.visit(new_expr) + expression = new_expr + return new_expr + + def expr_simp(self, expression): + "Call simplification recursively" + return self.visit(expression) + + def __call__(self, expression): + "Call simplification recursively" + return self.visit(expression) + + +# Public ExprSimplificationPass instance with commons passes +expr_simp = ExpressionSimplifier() +expr_simp.enable_passes(ExpressionSimplifier.PASS_COMMONS) + +expr_simp_high_to_explicit = ExpressionSimplifier() +expr_simp_high_to_explicit.enable_passes(ExpressionSimplifier.PASS_HIGH_TO_EXPLICIT) + +expr_simp_explicit = ExpressionSimplifier() +expr_simp_explicit.enable_passes(ExpressionSimplifier.PASS_COMMONS) +expr_simp_explicit.enable_passes(ExpressionSimplifier.PASS_HIGH_TO_EXPLICIT) diff --git a/src/miasm/expression/simplifications_common.py b/src/miasm/expression/simplifications_common.py new file mode 100644 index 00000000..9156ee67 --- /dev/null +++ b/src/miasm/expression/simplifications_common.py @@ -0,0 +1,1868 @@ +# ----------------------------- # +# Common simplifications passes # +# ----------------------------- # + +from future.utils import viewitems + +from miasm.core.modint import mod_size2int, mod_size2uint +from miasm.expression.expression import ExprInt, ExprSlice, ExprMem, \ + ExprCond, ExprOp, ExprCompose, TOK_INF_SIGNED, TOK_INF_UNSIGNED, \ + TOK_INF_EQUAL_SIGNED, TOK_INF_EQUAL_UNSIGNED, TOK_EQUAL +from miasm.expression.expression_helper import parity, op_propag_cst, \ + merge_sliceto_slice +from miasm.expression.simplifications_explicit import simp_flags + +def simp_cst_propagation(e_s, expr): + """This passe includes: + - Constant folding + - Common logical identities + - Common binary identities + """ + + # merge associatif op + args = list(expr.args) + op_name = expr.op + # simpl integer manip + # int OP int => int + # TODO: <<< >>> << >> are architecture dependent + if op_name in op_propag_cst: + while (len(args) >= 2 and + args[-1].is_int() and + args[-2].is_int()): + int2 = args.pop() + int1 = args.pop() + if op_name == '+': + out = mod_size2uint[int1.size](int(int1) + int(int2)) + elif op_name == '*': + out = mod_size2uint[int1.size](int(int1) * int(int2)) + elif op_name == '**': + out = mod_size2uint[int1.size](int(int1) ** int(int2)) + elif op_name == '^': + out = mod_size2uint[int1.size](int(int1) ^ int(int2)) + elif op_name == '&': + out = mod_size2uint[int1.size](int(int1) & int(int2)) + elif op_name == '|': + out = mod_size2uint[int1.size](int(int1) | int(int2)) + elif op_name == '>>': + if int(int2) > int1.size: + out = 0 + else: + out = mod_size2uint[int1.size](int(int1) >> int(int2)) + elif op_name == '<<': + if int(int2) > int1.size: + out = 0 + else: + out = mod_size2uint[int1.size](int(int1) << int(int2)) + elif op_name == 'a>>': + tmp1 = mod_size2int[int1.size](int(int1)) + tmp2 = mod_size2uint[int2.size](int(int2)) + if tmp2 > int1.size: + is_signed = int(int1) & (1 << (int1.size - 1)) + if is_signed: + out = -1 + else: + out = 0 + else: + out = mod_size2uint[int1.size](tmp1 >> tmp2) + elif op_name == '>>>': + shifter = int(int2) % int2.size + out = (int(int1) >> shifter) | (int(int1) << (int2.size - shifter)) + elif op_name == '<<<': + shifter = int(int2) % int2.size + out = (int(int1) << shifter) | (int(int1) >> (int2.size - shifter)) + elif op_name == '/': + if int(int2) == 0: + return expr + out = int(int1) // int(int2) + elif op_name == '%': + if int(int2) == 0: + return expr + out = int(int1) % int(int2) + elif op_name == 'sdiv': + if int(int2) == 0: + return expr + tmp1 = mod_size2int[int1.size](int(int1)) + tmp2 = mod_size2int[int2.size](int(int2)) + out = mod_size2uint[int1.size](tmp1 // tmp2) + elif op_name == 'smod': + if int(int2) == 0: + return expr + tmp1 = mod_size2int[int1.size](int(int1)) + tmp2 = mod_size2int[int2.size](int(int2)) + out = mod_size2uint[int1.size](tmp1 % tmp2) + elif op_name == 'umod': + if int(int2) == 0: + return expr + tmp1 = mod_size2uint[int1.size](int(int1)) + tmp2 = mod_size2uint[int2.size](int(int2)) + out = mod_size2uint[int1.size](tmp1 % tmp2) + elif op_name == 'udiv': + if int(int2) == 0: + return expr + tmp1 = mod_size2uint[int1.size](int(int1)) + tmp2 = mod_size2uint[int2.size](int(int2)) + out = mod_size2uint[int1.size](tmp1 // tmp2) + + + + args.append(ExprInt(int(out), int1.size)) + + # cnttrailzeros(int) => int + if op_name == "cnttrailzeros" and args[0].is_int(): + i = 0 + while int(args[0]) & (1 << i) == 0 and i < args[0].size: + i += 1 + return ExprInt(i, args[0].size) + + # cntleadzeros(int) => int + if op_name == "cntleadzeros" and args[0].is_int(): + if int(args[0]) == 0: + return ExprInt(args[0].size, args[0].size) + i = args[0].size - 1 + while int(args[0]) & (1 << i) == 0: + i -= 1 + return ExprInt(expr.size - (i + 1), args[0].size) + + # -(-(A)) => A + if (op_name == '-' and len(args) == 1 and args[0].is_op('-') and + len(args[0].args) == 1): + return args[0].args[0] + + + # -(int) => -int + if op_name == '-' and len(args) == 1 and args[0].is_int(): + return ExprInt(-int(args[0]), expr.size) + # A op 0 =>A + if op_name in ['+', '|', "^", "<<", ">>", "<<<", ">>>"] and len(args) > 1: + if args[-1].is_int(0): + args.pop() + # A - 0 =>A + if op_name == '-' and len(args) > 1 and args[-1].is_int(0): + assert len(args) == 2 # Op '-' with more than 2 args: SantityCheckError + return args[0] + + # A * 1 =>A + if op_name == "*" and len(args) > 1 and args[-1].is_int(1): + args.pop() + + # for cannon form + # A * -1 => - A + if op_name == "*" and len(args) > 1 and args[-1] == args[-1].mask: + args.pop() + args[-1] = - args[-1] + + # op A => A + if op_name in ['+', '*', '^', '&', '|', '>>', '<<', + 'a>>', '<<<', '>>>', 'sdiv', 'smod', 'umod', 'udiv'] and len(args) == 1: + return args[0] + + # A-B => A + (-B) + if op_name == '-' and len(args) > 1: + if len(args) > 2: + raise ValueError( + 'sanity check fail on expr -: should have one or 2 args ' + + '%r %s' % (expr, expr) + ) + return ExprOp('+', args[0], -args[1]) + + # A op 0 => 0 + if op_name in ['&', "*"] and args[-1].is_int(0): + return ExprInt(0, expr.size) + + # - (A + B +...) => -A + -B + -C + if op_name == '-' and len(args) == 1 and args[0].is_op('+'): + args = [-a for a in args[0].args] + return ExprOp('+', *args) + + # -(a?int1:int2) => (a?-int1:-int2) + if (op_name == '-' and len(args) == 1 and + args[0].is_cond() and + args[0].src1.is_int() and args[0].src2.is_int()): + int1 = args[0].src1 + int2 = args[0].src2 + int1 = ExprInt(-int1.arg, int1.size) + int2 = ExprInt(-int2.arg, int2.size) + return ExprCond(args[0].cond, int1, int2) + + i = 0 + while i < len(args) - 1: + j = i + 1 + while j < len(args): + # A ^ A => 0 + if op_name == '^' and args[i] == args[j]: + args[i] = ExprInt(0, args[i].size) + del args[j] + continue + # A + (- A) => 0 + if op_name == '+' and args[j].is_op("-"): + if len(args[j].args) == 1 and args[i] == args[j].args[0]: + args[i] = ExprInt(0, args[i].size) + del args[j] + continue + # (- A) + A => 0 + if op_name == '+' and args[i].is_op("-"): + if len(args[i].args) == 1 and args[j] == args[i].args[0]: + args[i] = ExprInt(0, args[i].size) + del args[j] + continue + # A | A => A + if op_name == '|' and args[i] == args[j]: + del args[j] + continue + # A & A => A + if op_name == '&' and args[i] == args[j]: + del args[j] + continue + j += 1 + i += 1 + + if op_name in ['+', '^', '|', '&', '%', '/', '**'] and len(args) == 1: + return args[0] + + # A <<< A.size => A + if (op_name in ['<<<', '>>>'] and + args[1].is_int() and + int(args[1]) == args[0].size): + return args[0] + + # (A <<< X) <<< Y => A <<< (X+Y) (or <<< >>>) if X + Y does not overflow + if (op_name in ['<<<', '>>>'] and + args[0].is_op() and + args[0].op in ['<<<', '>>>']): + A = args[0].args[0] + X = args[0].args[1] + Y = args[1] + if op_name != args[0].op and e_s(X - Y) == ExprInt(0, X.size): + return args[0].args[0] + elif X.is_int() and Y.is_int(): + new_X = int(X) % expr.size + new_Y = int(Y) % expr.size + if op_name == args[0].op: + rot = (new_X + new_Y) % expr.size + op = op_name + else: + rot = new_Y - new_X + op = op_name + if rot < 0: + rot = - rot + op = {">>>": "<<<", "<<<": ">>>"}[op_name] + args = [A, ExprInt(rot, expr.size)] + op_name = op + + else: + # Do not consider this case, too tricky (overflow on addition / + # subtraction) + pass + + # A >> X >> Y => A >> (X+Y) if X + Y does not overflow + # To be sure, only consider the simplification when X.msb and Y.msb are 0 + if (op_name in ['<<', '>>'] and + args[0].is_op(op_name)): + X = args[0].args[1] + Y = args[1] + if (e_s(X.msb()) == ExprInt(0, 1) and + e_s(Y.msb()) == ExprInt(0, 1)): + args = [args[0].args[0], X + Y] + + # ((var >> int1) << int1) => var & mask + # ((var << int1) >> int1) => var & mask + if (op_name in ['<<', '>>'] and + args[0].is_op() and + args[0].op in ['<<', '>>'] and + op_name != args[0]): + var = args[0].args[0] + int1 = args[0].args[1] + int2 = args[1] + if int1 == int2 and int1.is_int() and int(int1) < expr.size: + if op_name == '>>': + mask = ExprInt((1 << (expr.size - int(int1))) - 1, expr.size) + else: + mask = ExprInt( + ((1 << int(int1)) - 1) ^ ((1 << expr.size) - 1), + expr.size + ) + ret = var & mask + return ret + + # ((A & A.mask) + if op_name == "&" and args[-1] == expr.mask: + args = args[:-1] + if len(args) == 1: + return args[0] + return ExprOp('&', *args) + + # ((A | A.mask) + if op_name == "|" and args[-1] == expr.mask: + return args[-1] + + # ! (!X + int) => X - int + # TODO + + # ((A & mask) >> shift) with mask < 2**shift => 0 + if op_name == ">>" and args[1].is_int() and args[0].is_op("&"): + if (args[0].args[1].is_int() and + 2 ** int(args[1]) > int(args[0].args[1])): + return ExprInt(0, args[0].size) + + # parity(int) => int + if op_name == 'parity' and args[0].is_int(): + return ExprInt(parity(int(args[0])), 1) + + # (-a) * b * (-c) * (-d) => (-a) * b * c * d + if op_name == "*" and len(args) > 1: + new_args = [] + counter = 0 + for arg in args: + if arg.is_op('-') and len(arg.args) == 1: + new_args.append(arg.args[0]) + counter += 1 + else: + new_args.append(arg) + if counter % 2: + return -ExprOp(op_name, *new_args) + args = new_args + + # -(a * b * int) => a * b * (-int) + if op_name == "-" and args[0].is_op('*') and args[0].args[-1].is_int(): + args = args[0].args + return ExprOp('*', *(list(args[:-1]) + [ExprInt(-int(args[-1]), expr.size)])) + + # A << int with A ExprCompose => move index + if (op_name == "<<" and args[0].is_compose() and + args[1].is_int() and int(args[1]) != 0): + final_size = args[0].size + shift = int(args[1]) + new_args = [] + # shift indexes + for index, arg in args[0].iter_args(): + new_args.append((arg, index+shift, index+shift+arg.size)) + # filter out expression + filter_args = [] + min_index = final_size + for tmp, start, stop in new_args: + if start >= final_size: + continue + if stop > final_size: + tmp = tmp[:tmp.size - (stop - final_size)] + filter_args.append(tmp) + min_index = min(start, min_index) + # create entry 0 + assert min_index != 0 + tmp = ExprInt(0, min_index) + args = [tmp] + filter_args + return ExprCompose(*args) + + # A >> int with A ExprCompose => move index + if op_name == ">>" and args[0].is_compose() and args[1].is_int(): + final_size = args[0].size + shift = int(args[1]) + new_args = [] + # shift indexes + for index, arg in args[0].iter_args(): + new_args.append((arg, index-shift, index+arg.size-shift)) + # filter out expression + filter_args = [] + max_index = 0 + for tmp, start, stop in new_args: + if stop <= 0: + continue + if start < 0: + tmp = tmp[-start:] + filter_args.append(tmp) + max_index = max(stop, max_index) + # create entry 0 + tmp = ExprInt(0, final_size - max_index) + args = filter_args + [tmp] + return ExprCompose(*args) + + + # Compose(a) OP Compose(b) with a/b same bounds => Compose(a OP b) + if op_name in ['|', '&', '^'] and all([arg.is_compose() for arg in args]): + bounds = set() + for arg in args: + bound = tuple([tmp.size for tmp in arg.args]) + bounds.add(bound) + if len(bounds) == 1: + new_args = [[tmp] for tmp in args[0].args] + for sub_arg in args[1:]: + for i, tmp in enumerate(sub_arg.args): + new_args[i].append(tmp) + args = [] + for i, arg in enumerate(new_args): + args.append(ExprOp(op_name, *arg)) + return ExprCompose(*args) + + return ExprOp(op_name, *args) + + +def simp_cond_op_int(_, expr): + "Extract conditions from operations" + + + # x?a:b + x?c:d + e => x?(a+c+e:b+d+e) + if not expr.op in ["+", "|", "^", "&", "*", '<<', '>>', 'a>>']: + return expr + if len(expr.args) < 2: + return expr + conds = set() + for arg in expr.args: + if arg.is_cond(): + conds.add(arg) + if len(conds) != 1: + return expr + cond = list(conds).pop() + + args1, args2 = [], [] + for arg in expr.args: + if arg.is_cond(): + args1.append(arg.src1) + args2.append(arg.src2) + else: + args1.append(arg) + args2.append(arg) + + return ExprCond(cond.cond, + ExprOp(expr.op, *args1), + ExprOp(expr.op, *args2)) + + +def simp_cond_factor(e_s, expr): + "Merge similar conditions" + if not expr.op in ["+", "|", "^", "&", "*", '<<', '>>', 'a>>']: + return expr + if len(expr.args) < 2: + return expr + + if expr.op in ['>>', '<<', 'a>>']: + assert len(expr.args) == 2 + + # Note: the following code is correct for non-commutative operation only if + # there is 2 arguments. Otherwise, the order is not conserved + + # Regroup sub-expression by similar conditions + conds = {} + not_conds = [] + multi_cond = False + for arg in expr.args: + if not arg.is_cond(): + not_conds.append(arg) + continue + cond = arg.cond + if not cond in conds: + conds[cond] = [] + else: + multi_cond = True + conds[cond].append(arg) + if not multi_cond: + return expr + + # Rebuild the new expression + c_out = not_conds + for cond, vals in viewitems(conds): + new_src1 = [x.src1 for x in vals] + new_src2 = [x.src2 for x in vals] + src1 = e_s.expr_simp(ExprOp(expr.op, *new_src1)) + src2 = e_s.expr_simp(ExprOp(expr.op, *new_src2)) + c_out.append(ExprCond(cond, src1, src2)) + + if len(c_out) == 1: + new_e = c_out[0] + else: + new_e = ExprOp(expr.op, *c_out) + return new_e + + +def simp_slice(e_s, expr): + "Slice optimization" + + # slice(A, 0, a.size) => A + if expr.start == 0 and expr.stop == expr.arg.size: + return expr.arg + # Slice(int) => int + if expr.arg.is_int(): + total_bit = expr.stop - expr.start + mask = (1 << (expr.stop - expr.start)) - 1 + return ExprInt(int((int(expr.arg) >> expr.start) & mask), total_bit) + # Slice(Slice(A, x), y) => Slice(A, z) + if expr.arg.is_slice(): + if expr.stop - expr.start > expr.arg.stop - expr.arg.start: + raise ValueError('slice in slice: getting more val', str(expr)) + + return ExprSlice(expr.arg.arg, expr.start + expr.arg.start, + expr.start + expr.arg.start + (expr.stop - expr.start)) + if expr.arg.is_compose(): + # Slice(Compose(A), x) => Slice(A, y) + for index, arg in expr.arg.iter_args(): + if index <= expr.start and index+arg.size >= expr.stop: + return arg[expr.start - index:expr.stop - index] + # Slice(Compose(A, B, C), x) => Compose(A, B, C) with truncated A/B/C + out = [] + for index, arg in expr.arg.iter_args(): + # arg is before slice start + if expr.start >= index + arg.size: + continue + # arg is after slice stop + elif expr.stop <= index: + continue + # arg is fully included in slice + elif expr.start <= index and index + arg.size <= expr.stop: + out.append(arg) + continue + # arg is truncated at start + if expr.start > index: + slice_start = expr.start - index + else: + # arg is not truncated at start + slice_start = 0 + # a is truncated at stop + if expr.stop < index + arg.size: + slice_stop = arg.size + expr.stop - (index + arg.size) - slice_start + else: + slice_stop = arg.size + out.append(arg[slice_start:slice_stop]) + + return ExprCompose(*out) + + # ExprMem(x, size)[:A] => ExprMem(x, a) + # XXXX todo hum, is it safe? + if (expr.arg.is_mem() and + expr.start == 0 and + expr.arg.size > expr.stop and expr.stop % 8 == 0): + return ExprMem(expr.arg.ptr, size=expr.stop) + # distributivity of slice and & + # (a & int)[x:y] => 0 if int[x:y] == 0 + if expr.arg.is_op("&") and expr.arg.args[-1].is_int(): + tmp = e_s.expr_simp(expr.arg.args[-1][expr.start:expr.stop]) + if tmp.is_int(0): + return tmp + # distributivity of slice and exprcond + # (a?int1:int2)[x:y] => (a?int1[x:y]:int2[x:y]) + # (a?compose1:compose2)[x:y] => (a?compose1[x:y]:compose2[x:y]) + if (expr.arg.is_cond() and + (expr.arg.src1.is_int() or expr.arg.src1.is_compose()) and + (expr.arg.src2.is_int() or expr.arg.src2.is_compose())): + src1 = expr.arg.src1[expr.start:expr.stop] + src2 = expr.arg.src2[expr.start:expr.stop] + return ExprCond(expr.arg.cond, src1, src2) + + # (a * int)[0:y] => (a[0:y] * int[0:y]) + if expr.start == 0 and expr.arg.is_op("*") and expr.arg.args[-1].is_int(): + args = [e_s.expr_simp(a[expr.start:expr.stop]) for a in expr.arg.args] + return ExprOp(expr.arg.op, *args) + + # (a >> int)[x:y] => a[x+int:y+int] with int+y <= a.size + # (a << int)[x:y] => a[x-int:y-int] with x-int >= 0 + if (expr.arg.is_op() and expr.arg.op in [">>", "<<"] and + expr.arg.args[1].is_int()): + arg, shift = expr.arg.args + shift = int(shift) + if expr.arg.op == ">>": + if shift + expr.stop <= arg.size: + return arg[expr.start + shift:expr.stop + shift] + elif expr.arg.op == "<<": + if expr.start - shift >= 0: + return arg[expr.start - shift:expr.stop - shift] + else: + raise ValueError('Bad case') + + return expr + + +def simp_compose(e_s, expr): + "Commons simplification on ExprCompose" + args = merge_sliceto_slice(expr) + out = [] + # compose of compose + for arg in args: + if arg.is_compose(): + out += arg.args + else: + out.append(arg) + args = out + # Compose(a) with a.size = compose.size => a + if len(args) == 1 and args[0].size == expr.size: + return args[0] + + # {(X[z:], 0, X.size-z), (0, X.size-z, X.size)} => (X >> z) + if len(args) == 2 and args[1].is_int(0): + if (args[0].is_slice() and + args[0].stop == args[0].arg.size and + args[0].size + args[1].size == args[0].arg.size): + new_expr = args[0].arg >> ExprInt(args[0].start, args[0].arg.size) + return new_expr + + # {@X[base + i] 0 X, @Y[base + i + X] X (X + Y)} => @(X+Y)[base + i] + for i, arg in enumerate(args[:-1]): + nxt = args[i + 1] + if arg.is_mem() and nxt.is_mem(): + gap = e_s(nxt.ptr - arg.ptr) + if gap.is_int() and arg.size % 8 == 0 and int(gap) == arg.size // 8: + args = args[:i] + [ExprMem(arg.ptr, + arg.size + nxt.size)] + args[i + 2:] + return ExprCompose(*args) + # {A, signext(A)[32:64]} => signext(A) + if len(args) == 2 and args[0].size == args[1].size: + arg1, arg2 = args + size = arg1.size + sign_ext = arg1.signExtend(arg1.size*2) + if arg2 == sign_ext[size:2*size]: + return sign_ext + + + # {a, x?b:d, x?c:e, f} => x?{a, b, c, f}:{a, d, e, f} + conds = set(arg.cond for arg in expr.args if arg.is_cond()) + if len(conds) == 1: + cond = list(conds)[0] + args1, args2 = [], [] + for arg in expr.args: + if arg.is_cond(): + args1.append(arg.src1) + args2.append(arg.src2) + else: + args1.append(arg) + args2.append(arg) + arg1 = e_s(ExprCompose(*args1)) + arg2 = e_s(ExprCompose(*args2)) + return ExprCond(cond, arg1, arg2) + return ExprCompose(*args) + +def simp_cond(_, expr): + """ + Common simplifications on ExprCond. + Eval exprcond src1/src2 with satifiable/unsatisfiable condition propagation + """ + if (not expr.cond.is_int()) and expr.cond.size == 1: + src1 = expr.src1.replace_expr({expr.cond: ExprInt(1, 1)}) + src2 = expr.src2.replace_expr({expr.cond: ExprInt(0, 1)}) + if src1 != expr.src1 or src2 != expr.src2: + return ExprCond(expr.cond, src1, src2) + + # -A ? B:C => A ? B:C + if expr.cond.is_op('-') and len(expr.cond.args) == 1: + expr = ExprCond(expr.cond.args[0], expr.src1, expr.src2) + # a?x:x + elif expr.src1 == expr.src2: + expr = expr.src1 + # int ? A:B => A or B + elif expr.cond.is_int(): + if int(expr.cond) == 0: + expr = expr.src2 + else: + expr = expr.src1 + # a?(a?b:c):x => a?b:x + elif expr.src1.is_cond() and expr.cond == expr.src1.cond: + expr = ExprCond(expr.cond, expr.src1.src1, expr.src2) + # a?x:(a?b:c) => a?x:c + elif expr.src2.is_cond() and expr.cond == expr.src2.cond: + expr = ExprCond(expr.cond, expr.src1, expr.src2.src2) + # a|int ? b:c => b with int != 0 + elif (expr.cond.is_op('|') and + expr.cond.args[1].is_int() and + expr.cond.args[1].arg != 0): + return expr.src1 + + # (C?int1:int2)?(A:B) => + elif (expr.cond.is_cond() and + expr.cond.src1.is_int() and + expr.cond.src2.is_int()): + int1 = int(expr.cond.src1) + int2 = int(expr.cond.src2) + if int1 and int2: + expr = expr.src1 + elif int1 == 0 and int2 == 0: + expr = expr.src2 + elif int1 == 0 and int2: + expr = ExprCond(expr.cond.cond, expr.src2, expr.src1) + elif int1 and int2 == 0: + expr = ExprCond(expr.cond.cond, expr.src1, expr.src2) + + elif expr.cond.is_compose(): + # {0, X, 0}?(A:B) => X?(A:B) + args = [arg for arg in expr.cond.args if not arg.is_int(0)] + if len(args) == 1: + arg = args.pop() + return ExprCond(arg, expr.src1, expr.src2) + elif len(args) < len(expr.cond.args): + return ExprCond(ExprCompose(*args), expr.src1, expr.src2) + return expr + + +def simp_mem(_, expr): + """ + Common simplifications on ExprMem: + @32[x?a:b] => x?@32[a]:@32[b] + """ + if expr.ptr.is_cond(): + cond = expr.ptr + ret = ExprCond(cond.cond, + ExprMem(cond.src1, expr.size), + ExprMem(cond.src2, expr.size)) + return ret + return expr + + + + +def test_cc_eq_args(expr, *sons_op): + """ + Return True if expression's arguments match the list in sons_op, and their + sub arguments are identical. Ex: + CC_S<=( + FLAG_SIGN_SUB(A, B), + FLAG_SUB_OF(A, B), + FLAG_EQ_CMP(A, B) + ) + """ + if not expr.is_op(): + return False + if len(expr.args) != len(sons_op): + return False + all_args = set() + for i, arg in enumerate(expr.args): + if not arg.is_op(sons_op[i]): + return False + all_args.add(arg.args) + return len(all_args) == 1 + + +def simp_cc_conds(_, expr): + """ + High level simplifications. Example: + CC_U<(FLAG_SUB_CF(A, B) => A <u B + """ + if (expr.is_op("CC_U>=") and + test_cc_eq_args( + expr, + "FLAG_SUB_CF" + )): + expr = ExprCond( + ExprOp(TOK_INF_UNSIGNED, *expr.args[0].args), + ExprInt(0, expr.size), + ExprInt(1, expr.size)) + + elif (expr.is_op("CC_U<") and + test_cc_eq_args( + expr, + "FLAG_SUB_CF" + )): + expr = ExprOp(TOK_INF_UNSIGNED, *expr.args[0].args) + + elif (expr.is_op("CC_NEG") and + test_cc_eq_args( + expr, + "FLAG_SIGN_SUB" + )): + expr = ExprOp(TOK_INF_SIGNED, *expr.args[0].args) + + elif (expr.is_op("CC_POS") and + test_cc_eq_args( + expr, + "FLAG_SIGN_SUB" + )): + expr = ExprCond( + ExprOp(TOK_INF_SIGNED, *expr.args[0].args), + ExprInt(0, expr.size), + ExprInt(1, expr.size) + ) + + elif (expr.is_op("CC_EQ") and + test_cc_eq_args( + expr, + "FLAG_EQ" + )): + arg = expr.args[0].args[0] + expr = ExprOp(TOK_EQUAL, arg, ExprInt(0, arg.size)) + + elif (expr.is_op("CC_NE") and + test_cc_eq_args( + expr, + "FLAG_EQ" + )): + arg = expr.args[0].args[0] + expr = ExprCond( + ExprOp(TOK_EQUAL,arg, ExprInt(0, arg.size)), + ExprInt(0, expr.size), + ExprInt(1, expr.size) + ) + elif (expr.is_op("CC_NE") and + test_cc_eq_args( + expr, + "FLAG_EQ_CMP" + )): + expr = ExprCond( + ExprOp(TOK_EQUAL, *expr.args[0].args), + ExprInt(0, expr.size), + ExprInt(1, expr.size) + ) + + elif (expr.is_op("CC_EQ") and + test_cc_eq_args( + expr, + "FLAG_EQ_CMP" + )): + expr = ExprOp(TOK_EQUAL, *expr.args[0].args) + + elif (expr.is_op("CC_NE") and + test_cc_eq_args( + expr, + "FLAG_EQ_AND" + )): + expr = ExprOp("&", *expr.args[0].args) + + elif (expr.is_op("CC_EQ") and + test_cc_eq_args( + expr, + "FLAG_EQ_AND" + )): + expr = ExprCond( + ExprOp("&", *expr.args[0].args), + ExprInt(0, expr.size), + ExprInt(1, expr.size) + ) + + elif (expr.is_op("CC_S>") and + test_cc_eq_args( + expr, + "FLAG_SIGN_SUB", + "FLAG_SUB_OF", + "FLAG_EQ_CMP", + )): + expr = ExprCond( + ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args), + ExprInt(0, expr.size), + ExprInt(1, expr.size) + ) + + elif (expr.is_op("CC_S>") and + len(expr.args) == 3 and + expr.args[0].is_op("FLAG_SIGN_SUB") and + expr.args[2].is_op("FLAG_EQ_CMP") and + expr.args[0].args == expr.args[2].args and + expr.args[1].is_int(0)): + expr = ExprCond( + ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args), + ExprInt(0, expr.size), + ExprInt(1, expr.size) + ) + + + + elif (expr.is_op("CC_S>=") and + test_cc_eq_args( + expr, + "FLAG_SIGN_SUB", + "FLAG_SUB_OF" + )): + expr = ExprCond( + ExprOp(TOK_INF_SIGNED, *expr.args[0].args), + ExprInt(0, expr.size), + ExprInt(1, expr.size) + ) + + elif (expr.is_op("CC_S>=") and + len(expr.args) == 2 and + expr.args[0].is_op("FLAG_SIGN_SUB") and + expr.args[0].args[1].is_int(0) and + expr.args[1].is_int(0)): + expr = ExprOp( + TOK_INF_EQUAL_SIGNED, + expr.args[0].args[1], + expr.args[0].args[0], + ) + + elif (expr.is_op("CC_S<") and + test_cc_eq_args( + expr, + "FLAG_SIGN_SUB", + "FLAG_SUB_OF" + )): + expr = ExprOp(TOK_INF_SIGNED, *expr.args[0].args) + + elif (expr.is_op("CC_S<=") and + test_cc_eq_args( + expr, + "FLAG_SIGN_SUB", + "FLAG_SUB_OF", + "FLAG_EQ_CMP", + )): + expr = ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args) + + elif (expr.is_op("CC_S<=") and + len(expr.args) == 3 and + expr.args[0].is_op("FLAG_SIGN_SUB") and + expr.args[2].is_op("FLAG_EQ_CMP") and + expr.args[0].args == expr.args[2].args and + expr.args[1].is_int(0)): + expr = ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args) + + elif (expr.is_op("CC_U<=") and + test_cc_eq_args( + expr, + "FLAG_SUB_CF", + "FLAG_EQ_CMP", + )): + expr = ExprOp(TOK_INF_EQUAL_UNSIGNED, *expr.args[0].args) + + elif (expr.is_op("CC_U>") and + test_cc_eq_args( + expr, + "FLAG_SUB_CF", + "FLAG_EQ_CMP", + )): + expr = ExprCond( + ExprOp(TOK_INF_EQUAL_UNSIGNED, *expr.args[0].args), + ExprInt(0, expr.size), + ExprInt(1, expr.size) + ) + + elif (expr.is_op("CC_S<") and + test_cc_eq_args( + expr, + "FLAG_SIGN_ADD", + "FLAG_ADD_OF" + )): + arg0, arg1 = expr.args[0].args + expr = ExprOp(TOK_INF_SIGNED, arg0, -arg1) + + return expr + + + +def simp_cond_flag(_, expr): + """FLAG_EQ_CMP(X, Y)?A:B => (X == Y)?A:B""" + cond = expr.cond + if cond.is_op("FLAG_EQ_CMP"): + return ExprCond(ExprOp(TOK_EQUAL, *cond.args), expr.src1, expr.src2) + return expr + + +def simp_sub_cf_zero(_, expr): + """FLAG_SUB_CF(0, X) => (X)?1:0""" + if not expr.is_op("FLAG_SUB_CF"): + return expr + if not expr.args[0].is_int(0): + return expr + return ExprCond(expr.args[1], ExprInt(1, 1), ExprInt(0, 1)) + +def simp_cond_cc_flag(expr_simp, expr): + """ + ExprCond(CC_><(bit), X, Y) => ExprCond(bit, X, Y) + ExprCond(CC_U>=(bit), X, Y) => ExprCond(bit, Y, X) + """ + if not expr.is_cond(): + return expr + if not expr.cond.is_op(): + return expr + expr_op = expr.cond + if expr_op.op not in ["CC_U<", "CC_U>="]: + return expr + arg = expr_op.args[0] + if arg.size != 1: + return expr + if expr_op.op == "CC_U<": + return ExprCond(arg, expr.src1, expr.src2) + if expr_op.op == "CC_U>=": + return ExprCond(arg, expr.src2, expr.src1) + return expr + +def simp_cond_sub_cf(expr_simp, expr): + """ + ExprCond(FLAG_SUB_CF(A, B), X, Y) => ExprCond(A <u B, X, Y) + """ + if not expr.is_cond(): + return expr + if not expr.cond.is_op("FLAG_SUB_CF"): + return expr + cond = ExprOp(TOK_INF_UNSIGNED, *expr.cond.args) + return ExprCond(cond, expr.src1, expr.src2) + + +def simp_cmp_int(expr_simp, expr): + """ + ({X, 0} == int) => X == int[:] + X + int1 == int2 => X == int2-int1 + X ^ int1 == int2 => X == int1^int2 + """ + if (expr.is_op(TOK_EQUAL) and + expr.args[1].is_int() and + expr.args[0].is_compose() and + len(expr.args[0].args) == 2 and + expr.args[0].args[1].is_int(0)): + # ({X, 0} == int) => X == int[:] + src = expr.args[0].args[0] + int_val = int(expr.args[1]) + new_int = ExprInt(int_val, src.size) + expr = expr_simp( + ExprOp(TOK_EQUAL, src, new_int) + ) + elif not expr.is_op(TOK_EQUAL): + return expr + assert len(expr.args) == 2 + + left, right = expr.args + if left.is_int() and not right.is_int(): + left, right = right, left + if not right.is_int(): + return expr + if not (left.is_op() and left.op in ['+', '^']): + return expr + if not left.args[-1].is_int(): + return expr + # X + int1 == int2 => X == int2-int1 + # WARNING: + # X - 0x10 <=u 0x20 gives X in [0x10 0x30] + # which is not equivalet to A <=u 0x10 + + left_orig = left + left, last_int = left.args[:-1], left.args[-1] + + if len(left) == 1: + left = left[0] + else: + left = ExprOp(left_orig.op, *left) + + if left_orig.op == "+": + new_int = expr_simp(right - last_int) + elif left_orig.op == '^': + new_int = expr_simp(right ^ last_int) + else: + raise RuntimeError("Unsupported operator") + + expr = expr_simp( + ExprOp(TOK_EQUAL, left, new_int), + ) + return expr + + + +def simp_cmp_int_arg(_, expr): + """ + (0x10 <= R0) ? A:B + => + (R0 < 0x10) ? B:A + """ + cond = expr.cond + if not cond.is_op(): + return expr + op = cond.op + if op not in [ + TOK_EQUAL, + TOK_INF_SIGNED, + TOK_INF_EQUAL_SIGNED, + TOK_INF_UNSIGNED, + TOK_INF_EQUAL_UNSIGNED + ]: + return expr + arg1, arg2 = cond.args + if arg2.is_int(): + return expr + if not arg1.is_int(): + return expr + src1, src2 = expr.src1, expr.src2 + if op == TOK_EQUAL: + return ExprCond(ExprOp(TOK_EQUAL, arg2, arg1), src1, src2) + + arg1, arg2 = arg2, arg1 + src1, src2 = src2, src1 + if op == TOK_INF_SIGNED: + op = TOK_INF_EQUAL_SIGNED + elif op == TOK_INF_EQUAL_SIGNED: + op = TOK_INF_SIGNED + elif op == TOK_INF_UNSIGNED: + op = TOK_INF_EQUAL_UNSIGNED + elif op == TOK_INF_EQUAL_UNSIGNED: + op = TOK_INF_UNSIGNED + return ExprCond(ExprOp(op, arg1, arg2), src1, src2) + + + +def simp_cmp_bijective_op(expr_simp, expr): + """ + A + B == A => A == 0 + + X + A == X + B => A == B + X ^ A == X ^ B => A == B + + TODO: + 3 * A + B == A + C => 2 * A + B == C + """ + + if not expr.is_op(TOK_EQUAL): + return expr + op_a = expr.args[0] + op_b = expr.args[1] + + # a == a + if op_a == op_b: + return ExprInt(1, 1) + + # Case: + # a + b + c == a + if op_a.is_op() and op_a.op in ["+", "^"]: + args = list(op_a.args) + if op_b in args: + args.remove(op_b) + if not args: + raise ValueError("Can be here") + elif len(args) == 1: + op_a = args[0] + else: + op_a = ExprOp(op_a.op, *args) + return ExprOp(TOK_EQUAL, op_a, ExprInt(0, args[0].size)) + # a == a + b + c + if op_b.is_op() and op_b.op in ["+", "^"]: + args = list(op_b.args) + if op_a in args: + args.remove(op_a) + if not args: + raise ValueError("Can be here") + elif len(args) == 1: + op_b = args[0] + else: + op_b = ExprOp(op_b.op, *args) + return ExprOp(TOK_EQUAL, op_b, ExprInt(0, args[0].size)) + + if not (op_a.is_op() and op_b.is_op()): + return expr + if op_a.op != op_b.op: + return expr + op = op_a.op + if op not in ["+", "^"]: + return expr + common = set(op_a.args).intersection(op_b.args) + if not common: + return expr + + args_a = list(op_a.args) + args_b = list(op_b.args) + for value in common: + while value in args_a and value in args_b: + args_a.remove(value) + args_b.remove(value) + + # a + b == a + b + c + if not args_a: + return ExprOp(TOK_EQUAL, ExprOp(op, *args_b), ExprInt(0, args_b[0].size)) + # a + b + c == a + b + if not args_b: + return ExprOp(TOK_EQUAL, ExprOp(op, *args_a), ExprInt(0, args_a[0].size)) + + arg_a = ExprOp(op, *args_a) + arg_b = ExprOp(op, *args_b) + return ExprOp(TOK_EQUAL, arg_a, arg_b) + + +def simp_subwc_cf(_, expr): + """SUBWC_CF(A, B, SUB_CF(C, D)) => SUB_CF({A, C}, {B, D})""" + if not expr.is_op('FLAG_SUBWC_CF'): + return expr + op3 = expr.args[2] + if not op3.is_op("FLAG_SUB_CF"): + return expr + + op1 = ExprCompose(expr.args[0], op3.args[0]) + op2 = ExprCompose(expr.args[1], op3.args[1]) + + return ExprOp("FLAG_SUB_CF", op1, op2) + + +def simp_subwc_of(_, expr): + """SUBWC_OF(A, B, SUB_CF(C, D)) => SUB_OF({A, C}, {B, D})""" + if not expr.is_op('FLAG_SUBWC_OF'): + return expr + op3 = expr.args[2] + if not op3.is_op("FLAG_SUB_CF"): + return expr + + op1 = ExprCompose(expr.args[0], op3.args[0]) + op2 = ExprCompose(expr.args[1], op3.args[1]) + + return ExprOp("FLAG_SUB_OF", op1, op2) + + +def simp_sign_subwc_cf(_, expr): + """SIGN_SUBWC(A, B, SUB_CF(C, D)) => SIGN_SUB({A, C}, {B, D})""" + if not expr.is_op('FLAG_SIGN_SUBWC'): + return expr + op3 = expr.args[2] + if not op3.is_op("FLAG_SUB_CF"): + return expr + + op1 = ExprCompose(expr.args[0], op3.args[0]) + op2 = ExprCompose(expr.args[1], op3.args[1]) + + return ExprOp("FLAG_SIGN_SUB", op1, op2) + +def simp_double_zeroext(_, expr): + """A.zeroExt(X).zeroExt(Y) => A.zeroExt(Y)""" + if not (expr.is_op() and expr.op.startswith("zeroExt")): + return expr + arg1 = expr.args[0] + if not (arg1.is_op() and arg1.op.startswith("zeroExt")): + return expr + arg2 = arg1.args[0] + return ExprOp(expr.op, arg2) + +def simp_double_signext(_, expr): + """A.signExt(X).signExt(Y) => A.signExt(Y)""" + if not (expr.is_op() and expr.op.startswith("signExt")): + return expr + arg1 = expr.args[0] + if not (arg1.is_op() and arg1.op.startswith("signExt")): + return expr + arg2 = arg1.args[0] + return ExprOp(expr.op, arg2) + +def simp_zeroext_eq_cst(_, expr): + """A.zeroExt(X) == int => A == int[:A.size]""" + if not expr.is_op(TOK_EQUAL): + return expr + arg1, arg2 = expr.args + if not arg2.is_int(): + return expr + if not (arg1.is_op() and arg1.op.startswith("zeroExt")): + return expr + src = arg1.args[0] + if int(arg2) > (1 << src.size): + # Always false + return ExprInt(0, expr.size) + return ExprOp(TOK_EQUAL, src, ExprInt(int(arg2), src.size)) + +def simp_cond_zeroext(_, expr): + """ + X.zeroExt()?(A:B) => X ? A:B + X.signExt()?(A:B) => X ? A:B + """ + if not ( + expr.cond.is_op() and + ( + expr.cond.op.startswith("zeroExt") or + expr.cond.op.startswith("signExt") + ) + ): + return expr + + ret = ExprCond(expr.cond.args[0], expr.src1, expr.src2) + return ret + +def simp_ext_eq_ext(_, expr): + """ + A.zeroExt(X) == B.zeroExt(X) => A == B + A.signExt(X) == B.signExt(X) => A == B + """ + if not expr.is_op(TOK_EQUAL): + return expr + arg1, arg2 = expr.args + if (not ((arg1.is_op() and arg1.op.startswith("zeroExt") and + arg2.is_op() and arg2.op.startswith("zeroExt")) or + (arg1.is_op() and arg1.op.startswith("signExt") and + arg2.is_op() and arg2.op.startswith("signExt")))): + return expr + if arg1.args[0].size != arg2.args[0].size: + return expr + return ExprOp(TOK_EQUAL, arg1.args[0], arg2.args[0]) + +def simp_cond_eq_zero(_, expr): + """(X == 0)?(A:B) => X?(B:A)""" + cond = expr.cond + if not cond.is_op(TOK_EQUAL): + return expr + arg1, arg2 = cond.args + if not arg2.is_int(0): + return expr + new_expr = ExprCond(arg1, expr.src2, expr.src1) + return new_expr + +def simp_sign_inf_zeroext(expr_s, expr): + """ + [!] Ensure before: X.zeroExt(X.size) => X + + X.zeroExt() <s 0 => 0 + X.zeroExt() <=s 0 => X == 0 + + X.zeroExt() <s cst => X.zeroExt() <u cst (cst positive) + X.zeroExt() <=s cst => X.zeroExt() <=u cst (cst positive) + + X.zeroExt() <s cst => 0 (cst negative) + X.zeroExt() <=s cst => 0 (cst negative) + + """ + if not (expr.is_op(TOK_INF_SIGNED) or expr.is_op(TOK_INF_EQUAL_SIGNED)): + return expr + arg1, arg2 = expr.args + if not arg2.is_int(): + return expr + if not (arg1.is_op() and arg1.op.startswith("zeroExt")): + return expr + src = arg1.args[0] + assert src.size < arg1.size + + # If cst is zero + if arg2.is_int(0): + if expr.is_op(TOK_INF_SIGNED): + # X.zeroExt() <s 0 => 0 + return ExprInt(0, expr.size) + else: + # X.zeroExt() <=s 0 => X == 0 + return ExprOp(TOK_EQUAL, src, ExprInt(0, src.size)) + + # cst is not zero + cst = int(arg2) + if cst & (1 << (arg2.size - 1)): + # cst is negative + return ExprInt(0, expr.size) + # cst is positive + if expr.is_op(TOK_INF_SIGNED): + # X.zeroExt() <s cst => X.zeroExt() <u cst (cst positive) + return ExprOp(TOK_INF_UNSIGNED, src, expr_s(arg2[:src.size])) + # X.zeroExt() <=s cst => X.zeroExt() <=u cst (cst positive) + return ExprOp(TOK_INF_EQUAL_UNSIGNED, src, expr_s(arg2[:src.size])) + + +def simp_zeroext_and_cst_eq_cst(expr_s, expr): + """ + A.zeroExt(X) & ... & int == int => A & ... & int[:A.size] == int[:A.size] + """ + if not expr.is_op(TOK_EQUAL): + return expr + arg1, arg2 = expr.args + if not arg2.is_int(): + return expr + if not arg1.is_op('&'): + return expr + is_ok = True + sizes = set() + for arg in arg1.args: + if arg.is_int(): + continue + if (arg.is_op() and + arg.op.startswith("zeroExt")): + sizes.add(arg.args[0].size) + continue + is_ok = False + break + if not is_ok: + return expr + if len(sizes) != 1: + return expr + size = list(sizes)[0] + if int(arg2) > ((1 << size) - 1): + return expr + args = [expr_s(arg[:size]) for arg in arg1.args] + left = ExprOp('&', *args) + right = expr_s(arg2[:size]) + ret = ExprOp(TOK_EQUAL, left, right) + return ret + + +def test_one_bit_set(arg): + """ + Return True if arg has form 1 << X + """ + return arg != 0 and ((arg & (arg - 1)) == 0) + +def simp_x_and_cst_eq_cst(_, expr): + """ + (x & ... & onebitmask == onebitmask) ? A:B => (x & ... & onebitmask) ? A:B + """ + cond = expr.cond + if not cond.is_op(TOK_EQUAL): + return expr + arg1, mask2 = cond.args + if not mask2.is_int(): + return expr + if not test_one_bit_set(int(mask2)): + return expr + if not arg1.is_op('&'): + return expr + mask1 = arg1.args[-1] + if mask1 != mask2: + return expr + cond = ExprOp('&', *arg1.args) + return ExprCond(cond, expr.src1, expr.src2) + +def simp_cmp_int_int(_, expr): + """ + IntA <s IntB => int + IntA <u IntB => int + IntA <=s IntB => int + IntA <=u IntB => int + IntA == IntB => int + """ + if expr.op not in [ + TOK_EQUAL, + TOK_INF_SIGNED, TOK_INF_UNSIGNED, + TOK_INF_EQUAL_SIGNED, TOK_INF_EQUAL_UNSIGNED, + ]: + return expr + if not all(arg.is_int() for arg in expr.args): + return expr + int_a, int_b = expr.args + if expr.is_op(TOK_EQUAL): + if int_a == int_b: + return ExprInt(1, 1) + return ExprInt(0, expr.size) + + if expr.op in [TOK_INF_SIGNED, TOK_INF_EQUAL_SIGNED]: + int_a = int(mod_size2int[int_a.size](int(int_a))) + int_b = int(mod_size2int[int_b.size](int(int_b))) + else: + int_a = int(mod_size2uint[int_a.size](int(int_a))) + int_b = int(mod_size2uint[int_b.size](int(int_b))) + + if expr.op in [TOK_INF_SIGNED, TOK_INF_UNSIGNED]: + ret = int_a < int_b + else: + ret = int_a <= int_b + + if ret: + ret = 1 + else: + ret = 0 + return ExprInt(ret, 1) + + +def simp_ext_cst(_, expr): + """ + Int.zeroExt(X) => Int + Int.signExt(X) => Int + """ + if not (expr.op.startswith("zeroExt") or expr.op.startswith("signExt")): + return expr + arg = expr.args[0] + if not arg.is_int(): + return expr + if expr.op.startswith("zeroExt"): + ret = int(arg) + else: + ret = int(mod_size2int[arg.size](int(arg))) + ret = ExprInt(ret, expr.size) + return ret + + + +def simp_ext_cond_int(e_s, expr): + """ + zeroExt(ExprCond(X, Int, Int)) => ExprCond(X, Int, Int) + """ + if not (expr.op.startswith("zeroExt") or expr.op.startswith("signExt")): + return expr + arg = expr.args[0] + if not arg.is_cond(): + return expr + if not (arg.src1.is_int() and arg.src2.is_int()): + return expr + src1 = ExprOp(expr.op, arg.src1) + src2 = ExprOp(expr.op, arg.src2) + return e_s(ExprCond(arg.cond, src1, src2)) + + +def simp_slice_of_ext(_, expr): + """ + C.zeroExt(X)[A:B] => 0 if A >= size(C) + C.zeroExt(X)[A:B] => C[A:B] if B <= size(C) + A.zeroExt(X)[0:Y] => A.zeroExt(Y) + """ + if not expr.arg.is_op(): + return expr + if not expr.arg.op.startswith("zeroExt"): + return expr + arg = expr.arg.args[0] + + if expr.start >= arg.size: + # C.zeroExt(X)[A:B] => 0 if A >= size(C) + return ExprInt(0, expr.size) + if expr.stop <= arg.size: + # C.zeroExt(X)[A:B] => C[A:B] if B <= size(C) + return arg[expr.start:expr.stop] + if expr.start == 0: + # A.zeroExt(X)[0:Y] => A.zeroExt(Y) + return arg.zeroExtend(expr.stop) + return expr + +def simp_slice_of_sext(e_s, expr): + """ + with Y <= size(A) + A.signExt(X)[0:Y] => A[0:Y] + """ + if not expr.arg.is_op(): + return expr + if not expr.arg.op.startswith("signExt"): + return expr + arg = expr.arg.args[0] + if expr.start != 0: + return expr + if expr.stop <= arg.size: + return e_s.expr_simp(arg[:expr.stop]) + return expr + + +def simp_slice_of_op_ext(expr_s, expr): + """ + (X.zeroExt() + {Z, } + ... + Int)[0:8] => X + ... + int[:] + (X.zeroExt() | ... | Int)[0:8] => X | ... | int[:] + ... + """ + if expr.start != 0: + return expr + src = expr.arg + if not src.is_op(): + return expr + if src.op not in ['+', '|', '^', '&']: + return expr + is_ok = True + for arg in src.args: + if arg.is_int(): + continue + if (arg.is_op() and + arg.op.startswith("zeroExt") and + arg.args[0].size == expr.stop): + continue + if arg.is_compose(): + continue + is_ok = False + break + if not is_ok: + return expr + args = [expr_s(arg[:expr.stop]) for arg in src.args] + return ExprOp(src.op, *args) + + +def simp_cond_logic_ext(expr_s, expr): + """(X.zeroExt() + ... + Int) ? A:B => X + ... + int[:] ? A:B""" + cond = expr.cond + if not cond.is_op(): + return expr + if cond.op not in ["&", "^", "|"]: + return expr + is_ok = True + sizes = set() + for arg in cond.args: + if arg.is_int(): + continue + if (arg.is_op() and + arg.op.startswith("zeroExt")): + sizes.add(arg.args[0].size) + continue + is_ok = False + break + if not is_ok: + return expr + if len(sizes) != 1: + return expr + size = list(sizes)[0] + args = [expr_s(arg[:size]) for arg in cond.args] + cond = ExprOp(cond.op, *args) + return ExprCond(cond, expr.src1, expr.src2) + + +def simp_cond_sign_bit(_, expr): + """(a & .. & 0x80000000) ? A:B => (a & ...) <s 0 ? A:B""" + cond = expr.cond + if not cond.is_op('&'): + return expr + last = cond.args[-1] + if not last.is_int(1 << (last.size - 1)): + return expr + zero = ExprInt(0, expr.cond.size) + if len(cond.args) == 2: + args = [cond.args[0], zero] + else: + args = [ExprOp('&', *list(cond.args[:-1])), zero] + cond = ExprOp(TOK_INF_SIGNED, *args) + return ExprCond(cond, expr.src1, expr.src2) + + +def simp_cond_add(expr_s, expr): + """ + (a+b)?X:Y => (a == b)?Y:X + (a^b)?X:Y => (a == b)?Y:X + """ + cond = expr.cond + if not cond.is_op(): + return expr + if cond.op not in ['+', '^']: + return expr + if len(cond.args) != 2: + return expr + arg1, arg2 = cond.args + if cond.is_op('+'): + new_cond = ExprOp('==', arg1, expr_s(-arg2)) + elif cond.is_op('^'): + new_cond = ExprOp('==', arg1, arg2) + else: + raise ValueError('Bad case') + return ExprCond(new_cond, expr.src2, expr.src1) + + +def simp_cond_eq_1_0(expr_s, expr): + """ + (a == b)?ExprInt(1, 1):ExprInt(0, 1) => a == b + (a <s b)?ExprInt(1, 1):ExprInt(0, 1) => a == b + ... + """ + cond = expr.cond + if not cond.is_op(): + return expr + if cond.op not in [ + TOK_EQUAL, + TOK_INF_SIGNED, TOK_INF_EQUAL_SIGNED, + TOK_INF_UNSIGNED, TOK_INF_EQUAL_UNSIGNED + ]: + return expr + if expr.src1 != ExprInt(1, 1) or expr.src2 != ExprInt(0, 1): + return expr + return cond + + +def simp_cond_inf_eq_unsigned_zero(expr_s, expr): + """ + (a <=u 0) => a == 0 + """ + if not expr.is_op(TOK_INF_EQUAL_UNSIGNED): + return expr + if not expr.args[1].is_int(0): + return expr + return ExprOp(TOK_EQUAL, expr.args[0], expr.args[1]) + + +def simp_test_signext_inf(expr_s, expr): + """A.signExt() <s int => A <s int[:]""" + if not (expr.is_op(TOK_INF_SIGNED) or expr.is_op(TOK_INF_EQUAL_SIGNED)): + return expr + arg, cst = expr.args + if not (arg.is_op() and arg.op.startswith("signExt")): + return expr + if not cst.is_int(): + return expr + base = arg.args[0] + tmp = int(mod_size2int[cst.size](int(cst))) + if -(1 << (base.size - 1)) <= tmp < (1 << (base.size - 1)): + # Can trunc integer + return ExprOp(expr.op, base, expr_s(cst[:base.size])) + if (tmp >= (1 << (base.size - 1)) or + tmp < -(1 << (base.size - 1)) ): + return ExprInt(1, 1) + return expr + + +def simp_test_zeroext_inf(expr_s, expr): + """A.zeroExt() <u int => A <u int[:]""" + if not (expr.is_op(TOK_INF_UNSIGNED) or expr.is_op(TOK_INF_EQUAL_UNSIGNED)): + return expr + arg, cst = expr.args + if not (arg.is_op() and arg.op.startswith("zeroExt")): + return expr + if not cst.is_int(): + return expr + base = arg.args[0] + tmp = int(mod_size2uint[cst.size](int(cst))) + if 0 <= tmp < (1 << base.size): + # Can trunc integer + return ExprOp(expr.op, base, expr_s(cst[:base.size])) + if tmp >= (1 << base.size): + return ExprInt(1, 1) + return expr + + +def simp_add_multiple(_, expr): + """ + X + X => 2 * X + X + X * int1 => X * (1 + int1) + X * int1 + (- X) => X * (int1 - 1) + X + (X << int1) => X * (1 + 2 ** int1) + Correct even if addition overflow/underflow + """ + if not expr.is_op('+'): + return expr + + # Extract each argument and its counter + operands = {} + for arg in expr.args: + if arg.is_op('*') and arg.args[1].is_int(): + base_expr, factor = arg.args + operands[base_expr] = operands.get(base_expr, 0) + int(factor) + elif arg.is_op('<<') and arg.args[1].is_int(): + base_expr, factor = arg.args + operands[base_expr] = operands.get(base_expr, 0) + 2 ** int(factor) + elif arg.is_op("-"): + arg = arg.args[0] + if arg.is_op('<<') and arg.args[1].is_int(): + base_expr, factor = arg.args + operands[base_expr] = operands.get(base_expr, 0) - (2 ** int(factor)) + else: + operands[arg] = operands.get(arg, 0) - 1 + else: + operands[arg] = operands.get(arg, 0) + 1 + out = [] + + # Best effort to factor common args: + # (a + b) * 3 + a + b => (a + b) * 4 + # Does not factor: + # (a + b) * 3 + 2 * a + b => (a + b) * 4 + a + modified = True + while modified: + modified = False + for arg, count in list(viewitems(operands)): + if not arg.is_op('+'): + continue + components = arg.args + if not all(component in operands for component in components): + continue + counters = set(operands[component] for component in components) + if len(counters) != 1: + continue + counter = counters.pop() + for component in components: + del operands[component] + operands[arg] += counter + modified = True + break + + for arg, count in viewitems(operands): + if count == 0: + continue + if count == 1: + out.append(arg) + continue + out.append(arg * ExprInt(count, expr.size)) + + if len(out) == len(expr.args): + # No reductions + return expr + if not out: + return ExprInt(0, expr.size) + if len(out) == 1: + return out[0] + return ExprOp('+', *out) + +def simp_compose_and_mask(_, expr): + """ + {X 0 8, Y 8 32} & 0xFF => zeroExt(X) + {X 0 8, Y 8 16, Z 16 32} & 0xFFFF => {X 0 8, Y 8 16, 0x0 16 32} + {X 0 8, 0x123456 8 32} & 0xFFFFFF => {X 0 8, 0x1234 8 24, 0x0 24 32} + """ + if not expr.is_op('&'): + return expr + # handle the case where arg2 = arg1.mask + if len(expr.args) != 2: + return expr + arg1, arg2 = expr.args + if not arg1.is_compose(): + return expr + if not arg2.is_int(): + return expr + int2 = int(arg2) + if (int2 + 1) & int2 != 0: + return expr + mask_size = int2.bit_length() + 7 // 8 + out = [] + for offset, arg in arg1.iter_args(): + if offset == mask_size: + return ExprCompose(*out).zeroExtend(expr.size) + elif mask_size > offset and mask_size < offset+arg.size and arg.is_int(): + out.append(ExprSlice(arg, 0, mask_size-offset)) + return ExprCompose(*out).zeroExtend(expr.size) + else: + out.append(arg) + return expr + +def simp_bcdadd_cf(_, expr): + """bcdadd(const, const) => decimal""" + if not(expr.is_op('bcdadd_cf')): + return expr + arg1 = expr.args[0] + arg2 = expr.args[1] + if not(arg1.is_int() and arg2.is_int()): + return expr + + carry = 0 + res = 0 + nib_1, nib_2 = 0, 0 + for i in range(0,16,4): + nib_1 = (arg1.arg >> i) & (0xF) + nib_2 = (arg2.arg >> i) & (0xF) + + j = (carry + nib_1 + nib_2) + if (j >= 10): + carry = 1 + j -= 10 + j &= 0xF + else: + carry = 0 + return ExprInt(carry, 1) + +def simp_bcdadd(_, expr): + """bcdadd(const, const) => decimal""" + if not(expr.is_op('bcdadd')): + return expr + arg1 = expr.args[0] + arg2 = expr.args[1] + if not(arg1.is_int() and arg2.is_int()): + return expr + + carry = 0 + res = 0 + nib_1, nib_2 = 0, 0 + for i in range(0,16,4): + nib_1 = (arg1.arg >> i) & (0xF) + nib_2 = (arg2.arg >> i) & (0xF) + + j = (carry + nib_1 + nib_2) + if (j >= 10): + carry = 1 + j -= 10 + j &= 0xF + else: + carry = 0 + res += j << i + return ExprInt(res, arg1.size) + + +def simp_smod_sext(expr_s, expr): + """ + a.size == b.size + smod(a.signExtend(X), b.signExtend(X)) => smod(a, b).signExtend(X) + """ + if not expr.is_op("smod"): + return expr + arg1, arg2 = expr.args + if arg1.is_op() and arg1.op.startswith("signExt"): + src1 = arg1.args[0] + if arg2.is_op() and arg2.op.startswith("signExt"): + src2 = arg2.args[0] + if src1.size == src2.size: + # Case: a.signext(), b.signext() + return ExprOp("smod", src1, src2).signExtend(expr.size) + return expr + elif arg2.is_int(): + src2 = expr_s.expr_simp(arg2[:src1.size]) + if expr_s.expr_simp(src2.signExtend(arg2.size)) == arg2: + # Case: a.signext(), int + return ExprOp("smod", src1, src2).signExtend(expr.size) + return expr + # Case: int , b.signext() + if arg2.is_op() and arg2.op.startswith("signExt"): + src2 = arg2.args[0] + if arg1.is_int(): + src1 = expr_s.expr_simp(arg1[:src2.size]) + if expr_s.expr_simp(src1.signExtend(arg1.size)) == arg1: + # Case: int, b.signext() + return ExprOp("smod", src1, src2).signExtend(expr.size) + return expr + +# FLAG_SUB_OF(CST1, CST2) => CST +def simp_flag_cst(expr_simp, expr): + if expr.op not in [ + "FLAG_EQ", "FLAG_EQ_AND", "FLAG_SIGN_SUB", "FLAG_EQ_CMP", "FLAG_ADD_CF", + "FLAG_SUB_CF", "FLAG_ADD_OF", "FLAG_SUB_OF", "FLAG_EQ_ADDWC", "FLAG_ADDWC_OF", + "FLAG_SUBWC_OF", "FLAG_ADDWC_CF", "FLAG_SUBWC_CF", "FLAG_SIGN_ADDWC", + "FLAG_SIGN_SUBWC", "FLAG_EQ_SUBWC", + "CC_U<=", "CC_U>=", "CC_S<", "CC_S>", "CC_S<=", "CC_S>=", "CC_U>", + "CC_U<", "CC_NEG", "CC_EQ", "CC_NE", "CC_POS" + ]: + return expr + if not all(arg.is_int() for arg in expr.args): + return expr + new_expr = expr_simp(simp_flags(expr_simp, expr)) + return new_expr diff --git a/src/miasm/expression/simplifications_cond.py b/src/miasm/expression/simplifications_cond.py new file mode 100644 index 00000000..6167cb4d --- /dev/null +++ b/src/miasm/expression/simplifications_cond.py @@ -0,0 +1,178 @@ +################################################################################ +# +# By choice, Miasm2 does not handle comparison as a single operation, but with +# operations corresponding to comparison computation. +# One may want to detect those comparison; this library is designed to add them +# in Miasm2 engine thanks to : +# - Conditions computation in ExprOp +# - Simplifications to catch known condition forms +# +# Conditions currently supported : +# <u, <s, == +# +# Authors : Fabrice DESCLAUX (CEA/DAM), Camille MOUGEY (CEA/DAM) +# +################################################################################ + +import miasm.expression.expression as m2_expr + + +# Jokers for expression matching + +jok1 = m2_expr.ExprId("jok1", 32) +jok2 = m2_expr.ExprId("jok2", 32) +jok3 = m2_expr.ExprId("jok3", 32) +jok_small = m2_expr.ExprId("jok_small", 1) + + +# Constructors + +def __ExprOp_cond(op, arg1, arg2): + "Return an ExprOp standing for arg1 op arg2 with size to 1" + ec = m2_expr.ExprOp(op, arg1, arg2) + return ec + + +def ExprOp_inf_signed(arg1, arg2): + "Return an ExprOp standing for arg1 <s arg2" + return __ExprOp_cond(m2_expr.TOK_INF_SIGNED, arg1, arg2) + + +def ExprOp_inf_unsigned(arg1, arg2): + "Return an ExprOp standing for arg1 <s arg2" + return __ExprOp_cond(m2_expr.TOK_INF_UNSIGNED, arg1, arg2) + +def ExprOp_equal(arg1, arg2): + "Return an ExprOp standing for arg1 == arg2" + return __ExprOp_cond(m2_expr.TOK_EQUAL, arg1, arg2) + + +# Catching conditions forms + +def __check_msb(e): + """If @e stand for the most significant bit of its arg, return the arg; + False otherwise""" + + if not isinstance(e, m2_expr.ExprSlice): + return False + + arg = e.arg + if e.start != (arg.size - 1) or e.stop != arg.size: + return False + + return arg + +def __match_expr_wrap(e, to_match, jok_list): + "Wrapper around match_expr to canonize pattern" + + to_match = to_match.canonize() + + r = m2_expr.match_expr(e, to_match, jok_list) + if r is False: + return False + + if r == {}: + return False + + return r + +def expr_simp_inf_signed(expr_simp, e): + "((x - y) ^ ((x ^ y) & ((x - y) ^ x))) [31:32] == x <s y" + + arg = __check_msb(e) + if arg is False: + return e + # We want jok3 = jok1 - jok2 + to_match = jok3 ^ ((jok1 ^ jok2) & (jok3 ^ jok1)) + r = __match_expr_wrap(arg, + to_match, + [jok1, jok2, jok3]) + + if r is False: + return e + + new_j3 = expr_simp(r[jok3]) + sub = expr_simp(r[jok1] - r[jok2]) + + if new_j3 == sub: + return ExprOp_inf_signed(r[jok1], r[jok2]) + else: + return e + +def expr_simp_inf_unsigned_inversed(expr_simp, e): + "((x - y) ^ ((x ^ y) & ((x - y) ^ x))) ^ x ^ y [31:32] == x <u y" + + arg = __check_msb(e) + if arg is False: + return e + + # We want jok3 = jok1 - jok2 + to_match = jok3 ^ ((jok1 ^ jok2) & (jok3 ^ jok1)) ^ jok1 ^ jok2 + r = __match_expr_wrap(arg, + to_match, + [jok1, jok2, jok3]) + + if r is False: + return e + + new_j3 = expr_simp(r[jok3]) + sub = expr_simp(r[jok1] - r[jok2]) + + if new_j3 == sub: + return ExprOp_inf_unsigned(r[jok1], r[jok2]) + else: + return e + +def expr_simp_inverse(expr_simp, e): + """(x <u y) ^ ((x ^ y) [31:32]) == x <s y, + (x <s y) ^ ((x ^ y) [31:32]) == x <u y""" + + to_match = (ExprOp_inf_unsigned(jok1, jok2) ^ jok_small) + r = __match_expr_wrap(e, + to_match, + [jok1, jok2, jok_small]) + + # Check for 2 symmetric cases + if r is False: + to_match = (ExprOp_inf_signed(jok1, jok2) ^ jok_small) + r = __match_expr_wrap(e, + to_match, + [jok1, jok2, jok_small]) + + if r is False: + return e + cur_sig = m2_expr.TOK_INF_SIGNED + else: + cur_sig = m2_expr.TOK_INF_UNSIGNED + + + arg = __check_msb(r[jok_small]) + if arg is False: + return e + + if not isinstance(arg, m2_expr.ExprOp) or arg.op != "^": + return e + + op_args = arg.args + if len(op_args) != 2: + return e + + if r[jok1] not in op_args or r[jok2] not in op_args: + return e + + if cur_sig == m2_expr.TOK_INF_UNSIGNED: + return ExprOp_inf_signed(r[jok1], r[jok2]) + else: + return ExprOp_inf_unsigned(r[jok1], r[jok2]) + +def expr_simp_equal(expr_simp, e): + """(x - y)?(0:1) == (x == y)""" + + to_match = m2_expr.ExprCond(jok1 + jok2, m2_expr.ExprInt(0, 1), m2_expr.ExprInt(1, 1)) + r = __match_expr_wrap(e, + to_match, + [jok1, jok2]) + if r is False: + return e + + return ExprOp_equal(r[jok1], expr_simp(-r[jok2])) diff --git a/src/miasm/expression/simplifications_explicit.py b/src/miasm/expression/simplifications_explicit.py new file mode 100644 index 00000000..1f9d2dbe --- /dev/null +++ b/src/miasm/expression/simplifications_explicit.py @@ -0,0 +1,159 @@ +from miasm.core.utils import size2mask +from miasm.expression.expression import ExprInt, ExprCond, ExprCompose, \ + TOK_EQUAL + + +def simp_ext(_, expr): + if expr.op.startswith('zeroExt_'): + arg = expr.args[0] + if expr.size == arg.size: + return arg + return ExprCompose(arg, ExprInt(0, expr.size - arg.size)) + + if expr.op.startswith("signExt_"): + arg = expr.args[0] + add_size = expr.size - arg.size + new_expr = ExprCompose( + arg, + ExprCond( + arg.msb(), + ExprInt(size2mask(add_size), add_size), + ExprInt(0, add_size) + ) + ) + return new_expr + return expr + + +def simp_flags(_, expr): + args = expr.args + + if expr.is_op("FLAG_EQ"): + return ExprCond(args[0], ExprInt(0, 1), ExprInt(1, 1)) + + elif expr.is_op("FLAG_EQ_AND"): + op1, op2 = args + return ExprCond(op1 & op2, ExprInt(0, 1), ExprInt(1, 1)) + + elif expr.is_op("FLAG_SIGN_SUB"): + return (args[0] - args[1]).msb() + + elif expr.is_op("FLAG_EQ_CMP"): + return ExprCond( + args[0] - args[1], + ExprInt(0, 1), + ExprInt(1, 1), + ) + + elif expr.is_op("FLAG_ADD_CF"): + op1, op2 = args + res = op1 + op2 + return (((op1 ^ op2) ^ res) ^ ((op1 ^ res) & (~(op1 ^ op2)))).msb() + + elif expr.is_op("FLAG_SUB_CF"): + op1, op2 = args + res = op1 - op2 + return (((op1 ^ op2) ^ res) ^ ((op1 ^ res) & (op1 ^ op2))).msb() + + elif expr.is_op("FLAG_ADD_OF"): + op1, op2 = args + res = op1 + op2 + return (((op1 ^ res) & (~(op1 ^ op2)))).msb() + + elif expr.is_op("FLAG_SUB_OF"): + op1, op2 = args + res = op1 - op2 + return (((op1 ^ res) & (op1 ^ op2))).msb() + + elif expr.is_op("FLAG_EQ_ADDWC"): + op1, op2, op3 = args + return ExprCond( + op1 + op2 + op3.zeroExtend(op1.size), + ExprInt(0, 1), + ExprInt(1, 1), + ) + + elif expr.is_op("FLAG_ADDWC_OF"): + op1, op2, op3 = args + res = op1 + op2 + op3.zeroExtend(op1.size) + return (((op1 ^ res) & (~(op1 ^ op2)))).msb() + + elif expr.is_op("FLAG_SUBWC_OF"): + op1, op2, op3 = args + res = op1 - (op2 + op3.zeroExtend(op1.size)) + return (((op1 ^ res) & (op1 ^ op2))).msb() + + elif expr.is_op("FLAG_ADDWC_CF"): + op1, op2, op3 = args + res = op1 + op2 + op3.zeroExtend(op1.size) + return (((op1 ^ op2) ^ res) ^ ((op1 ^ res) & (~(op1 ^ op2)))).msb() + + elif expr.is_op("FLAG_SUBWC_CF"): + op1, op2, op3 = args + res = op1 - (op2 + op3.zeroExtend(op1.size)) + return (((op1 ^ op2) ^ res) ^ ((op1 ^ res) & (op1 ^ op2))).msb() + + elif expr.is_op("FLAG_SIGN_ADDWC"): + op1, op2, op3 = args + return (op1 + op2 + op3.zeroExtend(op1.size)).msb() + + elif expr.is_op("FLAG_SIGN_SUBWC"): + op1, op2, op3 = args + return (op1 - (op2 + op3.zeroExtend(op1.size))).msb() + + + elif expr.is_op("FLAG_EQ_SUBWC"): + op1, op2, op3 = args + res = op1 - (op2 + op3.zeroExtend(op1.size)) + return ExprCond(res, ExprInt(0, 1), ExprInt(1, 1)) + + elif expr.is_op("CC_U<="): + op_cf, op_zf = args + return op_cf | op_zf + + elif expr.is_op("CC_U>="): + op_cf, = args + return ~op_cf + + elif expr.is_op("CC_S<"): + op_nf, op_of = args + return op_nf ^ op_of + + elif expr.is_op("CC_S>"): + op_nf, op_of, op_zf = args + return ~(op_zf | (op_nf ^ op_of)) + + elif expr.is_op("CC_S<="): + op_nf, op_of, op_zf = args + return op_zf | (op_nf ^ op_of) + + elif expr.is_op("CC_S>="): + op_nf, op_of = args + return ~(op_nf ^ op_of) + + elif expr.is_op("CC_U>"): + op_cf, op_zf = args + return ~(op_cf | op_zf) + + elif expr.is_op("CC_U<"): + op_cf, = args + return op_cf + + elif expr.is_op("CC_NEG"): + op_nf, = args + return op_nf + + elif expr.is_op("CC_EQ"): + op_zf, = args + return op_zf + + elif expr.is_op("CC_NE"): + op_zf, = args + return ~op_zf + + elif expr.is_op("CC_POS"): + op_nf, = args + return ~op_nf + + return expr + diff --git a/src/miasm/expression/smt2_helper.py b/src/miasm/expression/smt2_helper.py new file mode 100644 index 00000000..53d323e8 --- /dev/null +++ b/src/miasm/expression/smt2_helper.py @@ -0,0 +1,296 @@ +# Helper functions for the generation of SMT2 expressions +# The SMT2 expressions will be returned as a string. +# The expressions are divided as follows +# +# - generic SMT2 operations +# - definitions of SMT2 structures +# - bit vector operations +# - array operations + +# generic SMT2 operations + +def smt2_eq(a, b): + """ + Assignment: a = b + """ + return "(= {} {})".format(a, b) + + +def smt2_implies(a, b): + """ + Implication: a => b + """ + return "(=> {} {})".format(a, b) + + +def smt2_and(*args): + """ + Conjunction: a and b and c ... + """ + # transform args into strings + args = [str(arg) for arg in args] + return "(and {})".format(' '.join(args)) + + +def smt2_or(*args): + """ + Disjunction: a or b or c ... + """ + # transform args into strings + args = [str(arg) for arg in args] + return "(or {})".format(' '.join(args)) + + +def smt2_ite(cond, a, b): + """ + If-then-else: cond ? a : b + """ + return "(ite {} {} {})".format(cond, a, b) + + +def smt2_distinct(*args): + """ + Distinction: a != b != c != ... + """ + # transform args into strings + args = [str(arg) for arg in args] + return "(distinct {})".format(' '.join(args)) + + +def smt2_assert(expr): + """ + Assertion that @expr holds + """ + return "(assert {})".format(expr) + + +# definitions + +def declare_bv(bv, size): + """ + Declares an bit vector @bv of size @size + """ + return "(declare-fun {} () {})".format(bv, bit_vec(size)) + + +def declare_array(a, bv1, bv2): + """ + Declares an SMT2 array represented as a map + from a bit vector to another bit vector. + :param a: array name + :param bv1: SMT2 bit vector + :param bv2: SMT2 bit vector + """ + return "(declare-fun {} () (Array {} {}))".format(a, bv1, bv2) + + +def bit_vec_val(v, size): + """ + Declares a bit vector value + :param v: int, value of the bit vector + :param size: size of the bit vector + """ + return "(_ bv{} {})".format(v, size) + + +def bit_vec(size): + """ + Returns a bit vector of size @size + """ + return "(_ BitVec {})".format(size) + + +# bit vector operations + +def bvadd(a, b): + """ + Addition: a + b + """ + return "(bvadd {} {})".format(a, b) + + +def bvsub(a, b): + """ + Subtraction: a - b + """ + return "(bvsub {} {})".format(a, b) + + +def bvmul(a, b): + """ + Multiplication: a * b + """ + return "(bvmul {} {})".format(a, b) + + +def bvand(a, b): + """ + Bitwise AND: a & b + """ + return "(bvand {} {})".format(a, b) + + +def bvor(a, b): + """ + Bitwise OR: a | b + """ + return "(bvor {} {})".format(a, b) + + +def bvxor(a, b): + """ + Bitwise XOR: a ^ b + """ + return "(bvxor {} {})".format(a, b) + + +def bvneg(bv): + """ + Unary minus: - bv + """ + return "(bvneg {})".format(bv) + + +def bvsdiv(a, b): + """ + Signed division: a / b + """ + return "(bvsdiv {} {})".format(a, b) + + +def bvudiv(a, b): + """ + Unsigned division: a / b + """ + return "(bvudiv {} {})".format(a, b) + + +def bvsmod(a, b): + """ + Signed modulo: a mod b + """ + return "(bvsmod {} {})".format(a, b) + + +def bvurem(a, b): + """ + Unsigned modulo: a mod b + """ + return "(bvurem {} {})".format(a, b) + + +def bvshl(a, b): + """ + Shift left: a << b + """ + return "(bvshl {} {})".format(a, b) + + +def bvlshr(a, b): + """ + Logical shift right: a >> b + """ + return "(bvlshr {} {})".format(a, b) + + +def bvashr(a, b): + """ + Arithmetic shift right: a a>> b + """ + return "(bvashr {} {})".format(a, b) + + +def bv_rotate_left(a, b, size): + """ + Rotates bits of a to the left b times: a <<< b + + Since ((_ rotate_left b) a) does not support + symbolic values for b, the implementation is + based on a C implementation. + + Therefore, the rotation will be computed as + a << (b & (size - 1))) | (a >> (size - (b & (size - 1)))) + + :param a: bit vector + :param b: bit vector + :param size: size of a + """ + + # define constant + s = bit_vec_val(size, size) + + # shift = b & (size - 1) + shift = bvand(b, bvsub(s, bit_vec_val(1, size))) + + # (a << shift) | (a >> size - shift) + rotate = bvor(bvshl(a, shift), + bvlshr(a, bvsub(s, shift))) + + return rotate + + +def bv_rotate_right(a, b, size): + """ + Rotates bits of a to the right b times: a >>> b + + Since ((_ rotate_right b) a) does not support + symbolic values for b, the implementation is + based on a C implementation. + + Therefore, the rotation will be computed as + a >> (b & (size - 1))) | (a << (size - (b & (size - 1)))) + + :param a: bit vector + :param b: bit vector + :param size: size of a + """ + + # define constant + s = bit_vec_val(size, size) + + # shift = b & (size - 1) + shift = bvand(b, bvsub(s, bit_vec_val(1, size))) + + # (a >> shift) | (a << size - shift) + rotate = bvor(bvlshr(a, shift), + bvshl(a, bvsub(s, shift))) + + return rotate + + +def bv_extract(high, low, bv): + """ + Extracts bits from a bit vector + :param high: end bit + :param low: start bit + :param bv: bit vector + """ + return "((_ extract {} {}) {})".format(high, low, bv) + + +def bv_concat(a, b): + """ + Concatenation of two SMT2 expressions + """ + return "(concat {} {})".format(a, b) + + +# array operations + +def array_select(array, index): + """ + Reads from an SMT2 array at index @index + :param array: SMT2 array + :param index: SMT2 expression, index of the array + """ + return "(select {} {})".format(array, index) + + +def array_store(array, index, value): + """ + Writes an value into an SMT2 array at address @index + :param array: SMT array + :param index: SMT2 expression, index of the array + :param value: SMT2 expression, value to write + """ + return "(store {} {} {})".format(array, index, value) |