about summary refs log tree commit diff stats
path: root/miasm2/expression
diff options
context:
space:
mode:
Diffstat (limited to 'miasm2/expression')
-rw-r--r--miasm2/expression/expression.py71
-rw-r--r--miasm2/expression/expression_helper.py4
-rw-r--r--miasm2/expression/expression_reduce.py14
-rw-r--r--miasm2/expression/simplifications.py34
-rw-r--r--miasm2/expression/simplifications_common.py142
-rw-r--r--miasm2/expression/simplifications_cond.py54
6 files changed, 184 insertions, 135 deletions
diff --git a/miasm2/expression/expression.py b/miasm2/expression/expression.py
index 58ebda60..d06b7e21 100644
--- a/miasm2/expression/expression.py
+++ b/miasm2/expression/expression.py
@@ -733,7 +733,7 @@ class ExprAssign(Expr):
     def get_r(self, mem_read=False, cst_read=False):
         elements = self._src.get_r(mem_read, cst_read)
         if isinstance(self._dst, ExprMem) and mem_read:
-            elements.update(self._dst.arg.get_r(mem_read, cst_read))
+            elements.update(self._dst.ptr.get_r(mem_read, cst_read))
         return elements
 
     def get_w(self):
@@ -891,47 +891,56 @@ class ExprMem(Expr):
      - Memory write
     """
 
-    __slots__ = Expr.__slots__ + ["_arg"]
+    __slots__ = Expr.__slots__ + ["_ptr"]
 
-    def __init__(self, arg, size=None):
+    def __init__(self, ptr, size=None):
         """Create an ExprMem
-        @arg: Expr, memory access address
+        @ptr: Expr, memory access address
         @size: int, memory access size
         """
         if size is None:
-            warnings.warn('DEPRECATION WARNING: size is a mandatory argument: use ExprMem(arg, SIZE)')
+            warnings.warn('DEPRECATION WARNING: size is a mandatory argument: use ExprMem(ptr, SIZE)')
             size = 32
 
-        # arg must be Expr
-        assert isinstance(arg, Expr)
+        # ptr must be Expr
+        assert isinstance(ptr, Expr)
         assert isinstance(size, (int, long))
 
-        if not isinstance(arg, Expr):
+        if not isinstance(ptr, Expr):
             raise ValueError(
-                'ExprMem: arg must be an Expr (not %s)' % type(arg))
+                'ExprMem: ptr must be an Expr (not %s)' % type(ptr))
 
         super(ExprMem, self).__init__(size)
-        self._arg = arg
+        self._ptr = ptr
 
-    arg = property(lambda self: self._arg)
+    def get_arg(self):
+        warnings.warn('DEPRECATION WARNING: use exprmem.ptr instead of exprmem.arg')
+        return self.ptr
+
+    def set_arg(self, value):
+        warnings.warn('DEPRECATION WARNING: use exprmem.ptr instead of exprmem.arg')
+        self.ptr = value
+
+    ptr = property(lambda self: self._ptr)
+    arg = property(get_arg, set_arg)
 
     def __reduce__(self):
-        state = self._arg, self._size
+        state = self._ptr, self._size
         return self.__class__, state
 
-    def __new__(cls, arg, size=None):
+    def __new__(cls, ptr, size=None):
         if size is None:
-            warnings.warn('DEPRECATION WARNING: size is a mandatory argument: use ExprMem(arg, SIZE)')
+            warnings.warn('DEPRECATION WARNING: size is a mandatory argument: use ExprMem(ptr, SIZE)')
             size = 32
 
-        return Expr.get_object(cls, (arg, size))
+        return Expr.get_object(cls, (ptr, size))
 
     def __str__(self):
-        return "@%d[%s]" % (self.size, str(self.arg))
+        return "@%d[%s]" % (self.size, str(self.ptr))
 
     def get_r(self, mem_read=False, cst_read=False):
         if mem_read:
-            return set(self._arg.get_r(mem_read, cst_read).union(set([self])))
+            return set(self._ptr.get_r(mem_read, cst_read).union(set([self])))
         else:
             return set([self])
 
@@ -939,37 +948,37 @@ class ExprMem(Expr):
         return set([self])  # [memreg]
 
     def _exprhash(self):
-        return hash((EXPRMEM, hash(self._arg), self._size))
+        return hash((EXPRMEM, hash(self._ptr), self._size))
 
     def _exprrepr(self):
         return "%s(%r, %r)" % (self.__class__.__name__,
-                               self._arg, self._size)
+                               self._ptr, self._size)
 
     def __contains__(self, expr):
-        return self == expr or self._arg.__contains__(expr)
+        return self == expr or self._ptr.__contains__(expr)
 
     @visit_chk
     def visit(self, callback, test_visit=None):
-        arg = self._arg.visit(callback, test_visit)
-        if arg == self._arg:
+        ptr = self._ptr.visit(callback, test_visit)
+        if ptr == self._ptr:
             return self
-        return ExprMem(arg, self.size)
+        return ExprMem(ptr, self.size)
 
     def copy(self):
-        arg = self.arg.copy()
-        return ExprMem(arg, size=self.size)
+        ptr = self.ptr.copy()
+        return ExprMem(ptr, size=self.size)
 
     def is_mem_segm(self):
         """Returns True if is ExprMem and ptr is_op_segm"""
-        return self._arg.is_op_segm()
+        return self._ptr.is_op_segm()
 
     def depth(self):
-        return self._arg.depth() + 1
+        return self._ptr.depth() + 1
 
     def graph_recursive(self, graph):
         graph.add_node(self)
-        self._arg.graph_recursive(graph)
-        graph.add_uniq_edge(self, self._arg)
+        self._ptr.graph_recursive(graph)
+        graph.add_uniq_edge(self, self._ptr)
 
     def is_mem(self):
         return True
@@ -1426,7 +1435,7 @@ def compare_exprs(expr1, expr2):
         ret = compare_exprs(expr1.src2, expr2.src2)
         return ret
     elif cls1 == ExprMem:
-        ret = compare_exprs(expr1.arg, expr2.arg)
+        ret = compare_exprs(expr1.ptr, expr2.ptr)
         if ret:
             return ret
         return cmp(expr1.size, expr2.size)
@@ -1616,7 +1625,7 @@ def match_expr(expr, pattern, tks, result=None):
             return False
         if expr.size != pattern.size:
             return False
-        return match_expr(expr.arg, pattern.arg, tks, result)
+        return match_expr(expr.ptr, pattern.ptr, tks, result)
 
     elif expr.is_slice():
         if not pattern.is_slice():
diff --git a/miasm2/expression/expression_helper.py b/miasm2/expression/expression_helper.py
index 7db41394..c503ebfc 100644
--- a/miasm2/expression/expression_helper.py
+++ b/miasm2/expression/expression_helper.py
@@ -272,7 +272,7 @@ class Variables_Identifier(object):
             pass
 
         elif isinstance(expr, m2_expr.ExprMem):
-            self.find_variables_rec(expr.arg)
+            self.find_variables_rec(expr.ptr)
 
         elif isinstance(expr, m2_expr.ExprCompose):
             for arg in expr.args:
@@ -567,7 +567,7 @@ def possible_values(expr):
         consvals.update(ConstrainedValue(consval.constraints,
                                          m2_expr.ExprMem(consval.value,
                                                          expr.size))
-                        for consval in possible_values(expr.arg))
+                        for consval in possible_values(expr.ptr))
     elif isinstance(expr, m2_expr.ExprAssign):
         consvals.update(possible_values(expr.src))
     # Special case: constraint insertion
diff --git a/miasm2/expression/expression_reduce.py b/miasm2/expression/expression_reduce.py
index ab38dfdb..0099dd78 100644
--- a/miasm2/expression/expression_reduce.py
+++ b/miasm2/expression/expression_reduce.py
@@ -68,13 +68,13 @@ class ExprNodeMem(ExprNode):
     def __init__(self, expr):
         assert expr.is_mem()
         super(ExprNodeMem, self).__init__(expr)
-        self.arg = None
+        self.ptr = None
 
     def __repr__(self):
         if self.info is not None:
             out = repr(self.info)
         else:
-            out = "@%d[%r]" % (self.expr.size, self.arg)
+            out = "@%d[%r]" % (self.expr.size, self.ptr)
         return out
 
 
@@ -173,9 +173,9 @@ class ExprReducer(object):
         elif isinstance(expr, ExprInt):
             node = ExprNodeInt(expr)
         elif isinstance(expr, ExprMem):
-            son = self.expr2node(expr.arg)
+            son = self.expr2node(expr.ptr)
             node = ExprNodeMem(expr)
-            node.arg = son
+            node.ptr = son
         elif isinstance(expr, ExprSlice):
             son = self.expr2node(expr.arg)
             node = ExprNodeSlice(expr)
@@ -223,9 +223,9 @@ class ExprReducer(object):
         elif isinstance(expr, ExprLoc):
             node = ExprNodeLoc(expr)
         elif isinstance(expr, ExprMem):
-            arg = self.categorize(node.arg, lvl=lvl + 1, **kwargs)
-            node = ExprNodeMem(ExprMem(arg.expr, expr.size))
-            node.arg = arg
+            ptr = self.categorize(node.ptr, lvl=lvl + 1, **kwargs)
+            node = ExprNodeMem(ExprMem(ptr.expr, expr.size))
+            node.ptr = ptr
         elif isinstance(expr, ExprSlice):
             arg = self.categorize(node.arg, lvl=lvl + 1, **kwargs)
             node = ExprNodeSlice(ExprSlice(arg.expr, expr.start, expr.stop))
diff --git a/miasm2/expression/simplifications.py b/miasm2/expression/simplifications.py
index e090d806..a237a57e 100644
--- a/miasm2/expression/simplifications.py
+++ b/miasm2/expression/simplifications.py
@@ -45,16 +45,23 @@ class ExpressionSimplifier(object):
             simplifications_common.simp_double_zeroext,
             simplifications_common.simp_double_signext,
             simplifications_common.simp_zeroext_eq_cst,
+            simplifications_common.simp_ext_eq_ext,
+
+            simplifications_common.simp_cmp_int,
+            simplifications_common.simp_cmp_int_int,
+            simplifications_common.simp_ext_cst,
 
         ],
 
-        m2_expr.ExprSlice: [simplifications_common.simp_slice],
+        m2_expr.ExprSlice: [
+            simplifications_common.simp_slice,
+            simplifications_common.simp_slice_of_ext,
+        ],
         m2_expr.ExprCompose: [simplifications_common.simp_compose],
         m2_expr.ExprCond: [
             simplifications_common.simp_cond,
             # CC op
             simplifications_common.simp_cond_flag,
-            simplifications_common.simp_cond_int,
             simplifications_common.simp_cmp_int_arg,
 
             simplifications_common.simp_cond_eq_zero,
@@ -68,14 +75,18 @@ class ExpressionSimplifier(object):
     PASS_HEAVY = {}
 
     # Cond passes
-    PASS_COND = {m2_expr.ExprSlice: [simplifications_cond.expr_simp_inf_signed,
-                                     simplifications_cond.expr_simp_inf_unsigned_inversed],
-                 m2_expr.ExprOp: [simplifications_cond.exec_inf_unsigned,
-                                  simplifications_cond.exec_inf_signed,
-                                  simplifications_cond.expr_simp_inverse,
-                                  simplifications_cond.exec_equal],
-                 m2_expr.ExprCond: [simplifications_cond.expr_simp_equal]
-                 }
+    PASS_COND = {
+        m2_expr.ExprSlice: [
+            simplifications_cond.expr_simp_inf_signed,
+            simplifications_cond.expr_simp_inf_unsigned_inversed
+        ],
+        m2_expr.ExprOp: [
+            simplifications_cond.expr_simp_inverse,
+        ],
+        m2_expr.ExprCond: [
+            simplifications_cond.expr_simp_equal
+        ]
+    }
 
 
     # Available passes lists are:
@@ -99,6 +110,9 @@ class ExpressionSimplifier(object):
         Callback signature: Expr callback(ExpressionSimplifier, Expr)
         """
 
+        # Clear cache of simplifiied expressions when adding a new pass
+        self.simplified_exprs.clear()
+
         for k, v in passes.items():
             self.expr_simp_cb[k] = fast_unify(self.expr_simp_cb.get(k, []) + v)
 
diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py
index e7dacc91..7db4e819 100644
--- a/miasm2/expression/simplifications_common.py
+++ b/miasm2/expression/simplifications_common.py
@@ -509,7 +509,7 @@ def simp_slice(e_s, expr):
     if (expr.arg.is_mem() and
           expr.start == 0 and
           expr.arg.size > expr.stop and expr.stop % 8 == 0):
-        return ExprMem(expr.arg.arg, size=expr.stop)
+        return ExprMem(expr.arg.ptr, size=expr.stop)
     # distributivity of slice and &
     # (a & int)[x:y] => 0 if int[x:y] == 0
     if expr.arg.is_op("&") and expr.arg.args[-1].is_int():
@@ -576,9 +576,9 @@ def simp_compose(e_s, expr):
     for i, arg in enumerate(args[:-1]):
         nxt = args[i + 1]
         if arg.is_mem() and nxt.is_mem():
-            gap = e_s(nxt.arg - arg.arg)
+            gap = e_s(nxt.ptr - arg.ptr)
             if gap.is_int() and arg.size % 8 == 0 and int(gap) == arg.size / 8:
-                args = args[:i] + [ExprMem(arg.arg,
+                args = args[:i] + [ExprMem(arg.ptr,
                                           arg.size + nxt.size)] + args[i + 2:]
                 return ExprCompose(*args)
 
@@ -664,8 +664,8 @@ def simp_mem(e_s, expr):
     "Common simplifications on ExprMem"
 
     # @32[x?a:b] => x?@32[a]:@32[b]
-    if expr.arg.is_cond():
-        cond = expr.arg
+    if expr.ptr.is_cond():
+        cond = expr.ptr
         ret = ExprCond(cond.cond,
                        ExprMem(cond.src1, expr.size),
                        ExprMem(cond.src2, expr.size))
@@ -884,35 +884,34 @@ def simp_cond_flag(expr_simp, expr):
     return expr
 
 
-def simp_cond_int(expr_simp, expr):
-    if (expr.cond.is_op(TOK_EQUAL) and
-          expr.cond.args[1].is_int() and
-          expr.cond.args[0].is_compose() and
-          len(expr.cond.args[0].args) == 2 and
-          expr.cond.args[0].args[1].is_int(0)):
+def simp_cmp_int(expr_simp, expr):
+    # ({X, 0} == int) => X == int[:]
+    # X + int1 == int2 => X == int2-int1
+    if (expr.is_op(TOK_EQUAL) and
+          expr.args[1].is_int() and
+          expr.args[0].is_compose() and
+          len(expr.args[0].args) == 2 and
+          expr.args[0].args[1].is_int(0)):
         # ({X, 0} == int) => X == int[:]
-        src = expr.cond.args[0].args[0]
-        int_val = int(expr.cond.args[1])
+        src = expr.args[0].args[0]
+        int_val = int(expr.args[1])
         new_int = ExprInt(int_val, src.size)
         expr = expr_simp(
-            ExprCond(
-                ExprOp(TOK_EQUAL, src, new_int),
-                expr.src1,
-                expr.src2)
+            ExprOp(TOK_EQUAL, src, new_int)
         )
-    elif (expr.cond.is_op() and
-          expr.cond.op in [
+    elif (expr.is_op() and
+          expr.op in [
               TOK_EQUAL,
-              TOK_INF_SIGNED,
-              TOK_INF_EQUAL_SIGNED,
-              TOK_INF_UNSIGNED,
-              TOK_INF_EQUAL_UNSIGNED
           ] and
-          expr.cond.args[1].is_int() and
-          expr.cond.args[0].is_op("+") and
-          expr.cond.args[0].args[-1].is_int()):
+          expr.args[1].is_int() and
+          expr.args[0].is_op("+") and
+          expr.args[0].args[-1].is_int()):
         # X + int1 == int2 => X == int2-int1
-        left, right = expr.cond.args
+        # WARNING:
+        # X - 0x10 <=u 0x20 gives X in [0x10 0x30]
+        # which is not equivalet to A <=u 0x10
+
+        left, right = expr.args
         left, int_diff = left.args[:-1], left.args[-1]
         if len(left) == 1:
             left = left[0]
@@ -920,10 +919,7 @@ def simp_cond_int(expr_simp, expr):
             left = ExprOp('+', *left)
         new_int = expr_simp(right - int_diff)
         expr = expr_simp(
-            ExprCond(
-                ExprOp(expr.cond.op, left, new_int),
-                expr.src1,
-                expr.src2)
+            ExprOp(expr.op, left, new_int),
         )
     return expr
 
@@ -1047,6 +1043,20 @@ def simp_zeroext_eq_cst(expr_s, expr):
         return ExprInt(0, 1)
     return ExprOp(TOK_EQUAL, src, ExprInt(int(arg2), src.size))
 
+def simp_ext_eq_ext(expr_s, expr):
+    # A.zeroExt(X) == B.zeroExt(X) => A == B
+    # A.signExt(X) == B.signExt(X) => A == B
+    if not expr.is_op(TOK_EQUAL):
+        return expr
+    arg1, arg2 = expr.args
+    if (not ((arg1.is_op() and arg1.op.startswith("zeroExt") and
+              arg2.is_op() and arg2.op.startswith("zeroExt")) or
+             (arg1.is_op() and arg1.op.startswith("signExt") and
+               arg2.is_op() and arg2.op.startswith("signExt")))):
+        return expr
+    if arg1.args[0].size != arg2.args[0].size:
+        return expr
+    return ExprOp(TOK_EQUAL, arg1.args[0], arg2.args[0])
 
 def simp_cond_eq_zero(expr_s, expr):
     # (X == 0)?(A:B) => X?(B:A)
@@ -1058,3 +1068,73 @@ def simp_cond_eq_zero(expr_s, expr):
         return expr
     new_expr = ExprCond(arg1, expr.src2, expr.src1)
     return new_expr
+
+
+def simp_cmp_int_int(expr_s, expr):
+    # IntA <s IntB => int
+    # IntA <u IntB => int
+    # IntA <=s IntB => int
+    # IntA <=u IntB => int
+    # IntA == IntB => int
+    if expr.op not in [
+            TOK_EQUAL,
+            TOK_INF_SIGNED, TOK_INF_UNSIGNED,
+            TOK_INF_EQUAL_SIGNED, TOK_INF_EQUAL_UNSIGNED,
+    ]:
+        return expr
+    if not all(arg.is_int() for arg in expr.args):
+        return expr
+    int_a, int_b = expr.args
+    if expr.is_op(TOK_EQUAL):
+        if int_a == int_b:
+            return ExprInt(1, 1)
+        else:
+            return ExprInt(0, 1)
+
+    if expr.op in [TOK_INF_SIGNED, TOK_INF_EQUAL_SIGNED]:
+        int_a = int(mod_size2int[int_a.size](int(int_a)))
+        int_b = int(mod_size2int[int_b.size](int(int_b)))
+    else:
+        int_a = int(mod_size2uint[int_a.size](int(int_a)))
+        int_b = int(mod_size2uint[int_b.size](int(int_b)))
+
+    if expr.op in [TOK_INF_SIGNED, TOK_INF_UNSIGNED]:
+        ret = int_a < int_b
+    else:
+        ret = int_a <= int_b
+
+    if ret:
+        ret = 1
+    else:
+        ret = 0
+    return ExprInt(ret, 1)
+
+
+def simp_ext_cst(expr_s, expr):
+    # Int.zeroExt(X) => Int
+    # Int.signExt(X) => Int
+    if not (expr.op.startswith("zeroExt") or expr.op.startswith("signExt")):
+        return expr
+    arg = expr.args[0]
+    if not arg.is_int():
+        return expr
+    if expr.op.startswith("zeroExt"):
+        ret = int(arg)
+    else:
+        ret = int(mod_size2int[arg.size](int(arg)))
+    ret = ExprInt(ret, expr.size)
+    return ret
+
+
+def simp_slice_of_ext(expr_s, expr):
+    # zeroExt(X)[0:size(X)] => X
+    if expr.start != 0:
+        return expr
+    if not expr.arg.is_op():
+        return expr
+    if not expr.arg.op.startswith("zeroExt"):
+        return expr
+    arg = expr.arg.args[0]
+    if arg.size != expr.size:
+        return expr
+    return arg
diff --git a/miasm2/expression/simplifications_cond.py b/miasm2/expression/simplifications_cond.py
index 6bdc810f..f6b1ea8b 100644
--- a/miasm2/expression/simplifications_cond.py
+++ b/miasm2/expression/simplifications_cond.py
@@ -176,57 +176,3 @@ def expr_simp_equal(expr_simp, e):
         return e
 
     return ExprOp_equal(r[jok1], expr_simp(-r[jok2]))
-
-# Compute conditions
-
-def exec_inf_unsigned(expr_simp, e):
-    "Compute x <u y"
-    if e.op != m2_expr.TOK_INF_UNSIGNED:
-        return e
-
-    arg1, arg2 = e.args
-
-    if isinstance(arg1, m2_expr.ExprInt) and isinstance(arg2, m2_expr.ExprInt):
-        return m2_expr.ExprInt(1, 1) if (arg1.arg < arg2.arg) else m2_expr.ExprInt(0, 1)
-    else:
-        return e
-
-
-def __comp_signed(arg1, arg2):
-    """Return ExprInt(1, 1) if arg1 <s arg2 else ExprInt(0, 1)
-    @arg1, @arg2: ExprInt"""
-
-    val1 = int(arg1)
-    if val1 >> (arg1.size - 1) == 1:
-        val1 = - ((int(arg1.mask) ^ val1) + 1)
-
-    val2 = int(arg2)
-    if val2 >> (arg2.size - 1) == 1:
-        val2 = - ((int(arg2.mask) ^ val2) + 1)
-
-    return m2_expr.ExprInt(1, 1) if (val1 < val2) else m2_expr.ExprInt(0, 1)
-
-def exec_inf_signed(expr_simp, e):
-    "Compute x <s y"
-
-    if e.op != m2_expr.TOK_INF_SIGNED:
-        return e
-
-    arg1, arg2 = e.args
-
-    if isinstance(arg1, m2_expr.ExprInt) and isinstance(arg2, m2_expr.ExprInt):
-        return __comp_signed(arg1, arg2)
-    else:
-        return e
-
-def exec_equal(expr_simp, e):
-    "Compute x == y"
-
-    if e.op != m2_expr.TOK_EQUAL:
-        return e
-
-    arg1, arg2 = e.args
-    if isinstance(arg1, m2_expr.ExprInt) and isinstance(arg2, m2_expr.ExprInt):
-        return m2_expr.ExprInt(1, 1) if (arg1.arg == arg2.arg) else m2_expr.ExprInt(0, 1)
-    else:
-        return e