diff options
| author | serpilliere <serpilliere@users.noreply.github.com> | 2016-01-12 22:54:18 +0100 |
|---|---|---|
| committer | serpilliere <serpilliere@users.noreply.github.com> | 2016-01-12 22:54:18 +0100 |
| commit | 6461a40e5eaf4bf39aadfee29ac72fe9afac4f9e (patch) | |
| tree | 34f798fb066ba21370e00ff5549439be6225c0de | |
| parent | c7d3d40ba1489ceffc1949f721c85a5aba8005aa (diff) | |
| parent | 491034906747e44e48bd4d75104e95f0b2d0dbe8 (diff) | |
| download | miasm-6461a40e5eaf4bf39aadfee29ac72fe9afac4f9e.tar.gz miasm-6461a40e5eaf4bf39aadfee29ac72fe9afac4f9e.zip | |
Merge pull request #298 from mrphrazer/smt2_translator
SMT2 translator
| -rw-r--r-- | miasm2/expression/smt2_helper.py | 296 | ||||
| -rw-r--r-- | miasm2/ir/translators/smt2.py | 283 | ||||
| -rw-r--r-- | test/ir/translators/smt2.py | 40 | ||||
| -rw-r--r-- | test/test_all.py | 2 |
4 files changed, 621 insertions, 0 deletions
diff --git a/miasm2/expression/smt2_helper.py b/miasm2/expression/smt2_helper.py new file mode 100644 index 00000000..53d323e8 --- /dev/null +++ b/miasm2/expression/smt2_helper.py @@ -0,0 +1,296 @@ +# Helper functions for the generation of SMT2 expressions +# The SMT2 expressions will be returned as a string. +# The expressions are divided as follows +# +# - generic SMT2 operations +# - definitions of SMT2 structures +# - bit vector operations +# - array operations + +# generic SMT2 operations + +def smt2_eq(a, b): + """ + Assignment: a = b + """ + return "(= {} {})".format(a, b) + + +def smt2_implies(a, b): + """ + Implication: a => b + """ + return "(=> {} {})".format(a, b) + + +def smt2_and(*args): + """ + Conjunction: a and b and c ... + """ + # transform args into strings + args = [str(arg) for arg in args] + return "(and {})".format(' '.join(args)) + + +def smt2_or(*args): + """ + Disjunction: a or b or c ... + """ + # transform args into strings + args = [str(arg) for arg in args] + return "(or {})".format(' '.join(args)) + + +def smt2_ite(cond, a, b): + """ + If-then-else: cond ? a : b + """ + return "(ite {} {} {})".format(cond, a, b) + + +def smt2_distinct(*args): + """ + Distinction: a != b != c != ... + """ + # transform args into strings + args = [str(arg) for arg in args] + return "(distinct {})".format(' '.join(args)) + + +def smt2_assert(expr): + """ + Assertion that @expr holds + """ + return "(assert {})".format(expr) + + +# definitions + +def declare_bv(bv, size): + """ + Declares an bit vector @bv of size @size + """ + return "(declare-fun {} () {})".format(bv, bit_vec(size)) + + +def declare_array(a, bv1, bv2): + """ + Declares an SMT2 array represented as a map + from a bit vector to another bit vector. + :param a: array name + :param bv1: SMT2 bit vector + :param bv2: SMT2 bit vector + """ + return "(declare-fun {} () (Array {} {}))".format(a, bv1, bv2) + + +def bit_vec_val(v, size): + """ + Declares a bit vector value + :param v: int, value of the bit vector + :param size: size of the bit vector + """ + return "(_ bv{} {})".format(v, size) + + +def bit_vec(size): + """ + Returns a bit vector of size @size + """ + return "(_ BitVec {})".format(size) + + +# bit vector operations + +def bvadd(a, b): + """ + Addition: a + b + """ + return "(bvadd {} {})".format(a, b) + + +def bvsub(a, b): + """ + Subtraction: a - b + """ + return "(bvsub {} {})".format(a, b) + + +def bvmul(a, b): + """ + Multiplication: a * b + """ + return "(bvmul {} {})".format(a, b) + + +def bvand(a, b): + """ + Bitwise AND: a & b + """ + return "(bvand {} {})".format(a, b) + + +def bvor(a, b): + """ + Bitwise OR: a | b + """ + return "(bvor {} {})".format(a, b) + + +def bvxor(a, b): + """ + Bitwise XOR: a ^ b + """ + return "(bvxor {} {})".format(a, b) + + +def bvneg(bv): + """ + Unary minus: - bv + """ + return "(bvneg {})".format(bv) + + +def bvsdiv(a, b): + """ + Signed division: a / b + """ + return "(bvsdiv {} {})".format(a, b) + + +def bvudiv(a, b): + """ + Unsigned division: a / b + """ + return "(bvudiv {} {})".format(a, b) + + +def bvsmod(a, b): + """ + Signed modulo: a mod b + """ + return "(bvsmod {} {})".format(a, b) + + +def bvurem(a, b): + """ + Unsigned modulo: a mod b + """ + return "(bvurem {} {})".format(a, b) + + +def bvshl(a, b): + """ + Shift left: a << b + """ + return "(bvshl {} {})".format(a, b) + + +def bvlshr(a, b): + """ + Logical shift right: a >> b + """ + return "(bvlshr {} {})".format(a, b) + + +def bvashr(a, b): + """ + Arithmetic shift right: a a>> b + """ + return "(bvashr {} {})".format(a, b) + + +def bv_rotate_left(a, b, size): + """ + Rotates bits of a to the left b times: a <<< b + + Since ((_ rotate_left b) a) does not support + symbolic values for b, the implementation is + based on a C implementation. + + Therefore, the rotation will be computed as + a << (b & (size - 1))) | (a >> (size - (b & (size - 1)))) + + :param a: bit vector + :param b: bit vector + :param size: size of a + """ + + # define constant + s = bit_vec_val(size, size) + + # shift = b & (size - 1) + shift = bvand(b, bvsub(s, bit_vec_val(1, size))) + + # (a << shift) | (a >> size - shift) + rotate = bvor(bvshl(a, shift), + bvlshr(a, bvsub(s, shift))) + + return rotate + + +def bv_rotate_right(a, b, size): + """ + Rotates bits of a to the right b times: a >>> b + + Since ((_ rotate_right b) a) does not support + symbolic values for b, the implementation is + based on a C implementation. + + Therefore, the rotation will be computed as + a >> (b & (size - 1))) | (a << (size - (b & (size - 1)))) + + :param a: bit vector + :param b: bit vector + :param size: size of a + """ + + # define constant + s = bit_vec_val(size, size) + + # shift = b & (size - 1) + shift = bvand(b, bvsub(s, bit_vec_val(1, size))) + + # (a >> shift) | (a << size - shift) + rotate = bvor(bvlshr(a, shift), + bvshl(a, bvsub(s, shift))) + + return rotate + + +def bv_extract(high, low, bv): + """ + Extracts bits from a bit vector + :param high: end bit + :param low: start bit + :param bv: bit vector + """ + return "((_ extract {} {}) {})".format(high, low, bv) + + +def bv_concat(a, b): + """ + Concatenation of two SMT2 expressions + """ + return "(concat {} {})".format(a, b) + + +# array operations + +def array_select(array, index): + """ + Reads from an SMT2 array at index @index + :param array: SMT2 array + :param index: SMT2 expression, index of the array + """ + return "(select {} {})".format(array, index) + + +def array_store(array, index, value): + """ + Writes an value into an SMT2 array at address @index + :param array: SMT array + :param index: SMT2 expression, index of the array + :param value: SMT2 expression, value to write + """ + return "(store {} {} {})".format(array, index, value) diff --git a/miasm2/ir/translators/smt2.py b/miasm2/ir/translators/smt2.py new file mode 100644 index 00000000..96f8dab3 --- /dev/null +++ b/miasm2/ir/translators/smt2.py @@ -0,0 +1,283 @@ +import logging +import operator + +from miasm2.core.asmbloc import asm_label +from miasm2.ir.translators.translator import Translator +from miasm2.expression.smt2_helper import * + +log = logging.getLogger("translator_smt2") +console_handler = logging.StreamHandler() +console_handler.setFormatter(logging.Formatter("%(levelname)-5s: %(message)s")) +log.addHandler(console_handler) +log.setLevel(logging.WARNING) + +class SMT2Mem(object): + """ + Memory abstraction for TranslatorSMT2. Memory elements are only accessed, + never written. To give a concrete value for a given memory cell in a solver, + add "mem32.get(address, size) == <value>" constraints to your equation. + The endianness of memory accesses is handled accordingly to the "endianness" + attribute. + Note: Will have one memory space for each addressing size used. + For example, if memory is accessed via 32 bits values and 16 bits values, + these access will not occur in the same address space. + + Adapted from Z3Mem + """ + + def __init__(self, endianness="<", name="mem"): + """Initializes an SMT2Mem object with a given @name and @endianness. + @endianness: Endianness of memory representation. '<' for little endian, + '>' for big endian. + @name: name of memory Arrays generated. They will be named + name+str(address size) (for example mem32, mem16...). + """ + if endianness not in ['<', '>']: + raise ValueError("Endianness should be '>' (big) or '<' (little)") + self.endianness = endianness + self.mems = {} # Address size -> SMT2 memory array + self.name = name + # initialise address size + self.addr_size = 0 + + def get_mem_array(self, size): + """Returns an SMT Array used internally to represent memory for addresses + of size @size. + @size: integer, size in bit of addresses in the memory to get. + Return an string with the name of the SMT array.. + """ + try: + mem = self.mems[size] + except KeyError: + # Lazy instanciation + self.mems[size] = self.name + str(size) + mem = self.mems[size] + return mem + + def __getitem__(self, addr): + """One byte memory access. Different address sizes with the same value + will result in different memory accesses. + @addr: an SMT2 expression, the address to read. + Return an SMT2 expression of size 8 bits representing a memory access. + """ + size = self.addr_size + mem = self.get_mem_array(size) + return array_select(mem, addr) + + def get(self, addr, size, addr_size): + """ Memory access at address @addr of size @size with + address size @addr_size. + @addr: an SMT2 expression, the address to read. + @size: int, size of the read in bits. + @addr_size: int, size of the address + Return a SMT2 expression representing a memory access. + """ + # set address size per read access + self.addr_size = addr_size + + original_size = size + if original_size % 8 != 0: + # Size not aligned on 8bits -> read more than size and extract after + size = ((original_size / 8) + 1) * 8 + res = self[addr] + if self.is_little_endian(): + for i in xrange(1, size/8): + index = bvadd(addr, bit_vec_val(i, addr_size)) + res = bv_concat(self[index], res) + else: + for i in xrange(1, size/8): + res = bv_concat(res, self[index]) + if size == original_size: + return res + else: + # Size not aligned, extract right sized result + return bv_extract(original_size-1, 0, res) + + def is_little_endian(self): + """True if this memory is little endian.""" + return self.endianness == "<" + + def is_big_endian(self): + """True if this memory is big endian.""" + return not self.is_little_endian() + + +class TranslatorSMT2(Translator): + """Translate a Miasm expression into an equivalent SMT2 + expression. Memory is abstracted via SMT2Mem. + The result of from_expr will be an SMT2 expression. + + If you want to interract with the memory abstraction after the translation, + you can instantiate your own SMT2Mem that will be equivalent to the one + used by TranslatorSMT2. + + TranslatorSMT2 provides the creation of a valid SMT2 file. For this, + it keeps track of the translated bit vectors. + + Adapted from TranslatorZ3 + """ + + # Implemented language + __LANG__ = "smt2" + + def __init__(self, endianness="<", **kwargs): + """Instance a SMT2 translator + @endianness: (optional) memory endianness + """ + super(TranslatorSMT2, self).__init__(**kwargs) + # memory abstraction + self._mem = SMT2Mem(endianness) + # map of translated bit vectors + self._bitvectors = dict() + + def from_ExprInt(self, expr): + return bit_vec_val(expr.arg.arg, expr.size) + + def from_ExprId(self, expr): + if isinstance(expr.name, asm_label): + if expr.name.offset is not None: + return bit_vec_val(str(expr.name.offset), expr.size) + else: + # SMT2-escape expression name + name = "|{}|".format(str(expr.name)) + if name not in self._bitvectors: + self._bitvectors[name] = expr.size + return name + else: + if str(expr) not in self._bitvectors: + self._bitvectors[str(expr)] = expr.size + return str(expr) + + def from_ExprMem(self, expr): + addr = self.from_expr(expr.arg) + # size to read from memory + size = expr.size + # size of memory address + addr_size = expr.arg.size + return self._mem.get(addr, size, addr_size) + + def from_ExprSlice(self, expr): + res = self.from_expr(expr.arg) + res = bv_extract(expr.stop-1, expr.start, res) + return res + + def from_ExprCompose(self, expr): + res = None + args = sorted(expr.args, key=operator.itemgetter(2)) # sort by start off + for subexpr, start, stop in args: + sube = self.from_expr(subexpr) + e = bv_extract(stop-start-1, 0, sube) + if res: + res = bv_concat(e, res) + else: + res = e + return res + + def from_ExprCond(self, expr): + cond = self.from_expr(expr.cond) + src1 = self.from_expr(expr.src1) + src2 = self.from_expr(expr.src2) + + # (and (distinct cond (_ bv0 <size>)) true) + zero = bit_vec_val(0, expr.cond.size) + distinct = smt2_distinct(cond, zero) + distinct_and = smt2_and(distinct, "true") + + # (ite ((and (distinct cond (_ bv0 <size>)) true) src1 src2)) + return smt2_ite(distinct_and, src1, src2) + + def from_ExprOp(self, expr): + args = map(self.from_expr, expr.args) + res = args[0] + + if len(args) > 1: + for arg in args[1:]: + if expr.op == "+": + res = bvadd(res, arg) + elif expr.op == "-": + res = bvsub(res, arg) + elif expr.op == "*": + res = bvmul(res, arg) + elif expr.op == "/": + res = bvsdiv(res, arg) + elif expr.op == "idiv": + res = bvsdiv(res, arg) + elif expr.op == "udiv": + res = bvudiv(res, arg) + elif expr.op == "%": + res = bvsmod(res, arg) + elif expr.op == "imod": + res = bvsmod(res, arg) + elif expr.op == "umod": + res = bvurem(res, arg) + elif expr.op == "&": + res = bvand(res, arg) + elif expr.op == "^": + res = bvxor(res, arg) + elif expr.op == "|": + res = bvor(res, arg) + elif expr.op == "<<": + res = bvshl(res, arg) + elif expr.op == ">>": + res = bvlshr(res, arg) + elif expr.op == "a<<": + res = bvshl(res, arg) + elif expr.op == "a>>": + res = bvashr(res, arg) + elif expr.op == "<<<": + res = bv_rotate_left(res, arg, expr.size) + elif expr.op == ">>>": + res = bv_rotate_right(res, arg, expr.size) + else: + raise NotImplementedError("Unsupported OP yet: %s" % expr.op) + elif expr.op == 'parity': + arg = bv_extract(7, 0, res) + res = bit_vec_val(1, 1) + for i in xrange(8): + res = bvxor(res, bv_extract(i, i, arg)) + elif expr.op == '-': + res = bvneg(res) + else: + raise NotImplementedError("Unsupported OP yet: %s" % expr.op) + + return res + + def from_ExprAff(self, expr): + src = self.from_expr(expr.src) + dst = self.from_expr(expr.dst) + return smt2_assert(smt2_eq(src, dst)) + + def to_smt2(self, exprs, logic="QF_ABV"): + """ + Converts a valid SMT2 file for a given list of + SMT2 expressions. + + :param exprs: list of SMT2 expressions + :param logic: SMT2 logic + :return: String of the SMT2 file + """ + ret = "" + ret += "(set-logic {})\n".format(logic) + + # define bit vectors + for bv in self._bitvectors: + size = self._bitvectors[bv] + ret += "{}\n".format(declare_bv(bv, size)) + + # define memory arrays + for size in self._mem.mems: + mem = self._mem.mems[size] + ret += "{}\n".format(declare_array(mem, bit_vec(size), bit_vec(8))) + + # merge SMT2 expressions + for expr in exprs: + ret += expr + "\n" + + # define action + ret += "(check-sat)\n" + + return ret + + +# Register the class +Translator.register(TranslatorSMT2) diff --git a/test/ir/translators/smt2.py b/test/ir/translators/smt2.py new file mode 100644 index 00000000..97877a3b --- /dev/null +++ b/test/ir/translators/smt2.py @@ -0,0 +1,40 @@ +from z3 import Solver, unsat, parse_smt2_string +from miasm2.expression.expression import * +from miasm2.ir.translators.smt2 import TranslatorSMT2 +from miasm2.ir.translators.z3_ir import TranslatorZ3 + +# create nested expression +a = ExprId("a", 64) +b = ExprId('b', 32) +c = ExprId('c', 16) +d = ExprId('d', 8) +e = ExprId('e', 1) + +left = ExprCond(e + ExprOp('parity', a), + ExprMem(a * a, 64), + ExprMem(a, 64)) + +cond = ExprSlice(ExprSlice(ExprSlice(a, 0, 32) + b, 0, 16) * c, 0, 8) << ExprOp('>>>', d, ExprInt(uint8(0x5L))) +right = ExprCond(cond, + a + ExprInt(uint64(0x64L)), + ExprInt(uint64(0x16L))) + +e = ExprAff(left, right) + +# initialise translators +t_z3 = TranslatorZ3() +t_smt2 = TranslatorSMT2() + +# translate to z3 +e_z3 = t_z3.from_expr(e) +# translate to smt2 +smt2 = t_smt2.to_smt2([t_smt2.from_expr(e)]) + +# parse smt2 string with z3 +smt2_z3 = parse_smt2_string(smt2) +# initialise SMT solver +s = Solver() + +# prove equivalence of z3 and smt2 translation +s.add(e_z3 != smt2_z3) +assert (s.check() == unsat) diff --git a/test/test_all.py b/test/test_all.py index e52123ea..bc019104 100644 --- a/test/test_all.py +++ b/test/test_all.py @@ -213,6 +213,8 @@ testset += RegressionTest(["analysis.py"], base_dir="ir", for fname in fnames]) testset += RegressionTest(["z3_ir.py"], base_dir="ir/translators", tags=[TAGS["z3"]]) +testset += RegressionTest(["smt2.py"], base_dir="ir/translators", + tags=[TAGS["z3"]]) ## OS_DEP for script in ["win_api_x86_32.py", ]: |