about summary refs log tree commit diff stats
path: root/miasm2/core/cpu.py
diff options
context:
space:
mode:
authorFabrice Desclaux <fabrice.desclaux@cea.fr>2018-05-05 23:13:12 +0200
committerFabrice Desclaux <fabrice.desclaux@cea.fr>2018-05-14 10:29:27 +0200
commit94d49ed54f07e3d399de74de13f5422837c031fa (patch)
treeb3f7fd34c7ff8d17bd9f26d53511b30935485092 /miasm2/core/cpu.py
parentdb4fd7f58d6a4ed87fc7d6f28c7c2af31e61fb65 (diff)
downloadmiasm-94d49ed54f07e3d399de74de13f5422837c031fa.tar.gz
miasm-94d49ed54f07e3d399de74de13f5422837c031fa.zip
Core: updt parser structure
Diffstat (limited to 'miasm2/core/cpu.py')
-rw-r--r--miasm2/core/cpu.py384
1 files changed, 206 insertions, 178 deletions
diff --git a/miasm2/core/cpu.py b/miasm2/core/cpu.py
index 061752f8..ca419458 100644
--- a/miasm2/core/cpu.py
+++ b/miasm2/core/cpu.py
@@ -13,6 +13,9 @@ from miasm2.core.bin_stream import bin_stream, bin_stream_str
 from miasm2.core.utils import Disasm_Exception
 from miasm2.expression.simplifications import expr_simp
 
+
+from miasm2.core.asm_ast import AstNode, AstInt, AstId, AstMem, AstOp
+
 log = logging.getLogger("cpuhelper")
 console_handler = logging.StreamHandler()
 console_handler.setFormatter(logging.Formatter("%(levelname)-5s: %(message)s"))
@@ -85,12 +88,19 @@ def literal_list(l):
     return o
 
 
-class reg_info:
+class reg_info(object):
 
     def __init__(self, reg_str, reg_expr):
         self.str = reg_str
         self.expr = reg_expr
-        self.parser = literal_list(reg_str).setParseAction(self.reg2expr)
+        self.parser = literal_list(reg_str).setParseAction(self.cb_parse)
+
+    def cb_parse(self, t):
+        assert len(t) == 1
+        i = self.str.index(t[0])
+        reg = self.expr[i]
+        result = AstId(reg)
+        return result
 
     def reg2expr(self, s):
         i = self.str.index(s[0])
@@ -100,15 +110,21 @@ class reg_info:
         return self.expr.index(e)
 
 
-
-class reg_info_dct:
+class reg_info_dct(object):
 
     def __init__(self, reg_expr):
         self.dct_str_inv = dict((v.name, k) for k, v in reg_expr.iteritems())
         self.dct_expr = reg_expr
         self.dct_expr_inv = dict((v, k) for k, v in reg_expr.iteritems())
         reg_str = [v.name for v in reg_expr.itervalues()]
-        self.parser = literal_list(reg_str).setParseAction(self.reg2expr)
+        self.parser = literal_list(reg_str).setParseAction(self.cb_parse)
+
+    def cb_parse(self, t):
+        assert len(t) == 1
+        i = self.dct_str_inv[t[0]]
+        reg = self.dct_expr[i]
+        result = AstId(reg)
+        return result
 
     def reg2expr(self, s):
         i = self.dct_str_inv[s[0]]
@@ -118,34 +134,30 @@ class reg_info_dct:
         return self.dct_expr_inv[e]
 
 
-def gen_reg(rname, env, sz=32):
-    """
-    Gen reg expr and parser
-    Equivalent to:
-        PC = ExprId('PC')
-        reg_pc_str = ['PC']
-        reg_pc_expr = [ExprId(x, sz) for x in reg_pc_str]
-        regpc = reg_info(reg_pc_str, reg_pc_expr)
+def gen_reg(reg_name, sz=32):
+    """Gen reg expr and parser"""
+    reg_name_lower = reg_name.lower()
+    reg = m2_expr.ExprId(reg_name, sz)
+    reginfo = reg_info([reg_name], [reg])
+    return reg, reginfo
 
-        class bs_rname(m_reg):
-            reg = regi_rname
 
-        bsrname = bs(l=0, cls=(bs_rname,))
+def gen_reg_bs(reg_name, reg_info, base_cls):
+    """
+    Generate:
+        class bs_reg_name(base_cls):
+            reg = reg_info
 
+        bs_reg_name = bs(l=0, cls=(bs_reg_name,))
     """
-    rnamel = rname.lower()
-    r = m2_expr.ExprId(rname, sz)
-    reg_str = [rname]
-    reg_expr = [r]
-    regi = reg_info(reg_str, reg_expr)
-    # define as global val
-    cname = "bs_" + rnamel
-    c = type(cname, (m_reg,), {'reg': regi})
-    env[rname] = r
-    env["regi_" + rnamel] = regi
-    env[cname] = c
-    env["bs" + rnamel] = bs(l=0, cls=(c,))
-    return r, regi
+    reg_name_lower = reg_name.lower()
+
+    bs_name = "bs_%s" % reg_name
+    cls = type(bs_name, base_cls, {'reg': reg_info})
+
+    bs_obj = bs(l=0, cls=(cls,))
+
+    return cls, bs_obj
 
 
 def gen_regs(rnames, env, sz=32):
@@ -217,125 +229,6 @@ def ast_int2expr(a):
     return m2_expr.ExprInt(a, 32)
 
 
-
-class ParseAst(object):
-
-    def __init__(self, id2expr, int2expr, default_size=32):
-        self.id2expr = id2expr
-        self.int2expr = int2expr
-        self.default_size = default_size
-
-    def int_from_size(self, size, value):
-        """Transform a string into ExprInt.
-        * if @size is None, use provided int2expr
-        * else, use @size to generate integer
-        @size: size of int; None if not forced.
-        @value: string representing an integer
-        """
-        if size is None:
-            return self.int2expr(value)
-        else:
-            return m2_expr.ExprInt(value, size)
-
-    def id_from_size(self, size, value):
-        """Transform a string into ExprId.
-        * if @size is None, use provided id2expr
-        * else, use @size to generate id
-        @size: size of id; None if not forced.
-        @value: string representing the id
-        """
-        value = self.id2expr(value)
-        if isinstance(value, m2_expr.Expr):
-            return value
-        if size is None:
-            size = self.default_size
-        assert value is not None
-        return m2_expr.ExprId(asmblock.AsmLabel(value), size)
-
-    def ast_to_expr(self, size, ast):
-        """Transform a typed ast into a Miasm expression
-        @size: default size
-        @ast: typed ast
-        """
-        assert(isinstance(ast, tuple))
-        if ast[0] is m2_expr.ExprId:
-            expr = self.id_from_size(size, ast[1])
-            if isinstance(expr, str):
-                expr = self.id_from_size(size, expr)
-        elif ast[0] is m2_expr.ExprInt:
-            expr = self.int_from_size(size, ast[1])
-        elif ast[0] is m2_expr.ExprOp:
-            out = []
-            for arg in ast[1]:
-                if isinstance(arg, tuple):
-                    arg = self.ast_to_expr(size, arg)
-                out.append(arg)
-            expr = ast_parse_op(out)
-        else:
-            raise TypeError('unknown type')
-        return expr
-
-    def ast_get_ids(self, ast):
-        """Retrieve every node of type ExprId in @ast
-        @ast: typed ast
-        """
-        assert(isinstance(ast, tuple))
-        if ast[0] is m2_expr.ExprId:
-            return set([ast[1]])
-        elif ast[0] is m2_expr.ExprInt:
-            return set()
-        elif ast[0] is m2_expr.ExprOp:
-            out = set()
-            for x in ast[1]:
-                if isinstance(x, tuple):
-                    out.update(self.ast_get_ids(x))
-            return out
-        raise TypeError('unknown type')
-
-    def _extract_ast_core(self, ast):
-        assert(isinstance(ast, tuple))
-        if ast[0] in [m2_expr.ExprInt, m2_expr.ExprId]:
-            return ast
-        elif ast[0] is m2_expr.ExprOp:
-            out = []
-            for arg in ast[1]:
-                if isinstance(arg, tuple):
-                    arg = self._extract_ast_core(arg)
-                out.append(arg)
-            return tuple([ast[0]] + [out])
-        else:
-            raise TypeError('unknown type')
-
-    def extract_ast_core(self, ast):
-        """
-        Trasform an @ast into a Miasm expression.
-        Use registers size to deduce label and integers sizes.
-        """
-        ast = self._extract_ast_core(ast)
-        ids = self.ast_get_ids(ast)
-        ids_expr = [self.id2expr(x) for x in ids]
-        sizes = set([expr.size for expr in ids_expr
-                     if isinstance(expr, m2_expr.Expr)])
-        if not sizes:
-            size = None
-        elif len(sizes) == 1:
-            size = sizes.pop()
-        else:
-            # Multiple sizes in ids
-            raise StopIteration
-        return self.ast_to_expr(size, ast)
-
-    def __call__(self, ast):
-        """
-        Trasform an @ast into a Miasm expression.
-        Use registers size to deduce label and integers sizes.
-        """
-        ast = ast[0]
-        if isinstance(ast, m2_expr.Expr):
-            return ast
-        return self.extract_ast_core(ast)
-
-
 def neg_int(t):
     x = -t[0]
     return x
@@ -361,29 +254,154 @@ multop = pyparsing.oneOf('* / %')
 plusop = pyparsing.oneOf('+ -')
 
 
-def gen_base_expr():
-    variable = pyparsing.Word(pyparsing.alphas + "_$.",
-                              pyparsing.alphanums + "_")
-    variable.setParseAction(parse_id)
-    operand = str_int | variable
-    base_expr = pyparsing.operatorPrecedence(operand,
-                                   [("!", 1, pyparsing.opAssoc.RIGHT, parse_op),
-                                    (logicop, 2, pyparsing.opAssoc.RIGHT,
-                                     parse_op),
-                                    (signop, 1, pyparsing.opAssoc.RIGHT,
-                                     parse_op),
-                                    (multop, 2, pyparsing.opAssoc.LEFT,
-                                     parse_op),
-                                    (plusop, 2, pyparsing.opAssoc.LEFT,
-                                     parse_op),
-                                    ])
-    return variable, operand, base_expr
+##########################
 
+def literal_list(l):
+    l = l[:]
+    l.sort()
+    l = l[::-1]
+    o = pyparsing.Literal(l[0])
+    for x in l[1:]:
+        o |= pyparsing.Literal(x)
+    return o
+
+
+def cb_int(t):
+    assert len(t) == 1
+    integer = AstInt(t[0])
+    return integer
+
+
+def cb_parse_id(t):
+    assert len(t) == 1
+    reg = t[0]
+    return AstId(reg)
+
+
+def cb_op_not(t):
+    tokens = t[0]
+    assert len(tokens) == 2
+    assert tokens[0] == "!"
+    result = AstOp("!", tokens[1])
+    return result
+
+
+def merge_ops(tokens, op):
+    args = []
+    if len(tokens) >= 3:
+        args = [tokens.pop(0)]
+        i = 0
+        while i < len(tokens):
+            op_tmp = tokens[i]
+            arg = tokens[i+1]
+            i += 2
+            if op_tmp != op:
+                raise ValueError("Bad operator")
+            args.append(arg)
+    result = AstOp(op, *args)
+    return result
+
+
+def cb_op_and(t):
+    result = merge_ops(t[0], "&")
+    return result
+
+
+def cb_op_xor(t):
+    result = merge_ops(t[0], "^")
+    return result
+
+
+def cb_op_sign(t):
+    assert len(t) == 1
+    op, value = t[0]
+    return -value
+
+
+def cb_op_div(t):
+    tokens = t[0]
+    assert len(tokens) == 3
+    assert tokens[1] == "/"
+    result = AstOp("/", tokens[0], tokens[2])
+    return result
+
+
+def cb_op_plusminus(t):
+    tokens = t[0]
+    if len(tokens) == 3:
+        # binary op
+        assert isinstance(tokens[0], AstNode)
+        assert isinstance(tokens[2], AstNode)
+        op, args = tokens[1], [tokens[0], tokens[2]]
+    elif len(tokens) > 3:
+        args = [tokens.pop(0)]
+        i = 0
+        while i < len(tokens):
+            op = tokens[i]
+            arg = tokens[i+1]
+            i += 2
+            if op == '-':
+                arg = -arg
+            elif op == '+':
+                pass
+            else:
+                raise ValueError("Bad operator")
+            args.append(arg)
+        op = '+'
+    else:
+        raise ValueError("Parsing error")
+    assert all(isinstance(arg, AstNode) for arg in args)
+    result = AstOp(op, *args)
+    return result
 
-variable, operand, base_expr = gen_base_expr()
 
-my_var_parser = ParseAst(ast_id2expr, ast_int2expr)
-base_expr.setParseAction(my_var_parser)
+def cb_op_mul(t):
+    tokens = t[0]
+    assert len(tokens) == 3
+    assert isinstance(tokens[0], AstNode)
+    assert isinstance(tokens[2], AstNode)
+
+    # binary op
+    op, args = tokens[1], [tokens[0], tokens[2]]
+    result = AstOp(op, *args)
+    return result
+
+
+integer = pyparsing.Word(pyparsing.nums).setParseAction(lambda t: int(t[0]))
+hex_word = pyparsing.Literal('0x') + pyparsing.Word(pyparsing.hexnums)
+hex_int = pyparsing.Combine(hex_word).setParseAction(lambda t: int(t[0], 16))
+
+str_int_pos = (hex_int | integer)
+
+str_int = str_int_pos
+str_int.setParseAction(cb_int)
+
+notop = pyparsing.oneOf('!')
+andop = pyparsing.oneOf('&')
+orop = pyparsing.oneOf('|')
+xorop = pyparsing.oneOf('^')
+shiftop = pyparsing.oneOf('>> <<')
+rotop = pyparsing.oneOf('<<< >>>')
+signop = pyparsing.oneOf('+ -')
+mulop = pyparsing.oneOf('*')
+plusop = pyparsing.oneOf('+ -')
+divop = pyparsing.oneOf('/')
+
+
+variable = pyparsing.Word(pyparsing.alphas + "_$.", pyparsing.alphanums + "_")
+variable.setParseAction(cb_parse_id)
+operand = str_int | variable
+
+base_expr = pyparsing.operatorPrecedence(operand,
+                               [(notop,   1, pyparsing.opAssoc.RIGHT, cb_op_not),
+                                (andop, 2, pyparsing.opAssoc.RIGHT, cb_op_and),
+                                (xorop, 2, pyparsing.opAssoc.RIGHT, cb_op_xor),
+                                (signop,  1, pyparsing.opAssoc.RIGHT, cb_op_sign),
+                                (mulop,  2, pyparsing.opAssoc.RIGHT, cb_op_mul),
+                                (divop,  2, pyparsing.opAssoc.RIGHT, cb_op_div),
+                                (plusop,  2, pyparsing.opAssoc.LEFT, cb_op_plusminus),
+                                ])
+
 
 default_prio = 0x1337
 
@@ -656,7 +674,7 @@ class bs_swapargs(bs_divert):
 
 class m_arg(object):
 
-    def fromstring(self, text, parser_result=None):
+    def fromstring(self, text, symbol_pool, parser_result=None):
         if parser_result:
             e, start, stop = parser_result[self.parser]
             self.expr = e
@@ -665,9 +683,14 @@ class m_arg(object):
             v, start, stop = self.parser.scanString(text).next()
         except StopIteration:
             return None, None
-        self.expr = v[0]
+        arg = v[0]
+        expr = self.asm_ast_to_expr(arg, symbol_pool)
+        self.expr = expr
         return start, stop
 
+    def asm_ast_to_expr(self, arg, symbol_pool):
+        raise NotImplementedError("Virtual")
+
 
 class m_reg(m_arg):
     prio = default_prio
@@ -688,7 +711,7 @@ class reg_noarg(object):
     reg_info = None
     parser = None
 
-    def fromstring(self, text, parser_result=None):
+    def fromstring(self, text, symbol_pool, parser_result=None):
         if parser_result:
             e, start, stop = parser_result[self.parser]
             self.expr = e
@@ -697,7 +720,9 @@ class reg_noarg(object):
             v, start, stop = self.parser.scanString(text).next()
         except StopIteration:
             return None, None
-        self.expr = v[0]
+        arg = v[0]
+        expr = self.parses_to_expr(arg, symbol_pool)
+        self.expr = expr
         return start, stop
 
     def decode(self, v):
@@ -1252,7 +1277,7 @@ class cls_mn(object):
         return out[0]
 
     @classmethod
-    def fromstring(cls, text, mode = None):
+    def fromstring(cls, text, symbol_pool, mode = None):
         global total_scans
         name = re.search('(\S+)', text).groups()
         if not name:
@@ -1291,9 +1316,12 @@ class cls_mn(object):
                             v, start, stop = [None], None, None
                         if start != 0:
                             v, start, stop = [None], None, None
-                        parsers[(i, start_i)][p] = v[0], start, stop
-
-                    start, stop = f.fromstring(args_str, parsers[(i, start_i)])
+                        if v != [None]:
+                            v = f.asm_ast_to_expr(v[0], symbol_pool)
+                        if v is None:
+                            v, start, stop = [None], None, None
+                        parsers[(i, start_i)][p] = v, start, stop
+                    start, stop = f.fromstring(args_str, symbol_pool, parsers[(i, start_i)])
                     if start != 0:
                         log.debug("cannot fromstring %r", args_str)
                         cannot_parse = True
@@ -1532,7 +1560,7 @@ class imm_noarg(object):
             return None
         return v
 
-    def fromstring(self, text, parser_result=None):
+    def fromstring(self, text, symbol_pool, parser_result=None):
         if parser_result:
             e, start, stop = parser_result[self.parser]
         else:
@@ -1540,7 +1568,7 @@ class imm_noarg(object):
                 e, start, stop = self.parser.scanString(text).next()
             except StopIteration:
                 return None, None
-        if e is None:
+        if e == [None]:
             return None, None
 
         assert(isinstance(e, m2_expr.Expr))