about summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorFabrice Desclaux <fabrice.desclaux@cea.fr>2018-01-11 09:56:26 +0100
committerFabrice Desclaux <fabrice.desclaux@cea.fr>2018-01-11 10:29:36 +0100
commit79a39f095ad05c1d719ea841986c6d2642d184c7 (patch)
treeae0294d654368885d33cad47725d2f86e251d709
parent046109f04276755fea9470fe62c8eb37ee329f86 (diff)
downloadmiasm-79a39f095ad05c1d719ea841986c6d2642d184c7.tar.gz
miasm-79a39f095ad05c1d719ea841986c6d2642d184c7.zip
Expression: use stp translator
-rw-r--r--example/expression/solve_condition_stp.py213
-rw-r--r--miasm2/expression/stp.py68
-rwxr-xr-xtest/expression/stp.py20
3 files changed, 107 insertions, 194 deletions
diff --git a/example/expression/solve_condition_stp.py b/example/expression/solve_condition_stp.py
index 438188ab..b3ee6938 100644
--- a/example/expression/solve_condition_stp.py
+++ b/example/expression/solve_condition_stp.py
@@ -1,24 +1,23 @@
 import sys
 import subprocess
-from collections import defaultdict
 from optparse import OptionParser
 from pdb import pm
 
-from miasm2.arch.x86.arch import *
-from miasm2.arch.x86.regs import *
-from miasm2.arch.x86.sem import *
+from miasm2.analysis.machine import Machine
+from miasm2.expression.expression import ExprInt, ExprCond, ExprId, \
+    get_expr_ids, ExprAff
+from miasm2.arch.x86.arch import ParseAst
 from miasm2.core.bin_stream import bin_stream_str
 from miasm2.core import asmblock
-from miasm2.expression.expression import get_rw
-from miasm2.expression.modint import uint32
 from miasm2.ir.symbexec import SymbolicExecutionEngine, get_block
 from miasm2.expression.simplifications import expr_simp
-from miasm2.expression import stp
 from miasm2.core import parse_asm
 from miasm2.arch.x86.disasm import dis_x86_32 as dis_engine
+from miasm2.ir.translators.translator  import Translator
 
 
-mn = mn_x86
+machine = Machine("x86_32")
+
 
 parser = OptionParser(usage="usage: %prog [options] file")
 parser.add_option('-a', "--address", dest="address", metavar="ADDRESS",
@@ -32,99 +31,81 @@ if not args:
 
 def emul_symb(ir_arch, mdis, states_todo, states_done):
     while states_todo:
-        ad, symbols, conds = states_todo.pop()
-        print '*' * 40, "addr", ad, '*' * 40
-        if (ad, symbols, conds) in states_done:
-            print 'skip', ad
+        addr, symbols, conds = states_todo.pop()
+        print '*' * 40, "addr", addr, '*' * 40
+        if (addr, symbols, conds) in states_done:
+            print 'Known state, skipping', addr
             continue
-        states_done.add((ad, symbols, conds))
-        sb = SymbolicExecutionEngine(ir_arch, {})
-        sb.symbols = symbols.copy()
-        if ir_arch.pc in sb.symbols:
-            del(sb.symbols[ir_arch.pc])
-        b = get_block(ir_arch, mdis, ad)
-
-        print 'run block'
-        print b
-        # print blocks[ad]
-        ad = sb.emulbloc(b)
-        print 'final state'
-        sb.dump_id()
-        print 'dataflow'
-        # data_flow_graph_from_expr(sb)
-
-        assert(ad is not None)
-        print "DST", ad
-
-        if isinstance(ad, ExprCond):
+        states_done.add((addr, symbols, conds))
+        symbexec = SymbolicExecutionEngine(ir_arch, {})
+        symbexec.symbols = symbols.copy()
+        if ir_arch.pc in symbexec.symbols:
+            del symbexec.symbols[ir_arch.pc]
+        irblock = get_block(ir_arch, mdis, addr)
+
+        print 'Run block:'
+        print irblock
+        addr = symbexec.emulbloc(irblock)
+        print 'Final state:'
+        symbexec.dump_id()
+
+        assert addr is not None
+
+        if isinstance(addr, ExprCond):
             # Create 2 states, each including complementary conditions
-            p1 = sb.symbols.copy()
-            p2 = sb.symbols.copy()
-            c1 = {ad.cond: ExprInt(0, ad.cond.size)}
-            c2 = {ad.cond: ExprInt(1, ad.cond.size)}
-            print ad.cond
-            p1[ad.cond] = ExprInt(0, ad.cond.size)
-            p2[ad.cond] = ExprInt(1, ad.cond.size)
-            ad1 = expr_simp(sb.eval_expr(ad.replace_expr(c1), {}))
-            ad2 = expr_simp(sb.eval_expr(ad.replace_expr(c2), {}))
-            if not (isinstance(ad1, ExprInt) or (isinstance(ad1, ExprId) and isinstance(ad1.name, asmblock.AsmLabel)) and
-                    isinstance(ad2, ExprInt) or (isinstance(ad2, ExprId) and isinstance(ad2.name, asmblock.AsmLabel))):
-                print str(ad1), str(ad2)
-                raise ValueError("zarb condition")
-            conds1 = list(conds) + c1.items()
-            conds2 = list(conds) + c2.items()
-            if isinstance(ad1, ExprId):
-                ad1 = ad1.name
-            if isinstance(ad2, ExprId):
-                ad2 = ad2.name
-            if isinstance(ad1, ExprInt):
-                ad1 = ad1.arg
-            if isinstance(ad2, ExprInt):
-                ad2 = ad2.arg
-            states_todo.add((ad1, p1, tuple(conds1)))
-            states_todo.add((ad2, p2, tuple(conds2)))
-        elif isinstance(ad, ExprInt):
-            ad = int(ad.arg)
-            states_todo.add((ad, sb.symbols.copy(), tuple(conds)))
-        elif isinstance(ad, ExprId) and isinstance(ad.name, asmblock.AsmLabel):
-            if isinstance(ad, ExprId):
-                ad = ad.name
-            states_todo.add((ad, sb.symbols.copy(), tuple(conds)))
-        elif ad == ret_addr:
-            print 'ret reached'
+            cond_group_a = {addr.cond: ExprInt(0, addr.cond.size)}
+            cond_group_b = {addr.cond: ExprInt(1, addr.cond.size)}
+            addr_a = expr_simp(symbexec.eval_expr(addr.replace_expr(cond_group_a), {}))
+            addr_b = expr_simp(symbexec.eval_expr(addr.replace_expr(cond_group_b), {}))
+            if not (addr_a.is_int() or asmblock.expr_is_label(addr_a) and
+                    addr_b.is_int() or asmblock.expr_is_label(addr_b)):
+                print str(addr_a), str(addr_b)
+                raise ValueError("Unsupported condition")
+            if isinstance(addr_a, ExprInt):
+                addr_a = int(addr_a.arg)
+            if isinstance(addr_b, ExprInt):
+                addr_b = int(addr_b.arg)
+            states_todo.add((addr_a, symbexec.symbols.copy(), tuple(list(conds) + cond_group_a.items())))
+            states_todo.add((addr_b, symbexec.symbols.copy(), tuple(list(conds) + cond_group_b.items())))
+        elif isinstance(addr, ExprInt):
+            addr = int(addr.arg)
+            states_todo.add((addr, symbexec.symbols.copy(), tuple(conds)))
+        elif asmblock.expr_is_label(addr):
+            addr = addr.name
+            states_todo.add((addr, symbexec.symbols.copy(), tuple(conds)))
+        elif addr == ret_addr:
+            print 'Return address reached'
             continue
         else:
-            raise ValueError("zarb eip")
+            raise ValueError("Unsupported destination")
 
 
 if __name__ == '__main__':
 
+    translator_smt2 = Translator.to_language("smt2")
     data = open(args[0]).read()
     bs = bin_stream_str(data)
 
     mdis = dis_engine(bs)
 
-    ad = int(options.address, 16)
+    addr = int(options.address, 16)
 
-    symbols_init = {}
-    for i, r in enumerate(all_regs_ids):
-        symbols_init[r] = all_regs_ids_init[i]
+    symbols_init = dict(machine.mn.regs.regs_init)
 
     # config parser for 32 bit
-    reg_and_id = dict(mn_x86.regs.all_regs_ids_byname)
+    reg_and_id = dict(machine.mn.regs.all_regs_ids_byname)
 
-    def my_ast_int2expr(a):
-        return ExprInt(a, 32)
+    def my_ast_int2expr(name):
+        return ExprInt(name, 32)
 
     # Modifify parser to avoid label creation in PUSH argc
     def my_ast_id2expr(string_parsed):
         if string_parsed in reg_and_id:
             return reg_and_id[string_parsed]
-        else:
-            return ExprId(string_parsed, size=32)
+        return ExprId(string_parsed, size=32)
 
     my_var_parser = ParseAst(my_ast_id2expr, my_ast_int2expr)
-    base_expr.setParseAction(my_var_parser)
+    machine.base_expr.setParseAction(my_var_parser)
 
     argc = ExprId('argc', 32)
     argv = ExprId('argv', 32)
@@ -135,13 +116,13 @@ if __name__ == '__main__':
 
     my_symbols = [argc, argv, ret_addr]
     my_symbols = dict([(x.name, x) for x in my_symbols])
-    my_symbols.update(mn_x86.regs.all_regs_ids_byname)
+    my_symbols.update(machine.mn.regs.all_regs_ids_byname)
 
-    ir_arch = ir_x86_32(mdis.symbol_pool)
+    ir_arch = machine.ir(mdis.symbol_pool)
 
-    sb = SymbolicExecutionEngine(ir_arch, symbols_init)
+    symbexec = SymbolicExecutionEngine(ir_arch, symbols_init)
 
-    blocks, symbol_pool = parse_asm.parse_txt(mn_x86, 32, '''
+    blocks, symbol_pool = parse_asm.parse_txt(machine.mn, 32, '''
     PUSH argv
     PUSH argc
     PUSH ret_addr
@@ -155,15 +136,15 @@ if __name__ == '__main__':
         line.offset, line.l = i, 1
     ir_arch.add_block(b)
     irb = get_block(ir_arch, mdis, 0)
-    sb.emulbloc(irb)
-    sb.dump_mem()
+    symbexec.emulbloc(irb)
+    symbexec.dump_mem()
 
     # reset ir_arch blocks
     ir_arch.blocks = {}
 
     states_todo = set()
     states_done = set()
-    states_todo.add((uint32(ad), sb.symbols, ()))
+    states_todo.add((addr, symbexec.symbols, ()))
 
     # emul blocks, propagate states
     emul_symb(ir_arch, mdis, states_todo, states_done)
@@ -171,57 +152,53 @@ if __name__ == '__main__':
     all_info = []
 
     print '*' * 40, 'conditions to match', '*' * 40
-    for ad, symbols, conds in sorted(states_done):
-        print '*' * 40, ad, '*' * 40
+    for addr, symbols, conds in sorted(states_done):
+        print '*' * 40, addr, '*' * 40
         reqs = []
         for k, v in conds:
             print k, v
             reqs.append((k, v))
-        all_info.append((ad, reqs))
+        all_info.append((addr, reqs))
 
     all_cases = set()
 
-    sb = SymbolicExecutionEngine(ir_arch, symbols_init)
-    for ad, reqs_cond in all_info:
+    symbexec = SymbolicExecutionEngine(ir_arch, symbols_init)
+    for addr, reqs_cond in all_info:
+        out = ['(set-logic QF_ABV)',
+               '(set-info :smt-lib-version 2.0)']
+
+        conditions = []
         all_ids = set()
-        for k, v in reqs_cond:
-            all_ids.update(get_expr_ids(k))
-
-        out = []
-
-        # declare variables
-        for v in all_ids:
-            out.append(str(v) + ":" + "BITVECTOR(%d);" % v.size)
-
-        all_csts = []
-        for k, v in reqs_cond:
-            cst = k.strcst()
-            val = v.arg
-            assert(val in [0, 1])
-            inv = ""
-            if val == 1:
-                inv = "NOT "
-            val = "0" * v.size
-            all_csts.append("(%s%s=0bin%s)" % (inv, cst, val))
-        if not all_csts:
+        for expr, value in reqs_cond:
+
+            all_ids.update(get_expr_ids(expr))
+            expr_test = ExprCond(expr,
+                                 ExprInt(1, value.size),
+                                 ExprInt(0, value.size))
+            cond = translator_smt2.from_expr(ExprAff(expr_test, value))
+            conditions.append(cond)
+
+        for name in all_ids:
+            out.append("(declare-fun %s () (_ BitVec %d))" % (name, name.size))
+        if not out:
             continue
-        rez = " AND ".join(all_csts)
-        out.append("QUERY(NOT (%s));" % rez)
-        end = "\n".join(out)
-        open('out.dot', 'w').write(end)
+
+        out += conditions
+        out.append('(check-sat)')
+        open('out.dot', 'w').write('\n'.join(out))
         try:
             cases = subprocess.check_output(["/home/serpilliere/tools/stp/stp",
-                                             "-p",
+                                             "-p", '--SMTLIB2',
                                              "out.dot"])
         except OSError:
-            print "ERF, cannot find stp"
+            print "Cannot find stp binary!"
             break
         for c in cases.split('\n'):
             if c.startswith('ASSERT'):
-                all_cases.add((ad, c))
+                all_cases.add((addr, c))
 
     print '*' * 40, 'ALL COND', '*' * 40
     all_cases = list(all_cases)
     all_cases.sort(key=lambda x: (x[0], x[1]))
-    for ad, val in all_cases:
-        print 'address', ad, 'is reachable using argc', val
+    for addr, val in all_cases:
+        print 'Address:', addr, 'is reachable using argc', val
diff --git a/miasm2/expression/stp.py b/miasm2/expression/stp.py
deleted file mode 100644
index c9b76e4c..00000000
--- a/miasm2/expression/stp.py
+++ /dev/null
@@ -1,68 +0,0 @@
-from miasm2.expression.expression import *
-
-
-"""
-Quick implementation of miasm traduction to stp langage
-TODO XXX: finish
-"""
-
-
-def ExprInt_strcst(self):
-    b = bin(int(self))[2::][::-1]
-    b += "0" * self.size
-    b = b[:self.size][::-1]
-    return "0bin" + b
-
-
-def ExprId_strcst(self):
-    return self.name
-
-
-def genop(op, size, a, b):
-    return op + '(' + str(size) + ',' + a + ', ' + b + ')'
-
-
-def genop_nosize(op, size, a, b):
-    return op + '(' + a + ', ' + b + ')'
-
-
-def ExprOp_strcst(self):
-    op = self.op
-    op_dct = {"|": " | ",
-              "&": " & "}
-    if op in op_dct:
-        return '(' + op_dct[op].join([x.strcst() for x in self.args]) + ')'
-    op_dct = {"-": "BVUMINUS"}
-    if op in op_dct:
-        return op_dct[op] + '(' + self.args[0].strcst() + ')'
-    op_dct = {"^": ("BVXOR", genop_nosize),
-              "+": ("BVPLUS", genop)}
-    if not op in op_dct:
-        raise ValueError('implement op', op)
-    op, f = op_dct[op]
-    args = [x.strcst() for x in self.args][::-1]
-    a = args.pop()
-    b = args.pop()
-    size = self.args[0].size
-    out = f(op, size, a, b)
-    while args:
-        out = f(op, size, out, args.pop())
-    return out
-
-
-def ExprSlice_strcst(self):
-    return '(' + self.arg.strcst() + ')[%d:%d]' % (self.stop - 1, self.start)
-
-
-def ExprCond_strcst(self):
-    cond = self.cond.strcst()
-    src1 = self.src1.strcst()
-    src2 = self.src2.strcst()
-    return "(IF %s=(%s) THEN %s ELSE %s ENDIF)" % (
-        "0bin%s" % ('0' * self.cond.size), cond, src2, src1)
-
-ExprInt.strcst = ExprInt_strcst
-ExprId.strcst = ExprId_strcst
-ExprOp.strcst = ExprOp_strcst
-ExprCond.strcst = ExprCond_strcst
-ExprSlice.strcst = ExprSlice_strcst
diff --git a/test/expression/stp.py b/test/expression/stp.py
index a4b037de..38bbf9c8 100755
--- a/test/expression/stp.py
+++ b/test/expression/stp.py
@@ -8,24 +8,28 @@ class TestIrIr2STP(unittest.TestCase):
 
     def test_ExprOp_strcst(self):
         from miasm2.expression.expression import ExprInt, ExprOp
-        import miasm2.expression.stp   # /!\ REALLY DIRTY HACK
+        from miasm2.ir.translators.translator  import Translator
+        translator_smt2 = Translator.to_language("smt2")
+
         args = [ExprInt(i, 32) for i in xrange(9)]
 
         self.assertEqual(
-            ExprOp('|',  *args[:2]).strcst(), r'(0bin00000000000000000000000000000000 | 0bin00000000000000000000000000000001)')
+            translator_smt2.from_expr(ExprOp('|',  *args[:2])), r'(bvor (_ bv0 32) (_ bv1 32))')
         self.assertEqual(
-            ExprOp('-',  *args[:2]).strcst(), r'BVUMINUS(0bin00000000000000000000000000000000)')
+            translator_smt2.from_expr(ExprOp('-',  *args[:2])), r'(bvsub (_ bv0 32) (_ bv1 32))')
         self.assertEqual(
-            ExprOp('+',  *args[:3]).strcst(), r'BVPLUS(32,BVPLUS(32,0bin00000000000000000000000000000000, 0bin00000000000000000000000000000001), 0bin00000000000000000000000000000010)')
-        self.assertRaises(ValueError, ExprOp('X', *args[:1]).strcst)
+            translator_smt2.from_expr(ExprOp('+',  *args[:3])), r'(bvadd (bvadd (_ bv0 32) (_ bv1 32)) (_ bv2 32))')
+        self.assertRaises(NotImplementedError, translator_smt2.from_expr, ExprOp('X', *args[:1]))
 
     def test_ExprSlice_strcst(self):
-        from miasm2.expression.expression import ExprInt, ExprSlice
-        import miasm2.expression.stp   # /!\ REALLY DIRTY HACK
+        from miasm2.expression.expression import ExprInt, ExprOp
+        from miasm2.ir.translators.translator  import Translator
+        translator_smt2 = Translator.to_language("smt2")
+
         args = [ExprInt(i, 32) for i in xrange(9)]
 
         self.assertEqual(
-            args[0][1:2].strcst(), r'(0bin00000000000000000000000000000000)[1:1]')
+            translator_smt2.from_expr(args[0][1:2]), r'((_ extract 1 1) (_ bv0 32))')
         self.assertRaises(ValueError, args[0].__getitem__, slice(1,7,2))
 
 if __name__ == '__main__':