about summary refs log tree commit diff stats
path: root/miasm2/arch/x86/arch.py
diff options
context:
space:
mode:
authorFabrice Desclaux <fabrice.desclaux@cea.fr>2018-05-05 23:13:12 +0200
committerFabrice Desclaux <fabrice.desclaux@cea.fr>2018-05-14 10:29:27 +0200
commit94d49ed54f07e3d399de74de13f5422837c031fa (patch)
treeb3f7fd34c7ff8d17bd9f26d53511b30935485092 /miasm2/arch/x86/arch.py
parentdb4fd7f58d6a4ed87fc7d6f28c7c2af31e61fb65 (diff)
downloadmiasm-94d49ed54f07e3d399de74de13f5422837c031fa.tar.gz
miasm-94d49ed54f07e3d399de74de13f5422837c031fa.zip
Core: updt parser structure
Diffstat (limited to 'miasm2/arch/x86/arch.py')
-rw-r--r--miasm2/arch/x86/arch.py363
1 files changed, 176 insertions, 187 deletions
diff --git a/miasm2/arch/x86/arch.py b/miasm2/arch/x86/arch.py
index 34a765e8..9310ce1d 100644
--- a/miasm2/arch/x86/arch.py
+++ b/miasm2/arch/x86/arch.py
@@ -8,6 +8,8 @@ from collections import defaultdict
 import miasm2.arch.x86.regs as regs_module
 from miasm2.arch.x86.regs import *
 from miasm2.core.asmblock import AsmLabel
+from miasm2.core.asm_ast import AstNode, AstInt, AstId, AstMem, AstOp
+
 
 log = logging.getLogger("x86_arch")
 console_handler = logging.StreamHandler()
@@ -121,148 +123,82 @@ replace_regs = {16: replace_regs16,
                 64: replace_regs64}
 
 
-# parser helper ###########
-PLUS = Suppress("+")
-MULT = Suppress("*")
-
-COLON = Suppress(":")
-
-
-LBRACK = Suppress("[")
-RBRACK = Suppress("]")
-
-dbreg = Group(gpregs16.parser | gpregs32.parser | gpregs64.parser)
-gpreg = (gpregs08.parser | gpregs08_64.parser | gpregs16.parser   |
-         gpregs32.parser | gpregs64.parser    | gpregs_xmm.parser |
-         gpregs_mm.parser | gpregs_bnd.parser)
-
-
-def reg2exprid(r):
-    if not r.name in all_regs_ids_byname:
-        raise ValueError('unknown reg')
-    return all_regs_ids_byname[r.name]
-
-
-def parse_deref_reg(s, l, t):
-    t = t[0][0]
-    return t[0]
-
-
-def parse_deref_int(s, l, t):
-    t = t[0]
-    return t[0]
-
-
-def parse_deref_regint(s, l, t):
-    t = t[0]
-    r1 = reg2exprid(t[0][0])
-    i1 = ExprInt(t[1].arg, r1.size)
-    return r1 + i1
-
-
-def parse_deref_regreg(s, l, t):
-    t = t[0]
-    return t[0][0] + t[1][0]
-
-
-def parse_deref_regregint(s, l, t):
-    t = t[0]
-    r1 = reg2exprid(t[0][0])
-    r2 = reg2exprid(t[1][0])
-    i1 = ExprInt(t[2].arg, r1.size)
-    return r1 + r2 + i1
+segm2enc = {CS: 1, SS: 2, DS: 3, ES: 4, FS: 5, GS: 6}
+enc2segm = dict([(x[1], x[0]) for x in segm2enc.items()])
 
+segm_info = reg_info_dct(enc2segm)
 
-def parse_deref_reg_intmreg(s, l, t):
-    t = t[0]
-    r1 = reg2exprid(t[0][0])
-    r2 = reg2exprid(t[1][0])
-    i1 = ExprInt(t[2].arg, r1.size)
-    return r1 + (r2 * i1)
 
 
-def parse_deref_reg_intmreg_int(s, l, t):
-    t = t[0]
-    r1 = reg2exprid(t[0][0])
-    r2 = reg2exprid(t[1][0])
-    i1 = ExprInt(t[2].arg, r1.size)
-    i2 = ExprInt(t[3].arg, r1.size)
-    return r1 + (r2 * i1) + i2
+enc2crx = {
+    0: cr0,
+    1: cr1,
+    2: cr2,
+    3: cr3,
+    4: cr4,
+    5: cr5,
+    6: cr6,
+    7: cr7,
+}
 
+crx_info = reg_info_dct(enc2crx)
 
-def parse_deref_intmreg(s, l, t):
-    t = t[0]
-    r1 = reg2exprid(t[0][0])
-    i1 = ExprInt(t[1].arg, r1.size)
-    return r1 * i1
 
+enc2drx = {
+    0: dr0,
+    1: dr1,
+    2: dr2,
+    3: dr3,
+    4: dr4,
+    5: dr5,
+    6: dr6,
+    7: dr7,
+}
 
-def parse_deref_intmregint(s, l, t):
-    t = t[0]
-    r1 = reg2exprid(t[0][0])
-    i1 = ExprInt(t[1].arg, r1.size)
-    i2 = ExprInt(t[1].arg, r1.size)
-    return (r1 * i1) + i2
+drx_info = reg_info_dct(enc2drx)
 
 
-def getreg(s, l, t):
-    t = t[0]
-    return t[0]
 
+# parser helper ###########
+PLUS = Suppress("+")
+MULT = Suppress("*")
 
-def parse_deref_ptr(s, l, t):
-    t = t[0]
-    return ExprMem(ExprOp('segm', t[0], t[1]))
+COLON = Suppress(":")
 
-def parse_deref_segmoff(s, l, t):
-    t = t[0]
-    return ExprOp('segm', t[0], t[1])
 
+LBRACK = Suppress("[")
+RBRACK = Suppress("]")
 
-variable, operand, base_expr = gen_base_expr()
 
+gpreg = (
+    gpregs08.parser |
+    gpregs08_64.parser |
+    gpregs16.parser |
+    gpregs32.parser |
+    gpregs64.parser |
+    gpregs_xmm.parser |
+    gpregs_mm.parser |
+    gpregs_bnd.parser
+)
 
-def ast_id2expr(t):
-    return mn_x86.regs.all_regs_ids_byname.get(t, t)
 
 
-def ast_int2expr(a):
-    return ExprInt(a, 64)
 
+def cb_deref_segmoff(t):
+    assert len(t) == 2
+    return AstOp('segm', t[0], t[1])
 
-my_var_parser = ParseAst(ast_id2expr, ast_int2expr)
-base_expr.setParseAction(my_var_parser)
 
-int_or_expr = base_expr
+def cb_deref_base_expr(t):
+    tokens = t[0]
+    assert isinstance(tokens, AstNode)
+    addr = tokens
+    return addr
 
-deref_mem_ad = Group(LBRACK + dbreg + RBRACK).setParseAction(parse_deref_reg)
-deref_mem_ad |= Group(
-    LBRACK + int_or_expr + RBRACK).setParseAction(parse_deref_int)
-deref_mem_ad |= Group(
-    LBRACK + dbreg + PLUS +
-    int_or_expr + RBRACK).setParseAction(parse_deref_regint)
-deref_mem_ad |= Group(
-    LBRACK + dbreg + PLUS +
-    dbreg + RBRACK).setParseAction(parse_deref_regreg)
-deref_mem_ad |= Group(
-    LBRACK + dbreg + PLUS + dbreg + PLUS +
-    int_or_expr + RBRACK).setParseAction(parse_deref_regregint)
-deref_mem_ad |= Group(
-    LBRACK + dbreg + PLUS + dbreg + MULT +
-    int_or_expr + RBRACK).setParseAction(parse_deref_reg_intmreg)
-deref_mem_ad |= Group(
-    LBRACK + dbreg + PLUS + dbreg + MULT + int_or_expr +
-    PLUS + int_or_expr + RBRACK).setParseAction(parse_deref_reg_intmreg_int)
-deref_mem_ad |= Group(
-    LBRACK + dbreg + MULT +
-    int_or_expr + RBRACK).setParseAction(parse_deref_intmreg)
-deref_mem_ad |= Group(
-    LBRACK + dbreg + MULT + int_or_expr +
-    PLUS + int_or_expr + RBRACK).setParseAction(parse_deref_intmregint)
 
+deref_mem_ad = (LBRACK + base_expr + RBRACK).setParseAction(cb_deref_base_expr)
 
-deref_ptr = Group(int_or_expr + COLON +
-                  int_or_expr).setParseAction(parse_deref_segmoff)
+deref_ptr = (base_expr + COLON + base_expr).setParseAction(cb_deref_segmoff)
 
 
 PTR = Suppress('PTR')
@@ -282,31 +218,30 @@ MEMPREFIX2SIZE = {'BYTE': 8, 'WORD': 16, 'DWORD': 32,
 
 SIZE2MEMPREFIX = dict((x[1], x[0]) for x in MEMPREFIX2SIZE.items())
 
-def parse_deref_mem(s, l, t):
-    t = t[0]
+def cb_deref_mem(t):
     if len(t) == 2:
         s, ptr = t
-        return ExprMem(ptr, MEMPREFIX2SIZE[s[0]])
+        assert isinstance(ptr, AstNode)
+        return AstMem(ptr, MEMPREFIX2SIZE[s])
     elif len(t) == 3:
         s, segm, ptr = t
-        return ExprMem(ExprOp('segm', segm[0], ptr), MEMPREFIX2SIZE[s[0]])
-    else:
-        raise ValueError('len(t) > 3')
+        return AstMem(AstOp('segm', segm, ptr), MEMPREFIX2SIZE[s])
+    raise ValueError('len(t) > 3')
 
-mem_size = Group(BYTE | DWORD | QWORD | WORD | TBYTE | XMMWORD)
-deref_mem = Group(mem_size + PTR + Optional(Group(int_or_expr + COLON))
-                  + deref_mem_ad).setParseAction(parse_deref_mem)
+mem_size = (BYTE | DWORD | QWORD | WORD | TBYTE | XMMWORD)
+deref_mem = (mem_size + PTR + Optional((base_expr + COLON))+ deref_mem_ad).setParseAction(cb_deref_mem)
 
 
-rmarg = Group(gpregs08.parser |
-              gpregs08_64.parser |
-              gpregs16.parser |
-              gpregs32.parser |
-              gpregs64.parser |
-              gpregs_mm.parser |
-              gpregs_xmm.parser |
-              gpregs_bnd.parser
-              ).setParseAction(getreg)
+rmarg = (
+    gpregs08.parser |
+    gpregs08_64.parser |
+    gpregs16.parser |
+    gpregs32.parser |
+    gpregs64.parser |
+    gpregs_mm.parser |
+    gpregs_xmm.parser |
+    gpregs_bnd.parser
+)
 
 rmarg |= deref_mem
 
@@ -314,36 +249,89 @@ rmarg |= deref_mem
 mem_far = FAR + deref_mem
 
 
-cl_or_imm = Group(r08_ecx.parser).setParseAction(getreg)
-cl_or_imm |= int_or_expr
+cl_or_imm = r08_ecx.parser
+cl_or_imm |= base_expr
+
+
 
+class x86_arg(m_arg):
+    def asm_ast_to_expr(self, value, symbol_pool, size_hint=None, fixed_size=None):
+        if size_hint is None:
+            size_hint = self.parent.v_opmode()
+        if fixed_size is None:
+            fixed_size = set()
+        if isinstance(value, AstId):
+            if value.name in all_regs_ids_byname:
+                reg = all_regs_ids_byname[value.name]
+                fixed_size.add(reg.size)
+                return reg
+            if isinstance(value.name, ExprId):
+                fixed_size.add(value.name.size)
+                return value.name
+            if value.name in MEMPREFIX2SIZE:
+                return None
+            if value.name in ["FAR"]:
+                return None
+
+            label = symbol_pool.getby_name_create(value.name)
+            return ExprId(label, size_hint)
+        if isinstance(value, AstOp):
+            # First pass to retreive fixed_size
+            if value.op == "segm":
+                segm = self.asm_ast_to_expr(value.args[0], symbol_pool)
+                ptr = self.asm_ast_to_expr(value.args[1], symbol_pool, None, fixed_size)
+                return ExprOp('segm', segm, ptr)
+            args = [self.asm_ast_to_expr(arg, symbol_pool, None, fixed_size) for arg in value.args]
+            if len(fixed_size) == 0:
+                # No fixed size
+                pass
+            elif len(fixed_size) == 1:
+                # One fixed size, regen all
+                size = list(fixed_size)[0]
+                args = [self.asm_ast_to_expr(arg, symbol_pool, size, fixed_size) for arg in value.args]
+            else:
+                raise ValueError("Size conflict")
+            if None in args:
+                return None
+            return ExprOp(value.op, *args)
+        if isinstance(value, AstInt):
+            if 1 << size_hint < value.value:
+                size_hint *= 2
+            return ExprInt(value.value, size_hint)
+        if isinstance(value, AstMem):
+            fixed_size.add(value.size)
+            ptr = self.asm_ast_to_expr(value.ptr, symbol_pool, None, set())
+            if ptr is None:
+                return None
+            return ExprMem(ptr, value.size)
+        return None
 
-class r_al(reg_noarg, m_arg):
+class r_al(reg_noarg, x86_arg):
     reg_info = r08_eax
     parser = reg_info.parser
 
 
-class r_ax(reg_noarg, m_arg):
+class r_ax(reg_noarg, x86_arg):
     reg_info = r16_eax
     parser = reg_info.parser
 
 
-class r_dx(reg_noarg, m_arg):
+class r_dx(reg_noarg, x86_arg):
     reg_info = r16_edx
     parser = reg_info.parser
 
 
-class r_eax(reg_noarg, m_arg):
+class r_eax(reg_noarg, x86_arg):
     reg_info = r32_eax
     parser = reg_info.parser
 
 
-class r_rax(reg_noarg, m_arg):
+class r_rax(reg_noarg, x86_arg):
     reg_info = r64_eax
     parser = reg_info.parser
 
 
-class r_cl(reg_noarg, m_arg):
+class r_cl(reg_noarg, x86_arg):
     reg_info = r08_ecx
     parser = reg_info.parser
 
@@ -442,9 +430,6 @@ repeat_mn = ["INS", "OUTS",
              "CMPSB", "CMPSW", "CMPSD", "CMPSQ",
              ]
 
-segm2enc = {CS: 1, SS: 2, DS: 3, ES: 4, FS: 5, GS: 6}
-enc2segm = dict([(x[1], x[0]) for x in segm2enc.items()])
-
 
 class group:
 
@@ -685,7 +670,7 @@ class mn_x86(cls_mn):
         return [(subcls, name, bases, dct, fields)]
 
     @classmethod
-    def fromstring(cls, text, mode):
+    def fromstring(cls, text, symbol_pool, mode):
         pref = 0
         prefix, new_s = get_prefix(text)
         if prefix == "LOCK":
@@ -697,7 +682,7 @@ class mn_x86(cls_mn):
         elif prefix == "REPE":
             pref |= 4
             text = new_s
-        c = super(mn_x86, cls).fromstring(text, mode)
+        c = super(mn_x86, cls).fromstring(text, symbol_pool, mode)
         c.additional_info.g1.value = pref
         return c
 
@@ -1224,7 +1209,7 @@ class x86_s32to64(x86_s08to32):
         return ExprInt(x, 64)
 
 
-class bs_eax(m_arg):
+class bs_eax(x86_arg):
     reg_info = r_eax_all
     rindex = 0
     parser = reg_info.parser
@@ -1264,7 +1249,7 @@ class bs_eax(m_arg):
             return False
         return False
 
-class bs_seg(m_arg):
+class bs_seg(x86_arg):
     reg_info = r_eax_all
     rindex = 0
     parser = reg_info.parser
@@ -1326,7 +1311,7 @@ class bs_gs(bs_seg):
     parser = reg_info.parser
 
 
-class x86_reg_st(reg_noarg, m_arg):
+class x86_reg_st(reg_noarg, x86_arg):
     reg_info = r_st_all
     parser = reg_info.parser
 
@@ -1934,11 +1919,11 @@ def modrm2expr(modrm, parent, w8, sx=0, xmm=0, mm=0, bnd=0):
     return expr
 
 
-class x86_rm_arg(m_arg):
+class x86_rm_arg(x86_arg):
     parser = rmarg
 
-    def fromstring(self, text, parser_result=None):
-        start, stop = super(x86_rm_arg, self).fromstring(text, parser_result)
+    def fromstring(self, text, symbol_pool, parser_result=None):
+        start, stop = super(x86_rm_arg, self).fromstring(text, symbol_pool, parser_result)
         p = self.parent
         if start is None:
             return None, None
@@ -2073,9 +2058,9 @@ class x86_rm_arg(m_arg):
             yield x
 
 class x86_rm_mem(x86_rm_arg):
-    def fromstring(self, text, parser_result=None):
+    def fromstring(self, text, symbol_pool, parser_result=None):
         self.expr = None
-        start, stop = super(x86_rm_mem, self).fromstring(text, parser_result)
+        start, stop = super(x86_rm_mem, self).fromstring(text, symbol_pool, parser_result)
         if not isinstance(self.expr, ExprMem):
             return None, None
         return start, stop
@@ -2083,9 +2068,9 @@ class x86_rm_mem(x86_rm_arg):
 
 class x86_rm_mem_far(x86_rm_arg):
     parser = mem_far
-    def fromstring(self, text, parser_result=None):
+    def fromstring(self, text, symbol_pool, parser_result=None):
         self.expr = None
-        start, stop = super(x86_rm_mem_far, self).fromstring(text, parser_result)
+        start, stop = super(x86_rm_mem_far, self).fromstring(text, symbol_pool, parser_result)
         if not isinstance(self.expr, ExprMem):
             return None, None
         self.expr = ExprOp('far', self.expr)
@@ -2455,24 +2440,28 @@ class x86_rm_reg_noarg(object):
 
     parser = gpreg
 
-    def fromstring(self, text, parser_result=None):
+    def fromstring(self, text, symbol_pool, parser_result=None):
         if not hasattr(self.parent, 'sx') and hasattr(self.parent, "w8"):
             self.parent.w8.value = 1
         if parser_result:
-            e, start, stop = parser_result[self.parser]
-            if e is None:
+            result, start, stop = parser_result[self.parser]
+            if result == [None]:
                 return None, None
-            self.expr = e
+            self.expr = result
             if self.expr.size == 8:
                 if hasattr(self.parent, 'sx') or not hasattr(self.parent, 'w8'):
                     return None, None
                 self.parent.w8.value = 0
             return start, stop
         try:
-            v, start, stop = self.parser.scanString(text).next()
+            result, start, stop = self.parser.scanString(text).next()
         except StopIteration:
             return None, None
-        self.expr = v[0]
+        expr = self.asm_ast_to_expr(result[0], symbol_pool)
+        if expr is None:
+            return None, None
+
+        self.expr = expr
         if self.expr.size == 0:
             if hasattr(self.parent, 'sx') or not hasattr(self.parent, 'w8'):
                 return None, None
@@ -2541,7 +2530,7 @@ class x86_rm_reg_noarg(object):
         return True
 
 
-class x86_rm_reg_mm(x86_rm_reg_noarg, m_arg):
+class x86_rm_reg_mm(x86_rm_reg_noarg, x86_arg):
     selreg = gpregs_mm
     def decode(self, v):
         if self.parent.mode == 64 and self.getrexsize():
@@ -2571,7 +2560,7 @@ class x86_rm_reg_xmm(x86_rm_reg_mm):
 class x86_rm_reg_bnd(x86_rm_reg_mm):
     selreg = gpregs_bnd
 
-class x86_rm_reg(x86_rm_reg_noarg, m_arg):
+class x86_rm_reg(x86_rm_reg_noarg, x86_arg):
     pass
 
 
@@ -2603,25 +2592,25 @@ class x86_reg_noarg(x86_rm_reg_noarg):
         self.parent.rex_b.value = v
 
 
-class x86_rm_segm(reg_noarg, m_arg):
+class x86_rm_segm(reg_noarg, x86_arg):
     prio = default_prio + 1
     reg_info = segmreg
     parser = reg_info.parser
 
 
-class x86_rm_cr(reg_noarg, m_arg):
+class x86_rm_cr(reg_noarg, x86_arg):
     prio = default_prio + 1
     reg_info = crregs
     parser = reg_info.parser
 
 
-class x86_rm_dr(reg_noarg, m_arg):
+class x86_rm_dr(reg_noarg, x86_arg):
     prio = default_prio + 1
     reg_info = drregs
     parser = reg_info.parser
 
 
-class x86_rm_flt(reg_noarg, m_arg):
+class x86_rm_flt(reg_noarg, x86_arg):
     prio = default_prio + 1
     reg_info = fltregs
     parser = reg_info.parser
@@ -2634,7 +2623,7 @@ class bs_fbit(bsi):
         return True
 
 
-class bs_cl1(bsi, m_arg):
+class bs_cl1(bsi, x86_arg):
     parser = cl_or_imm
 
     def decode(self, v):
@@ -2751,11 +2740,11 @@ class bs_cond_disp(bs_cond):
         return True
 
 
-class bs_cond_imm(bs_cond_scale, m_arg):
-    parser = int_or_expr
+class bs_cond_imm(bs_cond_scale, x86_arg):
+    parser = base_expr
     max_size = 32
 
-    def fromstring(self, text, parser_result=None):
+    def fromstring(self, text, symbol_pool, parser_result=None):
         if parser_result:
             expr, start, stop = parser_result[self.parser]
         else:
@@ -2880,9 +2869,9 @@ class bs_cond_imm64(bs_cond_imm):
 
 
 class bs_rel_off(bs_cond_imm):
-    parser = int_or_expr
+    parser = base_expr
 
-    def fromstring(self, text, parser_result=None):
+    def fromstring(self, text, symbol_pool, parser_result=None):
         if parser_result:
             expr, start, stop = parser_result[self.parser]
         else:
@@ -2940,7 +2929,7 @@ class bs_rel_off(bs_cond_imm):
         return True
 
 class bs_s08(bs_rel_off):
-    parser = int_or_expr
+    parser = base_expr
 
     @classmethod
     def flen(cls, mode, v):
@@ -3021,10 +3010,10 @@ class bs_moff(bsi):
         return True
 
 
-class bs_movoff(m_arg):
+class bs_movoff(x86_arg):
     parser = deref_mem
 
-    def fromstring(self, s, parser_result=None):
+    def fromstring(self, text, symbol_pool, parser_result=None):
         if parser_result:
             e, start, stop = parser_result[self.parser]
             if e is None:
@@ -3088,10 +3077,10 @@ class bs_movoff(m_arg):
         return True
 
 
-class bs_msegoff(m_arg):
+class bs_msegoff(x86_arg):
     parser = deref_ptr
 
-    def fromstring(self, s, parser_result=None):
+    def fromstring(self, text, symbol_pool, parser_result=None):
         if parser_result:
             e, start, stop = parser_result[self.parser]
             if e is None:
@@ -3172,13 +3161,13 @@ disp = bs(l=0, cls=(bs_cond_disp,), fname = "disp")
 
 s08 = bs(l=8, cls=(bs_s08, ))
 
-u08 = bs(l=8, cls=(x86_08, m_arg))
-u07 = bs(l=7, cls=(x86_08, m_arg))
-u16 = bs(l=16, cls=(x86_16, m_arg))
-u32 = bs(l=32, cls=(x86_32, m_arg))
-s3264 = bs(l=32, cls=(x86_s32to64, m_arg))
+u08 = bs(l=8, cls=(x86_08, x86_arg))
+u07 = bs(l=7, cls=(x86_08, x86_arg))
+u16 = bs(l=16, cls=(x86_16, x86_arg))
+u32 = bs(l=32, cls=(x86_32, x86_arg))
+s3264 = bs(l=32, cls=(x86_s32to64, x86_arg))
 
-u08_3 = bs(l=0, cls=(x86_imm_fix_08, m_arg), ival = 3)
+u08_3 = bs(l=0, cls=(x86_imm_fix_08, x86_arg), ival = 3)
 
 d0 = bs("000", fname='reg')
 d1 = bs("001", fname='reg')