about summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorserpilliere <serpilliere@users.noreply.github.com>2020-12-04 06:59:35 +0100
committerGitHub <noreply@github.com>2020-12-04 06:59:35 +0100
commit34438feab8834e93deb57b51bd66a172be6e8135 (patch)
tree3b6e03207deebdb6dc0ff3d2f9b31895d937049a
parent72a8babc6ad2c13e49b12c0e79eeea61067ccc10 (diff)
parent73b6bc5f622941cc382ddb1e4c099029dd9ec3c4 (diff)
downloadfocaccia-miasm-34438feab8834e93deb57b51bd66a172be6e8135.tar.gz
focaccia-miasm-34438feab8834e93deb57b51bd66a172be6e8135.zip
Merge pull request #1319 from serpilliere/fix_z3_div_add_simpl
Fix z3 div; add simpl
-rw-r--r--example/ida/graph_ir.py44
-rw-r--r--miasm/core/modint.py14
-rw-r--r--miasm/expression/simplifications.py2
-rw-r--r--miasm/expression/simplifications_common.py58
-rw-r--r--miasm/ir/translators/z3_ir.py13
-rw-r--r--test/expression/z3_div.py37
-rwxr-xr-xtest/test_all.py4
7 files changed, 155 insertions, 17 deletions
diff --git a/example/ida/graph_ir.py b/example/ida/graph_ir.py
index b8afe5fc..d10e1ebd 100644
--- a/example/ida/graph_ir.py
+++ b/example/ida/graph_ir.py
@@ -16,6 +16,7 @@ from miasm.expression.simplifications import expr_simp
 from miasm.ir.ir import IRBlock, AssignBlock
 from miasm.analysis.data_flow import load_from_int
 from utils import guess_machine, expr2colorstr
+from miasm.expression.expression import ExprLoc, ExprInt, ExprOp, ExprAssign
 from miasm.analysis.simplifier import IRCFGSimplifierCommon, IRCFGSimplifierSSA
 from miasm.core.locationdb import LocationDB
 
@@ -26,8 +27,9 @@ TYPE_GRAPH_IRSSA = 1
 TYPE_GRAPH_IRSSAUNSSA = 2
 
 OPTION_GRAPH_CODESIMPLIFY = 1
-OPTION_GRAPH_DONTMODSTACK = 2
-OPTION_GRAPH_LOADMEMINT = 4
+OPTION_GRAPH_USE_IDA_STACK = 2
+OPTION_GRAPH_DONTMODSTACK = 4
+OPTION_GRAPH_LOADMEMINT = 8
 
 
 class GraphIRForm(ida_kernwin.Form):
@@ -47,6 +49,7 @@ Analysis:
 
 Options:
 <Simplify code:{rCodeSimplify}>
+<Use ida stack:{rUseIdaStack}>
 <Subcalls dont change stack:{rDontModStack}>
 <Load static memory:{rLoadMemInt}>{cOptions}>
 """,
@@ -62,6 +65,7 @@ Options:
                 'cOptions': ida_kernwin.Form.ChkGroupControl(
                     (
                         "rCodeSimplify",
+                        "rUseIdaStack",
                         "rDontModStack",
                         "rLoadMemInt"
                     )
@@ -70,6 +74,7 @@ Options:
         )
         form, _ = self.Compile()
         form.rCodeSimplify.checked = True
+        form.rUseIdaStack.checked = True
         form.rDontModStack.checked = False
         form.rLoadMemInt.checked = False
 
@@ -173,22 +178,36 @@ def is_addr_ro_variable(bs, addr, size):
     return True
 
 
-def build_graph(start_addr, type_graph, simplify=False, dontmodstack=True, loadint=False, verbose=False):
+def build_graph(start_addr, type_graph, simplify=False, use_ida_stack=True, dontmodstack=False, loadint=False, verbose=False):
     machine = guess_machine(addr=start_addr)
     dis_engine, ira = machine.dis_engine, machine.ira
 
     class IRADelModCallStack(ira):
         def call_effects(self, addr, instr):
             assignblks, extra = super(IRADelModCallStack, self).call_effects(addr, instr)
-            if not dontmodstack:
-                return assignblks, extra
-            out = []
-            for assignblk in assignblks:
-                dct = dict(assignblk)
-                dct = {
-                    dst:src for (dst, src) in viewitems(dct) if dst != self.sp
-                }
-                out.append(AssignBlock(dct, assignblk.instr))
+            if use_ida_stack:
+                stk_before = idc.get_spd(instr.offset)
+                stk_after = idc.get_spd(instr.offset + instr.l)
+                stk_diff = stk_after - stk_before
+                print(hex(stk_diff))
+                call_assignblk = AssignBlock(
+                    [
+                        ExprAssign(self.ret_reg, ExprOp('call_func_ret', addr)),
+                        ExprAssign(self.sp, self.sp + ExprInt(stk_diff, self.sp.size))
+                    ],
+                    instr
+                )
+                return [call_assignblk], []
+            else:
+                if not dontmodstack:
+                    return assignblks, extra
+                out = []
+                for assignblk in assignblks:
+                    dct = dict(assignblk)
+                    dct = {
+                        dst:src for (dst, src) in viewitems(dct) if dst != self.sp
+                    }
+                    out.append(AssignBlock(dct, assignblk.instr))
             return out, extra
 
 
@@ -338,6 +357,7 @@ def function_graph_ir():
         func_addr,
         settings.cScope.value,
         simplify=settings.cOptions.value & OPTION_GRAPH_CODESIMPLIFY,
+        use_ida_stack=settings.cOptions.value & OPTION_GRAPH_USE_IDA_STACK,
         dontmodstack=settings.cOptions.value & OPTION_GRAPH_DONTMODSTACK,
         loadint=settings.cOptions.value & OPTION_GRAPH_LOADMEMINT,
         verbose=False
diff --git a/miasm/core/modint.py b/miasm/core/modint.py
index 2ecefed1..14b4dc2c 100644
--- a/miasm/core/modint.py
+++ b/miasm/core/modint.py
@@ -55,6 +55,20 @@ class moduint(object):
     def __div__(self, y):
         # Python: 8 / -7 == -2 (C-like: -1)
         # int(float) trick cannot be used, due to information loss
+        # Examples:
+        #
+        # 42 / 10 => 4
+        # 42 % 10 => 2
+        #
+        # -42 / 10 => -4
+        # -42 % 10 => -2
+        #
+        # 42 / -10 => -4
+        # 42 % -10 => 2
+        #
+        # -42 / -10 => 4
+        # -42 % -10 => -2
+
         den = int(y)
         num = int(self)
         result_sign = 1 if (den * num) >= 0 else -1
diff --git a/miasm/expression/simplifications.py b/miasm/expression/simplifications.py
index c65b2b7b..38c4cbf4 100644
--- a/miasm/expression/simplifications.py
+++ b/miasm/expression/simplifications.py
@@ -65,11 +65,13 @@ class ExpressionSimplifier(ExprVisitorCallbackBottomToTop):
             simplifications_common.simp_compose_and_mask,
             simplifications_common.simp_bcdadd_cf,
             simplifications_common.simp_bcdadd,
+            simplifications_common.simp_smod_sext,
         ],
 
         m2_expr.ExprSlice: [
             simplifications_common.simp_slice,
             simplifications_common.simp_slice_of_ext,
+            simplifications_common.simp_slice_of_sext,
             simplifications_common.simp_slice_of_op_ext,
         ],
         m2_expr.ExprCompose: [simplifications_common.simp_compose],
diff --git a/miasm/expression/simplifications_common.py b/miasm/expression/simplifications_common.py
index fd45ef6d..85af9dc4 100644
--- a/miasm/expression/simplifications_common.py
+++ b/miasm/expression/simplifications_common.py
@@ -594,6 +594,14 @@ def simp_compose(e_s, expr):
                 args = args[:i] + [ExprMem(arg.ptr,
                                           arg.size + nxt.size)] + args[i + 2:]
                 return ExprCompose(*args)
+    # {A, signext(A)[32:64]} => signext(A)
+    if len(args) == 2 and args[0].size == args[1].size:
+        arg1, arg2 = args
+        size = arg1.size
+        sign_ext = arg1.signExtend(arg1.size*2)
+        if arg2 == sign_ext[size:2*size]:
+            return sign_ext
+
 
     # {a, x?b:d, x?c:e, f} => x?{a, b, c, f}:{a, d, e, f}
     conds = set(arg.cond for arg in expr.args if arg.is_cond())
@@ -1443,6 +1451,23 @@ def simp_slice_of_ext(_, expr):
         return arg.zeroExtend(expr.stop)
     return expr
 
+def simp_slice_of_sext(e_s, expr):
+    """
+    with Y <= size(A)
+    A.signExt(X)[0:Y] => A[0:Y]
+    """
+    if not expr.arg.is_op():
+        return expr
+    if not expr.arg.op.startswith("signExt"):
+        return expr
+    arg = expr.arg.args[0]
+    if expr.start != 0:
+        return expr
+    if expr.stop <= arg.size:
+        return e_s.expr_simp(arg[:expr.stop])
+    return expr
+
+
 def simp_slice_of_op_ext(expr_s, expr):
     """
     (X.zeroExt() + {Z, } + ... + Int)[0:8] => X + ... + int[:]
@@ -1763,3 +1788,36 @@ def simp_bcdadd(_, expr):
             carry = 0
         res += j << i
     return ExprInt(res, arg1.size)
+
+
+def simp_smod_sext(expr_s, expr):
+    """
+    a.size == b.size
+    smod(a.signExtend(X), b.signExtend(X)) => smod(a, b).signExtend(X)
+    """
+    if not expr.is_op("smod"):
+        return expr
+    arg1, arg2 = expr.args
+    if arg1.is_op() and arg1.op.startswith("signExt"):
+        src1 = arg1.args[0]
+        if arg2.is_op() and arg2.op.startswith("signExt"):
+            src2 = arg2.args[0]
+            if src1.size == src2.size:
+                # Case: a.signext(), b.signext()
+                return ExprOp("smod", src1, src2).signExtend(expr.size)
+            return expr
+        elif arg2.is_int():
+            src2 = expr_s.expr_simp(arg2[:src1.size])
+            if expr_s.expr_simp(src2.signExtend(arg2.size)) == arg2:
+                # Case: a.signext(), int
+                return ExprOp("smod", src1, src2).signExtend(expr.size)
+            return expr
+    # Case: int        , b.signext()
+    if arg2.is_op() and arg2.op.startswith("signExt"):
+        src2 = arg2.args[0]
+        if arg1.is_int():
+            src1 = expr_s.expr_simp(arg1[:src2.size])
+            if expr_s.expr_simp(src1.signExtend(arg1.size)) == arg1:
+                # Case: int, b.signext()
+                return ExprOp("smod", src1, src2).signExtend(expr.size)
+    return expr
diff --git a/miasm/ir/translators/z3_ir.py b/miasm/ir/translators/z3_ir.py
index 1a36e94e..4b674c4e 100644
--- a/miasm/ir/translators/z3_ir.py
+++ b/miasm/ir/translators/z3_ir.py
@@ -173,11 +173,14 @@ class TranslatorZ3(Translator):
     def _abs(self, z3_value):
         return z3.If(z3_value >= 0,z3_value,-z3_value)
 
-    def _sdivC(self, num, den):
-        """Divide (signed) @num by @den (z3 values) as C would
+    def _sdivC(self, num_expr, den_expr):
+        """Divide (signed) @num by @den (Expr) as C would
         See modint.__div__ for implementation choice
         """
-        result_sign = z3.If(num * den >= 0,
+        num, den = self.from_expr(num_expr), self.from_expr(den_expr)
+        num_s = self.from_expr(num_expr.signExtend(num_expr.size * 2))
+        den_s = self.from_expr(den_expr.signExtend(den_expr.size * 2))
+        result_sign = z3.If(num_s * den_s >= 0,
                             z3.BitVecVal(1, num.size()),
                             z3.BitVecVal(-1, num.size()),
         )
@@ -200,11 +203,11 @@ class TranslatorZ3(Translator):
                 elif expr.op == ">>>":
                     res = z3.RotateRight(res, arg)
                 elif expr.op == "sdiv":
-                    res = self._sdivC(res, arg)
+                    res = self._sdivC(expr.args[0], expr.args[1])
                 elif expr.op == "udiv":
                     res = z3.UDiv(res, arg)
                 elif expr.op == "smod":
-                    res = res - (arg * (self._sdivC(res, arg)))
+                    res = res - (arg * (self._sdivC(expr.args[0], expr.args[1])))
                 elif expr.op == "umod":
                     res = z3.URem(res, arg)
                 elif expr.op == "==":
diff --git a/test/expression/z3_div.py b/test/expression/z3_div.py
new file mode 100644
index 00000000..d436634b
--- /dev/null
+++ b/test/expression/z3_div.py
@@ -0,0 +1,37 @@
+import z3
+from miasm.ir.translators import Translator
+from miasm.expression.expression import *
+
+translator = Translator.to_language("z3")
+
+values = [
+    (42, 10, 4, 2),
+    (-42, 10, -4, -2),
+    (42, -10, -4, 2),
+    (-42, -10, 4, -2)
+]
+
+for a, b, c, d in values:
+    cst_a = ExprInt(a, 8)
+    cst_b = ExprInt(b, 8)
+
+    div_result = ExprInt(c, 8)
+    div = ExprOp("sdiv", cst_a, cst_b)
+    print("%d / %d == %d" % (a, b, div_result))
+    solver = z3.Solver()
+    print("%s == %s" %(div, div_result))
+    eq1 = translator.from_expr(div) != translator.from_expr(div_result)
+    solver.add(eq1)
+    result = solver.check()
+    assert result == z3.unsat
+
+    mod_result = ExprInt(d, 8)
+    print("%d %% %d == %d" % (a, b, mod_result))
+    res2 = ExprOp("smod", cst_a, cst_b)
+    solver = z3.Solver()
+    print("%s == %s" %(res2, mod_result))
+    eq2 = translator.from_expr(res2) != translator.from_expr(mod_result)
+    solver.add(eq2)
+    result = solver.check()
+    assert result == z3.unsat
+
diff --git a/test/test_all.py b/test/test_all.py
index 2670761b..c2391572 100755
--- a/test/test_all.py
+++ b/test/test_all.py
@@ -332,6 +332,10 @@ testset += RegressionTest(["simplifications.py", "--z3"],
                           base_dir="expression",
                           tags=[TAGS["z3"]])
 
+testset += RegressionTest(["z3_div.py"],
+                          base_dir="expression",
+                          tags=[TAGS["z3"]])
+
 ## ObjC/CHandler
 testset += RegressionTest(["test_chandler.py"], base_dir="expr_type",
                           tags=[TAGS["cparser"]])