about summary refs log tree commit diff stats
path: root/miasm2/arch/arm/arch.py
diff options
context:
space:
mode:
Diffstat (limited to 'miasm2/arch/arm/arch.py')
-rw-r--r--miasm2/arch/arm/arch.py202
1 files changed, 107 insertions, 95 deletions
diff --git a/miasm2/arch/arm/arch.py b/miasm2/arch/arm/arch.py
index b607b6c2..e09619ae 100644
--- a/miasm2/arch/arm/arch.py
+++ b/miasm2/arch/arm/arch.py
@@ -8,6 +8,7 @@ from collections import defaultdict
 from miasm2.core.bin_stream import bin_stream
 import miasm2.arch.arm.regs as regs_module
 from miasm2.arch.arm.regs import *
+from miasm2.core.asm_ast import AstInt, AstId, AstMem, AstOp
 
 # A1 encoding
 
@@ -20,7 +21,7 @@ log.setLevel(logging.DEBUG)
 # arm regs ##############
 reg_dum = ExprId('DumReg', 32)
 
-gen_reg('PC', globals())
+PC, _ = gen_reg('PC')
 
 # GP
 regs_str = ['R%d' % r for r in xrange(0x10)]
@@ -104,13 +105,13 @@ barrier_info = reg_info_dct(barrier_expr)
 
 # parser helper ###########
 
-def tok_reg_duo(s, l, t):
+def cb_tok_reg_duo(t):
     t = t[0]
-    i1 = gpregs.expr.index(t[0])
-    i2 = gpregs.expr.index(t[1])
+    i1 = gpregs.expr.index(t[0].name)
+    i2 = gpregs.expr.index(t[1].name)
     o = []
     for i in xrange(i1, i2 + 1):
-        o.append(gpregs.expr[i])
+        o.append(AstId(gpregs.expr[i]))
     return o
 
 LPARENTHESIS = Literal("(")
@@ -124,14 +125,14 @@ CIRCUNFLEX = Literal("^")
 
 def check_bounds(left_bound, right_bound, value):
     if left_bound <= value and value <= right_bound:
-        return ExprInt(value, 32)
+        return AstInt(value)
     else:
         raise ValueError('shift operator immediate value out of bound')
 
 
 def check_values(values, value):
     if value in values:
-        return ExprInt(value, 32)
+        return AstInt(value)
     else:
         raise ValueError('shift operator immediate value out of bound')
 
@@ -141,11 +142,11 @@ int_1_32 = str_int.copy().setParseAction(lambda v: check_bounds(1, 32, v[0]))
 int_8_16_24 = str_int.copy().setParseAction(lambda v: check_values([8, 16, 24], v[0]))
 
 
-def reglistparse(s, l, t):
+def cb_reglistparse(s, l, t):
     t = t[0]
     if t[-1] == "^":
-        return ExprOp('sbit', ExprOp('reglist', *t[:-1]))
-    return ExprOp('reglist', *t)
+        return AstOp('sbit', AstOp('reglist', *t[:-1]))
+    return AstOp('reglist', *t)
 
 
 allshifts = ['<<', '>>', 'a>>', '>>>', 'rrx']
@@ -161,11 +162,11 @@ def op_shift2expr(s, l, t):
     return shift2expr_dct[t[0]]
 
 reg_duo = Group(gpregs.parser + MINUS +
-                gpregs.parser).setParseAction(tok_reg_duo)
+                gpregs.parser).setParseAction(cb_tok_reg_duo)
 reg_or_duo = reg_duo | gpregs.parser
 gpreg_list = Group(LACC + delimitedList(
     reg_or_duo, delim=',') + RACC + Optional(CIRCUNFLEX))
-gpreg_list.setParseAction(reglistparse)
+gpreg_list.setParseAction(cb_reglistparse)
 
 LBRACK = Suppress("[")
 RBRACK = Suppress("]")
@@ -187,130 +188,116 @@ gpreg_p = gpregs.parser
 psr_p = cpsr_regs.parser | spsr_regs.parser
 
 
-def shift2expr(t):
+def cb_shift(t):
     if len(t) == 1:
         ret = t[0]
     elif len(t) == 2:
-        ret = ExprOp(t[1], t[0])
+        ret = AstOp(t[1], t[0])
     elif len(t) == 3:
-        ret = ExprOp(t[1], t[0], t[2])
+        ret = AstOp(t[1], t[0], t[2])
     else:
         raise ValueError("Bad arg")
     return ret
 
-variable, operand, base_expr = gen_base_expr()
-
-int_or_expr = base_expr
-
-
-def ast_id2expr(t):
-    return mn_arm.regs.all_regs_ids_byname.get(t, t)
-
-
-def ast_int2expr(a):
-    return ExprInt(a, 32)
-
-
-my_var_parser = ParseAst(ast_id2expr, ast_int2expr)
-base_expr.setParseAction(my_var_parser)
-
-
 shift_off = (gpregs.parser + Optional(
     (all_unaryop_shifts_t) |
     (all_binaryop_1_31_shifts_t + (gpregs.parser | int_1_31)) |
     (all_binaryop_1_32_shifts_t + (gpregs.parser | int_1_32))
-)).setParseAction(shift2expr)
+)).setParseAction(cb_shift)
 shift_off |= base_expr
 
 
 rot2_expr = (gpregs.parser + Optional(
     (ror_shifts_t + (int_8_16_24))
-)).setParseAction(shift2expr)
+)).setParseAction(cb_shift)
 
 
 OP_LSL = Suppress("LSL")
 
-def expr_deref_reg_reg(t):
+def cb_deref_reg_reg(t):
     if len(t) != 2:
         raise ValueError("Bad mem format")
-    return ExprMem(t[0] + t[1], 8)
+    return AstMem(AstOp('+', t[0], t[1]), 8)
 
-def expr_deref_reg_reg_lsl_1(t):
+def cb_deref_reg_reg_lsl_1(t):
     if len(t) != 3:
         raise ValueError("Bad mem format")
     reg1, reg2, index = t
-    if index != ExprInt(1, 32):
+    if not isinstance(index, AstInt) or index.value != 1:
         raise ValueError("Bad index")
-    ret = ExprMem(reg1 + (reg2 << index), 16)
+    ret = AstMem(AstOp('+', reg1, AstOp('<<', reg2, index)), 16)
     return ret
 
 
-deref_reg_reg = (LBRACK + gpregs.parser + COMMA + gpregs.parser + RBRACK).setParseAction(expr_deref_reg_reg)
-deref_reg_reg_lsl_1 = (LBRACK + gpregs.parser + COMMA + gpregs.parser + OP_LSL + base_expr + RBRACK).setParseAction(expr_deref_reg_reg_lsl_1)
+deref_reg_reg = (LBRACK + gpregs.parser + COMMA + gpregs.parser + RBRACK).setParseAction(cb_deref_reg_reg)
+deref_reg_reg_lsl_1 = (LBRACK + gpregs.parser + COMMA + gpregs.parser + OP_LSL + base_expr + RBRACK).setParseAction(cb_deref_reg_reg_lsl_1)
 
 
 
 (gpregs.parser + Optional(
     (ror_shifts_t + (int_8_16_24))
-)).setParseAction(shift2expr)
+)).setParseAction(cb_shift)
 
 
 
+reg_or_base = gpregs.parser | base_expr
+
 def deref2expr_nooff(s, l, t):
     t = t[0]
     # XXX default
     return ExprOp("preinc", t[0], ExprInt(0, 32))
 
 
-def deref2expr_pre(s, l, t):
+def cb_deref_preinc(t):
     t = t[0]
     if len(t) == 1:
-        return ExprOp("preinc", t[0], ExprInt(0, 32))
+        return AstOp("preinc", t[0], AstInt(0))
     elif len(t) == 2:
-        return ExprOp("preinc", t[0], t[1])
+        return AstOp("preinc", t[0], t[1])
     else:
         raise NotImplementedError('len(t) > 2')
 
 
-def deref2expr_pre_mem(s, l, t):
+def cb_deref_pre_mem(t):
     t = t[0]
     if len(t) == 1:
-        return ExprMem(ExprOp("preinc", t[0], ExprInt(0, 32)), 32)
+        return AstMem(AstOp("preinc", t[0], AstInt(0)), 32)
     elif len(t) == 2:
-        return ExprMem(ExprOp("preinc", t[0], t[1]), 32)
+        return AstMem(AstOp("preinc", t[0], t[1]), 32)
     else:
         raise NotImplementedError('len(t) > 2')
 
 
-def deref2expr_post(s, l, t):
+def cb_deref_post(t):
     t = t[0]
-    return ExprOp("postinc", t[0], t[1])
+    return AstOp("postinc", t[0], t[1])
 
 
-def deref_wb(s, l, t):
+def cb_deref_wb(t):
     t = t[0]
     if t[-1] == '!':
-        return ExprMem(ExprOp('wback', *t[:-1]), 32)
-    return ExprMem(t[0], 32)
+        return AstMem(AstOp('wback', *t[:-1]), 32)
+    return AstMem(t[0], 32)
 
 # shift_off.setParseAction(deref_off)
 deref_nooff = Group(
     LBRACK + gpregs.parser + RBRACK).setParseAction(deref2expr_nooff)
 deref_pre = Group(LBRACK + gpregs.parser + Optional(
-    COMMA + shift_off) + RBRACK).setParseAction(deref2expr_pre)
+    COMMA + shift_off) + RBRACK).setParseAction(cb_deref_preinc)
 deref_post = Group(LBRACK + gpregs.parser + RBRACK +
-                   COMMA + shift_off).setParseAction(deref2expr_post)
+                   COMMA + shift_off).setParseAction(cb_deref_post)
 deref = Group((deref_post | deref_pre | deref_nooff)
-              + Optional('!')).setParseAction(deref_wb)
+              + Optional('!')).setParseAction(cb_deref_wb)
 
 
-def parsegpreg_wb(s, l, t):
+def cb_gpreb_wb(t):
+    assert len(t) == 1
     t = t[0]
     if t[-1] == '!':
-        return ExprOp('wback', *t[:-1])
+        return AstOp('wback', *t[:-1])
     return t[0]
 
-gpregs_wb = Group(gpregs.parser + Optional('!')).setParseAction(parsegpreg_wb)
+gpregs_wb = Group(gpregs.parser + Optional('!')).setParseAction(cb_gpreb_wb)
 
 
 cond_list_full = ['EQ', 'NE', 'CS', 'CC', 'MI', 'PL', 'VS', 'VC',
@@ -780,7 +767,31 @@ class mn_armt(cls_mn):
         return 32
 
 
-class arm_reg(reg_noarg, m_arg):
+class arm_arg(m_arg):
+    def asm_ast_to_expr(self, arg, symbol_pool):
+        if isinstance(arg, AstId):
+            if isinstance(arg.name, ExprId):
+                return arg.name
+            if arg.name in gpregs.str:
+                return None
+            label = symbol_pool.getby_name_create(arg.name)
+            return ExprId(label, 32)
+        if isinstance(arg, AstOp):
+            args = [self.asm_ast_to_expr(tmp, symbol_pool) for tmp in arg.args]
+            if None in args:
+                return None
+            return ExprOp(arg.op, *args)
+        if isinstance(arg, AstInt):
+            return ExprInt(arg.value, 32)
+        if isinstance(arg, AstMem):
+            ptr = self.asm_ast_to_expr(arg.ptr, symbol_pool)
+            if ptr is None:
+                return None
+            return ExprMem(ptr, arg.size)
+        return None
+
+
+class arm_reg(reg_noarg, arm_arg):
     pass
 
 
@@ -820,7 +831,7 @@ class arm_reg_wb(arm_reg):
         return True
 
 
-class arm_psr(m_arg):
+class arm_psr(arm_arg):
     parser = psr_p
 
     def decode(self, v):
@@ -856,7 +867,7 @@ class arm_preg(arm_reg):
     parser = reg_info.parser
 
 
-class arm_imm(imm_noarg, m_arg):
+class arm_imm(imm_noarg, arm_arg):
     parser = base_expr
 
 
@@ -900,7 +911,7 @@ class arm_offs(arm_imm):
         return True
 
 
-class arm_imm8_12(m_arg):
+class arm_imm8_12(arm_arg):
     parser = deref
 
     def decode(self, v):
@@ -956,8 +967,8 @@ class arm_imm8_12(m_arg):
         return True
 
 
-class arm_imm_4_12(m_arg):
-    parser = base_expr
+class arm_imm_4_12(arm_arg):
+    parser = reg_or_base
 
     def decode(self, v):
         v = v & self.lmask
@@ -976,7 +987,7 @@ class arm_imm_4_12(m_arg):
         return True
 
 
-class arm_imm_12_4(m_arg):
+class arm_imm_12_4(arm_arg):
     parser = base_expr
 
     def decode(self, v):
@@ -996,7 +1007,7 @@ class arm_imm_12_4(m_arg):
         return True
 
 
-class arm_op2(m_arg):
+class arm_op2(arm_arg):
     parser = shift_off
 
     def str_to_imm_rot_form(self, s, neg=False):
@@ -1168,7 +1179,7 @@ class arm_op2imm(arm_imm8_12):
 
         # if len(v) <1:
         #    raise ValueError('cannot parse', s)
-        self.parent.rn.fromstring(e.args[0])
+        self.parent.rn.expr = e.args[0]
         if len(e.args) == 1:
             self.parent.immop.value = 0
             self.value = 0
@@ -1229,7 +1240,7 @@ def reglist2str(rlist):
     return "{" + ", ".join(out) + '}'
 
 
-class arm_rlist(m_arg):
+class arm_rlist(arm_arg):
     parser = gpreg_list
 
     def encode(self):
@@ -1436,7 +1447,7 @@ class mul_part_y(bs_mod_name):
 mul_x = mul_part_x(l=1, fname='x', mn_mod=['B', 'T'])
 mul_y = mul_part_y(l=1, fname='y', mn_mod=['B', 'T'])
 
-class arm_immed(m_arg):
+class arm_immed(arm_arg):
     parser = deref
 
     def decode(self, v):
@@ -1509,7 +1520,7 @@ immedL = bs(l=4, cls=(arm_immed, m_arg), fname='immedL')
 hb = bs(l=1)
 
 
-class armt2_rot_rm(m_arg):
+class armt2_rot_rm(arm_arg):
     parser = shift_off
     def decode(self, v):
         r = self.parent.rm.expr
@@ -1530,7 +1541,7 @@ class armt2_rot_rm(m_arg):
 rot_rm = bs(l=2, cls=(armt2_rot_rm,), fname="rot_rm")
 
 
-class arm_mem_rn_imm(m_arg):
+class arm_mem_rn_imm(arm_arg):
     parser = deref
     def decode(self, v):
         value = self.parent.imm.value
@@ -1695,7 +1706,7 @@ class arm_widthm1(arm_imm, m_arg):
         return True
 
 
-class arm_rm_rot2(m_arg):
+class arm_rm_rot2(arm_arg):
     parser = rot2_expr
     def decode(self, v):
         expr = gpregs.expr[v]
@@ -1755,12 +1766,12 @@ rot2 = bs(l=2, fname="rot2")
 widthm1 = bs(l=5, cls=(arm_widthm1, m_arg))
 lsb = bs(l=5, cls=(arm_imm, m_arg))
 
-rd_nopc = bs(l=4, cls=(arm_gpreg_nopc,m_arg), fname="rd")
-rn_nopc = bs(l=4, cls=(arm_gpreg_nopc,m_arg), fname="rn")
-ra_nopc = bs(l=4, cls=(arm_gpreg_nopc,m_arg), fname="ra")
-rt_nopc = bs(l=4, cls=(arm_gpreg_nopc,m_arg), fname="rt")
+rd_nopc = bs(l=4, cls=(arm_gpreg_nopc, arm_arg), fname="rd")
+rn_nopc = bs(l=4, cls=(arm_gpreg_nopc, arm_arg), fname="rn")
+ra_nopc = bs(l=4, cls=(arm_gpreg_nopc, arm_arg), fname="ra")
+rt_nopc = bs(l=4, cls=(arm_gpreg_nopc, arm_arg), fname="rt")
 
-rn_nosp = bs(l=4, cls=(arm_gpreg_nosp,m_arg), fname="rn")
+rn_nosp = bs(l=4, cls=(arm_gpreg_nosp, arm_arg), fname="rn")
 
 rn_nopc_noarg = bs(l=4, cls=(arm_gpreg_nopc,), fname="rn")
 
@@ -1783,22 +1794,22 @@ gpregs_sppc = reg_info(regs_str[-1:] + regs_str[13:14],
                        regs_expr[-1:] + regs_expr[13:14])
 
 deref_reg_imm = Group(LBRACK + gpregs.parser + Optional(
-    COMMA + shift_off) + RBRACK).setParseAction(deref2expr_pre_mem)
+    COMMA + shift_off) + RBRACK).setParseAction(cb_deref_pre_mem)
 deref_low = Group(LBRACK + gpregs_l.parser + Optional(
-    COMMA + shift_off) + RBRACK).setParseAction(deref2expr_pre_mem)
+    COMMA + shift_off) + RBRACK).setParseAction(cb_deref_pre_mem)
 deref_pc = Group(LBRACK + gpregs_pc.parser + Optional(
-    COMMA + shift_off) + RBRACK).setParseAction(deref2expr_pre_mem)
+    COMMA + shift_off) + RBRACK).setParseAction(cb_deref_pre_mem)
 deref_sp = Group(LBRACK + gpregs_sp.parser + COMMA +
-                 shift_off + RBRACK).setParseAction(deref2expr_pre_mem)
+                 shift_off + RBRACK).setParseAction(cb_deref_pre_mem)
 
 gpregs_l_wb = Group(
-    gpregs_l.parser + Optional('!')).setParseAction(parsegpreg_wb)
+    gpregs_l.parser + Optional('!')).setParseAction(cb_gpreb_wb)
 
 
 gpregs_l_13 = reg_info(regs_str[:13], regs_expr[:13])
 
 
-class arm_offreg(m_arg):
+class arm_offreg(arm_arg):
     parser = deref_pc
 
     def decodeval(self, v):
@@ -1909,7 +1920,7 @@ class arm_off7(arm_imm):
     def encodeval(self, v):
         return v >> 2
 
-class arm_deref_reg_imm(m_arg):
+class arm_deref_reg_imm(arm_arg):
     parser = deref_reg_imm
 
     def decode(self, v):
@@ -2010,7 +2021,7 @@ class arm_offh(imm_noarg):
         return True
 
 
-class armt_rlist(m_arg):
+class armt_rlist(arm_arg):
     parser = gpreg_list
 
     def encode(self):
@@ -2366,7 +2377,7 @@ armtop("sxth", [bs('10110010'), bs('00'), rml, rdl], [rdl, rml])
 #
 # ARM Architecture Reference Manual Thumb-2 Supplement
 
-armt_gpreg_shift_off = (gpregs_nosppc.parser + allshifts_t_armt + (gpregs.parser | int_1_31)).setParseAction(shift2expr)
+armt_gpreg_shift_off = (gpregs_nosppc.parser + allshifts_t_armt + (gpregs.parser | int_1_31)).setParseAction(cb_shift)
 
 
 armt_gpreg_shift_off |= gpregs_nosppc.parser
@@ -2783,7 +2794,7 @@ aif_expr = [ExprId(x, 32) if x != None else None for x in aif_str]
 
 aif_reg = reg_info(aif_str, aif_expr)
 
-class armt_aif(reg_noarg, m_arg):
+class armt_aif(reg_noarg, arm_arg):
     reg_info = aif_reg
     parser = reg_info.parser
 
@@ -2798,14 +2809,14 @@ class armt_aif(reg_noarg, m_arg):
             return ret
         return self.value != 0
 
-    def fromstring(self, text, parser_result=None):
-        start, stop = super(armt_aif, self).fromstring(text, parser_result)
+    def fromstring(self, text, symbol_pool, parser_result=None):
+        start, stop = super(armt_aif, self).fromstring(text, symbol_pool, parser_result)
         if self.expr.name == "X":
             return None, None
         return start, stop
 
 
-class armt_it_arg(m_arg):
+class armt_it_arg(arm_arg):
     arg_E = ExprId('E', 1)
     arg_NE = ExprId('NE', 1)
 
@@ -2878,7 +2889,7 @@ class armt_cond_lsb(bs_divert):
 cond_expr = [ExprId(x, 32) for x in cond_list_full]
 cond_info = reg_info(cond_list_full, cond_expr)
 
-class armt_cond_arg(m_arg):
+class armt_cond_arg(arm_arg):
     parser = cond_info.parser
 
     def decode(self, v):
@@ -2948,7 +2959,8 @@ class armt_op2imm(arm_imm8_12):
             # XXX default
             self.parent.ppi.value = 1
 
-        self.parent.rn.fromstring(e.args[0])
+        self.parent.rn.expr = e.args[0]
+
         if len(e.args) == 1:
             self.value = 0
             return True
@@ -3033,7 +3045,7 @@ class armt_deref_reg(arm_imm8_12):
         return True
 
 
-class armt_deref_reg_reg(m_arg):
+class armt_deref_reg_reg(arm_arg):
     parser = deref_reg_reg
     reg_info = gpregs
 
@@ -3116,7 +3128,7 @@ bs_deref_reg_reg = bs(l=4, cls=(armt_deref_reg_reg,))
 bs_deref_reg_reg_lsl_1 = bs(l=4, cls=(armt_deref_reg_reg_lsl_1,))
 
 
-class armt_barrier_option(reg_noarg, m_arg):
+class armt_barrier_option(reg_noarg, arm_arg):
     reg_info = barrier_info
     parser = reg_info.parser