diff options
| -rw-r--r-- | miasm2/arch/ppc/sem.py | 4 | ||||
| -rw-r--r-- | miasm2/arch/x86/sem.py | 14 | ||||
| -rw-r--r-- | miasm2/expression/simplifications_common.py | 14 | ||||
| -rw-r--r-- | miasm2/ir/translators/C.py | 4 | ||||
| -rw-r--r-- | miasm2/ir/translators/smt2.py | 7 | ||||
| -rw-r--r-- | miasm2/ir/translators/z3_ir.py | 15 | ||||
| -rw-r--r-- | miasm2/jitter/llvmconvert.py | 28 | ||||
| -rw-r--r-- | miasm2/jitter/vm_mngr.c | 28 | ||||
| -rw-r--r-- | miasm2/jitter/vm_mngr.h | 4 | ||||
| -rw-r--r-- | test/expression/simplifications.py | 9 | ||||
| -rwxr-xr-x | test/ir/ir2C.py | 2 | ||||
| -rw-r--r-- | test/ir/translators/z3_ir.py | 32 |
12 files changed, 100 insertions, 61 deletions
diff --git a/miasm2/arch/ppc/sem.py b/miasm2/arch/ppc/sem.py index 4923e3a8..741ae24b 100644 --- a/miasm2/arch/ppc/sem.py +++ b/miasm2/arch/ppc/sem.py @@ -98,9 +98,7 @@ def mn_do_and(ir, instr, ra, rs, arg2): return ret, [] def mn_do_cntlzw(ir, instr, ra, rs): - rvalue = ExprCond(rs, ExprInt(31, 32) - ExprOp('bsr', rs), ExprInt(32, 32)) - - ret = [ ExprAff(ra, rvalue) ] + ret = [ ExprAff(ra, ExprOp('cntleadzeros'), rs) ] if instr.name[-1] == '.': ret += mn_compute_flags(rvalue) diff --git a/miasm2/arch/x86/sem.py b/miasm2/arch/x86/sem.py index 7682192f..5de58c15 100644 --- a/miasm2/arch/x86/sem.py +++ b/miasm2/arch/x86/sem.py @@ -2867,14 +2867,14 @@ def aas(ir, instr): return _tpl_aaa(ir, instr, "-") -def bsr_bsf(ir, instr, dst, src, op_name): +def bsr_bsf(ir, instr, dst, src, op_func): """ IF SRC == 0 ZF = 1 DEST is left unchanged ELSE ZF = 0 - DEST = @op_name(SRC) + DEST = @op_func(SRC) """ lbl_src_null = m2_expr.ExprId(ir.gen_label(), ir.IRDst.size) lbl_src_not_null = m2_expr.ExprId(ir.gen_label(), ir.IRDst.size) @@ -2891,7 +2891,7 @@ def bsr_bsf(ir, instr, dst, src, op_name): e_src_not_null = [] e_src_not_null.append(m2_expr.ExprAff(zf, m2_expr.ExprInt(0, zf.size))) - e_src_not_null.append(m2_expr.ExprAff(dst, m2_expr.ExprOp(op_name, src))) + e_src_not_null.append(m2_expr.ExprAff(dst, op_func(src))) e_src_not_null.append(aff_dst) return e, [IRBlock(lbl_src_null.name, [AssignBlock(e_src_null, instr)]), @@ -2899,11 +2899,15 @@ def bsr_bsf(ir, instr, dst, src, op_name): def bsf(ir, instr, dst, src): - return bsr_bsf(ir, instr, dst, src, "bsf") + return bsr_bsf(ir, instr, dst, src, + lambda src: m2_expr.ExprOp("cnttrailzeros", src)) def bsr(ir, instr, dst, src): - return bsr_bsf(ir, instr, dst, src, "bsr") + return bsr_bsf( + ir, instr, dst, src, + lambda src: m2_expr.ExprInt(src.size - 1, src.size) - m2_expr.ExprOp("cntleadzeros", src) + ) def arpl(_, instr, dst, src): diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py index a1301cba..13b25ce2 100644 --- a/miasm2/expression/simplifications_common.py +++ b/miasm2/expression/simplifications_common.py @@ -95,19 +95,21 @@ def simp_cst_propagation(e_s, expr): args.append(ExprInt(out, int1.size)) - # bsf(int) => int - if op_name == "bsf" and args[0].is_int() and args[0].arg != 0: + # cnttrailzeros(int) => int + if op_name == "cnttrailzeros" and args[0].is_int(): i = 0 - while args[0].arg & (1 << i) == 0: + while args[0].arg & (1 << i) == 0 and i < args[0].size: i += 1 return ExprInt(i, args[0].size) - # bsr(int) => int - if op_name == "bsr" and args[0].is_int() and args[0].arg != 0: + # cntleadzeros(int) => int + if op_name == "cntleadzeros" and args[0].is_int(): + if args[0].arg == 0: + return ExprInt(args[0].size, args[0].size) i = args[0].size - 1 while args[0].arg & (1 << i) == 0: i -= 1 - return ExprInt(i, args[0].size) + return ExprInt(expr.size - (i + 1), args[0].size) # -(-(A)) => A if (op_name == '-' and len(args) == 1 and args[0].is_op('-') and diff --git a/miasm2/ir/translators/C.py b/miasm2/ir/translators/C.py index 0e285669..099f1420 100644 --- a/miasm2/ir/translators/C.py +++ b/miasm2/ir/translators/C.py @@ -43,8 +43,8 @@ class TranslatorC(Translator): if expr.op == 'parity': return "parity(%s&0x%x)" % (self.from_expr(expr.args[0]), size2mask(expr.args[0].size)) - elif expr.op in ['bsr', 'bsf']: - return "x86_%s(0x%x, %s)" % (expr.op, + elif expr.op in ['cntleadzeros', 'cnttrailzeros']: + return "%s(0x%x, %s)" % (expr.op, expr.args[0].size, self.from_expr(expr.args[0])) elif expr.op in ['clz']: diff --git a/miasm2/ir/translators/smt2.py b/miasm2/ir/translators/smt2.py index 26ff9127..18bcb9bd 100644 --- a/miasm2/ir/translators/smt2.py +++ b/miasm2/ir/translators/smt2.py @@ -233,7 +233,7 @@ class TranslatorSMT2(Translator): res = bvxor(res, bv_extract(i, i, arg)) elif expr.op == '-': res = bvneg(res) - elif expr.op == "bsf": + elif expr.op == "cnttrailzeros": src = res size = expr.size size_smt2 = bit_vec_val(size, size) @@ -254,7 +254,7 @@ class TranslatorSMT2(Translator): cond = smt2_distinct(op, zero_smt2) # ite(cond, i, res) res = smt2_ite(cond, i_smt2, res) - elif expr.op == "bsr": + elif expr.op == "cntleadzeros": src = res size = expr.size one_smt2 = bit_vec_val(1, size) @@ -271,7 +271,8 @@ class TranslatorSMT2(Translator): # op != 0 cond = smt2_distinct(op, zero_smt2) # ite(cond, index, res) - res = smt2_ite(cond, index_smt2, res) + value_smt2 = bit_vec_val(size - (index + 1), size) + res = smt2_ite(cond, value_smt2, res) else: raise NotImplementedError("Unsupported OP yet: %s" % expr.op) diff --git a/miasm2/ir/translators/z3_ir.py b/miasm2/ir/translators/z3_ir.py index 74bdd79c..536daff1 100644 --- a/miasm2/ir/translators/z3_ir.py +++ b/miasm2/ir/translators/z3_ir.py @@ -207,19 +207,20 @@ class TranslatorZ3(Translator): res = res ^ z3.Extract(i, i, arg) elif expr.op == '-': res = -res - elif expr.op == "bsf": + elif expr.op == "cnttrailzeros": size = expr.size src = res - res = z3.If((src & (1 << (size - 1))) != 0, size - 1, src) - for i in xrange(size - 2, -1, -1): + res = z3.If(src == 0, size, src) + for i in xrange(size - 1, -1, -1): res = z3.If((src & (1 << i)) != 0, i, res) - elif expr.op == "bsr": + elif expr.op == "cntleadzeros": size = expr.size src = res - res = z3.If((src & 1) != 0, 0, src) - for i in xrange(size - 1, 0, -1): + res = z3.If(src == 0, size, src) + for i in xrange(size, 0, -1): index = - i % size - res = z3.If((src & (1 << index)) != 0, index, res) + out = size - (index + 1) + res = z3.If((src & (1 << index)) != 0, out, res) else: raise NotImplementedError("Unsupported OP yet: %s" % expr.op) diff --git a/miasm2/jitter/llvmconvert.py b/miasm2/jitter/llvmconvert.py index 35db1538..eef34c16 100644 --- a/miasm2/jitter/llvmconvert.py +++ b/miasm2/jitter/llvmconvert.py @@ -227,12 +227,6 @@ class LLVMContext_JIT(LLVMContext): itype = LLVMType.IntType(64) fc = {"llvm.ctpop.i8": {"ret": i8, "args": [i8]}, - "x86_bsr": {"ret": itype, - "args": [itype, - itype]}, - "x86_bsf": {"ret": itype, - "args": [itype, - itype]}, "segm2addr": {"ret": itype, "args": [p8, itype, @@ -377,9 +371,7 @@ class LLVMFunction(): op_translate = {'cpuid': 'cpuid', } ## Add the size as first argument - op_translate_with_size = {'bsr': 'x86_bsr', - 'bsf': 'x86_bsf', - } + op_translate_with_size = {} ## Add the size as suffix op_translate_with_suffix_size = {'bcdadd': 'bcdadd', 'bcdadd_cf': 'bcdadd_cf', @@ -714,6 +706,24 @@ class LLVMFunction(): self.update_cache(expr, ret) return ret + if op in ["cntleadzeros", "cnttrailzeros"]: + assert len(expr.args) == 1 + arg = self.add_ir(expr.args[0]) + func_name = { + "cntleadzeros": "ctlz", + "cnttrailzeros": "cttz", + }[op] + func_llvm_name = "llvm.%s.i%d" % (func_name, expr.size) + func_sig = {func_llvm_name: { + "ret": LLVMType.IntType(expr.size), + "args": [LLVMType.IntType(expr.args[0].size)] + }} + self.llvm_context.add_fc(func_sig, readonly=True) + ret = builder.call(self.mod.get_global(func_llvm_name), + [arg]) + self.update_cache(expr, ret) + return ret + if op == "segm": fc_ptr = self.mod.get_global("segm2addr") diff --git a/miasm2/jitter/vm_mngr.c b/miasm2/jitter/vm_mngr.c index 3a0e51d3..4331a2ac 100644 --- a/miasm2/jitter/vm_mngr.c +++ b/miasm2/jitter/vm_mngr.c @@ -832,27 +832,41 @@ uint64_t rot_right(uint64_t size, uint64_t a, uint64_t b) } } -unsigned int x86_bsr(uint64_t size, uint64_t src) +/* + * Count leading zeros - count the number of zero starting at the most + * significant bit + * + * Example: + * - cntleadzeros(size=32, src=2): 30 + * - cntleadzeros(size=32, src=0): 32 + */ +unsigned int cntleadzeros(uint64_t size, uint64_t src) { int64_t i; for (i=(int64_t)size-1; i>=0; i--){ if (src & (1ull << i)) - return i; + return size - (i + 1); } - fprintf(stderr, "sanity check error bsr\n"); - exit(EXIT_FAILURE); + return size; } -unsigned int x86_bsf(uint64_t size, uint64_t src) +/* + * Count trailing zeros - count the number of zero starting at the least + * significant bit + * + * Example: + * - cnttrailzeros(size=32, src=2): 1 + * - cnttrailzeros(size=32, src=0): 32 + */ +unsigned int cnttrailzeros(uint64_t size, uint64_t src) { uint64_t i; for (i=0; i<size; i++){ if (src & (1ull << i)) return i; } - fprintf(stderr, "sanity check error bsf\n"); - exit(EXIT_FAILURE); + return size; } diff --git a/miasm2/jitter/vm_mngr.h b/miasm2/jitter/vm_mngr.h index f050f7c0..b101b6ca 100644 --- a/miasm2/jitter/vm_mngr.h +++ b/miasm2/jitter/vm_mngr.h @@ -219,8 +219,8 @@ unsigned int umul16_hi(unsigned short a, unsigned short b); uint64_t rot_left(uint64_t size, uint64_t a, uint64_t b); uint64_t rot_right(uint64_t size, uint64_t a, uint64_t b); -unsigned int x86_bsr(uint64_t size, uint64_t src); -unsigned int x86_bsf(uint64_t size, uint64_t src); +unsigned int cntleadzeros(uint64_t size, uint64_t src); +unsigned int cnttrailzeros(uint64_t size, uint64_t src); #define UDIV(sizeA) \ uint ## sizeA ## _t udiv ## sizeA (vm_cpu_t* vmcpu, uint ## sizeA ## _t a, uint ## sizeA ## _t b) \ diff --git a/test/expression/simplifications.py b/test/expression/simplifications.py index 3e2e5177..a4e839cf 100644 --- a/test/expression/simplifications.py +++ b/test/expression/simplifications.py @@ -395,6 +395,15 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)), ExprInt(0xc6, 8)), (ExprOp("imod", ExprInt(0x0123, 16), ExprInt(0xfffb, 16))[:8], ExprInt(0x01, 8)), + (ExprOp("cnttrailzeros", ExprInt(0x2, 32)), + ExprInt(0x1, 32)), + (ExprOp("cnttrailzeros", ExprInt(0x0, 32)), + ExprInt(0x20, 32)), + (ExprOp("cntleadzeros", ExprInt(0x2, 32)), + ExprInt(30, 32)), + (ExprOp("cntleadzeros", ExprInt(0x0, 32)), + ExprInt(0x20, 32)), + (ExprCompose(ExprInt(0x0123, 16), ExprMem(a + ExprInt(0x40, a.size), 16), ExprMem(a + ExprInt(0x42, a.size), 16), ExprInt(0x0321, 16)), diff --git a/test/ir/ir2C.py b/test/ir/ir2C.py index c84473c3..20ade999 100755 --- a/test/ir/ir2C.py +++ b/test/ir/ir2C.py @@ -38,7 +38,7 @@ class TestIrIr2C(unittest.TestCase): self.translationTest( ExprOp('-', *args[:2]), r'(((0x0&0xffffffff) - (0x1&0xffffffff))&0xffffffff)') self.translationTest( - ExprOp('bsr', *args[:1]), r'x86_bsr(0x0, 0x20)') + ExprOp('cntleadzeros', *args[:1]), r'cntleadzeros(0x0, 0x20)') self.translationTest( ExprOp('cpuid', *args[:2]), r'cpuid(0x0, 0x1)') self.translationTest( diff --git a/test/ir/translators/z3_ir.py b/test/ir/translators/z3_ir.py index 83744786..6ae2dcd0 100644 --- a/test/ir/translators/z3_ir.py +++ b/test/ir/translators/z3_ir.py @@ -150,22 +150,22 @@ ez3 = Translator.to_language('z3').from_expr(e8) assert not equiv(ez3, z3_e7) # -------------------------------------------------------------------------- -# bsr, bsf - -# bsf(0x1138) == 3 -bsf1 = Translator.to_language('z3').from_expr(ExprOp("bsf", ExprInt(0x1138, 32))) -bsf2 = z3.BitVecVal(3, 32) -assert(equiv(bsf1, bsf2)) - -# bsr(0x11300) == 0x10 -bsr1 = Translator.to_language('z3').from_expr(ExprOp("bsr", ExprInt(0x11300, 32))) -bsr2 = z3.BitVecVal(0x10, 32) -assert(equiv(bsr1, bsr2)) - -# bsf(0x80000) == bsr(0x80000) -bsf3 = Translator.to_language('z3').from_expr(ExprOp("bsf", ExprInt(0x80000, 32))) -bsr3 = Translator.to_language('z3').from_expr(ExprOp("bsr", ExprInt(0x80000, 32))) -assert(equiv(bsf3, bsr3)) +# cntleadzeros, cnttrailzeros + +# cnttrailzeros(0x1138) == 3 +cnttrailzeros1 = Translator.to_language('z3').from_expr(ExprOp("cnttrailzeros", ExprInt(0x1138, 32))) +cnttrailzeros2 = z3.BitVecVal(3, 32) +assert(equiv(cnttrailzeros1, cnttrailzeros2)) + +# cntleadzeros(0x11300) == 0xf +cntleadzeros1 = Translator.to_language('z3').from_expr(ExprOp("cntleadzeros", ExprInt(0x11300, 32))) +cntleadzeros2 = z3.BitVecVal(0xf, 32) +assert(equiv(cntleadzeros1, cntleadzeros2)) + +# cnttrailzeros(0x8000) + 1 == cntleadzeros(0x8000) +cnttrailzeros3 = Translator.to_language('z3').from_expr(ExprOp("cnttrailzeros", ExprInt(0x8000, 32)) + ExprInt(1, 32)) +cntleadzeros3 = Translator.to_language('z3').from_expr(ExprOp("cntleadzeros", ExprInt(0x8000, 32))) +assert(equiv(cnttrailzeros3, cntleadzeros3)) print "TranslatorZ3 tests are OK." |