about summary refs log tree commit diff stats
path: root/example/expression/solve_condition_stp.py
diff options
context:
space:
mode:
authorFabrice Desclaux <fabrice.desclaux@cea.fr>2018-01-11 09:56:26 +0100
committerFabrice Desclaux <fabrice.desclaux@cea.fr>2018-01-11 10:29:36 +0100
commit79a39f095ad05c1d719ea841986c6d2642d184c7 (patch)
treeae0294d654368885d33cad47725d2f86e251d709 /example/expression/solve_condition_stp.py
parent046109f04276755fea9470fe62c8eb37ee329f86 (diff)
downloadfocaccia-miasm-79a39f095ad05c1d719ea841986c6d2642d184c7.tar.gz
focaccia-miasm-79a39f095ad05c1d719ea841986c6d2642d184c7.zip
Expression: use stp translator
Diffstat (limited to 'example/expression/solve_condition_stp.py')
-rw-r--r--example/expression/solve_condition_stp.py213
1 files changed, 95 insertions, 118 deletions
diff --git a/example/expression/solve_condition_stp.py b/example/expression/solve_condition_stp.py
index 438188ab..b3ee6938 100644
--- a/example/expression/solve_condition_stp.py
+++ b/example/expression/solve_condition_stp.py
@@ -1,24 +1,23 @@
 import sys
 import subprocess
-from collections import defaultdict
 from optparse import OptionParser
 from pdb import pm
 
-from miasm2.arch.x86.arch import *
-from miasm2.arch.x86.regs import *
-from miasm2.arch.x86.sem import *
+from miasm2.analysis.machine import Machine
+from miasm2.expression.expression import ExprInt, ExprCond, ExprId, \
+    get_expr_ids, ExprAff
+from miasm2.arch.x86.arch import ParseAst
 from miasm2.core.bin_stream import bin_stream_str
 from miasm2.core import asmblock
-from miasm2.expression.expression import get_rw
-from miasm2.expression.modint import uint32
 from miasm2.ir.symbexec import SymbolicExecutionEngine, get_block
 from miasm2.expression.simplifications import expr_simp
-from miasm2.expression import stp
 from miasm2.core import parse_asm
 from miasm2.arch.x86.disasm import dis_x86_32 as dis_engine
+from miasm2.ir.translators.translator  import Translator
 
 
-mn = mn_x86
+machine = Machine("x86_32")
+
 
 parser = OptionParser(usage="usage: %prog [options] file")
 parser.add_option('-a', "--address", dest="address", metavar="ADDRESS",
@@ -32,99 +31,81 @@ if not args:
 
 def emul_symb(ir_arch, mdis, states_todo, states_done):
     while states_todo:
-        ad, symbols, conds = states_todo.pop()
-        print '*' * 40, "addr", ad, '*' * 40
-        if (ad, symbols, conds) in states_done:
-            print 'skip', ad
+        addr, symbols, conds = states_todo.pop()
+        print '*' * 40, "addr", addr, '*' * 40
+        if (addr, symbols, conds) in states_done:
+            print 'Known state, skipping', addr
             continue
-        states_done.add((ad, symbols, conds))
-        sb = SymbolicExecutionEngine(ir_arch, {})
-        sb.symbols = symbols.copy()
-        if ir_arch.pc in sb.symbols:
-            del(sb.symbols[ir_arch.pc])
-        b = get_block(ir_arch, mdis, ad)
-
-        print 'run block'
-        print b
-        # print blocks[ad]
-        ad = sb.emulbloc(b)
-        print 'final state'
-        sb.dump_id()
-        print 'dataflow'
-        # data_flow_graph_from_expr(sb)
-
-        assert(ad is not None)
-        print "DST", ad
-
-        if isinstance(ad, ExprCond):
+        states_done.add((addr, symbols, conds))
+        symbexec = SymbolicExecutionEngine(ir_arch, {})
+        symbexec.symbols = symbols.copy()
+        if ir_arch.pc in symbexec.symbols:
+            del symbexec.symbols[ir_arch.pc]
+        irblock = get_block(ir_arch, mdis, addr)
+
+        print 'Run block:'
+        print irblock
+        addr = symbexec.emulbloc(irblock)
+        print 'Final state:'
+        symbexec.dump_id()
+
+        assert addr is not None
+
+        if isinstance(addr, ExprCond):
             # Create 2 states, each including complementary conditions
-            p1 = sb.symbols.copy()
-            p2 = sb.symbols.copy()
-            c1 = {ad.cond: ExprInt(0, ad.cond.size)}
-            c2 = {ad.cond: ExprInt(1, ad.cond.size)}
-            print ad.cond
-            p1[ad.cond] = ExprInt(0, ad.cond.size)
-            p2[ad.cond] = ExprInt(1, ad.cond.size)
-            ad1 = expr_simp(sb.eval_expr(ad.replace_expr(c1), {}))
-            ad2 = expr_simp(sb.eval_expr(ad.replace_expr(c2), {}))
-            if not (isinstance(ad1, ExprInt) or (isinstance(ad1, ExprId) and isinstance(ad1.name, asmblock.AsmLabel)) and
-                    isinstance(ad2, ExprInt) or (isinstance(ad2, ExprId) and isinstance(ad2.name, asmblock.AsmLabel))):
-                print str(ad1), str(ad2)
-                raise ValueError("zarb condition")
-            conds1 = list(conds) + c1.items()
-            conds2 = list(conds) + c2.items()
-            if isinstance(ad1, ExprId):
-                ad1 = ad1.name
-            if isinstance(ad2, ExprId):
-                ad2 = ad2.name
-            if isinstance(ad1, ExprInt):
-                ad1 = ad1.arg
-            if isinstance(ad2, ExprInt):
-                ad2 = ad2.arg
-            states_todo.add((ad1, p1, tuple(conds1)))
-            states_todo.add((ad2, p2, tuple(conds2)))
-        elif isinstance(ad, ExprInt):
-            ad = int(ad.arg)
-            states_todo.add((ad, sb.symbols.copy(), tuple(conds)))
-        elif isinstance(ad, ExprId) and isinstance(ad.name, asmblock.AsmLabel):
-            if isinstance(ad, ExprId):
-                ad = ad.name
-            states_todo.add((ad, sb.symbols.copy(), tuple(conds)))
-        elif ad == ret_addr:
-            print 'ret reached'
+            cond_group_a = {addr.cond: ExprInt(0, addr.cond.size)}
+            cond_group_b = {addr.cond: ExprInt(1, addr.cond.size)}
+            addr_a = expr_simp(symbexec.eval_expr(addr.replace_expr(cond_group_a), {}))
+            addr_b = expr_simp(symbexec.eval_expr(addr.replace_expr(cond_group_b), {}))
+            if not (addr_a.is_int() or asmblock.expr_is_label(addr_a) and
+                    addr_b.is_int() or asmblock.expr_is_label(addr_b)):
+                print str(addr_a), str(addr_b)
+                raise ValueError("Unsupported condition")
+            if isinstance(addr_a, ExprInt):
+                addr_a = int(addr_a.arg)
+            if isinstance(addr_b, ExprInt):
+                addr_b = int(addr_b.arg)
+            states_todo.add((addr_a, symbexec.symbols.copy(), tuple(list(conds) + cond_group_a.items())))
+            states_todo.add((addr_b, symbexec.symbols.copy(), tuple(list(conds) + cond_group_b.items())))
+        elif isinstance(addr, ExprInt):
+            addr = int(addr.arg)
+            states_todo.add((addr, symbexec.symbols.copy(), tuple(conds)))
+        elif asmblock.expr_is_label(addr):
+            addr = addr.name
+            states_todo.add((addr, symbexec.symbols.copy(), tuple(conds)))
+        elif addr == ret_addr:
+            print 'Return address reached'
             continue
         else:
-            raise ValueError("zarb eip")
+            raise ValueError("Unsupported destination")
 
 
 if __name__ == '__main__':
 
+    translator_smt2 = Translator.to_language("smt2")
     data = open(args[0]).read()
     bs = bin_stream_str(data)
 
     mdis = dis_engine(bs)
 
-    ad = int(options.address, 16)
+    addr = int(options.address, 16)
 
-    symbols_init = {}
-    for i, r in enumerate(all_regs_ids):
-        symbols_init[r] = all_regs_ids_init[i]
+    symbols_init = dict(machine.mn.regs.regs_init)
 
     # config parser for 32 bit
-    reg_and_id = dict(mn_x86.regs.all_regs_ids_byname)
+    reg_and_id = dict(machine.mn.regs.all_regs_ids_byname)
 
-    def my_ast_int2expr(a):
-        return ExprInt(a, 32)
+    def my_ast_int2expr(name):
+        return ExprInt(name, 32)
 
     # Modifify parser to avoid label creation in PUSH argc
     def my_ast_id2expr(string_parsed):
         if string_parsed in reg_and_id:
             return reg_and_id[string_parsed]
-        else:
-            return ExprId(string_parsed, size=32)
+        return ExprId(string_parsed, size=32)
 
     my_var_parser = ParseAst(my_ast_id2expr, my_ast_int2expr)
-    base_expr.setParseAction(my_var_parser)
+    machine.base_expr.setParseAction(my_var_parser)
 
     argc = ExprId('argc', 32)
     argv = ExprId('argv', 32)
@@ -135,13 +116,13 @@ if __name__ == '__main__':
 
     my_symbols = [argc, argv, ret_addr]
     my_symbols = dict([(x.name, x) for x in my_symbols])
-    my_symbols.update(mn_x86.regs.all_regs_ids_byname)
+    my_symbols.update(machine.mn.regs.all_regs_ids_byname)
 
-    ir_arch = ir_x86_32(mdis.symbol_pool)
+    ir_arch = machine.ir(mdis.symbol_pool)
 
-    sb = SymbolicExecutionEngine(ir_arch, symbols_init)
+    symbexec = SymbolicExecutionEngine(ir_arch, symbols_init)
 
-    blocks, symbol_pool = parse_asm.parse_txt(mn_x86, 32, '''
+    blocks, symbol_pool = parse_asm.parse_txt(machine.mn, 32, '''
     PUSH argv
     PUSH argc
     PUSH ret_addr
@@ -155,15 +136,15 @@ if __name__ == '__main__':
         line.offset, line.l = i, 1
     ir_arch.add_block(b)
     irb = get_block(ir_arch, mdis, 0)
-    sb.emulbloc(irb)
-    sb.dump_mem()
+    symbexec.emulbloc(irb)
+    symbexec.dump_mem()
 
     # reset ir_arch blocks
     ir_arch.blocks = {}
 
     states_todo = set()
     states_done = set()
-    states_todo.add((uint32(ad), sb.symbols, ()))
+    states_todo.add((addr, symbexec.symbols, ()))
 
     # emul blocks, propagate states
     emul_symb(ir_arch, mdis, states_todo, states_done)
@@ -171,57 +152,53 @@ if __name__ == '__main__':
     all_info = []
 
     print '*' * 40, 'conditions to match', '*' * 40
-    for ad, symbols, conds in sorted(states_done):
-        print '*' * 40, ad, '*' * 40
+    for addr, symbols, conds in sorted(states_done):
+        print '*' * 40, addr, '*' * 40
         reqs = []
         for k, v in conds:
             print k, v
             reqs.append((k, v))
-        all_info.append((ad, reqs))
+        all_info.append((addr, reqs))
 
     all_cases = set()
 
-    sb = SymbolicExecutionEngine(ir_arch, symbols_init)
-    for ad, reqs_cond in all_info:
+    symbexec = SymbolicExecutionEngine(ir_arch, symbols_init)
+    for addr, reqs_cond in all_info:
+        out = ['(set-logic QF_ABV)',
+               '(set-info :smt-lib-version 2.0)']
+
+        conditions = []
         all_ids = set()
-        for k, v in reqs_cond:
-            all_ids.update(get_expr_ids(k))
-
-        out = []
-
-        # declare variables
-        for v in all_ids:
-            out.append(str(v) + ":" + "BITVECTOR(%d);" % v.size)
-
-        all_csts = []
-        for k, v in reqs_cond:
-            cst = k.strcst()
-            val = v.arg
-            assert(val in [0, 1])
-            inv = ""
-            if val == 1:
-                inv = "NOT "
-            val = "0" * v.size
-            all_csts.append("(%s%s=0bin%s)" % (inv, cst, val))
-        if not all_csts:
+        for expr, value in reqs_cond:
+
+            all_ids.update(get_expr_ids(expr))
+            expr_test = ExprCond(expr,
+                                 ExprInt(1, value.size),
+                                 ExprInt(0, value.size))
+            cond = translator_smt2.from_expr(ExprAff(expr_test, value))
+            conditions.append(cond)
+
+        for name in all_ids:
+            out.append("(declare-fun %s () (_ BitVec %d))" % (name, name.size))
+        if not out:
             continue
-        rez = " AND ".join(all_csts)
-        out.append("QUERY(NOT (%s));" % rez)
-        end = "\n".join(out)
-        open('out.dot', 'w').write(end)
+
+        out += conditions
+        out.append('(check-sat)')
+        open('out.dot', 'w').write('\n'.join(out))
         try:
             cases = subprocess.check_output(["/home/serpilliere/tools/stp/stp",
-                                             "-p",
+                                             "-p", '--SMTLIB2',
                                              "out.dot"])
         except OSError:
-            print "ERF, cannot find stp"
+            print "Cannot find stp binary!"
             break
         for c in cases.split('\n'):
             if c.startswith('ASSERT'):
-                all_cases.add((ad, c))
+                all_cases.add((addr, c))
 
     print '*' * 40, 'ALL COND', '*' * 40
     all_cases = list(all_cases)
     all_cases.sort(key=lambda x: (x[0], x[1]))
-    for ad, val in all_cases:
-        print 'address', ad, 'is reachable using argc', val
+    for addr, val in all_cases:
+        print 'Address:', addr, 'is reachable using argc', val