about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--miasm2/arch/ppc/sem.py4
-rw-r--r--miasm2/arch/x86/sem.py14
-rw-r--r--miasm2/expression/simplifications_common.py14
-rw-r--r--miasm2/ir/translators/C.py4
-rw-r--r--miasm2/ir/translators/smt2.py7
-rw-r--r--miasm2/ir/translators/z3_ir.py15
-rw-r--r--miasm2/jitter/llvmconvert.py28
-rw-r--r--miasm2/jitter/vm_mngr.c28
-rw-r--r--miasm2/jitter/vm_mngr.h4
-rw-r--r--test/expression/simplifications.py9
-rwxr-xr-xtest/ir/ir2C.py2
-rw-r--r--test/ir/translators/z3_ir.py32
12 files changed, 100 insertions, 61 deletions
diff --git a/miasm2/arch/ppc/sem.py b/miasm2/arch/ppc/sem.py
index 4923e3a8..741ae24b 100644
--- a/miasm2/arch/ppc/sem.py
+++ b/miasm2/arch/ppc/sem.py
@@ -98,9 +98,7 @@ def mn_do_and(ir, instr, ra, rs, arg2):
     return ret, []
 
 def mn_do_cntlzw(ir, instr, ra, rs):
-    rvalue = ExprCond(rs, ExprInt(31, 32) - ExprOp('bsr', rs), ExprInt(32, 32))
-
-    ret = [ ExprAff(ra, rvalue) ]
+    ret = [ ExprAff(ra, ExprOp('cntleadzeros'), rs) ]
 
     if instr.name[-1] == '.':
         ret += mn_compute_flags(rvalue)
diff --git a/miasm2/arch/x86/sem.py b/miasm2/arch/x86/sem.py
index 7682192f..5de58c15 100644
--- a/miasm2/arch/x86/sem.py
+++ b/miasm2/arch/x86/sem.py
@@ -2867,14 +2867,14 @@ def aas(ir, instr):
     return _tpl_aaa(ir, instr, "-")
 
 
-def bsr_bsf(ir, instr, dst, src, op_name):
+def bsr_bsf(ir, instr, dst, src, op_func):
     """
     IF SRC == 0
         ZF = 1
         DEST is left unchanged
     ELSE
         ZF = 0
-        DEST = @op_name(SRC)
+        DEST = @op_func(SRC)
     """
     lbl_src_null = m2_expr.ExprId(ir.gen_label(), ir.IRDst.size)
     lbl_src_not_null = m2_expr.ExprId(ir.gen_label(), ir.IRDst.size)
@@ -2891,7 +2891,7 @@ def bsr_bsf(ir, instr, dst, src, op_name):
 
     e_src_not_null = []
     e_src_not_null.append(m2_expr.ExprAff(zf, m2_expr.ExprInt(0, zf.size)))
-    e_src_not_null.append(m2_expr.ExprAff(dst, m2_expr.ExprOp(op_name, src)))
+    e_src_not_null.append(m2_expr.ExprAff(dst, op_func(src)))
     e_src_not_null.append(aff_dst)
 
     return e, [IRBlock(lbl_src_null.name, [AssignBlock(e_src_null, instr)]),
@@ -2899,11 +2899,15 @@ def bsr_bsf(ir, instr, dst, src, op_name):
 
 
 def bsf(ir, instr, dst, src):
-    return bsr_bsf(ir, instr, dst, src, "bsf")
+    return bsr_bsf(ir, instr, dst, src,
+                   lambda src: m2_expr.ExprOp("cnttrailzeros", src))
 
 
 def bsr(ir, instr, dst, src):
-    return bsr_bsf(ir, instr, dst, src, "bsr")
+    return bsr_bsf(
+        ir, instr, dst, src,
+        lambda src: m2_expr.ExprInt(src.size - 1, src.size) - m2_expr.ExprOp("cntleadzeros", src)
+    )
 
 
 def arpl(_, instr, dst, src):
diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py
index a1301cba..13b25ce2 100644
--- a/miasm2/expression/simplifications_common.py
+++ b/miasm2/expression/simplifications_common.py
@@ -95,19 +95,21 @@ def simp_cst_propagation(e_s, expr):
 
             args.append(ExprInt(out, int1.size))
 
-    # bsf(int) => int
-    if op_name == "bsf" and args[0].is_int() and args[0].arg != 0:
+    # cnttrailzeros(int) => int
+    if op_name == "cnttrailzeros" and args[0].is_int():
         i = 0
-        while args[0].arg & (1 << i) == 0:
+        while args[0].arg & (1 << i) == 0 and i < args[0].size:
             i += 1
         return ExprInt(i, args[0].size)
 
-    # bsr(int) => int
-    if op_name == "bsr" and args[0].is_int() and args[0].arg != 0:
+    # cntleadzeros(int) => int
+    if op_name == "cntleadzeros" and args[0].is_int():
+        if args[0].arg == 0:
+            return ExprInt(args[0].size, args[0].size)
         i = args[0].size - 1
         while args[0].arg & (1 << i) == 0:
             i -= 1
-        return ExprInt(i, args[0].size)
+        return ExprInt(expr.size - (i + 1), args[0].size)
 
     # -(-(A)) => A
     if (op_name == '-' and len(args) == 1 and args[0].is_op('-') and
diff --git a/miasm2/ir/translators/C.py b/miasm2/ir/translators/C.py
index 0e285669..099f1420 100644
--- a/miasm2/ir/translators/C.py
+++ b/miasm2/ir/translators/C.py
@@ -43,8 +43,8 @@ class TranslatorC(Translator):
             if expr.op == 'parity':
                 return "parity(%s&0x%x)" % (self.from_expr(expr.args[0]),
                                             size2mask(expr.args[0].size))
-            elif expr.op in ['bsr', 'bsf']:
-                return "x86_%s(0x%x, %s)" % (expr.op,
+            elif expr.op in ['cntleadzeros', 'cnttrailzeros']:
+                return "%s(0x%x, %s)" % (expr.op,
                                              expr.args[0].size,
                                              self.from_expr(expr.args[0]))
             elif expr.op in ['clz']:
diff --git a/miasm2/ir/translators/smt2.py b/miasm2/ir/translators/smt2.py
index 26ff9127..18bcb9bd 100644
--- a/miasm2/ir/translators/smt2.py
+++ b/miasm2/ir/translators/smt2.py
@@ -233,7 +233,7 @@ class TranslatorSMT2(Translator):
                 res = bvxor(res, bv_extract(i, i, arg))
         elif expr.op == '-':
             res = bvneg(res)
-        elif expr.op == "bsf":
+        elif expr.op == "cnttrailzeros":
             src = res
             size = expr.size
             size_smt2 = bit_vec_val(size, size)
@@ -254,7 +254,7 @@ class TranslatorSMT2(Translator):
                 cond = smt2_distinct(op, zero_smt2)
                 # ite(cond, i, res)
                 res = smt2_ite(cond, i_smt2, res)
-        elif expr.op == "bsr":
+        elif expr.op == "cntleadzeros":
             src = res
             size = expr.size
             one_smt2 = bit_vec_val(1, size)
@@ -271,7 +271,8 @@ class TranslatorSMT2(Translator):
                 # op != 0
                 cond = smt2_distinct(op, zero_smt2)
                 # ite(cond, index, res)
-                res = smt2_ite(cond, index_smt2, res)
+                value_smt2 = bit_vec_val(size - (index + 1), size)
+                res = smt2_ite(cond, value_smt2, res)
         else:
             raise NotImplementedError("Unsupported OP yet: %s" % expr.op)
 
diff --git a/miasm2/ir/translators/z3_ir.py b/miasm2/ir/translators/z3_ir.py
index 74bdd79c..536daff1 100644
--- a/miasm2/ir/translators/z3_ir.py
+++ b/miasm2/ir/translators/z3_ir.py
@@ -207,19 +207,20 @@ class TranslatorZ3(Translator):
                 res = res ^ z3.Extract(i, i, arg)
         elif expr.op == '-':
             res = -res
-        elif expr.op == "bsf":
+        elif expr.op == "cnttrailzeros":
             size = expr.size
             src = res
-            res = z3.If((src & (1 << (size - 1))) != 0, size - 1, src)
-            for i in xrange(size - 2, -1, -1):
+            res = z3.If(src == 0, size, src)
+            for i in xrange(size - 1, -1, -1):
                 res = z3.If((src & (1 << i)) != 0, i, res)
-        elif expr.op == "bsr":
+        elif expr.op == "cntleadzeros":
             size = expr.size
             src = res
-            res = z3.If((src & 1) != 0, 0, src)
-            for i in xrange(size - 1, 0, -1):
+            res = z3.If(src == 0, size, src)
+            for i in xrange(size, 0, -1):
                 index = - i % size
-                res = z3.If((src & (1 << index)) != 0, index, res)
+                out = size - (index + 1)
+                res = z3.If((src & (1 << index)) != 0, out, res)
         else:
             raise NotImplementedError("Unsupported OP yet: %s" % expr.op)
 
diff --git a/miasm2/jitter/llvmconvert.py b/miasm2/jitter/llvmconvert.py
index 35db1538..eef34c16 100644
--- a/miasm2/jitter/llvmconvert.py
+++ b/miasm2/jitter/llvmconvert.py
@@ -227,12 +227,6 @@ class LLVMContext_JIT(LLVMContext):
         itype = LLVMType.IntType(64)
         fc = {"llvm.ctpop.i8": {"ret": i8,
                                 "args": [i8]},
-              "x86_bsr": {"ret": itype,
-                          "args": [itype,
-                                   itype]},
-              "x86_bsf": {"ret": itype,
-                          "args": [itype,
-                                   itype]},
               "segm2addr": {"ret": itype,
                             "args": [p8,
                                      itype,
@@ -377,9 +371,7 @@ class LLVMFunction():
     op_translate = {'cpuid': 'cpuid',
     }
     ## Add the size as first argument
-    op_translate_with_size = {'bsr': 'x86_bsr',
-                              'bsf': 'x86_bsf',
-    }
+    op_translate_with_size = {}
     ## Add the size as suffix
     op_translate_with_suffix_size = {'bcdadd': 'bcdadd',
                                      'bcdadd_cf': 'bcdadd_cf',
@@ -714,6 +706,24 @@ class LLVMFunction():
                 self.update_cache(expr, ret)
                 return ret
 
+            if op in ["cntleadzeros", "cnttrailzeros"]:
+                assert len(expr.args) == 1
+                arg = self.add_ir(expr.args[0])
+                func_name = {
+                    "cntleadzeros": "ctlz",
+                    "cnttrailzeros": "cttz",
+                }[op]
+                func_llvm_name = "llvm.%s.i%d" % (func_name, expr.size)
+                func_sig = {func_llvm_name: {
+                    "ret": LLVMType.IntType(expr.size),
+                    "args": [LLVMType.IntType(expr.args[0].size)]
+                }}
+                self.llvm_context.add_fc(func_sig, readonly=True)
+                ret = builder.call(self.mod.get_global(func_llvm_name),
+                                   [arg])
+                self.update_cache(expr, ret)
+                return ret
+
             if op == "segm":
                 fc_ptr = self.mod.get_global("segm2addr")
 
diff --git a/miasm2/jitter/vm_mngr.c b/miasm2/jitter/vm_mngr.c
index 3a0e51d3..4331a2ac 100644
--- a/miasm2/jitter/vm_mngr.c
+++ b/miasm2/jitter/vm_mngr.c
@@ -832,27 +832,41 @@ uint64_t rot_right(uint64_t size, uint64_t a, uint64_t b)
     }
 }
 
-unsigned int x86_bsr(uint64_t size, uint64_t src)
+/*
+ * Count leading zeros - count the number of zero starting at the most
+ * significant bit
+ *
+ * Example:
+ * - cntleadzeros(size=32, src=2): 30
+ * - cntleadzeros(size=32, src=0): 32
+ */
+unsigned int cntleadzeros(uint64_t size, uint64_t src)
 {
 	int64_t i;
 
 	for (i=(int64_t)size-1; i>=0; i--){
 		if (src & (1ull << i))
-			return i;
+			return size - (i + 1);
 	}
-	fprintf(stderr, "sanity check error bsr\n");
-	exit(EXIT_FAILURE);
+	return size;
 }
 
-unsigned int x86_bsf(uint64_t size, uint64_t src)
+/*
+ * Count trailing zeros - count the number of zero starting at the least
+ * significant bit
+ *
+ * Example:
+ * - cnttrailzeros(size=32, src=2): 1
+ * - cnttrailzeros(size=32, src=0): 32
+ */
+unsigned int cnttrailzeros(uint64_t size, uint64_t src)
 {
 	uint64_t i;
 	for (i=0; i<size; i++){
 		if (src & (1ull << i))
 			return i;
 	}
-	fprintf(stderr, "sanity check error bsf\n");
-	exit(EXIT_FAILURE);
+	return size;
 }
 
 
diff --git a/miasm2/jitter/vm_mngr.h b/miasm2/jitter/vm_mngr.h
index f050f7c0..b101b6ca 100644
--- a/miasm2/jitter/vm_mngr.h
+++ b/miasm2/jitter/vm_mngr.h
@@ -219,8 +219,8 @@ unsigned int umul16_hi(unsigned short a, unsigned short b);
 uint64_t rot_left(uint64_t size, uint64_t a, uint64_t b);
 uint64_t rot_right(uint64_t size, uint64_t a, uint64_t b);
 
-unsigned int x86_bsr(uint64_t size, uint64_t src);
-unsigned int x86_bsf(uint64_t size, uint64_t src);
+unsigned int cntleadzeros(uint64_t size, uint64_t src);
+unsigned int cnttrailzeros(uint64_t size, uint64_t src);
 
 #define UDIV(sizeA)						\
 	uint ## sizeA ## _t udiv ## sizeA (vm_cpu_t* vmcpu, uint ## sizeA ## _t a, uint ## sizeA ## _t b) \
diff --git a/test/expression/simplifications.py b/test/expression/simplifications.py
index 3e2e5177..a4e839cf 100644
--- a/test/expression/simplifications.py
+++ b/test/expression/simplifications.py
@@ -395,6 +395,15 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)),
      ExprInt(0xc6, 8)),
     (ExprOp("imod", ExprInt(0x0123, 16), ExprInt(0xfffb, 16))[:8],
      ExprInt(0x01, 8)),
+    (ExprOp("cnttrailzeros", ExprInt(0x2, 32)),
+     ExprInt(0x1, 32)),
+    (ExprOp("cnttrailzeros", ExprInt(0x0, 32)),
+     ExprInt(0x20, 32)),
+    (ExprOp("cntleadzeros", ExprInt(0x2, 32)),
+     ExprInt(30, 32)),
+    (ExprOp("cntleadzeros", ExprInt(0x0, 32)),
+     ExprInt(0x20, 32)),
+
 
     (ExprCompose(ExprInt(0x0123, 16), ExprMem(a + ExprInt(0x40, a.size), 16),
                  ExprMem(a + ExprInt(0x42, a.size), 16), ExprInt(0x0321, 16)),
diff --git a/test/ir/ir2C.py b/test/ir/ir2C.py
index c84473c3..20ade999 100755
--- a/test/ir/ir2C.py
+++ b/test/ir/ir2C.py
@@ -38,7 +38,7 @@ class TestIrIr2C(unittest.TestCase):
         self.translationTest(
             ExprOp('-',       *args[:2]), r'(((0x0&0xffffffff) - (0x1&0xffffffff))&0xffffffff)')
         self.translationTest(
-            ExprOp('bsr',     *args[:1]), r'x86_bsr(0x0, 0x20)')
+            ExprOp('cntleadzeros',     *args[:1]), r'cntleadzeros(0x0, 0x20)')
         self.translationTest(
             ExprOp('cpuid',  *args[:2]), r'cpuid(0x0, 0x1)')
         self.translationTest(
diff --git a/test/ir/translators/z3_ir.py b/test/ir/translators/z3_ir.py
index 83744786..6ae2dcd0 100644
--- a/test/ir/translators/z3_ir.py
+++ b/test/ir/translators/z3_ir.py
@@ -150,22 +150,22 @@ ez3 = Translator.to_language('z3').from_expr(e8)
 assert not equiv(ez3, z3_e7)
 
 # --------------------------------------------------------------------------
-# bsr, bsf
-
-# bsf(0x1138) == 3
-bsf1 = Translator.to_language('z3').from_expr(ExprOp("bsf", ExprInt(0x1138, 32)))
-bsf2 = z3.BitVecVal(3, 32)
-assert(equiv(bsf1, bsf2))
-
-# bsr(0x11300) == 0x10
-bsr1 = Translator.to_language('z3').from_expr(ExprOp("bsr", ExprInt(0x11300, 32)))
-bsr2 = z3.BitVecVal(0x10, 32)
-assert(equiv(bsr1, bsr2))
-
-# bsf(0x80000) == bsr(0x80000)
-bsf3 = Translator.to_language('z3').from_expr(ExprOp("bsf", ExprInt(0x80000, 32)))
-bsr3 = Translator.to_language('z3').from_expr(ExprOp("bsr", ExprInt(0x80000, 32)))
-assert(equiv(bsf3, bsr3))
+# cntleadzeros, cnttrailzeros
+
+# cnttrailzeros(0x1138) == 3
+cnttrailzeros1 = Translator.to_language('z3').from_expr(ExprOp("cnttrailzeros", ExprInt(0x1138, 32)))
+cnttrailzeros2 = z3.BitVecVal(3, 32)
+assert(equiv(cnttrailzeros1, cnttrailzeros2))
+
+# cntleadzeros(0x11300) == 0xf
+cntleadzeros1 = Translator.to_language('z3').from_expr(ExprOp("cntleadzeros", ExprInt(0x11300, 32)))
+cntleadzeros2 = z3.BitVecVal(0xf, 32)
+assert(equiv(cntleadzeros1, cntleadzeros2))
+
+# cnttrailzeros(0x8000) + 1 == cntleadzeros(0x8000)
+cnttrailzeros3 = Translator.to_language('z3').from_expr(ExprOp("cnttrailzeros", ExprInt(0x8000, 32)) + ExprInt(1, 32))
+cntleadzeros3 = Translator.to_language('z3').from_expr(ExprOp("cntleadzeros", ExprInt(0x8000, 32)))
+assert(equiv(cnttrailzeros3, cntleadzeros3))
 
 print "TranslatorZ3 tests are OK."