about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--miasm2/expression/smt2_helper.py296
-rw-r--r--miasm2/ir/translators/smt2.py283
-rw-r--r--test/ir/translators/smt2.py40
-rw-r--r--test/test_all.py2
4 files changed, 621 insertions, 0 deletions
diff --git a/miasm2/expression/smt2_helper.py b/miasm2/expression/smt2_helper.py
new file mode 100644
index 00000000..53d323e8
--- /dev/null
+++ b/miasm2/expression/smt2_helper.py
@@ -0,0 +1,296 @@
+# Helper functions for the generation of SMT2 expressions
+# The SMT2 expressions will be returned as a string.
+# The expressions are divided as follows
+#
+# - generic SMT2 operations
+# - definitions of SMT2 structures
+# - bit vector operations
+# - array operations
+
+# generic SMT2 operations
+
+def smt2_eq(a, b):
+    """
+    Assignment: a = b
+    """
+    return "(= {} {})".format(a, b)
+
+
+def smt2_implies(a, b):
+    """
+    Implication: a => b
+    """
+    return "(=> {} {})".format(a, b)
+
+
+def smt2_and(*args):
+    """
+    Conjunction: a and b and c ...
+    """
+    # transform args into strings
+    args = [str(arg) for arg in args]
+    return "(and {})".format(' '.join(args))
+
+
+def smt2_or(*args):
+    """
+    Disjunction: a or b or c ...
+    """
+    # transform args into strings
+    args = [str(arg) for arg in args]
+    return "(or {})".format(' '.join(args))
+
+
+def smt2_ite(cond, a, b):
+    """
+    If-then-else: cond ? a : b
+    """
+    return "(ite {} {} {})".format(cond, a, b)
+
+
+def smt2_distinct(*args):
+    """
+    Distinction: a != b != c != ...
+    """
+    # transform args into strings
+    args = [str(arg) for arg in args]
+    return "(distinct {})".format(' '.join(args))
+
+
+def smt2_assert(expr):
+    """
+    Assertion that @expr holds
+    """
+    return "(assert {})".format(expr)
+
+
+# definitions
+
+def declare_bv(bv, size):
+    """
+    Declares an bit vector @bv of size @size
+    """
+    return "(declare-fun {} () {})".format(bv, bit_vec(size))
+
+
+def declare_array(a, bv1, bv2):
+    """
+    Declares an SMT2 array represented as a map
+    from a bit vector to another bit vector.
+    :param a: array name
+    :param bv1: SMT2 bit vector
+    :param bv2: SMT2 bit vector
+    """
+    return "(declare-fun {} () (Array {} {}))".format(a, bv1, bv2)
+
+
+def bit_vec_val(v, size):
+    """
+    Declares a bit vector value
+    :param v: int, value of the bit vector
+    :param size: size of the bit vector
+    """
+    return "(_ bv{} {})".format(v, size)
+
+
+def bit_vec(size):
+    """
+    Returns a bit vector of size @size
+    """
+    return "(_ BitVec {})".format(size)
+
+
+# bit vector operations
+
+def bvadd(a, b):
+    """
+    Addition: a + b
+    """
+    return "(bvadd {} {})".format(a, b)
+
+
+def bvsub(a, b):
+    """
+    Subtraction: a - b
+    """
+    return "(bvsub {} {})".format(a, b)
+
+
+def bvmul(a, b):
+    """
+    Multiplication: a * b
+    """
+    return "(bvmul {} {})".format(a, b)
+
+
+def bvand(a, b):
+    """
+    Bitwise AND: a & b
+    """
+    return "(bvand {} {})".format(a, b)
+
+
+def bvor(a, b):
+    """
+    Bitwise OR: a | b
+    """
+    return "(bvor {} {})".format(a, b)
+
+
+def bvxor(a, b):
+    """
+    Bitwise XOR: a ^ b
+    """
+    return "(bvxor {} {})".format(a, b)
+
+
+def bvneg(bv):
+    """
+    Unary minus: - bv
+    """
+    return "(bvneg {})".format(bv)
+
+
+def bvsdiv(a, b):
+    """
+    Signed division: a / b
+    """
+    return "(bvsdiv {} {})".format(a, b)
+
+
+def bvudiv(a, b):
+    """
+    Unsigned division: a / b
+    """
+    return "(bvudiv {} {})".format(a, b)
+
+
+def bvsmod(a, b):
+    """
+    Signed modulo: a mod b
+    """
+    return "(bvsmod {} {})".format(a, b)
+
+
+def bvurem(a, b):
+    """
+    Unsigned modulo: a mod b
+    """
+    return "(bvurem {} {})".format(a, b)
+
+
+def bvshl(a, b):
+    """
+    Shift left: a << b
+    """
+    return "(bvshl {} {})".format(a, b)
+
+
+def bvlshr(a, b):
+    """
+    Logical shift right: a >> b
+    """
+    return "(bvlshr {} {})".format(a, b)
+
+
+def bvashr(a, b):
+    """
+    Arithmetic shift right: a a>> b
+    """
+    return "(bvashr {} {})".format(a, b)
+
+
+def bv_rotate_left(a, b, size):
+    """
+    Rotates bits of a to the left b times: a <<< b
+
+    Since ((_ rotate_left b) a) does not support
+    symbolic values for b, the implementation is
+    based on a C implementation.
+
+    Therefore, the rotation will be computed as
+    a << (b & (size - 1))) | (a >> (size - (b & (size - 1))))
+
+    :param a: bit vector
+    :param b: bit vector
+    :param size: size of a
+    """
+
+    # define constant
+    s = bit_vec_val(size, size)
+
+    # shift = b & (size  - 1)
+    shift = bvand(b, bvsub(s, bit_vec_val(1, size)))
+
+    # (a << shift) | (a >> size - shift)
+    rotate = bvor(bvshl(a, shift),
+                  bvlshr(a, bvsub(s, shift)))
+
+    return rotate
+
+
+def bv_rotate_right(a, b, size):
+    """
+    Rotates bits of a to the right b times: a >>> b
+
+    Since ((_ rotate_right b) a) does not support
+    symbolic values for b, the implementation is
+    based on a C implementation.
+
+    Therefore, the rotation will be computed as
+    a >> (b & (size - 1))) | (a << (size - (b & (size - 1))))
+
+    :param a: bit vector
+    :param b: bit vector
+    :param size: size of a
+    """
+
+    # define constant
+    s = bit_vec_val(size, size)
+
+    # shift = b & (size  - 1)
+    shift = bvand(b, bvsub(s, bit_vec_val(1, size)))
+
+    # (a >> shift) | (a << size - shift)
+    rotate = bvor(bvlshr(a, shift),
+                  bvshl(a, bvsub(s, shift)))
+
+    return rotate
+
+
+def bv_extract(high, low, bv):
+    """
+    Extracts bits from a bit vector
+    :param high: end bit
+    :param low: start bit
+    :param bv: bit vector
+    """
+    return "((_ extract {} {}) {})".format(high, low, bv)
+
+
+def bv_concat(a, b):
+    """
+    Concatenation of two SMT2 expressions
+    """
+    return "(concat {} {})".format(a, b)
+
+
+# array operations
+
+def array_select(array, index):
+    """
+    Reads from an SMT2 array at index @index
+    :param array: SMT2 array
+    :param index: SMT2 expression, index of the array
+    """
+    return "(select {} {})".format(array, index)
+
+
+def array_store(array, index, value):
+    """
+    Writes an value into an SMT2 array at address @index
+    :param array: SMT array
+    :param index: SMT2 expression, index of the array
+    :param value: SMT2 expression, value to write
+    """
+    return "(store {} {} {})".format(array, index, value)
diff --git a/miasm2/ir/translators/smt2.py b/miasm2/ir/translators/smt2.py
new file mode 100644
index 00000000..96f8dab3
--- /dev/null
+++ b/miasm2/ir/translators/smt2.py
@@ -0,0 +1,283 @@
+import logging
+import operator
+
+from miasm2.core.asmbloc import asm_label
+from miasm2.ir.translators.translator import Translator
+from miasm2.expression.smt2_helper import *
+
+log = logging.getLogger("translator_smt2")
+console_handler = logging.StreamHandler()
+console_handler.setFormatter(logging.Formatter("%(levelname)-5s: %(message)s"))
+log.addHandler(console_handler)
+log.setLevel(logging.WARNING)
+
+class SMT2Mem(object):
+    """
+    Memory abstraction for TranslatorSMT2. Memory elements are only accessed,
+    never written. To give a concrete value for a given memory cell in a solver,
+    add "mem32.get(address, size) == <value>" constraints to your equation.
+    The endianness of memory accesses is handled accordingly to the "endianness"
+    attribute.
+    Note: Will have one memory space for each addressing size used.
+    For example, if memory is accessed via 32 bits values and 16 bits values,
+    these access will not occur in the same address space.
+
+    Adapted from Z3Mem
+    """
+
+    def __init__(self, endianness="<", name="mem"):
+        """Initializes an SMT2Mem object with a given @name and @endianness.
+        @endianness: Endianness of memory representation. '<' for little endian,
+            '>' for big endian.
+        @name: name of memory Arrays generated. They will be named
+            name+str(address size) (for example mem32, mem16...).
+        """
+        if endianness not in ['<', '>']:
+            raise ValueError("Endianness should be '>' (big) or '<' (little)")
+        self.endianness = endianness
+        self.mems = {} # Address size -> SMT2 memory array
+        self.name = name
+        # initialise address size
+        self.addr_size = 0
+
+    def get_mem_array(self, size):
+        """Returns an SMT Array used internally to represent memory for addresses
+        of size @size.
+        @size: integer, size in bit of addresses in the memory to get.
+        Return an string with the name of the SMT array..
+        """
+        try:
+            mem = self.mems[size]
+        except KeyError:
+            # Lazy instanciation
+            self.mems[size] = self.name + str(size)
+            mem = self.mems[size]
+        return mem
+
+    def __getitem__(self, addr):
+        """One byte memory access. Different address sizes with the same value
+        will result in different memory accesses.
+        @addr: an SMT2 expression, the address to read.
+        Return an SMT2 expression of size 8 bits representing a memory access.
+        """
+        size = self.addr_size
+        mem = self.get_mem_array(size)
+        return array_select(mem, addr)
+
+    def get(self, addr, size, addr_size):
+        """ Memory access at address @addr of size @size with
+        address size @addr_size.
+        @addr: an SMT2 expression, the address to read.
+        @size: int, size of the read in bits.
+        @addr_size: int, size of the address
+        Return a SMT2 expression representing a memory access.
+        """
+        # set address size per read access
+        self.addr_size = addr_size
+
+        original_size = size
+        if original_size % 8 != 0:
+            # Size not aligned on 8bits -> read more than size and extract after
+            size = ((original_size / 8) + 1) * 8
+        res = self[addr]
+        if self.is_little_endian():
+            for i in xrange(1, size/8):
+                index = bvadd(addr, bit_vec_val(i, addr_size))
+                res = bv_concat(self[index], res)
+        else:
+            for i in xrange(1, size/8):
+                res = bv_concat(res, self[index])
+        if size == original_size:
+            return res
+        else:
+            # Size not aligned, extract right sized result
+            return bv_extract(original_size-1, 0, res)
+
+    def is_little_endian(self):
+        """True if this memory is little endian."""
+        return self.endianness == "<"
+
+    def is_big_endian(self):
+        """True if this memory is big endian."""
+        return not self.is_little_endian()
+
+
+class TranslatorSMT2(Translator):
+    """Translate a Miasm expression into an equivalent SMT2
+    expression. Memory is abstracted via SMT2Mem.
+    The result of from_expr will be an SMT2 expression.
+
+    If you want to interract with the memory abstraction after the translation,
+    you can instantiate your own SMT2Mem that will be equivalent to the one
+    used by TranslatorSMT2.
+
+    TranslatorSMT2 provides the creation of a valid SMT2 file. For this,
+    it keeps track of the translated bit vectors.
+
+    Adapted from TranslatorZ3
+    """
+
+    # Implemented language
+    __LANG__ = "smt2"
+
+    def __init__(self, endianness="<", **kwargs):
+        """Instance a SMT2 translator
+        @endianness: (optional) memory endianness
+        """
+        super(TranslatorSMT2, self).__init__(**kwargs)
+        # memory abstraction
+        self._mem = SMT2Mem(endianness)
+        # map of translated bit vectors
+        self._bitvectors = dict()
+
+    def from_ExprInt(self, expr):
+        return bit_vec_val(expr.arg.arg, expr.size)
+
+    def from_ExprId(self, expr):
+        if isinstance(expr.name, asm_label):
+            if expr.name.offset is not None:
+                return bit_vec_val(str(expr.name.offset), expr.size)
+            else:
+                # SMT2-escape expression name
+                name = "|{}|".format(str(expr.name))
+                if name not in self._bitvectors:
+                    self._bitvectors[name] = expr.size
+                return name
+        else:
+            if str(expr) not in self._bitvectors:
+                self._bitvectors[str(expr)] = expr.size
+            return str(expr)
+
+    def from_ExprMem(self, expr):
+        addr = self.from_expr(expr.arg)
+        # size to read from memory
+        size = expr.size
+        # size of memory address
+        addr_size = expr.arg.size
+        return self._mem.get(addr, size, addr_size)
+
+    def from_ExprSlice(self, expr):
+        res = self.from_expr(expr.arg)
+        res = bv_extract(expr.stop-1, expr.start, res)
+        return res
+
+    def from_ExprCompose(self, expr):
+        res = None
+        args = sorted(expr.args, key=operator.itemgetter(2))  # sort by start off
+        for subexpr, start, stop in args:
+            sube = self.from_expr(subexpr)
+            e = bv_extract(stop-start-1, 0, sube)
+            if res:
+                res = bv_concat(e, res)
+            else:
+                res = e
+        return res
+
+    def from_ExprCond(self, expr):
+        cond = self.from_expr(expr.cond)
+        src1 = self.from_expr(expr.src1)
+        src2 = self.from_expr(expr.src2)
+
+        # (and (distinct cond (_ bv0 <size>)) true)
+        zero = bit_vec_val(0, expr.cond.size)
+        distinct = smt2_distinct(cond, zero)
+        distinct_and = smt2_and(distinct, "true")
+
+        # (ite ((and (distinct cond (_ bv0 <size>)) true) src1 src2))
+        return smt2_ite(distinct_and, src1, src2)
+
+    def from_ExprOp(self, expr):
+        args = map(self.from_expr, expr.args)
+        res = args[0]
+
+        if len(args) > 1:
+            for arg in args[1:]:
+                if expr.op == "+":
+                    res = bvadd(res, arg)
+                elif expr.op == "-":
+                    res = bvsub(res, arg)
+                elif expr.op == "*":
+                    res = bvmul(res, arg)
+                elif expr.op == "/":
+                    res = bvsdiv(res, arg)
+                elif expr.op == "idiv":
+                    res = bvsdiv(res, arg)
+                elif expr.op == "udiv":
+                    res = bvudiv(res, arg)
+                elif expr.op == "%":
+                    res = bvsmod(res, arg)
+                elif expr.op == "imod":
+                    res = bvsmod(res, arg)
+                elif expr.op == "umod":
+                    res = bvurem(res, arg)
+                elif expr.op == "&":
+                    res = bvand(res, arg)
+                elif expr.op == "^":
+                    res = bvxor(res, arg)
+                elif expr.op == "|":
+                    res = bvor(res, arg)
+                elif expr.op == "<<":
+                    res = bvshl(res, arg)
+                elif expr.op == ">>":
+                    res = bvlshr(res, arg)
+                elif expr.op == "a<<":
+                    res = bvshl(res, arg)
+                elif expr.op == "a>>":
+                    res = bvashr(res, arg)
+                elif expr.op == "<<<":
+                    res = bv_rotate_left(res, arg, expr.size)
+                elif expr.op == ">>>":
+                    res = bv_rotate_right(res, arg, expr.size)
+                else:
+                    raise NotImplementedError("Unsupported OP yet: %s" % expr.op)
+        elif expr.op == 'parity':
+            arg = bv_extract(7, 0, res)
+            res = bit_vec_val(1, 1)
+            for i in xrange(8):
+                res = bvxor(res, bv_extract(i, i, arg))
+        elif expr.op == '-':
+            res = bvneg(res)
+        else:
+            raise NotImplementedError("Unsupported OP yet: %s" % expr.op)
+
+        return res
+
+    def from_ExprAff(self, expr):
+        src = self.from_expr(expr.src)
+        dst = self.from_expr(expr.dst)
+        return smt2_assert(smt2_eq(src, dst))
+
+    def to_smt2(self, exprs, logic="QF_ABV"):
+        """
+        Converts a valid SMT2 file for a given list of
+        SMT2 expressions.
+
+        :param exprs: list of SMT2 expressions
+        :param logic: SMT2 logic
+        :return: String of the SMT2 file
+        """
+        ret = ""
+        ret += "(set-logic {})\n".format(logic)
+
+        # define bit vectors
+        for bv in self._bitvectors:
+            size = self._bitvectors[bv]
+            ret += "{}\n".format(declare_bv(bv, size))
+
+        # define memory arrays
+        for size in self._mem.mems:
+            mem = self._mem.mems[size]
+            ret += "{}\n".format(declare_array(mem, bit_vec(size), bit_vec(8)))
+
+        # merge SMT2 expressions
+        for expr in exprs:
+            ret += expr + "\n"
+
+        # define action
+        ret += "(check-sat)\n"
+
+        return ret
+
+
+# Register the class
+Translator.register(TranslatorSMT2)
diff --git a/test/ir/translators/smt2.py b/test/ir/translators/smt2.py
new file mode 100644
index 00000000..97877a3b
--- /dev/null
+++ b/test/ir/translators/smt2.py
@@ -0,0 +1,40 @@
+from z3 import Solver, unsat, parse_smt2_string
+from miasm2.expression.expression import *
+from miasm2.ir.translators.smt2 import TranslatorSMT2
+from miasm2.ir.translators.z3_ir import TranslatorZ3
+
+# create nested expression
+a = ExprId("a", 64)
+b = ExprId('b', 32)
+c = ExprId('c', 16)
+d = ExprId('d', 8)
+e = ExprId('e', 1)
+
+left = ExprCond(e + ExprOp('parity', a),
+                ExprMem(a * a, 64),
+                ExprMem(a, 64))
+
+cond = ExprSlice(ExprSlice(ExprSlice(a, 0, 32) + b, 0, 16) * c, 0, 8) << ExprOp('>>>', d, ExprInt(uint8(0x5L)))
+right = ExprCond(cond,
+                 a + ExprInt(uint64(0x64L)),
+                 ExprInt(uint64(0x16L)))
+
+e = ExprAff(left, right)
+
+# initialise translators
+t_z3 = TranslatorZ3()
+t_smt2 = TranslatorSMT2()
+
+# translate to z3
+e_z3 = t_z3.from_expr(e)
+# translate to smt2
+smt2 = t_smt2.to_smt2([t_smt2.from_expr(e)])
+
+# parse smt2 string with z3
+smt2_z3 = parse_smt2_string(smt2)
+# initialise SMT solver
+s = Solver()
+
+# prove equivalence of z3 and smt2 translation
+s.add(e_z3 != smt2_z3)
+assert (s.check() == unsat)
diff --git a/test/test_all.py b/test/test_all.py
index e52123ea..bc019104 100644
--- a/test/test_all.py
+++ b/test/test_all.py
@@ -213,6 +213,8 @@ testset += RegressionTest(["analysis.py"], base_dir="ir",
                                     for fname in fnames])
 testset += RegressionTest(["z3_ir.py"], base_dir="ir/translators",
                           tags=[TAGS["z3"]])
+testset += RegressionTest(["smt2.py"], base_dir="ir/translators",
+                          tags=[TAGS["z3"]])
 ## OS_DEP
 for script in ["win_api_x86_32.py",
                ]: