#!/usr/bin/env python #-*- coding:utf-8 -*- import re from miasm2.expression.expression import * from pyparsing import * from miasm2.core.cpu import * from collections import defaultdict import regs as regs_module from regs import * from miasm2.ir.ir import * log = logging.getLogger("x86_arch") console_handler = logging.StreamHandler() console_handler.setFormatter(logging.Formatter("%(levelname)-5s: %(message)s")) log.addHandler(console_handler) log.setLevel(logging.WARN) f_isad = "AD" f_s08 = "S08" f_u08 = "U08" f_s16 = "S16" f_u16 = "U16" f_s32 = "S32" f_u32 = "U32" f_s64 = "S64" f_u64 = "U64" f_imm = 'IMM' f_imm2size = {f_s08: 8, f_s16: 16, f_s32: 32, f_s64: 64, f_u08: 8, f_u16: 16, f_u32: 32, f_u64: 64} size2gpregs = {8: gpregs08, 16: gpregs16, 32: gpregs32, 64: gpregs64} replace_regs64 = { AL: RAX[:8], CL: RCX[:8], DL: RDX[:8], BL: RBX[:8], AH: RAX[8:16], CH: RCX[8:16], DH: RDX[8:16], BH: RBX[8:16], SPL: RSP[0:8], BPL: RBP[0:8], SIL: RSI[0:8], DIL: RDI[0:8], R8B: R8[0:8], R9B: R9[0:8], R10B: R10[0:8], R11B: R11[0:8], R12B: R12[0:8], R13B: R13[0:8], R14B: R14[0:8], R15B: R15[0:8], AX: RAX[:16], CX: RCX[:16], DX: RDX[:16], BX: RBX[:16], SP: RSP[:16], BP: RBP[:16], SI: RSI[:16], DI: RDI[:16], R8W: R8[:16], R9W: R9[:16], R10W: R10[:16], R11W: R11[:16], R12W: R12[:16], R13W: R13[:16], R14W: R14[:16], R15W: R15[:16], EAX: RAX[:32], ECX: RCX[:32], EDX: RDX[:32], EBX: RBX[:32], ESP: RSP[:32], EBP: RBP[:32], ESI: RSI[:32], EDI: RDI[:32], R8D: R8[:32], R9D: R9[:32], R10D: R10[:32], R11D: R11[:32], R12D: R12[:32], R13D: R13[:32], R14D: R14[:32], R15D: R15[:32], IP: RIP[:16], EIP: RIP[:32], } replace_regs32 = { AL: EAX[:8], CL: ECX[:8], DL: EDX[:8], BL: EBX[:8], AH: EAX[8:16], CH: ECX[8:16], DH: EDX[8:16], BH: EBX[8:16], AX: EAX[:16], CX: ECX[:16], DX: EDX[:16], BX: EBX[:16], SP: ESP[:16], BP: EBP[:16], SI: ESI[:16], DI: EDI[:16], IP: EIP[:16] } replace_regs16 = { AL: AX[:8], CL: CX[:8], DL: DX[:8], BL: BX[:8], AH: AX[8:16], CH: CX[8:16], DH: DX[8:16], BH: BX[8:16], AX: AX[:16], CX: CX[:16], DX: DX[:16], BX: BX[:16], SP: SP[:16], BP: BP[:16], SI: SI[:16], DI: DI[:16], } replace_regs = {16: replace_regs16, 32: replace_regs32, 64: replace_regs64} # parser helper ########### PLUS = Suppress("+") MULT = Suppress("*") COLON = Suppress(":") LBRACK = Suppress("[") RBRACK = Suppress("]") dbreg = Group(gpregs16.parser | gpregs32.parser | gpregs64.parser) gpreg = (gpregs08.parser | gpregs08_64.parser | gpregs16.parser | gpregs32.parser | gpregs64.parser | gpregs_xmm.parser | gpregs_mm.parser) def reg2exprid(r): if not r.name in all_regs_ids_byname: raise ValueError('unknown reg') return all_regs_ids_byname[r.name] def parse_deref_reg(s, l, t): t = t[0][0] return t[0] def parse_deref_int(s, l, t): t = t[0] return t[0] def parse_deref_regint(s, l, t): t = t[0] r1 = reg2exprid(t[0][0]) i1 = ExprInt_from(r1, t[1].arg) return r1 + i1 def parse_deref_regreg(s, l, t): t = t[0] return t[0][0] + t[1][0] def parse_deref_regregint(s, l, t): t = t[0] r1 = reg2exprid(t[0][0]) r2 = reg2exprid(t[1][0]) i1 = ExprInt_from(r1, t[2].arg) return r1 + r2 + i1 def parse_deref_reg_intmreg(s, l, t): t = t[0] r1 = reg2exprid(t[0][0]) r2 = reg2exprid(t[1][0]) i1 = ExprInt_from(r1, t[2].arg) return r1 + (r2 * i1) def parse_deref_reg_intmreg_int(s, l, t): t = t[0] r1 = reg2exprid(t[0][0]) r2 = reg2exprid(t[1][0]) i1 = ExprInt_from(r1, t[2].arg) i2 = ExprInt_from(r1, t[3].arg) return r1 + (r2 * i1) + i2 def parse_deref_intmreg(s, l, t): t = t[0] r1 = reg2exprid(t[0][0]) i1 = ExprInt_from(r1, t[1].arg) return r1 * i1 def parse_deref_intmregint(s, l, t): t = t[0] r1 = reg2exprid(t[0][0]) i1 = ExprInt_from(r1, t[1].arg) i2 = ExprInt_from(r1, t[1].arg) return (r1 * i1) + i2 def getreg(s, l, t): t = t[0] return t[0] def parse_deref_ptr(s, l, t): t = t[0] return ExprMem(ExprOp('segm', t[0], t[1])) variable, operand, base_expr = gen_base_expr() def ast_id2expr(t): if not t in mn_x86.regs.all_regs_ids_byname: r = ExprId(t) else: r = mn_x86.regs.all_regs_ids_byname[t] return r def ast_int2expr(a): return ExprInt64(a) my_var_parser = parse_ast(ast_id2expr, ast_int2expr) base_expr.setParseAction(my_var_parser) int_or_expr = base_expr deref_mem_ad = Group(LBRACK + dbreg + RBRACK).setParseAction(parse_deref_reg) deref_mem_ad |= Group( LBRACK + int_or_expr + RBRACK).setParseAction(parse_deref_int) deref_mem_ad |= Group( LBRACK + dbreg + PLUS + int_or_expr + RBRACK).setParseAction(parse_deref_regint) deref_mem_ad |= Group( LBRACK + dbreg + PLUS + dbreg + RBRACK).setParseAction(parse_deref_regreg) deref_mem_ad |= Group( LBRACK + dbreg + PLUS + dbreg + PLUS + int_or_expr + RBRACK).setParseAction(parse_deref_regregint) deref_mem_ad |= Group( LBRACK + dbreg + PLUS + dbreg + MULT + int_or_expr + RBRACK).setParseAction(parse_deref_reg_intmreg) deref_mem_ad |= Group( LBRACK + dbreg + PLUS + dbreg + MULT + int_or_expr + PLUS + int_or_expr + RBRACK).setParseAction(parse_deref_reg_intmreg_int) deref_mem_ad |= Group( LBRACK + dbreg + MULT + int_or_expr + RBRACK).setParseAction(parse_deref_intmreg) deref_mem_ad |= Group( LBRACK + dbreg + MULT + int_or_expr + PLUS + int_or_expr + RBRACK).setParseAction(parse_deref_intmregint) deref_ptr = Group(int_or_expr + COLON + int_or_expr).setParseAction(parse_deref_ptr) PTR = Suppress('PTR') BYTE = Literal('BYTE') WORD = Literal('WORD') DWORD = Literal('DWORD') QWORD = Literal('QWORD') TBYTE = Literal('TBYTE') def parse_deref_mem(s, l, t): sz = {'BYTE': 8, 'WORD': 16, 'DWORD': 32, 'QWORD': 64, 'TBYTE': 80} t = t[0] if len(t) == 2: s, ptr = t return ExprMem(ptr, sz[s[0]]) elif len(t) == 3: s, segm, ptr = t return ExprMem(ExprOp('segm', segm[0], ptr), sz[s[0]]) else: raise ValueError('len(t) > 3') mem_size = Group(BYTE | DWORD | QWORD | WORD | TBYTE) deref_mem = Group(mem_size + PTR + Optional(Group(int_or_expr + COLON)) + deref_mem_ad).setParseAction(parse_deref_mem) rmarg = Group(gpregs08.parser | gpregs08_64.parser | gpregs16.parser | gpregs32.parser | gpregs64.parser | gpregs_mm.parser | gpregs_xmm.parser ).setParseAction(getreg) rmarg |= deref_mem cl_or_imm = Group(r08_ecx.parser).setParseAction(getreg) cl_or_imm |= int_or_expr class r_al(reg_noarg, m_arg): reg_info = r08_eax parser = reg_info.parser class r_ax(reg_noarg, m_arg): reg_info = r16_eax parser = reg_info.parser class r_dx(reg_noarg, m_arg): reg_info = r16_edx parser = reg_info.parser class r_eax(reg_noarg, m_arg): reg_info = r32_eax parser = reg_info.parser class r_rax(reg_noarg, m_arg): reg_info = r64_eax parser = reg_info.parser class r_cl(reg_noarg, m_arg): reg_info = r08_ecx parser = reg_info.parser invmode = {16: 32, 32: 16} def opmode_prefix(mode): size, opmode, admode = mode if size in [16, 32]: if opmode: return invmode[size] else: return size elif size == 64: if opmode: return 16 else: return 32 raise NotImplementedError('not fully functional') def admode_prefix(mode): size, opmode, admode = mode if size in [16, 32]: if admode: return invmode[size] else: return size elif size == 64: return 64 raise NotImplementedError('not fully functional') def v_opmode_info(size, opmode, rex_w, stk): if size in [16, 32]: if opmode: return invmode[size] else: return size elif size == 64: if rex_w == 1: return 64 elif stk: if opmode == 1: return 16 else: return 64 elif opmode == 1: return 16 return 32 def v_opmode(p): stk = hasattr(p, 'stk') return v_opmode_info(p.mode, p.opmode, p.rex_w.value, stk) def v_admode_info(size, admode): if size in [16, 32]: if admode: return invmode[size] else: return size elif size == 64: if admode == 1: return 32 return 64 def v_admode(p): return v_admode_info(p.mode, p.admode) def offsize(p): if p.opmode: return 16 else: return p.mode def get_prefix(s): g = re.search('(\S+)(\s+)', s) if not g: return None, s prefix, b = g.groups() return prefix, s[len(prefix) + len(b):] repeat_mn = ["INS", "OUTS", "MOVSB", "MOVSW", "MOVSD", "MOVSQ", "SCASB", "SCASW", "SCASD", "SCASQ", "LODSB", "LODSW", "LODSD", "LODSQ", "STOSB", "STOSW", "STOSD", "STOSQ", "CMPSB", "CMPSW", "CMPSD", "CMPSQ", ] segm2enc = {CS: 1, SS: 2, DS: 3, ES: 4, FS: 5, GS: 6} enc2segm = dict([(x[1], x[0]) for x in segm2enc.items()]) class group: def __init__(self): self.value = None class additional_info: def __init__(self): self.except_on_instr = False self.g1 = group() self.g2 = group() self.vopmode = None self.stk = False self.v_opmode = None self.v_admode = None self.prefixed = '' class instruction_x86(instruction): delayslot = 0 def __init__(self, *args, **kargs): super(instruction_x86, self).__init__(*args, **kargs) self.additional_info.stk = hasattr(self, 'stk') def v_opmode(self): return self.additional_info.v_opmode def v_admode(self): return self.additional_info.v_admode def dstflow(self): if self.name.startswith('J'): return True if self.name.startswith('LOOP'): return True # repxx yyy generate split flow # if self.g1.value & 6 and self.name in repeat_mn: # return True return self.name in ['CALL'] def dstflow2label(self, symbol_pool): if self.additional_info.g1.value & 6 and self.name in repeat_mn: return e = self.args[0] if isinstance(e, ExprId) and not e.name in all_regs_ids_byname: l = symbol_pool.getby_name_create(e.name) s = ExprId(l, e.size) self.args[0] = s elif isinstance(e, ExprInt): ad = e.arg + int(self.offset) + self.l l = symbol_pool.getby_offset_create(ad) s = ExprId(l, e.size) self.args[0] = s else: return def breakflow(self): if self.name.startswith('J'): return True if self.name.startswith('LOOP'): return True if self.name.startswith('RET'): return True if self.name.startswith('INT'): return True if self.name.startswith('SYS'): return True # repxx yyy generate split flow # if self.g1.value & 6 and self.name in repeat_mn: # return True return self.name in ['CALL', 'HLT', 'IRET', 'ICEBP'] def splitflow(self): if self.name.startswith('JMP'): return False if self.name.startswith('J'): return True if self.name.startswith('LOOP'): return True if self.name.startswith('INT'): return True if self.name.startswith('SYS'): return True # repxx yyy generate split flow # if self.g1.value & 6 and self.name in repeat_mn: # return True return self.name in ['CALL'] def setdstflow(self, a): return def is_subcall(self): return self.name in ['CALL'] def getdstflow(self, symbol_pool): if self.additional_info.g1.value & 6 and self.name in repeat_mn: ad = int(self.offset) l = symbol_pool.getby_offset_create(ad) # XXX size ??? s = ExprId(l, self.v_opmode()) return [s] return [self.args[0]] def get_symbol_size(self, symbol, symbol_pool): return self.mode def fixDstOffset(self): e = self.args[0] if self.offset is None: raise ValueError('symbol not resolved %s' % l) if not isinstance(e, ExprInt): # raise ValueError('dst must be int or label') log.warning('dynamic dst %r' % e) return # return ExprInt32(e.arg - (self.offset + self.l)) self.args[0] = ExprInt_fromsize( self.mode, e.arg - (self.offset + self.l)) def get_info(self, c): self.additional_info.g1.value = c.g1.value self.additional_info.g2.value = c.g2.value self.additional_info.v_opmode = c.v_opmode() self.additional_info.v_admode = c.v_admode() self.additional_info.prefix = c.prefix self.additional_info.prefixed = getattr(c, "prefixed", "") def __str__(self): o = super(instruction_x86, self).__str__() if self.additional_info.g1.value & 1: o = "LOCK %s" % o if self.additional_info.g1.value & 2: if getattr(self.additional_info.prefixed, 'default', "") != "\xF2": o = "REPNE %s" % o if self.additional_info.g1.value & 4: if getattr(self.additional_info.prefixed, 'default', "") != "\xF3": o = "REPE %s" % o return o def get_args_expr(self): args = [] for a in self.args: a = a.replace_expr(replace_regs[self.mode]) args.append(a) return args class mn_x86(cls_mn): name = "x86" prefix_op_size = False prefix_ad_size = False regs = regs_module all_mn = [] all_mn_mode = defaultdict(list) all_mn_name = defaultdict(list) all_mn_inst = defaultdict(list) bintree = {} num = 0 delayslot = 0 pc = {16: IP, 32: EIP, 64: RIP} sp = {16: SP, 32: ESP, 64: RSP} instruction = instruction_x86 max_instruction_len = 15 @classmethod def getpc(cls, attrib): return cls.pc[attrib] @classmethod def getsp(cls, attrib): return cls.sp[attrib] def v_opmode(self): if hasattr(self, 'stk'): stk = 1 else: stk = 0 return v_opmode_info(self.mode, self.opmode, self.rex_w.value, stk) def v_admode(self): size, opmode, admode = self.mode, self.opmode, self.admode if size in [16, 32]: if admode: return invmode[size] else: return size elif size == 64: if admode == 1: return 32 return 64 def additional_info(self): info = additional_info() info.g1.value = self.g1.value info.g2.value = self.g2.value info.v_opmode = self.v_opmode() info.prefixed = "" if hasattr(self, 'prefixed'): info.prefixed = self.prefixed.default return info @classmethod def check_mnemo(cls, fields): pass @classmethod def getmn(cls, name): return name.upper() @classmethod def mod_fields(cls, fields): prefix = [d_g1, d_g2, d_rex_p, d_rex_w, d_rex_r, d_rex_x, d_rex_b] return prefix + fields @classmethod def gen_modes(cls, subcls, name, bases, dct, fields): dct['mode'] = None return [(subcls, name, bases, dct, fields)] @classmethod def fromstring(cls, s, mode): pref = 0 prefix, new_s = get_prefix(s) if prefix == "LOCK": pref |= 1 s = new_s elif prefix == "REPNE": pref |= 2 s = new_s elif prefix == "REPE": pref |= 4 s = new_s c = super(mn_x86, cls).fromstring(s, mode) c.additional_info.g1.value = pref return c @classmethod def pre_dis(cls, v, mode, offset): offset_o = offset pre_dis_info = {'opmode': 0, 'admode': 0, 'g1': 0, 'g2': 0, 'rex_p': 0, 'rex_w': 0, 'rex_r': 0, 'rex_x': 0, 'rex_b': 0, 'prefix': "", 'prefixed': "", } while True: c = v.getbytes(offset) if c == '\x66': # pre_dis_info.opmode = 1 pre_dis_info['opmode'] = 1 elif c == '\x67': pre_dis_info['admode'] = 1 elif c == '\xf0': pre_dis_info['g1'] = 1 elif c == '\xf2': pre_dis_info['g1'] = 2 elif c == '\xf3': pre_dis_info['g1'] = 4 elif c == '\x2e': pre_dis_info['g2'] = 1 elif c == '\x36': pre_dis_info['g2'] = 2 elif c == '\x3e': pre_dis_info['g2'] = 3 elif c == '\x26': pre_dis_info['g2'] = 4 elif c == '\x64': pre_dis_info['g2'] = 5 elif c == '\x65': pre_dis_info['g2'] = 6 elif mode == 64 and c in '@ABCDEFGHIJKLMNO': x = ord(c) pre_dis_info['rex_p'] = 1 pre_dis_info['rex_w'] = (x >> 3) & 1 pre_dis_info['rex_r'] = (x >> 2) & 1 pre_dis_info['rex_x'] = (x >> 1) & 1 pre_dis_info['rex_b'] = (x >> 0) & 1 offset += 1 break else: c = '' break pre_dis_info['prefix'] += c offset += 1 # pre_dis_info.b = v[:offset] return pre_dis_info, v, mode, offset, offset - offset_o @classmethod def get_cls_instance(cls, cc, mode, infos=None): for opmode in [0, 1]: for admode in [0, 1]: # c = cls.all_mn_inst[cc][0] c = cc() c.init_class() c.reset_class() c.add_pre_dis_info() c.dup_info(infos) c.mode = mode c.opmode = opmode c.admode = admode if hasattr(c, "fopmode") and c.fopmode.mode == 64: c.rex_w.value = 1 yield c def post_dis(self): if self.g2.value: for a in self.args: if not isinstance(a.expr, ExprMem): continue m = a.expr a.expr = ExprMem( ExprOp('segm', enc2segm[self.g2.value], m.arg), m.size) if self.name == 'LEA': if not isinstance(self.args[1].expr, ExprMem): return None return self def dup_info(self, infos): if infos is not None: self.g1.value = infos.g1.value self.g2.value = infos.g2.value def reset_class(self): super(mn_x86, self).reset_class() # self.rex_w.value, self.rex_b.value, # self.rex_x.value = None, None, None # self.opmode.value, self.admode.value = None, None if hasattr(self, "opmode"): del(self.opmode) if hasattr(self, "admode"): del(self.admode) # self.opmode = 0 # self.admode = 0 def add_pre_dis_info(self, pre_dis_info=None): # print 'add_pre_dis_info', pre_dis_info if pre_dis_info is None: return True if hasattr(self, "prefixed") and self.prefixed.default == "\x66": pre_dis_info['opmode'] = 0 # if self.opmode != 0: # return False # if pre_dis_info['opmode'] != self.opmode: # return False # if pre_dis_info['admode'] != self.admode: # return False self.opmode = pre_dis_info['opmode'] self.admode = pre_dis_info['admode'] if hasattr(self, 'no_xmm_pref') and\ pre_dis_info['prefix'] and\ pre_dis_info['prefix'][-1] in '\x66\xf2\xf3': return False if (hasattr(self, "prefixed") and not pre_dis_info['prefix'].endswith(self.prefixed.default)): return False # print self.rex_w.value, pre_dis_info['rex_w'] # print 'rex', self.rex_w.value, self.rex_b.value, self.rex_x.value if (self.rex_w.value is not None and self.rex_w.value != pre_dis_info['rex_w']): return False else: self.rex_w.value = pre_dis_info['rex_w'] self.rex_r.value = pre_dis_info['rex_r'] self.rex_b.value = pre_dis_info['rex_b'] self.rex_x.value = pre_dis_info['rex_x'] self.rex_p.value = pre_dis_info['rex_p'] self.g1.value = pre_dis_info['g1'] self.g2.value = pre_dis_info['g2'] self.prefix = pre_dis_info['prefix'] # self.prefixed = pre_dis_info['prefixed'] """ if hasattr(self, "p_"): self.prefixed = self.p_.default if self.p_.default == "\x66": pre_dis_info['opmode'] = 0 if self.opmode != 0: return False #self.pre_dis_info = pre_dis_info """ return True def post_asm(self, v): return v def encodefields(self, decoded): v = super(mn_x86, self).encodefields(decoded) rex = 0x40 if self.g1.value is None: self.g1.value = 0 if self.g2.value is None: self.g2.value = 0 if self.rex_w.value: rex |= 0x8 if self.rex_r.value: rex |= 0x4 if self.rex_x.value: rex |= 0x2 if self.rex_b.value: rex |= 0x1 if rex != 0x40 or self.rex_p.value == 1: v = chr(rex) + v if hasattr(self, 'prefixed'): v = self.prefixed.default + v if self.g1.value & 1: v = "\xf0" + v if self.g1.value & 2: if hasattr(self, 'no_xmm_pref'): return None v = "\xf2" + v if self.g1.value & 4: if hasattr(self, 'no_xmm_pref'): return None v = "\xf3" + v if self.g2.value: v = {1: '\x2e', 2: '\x36', 3: '\x3e', 4: '\x26', 5: '\x64', 6: '\x65'}[self.g2.value] + v # mode prefix if hasattr(self, "admode") and self.admode: v = "\x67" + v if hasattr(self, "opmode") and self.opmode: if hasattr(self, 'no_xmm_pref'): return None v = "\x66" + v return v def getnextflow(self, symbol_pool): raise NotImplementedError('not fully functional') return self.offset + 4 def ir_pre_instruction(self): return [ExprAff(mRIP[self.mode], ExprInt_from(mRIP[self.mode], self.offset + self.l))] @classmethod def filter_asm_candidates(cls, instr, candidates): cand_same_mode = [] cand_diff_mode = [] out = [] for c, v in candidates: if (hasattr(c, 'no_xmm_pref') and (c.g1.value & 2 or c.g1.value & 4 or c.opmode)): continue if hasattr(c, "fopmode") and v_opmode(c) != c.fopmode.mode: # print 'DROP', c, v_opmode(c), c.fopmode.mode continue if hasattr(c, "fadmode") and v_admode(c) != c.fadmode.mode: # print 'DROP', c, v_opmode(c), c.fopmode.mode continue # relative dstflow must not have opmode set # (affect IP instead of EIP for instance) if (instr.dstflow() and instr.name not in ["JCXZ", "JECXZ", "JRCXZ"] and len(instr.args) == 1 and isinstance(instr.args[0], ExprInt) and c.opmode): continue out.append((c, v)) candidates = out # return [x[1][0] for x in candidates] for c, v in candidates: if v_opmode(c) == instr.mode: cand_same_mode += v for c, v in candidates: if v_opmode(c) != instr.mode: cand_diff_mode += v cand_same_mode.sort(key=lambda x: len(x)) cand_diff_mode.sort(key=lambda x: len(x)) return cand_same_mode + cand_diff_mode class bs8(bs): prio = default_prio def __init__(self, v, cls=None, fname=None, **kargs): super(bs8, self).__init__(int2bin(v, 8), 8, cls=cls, fname=fname, **kargs) class bs_modname_size(bs_divert): prio = 1 def divert(self, i, candidates): out = [] for candidate in candidates: cls, name, bases, dct, fields = candidate fopmode = opmode_prefix( (dct['mode'], dct['opmode'], dct['admode'])) mode = dct['mode'] size, opmode, admode = dct['mode'], dct['opmode'], dct['admode'] # no mode64 existance in name means no 64bit version of mnemo if mode == 64: if mode in self.args['name']: nfields = fields[:] f, i = getfieldindexby_name(nfields, 'rex_w') # f = bs("1", l=0, fname = 'rex_w') f = bs("1", l=0, cls=(bs_fbit,), fname="rex_w") osize = v_opmode_info(size, opmode, 1, 0) nfields[i] = f nfields = nfields[:-1] args = dict(self.args) ndct = dict(dct) if osize in self.args['name']: ndct['name'] = self.args['name'][osize] out.append((cls, ndct['name'], bases, ndct, nfields)) nfields = fields[:] nfields = nfields[:-1] f, i = getfieldindexby_name(nfields, 'rex_w') # f = bs("0", l=0, fname = 'rex_w') f = bs("0", l=0, cls=(bs_fbit,), fname="rex_w") osize = v_opmode_info(size, opmode, 0, 0) nfields[i] = f args = dict(self.args) ndct = dict(dct) if osize in self.args['name']: ndct['name'] = self.args['name'][osize] out.append((cls, ndct['name'], bases, ndct, nfields)) else: l = opmode_prefix((dct['mode'], dct['opmode'], dct['admode'])) osize = v_opmode_info(size, opmode, None, 0) nfields = fields[:-1] args = dict(self.args) ndct = dict(dct) if osize in self.args['name']: ndct['name'] = self.args['name'][osize] out.append((cls, ndct['name'], bases, ndct, nfields)) return out class bs_modname_jecx(bs_divert): prio = 1 def divert(self, i, candidates): out = [] for candidate in candidates: cls, name, bases, dct, fields = candidate fopmode = opmode_prefix( (dct['mode'], dct['opmode'], dct['admode'])) mode = dct['mode'] size, opmode, admode = dct['mode'], dct['opmode'], dct['admode'] nfields = fields[:] nfields = nfields[:-1] args = dict(self.args) ndct = dict(dct) if mode == 64: if admode: ndct['name'] = "JECXZ" else: ndct['name'] = "JRCXZ" elif mode == 32: if admode: ndct['name'] = "JCXZ" else: ndct['name'] = "JECXZ" elif mode == 16: if admode: ndct['name'] = "JECXZ" else: ndct['name'] = "JCXZ" else: raise ValueError('unhandled mode') out.append((cls, ndct['name'], bases, ndct, nfields)) return out class bs_modname_mode(bs_divert): prio = 1 def divert(self, i, candidates): out = [] for candidate in candidates: cls, name, bases, dct, fields = candidate fopmode = opmode_prefix( (dct['mode'], dct['opmode'], dct['admode'])) size, opmode, admode = dct['mode'], dct['opmode'], dct['admode'] mode = dct['mode'] l = opmode_prefix((dct['mode'], dct['opmode'], dct['admode'])) osize = v_opmode_info(size, opmode, None, 0) nfields = fields[:-1] args = dict(self.args) ndct = dict(dct) if mode == 64 or osize == 32: ndct['name'] = self.args['name'][mode] else: ndct['name'] = self.args['name'][16] out.append((cls, ndct['name'], bases, ndct, nfields)) return out class x86_imm(imm_noarg): parser = base_expr def decodeval(self, v): return swap_uint(self.l, v) def encodeval(self, v): return swap_uint(self.l, v) class x86_imm_fix(imm_noarg): parser = base_expr def decodeval(self, v): return self.ival def encodeval(self, v): if v != self.ival: return False return self.ival class x86_08(x86_imm): intsize = 8 intmask = (1 << intsize) - 1 class x86_16(x86_imm): intsize = 16 intmask = (1 << intsize) - 1 class x86_32(x86_imm): intsize = 32 intmask = (1 << intsize) - 1 class x86_64(x86_imm): intsize = 64 intmask = (1 << intsize) - 1 class x86_08_ne(x86_imm): intsize = 8 intmask = (1 << intsize) - 1 def encode(self): return True def decode(self, v): v = swap_uint(self.l, v) p = self.parent admode = p.v_admode() e = sign_ext(v, self.intsize, admode) e = ExprInt_fromsize(admode, e) self.expr = e return True class x86_16_ne(x86_08_ne): intsize = 16 intmask = (1 << intsize) - 1 class x86_32_ne(x86_08_ne): intsize = 32 intmask = (1 << intsize) - 1 class x86_64_ne(x86_08_ne): intsize = 64 intmask = (1 << intsize) - 1 class x86_s08to16(x86_imm): in_size = 8 out_size = 16 def myexpr(self, x): return ExprInt16(x) def int2expr(self, v): return self.myexpr(v) def expr2int(self, e): if not isinstance(e, ExprInt): return None v = int(e.arg) if v & ~((1 << self.l) - 1) != 0: return None return v def decode(self, v): v = v & self.lmask v = self.decodeval(v) if self.parent.v_opmode() == 64: self.expr = ExprInt64(sign_ext(v, self.in_size, 64)) else: if (1 << (self.l - 1)) & v: v = sign_ext(v, self.l, self.out_size) self.expr = self.myexpr(v) return True def encode(self): if not isinstance(self.expr, ExprInt): return False v = int(self.expr.arg) opmode = self.parent.v_opmode() out_size = self.out_size if opmode != self.out_size: if opmode == 32 and self.out_size == 64: out_size = opmode if v == sign_ext( int(v & ((1 << self.in_size) - 1)), self.in_size, out_size): pass else: # print 'cannot encode1', hex(v), # print hex(sign_ext(int(v&((1<> 6) & 3, (c >> 3) & 7, c & 7 def setmodrm(mod, re, rm): return ((mod & 3) << 6) | ((re & 7) << 3) | (rm & 7) def sib(c): return modrm(c) db_afs_64 = [] sib_64_s08_ebp = [] def gen_modrm_form(): global db_afs_64, sib_64_s08_ebp ebp = 5 sib_s08_ebp = [{f_isad: True} for i in range(0x100)] sib_u32_ebp = [{f_isad: True} for i in range(0x100)] sib_u32 = [{f_isad: True} for i in range(0x100)] sib_u64 = [] for rex_x in xrange(2): o = [] for rex_b in xrange(2): x = [{f_isad: True} for i in range(0x100)] o.append(x) sib_u64.append(o) sib_u64_ebp = [] for rex_x in xrange(2): o = [] for rex_b in xrange(2): x = [{f_isad: True} for i in range(0x100)] o.append(x) sib_u64_ebp.append(o) sib_64_s08_ebp = [] for rex_x in xrange(2): o = [] for rex_b in xrange(2): x = [{f_isad: True} for i in range(0x100)] o.append(x) sib_64_s08_ebp.append(o) for sib_rez in [sib_s08_ebp, sib_u32_ebp, sib_u32, sib_64_s08_ebp, sib_u64_ebp, sib_u64, ]: for index in range(0x100): ss, i, b = getmodrm(index) if b == 0b101: if sib_rez == sib_s08_ebp: sib_rez[index][f_imm] = f_s08 sib_rez[index][ebp] = 1 elif sib_rez == sib_u32_ebp: sib_rez[index][f_imm] = f_u32 sib_rez[index][ebp] = 1 elif sib_rez == sib_u32: sib_rez[index][f_imm] = f_u32 elif sib_rez == sib_u64_ebp: for rex_b in xrange(2): for rex_x in xrange(2): sib_rez[rex_x][rex_b][index][f_imm] = f_u32 sib_rez[rex_x][rex_b][index][ebp + 8 * rex_b] = 1 elif sib_rez == sib_u64: for rex_b in xrange(2): for rex_x in xrange(2): sib_rez[rex_x][rex_b][index][f_imm] = f_u32 elif sib_rez == sib_64_s08_ebp: for rex_b in xrange(2): for rex_x in xrange(2): sib_rez[rex_x][rex_b][index][f_imm] = f_s08 sib_rez[rex_x][rex_b][index][ebp + 8 * rex_b] = 1 else: if sib_rez == sib_s08_ebp: sib_rez[index][b] = 1 sib_rez[index][f_imm] = f_s08 elif sib_rez == sib_u32_ebp: sib_rez[index][b] = 1 sib_rez[index][f_imm] = f_u32 elif sib_rez == sib_u32: sib_rez[index][b] = 1 elif sib_rez == sib_u64_ebp: for rex_b in xrange(2): for rex_x in xrange(2): sib_rez[rex_x][rex_b][index][b + 8 * rex_b] = 1 sib_rez[rex_x][rex_b][index][f_imm] = f_u32 elif sib_rez == sib_u64: for rex_b in xrange(2): for rex_x in xrange(2): sib_rez[rex_x][rex_b][index][b + 8 * rex_b] = 1 elif sib_rez == sib_64_s08_ebp: for rex_b in xrange(2): for rex_x in xrange(2): sib_rez[rex_x][rex_b][index][f_imm] = f_s08 sib_rez[rex_x][rex_b][index][b + 8 * rex_b] = 1 if i == 0b100 and sib_rez in [sib_s08_ebp, sib_u32_ebp, sib_u32]: continue if sib_rez in [sib_s08_ebp, sib_u32_ebp, sib_u32]: tmp = i if not tmp in sib_rez[index]: sib_rez[index][tmp] = 0 # 1 << ss sib_rez[index][tmp] += 1 << ss else: for rex_b in xrange(2): for rex_x in xrange(2): tmp = i + 8 * rex_x if i == 0b100 and rex_x == 0: continue if not tmp in sib_rez[rex_x][rex_b][index]: sib_rez[rex_x][rex_b][index][tmp] = 0 # 1 << ss sib_rez[rex_x][rex_b][index][tmp] += 1 << ss # 32bit db_afs_32 = [None for i in range(0x100)] for i in range(0x100): index = i mod, re, rm = getmodrm(i) if mod == 0b00: if rm == 0b100: db_afs_32[index] = sib_u32 elif rm == 0b101: db_afs_32[index] = {f_isad: True, f_imm: f_u32} else: db_afs_32[index] = {f_isad: True, rm: 1} elif mod == 0b01: if rm == 0b100: db_afs_32[index] = sib_s08_ebp continue tmp = {f_isad: True, rm: 1, f_imm: f_s08} db_afs_32[index] = tmp elif mod == 0b10: if rm == 0b100: db_afs_32[index] = sib_u32_ebp else: db_afs_32[index] = {f_isad: True, rm: 1, f_imm: f_u32} elif mod == 0b11: db_afs_32[index] = {f_isad: False, rm: 1} # 64bit db_afs_64 = [None for i in range(0x400)] for i in range(0x400): index = i rex_x = (index >> 9) & 1 rex_b = (index >> 8) & 1 mod, re, rm = getmodrm(i & 0xff) if mod == 0b00: if rm == 0b100: db_afs_64[i] = sib_u64[rex_x][rex_b] elif rm == 0b101: db_afs_64[i] = {f_isad: True, f_imm: f_u32, 16: 1} else: db_afs_64[i] = {f_isad: True, rm + 8 * rex_b: 1} elif mod == 0b01: if rm == 0b100: db_afs_64[i] = sib_64_s08_ebp[rex_x][rex_b] continue tmp = {f_isad: True, rm + 8 * rex_b: 1, f_imm: f_s08} db_afs_64[i] = tmp elif mod == 0b10: if rm == 0b100: db_afs_64[i] = sib_u64_ebp[rex_x][rex_b] else: db_afs_64[i] = {f_isad: True, rm + 8 * rex_b: 1, f_imm: f_u32} elif mod == 0b11: db_afs_64[i] = {f_isad: False, rm + 8 * rex_b: 1} # 16bit db_afs_16 = [None for i in range(0x100)] _si = 6 _di = 7 _bx = 3 _bp = 5 for i in range(0x100): index = i mod, re, rm = getmodrm(i) if mod == 0b00: if rm == 0b100: db_afs_16[index] = {f_isad: True, _si: 1} elif rm == 0b101: db_afs_16[index] = {f_isad: True, _di: 1} elif rm == 0b110: db_afs_16[index] = { f_isad: True, f_imm: f_u16} # {f_isad:True,_bp:1} elif rm == 0b111: db_afs_16[index] = {f_isad: True, _bx: 1} else: db_afs_16[index] = {f_isad: True, [_si, _di][rm % 2]: 1, [_bx, _bp][(rm >> 1) % 2]: 1} elif mod in [0b01, 0b10]: if mod == 0b01: my_imm = f_s08 else: my_imm = f_u16 if rm == 0b100: db_afs_16[index] = {f_isad: True, _si: 1, f_imm: my_imm} elif rm == 0b101: db_afs_16[index] = {f_isad: True, _di: 1, f_imm: my_imm} elif rm == 0b110: db_afs_16[index] = {f_isad: True, _bp: 1, f_imm: my_imm} elif rm == 0b111: db_afs_16[index] = {f_isad: True, _bx: 1, f_imm: my_imm} else: db_afs_16[index] = {f_isad: True, [_si, _di][rm % 2]: 1, [_bx, _bp][(rm >> 1) % 2]: 1, f_imm: my_imm} elif mod == 0b11: db_afs_16[index] = {f_isad: False, rm: 1} byte2modrm = {} byte2modrm[16] = db_afs_16 byte2modrm[32] = db_afs_32 byte2modrm[64] = db_afs_64 modrm2byte = {16: defaultdict(list), 32: defaultdict(list), 64: defaultdict(list), } for size, db_afs in byte2modrm.items(): for i, modrm in enumerate(db_afs): if not isinstance(modrm, list): modrm = modrm.items() modrm.sort() modrm = tuple(modrm) modrm2byte[size][modrm].append(i) continue for j, modrm_f in enumerate(modrm): modrm_f = modrm_f.items() modrm_f.sort() modrm_f = tuple(modrm_f) modrm2byte[size][modrm_f].append((i, j)) return byte2modrm, modrm2byte byte2modrm, modrm2byte = gen_modrm_form() # ret is modr; ret is displacement def exprfindmod(e, o=None): if o is None: o = {} if isinstance(e, ExprInt): return e if isinstance(e, ExprId): i = size2gpregs[e.size].expr.index(e) o[i] = 1 return None elif isinstance(e, ExprOp): out = None if e.op == '+': for a in e.args: r = exprfindmod(a, o) if out and r1: raise ValueError('multiple displacement!') out = r return out elif e.op == "*": mul = int(e.args[1].arg) a = e.args[0] i = size2gpregs[a.size].expr.index(a) o[i] = mul else: raise ValueError('bad op') return None def expr2modrm(e, p, w8, sx=0, xmm=0, mm=0): o = defaultdict(lambda x: 0) if e.size == 64 and not e in gpregs_mm.expr: if hasattr(p, 'sd'): p.sd.value = 1 # print 'set64pref', str(e) elif hasattr(p, 'wd'): pass elif hasattr(p, 'stk'): pass else: p.rex_w.value = 1 opmode = p.v_opmode() if sx == 1: opmode = 16 if sx == 2: opmode = 32 if e.size == 8 and w8 != 0: return None, None, False if w8 == 0 and e.size != 8: return None, None, False if not isinstance(e, ExprMem): o[f_isad] = False if xmm: if e in gpregs_xmm.expr: i = gpregs_xmm.expr.index(e) o[i] = 1 return [o], None, True else: return None, None, False if mm: if e in gpregs_mm.expr: i = gpregs_mm.expr.index(e) o[i] = 1 return [o], None, True else: return None, None, False if w8 == 0: # if (p.v_opmode() == 64 or p.rex_p.value == 1) and e in # gpregs08_64.expr: if p.mode == 64 and e in gpregs08_64.expr: r = gpregs08_64 p.rex_p.value = 1 else: p.rex_p.value = 0 p.rex_x.value = 0 r = size2gpregs[8] if not e in r.expr: return None, None, False i = r.expr.index(e) o[i] = 1 return [o], None, True # print "ttt", opmode, e.size if opmode != e.size: # print "FFFF" return None, None, False if not e in size2gpregs[opmode].expr: return None, None, False i = size2gpregs[opmode].expr.index(e) # print 'aaa', p.mode, i if i > 7: if p.mode == 64: # p.rex_b.value = 1 # i -=7 # print "SET REXB" pass else: return None, None, False o[i] = 1 return [o], None, True if e.is_op_segm() and isinstance(e.arg.args[0], ExprInt): return None, None, False if e.is_op_segm(): segm = e.arg.args[0] ptr = e.arg.args[1] else: segm = None ptr = e.arg o[f_isad] = True ad_size = ptr.size admode = p.v_admode() if ad_size != admode: return None, None, False """ if e.size == 64: if hasattr(p, 'sd'): p.sd.value = 1 else: p.rex_w.value = 1 """ if w8 == 1 and e.size != opmode: # p.v_opmode(): if not (hasattr(p, 'sd') or hasattr(p, 'wd')): return None, None, False # print 'tttt' if hasattr(p, 'wd'): s = e.size if s == 16: p.wd.value = 1 elif s == 32: pass else: return None, None, False if p.mode == 64 and ptr.size == 32: if p.admode != 1: return None, None, False o = {f_isad: True} disp = exprfindmod(ptr, o) out = [] if disp is None: # add 0 disp disp = ExprInt32(0) if disp is not None: for s, x in [(f_s08, ExprInt8), (f_s16, ExprInt16), (f_s32, ExprInt32), (f_u08, ExprInt8), (f_u16, ExprInt16), (f_u32, ExprInt32)]: # print "1", disp v = x(int(disp.arg)) # print "2", v, hex(sign_ext(int(v.arg), v.size, disp.size)) if int(disp.arg) != sign_ext(int(v.arg), v.size, disp.size): # print 'nok' continue # print 'ok', s, v x1 = dict(o) x1[f_imm] = (s, v) out.append(x1) else: out = [o] return out, segm, True def modrm2expr(m, p, w8, sx=0, xmm=0, mm=0): o = [] if not m[f_isad]: k = [x[0] for x in m.items() if x[1] == 1] if len(k) != 1: raise ValueError('strange reg encoding %r' % m) k = k[0] if w8 == 0: opmode = 8 elif sx == 1: opmode = 16 elif sx == 2: opmode = 32 else: opmode = p.v_opmode() """ if k > 7: # XXX HACK TODO e = size2gpregs[64].expr[k] else: e = size2gpregs[opmode].expr[k] """ # print 'yyy', opmode, k if xmm: e = gpregs_xmm.expr[k] elif mm: e = gpregs_mm.expr[k] elif opmode == 8 and (p.v_opmode() == 64 or p.rex_p.value == 1): e = gpregs08_64.expr[k] else: e = size2gpregs[opmode].expr[k] return e # print "enc", m, p.v_admode(), p.prefix.opmode, p.prefix.admode admode = p.v_admode() opmode = p.v_opmode() for k, v in m.items(): if type(k) in [int, long]: e = size2gpregs[admode].expr[k] if v != 1: e = ExprInt_fromsize(admode, v) * e o.append(e) # print [str(x) for x in o] if f_imm in m: if p.disp.value is None: return None o.append(ExprInt_fromsize(admode, p.disp.expr.arg)) e = ExprOp('+', *o) if w8 == 0: opmode = 8 elif sx == 1: opmode = 16 elif sx == 2: opmode = 32 e = ExprMem(e, size=opmode) # print "mem size", opmode, e return e class x86_rm_arg(m_arg): parser = rmarg def fromstring(self, s, parser_result=None): start, stop = super(x86_rm_arg, self).fromstring(s, parser_result) e = self.expr p = self.parent if start is None: return None, None s = e.size return start, stop @staticmethod def arg2str(e): if isinstance(e, ExprId): o = str(e) elif isinstance(e, ExprMem): sz = {8: 'BYTE', 16: 'WORD', 32: 'DWORD', 64: 'QWORD', 80: 'TBYTE'}[e.size] segm = "" if e.is_op_segm(): segm = "%s:" % e.arg.args[0] e = e.arg.args[1] else: e = e.arg if isinstance(e, ExprOp): # s = str(e.arg)[1:-1] s = str(e).replace('(', '').replace(')', '') else: s = str(e) o = sz + ' PTR %s[%s]' % (segm, s) else: raise ValueError('check this %r' % e) return "%s" % o def get_modrm(self): p = self.parent admode = p.v_admode() if not admode in [16, 32, 64]: raise ValueError('strange admode %r', admode) v = setmodrm(p.mod.value, 0, p.rm.value) v |= p.rex_b.value << 8 v |= p.rex_x.value << 9 if p.mode == 64: # XXXx to check admode = 64 xx = byte2modrm[admode][v] if isinstance(xx, list): if not p.sib_scale: return False v = setmodrm(p.sib_scale.value, p.sib_index.value, p.sib_base.value) # print 'SIB', hex(v) # v |= p.rex_b.value << 8 # v |= p.rex_x.value << 9 # if v >= 0x100: # pass xx = xx[v] return xx def decode(self, v): p = self.parent xx = self.get_modrm() mm = hasattr(self.parent, "mm") xmm = hasattr(self.parent, "xmm") e = modrm2expr(xx, p, 1, xmm=xmm, mm=mm) if e is None: return False self.expr = e return True def gen_cand(self, v_cand, admode): # print "GEN CAND" if not admode in modrm2byte: # XXX TODO: 64bit raise StopIteration if not v_cand: raise StopIteration p = self.parent o_rex_x = p.rex_x.value o_rex_b = p.rex_b.value # add candidate without 0 imm new_v_cand = [] moddd = False for v in v_cand: new_v_cand.append(v) # print 'CANDI', v, admode if f_imm in v and int(v[f_imm][1].arg) == 0: v = dict(v) del(v[f_imm]) new_v_cand.append(v) moddd = True v_cand = new_v_cand out_c = [] for v in v_cand: disp = None # patch value in modrm if f_imm in v: size, disp = v[f_imm] disp = int(disp.arg) # disp = swap_uint(f_imm2size[size], int(disp)) v[f_imm] = size vo = v # print 'vv', v, disp v = v.items() v.sort() v = tuple(v) # print "II", e, admode # print 'III', v # if (8, 1) in v: # pass if not v in modrm2byte[admode]: # print 'cannot find' continue # print "FOUND1", v xx = modrm2byte[admode][v] # if opmode == 64 and admode == 64: # pdb.set_trace() # print "FOUND2", xx # default case for x in xx: if type(x) == tuple: modrm, sib = x else: modrm = x sib = None # print 'mod sib', hex(modrm), sib # print p.sib_scale # print p.sib_base # print p.sib_index # 16 bit cannot have sib if (not sib is None) and admode == 16: continue # if ((p.sib_scale and sib is None) or # (p.sib_scale is None and sib)): # log.debug('dif sib %r %r'%(p.sib_scale, sib)) # continue # print hex(modrm), sib # p.mod.value, dum, p.rm.value = getmodrm(modrm) rex = modrm >> 8 # 0# XXX HACK REM temporary REX modrm>>8 if rex and admode != 64: continue # print 'prefix', hex(rex) # p.rex_x.value = o_rex_x # p.rex_b.value = o_rex_b p.rex_x.value = (rex >> 1) & 1 p.rex_b.value = rex & 1 if o_rex_x is not None and p.rex_x.value != o_rex_x: continue if o_rex_b is not None and p.rex_b.value != o_rex_b: continue mod, re, rm = getmodrm(modrm) # check re on parent if re != p.reg.value: continue # p.mod.value.append(mod) # p.rm.value.append(rm) if sib: # print 'REX', p.rex_x.value, p.rex_b.value # print hex(modrm), hex(sib) # if (modrm & 0xFF == 4 and sib & 0xFF == 0x5 # and p.rex_b.value ==1 and p.rex_x.value == 0): # pass s_scale, s_index, s_base = getmodrm(sib) # p.sib_scale.value, p.sib_index.value, # p.sib_base.value = getmodrm(sib) # p.sib_scale.decode(mod) # p.sib_index.decode(re) # p.sib_base.decode(rm) # p.sib_scale.value.append(mod) # p.sib_index.value.append(re) # p.sib_base.value.append(rm) else: # p.sib_scale.value.append(None) # p.sib_index.value.append(None) # p.sib_base.value.append(None) s_scale, s_index, s_base = None, None, None # print 'IIII', repr(p.disp), f_imm in v # if p.disp and not f_imm in vo: # continue # if not p.disp and f_imm in vo: # continue # if p.disp: # if p.disp.l != f_imm2size[vo[f_imm]]: # continue # print "DISP", repr(p.disp), p.disp.l # p.disp.value = int(disp.arg) # print 'append' # print mod, rm, s_scale, s_index, s_base, disp # print p.mod, p.rm # out_c.append((mod, rm, s_scale, s_index, s_base, disp)) p.mod.value = mod p.rm.value = rm p.sib_scale.value = s_scale p.sib_index.value = s_index p.sib_base.value = s_base p.disp.value = disp if disp is not None: p.disp.l = f_imm2size[vo[f_imm]] yield True raise StopIteration def encode(self): e = self.expr # print "eee", e if isinstance(e, ExprInt): raise StopIteration p = self.parent admode = p.v_admode() mode = e.size mm = hasattr(self.parent, 'mm') xmm = hasattr(self.parent, 'xmm') v_cand, segm, ok = expr2modrm(e, p, 1, xmm=xmm, mm=mm) if segm: p.g2.value = segm2enc[segm] # print "REZ1", v_cand, ok for x in self.gen_cand(v_cand, admode): yield x class x86_rm_w8(x86_rm_arg): def decode(self, v): p = self.parent xx = self.get_modrm() e = modrm2expr(xx, p, p.w8.value) self.expr = e return e is not None def encode(self): e = self.expr if isinstance(e, ExprInt): raise StopIteration p = self.parent if p.w8.value is None: if e.size == 8: p.w8.value = 0 else: p.w8.value = 1 # print 'TTTTT', e v_cand, segm, ok = expr2modrm(e, p, p.w8.value) if segm: p.g2.value = segm2enc[segm] # print "REZ2", v_cand, ok for x in self.gen_cand(v_cand, p.v_admode()): # print 'REZ', p.rex_x.value yield x class x86_rm_sx(x86_rm_arg): def decode(self, v): p = self.parent xx = self.get_modrm() e = modrm2expr(xx, p, p.w8.value, 1) self.expr = e return e is not None def encode(self): e = self.expr if isinstance(e, ExprInt): raise StopIteration p = self.parent if p.w8.value is None: if e.size == 8: p.w8.value = 0 else: p.w8.value = 1 v_cand, segm, ok = expr2modrm(e, p, p.w8.value, 1) if segm: p.g2.value = segm2enc[segm] for x in self.gen_cand(v_cand, p.v_admode()): yield x class x86_rm_sxd(x86_rm_arg): def decode(self, v): p = self.parent xx = self.get_modrm() e = modrm2expr(xx, p, 1, 2) self.expr = e return e is not None def encode(self): e = self.expr if isinstance(e, ExprInt): raise StopIteration p = self.parent v_cand, segm, ok = expr2modrm(e, p, 1, 2) if segm: p.g2.value = segm2enc[segm] for x in self.gen_cand(v_cand, p.v_admode()): yield x class x86_rm_sd(x86_rm_arg): def decode(self, v): p = self.parent xx = self.get_modrm() e = modrm2expr(xx, p, 1) if not isinstance(e, ExprMem): return False if p.sd.value == 0: e = ExprMem(e.arg, 32) else: e = ExprMem(e.arg, 64) self.expr = e return e is not None def encode(self): e = self.expr if isinstance(e, ExprInt): raise StopIteration p = self.parent if not e.size in [32, 64]: raise StopIteration p.sd.value = 0 v_cand, segm, ok = expr2modrm(e, p, 1) for x in self.gen_cand(v_cand, p.v_admode()): yield x class x86_rm_wd(x86_rm_arg): def decode(self, v): p = self.parent xx = self.get_modrm() e = modrm2expr(xx, p, 1) if not isinstance(e, ExprMem): return False if p.wd.value == 0: e = ExprMem(e.arg, 32) else: e = ExprMem(e.arg, 16) self.expr = e return e is not None def encode(self): e = self.expr if isinstance(e, ExprInt): raise StopIteration p = self.parent p.wd.value = 0 v_cand, segm, ok = expr2modrm(e, p, 1) for x in self.gen_cand(v_cand, p.v_admode()): yield x class x86_rm_m80(x86_rm_arg): msize = 80 def decode(self, v): p = self.parent xx = self.get_modrm() # print "aaa", xx e = modrm2expr(xx, p, 1) if not isinstance(e, ExprMem): return False e = ExprMem(e.arg, self.msize) self.expr = e return e is not None def encode(self): e = self.expr if isinstance(e, ExprInt): raise StopIteration if not isinstance(e, ExprMem) or e.size != self.msize: raise StopIteration p = self.parent mode = p.mode if mode == 64: mode = 32 e = ExprMem(e.arg, mode) v_cand, segm, ok = expr2modrm(e, p, 1) for x in self.gen_cand(v_cand, p.v_admode()): yield x class x86_rm_m08(x86_rm_arg): msize = 8 def decode(self, v): p = self.parent xx = self.get_modrm() e = modrm2expr(xx, p, 0) self.expr = e return e is not None def encode(self): e = self.expr if e.size != 8: raise StopIteration """ if not isinstance(e, ExprMem) or e.size != self.msize: raise StopIteration """ p = self.parent mode = p.mode # if mode == 64: # mode = 32 # e = ExprMem(e.arg, mode) v_cand, segm, ok = expr2modrm(e, p, 0) for x in self.gen_cand(v_cand, p.v_admode()): yield x class x86_rm_m16(x86_rm_m80): msize = 16 class x86_rm_m64(x86_rm_m80): msize = 64 class x86_rm_reg_noarg(object): prio = default_prio + 1 parser = gpreg def fromstring(self, s, parser_result=None): # print 'parsing reg', s, opmode if not hasattr(self.parent, 'sx') and hasattr(self.parent, "w8"): self.parent.w8.value = 1 if parser_result: e, start, stop = parser_result[self.parser] # print 'reg result', e, start, stop if e is None: return None, None self.expr = e if self.expr.size == 8: if hasattr(self.parent, 'sx') or not hasattr(self.parent, 'w8'): return None, None self.parent.w8.value = 0 return start, stop try: v, start, stop = self.parser.scanString(s).next() except StopIteration: return None, None self.expr = v[0] if self.expr.size == 0: if hasattr(self.parent, 'sx') or not hasattr(self.parent, 'w8'): return None, None self.parent.w8.value = 0 # print 'parsed', s, self.expr return start, stop def getrexsize(self): return self.parent.rex_r.value def setrexsize(self, v): self.parent.rex_r.value = v def decode(self, v): v = v & self.lmask p = self.parent opmode = p.v_opmode() # if hasattr(p, 'sx'): # opmode = 16 if not hasattr(p, 'sx') and (hasattr(p, 'w8') and p.w8.value == 0): opmode = 8 r = size2gpregs[opmode] if p.mode == 64 and self.getrexsize(): v |= 0x8 # print "XXX", p.v_opmode(), p.rex_p.value if p.v_opmode() == 64 or p.rex_p.value == 1: if not hasattr(p, 'sx') and (hasattr(p, 'w8') and p.w8.value == 0): # if (hasattr(p, 'w8') and p.w8.value == 0): r = gpregs08_64 """ if v < 8: self.expr = r.expr[v] else: self.expr = size2gpregs[64].expr[v] """ if hasattr(p, "xmm") or hasattr(p, "xmmreg"): e = gpregs_xmm.expr[v] elif hasattr(p, "mm") or hasattr(p, "mmreg"): e = gpregs_mm.expr[v] else: e = r.expr[v] self.expr = e return True def encode(self): if not isinstance(self.expr, ExprId): return False if self.expr in gpregs64.expr and not hasattr(self.parent, 'stk'): self.parent.rex_w.value = 1 # print self.parent.opmode # fd opmode = self.parent.v_opmode() # if hasattr(self.parent, 'sx'): # opmode = 16 # print 'reg encode', self.expr, opmode if not hasattr(self.parent, 'sx') and hasattr(self.parent, 'w8'): self.parent.w8.value = 1 if self.expr.size == 8: if hasattr(self.parent, 'sx') or not hasattr(self.parent, 'w8'): return False self.parent.w8.value = 0 opmode = 8 r = size2gpregs[opmode] # print "YYY", opmode, self.expr if ((hasattr(self.parent, 'xmm') or hasattr(self.parent, 'xmmreg')) and self.expr in gpregs_xmm.expr): i = gpregs_xmm.expr.index(self.expr) elif ((hasattr(self.parent, 'mm') or hasattr(self.parent, 'mmreg')) and self.expr in gpregs_mm.expr): i = gpregs_mm.expr.index(self.expr) elif self.expr in r.expr: i = r.expr.index(self.expr) elif (opmode == 8 and self.parent.mode == 64 and self.expr in gpregs08_64.expr): i = gpregs08_64.expr.index(self.expr) self.parent.rex_p.value = 1 else: log.debug("cannot encode reg %r" % self.expr) return False # print "zzz", opmode, self.expr, i, self.parent.mode if self.parent.v_opmode() == 64: if i > 7: self.setrexsize(1) i -= 8 elif self.parent.mode == 64 and i > 7: i -= 8 # print 'rrr', self.getrexsize() # self.parent.rex_b.value = 1 self.setrexsize(1) if hasattr(self.parent, 'xmm') or hasattr(self.parent, 'mm'): if i > 7: i -= 8 self.value = i if self.value > self.lmask: log.debug("cannot encode field value %x %x" % (self.value, self.lmask)) return False # print 'RR ok' return True class x86_rm_reg(x86_rm_reg_noarg, m_arg): pass class x86_reg(x86_rm_reg): def getrexsize(self): return self.parent.rex_b.value def setrexsize(self, v): self.parent.rex_b.value = v class x86_reg_noarg(x86_rm_reg_noarg): def getrexsize(self): return self.parent.rex_b.value def setrexsize(self, v): self.parent.rex_b.value = v class x86_rm_segm(reg_noarg, m_arg): prio = default_prio + 1 reg_info = segmreg parser = reg_info.parser class x86_rm_cr(reg_noarg, m_arg): prio = default_prio + 1 reg_info = crregs parser = reg_info.parser class x86_rm_dr(reg_noarg, m_arg): prio = default_prio + 1 reg_info = drregs parser = reg_info.parser class x86_rm_flt(reg_noarg, m_arg): prio = default_prio + 1 reg_info = fltregs parser = reg_info.parser class bs_fbit(bsi): def decode(self, v): # value already decoded in pre_dis_info # print "jj", self.value return True class bs_cl1(bsi, m_arg): parser = cl_or_imm def decode(self, v): if v == 1: self.expr = regs08_expr[1] else: self.expr = ExprInt8(1) return True def encode(self): if self.expr == regs08_expr[1]: self.value = 1 elif isinstance(self.expr, ExprInt) and int(self.expr.arg) == 1: self.value = 0 else: return False return True def sib_cond(cls, mode, v): if admode_prefix((mode, v["opmode"], v["admode"])) == 16: return None if v['mod'] == 0b11: return None elif v['rm'] == 0b100: return cls.ll else: return None return v['rm'] == 0b100 class bs_cond_scale(bs_cond): # cond must return field len ll = 2 @classmethod def flen(cls, mode, v): return sib_cond(cls, mode, v) def encode(self): if self.value is None: self.value = 0 self.l = 0 return True return super(bs_cond, self).encode() def decode(self, v): self.value = v return True class bs_cond_index(bs_cond_scale): ll = 3 @classmethod def flen(cls, mode, v): return sib_cond(cls, mode, v) class bs_cond_disp(bs_cond): # cond must return field len @classmethod def flen(cls, mode, v): # print 'disp cond', mode, # print v, v_admode_info(mode, v['opmode'], v['admode']) # if v_admode_info(mode, v['opmode'], v['admode']) ==16: if admode_prefix((mode, v['opmode'], v['admode'])) == 16: if v['mod'] == 0b00: if v['rm'] == 0b110: return 16 else: return None elif v['mod'] == 0b01: return 8 elif v['mod'] == 0b10: return 16 return None # 32, 64 if 'sib_base' in v and v['sib_base'] == 0b101: if v['mod'] == 0b00: return 32 elif v['mod'] == 0b01: return 8 elif v['mod'] == 0b10: return 32 else: return None if v['mod'] == 0b00: if v['rm'] == 0b101: return 32 else: return None elif v['mod'] == 0b01: return 8 elif v['mod'] == 0b10: return 32 else: return None def encode(self): if self.value is None: self.value = 0 self.l = 0 return True self.value = swap_uint(self.l, self.value) return True def decode(self, v): admode = self.parent.v_admode() v = swap_uint(self.l, v) self.value = v v = sign_ext(v, self.l, admode) v = ExprInt_fromsize(admode, v) self.expr = v return True class bs_cond_imm(bs_cond_scale, m_arg): parser = int_or_expr max_size = 32 def fromstring(self, s, parser_result=None): if parser_result: e, start, stop = parser_result[self.parser] else: try: e, start, stop = self.parser.scanString(s).next() except StopIteration: e = None self.expr = e if len(self.parent.args) > 1: l = self.parent.args[0].expr.size else: l = self.parent.v_opmode() # l = min(l, self.max_size) # l = offsize(self.parent) if isinstance(self.expr, ExprInt): v = int(self.expr.arg) mask = ((1 << l) - 1) v = v & mask e = ExprInt_fromsize(l, v) self.expr = e if self.expr is None: log.debug('cannot fromstring int %r' % s) return None, None return start, stop @classmethod def flen(cls, mode, v): if 'w8' not in v or v['w8'] == 1: if 'se' in v and v['se'] == 1: return 8 else: # osize = v_opmode_info(mode, v['opmode'], v['admode']) # osize = opmode_prefix((mode, v['opmode'], v['admode'])) osize = v_opmode_info(mode, v['opmode'], v['rex_w'], 0) osize = min(osize, cls.max_size) return osize return 8 def getmaxlen(self): return 32 def encode(self): if not isinstance(self.expr, ExprInt): raise StopIteration arg0_expr = self.parent.args[0].expr self.parent.rex_w.value = 0 # special case for push if len(self.parent.args) == 1: v = int(self.expr.arg) l = self.parent.v_opmode() l = min(l, self.max_size) self.l = l mask = ((1 << self.l) - 1) # print 'ext', self.l, l, hex(v), hex(sign_ext(v & ((1< l: raise StopIteration if v != sign_ext(v & mask, self.l, l): raise StopIteration self.value = swap_uint(self.l, v & ((1 << self.l) - 1)) # print hex(self.value) yield True def decode(self, v): v = swap_uint(self.l, v) size = offsize(self.parent) v = sign_ext(v, self.l, size) v = ExprInt_fromsize(size, v) self.expr = v # print self.expr, repr(self.expr) return True class bs_rel_off08(bs_rel_off): @classmethod def flen(cls, mode, v): return 8 class bs_moff(bsi): @classmethod def flen(cls, mode, v): osize = v_opmode_info(mode, v['opmode'], v['rex_w'], 0) if osize == 16: return 16 else: return 32 def encode(self): if not hasattr(self.parent, "mseg"): raise StopIteration m = self.parent.mseg.expr if (not (isinstance(m, ExprMem) and m.is_op_segm() and isinstance(m.arg.args[0], ExprInt))): raise StopIteration l = self.parent.v_opmode() # self.parent.args[0].expr.size if l == 16: self.l = 16 else: self.l = 32 # print 'imm enc', l, self.parent.rex_w.value v = int(m.arg.args[1].arg) mask = ((1 << self.l) - 1) # print 'ext', self.l, l, hex(v), hex(sign_ext(v & ((1<