about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--.appveyor.yml2
-rw-r--r--.travis.yml4
-rw-r--r--miasm2/expression/simplifications.py6
-rw-r--r--miasm2/expression/simplifications_common.py112
-rw-r--r--miasm2/expression/simplifications_explicit.py8
-rw-r--r--miasm2/os_dep/win_api_x86_32.py8
-rw-r--r--test/expression/simplifications.py46
7 files changed, 146 insertions, 40 deletions
diff --git a/.appveyor.yml b/.appveyor.yml
index 5a6b3b38..fb565570 100644
--- a/.appveyor.yml
+++ b/.appveyor.yml
@@ -33,7 +33,7 @@ build_script:
 
 test_script:
   - cmd: cd c:\projects\miasm\test
-  - "%PYTHON%\\python.exe test_all.py"
+  - "%PYTHON%\\python.exe -W error test_all.py"
 
 after_test:
   - cmd: chdir
diff --git a/.travis.yml b/.travis.yml
index f5c55368..ee4f0ed5 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -26,9 +26,7 @@ before_script:
 - pip install -r optional_requirements.txt
 # codespell
 - "pip install codespell && git ls-files | xargs codespell --ignore-words=.codespell_ignore 2>/dev/null"
-# turn deprecation warning into RuntimeError
-- "find . -name '*.py' | xargs sed -i 's/warnings\\.warn(/raise RuntimeError(/g'"
 # install
 - python setup.py build build_ext
 - python setup.py install
-script: cd test && python test_all.py $MIASM_TEST_EXTRA_ARG && git ls-files -o --exclude-standard
+script: cd test && python -W error test_all.py $MIASM_TEST_EXTRA_ARG && git ls-files -o --exclude-standard
diff --git a/miasm2/expression/simplifications.py b/miasm2/expression/simplifications.py
index 8ea9c41f..483331a6 100644
--- a/miasm2/expression/simplifications.py
+++ b/miasm2/expression/simplifications.py
@@ -55,6 +55,7 @@ class ExpressionSimplifier(object):
             simplifications_common.simp_zeroext_and_cst_eq_cst,
             simplifications_common.simp_test_signext_inf,
             simplifications_common.simp_test_zeroext_inf,
+            simplifications_common.simp_cond_inf_eq_unsigned_zero,
 
         ],
 
@@ -67,6 +68,7 @@ class ExpressionSimplifier(object):
         m2_expr.ExprCond: [
             simplifications_common.simp_cond,
             simplifications_common.simp_cond_zeroext,
+            simplifications_common.simp_cond_add,
             # CC op
             simplifications_common.simp_cond_flag,
             simplifications_common.simp_cmp_int_arg,
@@ -75,11 +77,13 @@ class ExpressionSimplifier(object):
             simplifications_common.simp_x_and_cst_eq_cst,
             simplifications_common.simp_cond_logic_ext,
             simplifications_common.simp_cond_sign_bit,
+            simplifications_common.simp_cond_eq_1_0,
         ],
         m2_expr.ExprMem: [simplifications_common.simp_mem],
 
     }
 
+
     # Heavy passes
     PASS_HEAVY = {}
 
@@ -193,8 +197,6 @@ class ExpressionSimplifier(object):
 expr_simp = ExpressionSimplifier()
 expr_simp.enable_passes(ExpressionSimplifier.PASS_COMMONS)
 
-
-
 expr_simp_high_to_explicit = ExpressionSimplifier()
 expr_simp_high_to_explicit.enable_passes(ExpressionSimplifier.PASS_HIGH_TO_EXPLICIT)
 
diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py
index 00b14554..a4b7c61e 100644
--- a/miasm2/expression/simplifications_common.py
+++ b/miasm2/expression/simplifications_common.py
@@ -909,6 +909,7 @@ def simp_cmp_int(expr_simp, expr):
     """
     ({X, 0} == int) => X == int[:]
     X + int1 == int2 => X == int2-int1
+    X ^ int1 == int2 => X == int1^int2
     """
     if (expr.is_op(TOK_EQUAL) and
           expr.args[1].is_int() and
@@ -922,28 +923,42 @@ def simp_cmp_int(expr_simp, expr):
         expr = expr_simp(
             ExprOp(TOK_EQUAL, src, new_int)
         )
-    elif (expr.is_op() and
-          expr.op in [
-              TOK_EQUAL,
-          ] and
-          expr.args[1].is_int() and
-          expr.args[0].is_op("+") and
-          expr.args[0].args[-1].is_int()):
-        # X + int1 == int2 => X == int2-int1
-        # WARNING:
-        # X - 0x10 <=u 0x20 gives X in [0x10 0x30]
-        # which is not equivalet to A <=u 0x10
-
-        left, right = expr.args
-        left, int_diff = left.args[:-1], left.args[-1]
-        if len(left) == 1:
-            left = left[0]
-        else:
-            left = ExprOp('+', *left)
-        new_int = expr_simp(right - int_diff)
-        expr = expr_simp(
-            ExprOp(expr.op, left, new_int),
-        )
+    elif not expr.is_op(TOK_EQUAL):
+        return expr
+    assert len(expr.args) == 2
+
+    left, right = expr.args
+    if left.is_int() and not right.is_int():
+        left, right = right, left
+    if not right.is_int():
+        return expr
+    if not (left.is_op() and left.op in ['+', '^']):
+        return expr
+    if not left.args[-1].is_int():
+        return expr
+    # X + int1 == int2 => X == int2-int1
+    # WARNING:
+    # X - 0x10 <=u 0x20 gives X in [0x10 0x30]
+    # which is not equivalet to A <=u 0x10
+
+    left_orig = left
+    left, last_int = left.args[:-1], left.args[-1]
+
+    if len(left) == 1:
+        left = left[0]
+    else:
+        left = ExprOp(left.op, *left)
+
+    if left_orig.op == "+":
+        new_int = expr_simp(right - last_int)
+    elif left_orig.op == '^':
+        new_int = expr_simp(right ^ last_int)
+    else:
+        raise RuntimeError("Unsupported operator")
+
+    expr = expr_simp(
+        ExprOp(TOK_EQUAL, left, new_int),
+    )
     return expr
 
 
@@ -1375,6 +1390,59 @@ def simp_cond_sign_bit(_, expr):
     return ExprCond(cond, expr.src1, expr.src2)
 
 
+def simp_cond_add(expr_s, expr):
+    """
+    (a+b)?X:Y => (a == b)?Y:X
+    (a^b)?X:Y => (a == b)?Y:X
+    """
+    cond = expr.cond
+    if not cond.is_op():
+        return expr
+    if cond.op not in ['+', '^']:
+        return expr
+    if len(cond.args) != 2:
+        return expr
+    arg1, arg2 = cond.args
+    if cond.is_op('+'):
+        new_cond = ExprOp('==', arg1, expr_s(-arg2))
+    elif cond.is_op('^'):
+        new_cond = ExprOp('==', arg1, arg2)
+    else:
+        raise ValueError('Bad case')
+    return ExprCond(new_cond, expr.src2, expr.src1)
+
+
+def simp_cond_eq_1_0(expr_s, expr):
+    """
+    (a == b)?ExprInt(1, 1):ExprInt(0, 1) => a == b
+    (a <s b)?ExprInt(1, 1):ExprInt(0, 1) => a == b
+    ...
+    """
+    cond = expr.cond
+    if not cond.is_op():
+        return expr
+    if cond.op not in [
+            TOK_EQUAL,
+            TOK_INF_SIGNED, TOK_INF_EQUAL_SIGNED,
+            TOK_INF_UNSIGNED, TOK_INF_EQUAL_UNSIGNED
+            ]:
+        return expr
+    if expr.src1 != ExprInt(1, 1) or expr.src2 != ExprInt(0, 1):
+        return expr
+    return cond
+
+
+def simp_cond_inf_eq_unsigned_zero(expr_s, expr):
+    """
+    (a <=u 0) => a == 0
+    """
+    if not expr.is_op(TOK_INF_EQUAL_UNSIGNED):
+        return expr
+    if not expr.args[1].is_int(0):
+        return expr
+    return ExprOp(TOK_EQUAL, expr.args[0], expr.args[1])
+
+
 def simp_test_signext_inf(expr_s, expr):
     """A.signExt() <s int => A <s int[:]"""
     if not (expr.is_op(TOK_INF_SIGNED) or expr.is_op(TOK_INF_EQUAL_SIGNED)):
diff --git a/miasm2/expression/simplifications_explicit.py b/miasm2/expression/simplifications_explicit.py
index 4c5dde3e..00892201 100644
--- a/miasm2/expression/simplifications_explicit.py
+++ b/miasm2/expression/simplifications_explicit.py
@@ -155,13 +155,5 @@ def simp_flags(_, expr):
         op_nf, = args
         return ~op_nf
 
-    elif expr.is_op(TOK_EQUAL):
-        arg1, arg2 = args
-        return ExprCond(
-            arg1 - arg2,
-            ExprInt(0, expr.size),
-            ExprInt(1, expr.size),
-        )
-
     return expr
 
diff --git a/miasm2/os_dep/win_api_x86_32.py b/miasm2/os_dep/win_api_x86_32.py
index 5d6e4765..df679074 100644
--- a/miasm2/os_dep/win_api_x86_32.py
+++ b/miasm2/os_dep/win_api_x86_32.py
@@ -1490,20 +1490,20 @@ def kernel32_lstrlen(jitter):
     my_strlen(jitter, whoami(), jitter.get_str_ansi, len)
 
 
-def my_lstrcat(jitter, funcname, get_str):
+def my_lstrcat(jitter, funcname, get_str, set_str):
     ret_ad, args = jitter.func_args_stdcall(['ptr_str1', 'ptr_str2'])
     s1 = get_str(args.ptr_str1)
     s2 = get_str(args.ptr_str2)
-    jitter.vm.set_mem(args.ptr_str1, s1 + s2)
+    set_str(args.ptr_str1, s1 + s2)
     jitter.func_ret_stdcall(ret_ad, args.ptr_str1)
 
 
 def kernel32_lstrcatA(jitter):
-    my_lstrcat(jitter, whoami(), jitter.get_str_ansi)
+    my_lstrcat(jitter, whoami(), jitter.get_str_ansi, jitter.set_str_ansi)
 
 
 def kernel32_lstrcatW(jitter):
-    my_lstrcat(jitter, whoami(), jitter.get_str_unic)
+    my_lstrcat(jitter, whoami(), jitter.get_str_unic, jitter.set_str_unic)
 
 
 def kernel32_GetUserGeoID(jitter):
diff --git a/test/expression/simplifications.py b/test/expression/simplifications.py
index 5bca3fa9..cc33fc54 100644
--- a/test/expression/simplifications.py
+++ b/test/expression/simplifications.py
@@ -101,6 +101,10 @@ i3 = ExprInt(3, 32)
 im1 = ExprInt(-1, 32)
 im2 = ExprInt(-2, 32)
 
+bi0 = ExprInt(0, 1)
+bi1 = ExprInt(1, 1)
+
+
 icustom = ExprInt(0x12345678, 32)
 cc = ExprCond(a, b, c)
 
@@ -490,6 +494,21 @@ to_test = [
         ExprOp(TOK_EQUAL, a8, ExprInt(0xFF, 8))
     ),
 
+    (
+        ExprOp(TOK_EQUAL, i2, a + i1),
+        ExprOp(TOK_EQUAL, a , i1)
+    ),
+
+    (
+        ExprOp(TOK_EQUAL, a ^ i1, i2),
+        ExprOp(TOK_EQUAL, a , i3)
+    ),
+
+    (
+        ExprOp(TOK_EQUAL, i2, a ^ i1),
+        ExprOp(TOK_EQUAL, a , i3)
+    ),
+
     (ExprOp(TOK_INF_SIGNED, i1, i2), ExprInt(1, 1)),
     (ExprOp(TOK_INF_UNSIGNED, i1, i2), ExprInt(1, 1)),
     (ExprOp(TOK_INF_EQUAL_SIGNED, i1, i2), ExprInt(1, 1)),
@@ -692,6 +711,33 @@ to_test = [
 
     (a8.zeroExtend(32)[2:5], a8[2:5]),
 
+
+    (
+        ExprCond(a + b, a, b),
+        ExprCond(ExprOp(TOK_EQUAL, a, -b), b, a)
+    ),
+
+    (
+        ExprCond(a + i1, a, b),
+        ExprCond(ExprOp(TOK_EQUAL, a, im1), b, a)
+    ),
+
+
+    (
+        ExprCond(ExprOp(TOK_EQUAL, a, i1), bi1, bi0),
+        ExprOp(TOK_EQUAL, a, i1)
+    ),
+
+    (
+        ExprCond(ExprOp(TOK_INF_SIGNED, a, i1), bi1, bi0),
+        ExprOp(TOK_INF_SIGNED, a, i1)
+    ),
+
+    (
+        ExprOp(TOK_INF_EQUAL_UNSIGNED, a, i0),
+        ExprOp(TOK_EQUAL, a, i0)
+    ),
+
 ]
 
 for e_input, e_check in to_test: