diff options
| author | serpilliere <devnull@localhost> | 2014-06-03 10:27:56 +0200 |
|---|---|---|
| committer | serpilliere <devnull@localhost> | 2014-06-03 10:27:56 +0200 |
| commit | ed5c3668cc9f545b52674ad699fc2b0ed1ccb575 (patch) | |
| tree | 07faf97d7e4d083173a1f7e1bfd249baed2d74f9 /miasm2/expression | |
| parent | a183e1ebd525453710306695daa8c410fd0cb2af (diff) | |
| download | miasm-ed5c3668cc9f545b52674ad699fc2b0ed1ccb575.tar.gz miasm-ed5c3668cc9f545b52674ad699fc2b0ed1ccb575.zip | |
Miasm v2
* API has changed, so old scripts need updates * See example for API usage * Use tcc or llvm for jit emulation * Go to test and run test_all.py to check install Enjoy !
Diffstat (limited to '')
| -rw-r--r-- | miasm2/expression/__init__.py | 18 | ||||
| -rw-r--r-- | miasm2/expression/expression.py | 1253 | ||||
| -rw-r--r-- | miasm2/expression/expression_helper.py | 196 | ||||
| -rw-r--r-- | miasm2/expression/modint.py | 224 | ||||
| -rw-r--r-- | miasm2/expression/simplifications.py | 605 | ||||
| -rw-r--r-- | miasm2/expression/stp.py | 68 |
6 files changed, 2364 insertions, 0 deletions
diff --git a/miasm2/expression/__init__.py b/miasm2/expression/__init__.py new file mode 100644 index 00000000..fbabaacf --- /dev/null +++ b/miasm2/expression/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# +# Copyright (C) 2011 EADS France, Fabrice Desclaux <fabrice.desclaux@eads.net> +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with this program; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# diff --git a/miasm2/expression/expression.py b/miasm2/expression/expression.py new file mode 100644 index 00000000..3d73ee10 --- /dev/null +++ b/miasm2/expression/expression.py @@ -0,0 +1,1253 @@ +# +# Copyright (C) 2011 EADS France, Fabrice Desclaux <fabrice.desclaux@eads.net> +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with this program; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# These module implements Miasm IR components and basic operations related. +# IR components are : +# - ExprInt +# - ExprId +# - ExprAff +# - ExprCond +# - ExprMem +# - ExprOp +# - ExprSlice +# - ExprCompose +# + + +import itertools +from miasm2.expression.modint import * +from miasm2.core.graph import DiGraph + + +def visit_chk(visitor): + "Function decorator launching callback on Expression visit" + def wrapped(e, cb, test_visit=lambda x: True): + if (test_visit is not None) and (not test_visit(e)): + return e + e_new = visitor(e, cb, test_visit) + if e_new is None: + return None + e_new2 = cb(e_new) + return e_new2 + return wrapped + +# Hashing constants +EXPRINT = 1 +EXPRID = 2 +EXPRAFF = 3 +EXPRCOND = 4 +EXPRMEM = 5 +EXPROP = 6 +EXPRSLICE = 5 +EXPRCOMPOSE = 5 + +# Expression display + + +class DiGraphExpr(DiGraph): + + """Enhanced graph for Expression diplay + Expression are displayed as a tree with node and edge labeled + with only relevant information""" + + def node2str(self, node): + if isinstance(node, ExprOp): + return node.op + elif isinstance(node, ExprId): + return node.name + elif isinstance(node, ExprMem): + return "@%d" % node.size + elif isinstance(node, ExprCompose): + return "{ %d }" % node.size + elif isinstance(node, ExprCond): + return "? %d" % node.size + elif isinstance(node, ExprSlice): + return "[%d:%d]" % (node.start, node.stop) + return str(node) + + def edge2str(self, nfrom, nto): + if isinstance(nfrom, ExprCompose): + for i in nfrom.args: + if i[0] == nto: + return "[%s, %s]" % (i[1], i[2]) + elif isinstance(nfrom, ExprCond): + if nfrom.cond == nto: + return "?" + elif nfrom.src1 == nto: + return "True" + elif nfrom.src2 == nto: + return "False" + + return "" + + +# IR definitions + +class Expr(object): + + "Parent class for Miasm Expressions" + + is_term = False # Terminal expression + is_simp = False # Expression already simplified + is_canon = False # Expression already canonised + is_eval = False # Expression already evalued + + def set_size(self, value): + raise ValueError('size is not mutable') + size = property(lambda self: self._size) + + def __init__(self, arg): + self.arg = arg + + # Common operations + def __str__(self): + return str(self.arg) + + def __getitem__(self, i): + if not isinstance(i, slice): + raise TypeError("Expression: Bad slice: %s" % i) + start, stop, step = i.indices(self.size) + if step != 1: + raise ValueError("Expression: Bad slice: %s" % i) + return ExprSlice(self, start, stop) + + def get_size(self): + raise DeprecationWarning("use X.size instead of X.get_size()") + + def get_r(self, mem_read=False, cst_read=False): + return self.arg.get_r(mem_read, cst_read) + + def get_w(self): + return self.arg.get_w() + + def __repr__(self): + return "<%s_%d_0x%x>" % (self.__class__.__name__, self.size, id(self)) + + def __hash__(self): + return self._hash + + def __eq__(self, a): + if isinstance(a, Expr): + return self._hash == a._hash + else: + return False + + def __ne__(self, a): + return not self.__eq__(a) + + def __add__(self, a): + return ExprOp('+', self, a) + + def __sub__(self, a): + return ExprOp('+', self, ExprOp('-', a)) + + def __div__(self, a): + return ExprOp('/', self, a) + + def __mod__(self, a): + return ExprOp('%', self, a) + + def __mul__(self, a): + return ExprOp('*', self, a) + + def __lshift__(self, a): + return ExprOp('<<', self, a) + + def __rshift__(self, a): + return ExprOp('>>', self, a) + + def __xor__(self, a): + return ExprOp('^', self, a) + + def __or__(self, a): + return ExprOp('|', self, a) + + def __and__(self, a): + return ExprOp('&', self, a) + + def __neg__(self): + return ExprOp('-', self) + + def __invert__(self): + s = self.size + return ExprOp('^', self, ExprInt(mod_size2uint[s](size2mask(s)))) + + def copy(self): + "Deep copy of the expression" + return self.visit(lambda x: x) + + def replace_expr(self, dct=None): + """Find and replace sub expression using dct + @dct: dictionnary of Expr -> * + """ + if dct is None: + dct = {} + + def my_replace(e, dct): + if e in dct: + return dct[e] + return e + return self.visit(lambda e: my_replace(e, dct)) + + def canonize(self): + "Canonize the Expression" + + def must_canon(e): + # print 'test VISIT', e + return not e.is_simp + + def my_canon(e): + if e.is_simp: + return e + if isinstance(e, ExprOp): + if e.is_associative(): + # ((a+b) + c) => (a + b + c) + args = [] + for a in e.args: + if isinstance(a, ExprOp) and e.op == a.op: + args += a.args + else: + args.append(a) + args = canonize_expr_list(args) + new_e = ExprOp(e.op, *args) + else: + new_e = e + elif isinstance(e, ExprCompose): + new_e = ExprCompose(canonize_expr_list_compose(e.args)) + else: + new_e = e + return new_e + return self.visit(my_canon, must_canon) + + def msb(self): + "Return the Most Significant Bit" + s = self.size + return self[s - 1:s] + + def zeroExtend(self, size): + """Zero extend to size + @size: int + """ + assert(self.size <= size) + if self.size == size: + return self + ad_size = size - self.size + n = ExprInt_fromsize(ad_size, 0) + return ExprCompose([(self, 0, self.size), + (n, self.size, size)]) + + def signExtend(self, size): + """Sign extend to size + @size: int + """ + assert(self.size <= size) + if self.size == size: + return self + ad_size = size - self.size + c = ExprCompose([(self, 0, self.size), + (ExprCond(self.msb(), + ExprInt_fromsize( + ad_size, size2mask(ad_size)), + ExprInt_fromsize(ad_size, 0)), + self.size, size) + ]) + return c + + def graph_recursive(self, graph): + """Recursive method used by graph + @graph: miasm2.core.graph.DiGraph instance + Update @graph instance to include sons + This is an Abstract method""" + + raise ValueError("Abstract method") + + def graph(self): + """Return a DiGraph instance standing for Expr tree + Instance's display functions have been override for better visibility + Wrapper on graph_recursive""" + + # Create recursively the graph + graph = DiGraphExpr() + self.graph_recursive(graph) + + return graph + + def set_mask(self, value): + raise ValueError('mask is not mutable') + mask = property(lambda self: ExprInt_fromsize(self.size, -1)) + + +class ExprInt(Expr): + + """An ExprInt represent a constant in Miasm IR. + + Some use cases: + - Constant 0x42 + - Constant -0x30 + - Constant 0x12345678 on 32bits + """ + + def __init__(self, arg): + """Create an ExprInt from a numpy int + @arg: numpy int""" + + if not is_modint(arg): + raise ValueError('arg must by numpy int! %s' % arg) + + self.arg = arg + self._size = self.arg.size + self._hash = self.myhash() + + def __get_int(self): + "Return self integer representation" + return int(self.arg & size2mask(self.size)) + + def __str__(self): + if self.arg < 0: + return str("-0x%X" % (- self.__get_int())) + else: + return str("0x%X" % self.__get_int()) + + def get_r(self, mem_read=False, cst_read=False): + if cst_read: + return set([self]) + else: + return set() + + def get_w(self): + return set() + + def __contains__(self, e): + return self == e + + def myhash(self): + return hash((EXPRINT, self.arg, self.size)) + + def __repr__(self): + return Expr.__repr__(self)[:-1] + " 0x%X>" % self.__get_int() + + @visit_chk + def visit(self, cb, tv=None): + return self + + def copy(self): + return ExprInt(self.arg) + + def depth(self): + return 1 + + def graph_recursive(self, graph): + graph.add_node(self) + + +class ExprId(Expr): + + """An ExprId represent an identifier in Miasm IR. + + Some use cases: + - EAX register + - 'start' offset + - variable v1 + """ + + def __init__(self, name, size=32, is_term=False): + """Create an identifier + @name: str, identifier's name + @size: int, identifier's size + @is_term: boolean, is the identifier a terminal expression ? + """ + + self.name, self._size = name, size + self.is_term = is_term + self._hash = self.myhash() + + def __str__(self): + return str(self.name) + + def get_r(self, mem_read=False, cst_read=False): + return set([self]) + + def get_w(self): + return set([self]) + + def __contains__(self, e): + return self == e + + def myhash(self): + # TODO XXX: hash size ?? + return hash((EXPRID, self.name, self._size)) + + def __repr__(self): + return Expr.__repr__(self)[:-1] + " %s>" % self.name + + @visit_chk + def visit(self, cb, tv=None): + return self + + def copy(self): + return ExprId(self.name, self._size) + + def depth(self): + return 1 + + def graph_recursive(self, graph): + graph.add_node(self) + + +class ExprAff(Expr): + + """An ExprAff represent an affection from an Expression to another one. + + Some use cases: + - var1 <- 2 + """ + + def __init__(self, dst, src): + """Create an ExprAff for dst <- src + @dst: Expr, affectation destination + @src: Expr, affectation source + """ + + if dst.size != src.size: + raise ValueError( + "sanitycheck: ExprAff args must have same size! %s" % + ([(str(x), x.size) for x in [dst, src]])) + + if isinstance(dst, ExprSlice): + # Complete the source with missing slice parts + self.dst = dst.arg + rest = [(ExprSlice(dst.arg, r[0], r[1]), r[0], r[1]) + for r in dst.slice_rest()] + all_a = [(src, dst.start, dst.stop)] + rest + all_a.sort(key=lambda x: x[1]) + self.src = ExprCompose(all_a) + + else: + self.dst, self.src = dst, src + + self._hash = self.myhash() + self._size = self.dst.size + + def __str__(self): + return "%s = %s" % (str(self.dst), str(self.src)) + + def get_r(self, mem_read=False, cst_read=False): + r = self.src.get_r(mem_read, cst_read) + if isinstance(self.dst, ExprMem): + r.update(self.dst.arg.get_r(mem_read, cst_read)) + return r + + def get_w(self): + if isinstance(self.dst, ExprMem): + return set([self.dst]) # [memreg] + else: + return self.dst.get_w() + + def __contains__(self, e): + return self == e or self.src.__contains__(e) or self.dst.__contains__(e) + + def myhash(self): + return hash((EXPRAFF, self.dst._hash, self.src._hash)) + + # XXX /!\ for hackish expraff to slice + def get_modified_slice(self): + """Return an Expr list of extra expressions needed during the + object instanciation""" + + dst = self.dst + if not isinstance(self.src, ExprCompose): + raise ValueError("Get mod slice not on expraff slice", str(self)) + modified_s = [] + for x in self.src.args: + if (not isinstance(x[0], ExprSlice) or + x[0].arg != dst or + x[1] != x[0].start or + x[2] != x[0].stop): + # If x is not the initial expression + modified_s.append(x) + return modified_s + + @visit_chk + def visit(self, cb, tv=None): + dst, src = self.dst.visit(cb, tv), self.src.visit(cb, tv) + if dst == self.dst and src == self.src: + return self + else: + return ExprAff(dst, src) + + def copy(self): + return ExprAff(self.dst.copy(), self.src.copy()) + + def depth(self): + return max(self.src.depth(), self.dst.depth()) + 1 + + def graph_recursive(self, graph): + graph.add_node(self) + for a in [self.src, self.dst]: + a.graph_recursive(graph) + graph.add_uniq_edge(self, a) + + +class ExprCond(Expr): + + """An ExprCond stand for a condition on an Expr + + Use cases: + - var1 < var2 + - min(var1, var2) + - if (cond) then ... else ... + """ + + def __init__(self, cond, src1, src2): + """Create an ExprCond + @cond: Expr, condition + @src1: Expr, value if condition is evaled to not zero + @src2: Expr, value if condition is evaled zero + """ + + self.cond, self.src1, self.src2 = cond, src1, src2 + assert(src1.size == src2.size) + self._hash = self.myhash() + self._size = self.src1.size + + def __str__(self): + return "%s?(%s,%s)" % (str(self.cond), str(self.src1), str(self.src2)) + + def get_r(self, mem_read=False, cst_read=False): + out_src1 = self.src1.get_r(mem_read, cst_read) + out_src2 = self.src2.get_r(mem_read, cst_read) + return self.cond.get_r(mem_read, + cst_read).union(out_src1).union(out_src2) + + def get_w(self): + return set() + + def __contains__(self, e): + return (self == e or + self.cond.__contains__(e) or + self.src1.__contains__(e) or + self.src2.__contains__(e)) + + def myhash(self): + return hash((EXPRCOND, self.cond._hash, + self.src1._hash, self.src2._hash)) + + @visit_chk + def visit(self, cb, tv=None): + cond = self.cond.visit(cb, tv) + src1 = self.src1.visit(cb, tv) + src2 = self.src2.visit(cb, tv) + if cond == self.cond and \ + src1 == self.src1 and \ + src2 == self.src2: + return self + return ExprCond(cond, src1, src2) + + def copy(self): + return ExprCond(self.cond.copy(), + self.src1.copy(), + self.src2.copy()) + + def depth(self): + return max(self.cond.depth(), + self.src1.depth(), + self.src2.depth()) + 1 + + def graph_recursive(self, graph): + graph.add_node(self) + for a in [self.cond, self.src1, self.src2]: + a.graph_recursive(graph) + graph.add_uniq_edge(self, a) + + +class ExprMem(Expr): + + """An ExprMem stand for a memory access + + Use cases: + - Memory read + - Memory write + """ + + def __init__(self, arg, size=32): + """Create an ExprMem + @arg: Expr, memory access address + @size: int, memory access size + """ + if not isinstance(arg, Expr): + raise ValueError( + 'ExprMem: arg must be an Expr (not %s)' % type(arg)) + + self.arg, self._size = arg, size + self._hash = self.myhash() + + def __str__(self): + return "@%d[%s]" % (self._size, str(self.arg)) + + def get_r(self, mem_read=False, cst_read=False): + if mem_read: + return set(self.arg.get_r(mem_read, cst_read).union(set([self]))) + else: + return set([self]) + + def get_w(self): + return set([self]) # [memreg] + + def __contains__(self, e): + return self == e or self.arg.__contains__(e) + + def myhash(self): + return hash((EXPRMEM, self.arg._hash, self._size)) + + @visit_chk + def visit(self, cb, tv=None): + arg = self.arg.visit(cb, tv) + if arg == self.arg: + return self + return ExprMem(arg, self._size) + + def copy(self): + arg = self.arg.copy() + return ExprMem(arg, size=self._size) + + def is_op_segm(self): + return isinstance(self.arg, ExprOp) and self.arg.op == 'segm' + + def depth(self): + return self.arg.depth() + 1 + + def graph_recursive(self, graph): + graph.add_node(self) + self.arg.graph_recursive(graph) + graph.add_uniq_edge(self, self.arg) + + +class ExprOp(Expr): + + """An ExprOp stand for an operation between Expr + + Use cases: + - var1 XOR var2 + - var1 + var2 + var3 + - parity bit(var1) + """ + + def __init__(self, op, *args): + """Create an ExprOp + @op: str, operation + @*args: Expr, operand list + """ + + sizes = set([x.size for x in args]) + + if None not in sizes and len(sizes) != 1: + # Special cases : operande sizes can differ + if op not in ["segm"]: + raise ValueError( + "sanitycheck: ExprOp args must have same size! %s" % + ([(str(x), x.size) for x in args])) + + if not isinstance(op, str): + raise ValueError("ExprOp: 'op' argument must be a string") + + self.op, self.args = op, tuple(args) + self._hash = self.myhash() + + # Set size for special cases + if self.op in [ + '==', 'parity', 'fcom_c0', 'fcom_c1', 'fcom_c2', 'fcom_c3', + "access_segment_ok", "load_segment_limit_ok", "bcdadd_cf", + "ucomiss_zf", "ucomiss_pf", "ucomiss_cf"]: + sz = 1 + elif self.op in ['mem_16_to_double', 'mem_32_to_double', + 'mem_64_to_double', 'mem_80_to_double', + 'int_16_to_double', 'int_32_to_double', + 'int_64_to_double', 'int_80_to_double']: + sz = 64 + elif self.op in ['double_to_mem_16', 'double_to_int_16']: + sz = 16 + elif self.op in ['double_to_mem_32', 'double_to_int_32']: + sz = 32 + elif self.op in ['double_to_mem_64', 'double_to_int_64']: + sz = 64 + elif self.op in ['double_to_mem_80', 'double_to_int_80']: + sz = 80 + elif self.op in ['segm']: + sz = self.args[1].size + else: + if None in sizes: + sz = None + else: + # All arguments have the same size + sz = list(sizes)[0] + + self._size = sz + + def __str__(self): + if self.is_associative(): + return '(' + self.op.join([str(x) for x in self.args]) + ')' + if len(self.args) == 2: + return '(' + str(self.args[0]) + \ + ' ' + self.op + ' ' + str(self.args[1]) + ')' + elif len(self.args) > 2: + return self.op + '(' + ', '.join([str(x) for x in self.args]) + ')' + else: + return reduce(lambda x, y: x + ' ' + str(y), + self.args, + '(' + str(self.op)) + ')' + + def get_r(self, mem_read=False, cst_read=False): + return reduce(lambda x, y: + x.union(y.get_r(mem_read, cst_read)), self.args, set()) + + def get_w(self): + raise ValueError('op cannot be written!', self) + + def __contains__(self, e): + if self == e: + return True + for a in self.args: + if a.__contains__(e): + return True + return False + + def myhash(self): + h_hargs = [x._hash for x in self.args] + return hash((EXPROP, self.op, tuple(h_hargs))) + + def is_associative(self): + "Return True iff current operation is associative" + return (self.op in ['+', '*', '^', '&', '|']) + + def is_commutative(self): + "Return True iff current operation is commutative" + return (self.op in ['+', '*', '^', '&', '|']) + + @visit_chk + def visit(self, cb, tv=None): + args = [a.visit(cb, tv) for a in self.args] + modified = any([x[0] != x[1] for x in zip(self.args, args)]) + if modified: + return ExprOp(self.op, *args) + return self + + def copy(self): + args = [a.copy() for a in self.args] + return ExprOp(self.op, *args) + + def depth(self): + depth = [a.depth() for a in self.args] + return max(depth) + 1 + + def graph_recursive(self, graph): + graph.add_node(self) + for a in self.args: + a.graph_recursive(graph) + graph.add_uniq_edge(self, a) + + +class ExprSlice(Expr): + + def __init__(self, arg, start, stop): + assert(start < stop) + self.arg, self.start, self.stop = arg, start, stop + self._hash = self.myhash() + self._size = self.stop - self.start + + def __str__(self): + return "%s[%d:%d]" % (str(self.arg), self.start, self.stop) + + def get_r(self, mem_read=False, cst_read=False): + return self.arg.get_r(mem_read, cst_read) + + def get_w(self): + return self.arg.get_w() + + def __contains__(self, e): + if self == e: + return True + return self.arg.__contains__(e) + + def myhash(self): + return hash((EXPRSLICE, self.arg._hash, self.start, self.stop)) + + @visit_chk + def visit(self, cb, tv=None): + arg = self.arg.visit(cb, tv) + if arg == self.arg: + return self + return ExprSlice(arg, self.start, self.stop) + + def copy(self): + return ExprSlice(self.arg.copy(), self.start, self.stop) + + def depth(self): + return self.arg.depth() + 1 + + def slice_rest(self): + "Return the completion of the current slice" + size = self.arg.size + if self.start >= size or self.stop > size: + raise ValueError('bad slice rest %s %s %s' % + (size, self.start, self.stop)) + + if self.start == self.stop: + return [(0, size)] + + rest = [] + if self.start != 0: + rest.append((0, self.start)) + if self.stop < size: + rest.append((self.stop, size)) + + return rest + + def graph_recursive(self, graph): + graph.add_node(self) + self.arg.graph_recursive(graph) + graph.add_uniq_edge(self, self.arg) + + +class ExprCompose(Expr): + + """ + Compose is like a hambuger. + It's arguments are tuple of: (Expression, start, stop) + start and stop are intergers, determining Expression position in the compose. + + Burger Example: + ExprCompose([(salad, 0, 3), (cheese, 3, 10), (beacon, 10, 16)]) + In the example, salad.size == 3. + """ + + def __init__(self, args): + """Create an ExprCompose + @args: tuple(Expr, int, int) + """ + + for e, start, stop in args: + if e.size != stop - start: + raise ValueError( + "sanitycheck: ExprCompose args must have correct size!" + + " %r %r %r" % (e, e.size, stop - start)) + + # Transform args to lists + o = [] + for e, a, b in args: + assert(a >= 0 and b >= 0) + o.append(tuple([e, a, b])) + self.args = tuple(o) + + self._hash = self.myhash() + self._size = max([x[2] + for x in self.args]) - min([x[1] for x in self.args]) + + def __str__(self): + return '{' + ', '.join(['%s,%d,%d' % + (str(x[0]), x[1], x[2]) for x in self.args]) + '}' + + def get_r(self, mem_read=False, cst_read=False): + return reduce(lambda x, y: + x.union(y[0].get_r(mem_read, cst_read)), self.args, set()) + + def get_w(self): + return reduce(lambda x, y: + x.union(y[0].get_r(mem_read, cst_read)), self.args, set()) + + def __contains__(self, e): + if self == e: + return True + for a in self.args: + if a == e: + return True + if a[0].__contains__(e): + return True + return False + + def myhash(self): + h_args = [EXPRCOMPOSE] + [(x[0]._hash, x[1], x[2]) for x in self.args] + return hash(tuple(h_args)) + + @visit_chk + def visit(self, cb, tv=None): + args = [(a[0].visit(cb, tv), a[1], a[2]) for a in self.args] + modified = any([x[0] != x[1] for x in zip(self.args, args)]) + if modified: + return ExprCompose(args) + return self + + def copy(self): + args = [(a[0].copy(), a[1], a[2]) for a in self.args] + return ExprCompose(args) + + def depth(self): + depth = [a[0].depth() for a in self.args] + return max(depth) + 1 + + def graph_recursive(self, graph): + graph.add_node(self) + for a in self.args: + a[0].graph_recursive(graph) + graph.add_uniq_edge(self, a[0]) + + +# Expression order for comparaison +expr_order_dict = {ExprId: 1, + ExprCond: 2, + ExprMem: 3, + ExprOp: 4, + ExprSlice: 5, + ExprCompose: 7, + ExprInt: 8, + } + + +def compare_exprs_compose(e1, e2): + # Sort by start bit address, then expr, then stop but address + x = cmp(e1[1], e2[1]) + if x: + return x + x = compare_exprs(e1[0], e2[0]) + if x: + return x + x = cmp(e1[2], e2[2]) + return x + + +def compare_expr_list_compose(l1_e, l2_e): + # Sort by list elements in incremental order, then by list size + for i in xrange(min(len(l1_e), len(l2_e))): + x = compare_exprs_compose(l1_e[i], l2_e[i]) + if x: + return x + return cmp(len(l1_e), len(l2_e)) + + +def compare_expr_list(l1_e, l2_e): + # Sort by list elements in incremental order, then by list size + for i in xrange(min(len(l1_e), len(l2_e))): + x = compare_exprs(l1_e[i], l2_e[i]) + if x: + return x + return cmp(len(l1_e), len(l2_e)) + + +def compare_exprs(e1, e2): + """Compare 2 expressions for canonization + @e1: Expr + @e2: Expr + 0 => == + 1 => e1 > e2 + -1 => e1 < e2 + """ + c1 = e1.__class__ + c2 = e2.__class__ + if c1 != c2: + return cmp(expr_order_dict[c1], expr_order_dict[c2]) + if e1 == e2: + return 0 + if c1 == ExprInt: + return cmp(e1.arg, e2.arg) + elif c1 == ExprId: + x = cmp(e1.name, e2.name) + if x: + return x + return cmp(e1._size, e2._size) + elif c1 == ExprAff: + raise NotImplementedError( + "Comparaison from an ExprAff not yet implemented") + elif c2 == ExprCond: + x = compare_exprs(e1.cond, e2.cond) + if x: + return x + x = compare_exprs(e1.src1, e2.src1) + if x: + return x + x = compare_exprs(e1.src2, e2.src2) + return x + elif c1 == ExprMem: + x = compare_exprs(e1.arg, e2.arg) + if x: + return x + return cmp(e1._size, e2._size) + elif c1 == ExprOp: + if e1.op != e2.op: + return cmp(e1.op, e2.op) + return compare_expr_list(e1.args, e2.args) + elif c1 == ExprSlice: + x = compare_exprs(e1.arg, e2.arg) + if x: + return x + x = cmp(e1.start, e2.start) + if x: + return x + x = cmp(e1.stop, e2.stop) + return x + elif c1 == ExprCompose: + return compare_expr_list_compose(e1.args, e2.args) + raise NotImplementedError( + "Comparaison between %r %r not implemented" % (e1, e2)) + + +def canonize_expr_list(l): + l = list(l) + l.sort(cmp=compare_exprs) + return l + + +def canonize_expr_list_compose(l): + l = list(l) + l.sort(cmp=compare_exprs_compose) + return l + +# Generate ExprInt with common size + + +def ExprInt1(i): + return ExprInt(uint1(i)) + + +def ExprInt8(i): + return ExprInt(uint8(i)) + + +def ExprInt16(i): + return ExprInt(uint16(i)) + + +def ExprInt32(i): + return ExprInt(uint32(i)) + + +def ExprInt64(i): + return ExprInt(uint64(i)) + + +def ExprInt_from(e, i): + "Generate ExprInt with size equal to expression" + return ExprInt(mod_size2uint[e.size](i)) + + +def ExprInt_fromsize(size, i): + "Generate ExprInt with a given size" + return ExprInt(mod_size2uint[size](i)) + + +def get_expr_ids_visit(e, ids): + if isinstance(e, ExprId): + ids.add(e) + return e + + +def get_expr_ids(e): + ids = set() + e.visit(lambda x: get_expr_ids_visit(x, ids)) + return ids + + +def test_set(e, v, tks, result): + """Test if v can correspond to e. If so, update the context in result. + Otherwise, return False + @e : Expr + @v : Expr + @tks : list of ExprId, available jokers + @result : dictionnary of ExprId -> Expr, current context + """ + + if not v in tks: + return e == v + if v in result and result[v] != e: + return False + result[v] = e + return result + + +def MatchExpr(e, m, tks, result=None): + """Try to match m expression with e expression with tks jokers. + Result is output dictionnary with matching joker values. + @e : Expr to test + @m : Targetted Expr + @tks : list of ExprId, available jokers + @result : dictionnary of ExprId -> Expr, output matching context + """ + + if result is None: + result = {} + + if m in tks: + # m is a Joker + return test_set(e, m, tks, result) + + if isinstance(e, ExprInt): + return test_set(e, m, tks, result) + + elif isinstance(e, ExprId): + return test_set(e, m, tks, result) + + elif isinstance(e, ExprOp): + + # e need to be the same operation than m + if not isinstance(m, ExprOp): + return False + if e.op != m.op: + return False + if len(e.args) != len(m.args): + return False + + # Perform permutation only if the current operation is commutative + if e.is_commutative(): + permutations = itertools.permutations(e.args) + else: + permutations = [e.args] + + # For each permutations of arguments + for permut in permutations: + good = True + # We need to use a copy of result to not override it + myresult = dict(result) + for a1, a2 in zip(permut, m.args): + r = MatchExpr(a1, a2, tks, myresult) + # If the current permutation do not match EVERY terms + if r is False: + good = False + break + if good is True: + # We found a possibility + for k, v in myresult.items(): + # Updating result in place (to keep pointer in recursion) + result[k] = v + return result + return False + + # Recursive tests + + elif isinstance(e, ExprMem): + if not isinstance(m, ExprMem): + return False + if e._size != m._size: + return False + return MatchExpr(e.arg, m.arg, tks, result) + + elif isinstance(e, ExprSlice): + if not isinstance(m, ExprSlice): + return False + if e.start != m.start or e.stop != m.stop: + return False + return MatchExpr(e.arg, m.arg, tks, result) + + elif isinstance(e, ExprCond): + if not isinstance(m, ExprCond): + return False + r = MatchExpr(e.cond, m.cond, tks, result) + if r is False: + return False + r = MatchExpr(e.src1, m.src1, tks, result) + if r is False: + return False + r = MatchExpr(e.src2, m.src2, tks, result) + if r is False: + return False + return result + + elif isinstance(e, ExprCompose): + if not isinstance(m, ExprCompose): + return False + for a1, a2 in zip(e.args, m.args): + if a1[1] != a2[1] or a1[2] != a2[2]: + return False + r = MatchExpr(a1[0], a2[0], tks, result) + if r is False: + return False + return result + + else: + raise NotImplementedError("MatchExpr: Unknown type: %s" % type(e)) + + +def SearchExpr(e, m, tks, result=None): + # TODO XXX: to test + if result is None: + result = set() + + def visit_search(e, m, tks, result): + r = {} + MatchExpr(e, m, tks, r) + if r: + result.add(tuple(r.items())) + return e + e.visit(lambda x: visit_search(x, m, tks, result)) + + +def get_rw(exprs): + o_r = set() + o_w = set() + for e in exprs: + o_r.update(e.get_r(mem_read=True)) + for e in exprs: + o_w.update(e.get_w()) + return o_r, o_w + + +def get_list_rw(exprs, mem_read=False, cst_read=True): + """ + return list of read/write reg/cst/mem for each expressions + """ + list_rw = [] + # cst_num = 0 + for e in exprs: + o_r = set() + o_w = set() + # get r/w + o_r.update(e.get_r(mem_read=mem_read, cst_read=cst_read)) + if isinstance(e.dst, ExprMem): + o_r.update(e.dst.arg.get_r(mem_read=mem_read, cst_read=cst_read)) + o_w.update(e.get_w()) + # each cst is indexed + o_r_rw = set() + for r in o_r: + # if isinstance(r, ExprInt): + # r = ExprOp('cst_%d'%cst_num, r) + # cst_num += 1 + o_r_rw.add(r) + o_r = o_r_rw + list_rw.append((o_r, o_w)) + + return list_rw + + +def get_expr_ops(e): + def visit_getops(e, out=None): + if out is None: + out = set() + if isinstance(e, ExprOp): + out.add(e.op) + return e + ops = set() + e.visit(lambda x: visit_getops(x, ops)) + return ops + + +def get_expr_mem(e): + def visit_getmem(e, out=None): + if out is None: + out = set() + if isinstance(e, ExprMem): + out.add(e) + return e + ops = set() + e.visit(lambda x: visit_getmem(x, ops)) + return ops diff --git a/miasm2/expression/expression_helper.py b/miasm2/expression/expression_helper.py new file mode 100644 index 00000000..cd59730b --- /dev/null +++ b/miasm2/expression/expression_helper.py @@ -0,0 +1,196 @@ +# +# Copyright (C) 2011 EADS France, Fabrice Desclaux <fabrice.desclaux@eads.net> +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with this program; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# + +# Expressions manipulation functions +import miasm2.expression.expression as m2_expr + + +def parity(a): + tmp = (a) & 0xFFL + cpt = 1 + while tmp != 0: + cpt ^= tmp & 1 + tmp >>= 1 + return cpt + + +def merge_sliceto_slice(args): + sources = {} + non_slice = {} + sources_int = {} + for a in args: + if isinstance(a[0], m2_expr.ExprInt): + # sources_int[a.start] = a + # copy ExprInt because we will inplace modify arg just below + # /!\ TODO XXX never ever modify inplace args... + sources_int[a[1]] = (m2_expr.ExprInt_fromsize(a[2] - a[1], + a[0].arg.__class__( + a[0].arg)), + a[1], + a[2]) + elif isinstance(a[0], m2_expr.ExprSlice): + if not a[0].arg in sources: + sources[a[0].arg] = [] + sources[a[0].arg].append(a) + else: + non_slice[a[1]] = a + # find max stop to determine size + max_size = None + for a in args: + if max_size is None or max_size < a[2]: + max_size = a[2] + + # first simplify all num slices + final_sources = [] + sorted_s = [] + for x in sources_int.values(): + # mask int + v = x[0].arg & ((1 << (x[2] - x[1])) - 1) + x[0].arg = v + sorted_s.append((x[1], x)) + sorted_s.sort() + while sorted_s: + start, v = sorted_s.pop() + out = [m2_expr.ExprInt(v[0].arg), v[1], v[2]] + size = v[2] - v[1] + while sorted_s: + if sorted_s[-1][1][2] != start: + break + s_start, s_stop = sorted_s[-1][1][1], sorted_s[-1][1][2] + size += s_stop - s_start + a = m2_expr.mod_size2uint[size]( + (int(out[0].arg) << (out[1] - s_start)) + + int(sorted_s[-1][1][0].arg)) + out[0].arg = a + sorted_s.pop() + out[1] = s_start + out[0] = m2_expr.ExprInt_fromsize(size, out[0].arg) + final_sources.append((start, out)) + + final_sources_int = final_sources + # check if same sources have corresponding start/stop + # is slice AND is sliceto + simp_sources = [] + for args in sources.values(): + final_sources = [] + sorted_s = [] + for x in args: + sorted_s.append((x[1], x)) + sorted_s.sort() + while sorted_s: + start, v = sorted_s.pop() + ee = v[0].arg[v[0].start:v[0].stop] + out = ee, v[1], v[2] + while sorted_s: + if sorted_s[-1][1][2] != start: + break + if sorted_s[-1][1][0].stop != out[0].start: + break + + start = sorted_s[-1][1][1] + # out[0].start = sorted_s[-1][1][0].start + o_e, _, o_stop = out + o1, o2 = sorted_s[-1][1][0].start, o_e.stop + o_e = o_e.arg[o1:o2] + out = o_e, start, o_stop + # update _size + # out[0]._size = out[0].stop-out[0].start + sorted_s.pop() + out = out[0], start, out[2] + + final_sources.append((start, out)) + + simp_sources += final_sources + + simp_sources += final_sources_int + + for i, v in non_slice.items(): + simp_sources.append((i, v)) + + simp_sources.sort() + simp_sources = [x[1] for x in simp_sources] + return simp_sources + + +op_propag_cst = ['+', '*', '^', '&', '|', '>>', + '<<', "a>>", ">>>", "/", "%", 'idiv', 'irem'] + + +def is_pure_int(e): + """ + return True if expr is only composed with integers + /!\ ExprCond returns True is src1 and src2 are integers + """ + def modify_cond(e): + if isinstance(e, m2_expr.ExprCond): + return e.src1 | e.src2 + return e + + def find_int(e, s): + if isinstance(e, m2_expr.ExprId) or isinstance(e, m2_expr.ExprMem): + s.add(e) + return e + s = set() + new_e = e.visit(modify_cond) + new_e.visit(lambda x: find_int(x, s)) + if s: + return False + return True + + +def is_int_or_cond_src_int(e): + if isinstance(e, m2_expr.ExprInt): + return True + if isinstance(e, m2_expr.ExprCond): + return (isinstance(e.src1, m2_expr.ExprInt) and + isinstance(e.src2, m2_expr.ExprInt)) + return False + + +def fast_unify(seq, idfun=None): + # order preserving unifying list function + if idfun is None: + idfun = lambda x: x + seen = {} + result = [] + for item in seq: + marker = idfun(item) + + if marker in seen: + continue + seen[marker] = 1 + result.append(item) + return result + +def get_missing_interval(all_intervals, i_min=0, i_max=32): + """Return a list of missing interval in all_interval + @all_interval: list of (int, int) + @i_min: int, minimal missing interval bound + @i_max: int, maximal missing interval bound""" + + my_intervals = all_intervals[:] + my_intervals.sort() + my_intervals.append((i_max, i_max)) + + missing_i = [] + last_pos = i_min + for start, stop in my_intervals: + if last_pos != start: + missing_i.append((last_pos, start)) + last_pos = stop + return missing_i diff --git a/miasm2/expression/modint.py b/miasm2/expression/modint.py new file mode 100644 index 00000000..ffe1574c --- /dev/null +++ b/miasm2/expression/modint.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python +#-*- coding:utf-8 -*- + +class moduint(object): + + def __init__(self, arg): + if isinstance(arg, moduint): + arg = arg.arg + self.arg = arg % self.__class__.limit + assert(self.arg >= 0 and self.arg < self.__class__.limit) + + def __repr__(self): + return self.__class__.__name__ + '(' + hex(self.arg) + ')' + + def __hash__(self): + return hash(self.arg) + + @classmethod + def maxcast(cls, c2): + c2 = c2.__class__ + if cls.size > c2.size: + return cls + else: + return c2 + + def __cmp__(self, y): + if isinstance(y, moduint): + return cmp(self.arg, y.arg) + else: + return cmp(self.arg, y) + + def __add__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(self.arg + y.arg) + else: + return self.__class__(self.arg + y) + + def __and__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(self.arg & y.arg) + else: + return self.__class__(self.arg & y) + + def __div__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(self.arg / y.arg) + else: + return self.__class__(self.arg / y) + + def __int__(self): + return int(self.arg) + + def __long__(self): + return long(self.arg) + + def __invert__(self): + return self.__class__(~self.arg) + + def __lshift__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(self.arg << y.arg) + else: + return self.__class__(self.arg << y) + + def __mod__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(self.arg % y.arg) + else: + return self.__class__(self.arg % y) + + def __mul__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(self.arg * y.arg) + else: + return self.__class__(self.arg * y) + + def __neg__(self): + return self.__class__(-self.arg) + + def __or__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(self.arg | y.arg) + else: + return self.__class__(self.arg | y) + + def __radd__(self, y): + return self.__add__(y) + + def __rand__(self, y): + return self.__and__(y) + + def __rdiv__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(y.arg / self.arg) + else: + return self.__class__(y / self.arg) + + def __rlshift__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(y.arg << self.arg) + else: + return self.__class__(y << self.arg) + + def __rmod__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(y.arg % self.arg) + else: + return self.__class__(y % self.arg) + + def __rmul__(self, y): + return self.__mul__(y) + + def __ror__(self, y): + return self.__or__(y) + + def __rrshift__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(y.arg >> self.arg) + else: + return self.__class__(y >> self.arg) + + def __rshift__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(self.arg >> y.arg) + else: + return self.__class__(self.arg >> y) + + def __rsub__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(y.arg - self.arg) + else: + return self.__class__(y - self.arg) + + def __rxor__(self, y): + return self.__xor__(y) + + def __sub__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(self.arg - y.arg) + else: + return self.__class__(self.arg - y) + + def __xor__(self, y): + if isinstance(y, moduint): + cls = self.maxcast(y) + return cls(self.arg ^ y.arg) + else: + return self.__class__(self.arg ^ y) + + def __hex__(self): + return hex(self.arg) + + def __abs__(self): + return abs(self.arg) + + def __rpow__(self, v): + return v ** self.arg + + def __pow__(self, v): + return self.__class__(self.arg ** v) + + +class modint(moduint): + + def __init__(self, arg): + if isinstance(arg, moduint): + arg = arg.arg + a = arg % self.__class__.limit + if a >= self.__class__.limit / 2: + a -= self.__class__.limit + self.arg = a + assert(self.arg >= -self.__class__.limit / + 2 and self.arg < self.__class__.limit) + + +def is_modint(a): + return isinstance(a, moduint) + + +def size2mask(size): + return (1 << size) - 1 + +mod_size2uint = {} +mod_size2int = {} + +mod_uint2size = {} +mod_int2size = {} + + +def define_common_int(): + "Define common int: ExprInt1, ExprInt2, .." + global mod_size2int, mod_int2size, mod_size2uint, mod_uint2size + + common_int = xrange(1, 257) + + for i in common_int: + name = 'uint%d' % i + c = type(name, (moduint,), {"size": i, "limit": 1 << i}) + globals()[name] = c + mod_size2uint[i] = c + mod_uint2size[c] = i + + for i in common_int: + name = 'int%d' % i + c = type(name, (modint,), {"size": i, "limit": 1 << i}) + globals()[name] = c + mod_size2int[i] = c + mod_int2size[c] = i + +define_common_int() diff --git a/miasm2/expression/simplifications.py b/miasm2/expression/simplifications.py new file mode 100644 index 00000000..29d19614 --- /dev/null +++ b/miasm2/expression/simplifications.py @@ -0,0 +1,605 @@ +# +# Simplification methods library # +# + +from miasm2.expression.expression import * +from miasm2.expression.expression_helper import * + +# Common passes +# ------------- + + +def simp_cst_propagation(e_s, e): + """This passe includes: + - Constant folding + - Common logical identities + - Common binary identities + """ + + # merge associatif op + if not isinstance(e, ExprOp): + return e + args = list(e.args) + op = e.op + # simpl integer manip + # int OP int => int + if op in op_propag_cst: + while (len(args) >= 2 and + isinstance(args[-1], ExprInt) and + isinstance(args[-2], ExprInt)): + i2 = args.pop() + i1 = args.pop() + if op == '+': + o = i1.arg + i2.arg + elif op == '*': + o = i1.arg * i2.arg + elif op == '^': + o = i1.arg ^ i2.arg + elif op == '&': + o = i1.arg & i2.arg + elif op == '|': + o = i1.arg | i2.arg + elif op == '>>': + o = i1.arg >> i2.arg + elif op == '<<': + o = i1.arg << i2.arg + elif op == 'a>>': + x1 = mod_size2int[i1.arg.size](i1.arg) + x2 = mod_size2int[i2.arg.size](i2.arg) + o = mod_size2uint[i1.arg.size](x1 >> x2) + elif op == '>>>': + o = i1.arg >> i2.arg | i1.arg << (i1.size - i2.arg) + elif op == '/': + o = i1.arg / i2.arg + elif op == '%': + o = i1.arg % i2.arg + elif op == 'idiv': + assert(i2.arg) + x1 = mod_size2int[i1.arg.size](i1.arg) + x2 = mod_size2int[i2.arg.size](i2.arg) + o = mod_size2uint[i1.arg.size](x1 / x2) + elif op == 'irem': + assert(i2.arg) + x1 = mod_size2int[i1.arg.size](i1.arg) + x2 = mod_size2int[i2.arg.size](i2.arg) + o = mod_size2uint[i1.arg.size](x1 % x2) + + o = ExprInt_fromsize(i1.size, o) + args.append(o) + + # bsf(int) => int + if op == "bsf" and isinstance(args[0], ExprInt) and args[0].arg != 0: + i = 0 + while args[0].arg & (1 << i) == 0: + i += 1 + return ExprInt_from(args[0], i) + + # bsr(int) => int + if op == "bsr" and isinstance(args[0], ExprInt) and args[0].arg != 0: + i = args[0].size - 1 + while args[0].arg & (1 << i) == 0: + i -= 1 + return ExprInt_from(args[0], i) + + # -(-(A)) => A + if op == '-' and len(args) == 1 and isinstance(args[0], ExprOp) and \ + args[0].op == '-' and len(args[0].args) == 1: + return args[0].args[0] + + # -(int) => -int + if op == '-' and len(args) == 1 and isinstance(args[0], ExprInt): + return ExprInt(-args[0].arg) + # A op 0 =>A + if op in ['+', '-', '|', "^", "<<", ">>", "<<<", ">>>"] and len(args) > 1: + if isinstance(args[-1], ExprInt) and args[-1].arg == 0: + args.pop() + # A * 1 =>A + if op == "*" and len(args) > 1: + if isinstance(args[-1], ExprInt) and args[-1].arg == 1: + args.pop() + + # for cannon form + # A * -1 => - A + if op == "*" and len(args) > 1: + if (isinstance(args[-1], ExprInt) and + args[-1].arg == (1 << args[-1].size) - 1): + args.pop() + args[-1] = - args[-1] + + # op A => A + if op in ['+', '*', '^', '&', '|', '>>', '<<', + 'a>>', '<<<', '>>>', 'idiv', 'irem'] and len(args) == 1: + return args[0] + + # A-B => A + (-B) + if op == '-' and len(args) > 1: + if len(args) > 2: + raise ValueError( + 'sanity check fail on expr -: should have one or 2 args ' + + '%r %s' % (e, e)) + return ExprOp('+', args[0], -args[1]) + + # A op 0 => 0 + if op in ['&', "*"] and isinstance(args[1], ExprInt) and args[1].arg == 0: + return ExprInt_from(e, 0) + + # - (A + B +...) => -A + -B + -C + if (op == '-' and + len(args) == 1 and + isinstance(args[0], ExprOp) and + args[0].op == '+'): + args = [-a for a in args[0].args] + e = ExprOp('+', *args) + return e + + # -(a?int1:int2) => (a?-int1:-int2) + if (op == '-' and + len(args) == 1 and + isinstance(args[0], ExprCond) and + isinstance(args[0].src1, ExprInt) and + isinstance(args[0].src2, ExprInt)): + i1 = args[0].src1 + i2 = args[0].src2 + i1 = ExprInt_from(i1, -i1.arg) + i2 = ExprInt_from(i2, -i2.arg) + return ExprCond(args[0].cond, i1, i2) + + i = 0 + while i < len(args) - 1: + j = i + 1 + while j < len(args): + # A ^ A => 0 + if op == '^' and args[i] == args[j]: + args[i] = ExprInt_from(args[i], 0) + del(args[j]) + continue + # A + (- A) => 0 + if op == '+' and isinstance(args[j], ExprOp) and args[j].op == "-": + if len(args[j].args) == 1 and args[i] == args[j].args[0]: + args[i] = ExprInt_from(args[i], 0) + del(args[j]) + continue + # (- A) + A => 0 + if op == '+' and isinstance(args[i], ExprOp) and args[i].op == "-": + if len(args[i].args) == 1 and args[j] == args[i].args[0]: + args[i] = ExprInt_from(args[i], 0) + del(args[j]) + continue + # A | A => A + if op == '|' and args[i] == args[j]: + del(args[j]) + continue + # A & A => A + if op == '&' and args[i] == args[j]: + del(args[j]) + continue + j += 1 + i += 1 + + if op in ['|', '&', '%', '/'] and len(args) == 1: + return args[0] + + # A <<< A.size => A + if (op in ['<<<', '>>>'] and + isinstance(args[1], ExprInt) and + args[1].arg == args[0].size): + return args[0] + + # A <<< X <<< Y => A <<< (X+Y) (ou <<< >>>) + if (op in ['<<<', '>>>'] and + isinstance(args[0], ExprOp) and + args[0].op in ['<<<', '>>>']): + op1 = op + op2 = args[0].op + if op1 == op2: + op = op1 + args1 = args[0].args[1] + args[1] + else: + op = op2 + args1 = args[0].args[1] - args[1] + + args0 = args[0].args[0] + args = [args0, args1] + + # ((A & A.mask) + if op == "&" and args[-1] == e.mask: + return ExprOp('&', *args[:-1]) + + # ((A | A.mask) + if op == "|" and args[-1] == e.mask: + return args[-1] + + # ! (!X + int) => X - int + # TODO + + # ((A & mask) >> shift) whith mask < 2**shift => 0 + if (op == ">>" and + isinstance(args[1], ExprInt) and + isinstance(args[0], ExprOp) and args[0].op == "&"): + if (isinstance(args[0].args[1], ExprInt) and + 2 ** args[1].arg >= args[0].args[1].arg): + return ExprInt_from(args[0], 0) + + # int == int => 0 or 1 + if (op == '==' and + isinstance(args[0], ExprInt) and + isinstance(args[1], ExprInt)): + if args[0].arg == args[1].arg: + return ExprInt_from(args[0], 1) + else: + return ExprInt_from(args[0], 0) + #(A|int == 0) => 0 with int != 0 + if op == '==' and isinstance(args[1], ExprInt) and args[1].arg == 0: + if isinstance(args[0], ExprOp) and args[0].op == '|' and\ + isinstance(args[0].args[1], ExprInt) and \ + args[0].args[1].arg != 0: + return ExprInt_from(args[0], 0) + + # parity(int) => int + if op == 'parity' and isinstance(args[0], ExprInt): + return ExprInt1(parity(args[0].arg)) + + # (-a) * b * (-c) * (-d) => (-a) * b * c * d + if op == "*" and len(args) > 1: + new_args = [] + counter = 0 + for a in args: + if isinstance(a, ExprOp) and a.op == '-' and len(a.args) == 1: + new_args.append(a.args[0]) + counter += 1 + else: + new_args.append(a) + if counter % 2: + return -ExprOp(op, *new_args) + args = new_args + + return ExprOp(op, *args) + + +def simp_cond_op_int(e_s, e): + "Extract conditions from operations" + + if not isinstance(e, ExprOp): + return e + if not e.op in ["+", "|", "^", "&", "*", '<<', '>>', 'a>>']: + return e + if len(e.args) < 2: + return e + if not isinstance(e.args[-1], ExprInt): + return e + a_int = e.args[-1] + conds = [] + for a in e.args[:-1]: + if not isinstance(a, ExprCond): + return e + conds.append(a) + if not conds: + return e + c = conds.pop() + c = ExprCond(c.cond, + ExprOp(e.op, c.src1, a_int), + ExprOp(e.op, c.src2, a_int)) + conds.append(c) + new_e = ExprOp(e.op, *conds) + return new_e + + +def simp_cond_factor(e_s, e): + "Merge similar conditions" + if not isinstance(e, ExprOp): + return e + if not e.op in ["+", "|", "^", "&", "*", '<<', '>>', 'a>>']: + return e + if len(e.args) < 2: + return e + conds = {} + not_conds = [] + multi_cond = False + for a in e.args: + if not isinstance(a, ExprCond): + not_conds.append(a) + continue + c = a.cond + if not c in conds: + conds[c] = [] + else: + multi_cond = True + conds[c].append(a) + if not multi_cond: + return e + c_out = not_conds[:] + for c, vals in conds.items(): + new_src1 = [x.src1 for x in vals] + new_src2 = [x.src2 for x in vals] + src1 = e_s.expr_simp_wrapper(ExprOp(e.op, *new_src1)) + src2 = e_s.expr_simp_wrapper(ExprOp(e.op, *new_src2)) + c_out.append(ExprCond(c, src1, src2)) + + if len(c_out) == 1: + new_e = c_out[0] + else: + new_e = ExprOp(e.op, *c_out) + return new_e + + +def simp_slice(e_s, e): + "Slice optimization" + + # slice(A, 0, a.size) => A + if e.start == 0 and e.stop == e.arg.size: + return e.arg + # Slice(int) => int + elif isinstance(e.arg, ExprInt): + total_bit = e.stop - e.start + mask = (1 << (e.stop - e.start)) - 1 + return ExprInt_fromsize(total_bit, (e.arg.arg >> e.start) & mask) + # Slice(Slice(A, x), y) => Slice(A, z) + elif isinstance(e.arg, ExprSlice): + if e.stop - e.start > e.arg.stop - e.arg.start: + raise ValueError('slice in slice: getting more val', str(e)) + + new_e = ExprSlice(e.arg.arg, e.start + e.arg.start, + e.start + e.arg.start + (e.stop - e.start)) + return new_e + # Slice(Compose(A), x) => Slice(A, y) + elif isinstance(e.arg, ExprCompose): + for a in e.arg.args: + if a[1] <= e.start and a[2] >= e.stop: + new_e = a[0][e.start - a[1]:e.stop - a[1]] + return new_e + # ExprMem(x, size)[:A] => ExprMem(x, a) + # XXXX todo hum, is it safe? + elif (isinstance(e.arg, ExprMem) and + e.start == 0 and + e.arg.size > e.stop and e.stop % 8 == 0): + e = ExprMem(e.arg.arg, size=e.stop) + return e + # distributivity of slice and & + # (a & int)[x:y] => 0 if int[x:y] == 0 + elif (isinstance(e.arg, ExprOp) and + e.arg.op == "&" and + isinstance(e.arg.args[-1], ExprInt)): + tmp = e_s.expr_simp_wrapper(e.arg.args[-1][e.start:e.stop]) + if isinstance(tmp, ExprInt) and tmp.arg == 0: + return tmp + # distributivity of slice and exprcond + # (a?int1:int2)[x:y] => (a?int1[x:y]:int2[x:y]) + elif (isinstance(e.arg, ExprCond) and + isinstance(e.arg.src1, ExprInt) and + isinstance(e.arg.src2, ExprInt)): + src1 = e.arg.src1[e.start:e.stop] + src2 = e.arg.src2[e.start:e.stop] + e = ExprCond(e.arg.cond, src1, src2) + + # (a * int)[0:y] => (a[0:y] * int[0:y]) + elif (isinstance(e.arg, ExprOp) and + e.arg.op == "*" and + isinstance(e.arg.args[-1], ExprInt)): + args = [e_s.expr_simp_wrapper(a[e.start:e.stop]) for a in e.arg.args] + e = ExprOp(e.arg.op, *args) + + return e + + +def simp_compose(e_s, e): + "Commons simplification on ExprCompose" + args = merge_sliceto_slice(e.args) + out = [] + # compose of compose + for a in args: + if isinstance(a[0], ExprCompose): + for x, start, stop in a[0].args: + out.append((x, start + a[1], stop + a[1])) + else: + out.append(a) + args = out + # Compose(a) with a.size = compose.size => a + if len(args) == 1 and args[0][1] == 0 and args[0][2] == e.size: + return args[0][0] + + # {(X[X.size-z, 0, z), (0, z, X.size)} => (X >> x) + if (len(args) == 2 and + isinstance(args[1][0], ExprInt) and + args[1][0].arg == 0): + a1 = args[0] + a2 = args[1] + if (isinstance(a1[0], ExprSlice) and + a1[1] == 0 and a1[0].stop == a1[0].arg.size): + if a2[1] == a1[0].size and a2[2] == a1[0].arg.size: + new_e = a1[0].arg >> ExprInt_fromsize( + a1[0].arg.size, a1[0].start) + return new_e + + # Compose with ExprCond with integers for src1/src2 and intergers => + # propagage integers + # {XXX?(0x0,0x1)?(0x0,0x1),0,8, 0x0,8,32} => XXX?(int1, int2) + + ok = True + expr_cond = None + expr_ints = [] + for i, a in enumerate(args): + if not is_int_or_cond_src_int(a[0]): + ok = False + break + expr_ints.append(a) + if isinstance(a[0], ExprCond): + if expr_cond is not None: + ok = False + expr_cond = i + cond = a[0] + + if ok and expr_cond is not None: + src1 = [] + src2 = [] + for i, a in enumerate(expr_ints): + if i == expr_cond: + src1.append((a[0].src1, a[1], a[2])) + src2.append((a[0].src2, a[1], a[2])) + else: + src1.append(a) + src2.append(a) + src1 = e_s.apply_simp(ExprCompose(src1)) + src2 = e_s.apply_simp(ExprCompose(src2)) + if isinstance(src1, ExprInt) and isinstance(src2, ExprInt): + return ExprCond(cond.cond, src1, src2) + return ExprCompose(args) + + +def simp_cond(e_s, e): + "Common simplifications on ExprCond" + if not isinstance(e, ExprCond): + return e + # eval exprcond src1/src2 with satifiable/unsatisfiable condition + # propagation + if (not isinstance(e.cond, ExprInt)) and e.cond.size == 1: + src1 = e.src1.replace_expr({e.cond: ExprInt1(1)}) + src2 = e.src2.replace_expr({e.cond: ExprInt1(0)}) + if src1 != e.src1 or src2 != e.src2: + return ExprCond(e.cond, src1, src2) + + # -A ? B:C => A ? B:C + if (isinstance(e.cond, ExprOp) and + e.cond.op == '-' and + len(e.cond.args) == 1): + e = ExprCond(e.cond.args[0], e.src1, e.src2) + # a?x:x + elif e.src1 == e.src2: + e = e.src1 + # int ? A:B => A or B + elif isinstance(e.cond, ExprInt): + if e.cond.arg == 0: + e = e.src2 + else: + e = e.src1 + # a?(a?b:c):x => a?b:x + elif isinstance(e.src1, ExprCond) and e.cond == e.src1.cond: + e = ExprCond(e.cond, e.src1.src1, e.src2) + # a?x:(a?b:c) => a?x:c + elif isinstance(e.src2, ExprCond) and e.cond == e.src2.cond: + e = ExprCond(e.cond, e.src1, e.src2.src2) + # a|int ? b:c => b with int != 0 + elif (isinstance(e.cond, ExprOp) and + e.cond.op == '|' and + isinstance(e.cond.args[1], ExprInt) and + e.cond.args[1].arg != 0): + return e.src1 + + # (C?int1:int2)?(A:B) => + elif (isinstance(e.cond, ExprCond) and + isinstance(e.cond.src1, ExprInt) and + isinstance(e.cond.src2, ExprInt)): + int1 = e.cond.src1.arg.arg + int2 = e.cond.src2.arg.arg + if int1 and int2: + e = e.src1 + elif int1 == 0 and int2 == 0: + e = e.src2 + elif int1 == 0 and int2: + e = ExprCond(e.cond.cond, e.src2, e.src1) + elif int1 and int2 == 0: + e = ExprCond(e.cond.cond, e.src1, e.src2) + return e + + +# Expression Simplifier +# --------------------- + + +class ExpressionSimplifier(object): + + """Wrapper on expression simplification passes. + + Instance handle passes lists. + + Available passes lists are: + - commons: common passes such as constant folding + - heavy : rare passes (for instance, in case of obfuscation) + """ + + # Common passes + PASS_COMMONS = { + m2_expr.ExprOp: [simp_cst_propagation, + simp_cond_op_int, + simp_cond_factor], + m2_expr.ExprSlice: [simp_slice], + m2_expr.ExprCompose: [simp_compose], + m2_expr.ExprCond: [simp_cond], + } + + # Heavy passes + PASS_HEAVY = {} + + def __init__(self): + self.expr_simp_cb = {} + + def enable_passes(self, passes): + """Add passes from @passes + @passes: dict(Expr class : list(callback)) + + Callback signature: Expr callback(ExpressionSimplifier, Expr) + """ + + for k, v in passes.items(): + self.expr_simp_cb[k] = fast_unify(self.expr_simp_cb.get(k, []) + v) + + def apply_simp(self, expression): + """Apply enabled simplifications on expression + @expression: Expr instance + Return an Expr instance""" + + cls = expression.__class__ + for simp_func in self.expr_simp_cb.get(cls, []): + # Apply simplifications + expression = simp_func(self, expression) + + # If class changes, stop to prevent wrong simplifications + if expression.__class__ is not cls: + break + + return expression + + def expr_simp(self, expression): + """Apply enabled simplifications on expression and find a stable state + @expression: Expr instance + Return an Expr instance""" + + if expression.is_simp: + return expression + + # Find a stable state + while True: + # Canonize and simplify + e_new = self.apply_simp(expression.canonize()) + if e_new == expression: + break + + # Launch recursivity + expression = self.expr_simp_wrapper(e_new) + expression.is_simp = True + + # Mark expression as simplified + e_new.is_simp = True + return e_new + + def expr_simp_wrapper(self, expression, callback=None): + """Apply enabled simplifications on expression + @expression: Expr instance + @manual_callback: If set, call this function instead of normal one + Return an Expr instance""" + + if expression.is_simp: + return expression + + if callback is None: + callback = self.expr_simp + + return expression.visit(callback, lambda e: not(e.is_simp)) + + def __call__(self, expression, callback=None): + "Wrapper on expr_simp_wrapper" + return self.expr_simp_wrapper(expression, callback) + + +# Public ExprSimplificationPass instance with commons passes +expr_simp = ExpressionSimplifier() +expr_simp.enable_passes(ExpressionSimplifier.PASS_COMMONS) diff --git a/miasm2/expression/stp.py b/miasm2/expression/stp.py new file mode 100644 index 00000000..7ef96166 --- /dev/null +++ b/miasm2/expression/stp.py @@ -0,0 +1,68 @@ +from miasm2.expression.expression import * + + +""" +Quick implementation of miasm traduction to stp langage +TODO XXX: finish +""" + + +def ExprInt_strcst(self): + b = bin(int(self.arg))[2::][::-1] + b += "0" * self.size + b = b[:self.size][::-1] + return "0bin" + b + + +def ExprId_strcst(self): + return self.name + + +def genop(op, size, a, b): + return op + '(' + str(size) + ',' + a + ', ' + b + ')' + + +def genop_nosize(op, size, a, b): + return op + '(' + a + ', ' + b + ')' + + +def ExprOp_strcst(self): + op = self.op + op_dct = {"|": " | ", + "&": " & "} + if op in op_dct: + return '(' + op_dct[op].join([x.strcst() for x in self.args]) + ')' + op_dct = {"-": "BVUMINUS"} + if op in op_dct: + return op_dct[op] + '(' + self.args[0].strcst() + ')' + op_dct = {"^": ("BVXOR", genop_nosize), + "+": ("BVPLUS", genop)} + if not op in op_dct: + raise ValueError('implement op', op) + op, f = op_dct[op] + args = [x.strcst() for x in self.args][::-1] + a = args.pop() + b = args.pop() + size = self.args[0].size + out = f(op, size, a, b) + while args: + out = f(op, size, out, args.pop()) + return out + + +def ExprSlice_strcst(self): + return '(' + self.arg.strcst() + ')[%d:%d]' % (self.stop - 1, self.start) + + +def ExprCond_strcst(self): + cond = self.cond.strcst() + src1 = self.src1.strcst() + src2 = self.src2.strcst() + return "(IF %s=(%s) THEN %s ELSE %s ENDIF)" % ( + "0bin%s" % ('0' * self.cond.size), cond, src2, src1) + +ExprInt.strcst = ExprInt_strcst +ExprId.strcst = ExprId_strcst +ExprOp.strcst = ExprOp_strcst +ExprCond.strcst = ExprCond_strcst +ExprSlice.strcst = ExprSlice_strcst |