about summary refs log tree commit diff stats
path: root/miasm2/arch/x86/sem.py
diff options
context:
space:
mode:
Diffstat (limited to 'miasm2/arch/x86/sem.py')
-rw-r--r--miasm2/arch/x86/sem.py494
1 files changed, 406 insertions, 88 deletions
diff --git a/miasm2/arch/x86/sem.py b/miasm2/arch/x86/sem.py
index 3cbf5526..589c2eb9 100644
--- a/miasm2/arch/x86/sem.py
+++ b/miasm2/arch/x86/sem.py
@@ -3319,62 +3319,104 @@ def vec_op_clip(op, size):
 # Generic vertical operation
 
 
-def vec_vertical_sem(op, elt_size, reg_size, dst, src):
+def vec_vertical_sem(op, elt_size, reg_size, dst, src, apply_on_output):
     assert reg_size % elt_size == 0
     n = reg_size / elt_size
     if op == '-':
         ops = [
-            (dst[i * elt_size:(i + 1) * elt_size]
-             - src[i * elt_size:(i + 1) * elt_size]) for i in xrange(0, n)]
+            apply_on_output((dst[i * elt_size:(i + 1) * elt_size]
+                             - src[i * elt_size:(i + 1) * elt_size]))
+            for i in xrange(0, n)
+        ]
     else:
-        ops = [m2_expr.ExprOp(op, dst[i * elt_size:(i + 1) * elt_size],
-                              src[i * elt_size:(i + 1) * elt_size]) for i in xrange(0, n)]
+        ops = [
+            apply_on_output(m2_expr.ExprOp(op, dst[i * elt_size:(i + 1) * elt_size],
+                                           src[i * elt_size:(i + 1) * elt_size]))
+            for i in xrange(0, n)
+        ]
 
     return m2_expr.ExprCompose(*ops)
 
 
-def float_vec_vertical_sem(op, elt_size, reg_size, dst, src):
+def float_vec_vertical_sem(op, elt_size, reg_size, dst, src, apply_on_output):
     assert reg_size % elt_size == 0
     n = reg_size / elt_size
 
     x_to_int, int_to_x = {32: ('float_to_int_%d', 'int_%d_to_float'),
                           64: ('double_to_int_%d', 'int_%d_to_double')}[elt_size]
     if op == '-':
-        ops = [m2_expr.ExprOp(x_to_int % elt_size,
-                              m2_expr.ExprOp(int_to_x % elt_size, dst[i * elt_size:(i + 1) * elt_size]) -
-                              m2_expr.ExprOp(
-                                  int_to_x % elt_size, src[i * elt_size:(
-                                      i + 1) * elt_size])) for i in xrange(0, n)]
+        ops = [
+            apply_on_output(m2_expr.ExprOp(
+                x_to_int % elt_size,
+                m2_expr.ExprOp(int_to_x % elt_size, dst[i * elt_size:(i + 1) * elt_size]) -
+                m2_expr.ExprOp(
+                    int_to_x % elt_size, src[i * elt_size:(
+                        i + 1) * elt_size])))
+            for i in xrange(0, n)
+        ]
     else:
-        ops = [m2_expr.ExprOp(x_to_int % elt_size,
-                              m2_expr.ExprOp(op,
-                                             m2_expr.ExprOp(
-                                                 int_to_x % elt_size, dst[i * elt_size:(
-                                                     i + 1) * elt_size]),
-                                             m2_expr.ExprOp(
-                                                 int_to_x % elt_size, src[i * elt_size:(
-                                                     i + 1) * elt_size]))) for i in xrange(0, n)]
+        ops = [
+            apply_on_output(m2_expr.ExprOp(
+                x_to_int % elt_size,
+                m2_expr.ExprOp(op,
+                               m2_expr.ExprOp(
+                                   int_to_x % elt_size, dst[i * elt_size:(
+                                       i + 1) * elt_size]),
+                               m2_expr.ExprOp(
+                                   int_to_x % elt_size, src[i * elt_size:(
+                                       i + 1) * elt_size]))))
+            for i in xrange(0, n)]
 
     return m2_expr.ExprCompose(*ops)
 
 
-def __vec_vertical_instr_gen(op, elt_size, sem):
+def __vec_vertical_instr_gen(op, elt_size, sem, apply_on_output):
     def vec_instr(ir, instr, dst, src):
         e = []
         if isinstance(src, m2_expr.ExprMem):
             src = ir.ExprMem(src.arg, dst.size)
         reg_size = dst.size
-        e.append(m2_expr.ExprAff(dst, sem(op, elt_size, reg_size, dst, src)))
+        e.append(m2_expr.ExprAff(dst, sem(op, elt_size, reg_size, dst, src,
+                                          apply_on_output)))
         return e, []
     return vec_instr
 
 
-def vec_vertical_instr(op, elt_size):
-    return __vec_vertical_instr_gen(op, elt_size, vec_vertical_sem)
+def vec_vertical_instr(op, elt_size, apply_on_output=lambda x: x):
+    return __vec_vertical_instr_gen(op, elt_size, vec_vertical_sem,
+                                    apply_on_output)
+
 
+def float_vec_vertical_instr(op, elt_size, apply_on_output=lambda x: x):
+    return __vec_vertical_instr_gen(op, elt_size, float_vec_vertical_sem,
+                                    apply_on_output)
 
-def float_vec_vertical_instr(op, elt_size):
-    return __vec_vertical_instr_gen(op, elt_size, float_vec_vertical_sem)
+
+def _keep_mul_high(expr, signed=False):
+    assert expr.is_op("*") and len(expr.args) == 2
+
+    if signed:
+        arg1 = expr.args[0].signExtend(expr.size * 2)
+        arg2 = expr.args[1].signExtend(expr.size * 2)
+    else:
+        arg1 = expr.args[0].zeroExtend(expr.size * 2)
+        arg2 = expr.args[1].zeroExtend(expr.size * 2)
+    return m2_expr.ExprOp("*", arg1, arg2)[expr.size:]
+
+# Op, signed => associated comparison
+_min_max_func = {
+    ("min", False): m2_expr.expr_is_unsigned_lower,
+    ("min", True): m2_expr.expr_is_signed_lower,
+    ("max", False): m2_expr.expr_is_unsigned_greater,
+    ("max", True): m2_expr.expr_is_signed_greater,
+}
+def _min_max(expr, signed):
+    assert (expr.is_op("min") or expr.is_op("max")) and len(expr.args) == 2
+    return m2_expr.ExprCond(
+        _min_max_func[(expr.op, signed)](expr.args[1], expr.args[0]),
+        expr.args[1],
+        expr.args[0],
+    )
 
 
 # Integer arithmetic
@@ -3398,6 +3440,109 @@ psubw = vec_vertical_instr('-', 16)
 psubd = vec_vertical_instr('-', 32)
 psubq = vec_vertical_instr('-', 64)
 
+# Multiplications
+#
+
+# SSE
+pmullb = vec_vertical_instr('*', 8)
+pmullw = vec_vertical_instr('*', 16)
+pmulld = vec_vertical_instr('*', 32)
+pmullq = vec_vertical_instr('*', 64)
+pmulhub = vec_vertical_instr('*', 8, _keep_mul_high)
+pmulhuw = vec_vertical_instr('*', 16, _keep_mul_high)
+pmulhud = vec_vertical_instr('*', 32, _keep_mul_high)
+pmulhuq = vec_vertical_instr('*', 64, _keep_mul_high)
+pmulhb = vec_vertical_instr('*', 8, lambda x: _keep_mul_high(x, signed=True))
+pmulhw = vec_vertical_instr('*', 16, lambda x: _keep_mul_high(x, signed=True))
+pmulhd = vec_vertical_instr('*', 32, lambda x: _keep_mul_high(x, signed=True))
+pmulhq = vec_vertical_instr('*', 64, lambda x: _keep_mul_high(x, signed=True))
+
+def pmuludq(ir, instr, dst, src):
+    e = []
+    if dst.size == 64:
+        e.append(m2_expr.ExprAff(
+            dst,
+            src[:32].zeroExtend(64) * dst[:32].zeroExtend(64)
+        ))
+    elif dst.size == 128:
+        e.append(m2_expr.ExprAff(
+            dst[:64],
+            src[:32].zeroExtend(64) * dst[:32].zeroExtend(64)
+        ))
+        e.append(m2_expr.ExprAff(
+            dst[64:],
+            src[64:96].zeroExtend(64) * dst[64:96].zeroExtend(64)
+        ))
+    else:
+        raise RuntimeError("Unsupported size %d" % dst.size)
+    return e, []
+
+# Mix
+#
+
+# SSE
+def pmaddwd(ir, instr, dst, src):
+    sizedst = 32
+    sizesrc = 16
+    out = []
+    for start in xrange(0, dst.size, sizedst):
+        base = start
+        mul1 = src[base: base + sizesrc].signExtend(sizedst) * dst[base: base + sizesrc].signExtend(sizedst)
+        base += sizesrc
+        mul2 = src[base: base + sizesrc].signExtend(sizedst) * dst[base: base + sizesrc].signExtend(sizedst)
+        out.append(mul1 + mul2)
+    return [m2_expr.ExprAff(dst, m2_expr.ExprCompose(*out))], []
+
+
+def _absolute(expr):
+    """Return abs(@expr)"""
+    signed = expr.msb()
+    value_unsigned = (expr ^ expr.mask) + m2_expr.ExprInt(1, expr.size)
+    return m2_expr.ExprCond(signed, value_unsigned, expr)
+
+
+def psadbw(ir, instr, dst, src):
+    sizedst = 16
+    sizesrc = 8
+    out_dst = []
+    for start in xrange(0, dst.size, 64):
+        out = []
+        for src_start in xrange(0, 64, sizesrc):
+            beg = start + src_start
+            end = beg + sizesrc
+            # Not clear in the doc equations, but in the text, src and dst are:
+            # "8 unsigned byte integers"
+            out.append(_absolute(dst[beg: end].zeroExtend(sizedst) - src[beg: end].zeroExtend(sizedst)))
+        out_dst.append(m2_expr.ExprOp("+", *out))
+        out_dst.append(m2_expr.ExprInt(0, 64 - sizedst))
+
+    return [m2_expr.ExprAff(dst, m2_expr.ExprCompose(*out_dst))], []
+
+def _average(expr):
+    assert expr.is_op("avg") and len(expr.args) == 2
+
+    arg1 = expr.args[0].zeroExtend(expr.size * 2)
+    arg2 = expr.args[1].zeroExtend(expr.size * 2)
+    one = m2_expr.ExprInt(1, arg1.size)
+    # avg(unsigned) = (a + b + 1) >> 1, addition beeing at least on one more bit
+    return ((arg1 + arg2 + one) >> one)[:expr.size]
+
+pavgb = vec_vertical_instr('avg', 8, _average)
+pavgw = vec_vertical_instr('avg', 16, _average)
+
+# Comparisons
+#
+
+# SSE
+pminsw = vec_vertical_instr('min', 16, lambda x: _min_max(x, signed=True))
+pminub = vec_vertical_instr('min', 8, lambda x: _min_max(x, signed=False))
+pminuw = vec_vertical_instr('min', 16, lambda x: _min_max(x, signed=False))
+pminud = vec_vertical_instr('min', 32, lambda x: _min_max(x, signed=False))
+pmaxub = vec_vertical_instr('max', 8, lambda x: _min_max(x, signed=False))
+pmaxuw = vec_vertical_instr('max', 16, lambda x: _min_max(x, signed=False))
+pmaxud = vec_vertical_instr('max', 32, lambda x: _min_max(x, signed=False))
+pmaxsw = vec_vertical_instr('max', 16, lambda x: _min_max(x, signed=True))
+
 # Floating-point arithmetic
 #
 
@@ -3448,12 +3593,6 @@ def por(_, instr, dst, src):
     return e, []
 
 
-def pminsw(_, instr, dst, src):
-    e = []
-    e.append(m2_expr.ExprAff(dst, m2_expr.ExprCond((dst - src).msb(), dst, src)))
-    return e, []
-
-
 def cvtdq2pd(_, instr, dst, src):
     e = []
     e.append(
@@ -3819,62 +3958,6 @@ def iret(ir, instr):
     return exprs, []
 
 
-def pmaxu(_, instr, dst, src, size):
-    e = []
-    for i in xrange(0, dst.size, size):
-        op1 = dst[i:i + size]
-        op2 = src[i:i + size]
-        res = op1 - op2
-        # Compote CF in @res = @op1 - @op2
-        ret = (((op1 ^ op2) ^ res) ^ ((op1 ^ res) & (op1 ^ op2))).msb()
-
-        e.append(m2_expr.ExprAff(dst[i:i + size],
-                                 m2_expr.ExprCond(ret,
-                                                  src[i:i + size],
-                                                  dst[i:i + size])))
-    return e, []
-
-
-def pmaxub(ir, instr, dst, src):
-    return pmaxu(ir, instr, dst, src, 8)
-
-
-def pmaxuw(ir, instr, dst, src):
-    return pmaxu(ir, instr, dst, src, 16)
-
-
-def pmaxud(ir, instr, dst, src):
-    return pmaxu(ir, instr, dst, src, 32)
-
-
-def pminu(_, instr, dst, src, size):
-    e = []
-    for i in xrange(0, dst.size, size):
-        op1 = dst[i:i + size]
-        op2 = src[i:i + size]
-        res = op1 - op2
-        # Compote CF in @res = @op1 - @op2
-        ret = (((op1 ^ op2) ^ res) ^ ((op1 ^ res) & (op1 ^ op2))).msb()
-
-        e.append(m2_expr.ExprAff(dst[i:i + size],
-                                 m2_expr.ExprCond(ret,
-                                                  dst[i:i + size],
-                                                  src[i:i + size])))
-    return e, []
-
-
-def pminub(ir, instr, dst, src):
-    return pminu(ir, instr, dst, src, 8)
-
-
-def pminuw(ir, instr, dst, src):
-    return pminu(ir, instr, dst, src, 16)
-
-
-def pminud(ir, instr, dst, src):
-    return pminu(ir, instr, dst, src, 32)
-
-
 def pcmpeq(_, instr, dst, src, size):
     e = []
     for i in xrange(0, dst.size, size):
@@ -4173,6 +4256,202 @@ def palignr(ir, instr, dst, src, imm):
     return [m2_expr.ExprAff(dst, result)], []
 
 
+def _signed_saturation(expr, dst_size):
+    """Saturate the expr @expr for @dst_size bit
+    Signed saturation return MAX_INT / MIN_INT or value depending on the value
+    """
+    assert expr.size > dst_size
+
+    median = 1 << (dst_size - 1)
+    min_int = m2_expr.ExprInt(- median, dst_size)
+    max_int = m2_expr.ExprInt(median - 1, dst_size)
+    signed = expr.msb()
+    value_unsigned = (expr ^ expr.mask) + m2_expr.ExprInt(1, expr.size)
+    # Re-use the sign bit
+    value = m2_expr.ExprCompose(expr[:dst_size - 1], signed)
+
+    # Bit hack: to avoid a double signed comparison, use mask
+    # ie., in unsigned, 0xXY > 0x0f iff X is not null
+
+    # if expr >s 0
+    #    if expr[dst_size - 1:] > 0: # bigger than max_int
+    #        -> max_int
+    #    else
+    #        -> value
+    # else # negative
+    #    if expr[dst_size:-1] > 0: # smaller than min_int
+    #        -> value
+    #    else
+    #        -> min_int
+
+    return m2_expr.ExprCond(
+        signed,
+        m2_expr.ExprCond(value_unsigned[dst_size - 1:],
+                         min_int,
+                         value),
+        m2_expr.ExprCond(expr[dst_size - 1:],
+                         max_int,
+                         value),
+    )
+
+
+def _unsigned_saturation(expr, dst_size):
+    """Saturate the expr @expr for @dst_size bit
+    Unsigned saturation return MAX_INT or value depending on the value
+    """
+    assert expr.size > dst_size
+
+    zero = m2_expr.ExprInt(0, dst_size)
+    max_int = m2_expr.ExprInt(-1, dst_size)
+    value = expr[:dst_size]
+    signed = expr.msb()
+
+
+    # Bit hack: to avoid a double signed comparison, use mask
+    # ie., in unsigned, 0xXY > 0x0f iff X is not null
+
+    return m2_expr.ExprCond(
+        signed,
+        zero,
+        m2_expr.ExprCond(expr[dst_size:],
+                         max_int,
+                         value),
+    )
+
+
+
+def packsswb(ir, instr, dst, src):
+    out = []
+    for source in [dst, src]:
+        for start in xrange(0, dst.size, 16):
+            out.append(_signed_saturation(source[start:start + 16], 8))
+    return [m2_expr.ExprAff(dst, m2_expr.ExprCompose(*out))], []
+
+
+def packssdw(ir, instr, dst, src):
+    out = []
+    for source in [dst, src]:
+        for start in xrange(0, dst.size, 32):
+            out.append(_signed_saturation(source[start:start + 32], 16))
+    return [m2_expr.ExprAff(dst, m2_expr.ExprCompose(*out))], []
+
+
+def packuswb(ir, instr, dst, src):
+    out = []
+    for source in [dst, src]:
+        for start in xrange(0, dst.size, 16):
+            out.append(_unsigned_saturation(source[start:start + 16], 8))
+    return [m2_expr.ExprAff(dst, m2_expr.ExprCompose(*out))], []
+
+
+def _saturation_sub_unsigned(expr):
+    assert expr.is_op("+") and len(expr.args) == 2 and expr.args[-1].is_op("-")
+
+    # Compute the soustraction on one more bit to be able to distinguish cases:
+    # 0x48 - 0xd7 in 8 bit, should saturate
+    arg1 = expr.args[0].zeroExtend(expr.size + 1)
+    arg2 = expr.args[1].args[0].zeroExtend(expr.size + 1)
+    return _unsigned_saturation(arg1 - arg2, expr.size)
+
+def _saturation_sub_signed(expr):
+    assert expr.is_op("+") and len(expr.args) == 2 and expr.args[-1].is_op("-")
+
+    # Compute the substraction on two more bits, see _saturation_sub_unsigned
+    arg1 = expr.args[0].signExtend(expr.size + 2)
+    arg2 = expr.args[1].args[0].signExtend(expr.size + 2)
+    return _signed_saturation(arg1 - arg2, expr.size)
+
+def _saturation_add(expr):
+    assert expr.is_op("+") and len(expr.args) == 2
+
+    # Compute the addition on one more bit to be able to distinguish cases:
+    # 0x48 + 0xd7 in 8 bit, should saturate
+
+    arg1 = expr.args[0].zeroExtend(expr.size + 1)
+    arg2 = expr.args[1].zeroExtend(expr.size + 1)
+
+    # We can also use _unsigned_saturation with two additionnal bits (to
+    # distinguish minus and overflow case)
+    # The resulting expression being more complicated with an impossible case
+    # (signed=True), we rewrite the rule here
+
+    return m2_expr.ExprCond((arg1 + arg2).msb(), m2_expr.ExprInt(-1, expr.size),
+                            expr)
+
+def _saturation_add_signed(expr):
+    assert expr.is_op("+") and len(expr.args) == 2
+
+    # Compute the substraction on two more bits, see _saturation_add_unsigned
+
+    arg1 = expr.args[0].signExtend(expr.size + 2)
+    arg2 = expr.args[1].signExtend(expr.size + 2)
+
+    return _signed_saturation(arg1 + arg2, expr.size)
+
+
+# Saturate SSE operations
+
+psubusb = vec_vertical_instr('-', 8, _saturation_sub_unsigned)
+psubusw = vec_vertical_instr('-', 16, _saturation_sub_unsigned)
+paddusb = vec_vertical_instr('+', 8, _saturation_add)
+paddusw = vec_vertical_instr('+', 16, _saturation_add)
+psubsb = vec_vertical_instr('-', 8, _saturation_sub_signed)
+psubsw = vec_vertical_instr('-', 16, _saturation_sub_signed)
+paddsb = vec_vertical_instr('+', 8, _saturation_add_signed)
+paddsw = vec_vertical_instr('+', 16, _saturation_add_signed)
+
+
+# Others SSE operations
+
+def maskmovq(ir, instr, src, mask):
+    lbl_next = m2_expr.ExprId(ir.get_next_label(instr), ir.IRDst.size)
+    blks = []
+
+    # For each possibility, check if a write is necessary
+    check_labels = [m2_expr.ExprId(ir.gen_label(), ir.IRDst.size)
+                    for _ in xrange(0, mask.size, 8)]
+    # If the write has to be done, do it (otherwise, nothing happen)
+    write_labels = [m2_expr.ExprId(ir.gen_label(), ir.IRDst.size)
+                    for _ in xrange(0, mask.size, 8)]
+
+    # Build check blocks
+    for i, start in enumerate(xrange(0, mask.size, 8)):
+        bit = mask[start + 7: start + 8]
+        cur_label = check_labels[i]
+        next_check_label = check_labels[i + 1] if (i + 1) < len(check_labels) else lbl_next
+        write_label = write_labels[i]
+        check = m2_expr.ExprAff(ir.IRDst,
+                                m2_expr.ExprCond(bit,
+                                                 write_label,
+                                                 next_check_label))
+        blks.append(IRBlock(cur_label.name, [AssignBlock([check], instr)]))
+
+    # Build write blocks
+    dst_addr = mRDI[instr.mode]
+    for i, start in enumerate(xrange(0, mask.size, 8)):
+        bit = mask[start + 7: start + 8]
+        cur_label = write_labels[i]
+        next_check_label = check_labels[i + 1] if (i + 1) < len(check_labels) else lbl_next
+        write_addr = dst_addr + m2_expr.ExprInt(i, dst_addr.size)
+
+        # @8[DI/EDI/RDI + i] = src[byte i]
+        write_mem = m2_expr.ExprAff(m2_expr.ExprMem(write_addr, 8),
+                                    src[start: start + 8])
+        jump = m2_expr.ExprAff(ir.IRDst, next_check_label)
+        blks.append(IRBlock(cur_label.name, [AssignBlock([write_mem, jump], instr)]))
+
+    # If mask is null, bypass all
+    e = [m2_expr.ExprAff(ir.IRDst, m2_expr.ExprCond(mask,
+                                                    check_labels[0],
+                                                    lbl_next))]
+    return e, blks
+
+
+def emms(ir, instr):
+    # Implemented as a NOP
+    return [], []
+
+
 mnemo_func = {'mov': mov,
               'xchg': xchg,
               'movzx': movzx,
@@ -4557,6 +4836,29 @@ mnemo_func = {'mov': mov,
               "psubd": psubd,
               "psubq": psubq,
 
+              # Multiplications
+              # SSE
+              "pmullb": pmullb,
+              "pmullw": pmullw,
+              "pmulld": pmulld,
+              "pmullq": pmullq,
+              "pmulhub": pmulhub,
+              "pmulhuw": pmulhuw,
+              "pmulhud": pmulhud,
+              "pmulhuq": pmulhuq,
+              "pmulhb": pmulhb,
+              "pmulhw": pmulhw,
+              "pmulhd": pmulhd,
+              "pmulhq": pmulhq,
+              "pmuludq": pmuludq,
+
+              # Mix
+              # SSE
+              "pmaddwd": pmaddwd,
+              "psadbw": psadbw,
+              "pavgb": pavgb,
+              "pavgw": pavgw,
+
               # Arithmetic (floating-point)
               #
 
@@ -4614,6 +4916,7 @@ mnemo_func = {'mov': mov,
               "pmaxub": pmaxub,
               "pmaxuw": pmaxuw,
               "pmaxud": pmaxud,
+              "pmaxsw": pmaxsw,
 
               "pminub": pminub,
               "pminuw": pminuw,
@@ -4670,8 +4973,23 @@ mnemo_func = {'mov': mov,
 
               "pmovmskb": pmovmskb,
 
-              "smsw": smsw,
+              "packsswb": packsswb,
+              "packssdw": packssdw,
+              "packuswb": packuswb,
+
+              "psubusb": psubusb,
+              "psubusw": psubusw,
+              "paddusb": paddusb,
+              "paddusw": paddusw,
+              "psubsb": psubsb,
+              "psubsw": psubsw,
+              "paddsb": paddsb,
+              "paddsw": paddsw,
 
+              "smsw": smsw,
+              "maskmovq": maskmovq,
+              "maskmovdqu": maskmovq,
+              "emms": emms,
               }