about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--miasm/expression/simplifications.py1
-rw-r--r--miasm/expression/simplifications_common.py70
-rw-r--r--test/expression/simplifications.py32
3 files changed, 103 insertions, 0 deletions
diff --git a/miasm/expression/simplifications.py b/miasm/expression/simplifications.py
index 1d64456f..d8a11382 100644
--- a/miasm/expression/simplifications.py
+++ b/miasm/expression/simplifications.py
@@ -51,6 +51,7 @@ class ExpressionSimplifier(object):
             simplifications_common.simp_ext_eq_ext,
 
             simplifications_common.simp_cmp_int,
+            simplifications_common.simp_cmp_bijective_op,
             simplifications_common.simp_sign_inf_zeroext,
             simplifications_common.simp_cmp_int_int,
             simplifications_common.simp_ext_cst,
diff --git a/miasm/expression/simplifications_common.py b/miasm/expression/simplifications_common.py
index 69d56997..9db766d8 100644
--- a/miasm/expression/simplifications_common.py
+++ b/miasm/expression/simplifications_common.py
@@ -1004,6 +1004,76 @@ def simp_cmp_int_arg(_, expr):
     return ExprCond(ExprOp(op, arg1, arg2), src1, src2)
 
 
+
+def simp_cmp_bijective_op(expr_simp, expr):
+    """
+    A + B == A => A == 0
+
+    X + A == X + B => A == B
+    X ^ A == X ^ B => A == B
+
+    TODO:
+    3 * A + B == A + C => 2 * A + B == C
+    """
+
+    if not expr.is_op(TOK_EQUAL):
+        return expr
+    op_a = expr.args[0]
+    op_b = expr.args[1]
+
+    # a == a
+    if op_a == op_b:
+        return ExprInt(1, 1)
+
+    # Case:
+    # a + b + c == a
+    if op_a.is_op() and op_a.op in ["+", "^"]:
+        args = list(op_a.args)
+        if op_b in args:
+            args.remove(op_b)
+            if not args:
+                raise ValueError("Can be here")
+            elif len(args) == 1:
+                op_a = args[0]
+            else:
+                op_a = ExprOp(op_a.op, *args)
+            return ExprOp(TOK_EQUAL, op_a, ExprInt(0, args[0].size))
+    # a == a + b + c
+    if op_b.is_op() and op_b.op in ["+", "^"]:
+        args = list(op_b.args)
+        if op_a in args:
+            args.remove(op_a)
+            if not args:
+                raise ValueError("Can be here")
+            elif len(args) == 1:
+                op_b = args[0]
+            else:
+                op_b = ExprOp(op_b.op, *args)
+            return ExprOp(TOK_EQUAL, op_b, ExprInt(0, args[0].size))
+
+    if not (op_a.is_op() and op_b.is_op()):
+        return expr
+    if op_a.op != op_b.op:
+        return expr
+    op = op_a.op
+    if op not in ["+", "^"]:
+        return expr
+    common = set(op_a.args).intersection(op_b.args)
+    if not common:
+        return expr
+
+    args_a = list(op_a.args)
+    args_b = list(op_b.args)
+    for value in common:
+        while value in args_a and value in args_b:
+            args_a.remove(value)
+            args_b.remove(value)
+
+    arg_a = ExprOp(op, *args_a)
+    arg_b = ExprOp(op, *args_b)
+    return ExprOp(TOK_EQUAL, arg_a, arg_b)
+
+
 def simp_subwc_cf(_, expr):
     """SUBWC_CF(A, B, SUB_CF(C, D)) => SUB_CF({A, C}, {B, D})"""
     if not expr.is_op('FLAG_SUBWC_CF'):
diff --git a/test/expression/simplifications.py b/test/expression/simplifications.py
index de059075..f36a7b4d 100644
--- a/test/expression/simplifications.py
+++ b/test/expression/simplifications.py
@@ -522,6 +522,38 @@ to_test = [
     ),
 
 
+    (
+        ExprOp(TOK_EQUAL, a ^ b, a ^ c),
+        ExprOp(TOK_EQUAL, b , c)
+    ),
+
+    (
+        ExprOp(TOK_EQUAL, a + b, a + c),
+        ExprOp(TOK_EQUAL, b , c)
+    ),
+
+    (
+        ExprOp(TOK_EQUAL, a + b, a),
+        ExprOp(TOK_EQUAL, b , i0)
+    ),
+
+    (
+        ExprOp(TOK_EQUAL, a, a + b),
+        ExprOp(TOK_EQUAL, b , i0)
+    ),
+
+
+    (
+        ExprOp(TOK_EQUAL, ExprOp("+", a, b, c), a),
+        ExprOp(TOK_EQUAL, b+c , i0)
+    ),
+
+    (
+        ExprOp(TOK_EQUAL, a, ExprOp("+", a, b, c)),
+        ExprOp(TOK_EQUAL, b+c , i0)
+    ),
+
+
     (ExprOp(TOK_INF_SIGNED, i1, i2), ExprInt(1, 1)),
     (ExprOp(TOK_INF_UNSIGNED, i1, i2), ExprInt(1, 1)),
     (ExprOp(TOK_INF_EQUAL_SIGNED, i1, i2), ExprInt(1, 1)),