about summary refs log tree commit diff stats
path: root/src/miasm/core/cpu.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/miasm/core/cpu.py')
-rw-r--r--src/miasm/core/cpu.py1715
1 files changed, 1715 insertions, 0 deletions
diff --git a/src/miasm/core/cpu.py b/src/miasm/core/cpu.py
new file mode 100644
index 00000000..7df9f991
--- /dev/null
+++ b/src/miasm/core/cpu.py
@@ -0,0 +1,1715 @@
+#-*- coding:utf-8 -*-
+
+from builtins import range
+import re
+import struct
+import logging
+from collections import defaultdict
+
+
+from future.utils import viewitems, viewvalues
+
+import pyparsing
+
+from miasm.core.utils import decode_hex
+import miasm.expression.expression as m2_expr
+from miasm.core.bin_stream import bin_stream, bin_stream_str
+from miasm.core.utils import Disasm_Exception
+from miasm.expression.simplifications import expr_simp
+
+
+from miasm.core.asm_ast import AstNode, AstInt, AstId, AstOp
+from miasm.core import utils
+from future.utils import with_metaclass
+
+log = logging.getLogger("cpuhelper")
+console_handler = logging.StreamHandler()
+console_handler.setFormatter(logging.Formatter("[%(levelname)-8s]: %(message)s"))
+log.addHandler(console_handler)
+log.setLevel(logging.WARN)
+
+
+class bitobj(object):
+
+    def __init__(self, s=b""):
+        if not s:
+            bits = []
+        else:
+            bits = [int(x) for x in bin(int(encode_hex(s), 16))[2:]]
+            if len(bits) % 8:
+                bits = [0 for x in range(8 - (len(bits) % 8))] + bits
+        self.bits = bits
+        self.offset = 0
+
+    def __len__(self):
+        return len(self.bits) - self.offset
+
+    def getbits(self, n):
+        if not n:
+            return 0
+        if n > len(self.bits) - self.offset:
+            raise ValueError('not enough bits %r %r' % (n, len(self.bits)))
+        b = self.bits[self.offset:self.offset + n]
+        b = int("".join(str(x) for x in b), 2)
+        self.offset += n
+        return b
+
+    def putbits(self, b, n):
+        if not n:
+            return
+        bits = list(bin(b)[2:])
+        bits = [int(x) for x in bits]
+        bits = [0 for x in range(n - len(bits))] + bits
+        self.bits += bits
+
+    def tostring(self):
+        if len(self.bits) % 8:
+            raise ValueError(
+                'num bits must be 8 bit aligned: %d' % len(self.bits)
+            )
+        b = int("".join(str(x) for x in self.bits), 2)
+        b = "%X" % b
+        b = '0' * (len(self.bits) // 4 - len(b)) + b
+        b = decode_hex(b.encode())
+        return b
+
+    def reset(self):
+        self.offset = 0
+
+    def copy_state(self):
+        b = self.__class__()
+        b.bits = self.bits
+        b.offset = self.offset
+        return b
+
+
+def literal_list(l):
+    l = l[:]
+    l.sort()
+    l = l[::-1]
+    o = pyparsing.Literal(l[0])
+    for x in l[1:]:
+        o |= pyparsing.Literal(x)
+    return o
+
+
+class reg_info(object):
+
+    def __init__(self, reg_str, reg_expr):
+        self.str = reg_str
+        self.expr = reg_expr
+        self.parser = literal_list(reg_str).setParseAction(self.cb_parse)
+
+    def cb_parse(self, tokens):
+        assert len(tokens) == 1
+        i = self.str.index(tokens[0])
+        reg = self.expr[i]
+        result = AstId(reg)
+        return result
+
+    def reg2expr(self, s):
+        i = self.str.index(s[0])
+        return self.expr[i]
+
+    def expr2regi(self, e):
+        return self.expr.index(e)
+
+
+class reg_info_dct(object):
+
+    def __init__(self, reg_expr):
+        self.dct_str_inv = dict((v.name, k) for k, v in viewitems(reg_expr))
+        self.dct_expr = reg_expr
+        self.dct_expr_inv = dict((v, k) for k, v in viewitems(reg_expr))
+        reg_str = [v.name for v in viewvalues(reg_expr)]
+        self.parser = literal_list(reg_str).setParseAction(self.cb_parse)
+
+    def cb_parse(self, tokens):
+        assert len(tokens) == 1
+        i = self.dct_str_inv[tokens[0]]
+        reg = self.dct_expr[i]
+        result = AstId(reg)
+        return result
+
+    def reg2expr(self, s):
+        i = self.dct_str_inv[s[0]]
+        return self.dct_expr[i]
+
+    def expr2regi(self, e):
+        return self.dct_expr_inv[e]
+
+
+def gen_reg(reg_name, sz=32):
+    """Gen reg expr and parser"""
+    reg = m2_expr.ExprId(reg_name, sz)
+    reginfo = reg_info([reg_name], [reg])
+    return reg, reginfo
+
+
+def gen_reg_bs(reg_name, reg_info, base_cls):
+    """
+    Generate:
+        class bs_reg_name(base_cls):
+            reg = reg_info
+
+        bs_reg_name = bs(l=0, cls=(bs_reg_name,))
+    """
+
+    bs_name = "bs_%s" % reg_name
+    cls = type(bs_name, base_cls, {'reg': reg_info})
+
+    bs_obj = bs(l=0, cls=(cls,))
+
+    return cls, bs_obj
+
+
+def gen_regs(rnames, env, sz=32):
+    regs_str = []
+    regs_expr = []
+    regs_init = []
+    for rname in rnames:
+        r = m2_expr.ExprId(rname, sz)
+        r_init = m2_expr.ExprId(rname+'_init', sz)
+        regs_str.append(rname)
+        regs_expr.append(r)
+        regs_init.append(r_init)
+        env[rname] = r
+
+    reginfo = reg_info(regs_str, regs_expr)
+    return regs_expr, regs_init, reginfo
+
+
+LPARENTHESIS = pyparsing.Literal("(")
+RPARENTHESIS = pyparsing.Literal(")")
+
+
+def int2expr(tokens):
+    v = tokens[0]
+    return (m2_expr.ExprInt, v)
+
+
+def parse_op(tokens):
+    v = tokens[0]
+    return (m2_expr.ExprOp, v)
+
+
+def parse_id(tokens):
+    v = tokens[0]
+    return (m2_expr.ExprId, v)
+
+
+def ast_parse_op(tokens):
+    if len(tokens) == 1:
+        return tokens[0]
+    if len(tokens) == 2:
+        if tokens[0] in ['-', '+', '!']:
+            return m2_expr.ExprOp(tokens[0], tokens[1])
+    if len(tokens) == 3:
+        if tokens[1] == '-':
+            # a - b => a + (-b)
+            tokens[1] = '+'
+            tokens[2] = - tokens[2]
+        return m2_expr.ExprOp(tokens[1], tokens[0], tokens[2])
+    tokens = tokens[::-1]
+    while len(tokens) >= 3:
+        o1, op, o2 = tokens.pop(), tokens.pop(), tokens.pop()
+        if op == '-':
+            # a - b => a + (-b)
+            op = '+'
+            o2 = - o2
+        e = m2_expr.ExprOp(op, o1, o2)
+        tokens.append(e)
+    if len(tokens) != 1:
+        raise NotImplementedError('strange op')
+    return tokens[0]
+
+
+def ast_id2expr(a):
+    return m2_expr.ExprId(a, 32)
+
+
+def ast_int2expr(a):
+    return m2_expr.ExprInt(a, 32)
+
+
+def neg_int(tokens):
+    x = -tokens[0]
+    return x
+
+
+integer = pyparsing.Word(pyparsing.nums).setParseAction(lambda tokens: int(tokens[0]))
+hex_word = pyparsing.Literal('0x') + pyparsing.Word(pyparsing.hexnums)
+hex_int = pyparsing.Combine(hex_word).setParseAction(lambda tokens: int(tokens[0], 16))
+
+# str_int = (Optional('-') + (hex_int | integer))
+str_int_pos = (hex_int | integer)
+str_int_neg = (pyparsing.Suppress('-') + \
+                   (hex_int | integer)).setParseAction(neg_int)
+
+str_int = str_int_pos | str_int_neg
+str_int.setParseAction(int2expr)
+
+logicop = pyparsing.oneOf('& | ^ >> << <<< >>>')
+signop = pyparsing.oneOf('+ -')
+multop = pyparsing.oneOf('* / %')
+plusop = pyparsing.oneOf('+ -')
+
+
+##########################
+
+def literal_list(l):
+    l = l[:]
+    l.sort()
+    l = l[::-1]
+    o = pyparsing.Literal(l[0])
+    for x in l[1:]:
+        o |= pyparsing.Literal(x)
+    return o
+
+
+def cb_int(tokens):
+    assert len(tokens) == 1
+    integer = AstInt(tokens[0])
+    return integer
+
+
+def cb_parse_id(tokens):
+    assert len(tokens) == 1
+    reg = tokens[0]
+    return AstId(reg)
+
+
+def cb_op_not(tokens):
+    tokens = tokens[0]
+    assert len(tokens) == 2
+    assert tokens[0] == "!"
+    result = AstOp("!", tokens[1])
+    return result
+
+
+def merge_ops(tokens, op):
+    args = []
+    if len(tokens) >= 3:
+        args = [tokens.pop(0)]
+        i = 0
+        while i < len(tokens):
+            op_tmp = tokens[i]
+            arg = tokens[i+1]
+            i += 2
+            if op_tmp != op:
+                raise ValueError("Bad operator")
+            args.append(arg)
+    result = AstOp(op, *args)
+    return result
+
+
+def cb_op_and(tokens):
+    result = merge_ops(tokens[0], "&")
+    return result
+
+
+def cb_op_xor(tokens):
+    result = merge_ops(tokens[0], "^")
+    return result
+
+
+def cb_op_sign(tokens):
+    assert len(tokens) == 1
+    op, value = tokens[0]
+    return -value
+
+
+def cb_op_div(tokens):
+    tokens = tokens[0]
+    assert len(tokens) == 3
+    assert tokens[1] == "/"
+    result = AstOp("/", tokens[0], tokens[2])
+    return result
+
+
+def cb_op_plusminus(tokens):
+    tokens = tokens[0]
+    if len(tokens) == 3:
+        # binary op
+        assert isinstance(tokens[0], AstNode)
+        assert isinstance(tokens[2], AstNode)
+        op, args = tokens[1], [tokens[0], tokens[2]]
+    elif len(tokens) > 3:
+        args = [tokens.pop(0)]
+        i = 0
+        while i < len(tokens):
+            op = tokens[i]
+            arg = tokens[i+1]
+            i += 2
+            if op == '-':
+                arg = -arg
+            elif op == '+':
+                pass
+            else:
+                raise ValueError("Bad operator")
+            args.append(arg)
+        op = '+'
+    else:
+        raise ValueError("Parsing error")
+    assert all(isinstance(arg, AstNode) for arg in args)
+    result = AstOp(op, *args)
+    return result
+
+
+def cb_op_mul(tokens):
+    tokens = tokens[0]
+    assert len(tokens) == 3
+    assert isinstance(tokens[0], AstNode)
+    assert isinstance(tokens[2], AstNode)
+
+    # binary op
+    op, args = tokens[1], [tokens[0], tokens[2]]
+    result = AstOp(op, *args)
+    return result
+
+
+integer = pyparsing.Word(pyparsing.nums).setParseAction(lambda tokens: int(tokens[0]))
+hex_word = pyparsing.Literal('0x') + pyparsing.Word(pyparsing.hexnums)
+hex_int = pyparsing.Combine(hex_word).setParseAction(lambda tokens: int(tokens[0], 16))
+
+str_int_pos = (hex_int | integer)
+
+str_int = str_int_pos
+str_int.setParseAction(cb_int)
+
+notop = pyparsing.oneOf('!')
+andop = pyparsing.oneOf('&')
+orop = pyparsing.oneOf('|')
+xorop = pyparsing.oneOf('^')
+shiftop = pyparsing.oneOf('>> <<')
+rotop = pyparsing.oneOf('<<< >>>')
+signop = pyparsing.oneOf('+ -')
+mulop = pyparsing.oneOf('*')
+plusop = pyparsing.oneOf('+ -')
+divop = pyparsing.oneOf('/')
+
+
+variable = pyparsing.Word(pyparsing.alphas + "_$.", pyparsing.alphanums + "_")
+variable.setParseAction(cb_parse_id)
+operand = str_int | variable
+
+base_expr = pyparsing.infixNotation(operand,
+                               [(notop,   1, pyparsing.opAssoc.RIGHT, cb_op_not),
+                                (andop, 2, pyparsing.opAssoc.RIGHT, cb_op_and),
+                                (xorop, 2, pyparsing.opAssoc.RIGHT, cb_op_xor),
+                                (signop,  1, pyparsing.opAssoc.RIGHT, cb_op_sign),
+                                (mulop,  2, pyparsing.opAssoc.RIGHT, cb_op_mul),
+                                (divop,  2, pyparsing.opAssoc.RIGHT, cb_op_div),
+                                (plusop,  2, pyparsing.opAssoc.LEFT, cb_op_plusminus),
+                                ])
+
+
+default_prio = 0x1337
+
+
+def isbin(s):
+    return re.match(r'[0-1]+$', s)
+
+
+def int2bin(i, l):
+    s = '0' * l + bin(i)[2:]
+    return s[-l:]
+
+
+def myror32(v, r):
+    return ((v & 0xFFFFFFFF) >> r) | ((v << (32 - r)) & 0xFFFFFFFF)
+
+
+def myrol32(v, r):
+    return ((v & 0xFFFFFFFF) >> (32 - r)) | ((v << r) & 0xFFFFFFFF)
+
+
+class bs(object):
+    all_new_c = {}
+    prio = default_prio
+
+    def __init__(self, strbits=None, l=None, cls=None,
+                 fname=None, order=0, flen=None, **kargs):
+        if fname is None:
+            fname = hex(id(str((strbits, l, cls, fname, order, flen, kargs))))
+        if strbits is None:
+            strbits = ""  # "X"*l
+        elif l is None:
+            l = len(strbits)
+        if strbits and isbin(strbits):
+            value = int(strbits, 2)
+        elif 'default_val' in kargs:
+            value = int(kargs['default_val'], 2)
+        else:
+            value = None
+        allbits = list(strbits)
+        allbits.reverse()
+        fbits = 0
+        fmask = 0
+        while allbits:
+            a = allbits.pop()
+            if a == " ":
+                continue
+            fbits <<= 1
+            fmask <<= 1
+            if a in '01':
+                a = int(a)
+                fbits |= a
+                fmask |= 1
+        lmask = (1 << l) - 1
+        # gen conditional field
+        if cls:
+            for b in cls:
+                if 'flen' in b.__dict__:
+                    flen = getattr(b, 'flen')
+
+        self.strbits = strbits
+        self.l = l
+        self.cls = cls
+        self.fname = fname
+        self.order = order
+        self.fbits = fbits
+        self.fmask = fmask
+        self.flen = flen
+        self.value = value
+        self.kargs = kargs
+
+    lmask = property(lambda self:(1 << self.l) - 1)
+
+    def __getitem__(self, item):
+        return getattr(self, item)
+
+    def __repr__(self):
+        o = self.__class__.__name__
+        if self.fname:
+            o += "_%s" % self.fname
+        o += "_%(strbits)s" % self
+        if self.cls:
+            o += '_' + '_'.join([x.__name__ for x in self.cls])
+        return o
+
+    def gen(self, parent):
+        c_name = 'nbsi'
+        if self.cls:
+            c_name += '_' + '_'.join([x.__name__ for x in self.cls])
+            bases = list(self.cls)
+        else:
+            bases = []
+        # bsi added at end of list
+        # used to use first function of added class
+        bases += [bsi]
+        k = c_name, tuple(bases)
+        if k in self.all_new_c:
+            new_c = self.all_new_c[k]
+        else:
+            new_c = type(c_name, tuple(bases), {})
+            self.all_new_c[k] = new_c
+        c = new_c(parent,
+                  self.strbits, self.l, self.cls,
+                  self.fname, self.order, self.lmask, self.fbits,
+                  self.fmask, self.value, self.flen, **self.kargs)
+        return c
+
+    def check_fbits(self, v):
+        return v & self.fmask == self.fbits
+
+    @classmethod
+    def flen(cls, v):
+        raise NotImplementedError('not fully functional')
+
+
+class dum_arg(object):
+
+    def __init__(self, e=None):
+        self.expr = e
+
+
+class bsopt(bs):
+
+    def ispresent(self):
+        return True
+
+
+class bsi(object):
+
+    def __init__(self, parent, strbits, l, cls, fname, order,
+                 lmask, fbits, fmask, value, flen, **kargs):
+        self.parent = parent
+        self.strbits = strbits
+        self.l = l
+        self.cls = cls
+        self.fname = fname
+        self.order = order
+        self.fbits = fbits
+        self.fmask = fmask
+        self.flen = flen
+        self.value = value
+        self.kargs = kargs
+        self.__dict__.update(self.kargs)
+
+    lmask = property(lambda self:(1 << self.l) - 1)
+
+    def decode(self, v):
+        self.value = v & self.lmask
+        return True
+
+    def encode(self):
+        return True
+
+    def clone(self):
+        s = self.__class__(self.parent,
+                           self.strbits, self.l, self.cls,
+                           self.fname, self.order, self.lmask, self.fbits,
+                           self.fmask, self.value, self.flen, **self.kargs)
+        s.__dict__.update(self.kargs)
+        if hasattr(self, 'expr'):
+            s.expr = self.expr
+        return s
+
+    def __hash__(self):
+        kargs = []
+        for k, v in list(viewitems(self.kargs)):
+            if isinstance(v, list):
+                v = tuple(v)
+            kargs.append((k, v))
+        l = [self.strbits, self.l, self.cls,
+             self.fname, self.order, self.lmask, self.fbits,
+             self.fmask, self.value]  # + kargs
+
+        return hash(tuple(l))
+
+
+class bs_divert(object):
+    prio = default_prio
+
+    def __init__(self, **kargs):
+        self.args = kargs
+
+    def __getattr__(self, item):
+        if item in self.__dict__:
+            return self.__dict__[item]
+        elif item in self.args:
+            return self.args.get(item)
+        else:
+            raise AttributeError
+
+
+class bs_name(bs_divert):
+    prio = 1
+
+    def divert(self, i, candidates):
+        out = []
+        for cls, _, bases, dct, fields in candidates:
+            for new_name, value in viewitems(self.args['name']):
+                nfields = fields[:]
+                s = int2bin(value, self.args['l'])
+                args = dict(self.args)
+                args.update({'strbits': s})
+                f = bs(**args)
+                nfields[i] = f
+                ndct = dict(dct)
+                ndct['name'] = new_name
+                out.append((cls, new_name, bases, ndct, nfields))
+        return out
+
+
+class bs_mod_name(bs_divert):
+    prio = 2
+
+    def divert(self, i, candidates):
+        out = []
+        for cls, _, bases, dct, fields in candidates:
+            tab = self.args['mn_mod']
+            if isinstance(tab, list):
+                tmp = {}
+                for j, v in enumerate(tab):
+                    tmp[j] = v
+                tab = tmp
+            for value, new_name in viewitems(tab):
+                nfields = fields[:]
+                s = int2bin(value, self.args['l'])
+                args = dict(self.args)
+                args.update({'strbits': s})
+                f = bs(**args)
+                nfields[i] = f
+                ndct = dict(dct)
+                ndct['name'] = self.modname(ndct['name'], value)
+                out.append((cls, new_name, bases, ndct, nfields))
+        return out
+
+    def modname(self, name, i):
+        return name + self.args['mn_mod'][i]
+
+
+class bs_cond(bsi):
+    pass
+
+
+class bs_swapargs(bs_divert):
+
+    def divert(self, i, candidates):
+        out = []
+        for cls, name, bases, dct, fields in candidates:
+            # args not permuted
+            ndct = dict(dct)
+            nfields = fields[:]
+            # gen fix field
+            f = gen_bsint(0, self.args['l'], self.args)
+            nfields[i] = f
+            out.append((cls, name, bases, ndct, nfields))
+
+            # args permuted
+            ndct = dict(dct)
+            nfields = fields[:]
+            ap = ndct['args_permut'][:]
+            a = ap.pop(0)
+            b = ap.pop(0)
+            ndct['args_permut'] = [b, a] + ap
+            # gen fix field
+            f = gen_bsint(1, self.args['l'], self.args)
+            nfields[i] = f
+
+            out.append((cls, name, bases, ndct, nfields))
+        return out
+
+
+class m_arg(object):
+
+    def fromstring(self, text, loc_db, parser_result=None):
+        if parser_result:
+            e, start, stop = parser_result[self.parser]
+            self.expr = e
+            return start, stop
+        try:
+            v, start, stop = next(self.parser.scanString(text))
+        except StopIteration:
+            return None, None
+        arg = v[0]
+        expr = self.asm_ast_to_expr(arg, loc_db)
+        self.expr = expr
+        return start, stop
+
+    def asm_ast_to_expr(self, arg, loc_db, **kwargs):
+        raise NotImplementedError("Virtual")
+
+
+class m_reg(m_arg):
+    prio = default_prio
+
+    @property
+    def parser(self):
+        return self.reg.parser
+
+    def decode(self, v):
+        self.expr = self.reg.expr[0]
+        return True
+
+    def encode(self):
+        return self.expr == self.reg.expr[0]
+
+
+class reg_noarg(object):
+    reg_info = None
+    parser = None
+
+    def fromstring(self, text, loc_db, parser_result=None):
+        if parser_result:
+            e, start, stop = parser_result[self.parser]
+            self.expr = e
+            return start, stop
+        try:
+            v, start, stop = next(self.parser.scanString(text))
+        except StopIteration:
+            return None, None
+        arg = v[0]
+        expr = self.parses_to_expr(arg, loc_db)
+        self.expr = expr
+        return start, stop
+
+    def decode(self, v):
+        v = v & self.lmask
+        if v >= len(self.reg_info.expr):
+            return False
+        self.expr = self.reg_info.expr[v]
+        return True
+
+    def encode(self):
+        if not self.expr in self.reg_info.expr:
+            log.debug("cannot encode reg %r", self.expr)
+            return False
+        self.value = self.reg_info.expr.index(self.expr)
+        return True
+
+    def check_fbits(self, v):
+        return v & self.fmask == self.fbits
+
+
+class mn_prefix(object):
+    pass
+
+
+def swap16(v):
+    return struct.unpack('<H', struct.pack('>H', v))[0]
+
+
+def swap32(v):
+    return struct.unpack('<I', struct.pack('>I', v))[0]
+
+
+def perm_inv(p):
+    o = [None for x in range(len(p))]
+    for i, x in enumerate(p):
+        o[x] = i
+    return o
+
+
+def gen_bsint(value, l, args):
+    s = int2bin(value, l)
+    args = dict(args)
+    args.update({'strbits': s})
+    f = bs(**args)
+    return f
+
+total_scans = 0
+
+
+def branch2nodes(branch, nodes=None):
+    if nodes is None:
+        nodes = []
+    for k, v in viewitems(branch):
+        if not isinstance(v, dict):
+            continue
+        for k2 in v:
+            nodes.append((k, k2))
+        branch2nodes(v, nodes)
+
+
+def factor_one_bit(tree):
+    if isinstance(tree, set):
+        return tree
+    new_keys = defaultdict(lambda: defaultdict(dict))
+    if len(tree) == 1:
+        return tree
+    for k, v in viewitems(tree):
+        if k == "mn":
+            new_keys[k] = v
+            continue
+        l, fmask, fbits, fname, flen = k
+        if flen is not None or l <= 1:
+            new_keys[k] = v
+            continue
+        cfmask = fmask >> (l - 1)
+        nfmask = fmask & ((1 << (l - 1)) - 1)
+        cfbits = fbits >> (l - 1)
+        nfbits = fbits & ((1 << (l - 1)) - 1)
+        ck = 1, cfmask, cfbits, None, flen
+        nk = l - 1, nfmask, nfbits, fname, flen
+        if nk in new_keys[ck]:
+            raise NotImplementedError('not fully functional')
+        new_keys[ck][nk] = v
+    for k, v in list(viewitems(new_keys)):
+        new_keys[k] = factor_one_bit(v)
+    # try factor sons
+    if len(new_keys) != 1:
+        return new_keys
+    subtree = next(iter(viewvalues(new_keys)))
+    if len(subtree) != 1:
+        return new_keys
+    if next(iter(subtree)) == 'mn':
+        return new_keys
+
+    return new_keys
+
+
+def factor_fields(tree):
+    if not isinstance(tree, dict):
+        return tree
+    if len(tree) != 1:
+        return tree
+    # merge
+    k1, v1 = next(iter(viewitems(tree)))
+    if k1 == "mn":
+        return tree
+    l1, fmask1, fbits1, fname1, flen1 = k1
+    if fname1 is not None:
+        return tree
+    if flen1 is not None:
+        return tree
+
+    if not isinstance(v1, dict):
+        return tree
+    if len(v1) != 1:
+        return tree
+    k2, v2 = next(iter(viewitems(v1)))
+    if k2 == "mn":
+        return tree
+    l2, fmask2, fbits2, fname2, flen2 = k2
+    if fname2 is not None:
+        return tree
+    if flen2 is not None:
+        return tree
+    l = l1 + l2
+    fmask = (fmask1 << l2) | fmask2
+    fbits = (fbits1 << l2) | fbits2
+    fname = fname2
+    flen = flen2
+    k = l, fmask, fbits, fname, flen
+    new_keys = {k: v2}
+    return new_keys
+
+
+def factor_fields_all(tree):
+    if not isinstance(tree, dict):
+        return tree
+    new_keys = {}
+    for k, v in viewitems(tree):
+        v = factor_fields(v)
+        new_keys[k] = factor_fields_all(v)
+    return new_keys
+
+
+def graph_tree(tree):
+    nodes = []
+    branch2nodes(tree, nodes)
+
+    out = """
+          digraph G {
+          """
+    for a, b in nodes:
+        if b == 'mn':
+            continue
+        out += "%s -> %s;\n" % (id(a), id(b))
+    out += "}"
+    open('graph.txt', 'w').write(out)
+
+
+def add_candidate_to_tree(tree, c):
+    branch = tree
+    for f in c.fields:
+        if f.l == 0:
+            continue
+        node = f.l, f.fmask, f.fbits, f.fname, f.flen
+
+        if not node in branch:
+            branch[node] = {}
+        branch = branch[node]
+    if not 'mn' in branch:
+        branch['mn'] = set()
+    branch['mn'].add(c)
+
+
+def add_candidate(bases, c):
+    add_candidate_to_tree(bases[0].bintree, c)
+
+
+def getfieldby_name(fields, fname):
+    f = [x for x in fields if hasattr(x, 'fname') and x.fname == fname]
+    if len(f) != 1:
+        raise ValueError('more than one field with name: %s' % fname)
+    return f[0]
+
+
+def getfieldindexby_name(fields, fname):
+    for i, f in enumerate(fields):
+        if hasattr(f, 'fname') and f.fname == fname:
+            return f, i
+    return None
+
+
+class metamn(type):
+
+    def __new__(mcs, name, bases, dct):
+        if name == "cls_mn" or name.startswith('mn_'):
+            return type.__new__(mcs, name, bases, dct)
+        alias = dct.get('alias', False)
+
+        fields = bases[0].mod_fields(dct['fields'])
+        if not 'name' in dct:
+            dct["name"] = bases[0].getmn(name)
+        if 'args' in dct:
+            # special case for permuted arguments
+            o = []
+            p = []
+            for i, a in enumerate(dct['args']):
+                o.append((i, a))
+                if a in fields:
+                    p.append((fields.index(a), a))
+            p.sort()
+            p = [x[1] for x in p]
+            p = [dct['args'].index(x) for x in p]
+            dct['args_permut'] = perm_inv(p)
+        # order fields
+        f_ordered = [x for x in enumerate(fields)]
+        f_ordered.sort(key=lambda x: (x[1].prio, x[0]))
+        candidates = bases[0].gen_modes(mcs, name, bases, dct, fields)
+        for i, fc in f_ordered:
+            if isinstance(fc, bs_divert):
+                candidates = fc.divert(i, candidates)
+        for cls, name, bases, dct, fields in candidates:
+            ndct = dict(dct)
+            fields = [f for f in fields if f]
+            ndct['fields'] = fields
+            ndct['mn_len'] = sum([x.l for x in fields])
+            c = type.__new__(cls, name, bases, ndct)
+            c.alias = alias
+            c.check_mnemo(fields)
+            c.num = bases[0].num
+            bases[0].num += 1
+            bases[0].all_mn.append(c)
+            mode = dct['mode']
+            bases[0].all_mn_mode[mode].append(c)
+            bases[0].all_mn_name[c.name].append(c)
+            i = c()
+            i.init_class()
+            bases[0].all_mn_inst[c].append(i)
+            add_candidate(bases, c)
+            # gen byte lookup
+            o = ""
+            for f in i.fields_order:
+                if not isinstance(f, bsi):
+                    raise ValueError('f is not bsi')
+                if f.l == 0:
+                    continue
+                o += f.strbits
+        return c
+
+
+class instruction(object):
+    __slots__ = ["name", "mode", "args",
+                 "l", "b", "offset", "data",
+                 "additional_info", "delayslot"]
+
+    def __init__(self, name, mode, args, additional_info=None):
+        self.name = name
+        self.mode = mode
+        self.args = args
+        self.additional_info = additional_info
+        self.offset = None
+        self.l = None
+        self.b = None
+        self.delayslot = 0
+
+    def gen_args(self, args):
+        out = ', '.join([str(x) for x in args])
+        return out
+
+    def __str__(self):
+        return self.to_string()
+
+    def to_string(self, loc_db=None):
+        o = "%-10s " % self.name
+        args = []
+        for i, arg in enumerate(self.args):
+            if not isinstance(arg, m2_expr.Expr):
+                raise ValueError('zarb arg type')
+            x = self.arg2str(arg, i, loc_db)
+            args.append(x)
+        o += self.gen_args(args)
+        return o
+
+    def to_html(self, loc_db=None):
+        out = "%-10s " % self.name
+        out = '<font color="%s">%s</font>' % (utils.COLOR_MNEMO, out)
+
+        args = []
+        for i, arg in enumerate(self.args):
+            if not isinstance(arg, m2_expr.Expr):
+                raise ValueError('zarb arg type')
+            x = self.arg2html(arg, i, loc_db)
+            args.append(x)
+        out += self.gen_args(args)
+        return out
+
+    def get_asm_offset(self, expr):
+        return m2_expr.ExprInt(self.offset, expr.size)
+
+    def get_asm_next_offset(self, expr):
+        return m2_expr.ExprInt(self.offset+self.l, expr.size)
+
+    def resolve_args_with_symbols(self, loc_db):
+        args_out = []
+        for expr in self.args:
+            # try to resolve symbols using loc_db (0 for default value)
+            loc_keys = m2_expr.get_expr_locs(expr)
+            fixed_expr = {}
+            for exprloc in loc_keys:
+                loc_key = exprloc.loc_key
+                names = loc_db.get_location_names(loc_key)
+                # special symbols
+                if '$' in names:
+                    fixed_expr[exprloc] = self.get_asm_offset(exprloc)
+                    continue
+                if '_' in names:
+                    fixed_expr[exprloc] = self.get_asm_next_offset(exprloc)
+                    continue
+                arg_int = loc_db.get_location_offset(loc_key)
+                if arg_int is not None:
+                    fixed_expr[exprloc] = m2_expr.ExprInt(arg_int, exprloc.size)
+                    continue
+                if not names:
+                    raise ValueError('Unresolved symbol: %r' % exprloc)
+
+                offset = loc_db.get_location_offset(loc_key)
+                if offset is None:
+                    raise ValueError(
+                        'The offset of loc_key "%s" cannot be determined' % names
+                    )
+                else:
+                    # Fix symbol with its offset
+                    size = exprloc.size
+                    if size is None:
+                        default_size = self.get_symbol_size(exprloc, loc_db)
+                        size = default_size
+                    value = m2_expr.ExprInt(offset, size)
+                fixed_expr[exprloc] = value
+
+            expr = expr.replace_expr(fixed_expr)
+            expr = expr_simp(expr)
+            args_out.append(expr)
+        return args_out
+
+    def get_info(self, c):
+        return
+
+
+class cls_mn(with_metaclass(metamn, object)):
+    args_symb = []
+    instruction = instruction
+    # Block's offset alignment
+    alignment = 1
+
+    @classmethod
+    def guess_mnemo(cls, bs, attrib, pre_dis_info, offset):
+        candidates = []
+
+        candidates = set()
+
+        fname_values = pre_dis_info
+        todo = [
+            (dict(fname_values), branch, offset * 8)
+            for branch in list(viewitems(cls.bintree))
+        ]
+        for fname_values, branch, offset_b in todo:
+            (l, fmask, fbits, fname, flen), vals = branch
+
+            if flen is not None:
+                l = flen(attrib, fname_values)
+            if l is not None:
+                try:
+                    v = cls.getbits(bs, attrib, offset_b, l)
+                except IOError:
+                    # Raised if offset is out of bound
+                    continue
+                offset_b += l
+                if v & fmask != fbits:
+                    continue
+                if fname is not None and not fname in fname_values:
+                    fname_values[fname] = v
+            for nb, v in viewitems(vals):
+                if 'mn' in nb:
+                    candidates.update(v)
+                else:
+                    todo.append((dict(fname_values), (nb, v), offset_b))
+
+        return [c for c in candidates]
+
+    def reset_class(self):
+        for f in self.fields_order:
+            if f.strbits and isbin(f.strbits):
+                f.value = int(f.strbits, 2)
+            elif 'default_val' in f.kargs:
+                f.value = int(f.kargs['default_val'], 2)
+            else:
+                f.value = None
+            if f.fname:
+                setattr(self, f.fname, f)
+
+    def init_class(self):
+        args = []
+        fields_order = []
+        to_decode = []
+        off = 0
+        for i, fc in enumerate(self.fields):
+            f = fc.gen(self)
+            f.offset = off
+            off += f.l
+            fields_order.append(f)
+            to_decode.append((i, f))
+
+            if isinstance(f, m_arg):
+                args.append(f)
+            if f.fname:
+                setattr(self, f.fname, f)
+        if hasattr(self, 'args_permut'):
+            args = [args[self.args_permut[i]]
+                    for i in range(len(self.args_permut))]
+        to_decode.sort(key=lambda x: (x[1].order, x[0]))
+        to_decode = [fields_order.index(f[1]) for f in to_decode]
+        self.args = args
+        self.fields_order = fields_order
+        self.to_decode = to_decode
+
+    def add_pre_dis_info(self, prefix=None):
+        return True
+
+    @classmethod
+    def getbits(cls, bs, attrib, offset_b, l):
+        return bs.getbits(offset_b, l)
+
+    @classmethod
+    def getbytes(cls, bs, offset, l):
+        return bs.getbytes(offset, l)
+
+    @classmethod
+    def pre_dis(cls, v_o, attrib, offset):
+        return {}, v_o, attrib, offset, 0
+
+    def post_dis(self):
+        return self
+
+    @classmethod
+    def check_mnemo(cls, fields):
+        pass
+
+    @classmethod
+    def mod_fields(cls, fields):
+        return fields
+
+    @classmethod
+    def dis(cls, bs_o, mode_o = None, offset=0):
+        if not isinstance(bs_o, bin_stream):
+            bs_o = bin_stream_str(bs_o)
+
+        bs_o.enter_atomic_mode()
+
+        offset_o = offset
+        try:
+            pre_dis_info, bs, mode, offset, prefix_len = cls.pre_dis(
+                bs_o, mode_o, offset)
+        except:
+            bs_o.leave_atomic_mode()
+            raise
+        candidates = cls.guess_mnemo(bs, mode, pre_dis_info, offset)
+        if not candidates:
+            bs_o.leave_atomic_mode()
+            raise Disasm_Exception('cannot disasm (guess) at %X' % offset)
+
+        out = []
+        out_c = []
+        if hasattr(bs, 'getlen'):
+            bs_l = bs.getlen()
+        else:
+            bs_l = len(bs)
+
+        alias = False
+        for c in candidates:
+            log.debug("*" * 40, mode, c.mode)
+            log.debug(c.fields)
+
+            c = cls.all_mn_inst[c][0]
+
+            c.reset_class()
+            c.mode = mode
+
+            if not c.add_pre_dis_info(pre_dis_info):
+                continue
+
+            todo = {}
+            getok = True
+            fname_values = dict(pre_dis_info)
+            offset_b = offset * 8
+
+            total_l = 0
+            for i, f in enumerate(c.fields_order):
+                if f.flen is not None:
+                    l = f.flen(mode, fname_values)
+                else:
+                    l = f.l
+                if l is not None:
+                    total_l += l
+                    f.l = l
+                    f.is_present = True
+                    log.debug("FIELD %s %s %s %s", f.__class__, f.fname,
+                              offset_b, l)
+                    if bs_l * 8 - offset_b < l:
+                        getok = False
+                        break
+                    try:
+                        bv = cls.getbits(bs, mode, offset_b, l)
+                    except:
+                        bs_o.leave_atomic_mode()
+                        raise
+                    offset_b += l
+                    if not f.fname in fname_values:
+                        fname_values[f.fname] = bv
+                    todo[i] = bv
+                else:
+                    f.is_present = False
+                    todo[i] = None
+
+            if not getok:
+                continue
+
+            c.l = prefix_len + total_l // 8
+            for i in c.to_decode:
+                f = c.fields_order[i]
+                if f.is_present:
+                    ret = f.decode(todo[i])
+                    if not ret:
+                        log.debug("cannot decode %r", f)
+                        break
+
+            if not ret:
+                continue
+            for a in c.args:
+                a.expr = expr_simp(a.expr)
+
+            c.b = cls.getbytes(bs, offset_o, c.l)
+            c.offset = offset_o
+            c = c.post_dis()
+            if c is None:
+                continue
+            c_args = [a.expr for a in c.args]
+            instr = cls.instruction(c.name, mode, c_args,
+                                    additional_info=c.additional_info())
+            instr.l = prefix_len + total_l // 8
+            instr.b = cls.getbytes(bs, offset_o, instr.l)
+            instr.offset = offset_o
+            instr.get_info(c)
+            if c.alias:
+                alias = True
+            out.append(instr)
+            out_c.append(c)
+
+        bs_o.leave_atomic_mode()
+
+        if not out:
+            raise Disasm_Exception('cannot disasm at %X' % offset_o)
+        if len(out) != 1:
+            if not alias:
+                log.warning('dis multiple args ret default')
+
+            for i, o in enumerate(out_c):
+                if o.alias:
+                    return out[i]
+            raise NotImplementedError(
+                'Multiple disas: \n' +
+                "\n".join(str(x) for x in out)
+            )
+        return out[0]
+
+    @classmethod
+    def fromstring(cls, text, loc_db, mode = None):
+        global total_scans
+        name = re.search(r'(\S+)', text).groups()
+        if not name:
+            raise ValueError('cannot find name', text)
+        name = name[0]
+
+        if not name in cls.all_mn_name:
+            raise ValueError('unknown name', name)
+        clist = [x for x in cls.all_mn_name[name]]
+        out = []
+        out_args = []
+        parsers = defaultdict(dict)
+
+        for cc in clist:
+            for c in cls.get_cls_instance(cc, mode):
+                args_expr = []
+                args_str = text[len(name):].strip(' ')
+
+                start = 0
+                cannot_parse = False
+                len_o = len(args_str)
+
+                for i, f in enumerate(c.args):
+                    start_i = len_o - len(args_str)
+                    if type(f.parser) == tuple:
+                        parser = f.parser
+                    else:
+                        parser = (f.parser,)
+                    for p in parser:
+                        if p in parsers[(i, start_i)]:
+                            continue
+                        try:
+                            total_scans += 1
+                            v, start, stop = next(p.scanString(args_str))
+                        except StopIteration:
+                            v, start, stop = [None], None, None
+                        if start != 0:
+                            v, start, stop = [None], None, None
+                        if v != [None]:
+                            v = f.asm_ast_to_expr(v[0], loc_db)
+                        if v is None:
+                            v, start, stop = [None], None, None
+                        parsers[(i, start_i)][p] = v, start, stop
+                    start, stop = f.fromstring(args_str, loc_db, parsers[(i, start_i)])
+                    if start != 0:
+                        log.debug("cannot fromstring %r", args_str)
+                        cannot_parse = True
+                        break
+                    if f.expr is None:
+                        raise NotImplementedError('not fully functional')
+                    f.expr = expr_simp(f.expr)
+                    args_expr.append(f.expr)
+                    args_str = args_str[stop:].strip(' ')
+                    if args_str.startswith(','):
+                        args_str = args_str[1:]
+                    args_str = args_str.strip(' ')
+                if args_str:
+                    cannot_parse = True
+                if cannot_parse:
+                    continue
+
+                out.append(c)
+                out_args.append(args_expr)
+                break
+
+        if len(out) == 0:
+            raise ValueError('cannot fromstring %r' % text)
+        if len(out) != 1:
+            log.debug('fromstring multiple args ret default')
+        c = out[0]
+        c_args = out_args[0]
+
+        instr = cls.instruction(c.name, mode, c_args,
+                                additional_info=c.additional_info())
+        return instr
+
+    def dup_info(self, infos):
+        return
+
+    @classmethod
+    def get_cls_instance(cls, cc, mode, infos=None):
+        c = cls.all_mn_inst[cc][0]
+
+        c.reset_class()
+        c.add_pre_dis_info()
+        c.dup_info(infos)
+
+        c.mode = mode
+        yield c
+
+    @classmethod
+    def asm(cls, instr, loc_db=None):
+        """
+        Re asm instruction by searching mnemo using name and args. We then
+        can modify args and get the hex of a modified instruction
+        """
+        clist = cls.all_mn_name[instr.name]
+        clist = [x for x in clist]
+        vals = []
+        candidates = []
+        args = instr.resolve_args_with_symbols(loc_db)
+
+        for cc in clist:
+
+            for c in cls.get_cls_instance(
+                cc, instr.mode, instr.additional_info):
+
+                cannot_parse = False
+                if len(c.args) != len(instr.args):
+                    continue
+
+                # only fix args expr
+                for i in range(len(c.args)):
+                    c.args[i].expr = args[i]
+
+                v = c.value(instr.mode)
+                if not v:
+                    log.debug("cannot encode %r", c)
+                    cannot_parse = True
+                if cannot_parse:
+                    continue
+                vals += v
+                candidates.append((c, v))
+        if len(vals) == 0:
+            raise ValueError(
+                'cannot asm %r %r' %
+                (instr.name, [str(x) for x in instr.args])
+            )
+        if len(vals) != 1:
+            log.debug('asm multiple args ret default')
+
+        vals = cls.filter_asm_candidates(instr, candidates)
+        return vals
+
+    @classmethod
+    def filter_asm_candidates(cls, instr, candidates):
+        o = []
+        for _, v in candidates:
+            o += v
+        o.sort(key=len)
+        return o
+
+    def value(self, mode):
+        todo = [(0, 0, [(x, self.fields_order[x]) for x in self.to_decode[::-1]])]
+
+        result = []
+        done = []
+
+        while todo:
+            index, cur_len, to_decode = todo.pop()
+            # TEST XXX
+            for _, f in to_decode:
+                setattr(self, f.fname, f)
+            if (index, [x[1].value for x in to_decode]) in done:
+                continue
+            done.append((index, [x[1].value for x in to_decode]))
+
+            can_encode = True
+            for i, f in to_decode[index:]:
+                f.parent.l = cur_len
+                ret = f.encode()
+                if not ret:
+                    log.debug('cannot encode %r', f)
+                    can_encode = False
+                    break
+
+                if f.value is not None and f.l:
+                    if f.value > f.lmask:
+                        log.debug('cannot encode %r', f)
+                        can_encode = False
+                        break
+                    cur_len += f.l
+                index += 1
+                if ret is True:
+                    continue
+
+                for _ in ret:
+                    o = []
+                    if ((index, cur_len, [xx[1].value for xx in to_decode]) in todo or
+                        (index, cur_len, [xx[1].value for xx in to_decode]) in done):
+                        raise NotImplementedError('not fully functional')
+
+                    for p, f in to_decode:
+                        fnew = f.clone()
+                        o.append((p, fnew))
+                    todo.append((index, cur_len, o))
+                can_encode = False
+
+                break
+            if not can_encode:
+                continue
+            result.append(to_decode)
+
+        return self.decoded2bytes(result)
+
+    def encodefields(self, decoded):
+        bits = bitobj()
+        for _, f in decoded:
+            setattr(self, f.fname, f)
+
+            if f.value is None:
+                continue
+            bits.putbits(f.value, f.l)
+
+        return bits.tostring()
+
+    def decoded2bytes(self, result):
+        if not result:
+            return []
+
+        out = []
+        for decoded in result:
+            decoded.sort()
+
+            o = self.encodefields(decoded)
+            if o is None:
+                continue
+            out.append(o)
+        out = list(set(out))
+        return out
+
+    def gen_args(self, args):
+        out = ', '.join([str(x) for x in args])
+        return out
+
+    def args2str(self):
+        args = []
+        for arg in self.args:
+            # XXX todo test
+            if not (isinstance(arg, m2_expr.Expr) or
+                    isinstance(arg.expr, m2_expr.Expr)):
+                raise ValueError('zarb arg type')
+            x = str(arg)
+            args.append(x)
+        return args
+
+    def __str__(self):
+        o = "%-10s " % self.name
+        args = []
+        for arg in self.args:
+            # XXX todo test
+            if not (isinstance(arg, m2_expr.Expr) or
+                    isinstance(arg.expr, m2_expr.Expr)):
+                raise ValueError('zarb arg type')
+            x = str(arg)
+            args.append(x)
+
+        o += self.gen_args(args)
+        return o
+
+    def parse_prefix(self, v):
+        return 0
+
+    def set_dst_symbol(self, loc_db):
+        dst = self.getdstflow(loc_db)
+        args = []
+        for d in dst:
+            if isinstance(d, m2_expr.ExprInt):
+                l = loc_db.get_or_create_offset_location(int(d))
+
+                a = m2_expr.ExprId(l.name, d.size)
+            else:
+                a = d
+            args.append(a)
+        self.args_symb = args
+
+    def getdstflow(self, loc_db):
+        return [self.args[0].expr]
+
+
+class imm_noarg(object):
+    intsize = 32
+    intmask = (1 << intsize) - 1
+
+    def int2expr(self, v):
+        if (v & ~self.intmask) != 0:
+            return None
+        return m2_expr.ExprInt(v, self.intsize)
+
+    def expr2int(self, e):
+        if not isinstance(e, m2_expr.ExprInt):
+            return None
+        v = int(e)
+        if v & ~self.intmask != 0:
+            return None
+        return v
+
+    def fromstring(self, text, loc_db, parser_result=None):
+        if parser_result:
+            e, start, stop = parser_result[self.parser]
+        else:
+            try:
+                e, start, stop = next(self.parser.scanString(text))
+            except StopIteration:
+                return None, None
+        if e == [None]:
+            return None, None
+
+        assert(m2_expr.is_expr(e))
+        self.expr = e
+        if self.expr is None:
+            log.debug('cannot fromstring int %r', text)
+            return None, None
+        return start, stop
+
+    def decodeval(self, v):
+        return v
+
+    def encodeval(self, v):
+        return v
+
+    def decode(self, v):
+        v = v & self.lmask
+        v = self.decodeval(v)
+        e = self.int2expr(v)
+        if not e:
+            return False
+        self.expr = e
+        return True
+
+    def encode(self):
+        v = self.expr2int(self.expr)
+        if v is None:
+            return False
+        v = self.encodeval(v)
+        if v is False:
+            return False
+        self.value = v
+        return True
+
+
+class imm08_noarg(object):
+    int2expr = lambda self, x: m2_expr.ExprInt(x, 8)
+
+
+class imm16_noarg(object):
+    int2expr = lambda self, x: m2_expr.ExprInt(x, 16)
+
+
+class imm32_noarg(object):
+    int2expr = lambda self, x: m2_expr.ExprInt(x, 32)
+
+
+class imm64_noarg(object):
+    int2expr = lambda self, x: m2_expr.ExprInt(x, 64)
+
+
+class int32_noarg(imm_noarg):
+    intsize = 32
+    intmask = (1 << intsize) - 1
+
+    def decode(self, v):
+        v = sign_ext(v, self.l, self.intsize)
+        v = self.decodeval(v)
+        self.expr = self.int2expr(v)
+        return True
+
+    def encode(self):
+        if not isinstance(self.expr, m2_expr.ExprInt):
+            return False
+        v = int(self.expr)
+        if sign_ext(v & self.lmask, self.l, self.intsize) != v:
+            return False
+        v = self.encodeval(v & self.lmask)
+        if v is False:
+            return False
+        self.value = v & self.lmask
+        return True
+
+class bs8(bs):
+    prio = default_prio
+
+    def __init__(self, v, cls=None, fname=None, **kargs):
+        super(bs8, self).__init__(int2bin(v, 8), 8,
+                                  cls=cls, fname=fname, **kargs)
+
+
+
+
+def swap_uint(size, i):
+    if size == 8:
+        return i & 0xff
+    elif size == 16:
+        return struct.unpack('<H', struct.pack('>H', i & 0xffff))[0]
+    elif size == 32:
+        return struct.unpack('<I', struct.pack('>I', i & 0xffffffff))[0]
+    elif size == 64:
+        return struct.unpack('<Q', struct.pack('>Q', i & 0xffffffffffffffff))[0]
+    raise ValueError('unknown int len %r' % size)
+
+
+def swap_sint(size, i):
+    if size == 8:
+        return i
+    elif size == 16:
+        return struct.unpack('<h', struct.pack('>H', i & 0xffff))[0]
+    elif size == 32:
+        return struct.unpack('<i', struct.pack('>I', i & 0xffffffff))[0]
+    elif size == 64:
+        return struct.unpack('<q', struct.pack('>Q', i & 0xffffffffffffffff))[0]
+    raise ValueError('unknown int len %r' % size)
+
+
+def sign_ext(v, s_in, s_out):
+    assert(s_in <= s_out)
+    v &= (1 << s_in) - 1
+    sign_in = v & (1 << (s_in - 1))
+    if not sign_in:
+        return v
+    m = (1 << (s_out)) - 1
+    m ^= (1 << s_in) - 1
+    v |= m
+    return v