about summary refs log tree commit diff stats
path: root/src/miasm/expression/simplifications_common.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/miasm/expression/simplifications_common.py')
-rw-r--r--src/miasm/expression/simplifications_common.py1868
1 files changed, 1868 insertions, 0 deletions
diff --git a/src/miasm/expression/simplifications_common.py b/src/miasm/expression/simplifications_common.py
new file mode 100644
index 00000000..9156ee67
--- /dev/null
+++ b/src/miasm/expression/simplifications_common.py
@@ -0,0 +1,1868 @@
+# ----------------------------- #
+# Common simplifications passes #
+# ----------------------------- #
+
+from future.utils import viewitems
+
+from miasm.core.modint import mod_size2int, mod_size2uint
+from miasm.expression.expression import ExprInt, ExprSlice, ExprMem, \
+    ExprCond, ExprOp, ExprCompose, TOK_INF_SIGNED, TOK_INF_UNSIGNED, \
+    TOK_INF_EQUAL_SIGNED, TOK_INF_EQUAL_UNSIGNED, TOK_EQUAL
+from miasm.expression.expression_helper import parity, op_propag_cst, \
+    merge_sliceto_slice
+from miasm.expression.simplifications_explicit import simp_flags
+
+def simp_cst_propagation(e_s, expr):
+    """This passe includes:
+     - Constant folding
+     - Common logical identities
+     - Common binary identities
+     """
+
+    # merge associatif op
+    args = list(expr.args)
+    op_name = expr.op
+    # simpl integer manip
+    # int OP int => int
+    # TODO: <<< >>> << >> are architecture dependent
+    if op_name in op_propag_cst:
+        while (len(args) >= 2 and
+            args[-1].is_int() and
+            args[-2].is_int()):
+            int2 = args.pop()
+            int1 = args.pop()
+            if op_name == '+':
+                out = mod_size2uint[int1.size](int(int1) + int(int2))
+            elif op_name == '*':
+                out = mod_size2uint[int1.size](int(int1) * int(int2))
+            elif op_name == '**':
+                out = mod_size2uint[int1.size](int(int1) ** int(int2))
+            elif op_name == '^':
+                out = mod_size2uint[int1.size](int(int1) ^ int(int2))
+            elif op_name == '&':
+                out = mod_size2uint[int1.size](int(int1) & int(int2))
+            elif op_name == '|':
+                out = mod_size2uint[int1.size](int(int1) | int(int2))
+            elif op_name == '>>':
+                if int(int2) > int1.size:
+                    out = 0
+                else:
+                    out = mod_size2uint[int1.size](int(int1) >> int(int2))
+            elif op_name == '<<':
+                if int(int2) > int1.size:
+                    out = 0
+                else:
+                    out = mod_size2uint[int1.size](int(int1) << int(int2))
+            elif op_name == 'a>>':
+                tmp1 = mod_size2int[int1.size](int(int1))
+                tmp2 = mod_size2uint[int2.size](int(int2))
+                if tmp2 > int1.size:
+                    is_signed = int(int1) & (1 << (int1.size - 1))
+                    if is_signed:
+                        out = -1
+                    else:
+                        out = 0
+                else:
+                    out = mod_size2uint[int1.size](tmp1 >> tmp2)
+            elif op_name == '>>>':
+                shifter = int(int2) % int2.size
+                out = (int(int1) >> shifter) | (int(int1) << (int2.size - shifter))
+            elif op_name == '<<<':
+                shifter = int(int2) % int2.size
+                out = (int(int1) << shifter) | (int(int1) >> (int2.size - shifter))
+            elif op_name == '/':
+                if int(int2) == 0:
+                    return expr
+                out = int(int1) // int(int2)
+            elif op_name == '%':
+                if int(int2) == 0:
+                    return expr
+                out = int(int1) % int(int2)
+            elif op_name == 'sdiv':
+                if int(int2) == 0:
+                    return expr
+                tmp1 = mod_size2int[int1.size](int(int1))
+                tmp2 = mod_size2int[int2.size](int(int2))
+                out = mod_size2uint[int1.size](tmp1 // tmp2)
+            elif op_name == 'smod':
+                if int(int2) == 0:
+                    return expr
+                tmp1 = mod_size2int[int1.size](int(int1))
+                tmp2 = mod_size2int[int2.size](int(int2))
+                out = mod_size2uint[int1.size](tmp1 % tmp2)
+            elif op_name == 'umod':
+                if int(int2) == 0:
+                    return expr
+                tmp1 = mod_size2uint[int1.size](int(int1))
+                tmp2 = mod_size2uint[int2.size](int(int2))
+                out = mod_size2uint[int1.size](tmp1 % tmp2)
+            elif op_name == 'udiv':
+                if int(int2) == 0:
+                    return expr
+                tmp1 = mod_size2uint[int1.size](int(int1))
+                tmp2 = mod_size2uint[int2.size](int(int2))
+                out = mod_size2uint[int1.size](tmp1 // tmp2)
+
+
+
+            args.append(ExprInt(int(out), int1.size))
+
+    # cnttrailzeros(int) => int
+    if op_name == "cnttrailzeros" and args[0].is_int():
+        i = 0
+        while int(args[0]) & (1 << i) == 0 and i < args[0].size:
+            i += 1
+        return ExprInt(i, args[0].size)
+
+    # cntleadzeros(int) => int
+    if op_name == "cntleadzeros" and args[0].is_int():
+        if int(args[0]) == 0:
+            return ExprInt(args[0].size, args[0].size)
+        i = args[0].size - 1
+        while int(args[0]) & (1 << i) == 0:
+            i -= 1
+        return ExprInt(expr.size - (i + 1), args[0].size)
+
+    # -(-(A)) => A
+    if (op_name == '-' and len(args) == 1 and args[0].is_op('-') and
+        len(args[0].args) == 1):
+        return args[0].args[0]
+
+
+    # -(int) => -int
+    if op_name == '-' and len(args) == 1 and args[0].is_int():
+        return ExprInt(-int(args[0]), expr.size)
+    # A op 0 =>A
+    if op_name in ['+', '|', "^", "<<", ">>", "<<<", ">>>"] and len(args) > 1:
+        if args[-1].is_int(0):
+            args.pop()
+    # A - 0 =>A
+    if op_name == '-' and len(args) > 1 and args[-1].is_int(0):
+        assert len(args) == 2 # Op '-' with more than 2 args: SantityCheckError
+        return args[0]
+
+    # A * 1 =>A
+    if op_name == "*" and len(args) > 1 and args[-1].is_int(1):
+        args.pop()
+
+    # for cannon form
+    # A * -1 => - A
+    if op_name == "*" and len(args) > 1 and args[-1] == args[-1].mask:
+        args.pop()
+        args[-1] = - args[-1]
+
+    # op A => A
+    if op_name in ['+', '*', '^', '&', '|', '>>', '<<',
+              'a>>', '<<<', '>>>', 'sdiv', 'smod', 'umod', 'udiv'] and len(args) == 1:
+        return args[0]
+
+    # A-B => A + (-B)
+    if op_name == '-' and len(args) > 1:
+        if len(args) > 2:
+            raise ValueError(
+                'sanity check fail on expr -: should have one or 2 args ' +
+                '%r %s' % (expr, expr)
+            )
+        return ExprOp('+', args[0], -args[1])
+
+    # A op 0 => 0
+    if op_name in ['&', "*"] and args[-1].is_int(0):
+        return ExprInt(0, expr.size)
+
+    # - (A + B +...) => -A + -B + -C
+    if op_name == '-' and len(args) == 1 and args[0].is_op('+'):
+        args = [-a for a in args[0].args]
+        return ExprOp('+', *args)
+
+    # -(a?int1:int2) => (a?-int1:-int2)
+    if (op_name == '-' and len(args) == 1 and
+        args[0].is_cond() and
+        args[0].src1.is_int() and args[0].src2.is_int()):
+        int1 = args[0].src1
+        int2 = args[0].src2
+        int1 = ExprInt(-int1.arg, int1.size)
+        int2 = ExprInt(-int2.arg, int2.size)
+        return ExprCond(args[0].cond, int1, int2)
+
+    i = 0
+    while i < len(args) - 1:
+        j = i + 1
+        while j < len(args):
+            # A ^ A => 0
+            if op_name == '^' and args[i] == args[j]:
+                args[i] = ExprInt(0, args[i].size)
+                del args[j]
+                continue
+            # A + (- A) => 0
+            if op_name == '+' and args[j].is_op("-"):
+                if len(args[j].args) == 1 and args[i] == args[j].args[0]:
+                    args[i] = ExprInt(0, args[i].size)
+                    del args[j]
+                    continue
+            # (- A) + A => 0
+            if op_name == '+' and args[i].is_op("-"):
+                if len(args[i].args) == 1 and args[j] == args[i].args[0]:
+                    args[i] = ExprInt(0, args[i].size)
+                    del args[j]
+                    continue
+            # A | A => A
+            if op_name == '|' and args[i] == args[j]:
+                del args[j]
+                continue
+            # A & A => A
+            if op_name == '&' and args[i] == args[j]:
+                del args[j]
+                continue
+            j += 1
+        i += 1
+
+    if op_name in ['+', '^', '|', '&', '%', '/', '**'] and len(args) == 1:
+        return args[0]
+
+    # A <<< A.size => A
+    if (op_name in ['<<<', '>>>'] and
+        args[1].is_int() and
+        int(args[1]) == args[0].size):
+        return args[0]
+
+    # (A <<< X) <<< Y => A <<< (X+Y) (or <<< >>>) if X + Y does not overflow
+    if (op_name in ['<<<', '>>>'] and
+        args[0].is_op() and
+        args[0].op in ['<<<', '>>>']):
+        A = args[0].args[0]
+        X = args[0].args[1]
+        Y = args[1]
+        if op_name != args[0].op and e_s(X - Y) == ExprInt(0, X.size):
+            return args[0].args[0]
+        elif X.is_int() and Y.is_int():
+            new_X = int(X) % expr.size
+            new_Y = int(Y) % expr.size
+            if op_name == args[0].op:
+                rot = (new_X + new_Y) % expr.size
+                op = op_name
+            else:
+                rot = new_Y - new_X
+                op = op_name
+                if rot < 0:
+                    rot = - rot
+                    op = {">>>": "<<<", "<<<": ">>>"}[op_name]
+            args = [A, ExprInt(rot, expr.size)]
+            op_name = op
+
+        else:
+            # Do not consider this case, too tricky (overflow on addition /
+            # subtraction)
+            pass
+
+    # A >> X >> Y  =>  A >> (X+Y) if X + Y does not overflow
+    # To be sure, only consider the simplification when X.msb and Y.msb are 0
+    if (op_name in ['<<', '>>'] and
+        args[0].is_op(op_name)):
+        X = args[0].args[1]
+        Y = args[1]
+        if (e_s(X.msb()) == ExprInt(0, 1) and
+            e_s(Y.msb()) == ExprInt(0, 1)):
+            args = [args[0].args[0], X + Y]
+
+    # ((var >> int1) << int1) => var & mask
+    # ((var << int1) >> int1) => var & mask
+    if (op_name in ['<<', '>>'] and
+        args[0].is_op() and
+        args[0].op in ['<<', '>>'] and
+        op_name != args[0]):
+        var = args[0].args[0]
+        int1 = args[0].args[1]
+        int2 = args[1]
+        if int1 == int2 and int1.is_int() and int(int1) < expr.size:
+            if op_name == '>>':
+                mask = ExprInt((1 << (expr.size - int(int1))) - 1, expr.size)
+            else:
+                mask = ExprInt(
+                    ((1 << int(int1)) - 1) ^ ((1 << expr.size) - 1),
+                    expr.size
+                )
+            ret = var & mask
+            return ret
+
+    # ((A & A.mask)
+    if op_name == "&" and args[-1] == expr.mask:
+        args = args[:-1]
+        if len(args) == 1:
+            return args[0]
+        return ExprOp('&', *args)
+
+    # ((A | A.mask)
+    if op_name == "|" and args[-1] == expr.mask:
+        return args[-1]
+
+    # ! (!X + int) => X - int
+    # TODO
+
+    # ((A & mask) >> shift) with mask < 2**shift => 0
+    if op_name == ">>" and args[1].is_int() and args[0].is_op("&"):
+        if (args[0].args[1].is_int() and
+            2 ** int(args[1]) > int(args[0].args[1])):
+            return ExprInt(0, args[0].size)
+
+    # parity(int) => int
+    if op_name == 'parity' and args[0].is_int():
+        return ExprInt(parity(int(args[0])), 1)
+
+    # (-a) * b * (-c) * (-d) => (-a) * b * c * d
+    if op_name == "*" and len(args) > 1:
+        new_args = []
+        counter = 0
+        for arg in args:
+            if arg.is_op('-') and len(arg.args) == 1:
+                new_args.append(arg.args[0])
+                counter += 1
+            else:
+                new_args.append(arg)
+        if counter % 2:
+            return -ExprOp(op_name, *new_args)
+        args = new_args
+
+    # -(a * b * int) => a * b * (-int)
+    if op_name == "-" and args[0].is_op('*') and args[0].args[-1].is_int():
+        args = args[0].args
+        return ExprOp('*', *(list(args[:-1]) + [ExprInt(-int(args[-1]), expr.size)]))
+
+    # A << int with A ExprCompose => move index
+    if (op_name == "<<" and args[0].is_compose() and
+        args[1].is_int() and int(args[1]) != 0):
+        final_size = args[0].size
+        shift = int(args[1])
+        new_args = []
+        # shift indexes
+        for index, arg in args[0].iter_args():
+            new_args.append((arg, index+shift, index+shift+arg.size))
+        # filter out expression
+        filter_args = []
+        min_index = final_size
+        for tmp, start, stop in new_args:
+            if start >= final_size:
+                continue
+            if stop > final_size:
+                tmp = tmp[:tmp.size  - (stop - final_size)]
+            filter_args.append(tmp)
+            min_index = min(start, min_index)
+        # create entry 0
+        assert min_index != 0
+        tmp = ExprInt(0, min_index)
+        args = [tmp] + filter_args
+        return ExprCompose(*args)
+
+    # A >> int with A ExprCompose => move index
+    if op_name == ">>" and args[0].is_compose() and args[1].is_int():
+        final_size = args[0].size
+        shift = int(args[1])
+        new_args = []
+        # shift indexes
+        for index, arg in args[0].iter_args():
+            new_args.append((arg, index-shift, index+arg.size-shift))
+        # filter out expression
+        filter_args = []
+        max_index = 0
+        for tmp, start, stop in new_args:
+            if stop <= 0:
+                continue
+            if start < 0:
+                tmp = tmp[-start:]
+            filter_args.append(tmp)
+            max_index = max(stop, max_index)
+        # create entry 0
+        tmp = ExprInt(0, final_size - max_index)
+        args = filter_args + [tmp]
+        return ExprCompose(*args)
+
+
+    # Compose(a) OP Compose(b) with a/b same bounds => Compose(a OP b)
+    if op_name in ['|', '&', '^'] and all([arg.is_compose() for arg in args]):
+        bounds = set()
+        for arg in args:
+            bound = tuple([tmp.size for tmp in arg.args])
+            bounds.add(bound)
+        if len(bounds) == 1:
+            new_args = [[tmp] for tmp in args[0].args]
+            for sub_arg in args[1:]:
+                for i, tmp in enumerate(sub_arg.args):
+                    new_args[i].append(tmp)
+            args = []
+            for i, arg in enumerate(new_args):
+                args.append(ExprOp(op_name, *arg))
+            return ExprCompose(*args)
+
+    return ExprOp(op_name, *args)
+
+
+def simp_cond_op_int(_, expr):
+    "Extract conditions from operations"
+
+
+    # x?a:b + x?c:d + e => x?(a+c+e:b+d+e)
+    if not expr.op in ["+", "|", "^", "&", "*", '<<', '>>', 'a>>']:
+        return expr
+    if len(expr.args) < 2:
+        return expr
+    conds = set()
+    for arg in expr.args:
+        if arg.is_cond():
+            conds.add(arg)
+    if len(conds) != 1:
+        return expr
+    cond = list(conds).pop()
+
+    args1, args2 = [], []
+    for arg in expr.args:
+        if arg.is_cond():
+            args1.append(arg.src1)
+            args2.append(arg.src2)
+        else:
+            args1.append(arg)
+            args2.append(arg)
+
+    return ExprCond(cond.cond,
+                    ExprOp(expr.op, *args1),
+                    ExprOp(expr.op, *args2))
+
+
+def simp_cond_factor(e_s, expr):
+    "Merge similar conditions"
+    if not expr.op in ["+", "|", "^", "&", "*", '<<', '>>', 'a>>']:
+        return expr
+    if len(expr.args) < 2:
+        return expr
+
+    if expr.op in ['>>', '<<', 'a>>']:
+        assert len(expr.args) == 2
+
+    # Note: the following code is correct for non-commutative operation only if
+    # there is 2 arguments. Otherwise, the order is not conserved
+
+    # Regroup sub-expression by similar conditions
+    conds = {}
+    not_conds = []
+    multi_cond = False
+    for arg in expr.args:
+        if not arg.is_cond():
+            not_conds.append(arg)
+            continue
+        cond = arg.cond
+        if not cond in conds:
+            conds[cond] = []
+        else:
+            multi_cond = True
+        conds[cond].append(arg)
+    if not multi_cond:
+        return expr
+
+    # Rebuild the new expression
+    c_out = not_conds
+    for cond, vals in viewitems(conds):
+        new_src1 = [x.src1 for x in vals]
+        new_src2 = [x.src2 for x in vals]
+        src1 = e_s.expr_simp(ExprOp(expr.op, *new_src1))
+        src2 = e_s.expr_simp(ExprOp(expr.op, *new_src2))
+        c_out.append(ExprCond(cond, src1, src2))
+
+    if len(c_out) == 1:
+        new_e = c_out[0]
+    else:
+        new_e = ExprOp(expr.op, *c_out)
+    return new_e
+
+
+def simp_slice(e_s, expr):
+    "Slice optimization"
+
+    # slice(A, 0, a.size) => A
+    if expr.start == 0 and expr.stop == expr.arg.size:
+        return expr.arg
+    # Slice(int) => int
+    if expr.arg.is_int():
+        total_bit = expr.stop - expr.start
+        mask = (1 << (expr.stop - expr.start)) - 1
+        return ExprInt(int((int(expr.arg) >> expr.start) & mask), total_bit)
+    # Slice(Slice(A, x), y) => Slice(A, z)
+    if expr.arg.is_slice():
+        if expr.stop - expr.start > expr.arg.stop - expr.arg.start:
+            raise ValueError('slice in slice: getting more val', str(expr))
+
+        return ExprSlice(expr.arg.arg, expr.start + expr.arg.start,
+                         expr.start + expr.arg.start + (expr.stop - expr.start))
+    if expr.arg.is_compose():
+        # Slice(Compose(A), x) => Slice(A, y)
+        for index, arg in expr.arg.iter_args():
+            if index <= expr.start and index+arg.size >= expr.stop:
+                return arg[expr.start - index:expr.stop - index]
+        # Slice(Compose(A, B, C), x) => Compose(A, B, C) with truncated A/B/C
+        out = []
+        for index, arg in expr.arg.iter_args():
+            # arg is before slice start
+            if expr.start >= index + arg.size:
+                continue
+            # arg is after slice stop
+            elif expr.stop <= index:
+                continue
+            # arg is fully included in slice
+            elif expr.start <= index and index + arg.size <= expr.stop:
+                out.append(arg)
+                continue
+            # arg is truncated at start
+            if expr.start > index:
+                slice_start = expr.start - index
+            else:
+                # arg is not truncated at start
+                slice_start = 0
+            # a is truncated at stop
+            if expr.stop < index + arg.size:
+                slice_stop = arg.size + expr.stop - (index + arg.size) - slice_start
+            else:
+                slice_stop = arg.size
+            out.append(arg[slice_start:slice_stop])
+
+        return ExprCompose(*out)
+
+    # ExprMem(x, size)[:A] => ExprMem(x, a)
+    # XXXX todo hum, is it safe?
+    if (expr.arg.is_mem() and
+          expr.start == 0 and
+          expr.arg.size > expr.stop and expr.stop % 8 == 0):
+        return ExprMem(expr.arg.ptr, size=expr.stop)
+    # distributivity of slice and &
+    # (a & int)[x:y] => 0 if int[x:y] == 0
+    if expr.arg.is_op("&") and expr.arg.args[-1].is_int():
+        tmp = e_s.expr_simp(expr.arg.args[-1][expr.start:expr.stop])
+        if tmp.is_int(0):
+            return tmp
+    # distributivity of slice and exprcond
+    # (a?int1:int2)[x:y] => (a?int1[x:y]:int2[x:y])
+    # (a?compose1:compose2)[x:y] => (a?compose1[x:y]:compose2[x:y])
+    if (expr.arg.is_cond() and
+        (expr.arg.src1.is_int() or expr.arg.src1.is_compose()) and
+        (expr.arg.src2.is_int() or expr.arg.src2.is_compose())):
+        src1 = expr.arg.src1[expr.start:expr.stop]
+        src2 = expr.arg.src2[expr.start:expr.stop]
+        return ExprCond(expr.arg.cond, src1, src2)
+
+    # (a * int)[0:y] => (a[0:y] * int[0:y])
+    if expr.start == 0 and expr.arg.is_op("*") and expr.arg.args[-1].is_int():
+        args = [e_s.expr_simp(a[expr.start:expr.stop]) for a in expr.arg.args]
+        return ExprOp(expr.arg.op, *args)
+
+    # (a >> int)[x:y] => a[x+int:y+int] with int+y <= a.size
+    # (a << int)[x:y] => a[x-int:y-int] with x-int >= 0
+    if (expr.arg.is_op() and expr.arg.op in [">>", "<<"] and
+          expr.arg.args[1].is_int()):
+        arg, shift = expr.arg.args
+        shift = int(shift)
+        if expr.arg.op == ">>":
+            if shift + expr.stop <= arg.size:
+                return arg[expr.start + shift:expr.stop + shift]
+        elif expr.arg.op == "<<":
+            if expr.start - shift >= 0:
+                return arg[expr.start - shift:expr.stop - shift]
+        else:
+            raise ValueError('Bad case')
+
+    return expr
+
+
+def simp_compose(e_s, expr):
+    "Commons simplification on ExprCompose"
+    args = merge_sliceto_slice(expr)
+    out = []
+    # compose of compose
+    for arg in args:
+        if arg.is_compose():
+            out += arg.args
+        else:
+            out.append(arg)
+    args = out
+    # Compose(a) with a.size = compose.size => a
+    if len(args) == 1 and args[0].size == expr.size:
+        return args[0]
+
+    # {(X[z:], 0, X.size-z), (0, X.size-z, X.size)} => (X >> z)
+    if len(args) == 2 and args[1].is_int(0):
+        if (args[0].is_slice() and
+            args[0].stop == args[0].arg.size and
+            args[0].size + args[1].size == args[0].arg.size):
+            new_expr = args[0].arg >> ExprInt(args[0].start, args[0].arg.size)
+            return new_expr
+
+    # {@X[base + i] 0 X, @Y[base + i + X] X (X + Y)} => @(X+Y)[base + i]
+    for i, arg in enumerate(args[:-1]):
+        nxt = args[i + 1]
+        if arg.is_mem() and nxt.is_mem():
+            gap = e_s(nxt.ptr - arg.ptr)
+            if gap.is_int() and arg.size % 8 == 0 and int(gap) == arg.size // 8:
+                args = args[:i] + [ExprMem(arg.ptr,
+                                          arg.size + nxt.size)] + args[i + 2:]
+                return ExprCompose(*args)
+    # {A, signext(A)[32:64]} => signext(A)
+    if len(args) == 2 and args[0].size == args[1].size:
+        arg1, arg2 = args
+        size = arg1.size
+        sign_ext = arg1.signExtend(arg1.size*2)
+        if arg2 == sign_ext[size:2*size]:
+            return sign_ext
+
+
+    # {a, x?b:d, x?c:e, f} => x?{a, b, c, f}:{a, d, e, f}
+    conds = set(arg.cond for arg in expr.args if arg.is_cond())
+    if len(conds) == 1:
+        cond = list(conds)[0]
+        args1, args2 = [], []
+        for arg in expr.args:
+            if arg.is_cond():
+                args1.append(arg.src1)
+                args2.append(arg.src2)
+            else:
+                args1.append(arg)
+                args2.append(arg)
+        arg1 = e_s(ExprCompose(*args1))
+        arg2 = e_s(ExprCompose(*args2))
+        return ExprCond(cond, arg1, arg2)
+    return ExprCompose(*args)
+
+def simp_cond(_, expr):
+    """
+    Common simplifications on ExprCond.
+    Eval exprcond src1/src2 with satifiable/unsatisfiable condition propagation
+    """
+    if (not expr.cond.is_int()) and expr.cond.size == 1:
+        src1 = expr.src1.replace_expr({expr.cond: ExprInt(1, 1)})
+        src2 = expr.src2.replace_expr({expr.cond: ExprInt(0, 1)})
+        if src1 != expr.src1 or src2 != expr.src2:
+            return ExprCond(expr.cond, src1, src2)
+
+    # -A ? B:C => A ? B:C
+    if expr.cond.is_op('-') and len(expr.cond.args) == 1:
+        expr = ExprCond(expr.cond.args[0], expr.src1, expr.src2)
+    # a?x:x
+    elif expr.src1 == expr.src2:
+        expr = expr.src1
+    # int ? A:B => A or B
+    elif expr.cond.is_int():
+        if int(expr.cond) == 0:
+            expr = expr.src2
+        else:
+            expr = expr.src1
+    # a?(a?b:c):x => a?b:x
+    elif expr.src1.is_cond() and expr.cond == expr.src1.cond:
+        expr = ExprCond(expr.cond, expr.src1.src1, expr.src2)
+    # a?x:(a?b:c) => a?x:c
+    elif expr.src2.is_cond() and expr.cond == expr.src2.cond:
+        expr = ExprCond(expr.cond, expr.src1, expr.src2.src2)
+    # a|int ? b:c => b with int != 0
+    elif (expr.cond.is_op('|') and
+          expr.cond.args[1].is_int() and
+          expr.cond.args[1].arg != 0):
+        return expr.src1
+
+    # (C?int1:int2)?(A:B) =>
+    elif (expr.cond.is_cond() and
+          expr.cond.src1.is_int() and
+          expr.cond.src2.is_int()):
+        int1 = int(expr.cond.src1)
+        int2 = int(expr.cond.src2)
+        if int1 and int2:
+            expr = expr.src1
+        elif int1 == 0 and int2 == 0:
+            expr = expr.src2
+        elif int1 == 0 and int2:
+            expr = ExprCond(expr.cond.cond, expr.src2, expr.src1)
+        elif int1 and int2 == 0:
+            expr = ExprCond(expr.cond.cond, expr.src1, expr.src2)
+
+    elif expr.cond.is_compose():
+        # {0, X, 0}?(A:B) => X?(A:B)
+        args = [arg for arg in expr.cond.args if not arg.is_int(0)]
+        if len(args) == 1:
+            arg = args.pop()
+            return ExprCond(arg, expr.src1, expr.src2)
+        elif len(args) < len(expr.cond.args):
+            return ExprCond(ExprCompose(*args), expr.src1, expr.src2)
+    return expr
+
+
+def simp_mem(_, expr):
+    """
+    Common simplifications on ExprMem:
+    @32[x?a:b] => x?@32[a]:@32[b]
+    """
+    if expr.ptr.is_cond():
+        cond = expr.ptr
+        ret = ExprCond(cond.cond,
+                       ExprMem(cond.src1, expr.size),
+                       ExprMem(cond.src2, expr.size))
+        return ret
+    return expr
+
+
+
+
+def test_cc_eq_args(expr, *sons_op):
+    """
+    Return True if expression's arguments match the list in sons_op, and their
+    sub arguments are identical. Ex:
+    CC_S<=(
+              FLAG_SIGN_SUB(A, B),
+              FLAG_SUB_OF(A, B),
+              FLAG_EQ_CMP(A, B)
+    )
+    """
+    if not expr.is_op():
+        return False
+    if len(expr.args) != len(sons_op):
+        return False
+    all_args = set()
+    for i, arg in enumerate(expr.args):
+        if not arg.is_op(sons_op[i]):
+            return False
+        all_args.add(arg.args)
+    return len(all_args) == 1
+
+
+def simp_cc_conds(_, expr):
+    """
+    High level simplifications. Example:
+    CC_U<(FLAG_SUB_CF(A, B) => A <u B
+    """
+    if (expr.is_op("CC_U>=") and
+          test_cc_eq_args(
+              expr,
+              "FLAG_SUB_CF"
+          )):
+        expr = ExprCond(
+            ExprOp(TOK_INF_UNSIGNED, *expr.args[0].args),
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size))
+
+    elif (expr.is_op("CC_U<") and
+          test_cc_eq_args(
+              expr,
+              "FLAG_SUB_CF"
+          )):
+        expr = ExprOp(TOK_INF_UNSIGNED, *expr.args[0].args)
+
+    elif (expr.is_op("CC_NEG") and
+          test_cc_eq_args(
+              expr,
+              "FLAG_SIGN_SUB"
+          )):
+        expr = ExprOp(TOK_INF_SIGNED, *expr.args[0].args)
+
+    elif (expr.is_op("CC_POS") and
+          test_cc_eq_args(
+              expr,
+              "FLAG_SIGN_SUB"
+          )):
+        expr = ExprCond(
+            ExprOp(TOK_INF_SIGNED, *expr.args[0].args),
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
+        )
+
+    elif (expr.is_op("CC_EQ") and
+          test_cc_eq_args(
+              expr,
+              "FLAG_EQ"
+          )):
+        arg = expr.args[0].args[0]
+        expr = ExprOp(TOK_EQUAL, arg, ExprInt(0, arg.size))
+
+    elif (expr.is_op("CC_NE") and
+          test_cc_eq_args(
+              expr,
+              "FLAG_EQ"
+          )):
+        arg = expr.args[0].args[0]
+        expr = ExprCond(
+            ExprOp(TOK_EQUAL,arg, ExprInt(0, arg.size)),
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
+        )
+    elif (expr.is_op("CC_NE") and
+          test_cc_eq_args(
+              expr,
+              "FLAG_EQ_CMP"
+          )):
+        expr = ExprCond(
+            ExprOp(TOK_EQUAL, *expr.args[0].args),
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
+        )
+
+    elif (expr.is_op("CC_EQ") and
+          test_cc_eq_args(
+              expr,
+              "FLAG_EQ_CMP"
+          )):
+        expr = ExprOp(TOK_EQUAL, *expr.args[0].args)
+
+    elif (expr.is_op("CC_NE") and
+          test_cc_eq_args(
+              expr,
+              "FLAG_EQ_AND"
+          )):
+        expr = ExprOp("&", *expr.args[0].args)
+
+    elif (expr.is_op("CC_EQ") and
+          test_cc_eq_args(
+              expr,
+              "FLAG_EQ_AND"
+          )):
+        expr = ExprCond(
+            ExprOp("&", *expr.args[0].args),
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
+        )
+
+    elif (expr.is_op("CC_S>") and
+          test_cc_eq_args(
+              expr,
+              "FLAG_SIGN_SUB",
+              "FLAG_SUB_OF",
+              "FLAG_EQ_CMP",
+          )):
+        expr = ExprCond(
+            ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args),
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
+        )
+
+    elif (expr.is_op("CC_S>") and
+          len(expr.args) == 3 and
+          expr.args[0].is_op("FLAG_SIGN_SUB") and
+          expr.args[2].is_op("FLAG_EQ_CMP") and
+          expr.args[0].args == expr.args[2].args and
+          expr.args[1].is_int(0)):
+        expr = ExprCond(
+            ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args),
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
+        )
+
+
+
+    elif (expr.is_op("CC_S>=") and
+          test_cc_eq_args(
+              expr,
+              "FLAG_SIGN_SUB",
+              "FLAG_SUB_OF"
+          )):
+        expr = ExprCond(
+            ExprOp(TOK_INF_SIGNED, *expr.args[0].args),
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
+        )
+
+    elif (expr.is_op("CC_S>=") and
+          len(expr.args) == 2 and
+          expr.args[0].is_op("FLAG_SIGN_SUB") and
+          expr.args[0].args[1].is_int(0) and
+          expr.args[1].is_int(0)):
+        expr = ExprOp(
+            TOK_INF_EQUAL_SIGNED,
+            expr.args[0].args[1],
+            expr.args[0].args[0],
+        )
+
+    elif (expr.is_op("CC_S<") and
+          test_cc_eq_args(
+              expr,
+              "FLAG_SIGN_SUB",
+              "FLAG_SUB_OF"
+          )):
+        expr = ExprOp(TOK_INF_SIGNED, *expr.args[0].args)
+
+    elif (expr.is_op("CC_S<=") and
+          test_cc_eq_args(
+              expr,
+              "FLAG_SIGN_SUB",
+              "FLAG_SUB_OF",
+              "FLAG_EQ_CMP",
+          )):
+        expr = ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args)
+
+    elif (expr.is_op("CC_S<=") and
+          len(expr.args) == 3 and
+          expr.args[0].is_op("FLAG_SIGN_SUB") and
+          expr.args[2].is_op("FLAG_EQ_CMP") and
+          expr.args[0].args == expr.args[2].args and
+          expr.args[1].is_int(0)):
+        expr = ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args)
+
+    elif (expr.is_op("CC_U<=") and
+          test_cc_eq_args(
+              expr,
+              "FLAG_SUB_CF",
+              "FLAG_EQ_CMP",
+          )):
+        expr = ExprOp(TOK_INF_EQUAL_UNSIGNED, *expr.args[0].args)
+
+    elif (expr.is_op("CC_U>") and
+          test_cc_eq_args(
+              expr,
+              "FLAG_SUB_CF",
+              "FLAG_EQ_CMP",
+          )):
+        expr = ExprCond(
+            ExprOp(TOK_INF_EQUAL_UNSIGNED, *expr.args[0].args),
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
+        )
+
+    elif (expr.is_op("CC_S<") and
+          test_cc_eq_args(
+              expr,
+              "FLAG_SIGN_ADD",
+              "FLAG_ADD_OF"
+          )):
+        arg0, arg1 = expr.args[0].args
+        expr = ExprOp(TOK_INF_SIGNED, arg0, -arg1)
+
+    return expr
+
+
+
+def simp_cond_flag(_, expr):
+    """FLAG_EQ_CMP(X, Y)?A:B => (X == Y)?A:B"""
+    cond = expr.cond
+    if cond.is_op("FLAG_EQ_CMP"):
+        return ExprCond(ExprOp(TOK_EQUAL, *cond.args), expr.src1, expr.src2)
+    return expr
+
+
+def simp_sub_cf_zero(_, expr):
+    """FLAG_SUB_CF(0, X) => (X)?1:0"""
+    if not expr.is_op("FLAG_SUB_CF"):
+        return expr
+    if not expr.args[0].is_int(0):
+        return expr
+    return ExprCond(expr.args[1], ExprInt(1, 1), ExprInt(0, 1))
+
+def simp_cond_cc_flag(expr_simp, expr):
+    """
+    ExprCond(CC_><(bit), X, Y) => ExprCond(bit, X, Y)
+    ExprCond(CC_U>=(bit), X, Y) => ExprCond(bit, Y, X)
+    """
+    if not expr.is_cond():
+        return expr
+    if not expr.cond.is_op():
+        return expr
+    expr_op = expr.cond
+    if expr_op.op not in ["CC_U<", "CC_U>="]:
+        return expr
+    arg = expr_op.args[0]
+    if arg.size != 1:
+        return expr
+    if expr_op.op == "CC_U<":
+        return ExprCond(arg, expr.src1, expr.src2)
+    if expr_op.op == "CC_U>=":
+        return ExprCond(arg, expr.src2, expr.src1)
+    return expr
+
+def simp_cond_sub_cf(expr_simp, expr):
+    """
+    ExprCond(FLAG_SUB_CF(A, B), X, Y) => ExprCond(A <u B, X, Y)
+    """
+    if not expr.is_cond():
+        return expr
+    if not expr.cond.is_op("FLAG_SUB_CF"):
+        return expr
+    cond = ExprOp(TOK_INF_UNSIGNED, *expr.cond.args)
+    return ExprCond(cond, expr.src1, expr.src2)
+
+
+def simp_cmp_int(expr_simp, expr):
+    """
+    ({X, 0} == int) => X == int[:]
+    X + int1 == int2 => X == int2-int1
+    X ^ int1 == int2 => X == int1^int2
+    """
+    if (expr.is_op(TOK_EQUAL) and
+          expr.args[1].is_int() and
+          expr.args[0].is_compose() and
+          len(expr.args[0].args) == 2 and
+          expr.args[0].args[1].is_int(0)):
+        # ({X, 0} == int) => X == int[:]
+        src = expr.args[0].args[0]
+        int_val = int(expr.args[1])
+        new_int = ExprInt(int_val, src.size)
+        expr = expr_simp(
+            ExprOp(TOK_EQUAL, src, new_int)
+        )
+    elif not expr.is_op(TOK_EQUAL):
+        return expr
+    assert len(expr.args) == 2
+
+    left, right = expr.args
+    if left.is_int() and not right.is_int():
+        left, right = right, left
+    if not right.is_int():
+        return expr
+    if not (left.is_op() and left.op in ['+', '^']):
+        return expr
+    if not left.args[-1].is_int():
+        return expr
+    # X + int1 == int2 => X == int2-int1
+    # WARNING:
+    # X - 0x10 <=u 0x20 gives X in [0x10 0x30]
+    # which is not equivalet to A <=u 0x10
+
+    left_orig = left
+    left, last_int = left.args[:-1], left.args[-1]
+
+    if len(left) == 1:
+        left = left[0]
+    else:
+        left = ExprOp(left_orig.op, *left)
+
+    if left_orig.op == "+":
+        new_int = expr_simp(right - last_int)
+    elif left_orig.op == '^':
+        new_int = expr_simp(right ^ last_int)
+    else:
+        raise RuntimeError("Unsupported operator")
+
+    expr = expr_simp(
+        ExprOp(TOK_EQUAL, left, new_int),
+    )
+    return expr
+
+
+
+def simp_cmp_int_arg(_, expr):
+    """
+    (0x10 <= R0) ? A:B
+    =>
+    (R0 < 0x10) ? B:A
+    """
+    cond = expr.cond
+    if not cond.is_op():
+        return expr
+    op = cond.op
+    if op not in [
+            TOK_EQUAL,
+            TOK_INF_SIGNED,
+            TOK_INF_EQUAL_SIGNED,
+            TOK_INF_UNSIGNED,
+            TOK_INF_EQUAL_UNSIGNED
+    ]:
+        return expr
+    arg1, arg2 = cond.args
+    if arg2.is_int():
+        return expr
+    if not arg1.is_int():
+        return expr
+    src1, src2 = expr.src1, expr.src2
+    if op == TOK_EQUAL:
+        return ExprCond(ExprOp(TOK_EQUAL, arg2, arg1), src1, src2)
+
+    arg1, arg2 = arg2, arg1
+    src1, src2 = src2, src1
+    if op == TOK_INF_SIGNED:
+        op = TOK_INF_EQUAL_SIGNED
+    elif op == TOK_INF_EQUAL_SIGNED:
+        op = TOK_INF_SIGNED
+    elif op == TOK_INF_UNSIGNED:
+        op = TOK_INF_EQUAL_UNSIGNED
+    elif op == TOK_INF_EQUAL_UNSIGNED:
+        op = TOK_INF_UNSIGNED
+    return ExprCond(ExprOp(op, arg1, arg2), src1, src2)
+
+
+
+def simp_cmp_bijective_op(expr_simp, expr):
+    """
+    A + B == A => A == 0
+
+    X + A == X + B => A == B
+    X ^ A == X ^ B => A == B
+
+    TODO:
+    3 * A + B == A + C => 2 * A + B == C
+    """
+
+    if not expr.is_op(TOK_EQUAL):
+        return expr
+    op_a = expr.args[0]
+    op_b = expr.args[1]
+
+    # a == a
+    if op_a == op_b:
+        return ExprInt(1, 1)
+
+    # Case:
+    # a + b + c == a
+    if op_a.is_op() and op_a.op in ["+", "^"]:
+        args = list(op_a.args)
+        if op_b in args:
+            args.remove(op_b)
+            if not args:
+                raise ValueError("Can be here")
+            elif len(args) == 1:
+                op_a = args[0]
+            else:
+                op_a = ExprOp(op_a.op, *args)
+            return ExprOp(TOK_EQUAL, op_a, ExprInt(0, args[0].size))
+    # a == a + b + c
+    if op_b.is_op() and op_b.op in ["+", "^"]:
+        args = list(op_b.args)
+        if op_a in args:
+            args.remove(op_a)
+            if not args:
+                raise ValueError("Can be here")
+            elif len(args) == 1:
+                op_b = args[0]
+            else:
+                op_b = ExprOp(op_b.op, *args)
+            return ExprOp(TOK_EQUAL, op_b, ExprInt(0, args[0].size))
+
+    if not (op_a.is_op() and op_b.is_op()):
+        return expr
+    if op_a.op != op_b.op:
+        return expr
+    op = op_a.op
+    if op not in ["+", "^"]:
+        return expr
+    common = set(op_a.args).intersection(op_b.args)
+    if not common:
+        return expr
+
+    args_a = list(op_a.args)
+    args_b = list(op_b.args)
+    for value in common:
+        while value in args_a and value in args_b:
+            args_a.remove(value)
+            args_b.remove(value)
+
+    # a + b == a + b + c
+    if not args_a:
+        return ExprOp(TOK_EQUAL, ExprOp(op, *args_b), ExprInt(0, args_b[0].size))
+    # a + b + c == a + b
+    if not args_b:
+        return ExprOp(TOK_EQUAL, ExprOp(op, *args_a), ExprInt(0, args_a[0].size))
+
+    arg_a = ExprOp(op, *args_a)
+    arg_b = ExprOp(op, *args_b)
+    return ExprOp(TOK_EQUAL, arg_a, arg_b)
+
+
+def simp_subwc_cf(_, expr):
+    """SUBWC_CF(A, B, SUB_CF(C, D)) => SUB_CF({A, C}, {B, D})"""
+    if not expr.is_op('FLAG_SUBWC_CF'):
+        return expr
+    op3 = expr.args[2]
+    if not op3.is_op("FLAG_SUB_CF"):
+        return expr
+
+    op1 = ExprCompose(expr.args[0], op3.args[0])
+    op2 = ExprCompose(expr.args[1], op3.args[1])
+
+    return ExprOp("FLAG_SUB_CF", op1, op2)
+
+
+def simp_subwc_of(_, expr):
+    """SUBWC_OF(A, B, SUB_CF(C, D)) => SUB_OF({A, C}, {B, D})"""
+    if not expr.is_op('FLAG_SUBWC_OF'):
+        return expr
+    op3 = expr.args[2]
+    if not op3.is_op("FLAG_SUB_CF"):
+        return expr
+
+    op1 = ExprCompose(expr.args[0], op3.args[0])
+    op2 = ExprCompose(expr.args[1], op3.args[1])
+
+    return ExprOp("FLAG_SUB_OF", op1, op2)
+
+
+def simp_sign_subwc_cf(_, expr):
+    """SIGN_SUBWC(A, B, SUB_CF(C, D)) => SIGN_SUB({A, C}, {B, D})"""
+    if not expr.is_op('FLAG_SIGN_SUBWC'):
+        return expr
+    op3 = expr.args[2]
+    if not op3.is_op("FLAG_SUB_CF"):
+        return expr
+
+    op1 = ExprCompose(expr.args[0], op3.args[0])
+    op2 = ExprCompose(expr.args[1], op3.args[1])
+
+    return ExprOp("FLAG_SIGN_SUB", op1, op2)
+
+def simp_double_zeroext(_, expr):
+    """A.zeroExt(X).zeroExt(Y) => A.zeroExt(Y)"""
+    if not (expr.is_op() and expr.op.startswith("zeroExt")):
+        return expr
+    arg1 = expr.args[0]
+    if not (arg1.is_op() and arg1.op.startswith("zeroExt")):
+        return expr
+    arg2 = arg1.args[0]
+    return ExprOp(expr.op, arg2)
+
+def simp_double_signext(_, expr):
+    """A.signExt(X).signExt(Y) => A.signExt(Y)"""
+    if not (expr.is_op() and expr.op.startswith("signExt")):
+        return expr
+    arg1 = expr.args[0]
+    if not (arg1.is_op() and arg1.op.startswith("signExt")):
+        return expr
+    arg2 = arg1.args[0]
+    return ExprOp(expr.op, arg2)
+
+def simp_zeroext_eq_cst(_, expr):
+    """A.zeroExt(X) == int => A == int[:A.size]"""
+    if not expr.is_op(TOK_EQUAL):
+        return expr
+    arg1, arg2 = expr.args
+    if not arg2.is_int():
+        return expr
+    if not (arg1.is_op() and arg1.op.startswith("zeroExt")):
+        return expr
+    src = arg1.args[0]
+    if int(arg2) > (1 << src.size):
+        # Always false
+        return ExprInt(0, expr.size)
+    return ExprOp(TOK_EQUAL, src, ExprInt(int(arg2), src.size))
+
+def simp_cond_zeroext(_, expr):
+    """
+    X.zeroExt()?(A:B) => X ? A:B
+    X.signExt()?(A:B) => X ? A:B
+    """
+    if not (
+            expr.cond.is_op() and
+            (
+                expr.cond.op.startswith("zeroExt") or
+                expr.cond.op.startswith("signExt")
+            )
+    ):
+        return expr
+
+    ret = ExprCond(expr.cond.args[0], expr.src1, expr.src2)
+    return ret
+
+def simp_ext_eq_ext(_, expr):
+    """
+    A.zeroExt(X) == B.zeroExt(X) => A == B
+    A.signExt(X) == B.signExt(X) => A == B
+    """
+    if not expr.is_op(TOK_EQUAL):
+        return expr
+    arg1, arg2 = expr.args
+    if (not ((arg1.is_op() and arg1.op.startswith("zeroExt") and
+              arg2.is_op() and arg2.op.startswith("zeroExt")) or
+             (arg1.is_op() and arg1.op.startswith("signExt") and
+               arg2.is_op() and arg2.op.startswith("signExt")))):
+        return expr
+    if arg1.args[0].size != arg2.args[0].size:
+        return expr
+    return ExprOp(TOK_EQUAL, arg1.args[0], arg2.args[0])
+
+def simp_cond_eq_zero(_, expr):
+    """(X == 0)?(A:B) => X?(B:A)"""
+    cond = expr.cond
+    if not cond.is_op(TOK_EQUAL):
+        return expr
+    arg1, arg2 = cond.args
+    if not arg2.is_int(0):
+        return expr
+    new_expr = ExprCond(arg1, expr.src2, expr.src1)
+    return new_expr
+
+def simp_sign_inf_zeroext(expr_s, expr):
+    """
+    [!] Ensure before: X.zeroExt(X.size) => X
+
+    X.zeroExt() <s 0 => 0
+    X.zeroExt() <=s 0 => X == 0
+
+    X.zeroExt() <s cst => X.zeroExt() <u cst (cst positive)
+    X.zeroExt() <=s cst => X.zeroExt() <=u cst (cst positive)
+
+    X.zeroExt() <s cst => 0 (cst negative)
+    X.zeroExt() <=s cst => 0 (cst negative)
+
+    """
+    if not (expr.is_op(TOK_INF_SIGNED) or expr.is_op(TOK_INF_EQUAL_SIGNED)):
+        return expr
+    arg1, arg2 = expr.args
+    if not arg2.is_int():
+        return expr
+    if not (arg1.is_op() and arg1.op.startswith("zeroExt")):
+        return expr
+    src = arg1.args[0]
+    assert src.size < arg1.size
+
+    # If cst is zero
+    if arg2.is_int(0):
+        if expr.is_op(TOK_INF_SIGNED):
+            # X.zeroExt() <s 0 => 0
+            return ExprInt(0, expr.size)
+        else:
+            # X.zeroExt() <=s 0 => X == 0
+            return ExprOp(TOK_EQUAL, src, ExprInt(0, src.size))
+
+    # cst is not zero
+    cst = int(arg2)
+    if cst & (1 << (arg2.size - 1)):
+        # cst is negative
+        return ExprInt(0, expr.size)
+    # cst is positive
+    if expr.is_op(TOK_INF_SIGNED):
+        # X.zeroExt() <s cst => X.zeroExt() <u cst (cst positive)
+        return ExprOp(TOK_INF_UNSIGNED, src, expr_s(arg2[:src.size]))
+    # X.zeroExt() <=s cst => X.zeroExt() <=u cst (cst positive)
+    return ExprOp(TOK_INF_EQUAL_UNSIGNED, src, expr_s(arg2[:src.size]))
+
+
+def simp_zeroext_and_cst_eq_cst(expr_s, expr):
+    """
+    A.zeroExt(X) & ... & int == int => A & ... & int[:A.size] == int[:A.size]
+    """
+    if not expr.is_op(TOK_EQUAL):
+        return expr
+    arg1, arg2 = expr.args
+    if not arg2.is_int():
+        return expr
+    if not arg1.is_op('&'):
+        return expr
+    is_ok = True
+    sizes = set()
+    for arg in arg1.args:
+        if arg.is_int():
+            continue
+        if (arg.is_op() and
+            arg.op.startswith("zeroExt")):
+            sizes.add(arg.args[0].size)
+            continue
+        is_ok = False
+        break
+    if not is_ok:
+        return expr
+    if len(sizes) != 1:
+        return expr
+    size = list(sizes)[0]
+    if int(arg2) > ((1 << size) - 1):
+        return expr
+    args = [expr_s(arg[:size]) for arg in arg1.args]
+    left = ExprOp('&', *args)
+    right = expr_s(arg2[:size])
+    ret = ExprOp(TOK_EQUAL, left, right)
+    return ret
+
+
+def test_one_bit_set(arg):
+    """
+    Return True if arg has form 1 << X
+    """
+    return arg != 0  and ((arg & (arg - 1)) == 0)
+
+def simp_x_and_cst_eq_cst(_, expr):
+    """
+    (x & ... & onebitmask == onebitmask) ? A:B => (x & ... & onebitmask) ? A:B
+    """
+    cond = expr.cond
+    if not cond.is_op(TOK_EQUAL):
+        return expr
+    arg1, mask2 = cond.args
+    if not mask2.is_int():
+        return expr
+    if not test_one_bit_set(int(mask2)):
+        return expr
+    if not arg1.is_op('&'):
+        return expr
+    mask1 = arg1.args[-1]
+    if mask1 != mask2:
+        return expr
+    cond = ExprOp('&', *arg1.args)
+    return ExprCond(cond, expr.src1, expr.src2)
+
+def simp_cmp_int_int(_, expr):
+    """
+    IntA <s IntB => int
+    IntA <u IntB => int
+    IntA <=s IntB => int
+    IntA <=u IntB => int
+    IntA == IntB => int
+    """
+    if expr.op not in [
+            TOK_EQUAL,
+            TOK_INF_SIGNED, TOK_INF_UNSIGNED,
+            TOK_INF_EQUAL_SIGNED, TOK_INF_EQUAL_UNSIGNED,
+    ]:
+        return expr
+    if not all(arg.is_int() for arg in expr.args):
+        return expr
+    int_a, int_b = expr.args
+    if expr.is_op(TOK_EQUAL):
+        if int_a == int_b:
+            return ExprInt(1, 1)
+        return ExprInt(0, expr.size)
+
+    if expr.op in [TOK_INF_SIGNED, TOK_INF_EQUAL_SIGNED]:
+        int_a = int(mod_size2int[int_a.size](int(int_a)))
+        int_b = int(mod_size2int[int_b.size](int(int_b)))
+    else:
+        int_a = int(mod_size2uint[int_a.size](int(int_a)))
+        int_b = int(mod_size2uint[int_b.size](int(int_b)))
+
+    if expr.op in [TOK_INF_SIGNED, TOK_INF_UNSIGNED]:
+        ret = int_a < int_b
+    else:
+        ret = int_a <= int_b
+
+    if ret:
+        ret = 1
+    else:
+        ret = 0
+    return ExprInt(ret, 1)
+
+
+def simp_ext_cst(_, expr):
+    """
+    Int.zeroExt(X) => Int
+    Int.signExt(X) => Int
+    """
+    if not (expr.op.startswith("zeroExt") or expr.op.startswith("signExt")):
+        return expr
+    arg = expr.args[0]
+    if not arg.is_int():
+        return expr
+    if expr.op.startswith("zeroExt"):
+        ret = int(arg)
+    else:
+        ret = int(mod_size2int[arg.size](int(arg)))
+    ret = ExprInt(ret, expr.size)
+    return ret
+
+
+
+def simp_ext_cond_int(e_s, expr):
+    """
+    zeroExt(ExprCond(X, Int, Int)) => ExprCond(X, Int, Int)
+    """
+    if not (expr.op.startswith("zeroExt") or expr.op.startswith("signExt")):
+        return expr
+    arg = expr.args[0]
+    if not arg.is_cond():
+        return expr
+    if not (arg.src1.is_int() and arg.src2.is_int()):
+        return expr
+    src1 = ExprOp(expr.op, arg.src1)
+    src2 = ExprOp(expr.op, arg.src2)
+    return e_s(ExprCond(arg.cond, src1, src2))
+
+
+def simp_slice_of_ext(_, expr):
+    """
+    C.zeroExt(X)[A:B] => 0 if A >= size(C)
+    C.zeroExt(X)[A:B] => C[A:B] if B <= size(C)
+    A.zeroExt(X)[0:Y] => A.zeroExt(Y)
+    """
+    if not expr.arg.is_op():
+        return expr
+    if not expr.arg.op.startswith("zeroExt"):
+        return expr
+    arg = expr.arg.args[0]
+
+    if expr.start >= arg.size:
+        # C.zeroExt(X)[A:B] => 0 if A >= size(C)
+        return ExprInt(0, expr.size)
+    if expr.stop <= arg.size:
+        # C.zeroExt(X)[A:B] => C[A:B] if B <= size(C)
+        return arg[expr.start:expr.stop]
+    if expr.start == 0:
+        # A.zeroExt(X)[0:Y] => A.zeroExt(Y)
+        return arg.zeroExtend(expr.stop)
+    return expr
+
+def simp_slice_of_sext(e_s, expr):
+    """
+    with Y <= size(A)
+    A.signExt(X)[0:Y] => A[0:Y]
+    """
+    if not expr.arg.is_op():
+        return expr
+    if not expr.arg.op.startswith("signExt"):
+        return expr
+    arg = expr.arg.args[0]
+    if expr.start != 0:
+        return expr
+    if expr.stop <= arg.size:
+        return e_s.expr_simp(arg[:expr.stop])
+    return expr
+
+
+def simp_slice_of_op_ext(expr_s, expr):
+    """
+    (X.zeroExt() + {Z, } + ... + Int)[0:8] => X + ... + int[:]
+    (X.zeroExt() | ... | Int)[0:8] => X | ... | int[:]
+    ...
+    """
+    if expr.start != 0:
+        return expr
+    src = expr.arg
+    if not src.is_op():
+        return expr
+    if src.op not in ['+', '|', '^', '&']:
+        return expr
+    is_ok = True
+    for arg in src.args:
+        if arg.is_int():
+            continue
+        if (arg.is_op() and
+            arg.op.startswith("zeroExt") and
+            arg.args[0].size == expr.stop):
+            continue
+        if arg.is_compose():
+            continue
+        is_ok = False
+        break
+    if not is_ok:
+        return expr
+    args = [expr_s(arg[:expr.stop]) for arg in src.args]
+    return ExprOp(src.op, *args)
+
+
+def simp_cond_logic_ext(expr_s, expr):
+    """(X.zeroExt() + ... + Int) ? A:B => X + ... + int[:] ? A:B"""
+    cond = expr.cond
+    if not cond.is_op():
+        return expr
+    if cond.op not in ["&", "^", "|"]:
+        return expr
+    is_ok = True
+    sizes = set()
+    for arg in cond.args:
+        if arg.is_int():
+            continue
+        if (arg.is_op() and
+            arg.op.startswith("zeroExt")):
+            sizes.add(arg.args[0].size)
+            continue
+        is_ok = False
+        break
+    if not is_ok:
+        return expr
+    if len(sizes) != 1:
+        return expr
+    size = list(sizes)[0]
+    args = [expr_s(arg[:size]) for arg in cond.args]
+    cond = ExprOp(cond.op, *args)
+    return ExprCond(cond, expr.src1, expr.src2)
+
+
+def simp_cond_sign_bit(_, expr):
+    """(a & .. & 0x80000000) ? A:B => (a & ...) <s 0 ? A:B"""
+    cond = expr.cond
+    if not cond.is_op('&'):
+        return expr
+    last = cond.args[-1]
+    if not last.is_int(1 << (last.size - 1)):
+        return expr
+    zero = ExprInt(0, expr.cond.size)
+    if len(cond.args) == 2:
+        args = [cond.args[0], zero]
+    else:
+        args = [ExprOp('&', *list(cond.args[:-1])), zero]
+    cond = ExprOp(TOK_INF_SIGNED, *args)
+    return ExprCond(cond, expr.src1, expr.src2)
+
+
+def simp_cond_add(expr_s, expr):
+    """
+    (a+b)?X:Y => (a == b)?Y:X
+    (a^b)?X:Y => (a == b)?Y:X
+    """
+    cond = expr.cond
+    if not cond.is_op():
+        return expr
+    if cond.op not in ['+', '^']:
+        return expr
+    if len(cond.args) != 2:
+        return expr
+    arg1, arg2 = cond.args
+    if cond.is_op('+'):
+        new_cond = ExprOp('==', arg1, expr_s(-arg2))
+    elif cond.is_op('^'):
+        new_cond = ExprOp('==', arg1, arg2)
+    else:
+        raise ValueError('Bad case')
+    return ExprCond(new_cond, expr.src2, expr.src1)
+
+
+def simp_cond_eq_1_0(expr_s, expr):
+    """
+    (a == b)?ExprInt(1, 1):ExprInt(0, 1) => a == b
+    (a <s b)?ExprInt(1, 1):ExprInt(0, 1) => a == b
+    ...
+    """
+    cond = expr.cond
+    if not cond.is_op():
+        return expr
+    if cond.op not in [
+            TOK_EQUAL,
+            TOK_INF_SIGNED, TOK_INF_EQUAL_SIGNED,
+            TOK_INF_UNSIGNED, TOK_INF_EQUAL_UNSIGNED
+            ]:
+        return expr
+    if expr.src1 != ExprInt(1, 1) or expr.src2 != ExprInt(0, 1):
+        return expr
+    return cond
+
+
+def simp_cond_inf_eq_unsigned_zero(expr_s, expr):
+    """
+    (a <=u 0) => a == 0
+    """
+    if not expr.is_op(TOK_INF_EQUAL_UNSIGNED):
+        return expr
+    if not expr.args[1].is_int(0):
+        return expr
+    return ExprOp(TOK_EQUAL, expr.args[0], expr.args[1])
+
+
+def simp_test_signext_inf(expr_s, expr):
+    """A.signExt() <s int => A <s int[:]"""
+    if not (expr.is_op(TOK_INF_SIGNED) or expr.is_op(TOK_INF_EQUAL_SIGNED)):
+        return expr
+    arg, cst = expr.args
+    if not (arg.is_op() and arg.op.startswith("signExt")):
+        return expr
+    if not cst.is_int():
+        return expr
+    base = arg.args[0]
+    tmp = int(mod_size2int[cst.size](int(cst)))
+    if -(1 << (base.size - 1)) <= tmp < (1 << (base.size - 1)):
+        # Can trunc integer
+        return ExprOp(expr.op, base, expr_s(cst[:base.size]))
+    if (tmp >= (1 << (base.size - 1)) or
+        tmp < -(1 << (base.size - 1)) ):
+        return ExprInt(1, 1)
+    return expr
+
+
+def simp_test_zeroext_inf(expr_s, expr):
+    """A.zeroExt() <u int => A <u int[:]"""
+    if not (expr.is_op(TOK_INF_UNSIGNED) or expr.is_op(TOK_INF_EQUAL_UNSIGNED)):
+        return expr
+    arg, cst = expr.args
+    if not (arg.is_op() and arg.op.startswith("zeroExt")):
+        return expr
+    if not cst.is_int():
+        return expr
+    base = arg.args[0]
+    tmp = int(mod_size2uint[cst.size](int(cst)))
+    if 0 <= tmp < (1 << base.size):
+        # Can trunc integer
+        return ExprOp(expr.op, base, expr_s(cst[:base.size]))
+    if tmp >= (1 << base.size):
+        return ExprInt(1, 1)
+    return expr
+
+
+def simp_add_multiple(_, expr):
+    """
+    X + X => 2 * X
+    X + X * int1 => X * (1 + int1)
+    X * int1 + (- X) => X * (int1 - 1)
+    X + (X << int1) => X * (1 + 2 ** int1)
+    Correct even if addition overflow/underflow
+    """
+    if not expr.is_op('+'):
+        return expr
+
+    # Extract each argument and its counter
+    operands = {}
+    for arg in expr.args:
+        if arg.is_op('*') and arg.args[1].is_int():
+            base_expr, factor = arg.args
+            operands[base_expr] = operands.get(base_expr, 0) + int(factor)
+        elif arg.is_op('<<') and arg.args[1].is_int():
+            base_expr, factor = arg.args
+            operands[base_expr] = operands.get(base_expr, 0) + 2 ** int(factor)
+        elif arg.is_op("-"):
+            arg = arg.args[0]
+            if arg.is_op('<<') and arg.args[1].is_int():
+                base_expr, factor = arg.args
+                operands[base_expr] = operands.get(base_expr, 0) - (2 ** int(factor))
+            else:
+                operands[arg] = operands.get(arg, 0) - 1
+        else:
+            operands[arg] = operands.get(arg, 0) + 1
+    out = []
+
+    # Best effort to factor common args:
+    # (a + b) * 3 + a + b => (a + b) * 4
+    # Does not factor:
+    # (a + b) * 3 + 2 * a + b => (a + b) * 4 + a
+    modified = True
+    while modified:
+        modified = False
+        for arg, count in list(viewitems(operands)):
+            if not arg.is_op('+'):
+                continue
+            components = arg.args
+            if not all(component in operands for component in components):
+                continue
+            counters = set(operands[component] for component in components)
+            if len(counters) != 1:
+                continue
+            counter = counters.pop()
+            for component in components:
+                del operands[component]
+            operands[arg] += counter
+            modified = True
+            break
+
+    for arg, count in viewitems(operands):
+        if count == 0:
+            continue
+        if count == 1:
+            out.append(arg)
+            continue
+        out.append(arg * ExprInt(count, expr.size))
+
+    if len(out) == len(expr.args):
+        # No reductions
+        return expr
+    if not out:
+        return ExprInt(0, expr.size)
+    if len(out) == 1:
+        return out[0]
+    return ExprOp('+', *out)
+
+def simp_compose_and_mask(_, expr):
+    """
+    {X 0 8, Y 8 32} & 0xFF => zeroExt(X)
+    {X 0 8, Y 8 16, Z 16 32} & 0xFFFF => {X 0 8, Y 8 16, 0x0 16 32}
+    {X 0 8, 0x123456 8 32} & 0xFFFFFF => {X 0 8, 0x1234 8 24, 0x0 24 32}
+    """
+    if not expr.is_op('&'):
+        return expr
+    # handle the case where arg2 = arg1.mask
+    if len(expr.args) != 2:
+        return expr
+    arg1, arg2 = expr.args
+    if not arg1.is_compose():
+        return expr
+    if not arg2.is_int():
+        return expr
+    int2 = int(arg2)
+    if (int2 + 1) & int2 != 0:
+        return expr
+    mask_size = int2.bit_length() + 7 // 8
+    out = []
+    for offset, arg in arg1.iter_args():
+        if offset == mask_size:
+            return ExprCompose(*out).zeroExtend(expr.size)
+        elif mask_size > offset and mask_size < offset+arg.size and arg.is_int():
+            out.append(ExprSlice(arg, 0, mask_size-offset))
+            return ExprCompose(*out).zeroExtend(expr.size)
+        else:
+            out.append(arg)
+    return expr
+
+def simp_bcdadd_cf(_, expr):
+    """bcdadd(const, const) => decimal"""
+    if not(expr.is_op('bcdadd_cf')):
+        return expr
+    arg1 = expr.args[0]
+    arg2 = expr.args[1]
+    if not(arg1.is_int() and arg2.is_int()):
+        return expr
+
+    carry = 0
+    res = 0
+    nib_1, nib_2 = 0, 0
+    for i in range(0,16,4):
+        nib_1 = (arg1.arg >> i) & (0xF)
+        nib_2 = (arg2.arg >> i) & (0xF)
+
+        j = (carry + nib_1 + nib_2)
+        if (j >= 10):
+            carry = 1
+            j -= 10
+            j &= 0xF
+        else:
+            carry = 0
+    return ExprInt(carry, 1)
+
+def simp_bcdadd(_, expr):
+    """bcdadd(const, const) => decimal"""
+    if not(expr.is_op('bcdadd')):
+        return expr
+    arg1 = expr.args[0]
+    arg2 = expr.args[1]
+    if not(arg1.is_int() and arg2.is_int()):
+        return expr
+
+    carry = 0
+    res = 0
+    nib_1, nib_2 = 0, 0
+    for i in range(0,16,4):
+        nib_1 = (arg1.arg >> i) & (0xF)
+        nib_2 = (arg2.arg >> i) & (0xF)
+
+        j = (carry + nib_1 + nib_2)
+        if (j >= 10):
+            carry = 1
+            j -= 10
+            j &= 0xF
+        else:
+            carry = 0
+        res += j << i
+    return ExprInt(res, arg1.size)
+
+
+def simp_smod_sext(expr_s, expr):
+    """
+    a.size == b.size
+    smod(a.signExtend(X), b.signExtend(X)) => smod(a, b).signExtend(X)
+    """
+    if not expr.is_op("smod"):
+        return expr
+    arg1, arg2 = expr.args
+    if arg1.is_op() and arg1.op.startswith("signExt"):
+        src1 = arg1.args[0]
+        if arg2.is_op() and arg2.op.startswith("signExt"):
+            src2 = arg2.args[0]
+            if src1.size == src2.size:
+                # Case: a.signext(), b.signext()
+                return ExprOp("smod", src1, src2).signExtend(expr.size)
+            return expr
+        elif arg2.is_int():
+            src2 = expr_s.expr_simp(arg2[:src1.size])
+            if expr_s.expr_simp(src2.signExtend(arg2.size)) == arg2:
+                # Case: a.signext(), int
+                return ExprOp("smod", src1, src2).signExtend(expr.size)
+            return expr
+    # Case: int        , b.signext()
+    if arg2.is_op() and arg2.op.startswith("signExt"):
+        src2 = arg2.args[0]
+        if arg1.is_int():
+            src1 = expr_s.expr_simp(arg1[:src2.size])
+            if expr_s.expr_simp(src1.signExtend(arg1.size)) == arg1:
+                # Case: int, b.signext()
+                return ExprOp("smod", src1, src2).signExtend(expr.size)
+    return expr
+
+# FLAG_SUB_OF(CST1, CST2) => CST
+def simp_flag_cst(expr_simp, expr):
+    if expr.op not in [
+            "FLAG_EQ", "FLAG_EQ_AND", "FLAG_SIGN_SUB", "FLAG_EQ_CMP", "FLAG_ADD_CF",
+            "FLAG_SUB_CF", "FLAG_ADD_OF", "FLAG_SUB_OF", "FLAG_EQ_ADDWC", "FLAG_ADDWC_OF",
+            "FLAG_SUBWC_OF", "FLAG_ADDWC_CF", "FLAG_SUBWC_CF", "FLAG_SIGN_ADDWC",
+            "FLAG_SIGN_SUBWC", "FLAG_EQ_SUBWC",
+            "CC_U<=", "CC_U>=", "CC_S<", "CC_S>", "CC_S<=", "CC_S>=", "CC_U>",
+            "CC_U<", "CC_NEG", "CC_EQ", "CC_NE", "CC_POS"
+    ]:
+        return expr
+    if not all(arg.is_int() for arg in expr.args):
+        return expr
+    new_expr = expr_simp(simp_flags(expr_simp, expr))
+    return new_expr