diff options
Diffstat (limited to '')
| -rw-r--r-- | miasm2/analysis/data_flow.py | 505 | ||||
| -rw-r--r-- | miasm2/analysis/depgraph.py | 6 | ||||
| -rw-r--r-- | miasm2/analysis/ssa.py | 120 | ||||
| -rw-r--r-- | miasm2/arch/aarch64/arch.py | 8 | ||||
| -rw-r--r-- | miasm2/arch/aarch64/regs.py | 3 | ||||
| -rw-r--r-- | miasm2/arch/aarch64/sem.py | 535 | ||||
| -rw-r--r-- | miasm2/arch/arm/jit.py | 10 | ||||
| -rw-r--r-- | miasm2/arch/arm/sem.py | 376 | ||||
| -rw-r--r-- | miasm2/arch/mep/arch.py | 5 | ||||
| -rw-r--r-- | miasm2/arch/x86/sem.py | 484 | ||||
| -rw-r--r-- | miasm2/core/graph.py | 18 | ||||
| -rw-r--r-- | miasm2/expression/expression.py | 35 | ||||
| -rw-r--r-- | miasm2/expression/simplifications.py | 45 | ||||
| -rw-r--r-- | miasm2/expression/simplifications_common.py | 348 | ||||
| -rw-r--r-- | miasm2/expression/simplifications_explicit.py | 155 | ||||
| -rw-r--r-- | miasm2/ir/ir.py | 15 | ||||
| -rw-r--r-- | miasm2/ir/symbexec.py | 16 | ||||
| -rw-r--r-- | miasm2/ir/translators/C.py | 24 | ||||
| -rw-r--r-- | miasm2/ir/translators/z3_ir.py | 6 | ||||
| -rw-r--r-- | miasm2/jitter/codegen.py | 8 | ||||
| -rw-r--r-- | miasm2/jitter/jitcore_python.py | 7 | ||||
| -rw-r--r-- | miasm2/jitter/llvmconvert.py | 22 |
22 files changed, 2316 insertions, 435 deletions
diff --git a/miasm2/analysis/data_flow.py b/miasm2/analysis/data_flow.py index 0a224319..a0ff867b 100644 --- a/miasm2/analysis/data_flow.py +++ b/miasm2/analysis/data_flow.py @@ -3,8 +3,9 @@ from collections import namedtuple from miasm2.core.graph import DiGraph from miasm2.ir.ir import AssignBlock, IRBlock -from miasm2.expression.expression import ExprLoc - +from miasm2.expression.expression import ExprLoc, ExprMem, ExprId, ExprInt +from miasm2.expression.simplifications import expr_simp +from miasm2.core.interval import interval class ReachingDefinitions(dict): """ @@ -364,7 +365,7 @@ def _relink_block_node(ircfg, loc_key, son_loc_key, replace_dct): ) # Link parent to new dst - ircfg.add_edge(parent, son_loc_key) + ircfg.add_uniq_edge(parent, son_loc_key) # Unlink block ircfg.blocks[new_block.loc_key] = new_block @@ -513,3 +514,501 @@ def remove_empty_assignblks(ircfg): ircfg.blocks[loc_key] = IRBlock(loc_key, irs) return modified + + + +class SSADefUse(DiGraph): + """ + Generate DefUse information from SSA transformation + Links are not valid for ExprMem. + """ + + def add_var_def(self, node, src): + lbl, index, dst = node + index2dst = self._links.setdefault(lbl, {}) + dst2src = index2dst.setdefault(index, {}) + dst2src[dst] = src + + def add_def_node(self, def_nodes, node, src): + lbl, index, dst = node + if dst.is_id(): + def_nodes[dst] = node + + def add_use_node(self, use_nodes, node, src): + lbl, index, dst = node + sources = set() + if dst.is_mem(): + sources.update(dst.arg.get_r(mem_read=True)) + sources.update(src.get_r(mem_read=True)) + for source in sources: + if not source.is_mem(): + use_nodes.setdefault(source, set()).add(node) + + def get_node_target(self, node): + lbl, index, reg = node + return self._links[lbl][index][reg] + + def set_node_target(self, node, src): + lbl, index, reg = node + self._links[lbl][index][reg] = src + + @classmethod + def from_ssa(cls, ssa): + """ + Return a DefUse DiGraph from a SSA graph + @ssa: SSADiGraph instance + """ + + graph = cls() + # First pass + # Link line to its use and def + def_nodes = {} + use_nodes = {} + graph._links = {} + for lbl in ssa.graph.nodes(): + block = ssa.graph.blocks.get(lbl, None) + if block is None: + continue + for index, assignblk in enumerate(block): + for dst, src in assignblk.iteritems(): + node = lbl, index, dst + graph.add_var_def(node, src) + graph.add_def_node(def_nodes, node, src) + graph.add_use_node(use_nodes, node, src) + + for dst, node in def_nodes.iteritems(): + graph.add_node(node) + if dst not in use_nodes: + continue + for use in use_nodes[dst]: + graph.add_uniq_edge(node, use) + + return graph + + + + +def expr_test_visit(expr, test): + result = set() + expr.visit( + lambda expr: expr, + lambda expr: test(expr, result) + ) + if result: + return True + else: + return False + + +def expr_has_mem_test(expr, result): + if result: + # Don't analyse if we already found a candidate + return False + if expr.is_mem(): + result.add(expr) + return False + return True + + +def expr_has_mem(expr): + """ + Return True if expr contains at least one memory access + @expr: Expr instance + """ + return expr_test_visit(expr, expr_has_mem_test) + + +def expr_has_call_test(expr, result): + if result: + # Don't analyse if we already found a candidate + return False + if expr.is_op() and expr.op.startswith("call"): + result.add(expr) + return False + return True + + +def expr_has_call(expr): + """ + Return True if expr contains at least one "call" operator + @expr: Expr instance + """ + return expr_test_visit(expr, expr_has_call_test) + + +class PropagateExpr(object): + + def assignblk_is_propagation_barrier(self, assignblk): + for dst, src in assignblk.iteritems(): + if expr_has_call(src): + return True + if dst.is_mem(): + return True + return False + + def has_propagation_barrier(self, assignblks): + for assignblk in assignblks: + for dst, src in assignblk.iteritems(): + if expr_has_call(src): + return True + if dst.is_mem(): + return True + return False + + def is_mem_written(self, ssa, node, successor): + loc_a, index_a, reg_a = node + loc_b, index_b, reg_b = successor + block_b = ssa.graph.blocks[loc_b] + + nodes_to_do = self.compute_reachable_nodes_from_a_to_b(ssa.graph, loc_a, loc_b) + + + if loc_a == loc_b: + # src is dst + assert nodes_to_do == set([loc_a]) + if self.has_propagation_barrier(block_b.assignblks[index_a:index_b]): + return True + else: + # Check everyone but loc_a and loc_b + for loc in nodes_to_do - set([loc_a, loc_b]): + block = ssa.graph.blocks[loc] + if self.has_propagation_barrier(block.assignblks): + return True + # Check loc_a partially + block_a = ssa.graph.blocks[loc_a] + if self.has_propagation_barrier(block_a.assignblks[index_a:]): + return True + if nodes_to_do.intersection(ssa.graph.successors(loc_b)): + # There is a path from loc_b to loc_b => Check loc_b fully + if self.has_propagation_barrier(block_b.assignblks): + return True + else: + # Check loc_b partially + if self.has_propagation_barrier(block_b.assignblks[:index_b]): + return True + return False + + def compute_reachable_nodes_from_a_to_b(self, ssa, loc_a, loc_b): + reachables_a = set(ssa.reachable_sons(loc_a)) + reachables_b = set(ssa.reachable_parents_stop_node(loc_b, loc_a)) + return reachables_a.intersection(reachables_b) + + def propagation_allowed(self, ssa, to_replace, node_a, node_b): + """ + Return True if we can replace @node source into @node_b + """ + loc_a, index_a, reg_a = node_a + if not expr_has_mem(to_replace[reg_a]): + return True + if self.is_mem_written(ssa, node_a, node_b): + return False + return True + + def propagate(self, ssa, head): + defuse = SSADefUse.from_ssa(ssa) + to_replace = {} + node_to_reg = {} + for node in defuse.nodes(): + lbl, index, reg = node + src = defuse.get_node_target(node) + if expr_has_call(src): + continue + if src.is_op('Phi'): + continue + if reg.is_mem(): + continue + to_replace[reg] = src + node_to_reg[node] = reg + + modified = False + for node, reg in node_to_reg.iteritems(): + src = to_replace[reg] + + for successor in defuse.successors(node): + if not self.propagation_allowed(ssa, to_replace, node, successor): + continue + + loc_a, index_a, reg_a = node + loc_b, index_b, reg_b = successor + block = ssa.graph.blocks[loc_b] + + replace = {reg_a: to_replace[reg_a]} + # Replace + assignblks = list(block) + assignblk = block[index_b] + out = {} + for dst, src in assignblk.iteritems(): + if src.is_op('Phi'): + out[dst] = src + continue + + if src.is_mem(): + ptr = src.arg + ptr = ptr.replace_expr(replace) + new_src = ExprMem(ptr, src.size) + else: + new_src = src.replace_expr(replace) + + if dst.is_id(): + new_dst = dst + elif dst.is_mem(): + ptr = dst.arg + ptr = ptr.replace_expr(replace) + new_dst = ExprMem(ptr, dst.size) + else: + new_dst = dst.replace_expr(replace) + if not (new_dst.is_id() or new_dst.is_mem()): + new_dst = dst + if src != new_src or dst != new_dst: + modified = True + out[new_dst] = new_src + out = AssignBlock(out, assignblk.instr) + assignblks[index_b] = out + new_block = IRBlock(block.loc_key, assignblks) + ssa.graph.blocks[block.loc_key] = new_block + return modified + + +def stack_to_reg(expr): + if expr.is_mem(): + ptr = expr.arg + SP = ir_arch_a.sp + if ptr == SP: + return ExprId("STACK.0", expr.size) + elif (ptr.is_op('+') and + len(ptr.args) == 2 and + ptr.args[0] == SP and + ptr.args[1].is_int()): + diff = int(ptr.args[1]) + assert diff % 4 == 0 + diff = (0 - diff) & 0xFFFFFFFF + return ExprId("STACK.%d" % (diff / 4), expr.size) + return False + + +def is_stack_access(ir_arch_a, expr): + if not expr.is_mem(): + return False + ptr = expr.arg + diff = expr_simp(ptr - ir_arch_a.sp) + if not diff.is_int(): + return False + return expr + + +def visitor_get_stack_accesses(ir_arch_a, expr, stack_vars): + if is_stack_access(ir_arch_a, expr): + stack_vars.add(expr) + return expr + + +def get_stack_accesses(ir_arch_a, expr): + result = set() + expr.visit(lambda expr:visitor_get_stack_accesses(ir_arch_a, expr, result)) + return result + + +def get_interval_length(interval_in): + length = 0 + for start, stop in interval_in.intervals: + length += stop + 1 - start + return length + + +def check_expr_below_stack(ir_arch_a, expr): + """ + Return False if expr pointer is below original stack pointer + @ir_arch_a: ira instance + @expr: Expression instance + """ + ptr = expr.arg + diff = expr_simp(ptr - ir_arch_a.sp) + if not diff.is_int(): + return True + if int(diff) == 0 or int(expr_simp(diff.msb())) == 0: + return False + return True + + +def retrieve_stack_accesses(ir_arch_a, ssa): + """ + Walk the ssa graph and find stack based variables. + Return a dictionnary linking stack base address to its size/name + @ir_arch_a: ira instance + @ssa: SSADiGraph instance + """ + stack_vars = set() + for block in ssa.graph.blocks.itervalues(): + for assignblk in block: + for dst, src in assignblk.iteritems(): + stack_vars.update(get_stack_accesses(ir_arch_a, dst)) + stack_vars.update(get_stack_accesses(ir_arch_a, src)) + stack_vars = filter(lambda expr: check_expr_below_stack(ir_arch_a, expr), stack_vars) + + base_to_var = {} + for var in stack_vars: + base_to_var.setdefault(var.arg, set()).add(var) + + + base_to_interval = {} + for addr, vars in base_to_var.iteritems(): + var_interval = interval() + for var in vars: + offset = expr_simp(addr - ir_arch_a.sp) + if not offset.is_int(): + # skip non linear stack offset + continue + + start = int(offset) + stop = int(expr_simp(offset + ExprInt(var.size / 8, offset.size))) + mem = interval([(start, stop-1)]) + var_interval += mem + base_to_interval[addr] = var_interval + if not base_to_interval: + return {} + # Check if not intervals overlap + _, tmp = base_to_interval.popitem() + while base_to_interval: + addr, mem = base_to_interval.popitem() + assert (tmp & mem).empty + tmp += mem + + base_to_info = {} + base_to_name = {} + for addr, vars in base_to_var.iteritems(): + name = "var_%d" % (len(base_to_info)) + size = max([var.size for var in vars]) + base_to_info[addr] = size, name + return base_to_info + + +def fix_stack_vars(expr, base_to_info): + """ + Replace local stack accesses in expr using informations in @base_to_info + @expr: Expression instance + @base_to_info: dictionnary linking stack base address to its size/name + """ + if not expr.is_mem(): + return expr + ptr = expr.arg + if ptr not in base_to_info: + return expr + size, name = base_to_info[ptr] + var = ExprId(name, size) + if size == expr.size: + return var + assert expr.size < size + return var[:expr.size] + + +def replace_mem_stack_vars(expr, base_to_info): + return expr.visit(lambda expr:fix_stack_vars(expr, base_to_info)) + + +def replace_stack_vars(ir_arch_a, ssa): + """ + Try to replace stack based memory accesses by variables. + WARNING: may fail + + @ir_arch_a: ira instance + @ssa: SSADiGraph instance + """ + defuse = SSADefUse.from_ssa(ssa) + + base_to_info = retrieve_stack_accesses(ir_arch_a, ssa) + stack_vars = {} + modified = False + for block in ssa.graph.blocks.itervalues(): + assignblks = [] + for assignblk in block: + out = {} + for dst, src in assignblk.iteritems(): + new_dst = dst.visit(lambda expr:replace_mem_stack_vars(expr, base_to_info)) + new_src = src.visit(lambda expr:replace_mem_stack_vars(expr, base_to_info)) + if new_dst != dst or new_src != src: + modified |= True + + out[new_dst] = new_src + + out = AssignBlock(out, assignblk.instr) + assignblks.append(out) + new_block = IRBlock(block.loc_key, assignblks) + ssa.graph.blocks[block.loc_key] = new_block + return modified + + +def memlookup_test(expr, bs, is_addr_ro_variable, result): + if expr.is_mem() and expr.arg.is_int(): + ptr = int(expr.arg) + if is_addr_ro_variable(bs, ptr, expr.size): + result.add(expr) + return False + return True + + +def memlookup_visit(expr, bs, is_addr_ro_variable): + result = set() + expr.visit(lambda expr: expr, + lambda expr: memlookup_test(expr, bs, is_addr_ro_variable, result)) + return result + + +def get_memlookup(expr, bs, is_addr_ro_variable): + return memlookup_visit(expr, bs, is_addr_ro_variable) + + +def read_mem(bs, expr): + ptr = int(expr.arg) + var_bytes = bs.getbytes(ptr, expr.size / 8)[::-1] + try: + value = int(var_bytes.encode('hex'), 16) + except: + return expr + return ExprInt(value, expr.size) + + +def load_from_int(ir_arch, bs, is_addr_ro_variable): + """ + Replace memory read based on constant with static value + @ir_arch: ira instance + @bs: binstream instance + @is_addr_ro_variable: callback(addr, size) to test memory candidate + """ + + modified = False + for label, block in ir_arch.blocks.iteritems(): + assignblks = list() + for assignblk in block: + out = {} + for dst, src in assignblk.iteritems(): + # Test src + mems = get_memlookup(src, bs, is_addr_ro_variable) + src_new = src + if mems: + replace = {} + for mem in mems: + value = read_mem(bs, mem) + replace[mem] = value + src_new = src.replace_expr(replace) + if src_new != src: + modified = True + # Test dst pointer if dst is mem + if dst.is_mem(): + ptr = dst.arg + mems = get_memlookup(ptr, bs, is_addr_ro_variable) + ptr_new = ptr + if mems: + replace = {} + for mem in mems: + value = read_mem(bs, mem) + replace[mem] = value + ptr_new = ptr.replace_expr(replace) + if ptr_new != ptr: + modified = True + dst = ExprMem(ptr_new, dst.size) + out[dst] = src_new + out = AssignBlock(out, assignblk.instr) + assignblks.append(out) + block = IRBlock(block.loc_key, assignblks) + ir_arch.blocks[block.loc_key] = block + return modified diff --git a/miasm2/analysis/depgraph.py b/miasm2/analysis/depgraph.py index a5f3f0fd..46a83d2d 100644 --- a/miasm2/analysis/depgraph.py +++ b/miasm2/analysis/depgraph.py @@ -3,7 +3,7 @@ from miasm2.expression.expression import ExprInt, ExprLoc, ExprAff from miasm2.core.graph import DiGraph from miasm2.core.locationdb import LocationDB -from miasm2.expression.simplifications import expr_simp +from miasm2.expression.simplifications import expr_simp_explicit from miasm2.ir.symbexec import SymbolicExecutionEngine from miasm2.ir.ir import IRBlock, AssignBlock from miasm2.ir.translators import Translator @@ -456,7 +456,7 @@ class DependencyGraph(object): @implicit: (optional) Track IRDst for each block in the resulting path Following arguments define filters used to generate dependencies - @apply_simp: (optional) Apply expr_simp + @apply_simp: (optional) Apply expr_simp_explicit @follow_mem: (optional) Track memory syntactically @follow_call: (optional) Track through "call" """ @@ -480,7 +480,7 @@ class DependencyGraph(object): """ follow = set() for expr in exprs: - follow.add(expr_simp(expr)) + follow.add(expr_simp_explicit(expr)) return follow, set() @staticmethod diff --git a/miasm2/analysis/ssa.py b/miasm2/analysis/ssa.py index 63d0c4fb..2f25e4b8 100644 --- a/miasm2/analysis/ssa.py +++ b/miasm2/analysis/ssa.py @@ -568,7 +568,127 @@ class SSADiGraph(SSA): into IRBlock at the beginning""" for loc_key in self._phinodes: irblock = self.get_block(loc_key) + if irblock is None: + continue assignblk = AssignBlock(self._phinodes[loc_key]) # insert at the beginning new_irs = IRBlock(loc_key, [assignblk] + list(irblock.assignblks)) self.ircfg.blocks[loc_key] = new_irs + + + +def get_assignblk(graph, loc, index): + """ + Return the dictionnary of the AssignBlock from @graph at location @loc at + @index + @graph: IRCFG instance + @loc: Location instance + @index: assignblock index + """ + + irblock = graph.blocks[loc] + assignblks = irblock.assignblks + assignblk = assignblks[index] + assignblk_dct = dict(assignblk) + return assignblk_dct + + +def set_assignblk(graph, loc, index, assignblk_dct): + """ + Set the Assignblock in @graph at location @loc at @index using dictionnary + @assignblk_dct + + @graph: IRCFG instance + @loc: Location instance + @index: assignblock index + @assignblk_dct: dictionnary representing the AssignBlock + """ + + irblock = graph.blocks[loc] + assignblks = list(irblock.assignblks) + assignblk = assignblks[index] + + assignblks[index] = AssignBlock( + assignblk_dct, + assignblk.instr + ) + new_irblock = IRBlock(loc, assignblks) + return new_irblock + + +def remove_phi(ssa, head): + """ + Remove Phi using naive algorithm + Note: The _ssa_variable_to_expr must be up to date + + @ssa: a SSADiGraph instance + @head: the loc_key of the graph head + """ + + phivar2var = {} + + all_ssa_vars = ssa._ssa_variable_to_expr + + # Retrive Phi nodes + phi_nodes = [] + for irblock in ssa.graph.blocks.itervalues(): + for index, assignblk in enumerate(irblock): + for dst, src in assignblk.iteritems(): + if src.is_op('Phi'): + phi_nodes.append((irblock.loc_key, index)) + + + for phi_loc, phi_index in phi_nodes: + assignblk_dct = get_assignblk(ssa.graph, phi_loc, phi_index) + for dst, src in assignblk_dct.iteritems(): + if src.is_op('Phi'): + break + else: + raise RuntimeError('Cannot find phi?') + node_src = src + var = dst + + # Create new variable + new_var = ExprId('var%d' % len(phivar2var), var.size) + phivar2var[var] = new_var + phi_sources = set(node_src.args) + + # Place var init for non ssa variables + to_remove = set() + for phi_source in list(phi_sources): + if phi_source not in all_ssa_vars.union(phivar2var): + assignblk_dct = get_assignblk(ssa.graph, head, 0) + assignblk_dct[new_var] = phi_source + new_irblock = set_assignblk(ssa.graph, head, 0, assignblk_dct) + ssa.graph.blocks[head] = new_irblock + to_remove.add(phi_source) + phi_sources.difference_update(to_remove) + + var_to_replace = set([var]) + var_to_replace.update(phi_sources) + + + + + # Replace variables + to_replace_dct = {x:new_var for x in var_to_replace} + for loc in ssa.graph.blocks: + irblock = ssa.graph.blocks[loc] + assignblks = [] + for assignblk in irblock: + assignblk_dct = {} + for dst, src in assignblk.iteritems(): + dst = dst.replace_expr(to_replace_dct) + src = src.replace_expr(to_replace_dct) + assignblk_dct[dst] = src + assignblks.append(AssignBlock(assignblk_dct, assignblk.instr)) + new_irblock = IRBlock(loc, assignblks) + ssa.graph.blocks[loc] = new_irblock + + # Remove phi + assignblk_dct = get_assignblk(ssa.graph, phi_loc, phi_index) + del assignblk_dct[new_var] + + + new_irblock = set_assignblk(ssa.graph, phi_loc, phi_index, assignblk_dct) + ssa.graph.blocks[phi_loc] = new_irblock diff --git a/miasm2/arch/aarch64/arch.py b/miasm2/arch/aarch64/arch.py index 529621c4..8cb681f6 100644 --- a/miasm2/arch/aarch64/arch.py +++ b/miasm2/arch/aarch64/arch.py @@ -1839,6 +1839,14 @@ aarch64op("bics", [sf, bs('11'), bs('01010'), shift, bs('1'), rm_sft, imm6, rn, aarch64op("mov", [sf, bs('01'), bs('01010'), bs('00'), bs('0'), rmz, bs('000000'), bs('11111'), rd], [rd, rmz], alias=True) +aarch64op("adc", [sf, bs('00'), bs('11010000'), rm, bs('000000'), rn, rd], [rd, rn, rm]) +aarch64op("adcs", [sf, bs('01'), bs('11010000'), rm, bs('000000'), rn, rd], [rd, rn, rm]) + + +aarch64op("sbc", [sf, bs('10'), bs('11010000'), rm, bs('000000'), rn, rd], [rd, rn, rm]) +aarch64op("sbcs", [sf, bs('11'), bs('11010000'), rm, bs('000000'), rn, rd], [rd, rn, rm]) + + bcond = bs_mod_name(l=4, fname='cond', mn_mod=['EQ', 'NE', 'CS', 'CC', 'MI', 'PL', 'VS', 'VC', diff --git a/miasm2/arch/aarch64/regs.py b/miasm2/arch/aarch64/regs.py index c9da0653..85c8425a 100644 --- a/miasm2/arch/aarch64/regs.py +++ b/miasm2/arch/aarch64/regs.py @@ -1,6 +1,7 @@ #-*- coding:utf-8 -*- -from miasm2.expression.expression import * +from miasm2.expression.expression import ExprId, ExprInt, ExprLoc, ExprMem, \ + ExprSlice, ExprCond, ExprCompose, ExprOp from miasm2.core.cpu import gen_reg, gen_regs exception_flags = ExprId('exception_flags', 32) diff --git a/miasm2/arch/aarch64/sem.py b/miasm2/arch/aarch64/sem.py index 646065f4..c8077ebf 100644 --- a/miasm2/arch/aarch64/sem.py +++ b/miasm2/arch/aarch64/sem.py @@ -1,4 +1,5 @@ -from miasm2.expression import expression as m2_expr +from miasm2.expression.expression import ExprId, ExprInt, ExprLoc, ExprMem, \ + ExprSlice, ExprCond, ExprCompose, ExprOp, ExprAff from miasm2.ir.ir import IntermediateRepresentation, IRBlock, AssignBlock from miasm2.arch.aarch64.arch import mn_aarch64, conds_expr, replace_regs from miasm2.arch.aarch64.regs import * @@ -10,11 +11,20 @@ from miasm2.jitter.csts import EXCEPT_DIV_BY_ZERO, EXCEPT_INT_XX def update_flag_zf(a): - return [m2_expr.ExprAff(zf, m2_expr.ExprCond(a, m2_expr.ExprInt(0, 1), m2_expr.ExprInt(1, 1)))] + return [ExprAff(zf, ExprOp("FLAG_EQ", a))] -def update_flag_nf(a): - return [m2_expr.ExprAff(nf, a.msb())] +def update_flag_zf_eq(a, b): + return [ExprAff(zf, ExprOp("FLAG_EQ_CMP", a, b))] + + +def update_flag_nf(arg): + return [ + ExprAff( + nf, + ExprOp("FLAG_SIGN_SUB", arg, ExprInt(0, arg.size)) + ) + ] def update_flag_zn(a): @@ -24,103 +34,153 @@ def update_flag_zn(a): return e -def update_flag_logic(a): +def check_ops_msb(a, b, c): + if not a or not b or not c or a != b or a != c: + raise ValueError('bad ops size %s %s %s' % (a, b, c)) + + +def update_flag_add_cf(op1, op2): + "Compute cf in @op1 + @op2" + return [ExprAff(cf, ExprOp("FLAG_ADD_CF", op1, op2))] + + +def update_flag_add_of(op1, op2): + "Compute of in @op1 + @op2" + return [ExprAff(of, ExprOp("FLAG_ADD_OF", op1, op2))] + + +def update_flag_sub_cf(op1, op2): + "Compote CF in @op1 - @op2" + return [ExprAff(cf, ExprOp("FLAG_SUB_CF", op1, op2) ^ ExprInt(1, 1))] + + +def update_flag_sub_of(op1, op2): + "Compote OF in @op1 - @op2" + return [ExprAff(of, ExprOp("FLAG_SUB_OF", op1, op2))] + + +def update_flag_arith_add_co(arg1, arg2): e = [] - e += update_flag_zn(a) - # XXX TODO: set cf if ROT imm in argument - # e.append(m2_expr.ExprAff(cf, m2_expr.ExprInt(0, 1))) + e += update_flag_add_cf(arg1, arg2) + e += update_flag_add_of(arg1, arg2) return e -def update_flag_arith(a): +def update_flag_arith_add_zn(arg1, arg2): + """ + Compute zf and nf flags for (arg1 + arg2) + """ e = [] - e += update_flag_zn(a) + e += update_flag_zf_eq(arg1, -arg2) + e += [ExprAff(nf, ExprOp("FLAG_SIGN_SUB", arg1, -arg2))] return e -def check_ops_msb(a, b, c): - if not a or not b or not c or a != b or a != c: - raise ValueError('bad ops size %s %s %s' % (a, b, c)) +def update_flag_arith_sub_co(arg1, arg2): + """ + Compute cf and of flags for (arg1 - arg2) + """ + e = [] + e += update_flag_sub_cf(arg1, arg2) + e += update_flag_sub_of(arg1, arg2) + return e -def arith_flag(a, b, c): - a_s, b_s, c_s = a.size, b.size, c.size - check_ops_msb(a_s, b_s, c_s) - a_s, b_s, c_s = a.msb(), b.msb(), c.msb() - return a_s, b_s, c_s +def update_flag_arith_sub_zn(arg1, arg2): + """ + Compute zf and nf flags for (arg1 - arg2) + """ + e = [] + e += update_flag_zf_eq(arg1, arg2) + e += [ExprAff(nf, ExprOp("FLAG_SIGN_SUB", arg1, arg2))] + return e -# checked: ok for adc add because b & c before +cf -def update_flag_add_cf(op1, op2, res): - "Compute cf in @res = @op1 + @op2" - return m2_expr.ExprAff(cf, (((op1 ^ op2) ^ res) ^ ((op1 ^ res) & (~(op1 ^ op2)))).msb()) +def update_flag_zfaddwc_eq(arg1, arg2, arg3): + return [ExprAff(zf, ExprOp("FLAG_EQ_ADDWC", arg1, arg2, arg3))] -def update_flag_add_of(op1, op2, res): - "Compute of in @res = @op1 + @op2" - return m2_expr.ExprAff(of, (((op1 ^ res) & (~(op1 ^ op2)))).msb()) +def update_flag_zfsubwc_eq(arg1, arg2, arg3): + return [ExprAff(zf, ExprOp("FLAG_EQ_SUBWC", arg1, arg2, arg3))] -# checked: ok for sbb add because b & c before +cf -def update_flag_sub_cf(op1, op2, res): - "Compote CF in @res = @op1 - @op2" - return m2_expr.ExprAff(cf, - ((((op1 ^ op2) ^ res) ^ ((op1 ^ res) & (op1 ^ op2))).msb()) ^ m2_expr.ExprInt(1, 1)) +def update_flag_arith_addwc_zn(arg1, arg2, arg3): + """ + Compute znp flags for (arg1 + arg2 + cf) + """ + e = [] + e += update_flag_zfaddwc_eq(arg1, arg2, arg3) + e += [ExprAff(nf, ExprOp("FLAG_SIGN_ADDWC", arg1, arg2, arg3))] + return e -def update_flag_sub_of(op1, op2, res): - "Compote OF in @res = @op1 - @op2" - return m2_expr.ExprAff(of, (((op1 ^ res) & (op1 ^ op2))).msb()) +def update_flag_arith_subwc_zn(arg1, arg2, arg3): + """ + Compute znp flags for (arg1 - (arg2 + cf)) + """ + e = [] + e += update_flag_zfsubwc_eq(arg1, arg2, arg3) + e += [ExprAff(nf, ExprOp("FLAG_SIGN_SUBWC", arg1, arg2, arg3))] + return e -# clearing cv flags for bics (see C5.6.25) +def update_flag_addwc_cf(op1, op2, op3): + "Compute cf in @res = @op1 + @op2 + @op3" + return [ExprAff(cf, ExprOp("FLAG_ADDWC_CF", op1, op2, op3))] -def update_flag_bics (): - "Clear CF and OF" - return [ExprAff(cf, ExprInt (0,1)), ExprAff(of, ExprInt (0,1))] -# z = x+y (+cf?) +def update_flag_addwc_of(op1, op2, op3): + "Compute of in @res = @op1 + @op2 + @op3" + return [ExprAff(of, ExprOp("FLAG_ADDWC_OF", op1, op2, op3))] -def update_flag_add(x, y, z): +def update_flag_arith_addwc_co(arg1, arg2, arg3): e = [] - e.append(update_flag_add_cf(x, y, z)) - e.append(update_flag_add_of(x, y, z)) + e += update_flag_addwc_cf(arg1, arg2, arg3) + e += update_flag_addwc_of(arg1, arg2, arg3) return e -# z = x-y (+cf?) -def update_flag_sub(x, y, z): +def update_flag_subwc_cf(op1, op2, op3): + "Compute cf in @res = @op1 + @op2 + @op3" + return [ExprAff(cf, ExprOp("FLAG_SUBWC_CF", op1, op2, op3) ^ ExprInt(1, 1))] + + +def update_flag_subwc_of(op1, op2, op3): + "Compute of in @res = @op1 + @op2 + @op3" + return [ExprAff(of, ExprOp("FLAG_SUBWC_OF", op1, op2, op3))] + + +def update_flag_arith_subwc_co(arg1, arg2, arg3): e = [] - e.append(update_flag_sub_cf(x, y, z)) - e.append(update_flag_sub_of(x, y, z)) + e += update_flag_subwc_cf(arg1, arg2, arg3) + e += update_flag_subwc_of(arg1, arg2, arg3) return e -cond2expr = {'EQ': zf, - 'NE': zf ^ m2_expr.ExprInt(1, 1), - 'CS': cf, - 'CC': cf ^ m2_expr.ExprInt(1, 1), - 'MI': nf, - 'PL': nf ^ m2_expr.ExprInt(1, 1), - 'VS': of, - 'VC': of ^ m2_expr.ExprInt(1, 1), - 'HI': cf & (zf ^ m2_expr.ExprInt(1, 1)), - 'LS': (cf ^ m2_expr.ExprInt(1, 1)) | zf, - 'GE': nf ^ of ^ m2_expr.ExprInt(1, 1), - 'LT': nf ^ of, - 'GT': ((zf ^ m2_expr.ExprInt(1, 1)) & - (nf ^ of ^ m2_expr.ExprInt(1, 1))), - 'LE': zf | (nf ^ of), - 'AL': m2_expr.ExprInt(1, 1), - 'NV': m2_expr.ExprInt(0, 1) +cond2expr = {'EQ': ExprOp("CC_EQ", zf), + 'NE': ExprOp("CC_NE", zf), + 'CS': ExprOp("CC_U>=", cf ^ ExprInt(1, 1)), # inv cf + 'CC': ExprOp("CC_U<", cf ^ ExprInt(1, 1)), # inv cf + 'MI': ExprOp("CC_NEG", nf), + 'PL': ExprOp("CC_POS", nf), + 'VS': ExprOp("CC_sOVR", of), + 'VC': ExprOp("CC_sNOOVR", of), + 'HI': ExprOp("CC_U>", cf ^ ExprInt(1, 1), zf), # inv cf + 'LS': ExprOp("CC_U<=", cf ^ ExprInt(1, 1), zf), # inv cf + 'GE': ExprOp("CC_S>=", nf, of), + 'LT': ExprOp("CC_S<", nf, of), + 'GT': ExprOp("CC_S>", nf, of, zf), + 'LE': ExprOp("CC_S<=", nf, of, zf), + 'AL': ExprInt(1, 1), + 'NV': ExprInt(0, 1) } def extend_arg(dst, arg): - if not isinstance(arg, m2_expr.ExprOp): + if not isinstance(arg, ExprOp): return arg op, (reg, shift) = arg.op, arg.args @@ -156,7 +216,7 @@ def extend_arg(dst, arg): raise NotImplementedError('Unknown shifter operator') out = ExprOp(op, base, (shift.zeroExtend(dst.size) - & m2_expr.ExprInt(dst.size - 1, dst.size))) + & ExprInt(dst.size - 1, dst.size))) return out @@ -169,7 +229,7 @@ ctx = {"PC": PC, "of": of, "cond2expr": cond2expr, "extend_arg": extend_arg, - "m2_expr":m2_expr, + "ExprId":ExprId, "exception_flags": exception_flags, "interrupt_num": interrupt_num, "EXCEPT_DIV_BY_ZERO": EXCEPT_DIV_BY_ZERO, @@ -228,9 +288,14 @@ def bic(arg1, arg2, arg3): def bics(ir, instr, arg1, arg2, arg3): e = [] - arg1 = arg2 & (~extend_arg(arg2, arg3)) - e += update_flag_logic (arg1) - e += update_flag_bics () + tmp1, tmp2 = arg2, (~extend_arg(arg2, arg3)) + + e += [ExprAff(zf, ExprOp('FLAG_EQ_AND', tmp1, tmp2))] + e += update_flag_nf(res) + + e.append(ExprAff(arg1, res)) + + e += null_flag_co() return e, [] @@ -243,9 +308,12 @@ def adds(ir, instr, arg1, arg2, arg3): e = [] arg3 = extend_arg(arg2, arg3) res = arg2 + arg3 - e += update_flag_arith(res) - e += update_flag_add(arg2, arg3, res) - e.append(m2_expr.ExprAff(arg1, res)) + + e += update_flag_arith_add_zn(arg2, arg3) + e += update_flag_arith_add_co(arg2, arg3) + + e.append(ExprAff(arg1, res)) + return e, [] @@ -253,18 +321,22 @@ def subs(ir, instr, arg1, arg2, arg3): e = [] arg3 = extend_arg(arg2, arg3) res = arg2 - arg3 - e += update_flag_arith(res) - e += update_flag_sub(arg2, arg3, res) - e.append(m2_expr.ExprAff(arg1, res)) + + + e += update_flag_arith_sub_zn(arg2, arg3) + e += update_flag_arith_sub_co(arg2, arg3) + + e.append(ExprAff(arg1, res)) return e, [] def cmp(ir, instr, arg1, arg2): e = [] arg2 = extend_arg(arg1, arg2) - res = arg1 - arg2 - e += update_flag_arith(res) - e += update_flag_sub(arg1, arg2, res) + + e += update_flag_arith_sub_zn(arg1, arg2) + e += update_flag_arith_sub_co(arg1, arg2) + return e, [] @@ -272,8 +344,11 @@ def cmn(ir, instr, arg1, arg2): e = [] arg2 = extend_arg(arg1, arg2) res = arg1 + arg2 - e += update_flag_arith(res) - e += update_flag_add(arg1, arg2, res) + + + e += update_flag_arith_add_zn(arg1, arg2) + e += update_flag_arith_add_co(arg1, arg2) + return e, [] @@ -281,32 +356,38 @@ def ands(ir, instr, arg1, arg2, arg3): e = [] arg3 = extend_arg(arg2, arg3) res = arg2 & arg3 - e += update_flag_logic(res) - e.append(m2_expr.ExprAff(arg1, res)) + + e += [ExprAff(zf, ExprOp('FLAG_EQ_AND', arg2, arg3))] + e += update_flag_nf(res) + + e.append(ExprAff(arg1, res)) return e, [] def tst(ir, instr, arg1, arg2): e = [] arg2 = extend_arg(arg1, arg2) res = arg1 & arg2 - e += update_flag_logic(res) + + e += [ExprAff(zf, ExprOp('FLAG_EQ_AND', arg1, arg2))] + e += update_flag_nf(res) + return e, [] @sbuild.parse def lsl(arg1, arg2, arg3): - arg1 = arg2 << (arg3 & m2_expr.ExprInt(arg3.size - 1, arg3.size)) + arg1 = arg2 << (arg3 & ExprInt(arg3.size - 1, arg3.size)) @sbuild.parse def lsr(arg1, arg2, arg3): - arg1 = arg2 >> (arg3 & m2_expr.ExprInt(arg3.size - 1, arg3.size)) + arg1 = arg2 >> (arg3 & ExprInt(arg3.size - 1, arg3.size)) @sbuild.parse def asr(arg1, arg2, arg3): - arg1 = m2_expr.ExprOp( - 'a>>', arg2, (arg3 & m2_expr.ExprInt(arg3.size - 1, arg3.size))) + arg1 = ExprOp( + 'a>>', arg2, (arg3 & ExprInt(arg3.size - 1, arg3.size))) @sbuild.parse @@ -316,15 +397,15 @@ def mov(arg1, arg2): def movk(ir, instr, arg1, arg2): e = [] - if isinstance(arg2, m2_expr.ExprOp): + if isinstance(arg2, ExprOp): assert(arg2.op == 'slice_at' and - isinstance(arg2.args[0], m2_expr.ExprInt) and - isinstance(arg2.args[1], m2_expr.ExprInt)) + isinstance(arg2.args[0], ExprInt) and + isinstance(arg2.args[1], ExprInt)) value, shift = int(arg2.args[0].arg), int(arg2.args[1]) e.append( - m2_expr.ExprAff(arg1[shift:shift + 16], m2_expr.ExprInt(value, 16))) + ExprAff(arg1[shift:shift + 16], ExprInt(value, 16))) else: - e.append(m2_expr.ExprAff(arg1[:16], m2_expr.ExprInt(int(arg2), 16))) + e.append(ExprAff(arg1[:16], ExprInt(int(arg2), 16))) return e, [] @@ -343,7 +424,7 @@ def movn(arg1, arg2): def bl(arg1): PC = arg1 ir.IRDst = arg1 - LR = m2_expr.ExprInt(instr.offset + instr.l, 64) + LR = ExprInt(instr.offset + instr.l, 64) @sbuild.parse def csel(arg1, arg2, arg3, arg4): @@ -353,7 +434,7 @@ def csel(arg1, arg2, arg3, arg4): def ccmp(ir, instr, arg1, arg2, arg3, arg4): e = [] if(arg2.is_int): - arg2=m2_expr.ExprInt(arg2.arg.arg,arg1.size) + arg2=ExprInt(arg2.arg.arg,arg1.size) default_nf = arg3[0:1] default_zf = arg3[1:2] default_cf = arg3[2:3] @@ -365,71 +446,102 @@ def ccmp(ir, instr, arg1, arg2, arg3, arg4): new_cf = update_flag_sub_cf(arg1, arg2, res).src new_of = update_flag_sub_of(arg1, arg2, res).src - e.append(m2_expr.ExprAff(nf, m2_expr.ExprCond(cond_expr, + e.append(ExprAff(nf, ExprCond(cond_expr, new_nf, default_nf))) - e.append(m2_expr.ExprAff(zf, m2_expr.ExprCond(cond_expr, + e.append(ExprAff(zf, ExprCond(cond_expr, new_zf, default_zf))) - e.append(m2_expr.ExprAff(cf, m2_expr.ExprCond(cond_expr, + e.append(ExprAff(cf, ExprCond(cond_expr, new_cf, default_cf))) - e.append(m2_expr.ExprAff(of, m2_expr.ExprCond(cond_expr, + e.append(ExprAff(of, ExprCond(cond_expr, new_of, default_of))) return e, [] - + def csinc(ir, instr, arg1, arg2, arg3, arg4): e = [] cond_expr = cond2expr[arg4.name] - e.append(m2_expr.ExprAff(arg1, m2_expr.ExprCond(cond_expr, - arg2, - arg3 + m2_expr.ExprInt(1, arg3.size)))) + e.append( + ExprAff( + arg1, + ExprCond( + cond_expr, + arg2, + arg3 + ExprInt(1, arg3.size) + ) + ) + ) return e, [] def csinv(ir, instr, arg1, arg2, arg3, arg4): e = [] cond_expr = cond2expr[arg4.name] - e.append(m2_expr.ExprAff(arg1, m2_expr.ExprCond(cond_expr, - arg2, - ~arg3))) + e.append( + ExprAff( + arg1, + ExprCond( + cond_expr, + arg2, + ~arg3) + ) + ) return e, [] def csneg(ir, instr, arg1, arg2, arg3, arg4): e = [] cond_expr = cond2expr[arg4.name] - e.append(m2_expr.ExprAff(arg1, m2_expr.ExprCond(cond_expr, - arg2, - -arg3))) + e.append( + ExprAff( + arg1, + ExprCond( + cond_expr, + arg2, + -arg3) + ) + ) return e, [] def cset(ir, instr, arg1, arg2): e = [] cond_expr = cond2expr[arg2.name] - e.append(m2_expr.ExprAff(arg1, m2_expr.ExprCond(cond_expr, - m2_expr.ExprInt( - 1, arg1.size), - m2_expr.ExprInt(0, arg1.size)))) + e.append( + ExprAff( + arg1, + ExprCond( + cond_expr, + ExprInt(1, arg1.size), + ExprInt(0, arg1.size) + ) + ) + ) return e, [] def csetm(ir, instr, arg1, arg2): e = [] cond_expr = cond2expr[arg2.name] - e.append(m2_expr.ExprAff(arg1, m2_expr.ExprCond(cond_expr, - m2_expr.ExprInt( - -1, arg1.size), - m2_expr.ExprInt(0, arg1.size)))) + e.append( + ExprAff( + arg1, + ExprCond( + cond_expr, + ExprInt(-1, arg1.size), + ExprInt(0, arg1.size) + ) + ) + ) return e, [] def get_mem_access(mem): updt = None - if isinstance(mem, m2_expr.ExprOp): + if isinstance(mem, ExprOp): if mem.op == 'preinc': addr = mem.args[0] + mem.args[1] elif mem.op == 'segm': @@ -442,7 +554,7 @@ def get_mem_access(mem): off = reg.zeroExtend(base.size) << shift.zeroExtend(base.size) addr = base + off elif op == 'LSL': - if isinstance(shift, m2_expr.ExprInt) and int(shift) == 0: + if isinstance(shift, ExprInt) and int(shift) == 0: addr = base + reg.zeroExtend(base.size) else: addr = base + \ @@ -452,11 +564,11 @@ def get_mem_access(mem): raise NotImplementedError('bad op') elif mem.op == "postinc": addr, off = mem.args - updt = m2_expr.ExprAff(addr, addr + off) + updt = ExprAff(addr, addr + off) elif mem.op == "preinc_wb": base, off = mem.args addr = base + off - updt = m2_expr.ExprAff(base, base + off) + updt = ExprAff(base, base + off) else: raise NotImplementedError('bad op') else: @@ -468,7 +580,7 @@ def get_mem_access(mem): def ldr(ir, instr, arg1, arg2): e = [] addr, updt = get_mem_access(arg2) - e.append(m2_expr.ExprAff(arg1, m2_expr.ExprMem(addr, arg1.size))) + e.append(ExprAff(arg1, ExprMem(addr, arg1.size))) if updt: e.append(updt) return e, [] @@ -478,7 +590,7 @@ def ldr_size(ir, instr, arg1, arg2, size): e = [] addr, updt = get_mem_access(arg2) e.append( - m2_expr.ExprAff(arg1, m2_expr.ExprMem(addr, size).zeroExtend(arg1.size))) + ExprAff(arg1, ExprMem(addr, size).zeroExtend(arg1.size))) if updt: e.append(updt) return e, [] @@ -496,7 +608,7 @@ def ldrs_size(ir, instr, arg1, arg2, size): e = [] addr, updt = get_mem_access(arg2) e.append( - m2_expr.ExprAff(arg1, m2_expr.ExprMem(addr, size).signExtend(arg1.size))) + ExprAff(arg1, ExprMem(addr, size).signExtend(arg1.size))) if updt: e.append(updt) return e, [] @@ -518,7 +630,7 @@ def ldrsw(ir, instr, arg1, arg2): def l_str(ir, instr, arg1, arg2): e = [] addr, updt = get_mem_access(arg2) - e.append(m2_expr.ExprAff(m2_expr.ExprMem(addr, arg1.size), arg1)) + e.append(ExprAff(ExprMem(addr, arg1.size), arg1)) if updt: e.append(updt) return e, [] @@ -527,7 +639,7 @@ def l_str(ir, instr, arg1, arg2): def strb(ir, instr, arg1, arg2): e = [] addr, updt = get_mem_access(arg2) - e.append(m2_expr.ExprAff(m2_expr.ExprMem(addr, 8), arg1[:8])) + e.append(ExprAff(ExprMem(addr, 8), arg1[:8])) if updt: e.append(updt) return e, [] @@ -536,7 +648,7 @@ def strb(ir, instr, arg1, arg2): def strh(ir, instr, arg1, arg2): e = [] addr, updt = get_mem_access(arg2) - e.append(m2_expr.ExprAff(m2_expr.ExprMem(addr, 16), arg1[:16])) + e.append(ExprAff(ExprMem(addr, 16), arg1[:16])) if updt: e.append(updt) return e, [] @@ -545,9 +657,9 @@ def strh(ir, instr, arg1, arg2): def stp(ir, instr, arg1, arg2, arg3): e = [] addr, updt = get_mem_access(arg3) - e.append(m2_expr.ExprAff(m2_expr.ExprMem(addr, arg1.size), arg1)) + e.append(ExprAff(ExprMem(addr, arg1.size), arg1)) e.append( - m2_expr.ExprAff(m2_expr.ExprMem(addr + m2_expr.ExprInt(arg1.size / 8, addr.size), arg2.size), arg2)) + ExprAff(ExprMem(addr + ExprInt(arg1.size / 8, addr.size), arg2.size), arg2)) if updt: e.append(updt) return e, [] @@ -556,9 +668,9 @@ def stp(ir, instr, arg1, arg2, arg3): def ldp(ir, instr, arg1, arg2, arg3): e = [] addr, updt = get_mem_access(arg3) - e.append(m2_expr.ExprAff(arg1, m2_expr.ExprMem(addr, arg1.size))) + e.append(ExprAff(arg1, ExprMem(addr, arg1.size))) e.append( - m2_expr.ExprAff(arg2, m2_expr.ExprMem(addr + m2_expr.ExprInt(arg1.size / 8, addr.size), arg2.size))) + ExprAff(arg2, ExprMem(addr + ExprInt(arg1.size / 8, addr.size), arg2.size))) if updt: e.append(updt) return e, [] @@ -570,9 +682,9 @@ def sbfm(ir, instr, arg1, arg2, arg3, arg4): if sim > rim: res = arg2[rim:sim].signExtend(arg1.size) else: - shift = m2_expr.ExprInt(arg2.size - rim, arg2.size) + shift = ExprInt(arg2.size - rim, arg2.size) res = (arg2[:sim].signExtend(arg1.size) << shift) - e.append(m2_expr.ExprAff(arg1, res)) + e.append(ExprAff(arg1, res)) return e, [] @@ -582,9 +694,9 @@ def ubfm(ir, instr, arg1, arg2, arg3, arg4): if sim > rim: res = arg2[rim:sim].zeroExtend(arg1.size) else: - shift = m2_expr.ExprInt(arg2.size - rim, arg2.size) + shift = ExprInt(arg2.size - rim, arg2.size) res = (arg2[:sim].zeroExtend(arg1.size) << shift) - e.append(m2_expr.ExprAff(arg1, res)) + e.append(ExprAff(arg1, res)) return e, [] def bfm(ir, instr, arg1, arg2, arg3, arg4): @@ -592,12 +704,77 @@ def bfm(ir, instr, arg1, arg2, arg3, arg4): rim, sim = int(arg3.arg), int(arg4) + 1 if sim > rim: res = arg2[rim:sim] - e.append(m2_expr.ExprAff(arg1[:sim-rim], res)) + e.append(ExprAff(arg1[:sim-rim], res)) else: shift_i = arg2.size - rim - shift = m2_expr.ExprInt(shift_i, arg2.size) + shift = ExprInt(shift_i, arg2.size) res = arg2[:sim] - e.append(m2_expr.ExprAff(arg1[shift_i:shift_i+sim], res)) + e.append(ExprAff(arg1[shift_i:shift_i+sim], res)) + return e, [] + + + +def mrs(ir, insr, arg1, arg2, arg3, arg4, arg5): + e = [] + if arg2.is_int(3) and arg3.is_id("c4") and arg4.is_id("c2") and arg5.is_int(0): + out = [] + out.append(ExprInt(0x0, 28)) + out.append(of) + out.append(cf) + out.append(zf) + out.append(nf) + e.append(ExprAff(arg1, ExprCompose(*out).zeroExtend(arg1.size))) + else: + raise NotImplementedError("MSR not implemented") + return e, [] + +def msr(ir, instr, arg1, arg2, arg3, arg4, arg5): + + e = [] + if arg1.is_int(3) and arg2.is_id("c4") and arg3.is_id("c2") and arg4.is_int(0): + e.append(ExprAff(nf, arg5[31:32])) + e.append(ExprAff(zf, arg5[30:31])) + e.append(ExprAff(cf, arg5[29:30])) + e.append(ExprAff(of, arg5[28:29])) + else: + raise NotImplementedError("MRS not implemented") + return e, [] + + + +def adc(ir, instr, arg1, arg2, arg3): + arg3 = extend_arg(arg2, arg3) + e = [] + r = arg2 + arg3 + cf.zeroExtend(arg3.size) + e.append(ExprAff(arg1, r)) + return e, [] + + +def adcs(ir, instr, arg1, arg2, arg3): + arg3 = extend_arg(arg2, arg3) + e = [] + r = arg2 + arg3 + cf.zeroExtend(arg3.size) + e.append(ExprAff(arg1, r)) + e += update_flag_arith_addwc_zn(arg2, arg3, cf) + e += update_flag_arith_addwc_co(arg2, arg3, cf) + return e, [] + + +def sbc(ir, instr, arg1, arg2, arg3): + arg3 = extend_arg(arg2, arg3) + e = [] + r = arg2 - (arg3 + (~cf).zeroExtend(arg3.size)) + e.append(ExprAff(arg1, r)) + return e, [] + + +def sbcs(ir, instr, arg1, arg2, arg3): + arg3 = extend_arg(arg2, arg3) + e = [] + r = arg2 - (arg3 + (~cf).zeroExtend(arg3.size)) + e.append(ExprAff(arg1, r)) + e += update_flag_arith_subwc_zn(arg2, arg3, ~cf) + e += update_flag_arith_subwc_co(arg2, arg3, ~cf) return e, [] @@ -614,30 +791,30 @@ def msub(arg1, arg2, arg3, arg4): @sbuild.parse def udiv(arg1, arg2, arg3): if arg3: - arg1 = m2_expr.ExprOp('udiv', arg2, arg3) + arg1 = ExprOp('udiv', arg2, arg3) else: - exception_flags = m2_expr.ExprInt(EXCEPT_DIV_BY_ZERO, + exception_flags = ExprInt(EXCEPT_DIV_BY_ZERO, exception_flags.size) @sbuild.parse def cbz(arg1, arg2): - dst = m2_expr.ExprLoc(ir.get_next_loc_key(instr), 64) if arg1 else arg2 + dst = ExprLoc(ir.get_next_loc_key(instr), 64) if arg1 else arg2 PC = dst ir.IRDst = dst @sbuild.parse def cbnz(arg1, arg2): - dst = arg2 if arg1 else m2_expr.ExprLoc(ir.get_next_loc_key(instr), 64) + dst = arg2 if arg1 else ExprLoc(ir.get_next_loc_key(instr), 64) PC = dst ir.IRDst = dst @sbuild.parse def tbz(arg1, arg2, arg3): - bitmask = m2_expr.ExprInt(1, arg1.size) << arg2 - dst = m2_expr.ExprLoc( + bitmask = ExprInt(1, arg1.size) << arg2 + dst = ExprLoc( ir.get_next_loc_key(instr), 64 ) if arg1 & bitmask else arg3 @@ -647,8 +824,8 @@ def tbz(arg1, arg2, arg3): @sbuild.parse def tbnz(arg1, arg2, arg3): - bitmask = m2_expr.ExprInt(1, arg1.size) << arg2 - dst = arg3 if arg1 & bitmask else m2_expr.ExprLoc( + bitmask = ExprInt(1, arg1.size) << arg2 + dst = arg3 if arg1 & bitmask else ExprLoc( ir.get_next_loc_key(instr), 64 ) @@ -658,14 +835,16 @@ def tbnz(arg1, arg2, arg3): @sbuild.parse def b_ne(arg1): - dst = m2_expr.ExprLoc(ir.get_next_loc_key(instr), 64) if zf else arg1 + cond = cond2expr['NE'] + dst = arg1 if cond else ExprLoc(ir.get_next_loc_key(instr), 64) PC = dst ir.IRDst = dst @sbuild.parse def b_eq(arg1): - dst = arg1 if zf else m2_expr.ExprLoc(ir.get_next_loc_key(instr), 64) + cond = cond2expr['EQ'] + dst = arg1 if cond else ExprLoc(ir.get_next_loc_key(instr), 64) PC = dst ir.IRDst = dst @@ -673,7 +852,7 @@ def b_eq(arg1): @sbuild.parse def b_ge(arg1): cond = cond2expr['GE'] - dst = arg1 if cond else m2_expr.ExprLoc(ir.get_next_loc_key(instr), 64) + dst = arg1 if cond else ExprLoc(ir.get_next_loc_key(instr), 64) PC = dst ir.IRDst = dst @@ -681,7 +860,7 @@ def b_ge(arg1): @sbuild.parse def b_gt(arg1): cond = cond2expr['GT'] - dst = arg1 if cond else m2_expr.ExprLoc(ir.get_next_loc_key(instr), 64) + dst = arg1 if cond else ExprLoc(ir.get_next_loc_key(instr), 64) PC = dst ir.IRDst = dst @@ -689,7 +868,7 @@ def b_gt(arg1): @sbuild.parse def b_cc(arg1): cond = cond2expr['CC'] - dst = arg1 if cond else m2_expr.ExprLoc(ir.get_next_loc_key(instr), 64) + dst = arg1 if cond else ExprLoc(ir.get_next_loc_key(instr), 64) PC = dst ir.IRDst = dst @@ -697,7 +876,7 @@ def b_cc(arg1): @sbuild.parse def b_cs(arg1): cond = cond2expr['CS'] - dst = arg1 if cond else m2_expr.ExprLoc(ir.get_next_loc_key(instr), 64) + dst = arg1 if cond else ExprLoc(ir.get_next_loc_key(instr), 64) PC = dst ir.IRDst = dst @@ -705,7 +884,7 @@ def b_cs(arg1): @sbuild.parse def b_hi(arg1): cond = cond2expr['HI'] - dst = arg1 if cond else m2_expr.ExprLoc(ir.get_next_loc_key(instr), 64) + dst = arg1 if cond else ExprLoc(ir.get_next_loc_key(instr), 64) PC = dst ir.IRDst = dst @@ -713,7 +892,7 @@ def b_hi(arg1): @sbuild.parse def b_le(arg1): cond = cond2expr['LE'] - dst = arg1 if cond else m2_expr.ExprLoc(ir.get_next_loc_key(instr), 64) + dst = arg1 if cond else ExprLoc(ir.get_next_loc_key(instr), 64) PC = dst ir.IRDst = dst @@ -721,7 +900,7 @@ def b_le(arg1): @sbuild.parse def b_ls(arg1): cond = cond2expr['LS'] - dst = arg1 if cond else m2_expr.ExprLoc(ir.get_next_loc_key(instr), 64) + dst = arg1 if cond else ExprLoc(ir.get_next_loc_key(instr), 64) PC = dst ir.IRDst = dst @@ -729,7 +908,7 @@ def b_ls(arg1): @sbuild.parse def b_lt(arg1): cond = cond2expr['LT'] - dst = arg1 if cond else m2_expr.ExprLoc(ir.get_next_loc_key(instr), 64) + dst = arg1 if cond else ExprLoc(ir.get_next_loc_key(instr), 64) PC = dst ir.IRDst = dst @@ -742,7 +921,7 @@ def ret(arg1): @sbuild.parse def adrp(arg1, arg2): - arg1 = (PC & m2_expr.ExprInt(0xfffffffffffff000, 64)) + arg2 + arg1 = (PC & ExprInt(0xfffffffffffff000, 64)) + arg2 @sbuild.parse @@ -765,24 +944,34 @@ def br(arg1): def blr(arg1): PC = arg1 ir.IRDst = arg1 - LR = m2_expr.ExprLoc(ir.get_next_loc_key(instr), 64) + LR = ExprLoc(ir.get_next_loc_key(instr), 64) @sbuild.parse def nop(): """Do nothing""" +def rev(ir, instr, arg1, arg2): + out = [] + for i in xrange(0, arg2.size, 8): + out.append(arg2[i:i+8]) + out.reverse() + e = [] + result = ExprCompose(*out) + e.append(ExprAff(arg1, result)) + return e, [] + @sbuild.parse def extr(arg1, arg2, arg3, arg4): - compose = m2_expr.ExprCompose(arg2, arg3) + compose = ExprCompose(arg2, arg3) arg1 = compose[int(arg4.arg):int(arg4)+arg1.size] @sbuild.parse def svc(arg1): - exception_flags = m2_expr.ExprInt(EXCEPT_INT_XX, exception_flags.size) - interrupt_num = m2_expr.ExprInt(int(arg1), interrupt_num.size) + exception_flags = ExprInt(EXCEPT_INT_XX, exception_flags.size) + interrupt_num = ExprInt(int(arg1), interrupt_num.size) mnemo_func = sbuild.functions mnemo_func.update({ @@ -847,6 +1036,16 @@ mnemo_func.update({ 'ubfm': ubfm, 'extr': extr, + 'rev': rev, + + 'msr': msr, + 'mrs': mrs, + + 'adc': adc, + 'adcs': adcs, + 'sbc': sbc, + 'sbcs': sbcs, + }) @@ -869,15 +1068,15 @@ class ir_aarch64l(IntermediateRepresentation): IntermediateRepresentation.__init__(self, mn_aarch64, "l", loc_db) self.pc = PC self.sp = SP - self.IRDst = m2_expr.ExprId('IRDst', 64) + self.IRDst = ExprId('IRDst', 64) self.addrsize = 64 def get_ir(self, instr): args = instr.args - if len(args) and isinstance(args[-1], m2_expr.ExprOp): + if len(args) and isinstance(args[-1], ExprOp): if (args[-1].op in ['<<', '>>', '<<a', 'a>>', '<<<', '>>>'] and - isinstance(args[-1].args[-1], m2_expr.ExprId)): - args[-1] = m2_expr.ExprOp(args[-1].op, + isinstance(args[-1].args[-1], ExprId)): + args[-1] = ExprOp(args[-1].op, args[-1].args[0], args[-1].args[-1][:8].zeroExtend(32)) instr_ir, extra_ir = get_mnemo_expr(self, instr, *args) @@ -891,7 +1090,7 @@ class ir_aarch64l(IntermediateRepresentation): def expraff_fix_regs_for_mode(self, e): dst = self.expr_fix_regs_for_mode(e.dst) src = self.expr_fix_regs_for_mode(e.src) - return m2_expr.ExprAff(dst, src) + return ExprAff(dst, src) def irbloc_fix_regs_for_mode(self, irblock, mode=64): irs = [] @@ -901,7 +1100,7 @@ class ir_aarch64l(IntermediateRepresentation): del(new_assignblk[dst]) # Special case for 64 bits: # If destination is a 32 bit reg, zero extend the 64 bit reg - if (isinstance(dst, m2_expr.ExprId) and + if (isinstance(dst, ExprId) and dst.size == 32 and dst in replace_regs): src = src.zeroExtend(64) @@ -915,14 +1114,14 @@ class ir_aarch64l(IntermediateRepresentation): def mod_pc(self, instr, instr_ir, extra_ir): "Replace PC by the instruction's offset" - cur_offset = m2_expr.ExprInt(instr.offset, 64) + cur_offset = ExprInt(instr.offset, 64) pc_fixed = {self.pc: cur_offset} for i, expr in enumerate(instr_ir): dst, src = expr.dst, expr.src if dst != self.pc: dst = dst.replace_expr(pc_fixed) src = src.replace_expr(pc_fixed) - instr_ir[i] = m2_expr.ExprAff(dst, src) + instr_ir[i] = ExprAff(dst, src) for idx, irblock in enumerate(extra_ir): extra_ir[idx] = irblock.modify_exprs(lambda expr: expr.replace_expr(pc_fixed) \ @@ -953,4 +1152,4 @@ class ir_aarch64b(ir_aarch64l): IntermediateRepresentation.__init__(self, mn_aarch64, "b", loc_db) self.pc = PC self.sp = SP - self.IRDst = m2_expr.ExprId('IRDst', 64) + self.IRDst = ExprId('IRDst', 64) diff --git a/miasm2/arch/arm/jit.py b/miasm2/arch/arm/jit.py index 267bcea6..716a8826 100644 --- a/miasm2/arch/arm/jit.py +++ b/miasm2/arch/arm/jit.py @@ -8,6 +8,7 @@ from miasm2.jitter.codegen import CGen from miasm2.expression.expression import ExprId, ExprAff, ExprCond from miasm2.ir.ir import IRBlock, AssignBlock from miasm2.ir.translators.C import TranslatorC +from miasm2.expression.simplifications import expr_simp_high_to_explicit log = logging.getLogger('jit_arm') hnd = logging.StreamHandler() @@ -45,6 +46,15 @@ class arm_CGen(CGen): irblock_head = self.assignblk_to_irbloc(instr, assignblk_head) irblocks = [irblock_head] + assignblks_extra + + # Simplify high level operators + out = [] + for irblock in irblocks: + new_irblock = irblock.simplify(expr_simp_high_to_explicit)[1] + out.append(new_irblock) + irblocks = out + + for irblock in irblocks: assert irblock.dst is not None irblocks_list.append(irblocks) diff --git a/miasm2/arch/arm/sem.py b/miasm2/arch/arm/sem.py index d9c2d6cd..4e99e720 100644 --- a/miasm2/arch/arm/sem.py +++ b/miasm2/arch/arm/sem.py @@ -14,11 +14,20 @@ EXCEPT_PRIV_INSN = (1 << 17) def update_flag_zf(a): - return [ExprAff(zf, ExprCond(a, ExprInt(0, 1), ExprInt(1, 1)))] + return [ExprAff(zf, ExprOp("FLAG_EQ", a))] -def update_flag_nf(a): - return [ExprAff(nf, a.msb())] +def update_flag_zf_eq(a, b): + return [ExprAff(zf, ExprOp("FLAG_EQ_CMP", a, b))] + + +def update_flag_nf(arg): + return [ + ExprAff( + nf, + ExprOp("FLAG_SIGN_SUB", arg, ExprInt(0, arg.size)) + ) + ] def update_flag_zn(a): @@ -28,73 +37,136 @@ def update_flag_zn(a): return e -def update_flag_logic(a): + +# XXX TODO: set cf if ROT imm in argument + + +def check_ops_msb(a, b, c): + if not a or not b or not c or a != b or a != c: + raise ValueError('bad ops size %s %s %s' % (a, b, c)) + +def update_flag_add_cf(op1, op2): + "Compute cf in @op1 + @op2" + return [ExprAff(cf, ExprOp("FLAG_ADD_CF", op1, op2))] + + +def update_flag_add_of(op1, op2): + "Compute of in @op1 + @op2" + return [ExprAff(of, ExprOp("FLAG_ADD_OF", op1, op2))] + + +def update_flag_sub_cf(op1, op2): + "Compote CF in @op1 - @op2" + return [ExprAff(cf, ExprOp("FLAG_SUB_CF", op1, op2) ^ ExprInt(1, 1))] + + +def update_flag_sub_of(op1, op2): + "Compote OF in @op1 - @op2" + return [ExprAff(of, ExprOp("FLAG_SUB_OF", op1, op2))] + + +def update_flag_arith_add_co(arg1, arg2): e = [] - e += update_flag_zn(a) - # XXX TODO: set cf if ROT imm in argument - #e.append(ExprAff(cf, ExprInt(0, 1))) + e += update_flag_add_cf(arg1, arg2) + e += update_flag_add_of(arg1, arg2) return e -def update_flag_arith(a): +def update_flag_arith_add_zn(arg1, arg2): + """ + Compute zf and nf flags for (arg1 + arg2) + """ e = [] - e += update_flag_zn(a) + e += update_flag_zf_eq(arg1, -arg2) + e += [ExprAff(nf, ExprOp("FLAG_SIGN_SUB", arg1, -arg2))] return e -def check_ops_msb(a, b, c): - if not a or not b or not c or a != b or a != c: - raise ValueError('bad ops size %s %s %s' % (a, b, c)) +def update_flag_arith_sub_co(arg1, arg2): + """ + Compute cf and of flags for (arg1 - arg2) + """ + e = [] + e += update_flag_sub_cf(arg1, arg2) + e += update_flag_sub_of(arg1, arg2) + return e + + +def update_flag_arith_sub_zn(arg1, arg2): + """ + Compute zf and nf flags for (arg1 - arg2) + """ + e = [] + e += update_flag_zf_eq(arg1, arg2) + e += [ExprAff(nf, ExprOp("FLAG_SIGN_SUB", arg1, arg2))] + return e -def arith_flag(a, b, c): - a_s, b_s, c_s = a.size, b.size, c.size - check_ops_msb(a_s, b_s, c_s) - a_s, b_s, c_s = a.msb(), b.msb(), c.msb() - return a_s, b_s, c_s -# checked: ok for adc add because b & c before +cf -def update_flag_add_cf(op1, op2, res): - "Compute cf in @res = @op1 + @op2" - return ExprAff(cf, (((op1 ^ op2) ^ res) ^ ((op1 ^ res) & (~(op1 ^ op2)))).msb()) +def update_flag_zfaddwc_eq(arg1, arg2, arg3): + return [ExprAff(zf, ExprOp("FLAG_EQ_ADDWC", arg1, arg2, arg3))] +def update_flag_zfsubwc_eq(arg1, arg2, arg3): + return [ExprAff(zf, ExprOp("FLAG_EQ_SUBWC", arg1, arg2, arg3))] + + +def update_flag_arith_addwc_zn(arg1, arg2, arg3): + """ + Compute znp flags for (arg1 + arg2 + cf) + """ + e = [] + e += update_flag_zfaddwc_eq(arg1, arg2, arg3) + e += [ExprAff(nf, ExprOp("FLAG_SIGN_ADDWC", arg1, arg2, arg3))] + return e -def update_flag_add_of(op1, op2, res): - "Compute of in @res = @op1 + @op2" - return ExprAff(of, (((op1 ^ res) & (~(op1 ^ op2)))).msb()) +def update_flag_arith_subwc_zn(arg1, arg2, arg3): + """ + Compute znp flags for (arg1 - (arg2 + cf)) + """ + e = [] + e += update_flag_zfsubwc_eq(arg1, arg2, arg3) + e += [ExprAff(nf, ExprOp("FLAG_SIGN_SUBWC", arg1, arg2, arg3))] + return e -# checked: ok for sbb add because b & c before +cf -def update_flag_sub_cf(op1, op2, res): - "Compote CF in @res = @op1 - @op2" - return ExprAff(cf, - ((((op1 ^ op2) ^ res) ^ ((op1 ^ res) & (op1 ^ op2))).msb()) ^ ExprInt(1, 1)) +def update_flag_addwc_cf(op1, op2, op3): + "Compute cf in @res = @op1 + @op2 + @op3" + return [ExprAff(cf, ExprOp("FLAG_ADDWC_CF", op1, op2, op3))] -def update_flag_sub_of(op1, op2, res): - "Compote OF in @res = @op1 - @op2" - return ExprAff(of, (((op1 ^ res) & (op1 ^ op2))).msb()) -# z = x+y (+cf?) +def update_flag_addwc_of(op1, op2, op3): + "Compute of in @res = @op1 + @op2 + @op3" + return [ExprAff(of, ExprOp("FLAG_ADDWC_OF", op1, op2, op3))] -def update_flag_add(x, y, z): +def update_flag_arith_addwc_co(arg1, arg2, arg3): e = [] - e.append(update_flag_add_cf(x, y, z)) - e.append(update_flag_add_of(x, y, z)) + e += update_flag_addwc_cf(arg1, arg2, arg3) + e += update_flag_addwc_of(arg1, arg2, arg3) return e -# z = x-y (+cf?) -def update_flag_sub(x, y, z): +def update_flag_subwc_cf(op1, op2, op3): + "Compute cf in @res = @op1 + @op2 + @op3" + return [ExprAff(cf, ExprOp("FLAG_SUBWC_CF", op1, op2, op3) ^ ExprInt(1, 1))] + + +def update_flag_subwc_of(op1, op2, op3): + "Compute of in @res = @op1 + @op2 + @op3" + return [ExprAff(of, ExprOp("FLAG_SUBWC_OF", op1, op2, op3))] + + +def update_flag_arith_subwc_co(arg1, arg2, arg3): e = [] - e.append(update_flag_sub_cf(x, y, z)) - e.append(update_flag_sub_of(x, y, z)) + e += update_flag_subwc_cf(arg1, arg2, arg3) + e += update_flag_subwc_of(arg1, arg2, arg3) return e + def get_dst(a): if a == PC: return PC @@ -107,10 +179,11 @@ def adc(ir, instr, a, b, c=None): e = [] if c is None: b, c = a, b + arg1, arg2 = b, c r = b + c + cf.zeroExtend(32) if instr.name == 'ADCS' and a != PC: - e += update_flag_arith(r) - e += update_flag_add(b, c, r) + e += update_flag_arith_addwc_zn(arg1, arg2, cf) + e += update_flag_arith_addwc_co(arg1, arg2, cf) e.append(ExprAff(a, r)) dst = get_dst(a) if dst is not None: @@ -122,10 +195,11 @@ def add(ir, instr, a, b, c=None): e = [] if c is None: b, c = a, b + arg1, arg2 = b, c r = b + c if instr.name == 'ADDS' and a != PC: - e += update_flag_arith(r) - e += update_flag_add(b, c, r) + e += update_flag_arith_add_zn(arg1, arg2) + e += update_flag_arith_add_co(arg1, arg2) e.append(ExprAff(a, r)) dst = get_dst(a) if dst is not None: @@ -139,7 +213,9 @@ def l_and(ir, instr, a, b, c=None): b, c = a, b r = b & c if instr.name == 'ANDS' and a != PC: - e += update_flag_logic(r) + e += [ExprAff(zf, ExprOp('FLAG_EQ_AND', b, c))] + e += update_flag_nf(r) + e.append(ExprAff(a, r)) dst = get_dst(a) if dst is not None: @@ -163,9 +239,10 @@ def subs(ir, instr, a, b, c=None): e = [] if c is None: b, c = a, b + arg1, arg2 = b, c r = b - c - e += update_flag_arith(r) - e += update_flag_sub(b, c, r) + e += update_flag_arith_sub_zn(arg1, arg2) + e += update_flag_arith_sub_co(arg1, arg2) e.append(ExprAff(a, r)) dst = get_dst(a) if dst is not None: @@ -189,8 +266,12 @@ def eors(ir, instr, a, b, c=None): e = [] if c is None: b, c = a, b - r = b ^ c - e += update_flag_logic(r) + arg1, arg2 = b, c + r = arg1 ^ arg2 + + e += [ExprAff(zf, ExprOp('FLAG_EQ_CMP', arg1, arg2))] + e += update_flag_nf(r) + e.append(ExprAff(a, r)) dst = get_dst(a) if dst is not None: @@ -214,9 +295,12 @@ def rsbs(ir, instr, a, b, c=None): e = [] if c is None: b, c = a, b - r = c - b - e += update_flag_arith(r) - e += update_flag_sub(c, b, r) + arg1, arg2 = c, b + r = arg1 - arg2 + + e += update_flag_arith_sub_zn(arg1, arg2) + e += update_flag_arith_sub_co(arg1, arg2) + e.append(ExprAff(a, r)) dst = get_dst(a) if dst is not None: @@ -228,7 +312,8 @@ def sbc(ir, instr, a, b, c=None): e = [] if c is None: b, c = a, b - r = (b + cf.zeroExtend(32)) - (c + ExprInt(1, 32)) + arg1, arg2 = b, c + r = arg1 - (arg2 + (~cf).zeroExtend(32)) e.append(ExprAff(a, r)) dst = get_dst(a) if dst is not None: @@ -240,9 +325,12 @@ def sbcs(ir, instr, a, b, c=None): e = [] if c is None: b, c = a, b - r = (b + cf.zeroExtend(32)) - (c + ExprInt(1, 32)) - e += update_flag_arith(r) - e += update_flag_sub(b, c, r) + arg1, arg2 = b, c + r = arg1 - (arg2 + (~cf).zeroExtend(32)) + + e += update_flag_arith_subwc_zn(arg1, arg2, ~cf) + e += update_flag_arith_subwc_co(arg1, arg2, ~cf) + e.append(ExprAff(a, r)) dst = get_dst(a) if dst is not None: @@ -254,7 +342,8 @@ def rsc(ir, instr, a, b, c=None): e = [] if c is None: b, c = a, b - r = (c + cf.zeroExtend(32)) - (b + ExprInt(1, 32)) + arg1, arg2 = c, b + r = arg1 - (arg2 + (~cf).zeroExtend(32)) e.append(ExprAff(a, r)) dst = get_dst(a) if dst is not None: @@ -266,11 +355,14 @@ def rscs(ir, instr, a, b, c=None): e = [] if c is None: b, c = a, b - r = (c + cf.zeroExtend(32)) - (b + ExprInt(1, 32)) - e.append(ExprAff(a, r)) - e += update_flag_arith(r) - e += update_flag_sub(c, b, r) + arg1, arg2 = c, b + r = arg1 - (arg2 + (~cf).zeroExtend(32)) + + e += update_flag_arith_subwc_zn(arg1, arg2, ~cf) + e += update_flag_arith_subwc_co(arg1, arg2, ~cf) + e.append(ExprAff(a, r)) + dst = get_dst(a) if dst is not None: e.append(ExprAff(ir.IRDst, r)) @@ -279,8 +371,12 @@ def rscs(ir, instr, a, b, c=None): def tst(ir, instr, a, b): e = [] - r = a & b - e += update_flag_logic(r) + arg1, arg2 = a, b + r = arg1 & arg2 + + e += [ExprAff(zf, ExprOp('FLAG_EQ_AND', arg1, arg2))] + e += update_flag_nf(r) + return e, [] @@ -288,8 +384,12 @@ def teq(ir, instr, a, b, c=None): e = [] if c is None: b, c = a, b - r = b ^ c - e += update_flag_logic(r) + arg1, arg2 = b, c + r = arg1 ^ arg2 + + e += [ExprAff(zf, ExprOp('FLAG_EQ_CMP', arg1, arg2))] + e += update_flag_nf(r) + return e, [] @@ -297,9 +397,12 @@ def l_cmp(ir, instr, a, b, c=None): e = [] if c is None: b, c = a, b + arg1, arg2 = b, c r = b - c - e += update_flag_arith(r) - e += update_flag_sub(b, c, r) + + e += update_flag_arith_sub_zn(arg1, arg2) + e += update_flag_arith_sub_co(arg1, arg2) + return e, [] @@ -307,9 +410,12 @@ def cmn(ir, instr, a, b, c=None): e = [] if c is None: b, c = a, b + arg1, arg2 = b, c r = b + c - e += update_flag_arith(r) - e += update_flag_add(b, c, r) + + e += update_flag_arith_add_zn(arg1, arg2) + e += update_flag_arith_add_co(arg1, arg2) + return e, [] @@ -341,8 +447,12 @@ def orrs(ir, instr, a, b, c=None): e = [] if c is None: b, c = a, b + arg1, arg2 = b, c r = b | c - e += update_flag_logic(r) + + e += [ExprAff(zf, ExprOp('FLAG_EQ', r))] + e += update_flag_nf(r) + e.append(ExprAff(a, r)) dst = get_dst(a) if dst is not None: @@ -371,7 +481,9 @@ def movs(ir, instr, a, b): e = [] e.append(ExprAff(a, b)) # XXX TODO check - e += update_flag_logic(b) + e += [ExprAff(zf, ExprOp('FLAG_EQ', b))] + e += update_flag_nf(b) + dst = get_dst(a) if dst is not None: e.append(ExprAff(ir.IRDst, b)) @@ -392,13 +504,42 @@ def mvns(ir, instr, a, b): r = b ^ ExprInt(-1, 32) e.append(ExprAff(a, r)) # XXX TODO check - e += update_flag_logic(r) + e += [ExprAff(zf, ExprOp('FLAG_EQ', r))] + e += update_flag_nf(r) + dst = get_dst(a) if dst is not None: e.append(ExprAff(ir.IRDst, r)) return e, [] + +def mrs(ir, instr, a, b): + e = [] + if b.is_id('CPSR_cxsf'): + out = [] + out.append(ExprInt(0x10, 28)) + out.append(of) + out.append(cf) + out.append(zf) + out.append(nf) + e.append(ExprAff(a, ExprCompose(*out))) + else: + raise NotImplementedError("MSR not implemented") + return e, [] + +def msr(ir, instr, a, b): + e = [] + if a.is_id('CPSR_cf'): + e.append(ExprAff(nf, b[31:32])) + e.append(ExprAff(zf, b[30:31])) + e.append(ExprAff(cf, b[29:30])) + e.append(ExprAff(of, b[28:29])) + else: + raise NotImplementedError("MRS not implemented") + return e, [] + + def neg(ir, instr, a, b): e = [] r = - b @@ -427,8 +568,12 @@ def bics(ir, instr, a, b, c=None): e = [] if c is None: b, c = a, b - r = b & (c ^ ExprInt(-1, 32)) - e += update_flag_logic(r) + tmp1, tmp2 = b, ~c + r = tmp1 & tmp2 + + e += [ExprAff(zf, ExprOp('FLAG_EQ_AND', tmp1, tmp2))] + e += update_flag_nf(r) + e.append(ExprAff(a, r)) dst = get_dst(a) if dst is not None: @@ -836,7 +981,10 @@ def lsrs(ir, instr, a, b, c=None): b, c = a, b r = b >> c e.append(ExprAff(a, r)) - e += update_flag_logic(r) + + e += [ExprAff(zf, ExprOp('FLAG_EQ', r))] + e += update_flag_nf(r) + dst = get_dst(a) if dst is not None: e.append(ExprAff(ir.IRDst, r)) @@ -859,7 +1007,10 @@ def asrs(ir, instr, a, b, c=None): b, c = a, b r = ExprOp("a>>", b, c) e.append(ExprAff(a, r)) - e += update_flag_logic(r) + + e += [ExprAff(zf, ExprOp('FLAG_EQ', r))] + e += update_flag_nf(r) + dst = get_dst(a) if dst is not None: e.append(ExprAff(ir.IRDst, r)) @@ -883,7 +1034,10 @@ def lsls(ir, instr, a, b, c=None): b, c = a, b r = b << c e.append(ExprAff(a, r)) - e += update_flag_logic(r) + + e += [ExprAff(zf, ExprOp('FLAG_EQ', r))] + e += update_flag_nf(r) + dst = get_dst(a) if dst is not None: e.append(ExprAff(ir.IRDst, r)) @@ -894,7 +1048,10 @@ def rors(ir, instr, a, b): e = [] r = ExprOp(">>>", a, b) e.append(ExprAff(a, r)) - e += update_flag_logic(r) + + e += [ExprAff(zf, ExprOp('FLAG_EQ', r))] + e += update_flag_nf(r) + dst = get_dst(a) if dst is not None: e.append(ExprAff(ir.IRDst, r)) @@ -1223,31 +1380,46 @@ cond_dct = { cond_dct_inv = dict((name, num) for num, name in cond_dct.iteritems()) -tab_cond = {COND_EQ: zf, - COND_NE: ExprCond(zf, ExprInt(0, 1), ExprInt(1, 1)), - COND_CS: cf, - COND_CC: ExprCond(cf, ExprInt(0, 1), ExprInt(1, 1)), - COND_MI: nf, - COND_PL: ExprCond(nf, ExprInt(0, 1), ExprInt(1, 1)), - COND_VS: of, - COND_VC: ExprCond(of, ExprInt(0, 1), ExprInt(1, 1)), - COND_HI: cf & ExprCond(zf, ExprInt(0, 1), ExprInt(1, 1)), - # COND_HI: cf, - # COND_HI: ExprOp('==', - # ExprOp('|', cf, zf), - # ExprInt(0, 1)), - COND_LS: ExprCond(cf, ExprInt(0, 1), ExprInt(1, 1)) | zf, - COND_GE: ExprCond(nf - of, ExprInt(0, 1), ExprInt(1, 1)), - COND_LT: nf ^ of, - # COND_GT: ExprOp('|', - # ExprOp('==', zf, ExprInt(0, 1)) & (nf | of), - # ExprOp('==', nf, ExprInt(0, 1)) & ExprOp('==', of, ExprInt(0, 1))), - COND_GT: (ExprCond(zf, ExprInt(0, 1), ExprInt(1, 1)) & - ExprCond(nf - of, ExprInt(0, 1), ExprInt(1, 1))), - COND_LE: zf | (nf ^ of), + +""" +Code Meaning (for cmp or subs) Flags Tested +eq Equal. Z==1 +ne Not equal. Z==0 +cs or hs Unsigned higher or same (or carry set). C==1 +cc or lo Unsigned lower (or carry clear). C==0 +mi Negative. The mnemonic stands for "minus". N==1 +pl Positive or zero. The mnemonic stands for "plus". N==0 +vs Signed overflow. The mnemonic stands for "V set". V==1 +vc No signed overflow. The mnemonic stands for "V clear". V==0 +hi Unsigned higher. (C==1) && (Z==0) +ls Unsigned lower or same. (C==0) || (Z==1) +ge Signed greater than or equal. N==V +lt Signed less than. N!=V +gt Signed greater than. (Z==0) && (N==V) +le Signed less than or equal. (Z==1) || (N!=V) +al (or omitted) Always executed. None tested. +""" + +tab_cond = {COND_EQ: ExprOp("CC_EQ", zf), + COND_NE: ExprOp("CC_NE", zf), + COND_CS: ExprOp("CC_U>=", cf ^ ExprInt(1, 1)), # inv cf + COND_CC: ExprOp("CC_U<", cf ^ ExprInt(1, 1)), # inv cf + COND_MI: ExprOp("CC_NEG", nf), + COND_PL: ExprOp("CC_POS", nf), + COND_VS: ExprOp("CC_sOVR", of), + COND_VC: ExprOp("CC_sNOOVR", of), + COND_HI: ExprOp("CC_U>", cf ^ ExprInt(1, 1), zf), # inv cf + COND_LS: ExprOp("CC_U<=", cf ^ ExprInt(1, 1), zf), # inv cf + COND_GE: ExprOp("CC_S>=", nf, of), + COND_LT: ExprOp("CC_S<", nf, of), + COND_GT: ExprOp("CC_S>", nf, of, zf), + COND_LE: ExprOp("CC_S<=", nf, of, zf), } + + + def is_pc_written(ir, instr_ir): all_pc = ir.mn.pc.values() for ir in instr_ir: @@ -1359,6 +1531,10 @@ mnemo_condm1 = {'adds': add, 'movs': movs, 'bics': bics, 'mvns': mvns, + + 'mrs': mrs, + 'msr': msr, + 'negs': negs, 'muls': muls, diff --git a/miasm2/arch/mep/arch.py b/miasm2/arch/mep/arch.py index 3f844c06..a4c7182a 100644 --- a/miasm2/arch/mep/arch.py +++ b/miasm2/arch/mep/arch.py @@ -939,7 +939,8 @@ class mep_target24_signed(mep_target24): mep_target24.decode(self, v) v = int(self.expr.arg) - self.expr = ExprInt(v, 24).signExtend(32) + self.expr = ExprInt(sign_ext(v, 24, 32), 32) + return True @@ -1160,7 +1161,7 @@ class mep_disp12_align2_signed(mep_disp12_align2): mep_disp12_align2.decode(self, v) v = int(self.expr.arg) - self.expr = ExprInt(v, 12).signExtend(32) + self.expr = ExprInt(sign_ext(v, 12, 32), 32) return True diff --git a/miasm2/arch/x86/sem.py b/miasm2/arch/x86/sem.py index 00bdd6d7..8c140d7b 100644 --- a/miasm2/arch/x86/sem.py +++ b/miasm2/arch/x86/sem.py @@ -59,16 +59,30 @@ OF(A-B) = ((A XOR D) AND (A XOR B)) < 0 # XXX TODO make default check against 0 or not 0 (same eq as in C) +def update_flag_zf_eq(a, b): + return [m2_expr.ExprAff(zf, m2_expr.ExprOp("FLAG_EQ_CMP", a, b))] def update_flag_zf(a): - return [m2_expr.ExprAff( - zf, m2_expr.ExprCond(a, m2_expr.ExprInt(0, zf.size), - m2_expr.ExprInt(1, zf.size)))] + return [ + m2_expr.ExprAff( + zf, + m2_expr.ExprCond( + a, + m2_expr.ExprInt(0, zf.size), + m2_expr.ExprInt(1, zf.size) + ) + ) + ] -def update_flag_nf(a): - return [m2_expr.ExprAff(nf, a.msb())] +def update_flag_nf(arg): + return [ + m2_expr.ExprAff( + nf, + m2_expr.ExprOp("FLAG_SIGN_SUB", arg, m2_expr.ExprInt(0, arg.size)) + ) + ] def update_flag_pf(a): @@ -89,9 +103,15 @@ def update_flag_znp(a): return e -def update_flag_logic(a): +def update_flag_np(result): + e = [] + e += update_flag_nf(result) + e += update_flag_pf(result) + return e + + +def null_flag_co(): e = [] - e += update_flag_znp(a) e.append(m2_expr.ExprAff(of, m2_expr.ExprInt(0, of.size))) e.append(m2_expr.ExprAff(cf, m2_expr.ExprInt(0, cf.size))) return e @@ -103,6 +123,59 @@ def update_flag_arith(a): return e +def update_flag_zfaddwc_eq(arg1, arg2, arg3): + return [m2_expr.ExprAff(zf, m2_expr.ExprOp("FLAG_EQ_ADDWC", arg1, arg2, arg3))] + +def update_flag_zfsubwc_eq(arg1, arg2, arg3): + return [m2_expr.ExprAff(zf, m2_expr.ExprOp("FLAG_EQ_SUBWC", arg1, arg2, arg3))] + + +def update_flag_arith_add_znp(arg1, arg2): + """ + Compute znp flags for (arg1 + arg2) + """ + e = [] + e += update_flag_zf_eq(arg1, -arg2) + e += [m2_expr.ExprAff(nf, m2_expr.ExprOp("FLAG_SIGN_SUB", arg1, -arg2))] + e += update_flag_pf(arg1+arg2) + return e + + +def update_flag_arith_addwc_znp(arg1, arg2, arg3): + """ + Compute znp flags for (arg1 + arg2 + cf) + """ + e = [] + e += update_flag_zfaddwc_eq(arg1, arg2, arg3) + e += [m2_expr.ExprAff(nf, m2_expr.ExprOp("FLAG_SIGN_ADDWC", arg1, arg2, arg3))] + e += update_flag_pf(arg1+arg2+arg3.zeroExtend(arg2.size)) + return e + + + + +def update_flag_arith_sub_znp(arg1, arg2): + """ + Compute znp flags for (arg1 - arg2) + """ + e = [] + e += update_flag_zf_eq(arg1, arg2) + e += [m2_expr.ExprAff(nf, m2_expr.ExprOp("FLAG_SIGN_SUB", arg1, arg2))] + e += update_flag_pf(arg1 - arg2) + return e + + +def update_flag_arith_subwc_znp(arg1, arg2, arg3): + """ + Compute znp flags for (arg1 - (arg2 + cf)) + """ + e = [] + e += update_flag_zfsubwc_eq(arg1, arg2, arg3) + e += [m2_expr.ExprAff(nf, m2_expr.ExprOp("FLAG_SIGN_SUBWC", arg1, arg2, arg3))] + e += update_flag_pf(arg1 - (arg2+arg3.zeroExtend(arg2.size))) + return e + + def check_ops_msb(a, b, c): if not a or not b or not c or a != b or a != c: raise ValueError('bad ops size %s %s %s' % (a, b, c)) @@ -119,45 +192,80 @@ def arith_flag(a, b, c): def update_flag_add_cf(op1, op2, res): "Compute cf in @res = @op1 + @op2" - ret = (((op1 ^ op2) ^ res) ^ ((op1 ^ res) & (~(op1 ^ op2)))).msb() - return m2_expr.ExprAff(cf, ret) + #return [m2_expr.ExprAff(cf, m2_expr.ExprOp("FLAG_SUB_CF", op1, -op2))] + return [m2_expr.ExprAff(cf, m2_expr.ExprOp("FLAG_ADD_CF", op1, op2))] def update_flag_add_of(op1, op2, res): "Compute of in @res = @op1 + @op2" - return m2_expr.ExprAff(of, (((op1 ^ res) & (~(op1 ^ op2)))).msb()) + return [m2_expr.ExprAff(of, m2_expr.ExprOp("FLAG_ADD_OF", op1, op2))] # checked: ok for sbb add because b & c before +cf def update_flag_sub_cf(op1, op2, res): "Compote CF in @res = @op1 - @op2" - ret = (((op1 ^ op2) ^ res) ^ ((op1 ^ res) & (op1 ^ op2))).msb() - return m2_expr.ExprAff(cf, ret) + return [m2_expr.ExprAff(cf, m2_expr.ExprOp("FLAG_SUB_CF", op1, op2))] def update_flag_sub_of(op1, op2, res): "Compote OF in @res = @op1 - @op2" - return m2_expr.ExprAff(of, (((op1 ^ res) & (op1 ^ op2))).msb()) + return [m2_expr.ExprAff(of, m2_expr.ExprOp("FLAG_SUB_OF", op1, op2))] + + +def update_flag_addwc_cf(op1, op2, op3): + "Compute cf in @res = @op1 + @op2 + @op3" + return [m2_expr.ExprAff(cf, m2_expr.ExprOp("FLAG_ADDWC_CF", op1, op2, op3))] + + +def update_flag_addwc_of(op1, op2, op3): + "Compute of in @res = @op1 + @op2 + @op3" + return [m2_expr.ExprAff(of, m2_expr.ExprOp("FLAG_ADDWC_OF", op1, op2, op3))] -# z = x+y (+cf?) -def update_flag_add(x, y, z): +def update_flag_subwc_cf(op1, op2, op3): + "Compute cf in @res = @op1 + @op2 + @op3" + return [m2_expr.ExprAff(cf, m2_expr.ExprOp("FLAG_SUBWC_CF", op1, op2, op3))] + + +def update_flag_subwc_of(op1, op2, op3): + "Compute of in @res = @op1 + @op2 + @op3" + return [m2_expr.ExprAff(of, m2_expr.ExprOp("FLAG_SUBWC_OF", op1, op2, op3))] + + + + +def update_flag_arith_add_co(x, y, z): e = [] - e.append(update_flag_add_cf(x, y, z)) - e.append(update_flag_add_of(x, y, z)) + e += update_flag_add_cf(x, y, z) + e += update_flag_add_of(x, y, z) return e -# z = x-y (+cf?) + +def update_flag_arith_sub_co(x, y, z): + e = [] + e += update_flag_sub_cf(x, y, z) + e += update_flag_sub_of(x, y, z) + return e + + -def update_flag_sub(x, y, z): +def update_flag_arith_addwc_co(arg1, arg2, arg3): e = [] - e.append(update_flag_sub_cf(x, y, z)) - e.append(update_flag_sub_of(x, y, z)) + e += update_flag_addwc_cf(arg1, arg2, arg3) + e += update_flag_addwc_of(arg1, arg2, arg3) return e +def update_flag_arith_subwc_co(arg1, arg2, arg3): + e = [] + e += update_flag_subwc_cf(arg1, arg2, arg3) + e += update_flag_subwc_of(arg1, arg2, arg3) + return e + + + def set_float_cs_eip(instr): e = [] # XXX TODO check float updt @@ -344,20 +452,23 @@ def lea(_, instr, dst, src): def add(_, instr, dst, src): e = [] + result = dst + src - e += update_flag_arith(result) + + e += update_flag_arith_add_znp(dst, src) + e += update_flag_arith_add_co(dst, src, result) e += update_flag_af(dst, src, result) - e += update_flag_add(dst, src, result) e.append(m2_expr.ExprAff(dst, result)) return e, [] def xadd(_, instr, dst, src): e = [] + result = dst + src - e += update_flag_arith(result) + e += update_flag_arith_add_znp(dst, src) + e += update_flag_arith_add_co(src, dst, result) e += update_flag_af(dst, src, result) - e += update_flag_add(src, dst, result) if dst != src: e.append(m2_expr.ExprAff(src, dst)) e.append(m2_expr.ExprAff(dst, result)) @@ -366,21 +477,27 @@ def xadd(_, instr, dst, src): def adc(_, instr, dst, src): e = [] - result = dst + (src + m2_expr.ExprCompose(cf, - m2_expr.ExprInt(0, dst.size - 1))) - e += update_flag_arith(result) - e += update_flag_af(dst, src, result) - e += update_flag_add(dst, src, result) + + arg1 = dst + arg2 = src + result = arg1 + (arg2 + cf.zeroExtend(src.size)) + + e += update_flag_arith_addwc_znp(arg1, arg2, cf) + e += update_flag_arith_addwc_co(arg1, arg2, cf) + e += update_flag_af(arg1, arg2, result) e.append(m2_expr.ExprAff(dst, result)) return e, [] def sub(_, instr, dst, src): e = [] + arg1, arg2 = dst, src result = dst - src - e += update_flag_arith(result) + + e += update_flag_arith_sub_znp(arg1, arg2) + e += update_flag_arith_sub_co(arg1, arg2, result) e += update_flag_af(dst, src, result) - e += update_flag_sub(dst, src, result) + e.append(m2_expr.ExprAff(dst, result)) return e, [] @@ -389,11 +506,13 @@ def sub(_, instr, dst, src): def sbb(_, instr, dst, src): e = [] - result = dst - (src + m2_expr.ExprCompose(cf, - m2_expr.ExprInt(0, dst.size - 1))) - e += update_flag_arith(result) - e += update_flag_af(dst, src, result) - e += update_flag_sub(dst, src, result) + arg1 = dst + arg2 = src + result = arg1 - (arg2 + cf.zeroExtend(src.size)) + + e += update_flag_arith_subwc_znp(arg1, arg2, cf) + e += update_flag_af(arg1, arg2, result) + e += update_flag_arith_subwc_co(arg1, arg2, cf) e.append(m2_expr.ExprAff(dst, result)) return e, [] @@ -401,10 +520,12 @@ def sbb(_, instr, dst, src): def neg(_, instr, src): e = [] dst = m2_expr.ExprInt(0, src.size) - result = dst - src - e += update_flag_arith(result) - e += update_flag_sub(dst, src, result) - e += update_flag_af(dst, src, result) + arg1, arg2 = dst, src + result = arg1 - arg2 + + e += update_flag_arith_sub_znp(arg1, arg2) + e += update_flag_arith_sub_co(arg1, arg2, result) + e += update_flag_af(arg1, arg2, result) e.append(m2_expr.ExprAff(src, result)) return (e, []) @@ -418,9 +539,11 @@ def l_not(_, instr, dst): def l_cmp(_, instr, dst, src): e = [] + arg1, arg2 = dst, src result = dst - src - e += update_flag_arith(result) - e += update_flag_sub(dst, src, result) + + e += update_flag_arith_sub_znp(arg1, arg2) + e += update_flag_arith_sub_co(arg1, arg2, result) e += update_flag_af(dst, src, result) return (e, []) @@ -428,7 +551,9 @@ def l_cmp(_, instr, dst, src): def xor(_, instr, dst, src): e = [] result = dst ^ src - e += update_flag_logic(result) + e += [m2_expr.ExprAff(zf, m2_expr.ExprOp('FLAG_EQ_CMP', dst, src))] + e += update_flag_np(result) + e += null_flag_co() e.append(m2_expr.ExprAff(dst, result)) return (e, []) @@ -443,7 +568,9 @@ def pxor(_, instr, dst, src): def l_or(_, instr, dst, src): e = [] result = dst | src - e += update_flag_logic(result) + e += [m2_expr.ExprAff(zf, m2_expr.ExprOp('FLAG_EQ', dst | src))] + e += update_flag_np(result) + e += null_flag_co() e.append(m2_expr.ExprAff(dst, result)) return (e, []) @@ -451,7 +578,10 @@ def l_or(_, instr, dst, src): def l_and(_, instr, dst, src): e = [] result = dst & src - e += update_flag_logic(result) + e += [m2_expr.ExprAff(zf, m2_expr.ExprOp('FLAG_EQ_AND', dst, src))] + e += update_flag_np(result) + e += null_flag_co() + e.append(m2_expr.ExprAff(dst, result)) return (e, []) @@ -459,7 +589,12 @@ def l_and(_, instr, dst, src): def l_test(_, instr, dst, src): e = [] result = dst & src - e += update_flag_logic(result) + + e += [m2_expr.ExprAff(zf, m2_expr.ExprOp('FLAG_EQ_CMP', result, m2_expr.ExprInt(0, result.size)))] + e += [m2_expr.ExprAff(nf, m2_expr.ExprOp("FLAG_SIGN_SUB", result, m2_expr.ExprInt(0, result.size)))] + e += update_flag_pf(result) + e += null_flag_co() + return (e, []) @@ -717,23 +852,27 @@ def sti(_, instr): def inc(_, instr, dst): e = [] src = m2_expr.ExprInt(1, dst.size) + arg1, arg2 = dst, src result = dst + src - e += update_flag_arith(result) - e += update_flag_af(dst, src, result) - e.append(update_flag_add_of(dst, src, result)) + e += update_flag_arith_add_znp(arg1, arg2) + e += update_flag_af(arg1, arg2, result) + e += update_flag_add_of(arg1, arg2, result) + e.append(m2_expr.ExprAff(dst, result)) return e, [] def dec(_, instr, dst): e = [] - src = m2_expr.ExprInt(-1, dst.size) - result = dst + src - e += update_flag_arith(result) - e += update_flag_af(dst, src, ~result) + src = m2_expr.ExprInt(1, dst.size) + arg1, arg2 = dst, src + result = dst - src + + e += update_flag_arith_sub_znp(arg1, arg2) + e += update_flag_af(arg1, arg2, result) + e += update_flag_sub_of(arg1, arg2, result) - e.append(update_flag_add_of(dst, src, result)) e.append(m2_expr.ExprAff(dst, result)) return e, [] @@ -796,16 +935,22 @@ def popw(ir, instr, src): def sete(_, instr, dst): e = [] e.append( - m2_expr.ExprAff(dst, m2_expr.ExprCond(zf, m2_expr.ExprInt(1, dst.size), - m2_expr.ExprInt(0, dst.size)))) + m2_expr.ExprAff( + dst, + m2_expr.ExprOp("CC_EQ", zf).zeroExtend(dst.size), + ) + ) return e, [] def setnz(_, instr, dst): e = [] e.append( - m2_expr.ExprAff(dst, m2_expr.ExprCond(zf, m2_expr.ExprInt(0, dst.size), - m2_expr.ExprInt(1, dst.size)))) + m2_expr.ExprAff( + dst, + m2_expr.ExprOp("CC_EQ", ~zf).zeroExtend(dst.size), + ) + ) return e, [] @@ -813,17 +958,21 @@ def setl(_, instr, dst): e = [] e.append( m2_expr.ExprAff( - dst, m2_expr.ExprCond(nf - of, m2_expr.ExprInt(1, dst.size), - m2_expr.ExprInt(0, dst.size)))) + dst, + m2_expr.ExprOp("CC_S<", nf, of).zeroExtend(dst.size), + ) + ) return e, [] def setg(_, instr, dst): e = [] - a0 = m2_expr.ExprInt(0, dst.size) - a1 = m2_expr.ExprInt(1, dst.size) - ret = m2_expr.ExprCond(zf, a0, a1) & m2_expr.ExprCond(nf - of, a0, a1) - e.append(m2_expr.ExprAff(dst, ret)) + e.append( + m2_expr.ExprAff( + dst, + m2_expr.ExprOp("CC_S>", nf, of, zf).zeroExtend(dst.size), + ) + ) return e, [] @@ -831,128 +980,172 @@ def setge(_, instr, dst): e = [] e.append( m2_expr.ExprAff( - dst, m2_expr.ExprCond(nf - of, m2_expr.ExprInt(0, dst.size), - m2_expr.ExprInt(1, dst.size)))) + dst, + m2_expr.ExprOp("CC_S>=", nf, of).zeroExtend(dst.size), + ) + ) return e, [] def seta(_, instr, dst): e = [] - e.append(m2_expr.ExprAff(dst, m2_expr.ExprCond(cf | zf, - m2_expr.ExprInt( - 0, dst.size), - m2_expr.ExprInt(1, dst.size)))) - + e.append( + m2_expr.ExprAff( + dst, + m2_expr.ExprOp("CC_U>", cf, zf).zeroExtend(dst.size), + ) + ) return e, [] def setae(_, instr, dst): e = [] e.append( - m2_expr.ExprAff(dst, m2_expr.ExprCond(cf, m2_expr.ExprInt(0, dst.size), - m2_expr.ExprInt(1, dst.size)))) + m2_expr.ExprAff( + dst, + m2_expr.ExprOp("CC_U>=", cf).zeroExtend(dst.size), + ) + ) return e, [] def setb(_, instr, dst): e = [] e.append( - m2_expr.ExprAff(dst, m2_expr.ExprCond(cf, m2_expr.ExprInt(1, dst.size), - m2_expr.ExprInt(0, dst.size)))) + m2_expr.ExprAff( + dst, + m2_expr.ExprOp("CC_U<", cf).zeroExtend(dst.size), + ) + ) return e, [] def setbe(_, instr, dst): e = [] - e.append(m2_expr.ExprAff(dst, m2_expr.ExprCond(cf | zf, - m2_expr.ExprInt( - 1, dst.size), - m2_expr.ExprInt(0, dst.size))) - ) + e.append( + m2_expr.ExprAff( + dst, + m2_expr.ExprOp("CC_U<=", cf, zf).zeroExtend(dst.size), + ) + ) return e, [] def setns(_, instr, dst): e = [] e.append( - m2_expr.ExprAff(dst, m2_expr.ExprCond(nf, m2_expr.ExprInt(0, dst.size), - m2_expr.ExprInt(1, dst.size)))) + m2_expr.ExprAff( + dst, + m2_expr.ExprOp("CC_NEG", ~nf).zeroExtend(dst.size), + ) + ) return e, [] def sets(_, instr, dst): e = [] e.append( - m2_expr.ExprAff(dst, m2_expr.ExprCond(nf, m2_expr.ExprInt(1, dst.size), - m2_expr.ExprInt(0, dst.size)))) + m2_expr.ExprAff( + dst, + m2_expr.ExprOp("CC_NEG", nf).zeroExtend(dst.size), + ) + ) return e, [] def seto(_, instr, dst): e = [] e.append( - m2_expr.ExprAff(dst, m2_expr.ExprCond(of, m2_expr.ExprInt(1, dst.size), - m2_expr.ExprInt(0, dst.size)))) + m2_expr.ExprAff( + dst, + of.zeroExtend(dst.size) + ) + ) return e, [] def setp(_, instr, dst): e = [] e.append( - m2_expr.ExprAff(dst, m2_expr.ExprCond(pf, m2_expr.ExprInt(1, dst.size), - m2_expr.ExprInt(0, dst.size)))) + m2_expr.ExprAff( + dst, + pf.zeroExtend(dst.size) + ) + ) return e, [] def setnp(_, instr, dst): e = [] e.append( - m2_expr.ExprAff(dst, m2_expr.ExprCond(pf, m2_expr.ExprInt(0, dst.size), - m2_expr.ExprInt(1, dst.size)))) + m2_expr.ExprAff( + dst, + m2_expr.ExprCond( + pf, + m2_expr.ExprInt(0, dst.size), + m2_expr.ExprInt(1, dst.size) + ) + ) + ) return e, [] def setle(_, instr, dst): e = [] - a0 = m2_expr.ExprInt(0, dst.size) - a1 = m2_expr.ExprInt(1, dst.size) - ret = m2_expr.ExprCond(zf, a1, a0) | m2_expr.ExprCond(nf ^ of, a1, a0) - e.append(m2_expr.ExprAff(dst, ret)) + e.append( + m2_expr.ExprAff( + dst, + m2_expr.ExprOp("CC_S<=", nf, of, zf).zeroExtend(dst.size), + ) + ) return e, [] def setna(_, instr, dst): e = [] - a0 = m2_expr.ExprInt(0, dst.size) - a1 = m2_expr.ExprInt(1, dst.size) - ret = m2_expr.ExprCond(cf, a1, a0) & m2_expr.ExprCond(zf, a1, a0) - e.append(m2_expr.ExprAff(dst, ret)) + e.append( + m2_expr.ExprAff( + dst, + m2_expr.ExprOp("CC_U<=", cf, zf).zeroExtend(dst.size), + ) + ) return e, [] def setnbe(_, instr, dst): e = [] - e.append(m2_expr.ExprAff(dst, m2_expr.ExprCond(cf | zf, - m2_expr.ExprInt( - 0, dst.size), - m2_expr.ExprInt(1, dst.size))) - ) + e.append( + m2_expr.ExprAff( + dst, + m2_expr.ExprOp("CC_U>", cf, zf).zeroExtend(dst.size), + ) + ) return e, [] def setno(_, instr, dst): e = [] e.append( - m2_expr.ExprAff(dst, m2_expr.ExprCond(of, m2_expr.ExprInt(0, dst.size), - m2_expr.ExprInt(1, dst.size)))) + m2_expr.ExprAff( + dst, + m2_expr.ExprCond( + of, + m2_expr.ExprInt(0, dst.size), + m2_expr.ExprInt(1, dst.size) + ) + ) + ) return e, [] def setnb(_, instr, dst): e = [] e.append( - m2_expr.ExprAff(dst, m2_expr.ExprCond(cf, m2_expr.ExprInt(0, dst.size), - m2_expr.ExprInt(1, dst.size)))) + m2_expr.ExprAff( + dst, + m2_expr.ExprOp("CC_U>=", cf).zeroExtend(dst.size), + ) + ) return e, [] @@ -1358,7 +1551,8 @@ def jmp(ir, instr, dst): def jz(ir, instr, dst): - return gen_jcc(ir, instr, zf, dst, True) + #return gen_jcc(ir, instr, zf, dst, True) + return gen_jcc(ir, instr, m2_expr.ExprOp("CC_EQ", zf), dst, True) def jcxz(ir, instr, dst): @@ -1374,7 +1568,9 @@ def jrcxz(ir, instr, dst): def jnz(ir, instr, dst): - return gen_jcc(ir, instr, zf, dst, False) + #return gen_jcc(ir, instr, zf, dst, False) + return gen_jcc(ir, instr, m2_expr.ExprOp("CC_EQ", zf), dst, False) + def jp(ir, instr, dst): @@ -1386,43 +1582,55 @@ def jnp(ir, instr, dst): def ja(ir, instr, dst): - return gen_jcc(ir, instr, cf | zf, dst, False) + #return gen_jcc(ir, instr, cf | zf, dst, False) + return gen_jcc(ir, instr, m2_expr.ExprOp("CC_U>", cf, zf), dst, True) def jae(ir, instr, dst): - return gen_jcc(ir, instr, cf, dst, False) + #return gen_jcc(ir, instr, cf, dst, False) + return gen_jcc(ir, instr, m2_expr.ExprOp("CC_U>=", cf), dst, True) def jb(ir, instr, dst): - return gen_jcc(ir, instr, cf, dst, True) + #return gen_jcc(ir, instr, cf, dst, True) + return gen_jcc(ir, instr, m2_expr.ExprOp("CC_U<", cf), dst, True) def jbe(ir, instr, dst): - return gen_jcc(ir, instr, cf | zf, dst, True) + #return gen_jcc(ir, instr, cf | zf, dst, True) + return gen_jcc(ir, instr, m2_expr.ExprOp("CC_U<=", cf, zf), dst, True) def jge(ir, instr, dst): - return gen_jcc(ir, instr, nf - of, dst, False) + #return gen_jcc(ir, instr, nf - of, dst, False) + return gen_jcc(ir, instr, m2_expr.ExprOp("CC_S>=", nf, of), dst, True) def jg(ir, instr, dst): - return gen_jcc(ir, instr, zf | (nf - of), dst, False) + #return gen_jcc(ir, instr, zf | (nf - of), dst, False) + return gen_jcc(ir, instr, m2_expr.ExprOp("CC_S>", nf, of, zf), dst, True) def jl(ir, instr, dst): - return gen_jcc(ir, instr, nf - of, dst, True) + #return gen_jcc(ir, instr, nf - of, dst, True) + return gen_jcc(ir, instr, m2_expr.ExprOp("CC_S<", nf, of), dst, True) def jle(ir, instr, dst): - return gen_jcc(ir, instr, zf | (nf - of), dst, True) + #return gen_jcc(ir, instr, zf | (nf - of), dst, True) + return gen_jcc(ir, instr, m2_expr.ExprOp("CC_S<=", nf, of, zf), dst, True) + def js(ir, instr, dst): - return gen_jcc(ir, instr, nf, dst, True) + #return gen_jcc(ir, instr, nf, dst, True) + return gen_jcc(ir, instr, m2_expr.ExprOp("CC_NEG", nf), dst, True) + def jns(ir, instr, dst): - return gen_jcc(ir, instr, nf, dst, False) + #return gen_jcc(ir, instr, nf, dst, False) + return gen_jcc(ir, instr, m2_expr.ExprOp("CC_NEG", nf), dst, False) def jo(ir, instr, dst): @@ -2957,11 +3165,13 @@ def sldt(_, instr, dst): def cmovz(ir, instr, dst, src): - return gen_cmov(ir, instr, zf, dst, src, True) + #return gen_cmov(ir, instr, zf, dst, src, True) + return gen_cmov(ir, instr, m2_expr.ExprOp("CC_EQ", zf), dst, src, True) def cmovnz(ir, instr, dst, src): - return gen_cmov(ir, instr, zf, dst, src, False) + #return gen_cmov(ir, instr, zf, dst, src, False) + return gen_cmov(ir, instr, m2_expr.ExprOp("CC_EQ", zf), dst, src, False) def cmovpe(ir, instr, dst, src): @@ -2973,35 +3183,43 @@ def cmovnp(ir, instr, dst, src): def cmovge(ir, instr, dst, src): - return gen_cmov(ir, instr, nf ^ of, dst, src, False) + #return gen_cmov(ir, instr, nf ^ of, dst, src, False) + return gen_cmov(ir, instr, m2_expr.ExprOp("CC_S>=", nf, of), dst, src, True) def cmovg(ir, instr, dst, src): - return gen_cmov(ir, instr, zf | (nf ^ of), dst, src, False) + #return gen_cmov(ir, instr, zf | (nf ^ of), dst, src, False) + return gen_cmov(ir, instr, m2_expr.ExprOp("CC_S>", nf, of, zf), dst, src, True) def cmovl(ir, instr, dst, src): - return gen_cmov(ir, instr, nf ^ of, dst, src, True) + #return gen_cmov(ir, instr, nf ^ of, dst, src, True) + return gen_cmov(ir, instr, m2_expr.ExprOp("CC_S<", nf, of), dst, src, True) def cmovle(ir, instr, dst, src): - return gen_cmov(ir, instr, zf | (nf ^ of), dst, src, True) + #return gen_cmov(ir, instr, zf | (nf ^ of), dst, src, True) + return gen_cmov(ir, instr, m2_expr.ExprOp("CC_S<=", nf, of, zf), dst, src, True) def cmova(ir, instr, dst, src): - return gen_cmov(ir, instr, cf | zf, dst, src, False) + #return gen_cmov(ir, instr, cf | zf, dst, src, False) + return gen_cmov(ir, instr, m2_expr.ExprOp("CC_U>", cf, zf), dst, src, True) def cmovae(ir, instr, dst, src): - return gen_cmov(ir, instr, cf, dst, src, False) + #return gen_cmov(ir, instr, cf, dst, src, False) + return gen_cmov(ir, instr, m2_expr.ExprOp("CC_U>=", cf), dst, src, True) def cmovbe(ir, instr, dst, src): - return gen_cmov(ir, instr, cf | zf, dst, src, True) + #return gen_cmov(ir, instr, cf | zf, dst, src, True) + return gen_cmov(ir, instr, m2_expr.ExprOp("CC_U<=", cf, zf), dst, src, True) def cmovb(ir, instr, dst, src): - return gen_cmov(ir, instr, cf, dst, src, True) + #return gen_cmov(ir, instr, cf, dst, src, True) + return gen_cmov(ir, instr, m2_expr.ExprOp("CC_U<", cf), dst, src, True) def cmovo(ir, instr, dst, src): @@ -3013,11 +3231,13 @@ def cmovno(ir, instr, dst, src): def cmovs(ir, instr, dst, src): - return gen_cmov(ir, instr, nf, dst, src, True) + #return gen_cmov(ir, instr, nf, dst, src, True) + return gen_cmov(ir, instr, m2_expr.ExprOp("CC_NEG", nf), dst, src, True) def cmovns(ir, instr, dst, src): - return gen_cmov(ir, instr, nf, dst, src, False) + #return gen_cmov(ir, instr, nf, dst, src, False) + return gen_cmov(ir, instr, m2_expr.ExprOp("CC_NEG", nf), dst, src, False) def icebp(_, instr): diff --git a/miasm2/core/graph.py b/miasm2/core/graph.py index d35148b1..a817c024 100644 --- a/miasm2/core/graph.py +++ b/miasm2/core/graph.py @@ -267,6 +267,12 @@ class DiGraph(object): for next_node in next_cb(node): todo.add(next_node) + def predecessors_stop_node_iter(self, node, head): + if node == head: + raise StopIteration + for next_node in self.predecessors_iter(node): + yield next_node + def reachable_sons(self, head): """Compute all nodes reachable from node @head. Each son is an immediate successor of an arbitrary, already yielded son of @head""" @@ -277,6 +283,18 @@ class DiGraph(object): predecessor of an arbitrary, already yielded parent of @leaf""" return self._reachable_nodes(leaf, self.predecessors_iter) + def reachable_parents_stop_node(self, leaf, head): + """Compute all parents of node @leaf. Each parent is an immediate + predecessor of an arbitrary, already yielded parent of @leaf. + Do not compute reachables past @head node""" + return self._reachable_nodes( + leaf, + lambda node_cur: self.predecessors_stop_node_iter( + node_cur, head + ) + ) + + @staticmethod def _compute_generic_dominators(head, reachable_cb, prev_cb, next_cb): """Generic algorithm to compute either the dominators or postdominators diff --git a/miasm2/expression/expression.py b/miasm2/expression/expression.py index 11400e9e..954ba00a 100644 --- a/miasm2/expression/expression.py +++ b/miasm2/expression/expression.py @@ -367,8 +367,7 @@ class Expr(object): assert self.size <= size if self.size == size: return self - ad_size = size - self.size - return ExprCompose(self, ExprInt(0, ad_size)) + return ExprOp('zeroExt_%d' % size, self) def signExtend(self, size): """Sign extend to size @@ -377,11 +376,7 @@ class Expr(object): assert self.size <= size if self.size == size: return self - ad_size = size - self.size - return ExprCompose(self, - ExprCond(self.msb(), - ExprInt(size2mask(ad_size), ad_size), - ExprInt(0, ad_size))) + return ExprOp('signExt_%d' % size, self) def graph_recursive(self, graph): """Recursive method used by graph @@ -994,7 +989,14 @@ class ExprOp(Expr): if len(sizes) != 1: # Special cases : operande sizes can differ - if op not in ["segm"]: + if op not in [ + "segm", + "FLAG_EQ_ADDWC", "FLAG_EQ_SUBWC", + "FLAG_SIGN_ADDWC", "FLAG_SIGN_SUBWC", + "FLAG_ADDWC_CF", "FLAG_ADDWC_OF", + "FLAG_SUBWC_CF", "FLAG_SUBWC_OF", + + ]: raise ValueError( "sanitycheck: ExprOp args must have same size! %s" % ([(str(arg), arg.size) for arg in args])) @@ -1026,6 +1028,23 @@ class ExprOp(Expr): size = int(self._op[len("fp_to_sint"):]) elif self._op.startswith("fpconvert_fp"): size = int(self._op[len("fpconvert_fp"):]) + elif self._op in [ + "FLAG_ADD_CF", "FLAG_SUB_CF", + "FLAG_ADD_OF", "FLAG_SUB_OF", + "FLAG_EQ", "FLAG_EQ_CMP", + "FLAG_SIGN_SUB", "FLAG_SIGN_ADD", + "FLAG_EQ_AND", + "FLAG_EQ_ADDWC", "FLAG_EQ_SUBWC", + "FLAG_SIGN_ADDWC", "FLAG_SIGN_SUBWC", + "FLAG_ADDWC_CF", "FLAG_ADDWC_OF", + "FLAG_SUBWC_CF", "FLAG_SUBWC_OF", + ]: + size = 1 + + elif self._op.startswith('signExt_'): + size = int(self._op[8:]) + elif self._op.startswith('zeroExt_'): + size = int(self._op[8:]) elif self._op in ['segm']: size = self._args[1].size else: diff --git a/miasm2/expression/simplifications.py b/miasm2/expression/simplifications.py index e6c5dc54..712488e3 100644 --- a/miasm2/expression/simplifications.py +++ b/miasm2/expression/simplifications.py @@ -6,6 +6,7 @@ import logging from miasm2.expression import simplifications_common from miasm2.expression import simplifications_cond +from miasm2.expression import simplifications_explicit from miasm2.expression.expression_helper import fast_unify import miasm2.expression.expression as m2_expr @@ -32,13 +33,30 @@ class ExpressionSimplifier(object): # Common passes PASS_COMMONS = { - m2_expr.ExprOp: [simplifications_common.simp_cst_propagation, - simplifications_common.simp_cond_op_int, - simplifications_common.simp_cond_factor], + m2_expr.ExprOp: [ + simplifications_common.simp_cst_propagation, + simplifications_common.simp_cond_op_int, + simplifications_common.simp_cond_factor, + # CC op + simplifications_common.simp_cc_conds, + simplifications_common.simp_subwc_cf, + simplifications_common.simp_subwc_of, + simplifications_common.simp_sign_subwc_cf, + simplifications_common.simp_zeroext_eq_cst, + + ], + m2_expr.ExprSlice: [simplifications_common.simp_slice], m2_expr.ExprCompose: [simplifications_common.simp_compose], - m2_expr.ExprCond: [simplifications_common.simp_cond], + m2_expr.ExprCond: [ + simplifications_common.simp_cond, + # CC op + simplifications_common.simp_cond_flag, + simplifications_common.simp_cond_int, + simplifications_common.simp_cmp_int_arg, + ], m2_expr.ExprMem: [simplifications_common.simp_mem], + } # Heavy passes @@ -55,6 +73,16 @@ class ExpressionSimplifier(object): } + # Available passes lists are: + # - highlevel: transform high level operators to explicit computations + PASS_HIGH_TO_EXPLICIT = { + m2_expr.ExprOp: [ + simplifications_explicit.simp_flags, + simplifications_explicit.simp_ext, + ], + } + + def __init__(self): self.expr_simp_cb = {} self.simplified_exprs = set() @@ -136,3 +164,12 @@ class ExpressionSimplifier(object): # Public ExprSimplificationPass instance with commons passes expr_simp = ExpressionSimplifier() expr_simp.enable_passes(ExpressionSimplifier.PASS_COMMONS) + + + +expr_simp_high_to_explicit = ExpressionSimplifier() +expr_simp_high_to_explicit.enable_passes(ExpressionSimplifier.PASS_HIGH_TO_EXPLICIT) + +expr_simp_explicit = ExpressionSimplifier() +expr_simp_explicit.enable_passes(ExpressionSimplifier.PASS_COMMONS) +expr_simp_explicit.enable_passes(ExpressionSimplifier.PASS_HIGH_TO_EXPLICIT) diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py index 149c5b8d..fa2370bd 100644 --- a/miasm2/expression/simplifications_common.py +++ b/miasm2/expression/simplifications_common.py @@ -518,7 +518,10 @@ def simp_slice(e_s, expr): return tmp # distributivity of slice and exprcond # (a?int1:int2)[x:y] => (a?int1[x:y]:int2[x:y]) - if expr.arg.is_cond() and expr.arg.src1.is_int() and expr.arg.src2.is_int(): + # (a?compose1:compose2)[x:y] => (a?compose1[x:y]:compose2[x:y]) + if (expr.arg.is_cond() and + (expr.arg.src1.is_int() or expr.arg.src1.is_compose()) and + (expr.arg.src2.is_int() or expr.arg.src2.is_compose())): src1 = expr.arg.src1[expr.start:expr.stop] src2 = expr.arg.src2[expr.start:expr.stop] return ExprCond(expr.arg.cond, src1, src2) @@ -645,6 +648,15 @@ def simp_cond(e_s, expr): expr = ExprCond(expr.cond.cond, expr.src2, expr.src1) elif int1 and int2 == 0: expr = ExprCond(expr.cond.cond, expr.src1, expr.src2) + + elif expr.cond.is_compose(): + # {0, X, 0}?(A:B) => X?(A:B) + args = [arg for arg in expr.cond.args if not arg.is_int(0)] + if len(args) == 1: + arg = args.pop() + return ExprCond(arg, expr.src1, expr.src2) + elif len(args) < len(expr.cond.args): + return ExprCond(ExprCompose(*args), expr.src1, expr.src2) return expr @@ -659,3 +671,337 @@ def simp_mem(e_s, expr): ExprMem(cond.src2, expr.size)) return ret return expr + + + + +def test_cc_eq_args(expr, *sons_op): + if not expr.is_op(): + return False + if len(expr.args) != len(sons_op): + return False + all_args = set() + for i, arg in enumerate(expr.args): + if not arg.is_op(sons_op[i]): + return False + all_args.add(arg.args) + return len(all_args) == 1 + + +def simp_cc_conds(expr_simp, expr): + if (expr.is_op("CC_U>=") and + test_cc_eq_args( + expr, + "FLAG_SUB_CF" + )): + expr = ExprCond( + ExprOp("<u", *expr.args[0].args), + ExprInt(0, 1), + ExprInt(1, 1)) + + elif (expr.is_op("CC_U<") and + test_cc_eq_args( + expr, + "FLAG_SUB_CF" + )): + expr = ExprOp("<u", *expr.args[0].args) + + elif (expr.is_op("CC_NEG") and + test_cc_eq_args( + expr, + "FLAG_SIGN_SUB" + )): + expr = ExprOp("<s", *expr.args[0].args) + + elif (expr.is_op("CC_POS") and + test_cc_eq_args( + expr, + "FLAG_SIGN_SUB" + )): + expr = ExprCond( + ExprOp("<s", *expr.args[0].args), + ExprInt(0, 1), + ExprInt(1, 1) + ) + + elif (expr.is_op("CC_EQ") and + test_cc_eq_args( + expr, + "FLAG_EQ" + )): + arg = expr.args[0].args[0] + expr = ExprOp("==", arg, ExprInt(0, arg.size)) + + elif (expr.is_op("CC_NE") and + test_cc_eq_args( + expr, + "FLAG_EQ" + )): + arg = expr.args[0].args[0] + expr = ExprCond( + ExprOp("==",arg, ExprInt(0, arg.size)), + ExprInt(0, 1), + ExprInt(1, 1) + ) + elif (expr.is_op("CC_NE") and + test_cc_eq_args( + expr, + "FLAG_EQ_CMP" + )): + expr = ExprCond( + ExprOp("==", *expr.args[0].args), + ExprInt(0, 1), + ExprInt(1, 1) + ) + + elif (expr.is_op("CC_EQ") and + test_cc_eq_args( + expr, + "FLAG_EQ_CMP" + )): + expr = ExprOp("==", *expr.args[0].args) + + elif (expr.is_op("CC_NE") and + test_cc_eq_args( + expr, + "FLAG_EQ_AND" + )): + expr = ExprOp("&", *expr.args[0].args) + + elif (expr.is_op("CC_EQ") and + test_cc_eq_args( + expr, + "FLAG_EQ_AND" + )): + expr = ExprCond( + ExprOp("&", *expr.args[0].args), + ExprInt(0, 1), + ExprInt(1, 1) + ) + + elif (expr.is_op("CC_S>") and + test_cc_eq_args( + expr, + "FLAG_SIGN_SUB", + "FLAG_SUB_OF", + "FLAG_EQ_CMP", + )): + expr = ExprCond( + ExprOp("<=s", *expr.args[0].args), + ExprInt(0, 1), + ExprInt(1, 1) + ) + + elif (expr.is_op("CC_S>") and + len(expr.args) == 3 and + expr.args[0].is_op("FLAG_SIGN_SUB") and + expr.args[2].is_op("FLAG_EQ_CMP") and + expr.args[0].args == expr.args[2].args and + expr.args[1].is_int(0)): + expr = ExprCond( + ExprOp("<=s", *expr.args[0].args), + ExprInt(0, 1), + ExprInt(1, 1) + ) + + + + elif (expr.is_op("CC_S>=") and + test_cc_eq_args( + expr, + "FLAG_SIGN_SUB", + "FLAG_SUB_OF" + )): + expr = ExprCond( + ExprOp("<s", *expr.args[0].args), + ExprInt(0, 1), + ExprInt(1, 1) + ) + + elif (expr.is_op("CC_S<") and + test_cc_eq_args( + expr, + "FLAG_SIGN_SUB", + "FLAG_SUB_OF" + )): + expr = ExprOp("<s", *expr.args[0].args) + + elif (expr.is_op("CC_S<=") and + test_cc_eq_args( + expr, + "FLAG_SIGN_SUB", + "FLAG_SUB_OF", + "FLAG_EQ_CMP", + )): + expr = ExprOp("<=s", *expr.args[0].args) + + elif (expr.is_op("CC_S<=") and + len(expr.args) == 3 and + expr.args[0].is_op("FLAG_SIGN_SUB") and + expr.args[2].is_op("FLAG_EQ_CMP") and + expr.args[0].args == expr.args[2].args and + expr.args[1].is_int(0)): + expr = ExprOp("<=s", *expr.args[0].args) + + elif (expr.is_op("CC_U<=") and + test_cc_eq_args( + expr, + "FLAG_SUB_CF", + "FLAG_EQ_CMP", + )): + expr = ExprOp("<=u", *expr.args[0].args) + + elif (expr.is_op("CC_U>") and + test_cc_eq_args( + expr, + "FLAG_SUB_CF", + "FLAG_EQ_CMP", + )): + expr = ExprCond( + ExprOp("<=u", *expr.args[0].args), + ExprInt(0, 1), + ExprInt(1, 1) + ) + + elif (expr.is_op("CC_S<") and + test_cc_eq_args( + expr, + "FLAG_SIGN_ADD", + "FLAG_ADD_OF" + )): + arg0, arg1 = expr.args[0].args + expr = ExprOp("<s", arg0, -arg1) + + return expr + + + +def simp_cond_flag(expr_simp, expr): + # FLAG_EQ_CMP(X, Y)?A:B => (X == Y)?A:B + cond = expr.cond + if cond.is_op("FLAG_EQ_CMP"): + return ExprCond(ExprOp("==", *cond.args), expr.src1, expr.src2) + return expr + + +def simp_cond_int(expr_simp, expr): + if (expr.cond.is_op('==') and + expr.cond.args[1].is_int() and + expr.cond.args[0].is_compose() and + len(expr.cond.args[0].args) == 2 and + expr.cond.args[0].args[1].is_int(0)): + # ({X, 0} == int) => X == int[:] + src = expr.cond.args[0].args[0] + int_val = int(expr.cond.args[1]) + new_int = ExprInt(int_val, src.size) + expr = expr_simp(ExprCond(ExprOp("==", src, new_int), expr.src1, expr.src2)) + elif (expr.cond.is_op() and + expr.cond.op in ['==', '<s', '<=s', '<u', '<=u'] and + expr.cond.args[1].is_int() and + expr.cond.args[0].is_op("+") and + expr.cond.args[0].args[-1].is_int()): + # X + int1 == int2 => X == int2-int1 + left, right = expr.cond.args + left, int_diff = left.args[:-1], left.args[-1] + if len(left) == 1: + left = left[0] + else: + left = ExprOp('+', *left) + new_int = expr_simp(right - int_diff) + expr = expr_simp(ExprCond(ExprOp(expr.cond.op, left, new_int), expr.src1, expr.src2)) + return expr + + + +def simp_cmp_int_arg(expr_simp, expr): + """ + (0x10 <= R0) ? A:B + => + (R0 < 0x10) ? B:A + """ + cond = expr.cond + if not cond.is_op(): + return expr + op = cond.op + if op not in ['==', '<s', '<=s', '<u', '<=u']: + return expr + arg1, arg2 = cond.args + if arg2.is_int(): + return expr + if not arg1.is_int(): + return expr + src1, src2 = expr.src1, expr.src2 + if op == "==": + return ExprCond(ExprOp('==', arg2, arg1), src1, src2) + + arg1, arg2 = arg2, arg1 + src1, src2 = src2, src1 + if op == '<s': + op = '<=s' + elif op == '<=s': + op = '<s' + elif op == '<u': + op = '<=u' + elif op == '<=u': + op = '<u' + return ExprCond(ExprOp(op, arg1, arg2), src1, src2) + + + + +def simp_subwc_cf(expr_s, expr): + # SUBWC_CF(A, B, SUB_CF(C, D)) => SUB_CF({A, C}, {B, D}) + if not expr.is_op('FLAG_SUBWC_CF'): + return expr + op3 = expr.args[2] + if not op3.is_op("FLAG_SUB_CF"): + return expr + + op1 = ExprCompose(expr.args[0], op3.args[0]) + op2 = ExprCompose(expr.args[1], op3.args[1]) + + return ExprOp("FLAG_SUB_CF", op1, op2) + + +def simp_subwc_of(expr_s, expr): + # SUBWC_OF(A, B, SUB_CF(C, D)) => SUB_OF({A, C}, {B, D}) + if not expr.is_op('FLAG_SUBWC_OF'): + return expr + op3 = expr.args[2] + if not op3.is_op("FLAG_SUB_CF"): + return expr + + op1 = ExprCompose(expr.args[0], op3.args[0]) + op2 = ExprCompose(expr.args[1], op3.args[1]) + + return ExprOp("FLAG_SUB_OF", op1, op2) + + +def simp_sign_subwc_cf(expr_s, expr): + # SIGN_SUBWC(A, B, SUB_CF(C, D)) => SIGN_SUB({A, C}, {B, D}) + if not expr.is_op('FLAG_SIGN_SUBWC'): + return expr + op3 = expr.args[2] + if not op3.is_op("FLAG_SUB_CF"): + return expr + + op1 = ExprCompose(expr.args[0], op3.args[0]) + op2 = ExprCompose(expr.args[1], op3.args[1]) + + return ExprOp("FLAG_SIGN_SUB", op1, op2) + + +def simp_zeroext_eq_cst(expr_s, expr): + # A.zeroExt(X) == int => A == int[:A.size] + if not expr.is_op("=="): + return expr + arg1, arg2 = expr.args + if not arg2.is_int(): + return expr + if not (arg1.is_op() and arg1.op.startswith("zeroExt")): + return expr + src = arg1.args[0] + if int(arg2) > (1 << src.size): + # Always false + return ExprInt(0, 1) + return ExprOp("==", src, ExprInt(int(arg2), src.size)) diff --git a/miasm2/expression/simplifications_explicit.py b/miasm2/expression/simplifications_explicit.py new file mode 100644 index 00000000..78e056ec --- /dev/null +++ b/miasm2/expression/simplifications_explicit.py @@ -0,0 +1,155 @@ +from miasm2.expression.modint import size2mask +from miasm2.expression.expression import ExprInt, ExprCond, ExprOp, \ + ExprCompose + + +def simp_ext(_, expr): + if expr.op.startswith('zeroExt_'): + arg = expr.args[0] + if expr.size == arg.size: + return arg + return ExprCompose(arg, ExprInt(0, expr.size - arg.size)) + + if expr.op.startswith("signExt_"): + arg = expr.args[0] + add_size = expr.size - arg.size + new_expr = ExprCompose( + arg, + ExprCond( + arg.msb(), + ExprInt(size2mask(add_size), add_size), + ExprInt(0, add_size) + ) + ) + return new_expr + return expr + + +def simp_flags(_, expr): + args = expr.args + + if expr.is_op("FLAG_EQ"): + return ExprCond(args[0], ExprInt(0, 1), ExprInt(1, 1)) + + elif expr.is_op("FLAG_EQ_AND"): + op1, op2 = args + return ExprCond(op1 & op2, ExprInt(0, 1), ExprInt(1, 1)) + + elif expr.is_op("FLAG_SIGN_SUB"): + return (args[0] - args[1]).msb() + + elif expr.is_op("FLAG_EQ_CMP"): + return ExprCond( + args[0] - args[1], + ExprInt(0, 1), + ExprInt(1, 1), + ) + + elif expr.is_op("FLAG_ADD_CF"): + op1, op2 = args + res = op1 + op2 + return (((op1 ^ op2) ^ res) ^ ((op1 ^ res) & (~(op1 ^ op2)))).msb() + + elif expr.is_op("FLAG_SUB_CF"): + op1, op2 = args + res = op1 - op2 + return (((op1 ^ op2) ^ res) ^ ((op1 ^ res) & (op1 ^ op2))).msb() + + elif expr.is_op("FLAG_ADD_OF"): + op1, op2 = args + res = op1 + op2 + return (((op1 ^ res) & (~(op1 ^ op2)))).msb() + + elif expr.is_op("FLAG_SUB_OF"): + op1, op2 = args + res = op1 - op2 + return (((op1 ^ res) & (op1 ^ op2))).msb() + + elif expr.is_op("FLAG_EQ_ADDWC"): + op1, op2, op3 = args + return ExprCond( + op1 + op2 + op3.zeroExtend(op1.size), + ExprInt(0, 1), + ExprInt(1, 1), + ) + + elif expr.is_op("FLAG_ADDWC_OF"): + op1, op2, op3 = args + res = op1 + op2 + op3.zeroExtend(op1.size) + return (((op1 ^ res) & (~(op1 ^ op2)))).msb() + + elif expr.is_op("FLAG_SUBWC_OF"): + op1, op2, op3 = args + res = op1 - (op2 + op3.zeroExtend(op1.size)) + return (((op1 ^ res) & (op1 ^ op2))).msb() + + elif expr.is_op("FLAG_ADDWC_CF"): + op1, op2, op3 = args + res = op1 + op2 + op3.zeroExtend(op1.size) + return (((op1 ^ op2) ^ res) ^ ((op1 ^ res) & (~(op1 ^ op2)))).msb() + + elif expr.is_op("FLAG_SUBWC_CF"): + op1, op2, op3 = args + res = op1 - (op2 + op3.zeroExtend(op1.size)) + return (((op1 ^ op2) ^ res) ^ ((op1 ^ res) & (op1 ^ op2))).msb() + + elif expr.is_op("FLAG_SIGN_ADDWC"): + op1, op2, op3 = args + return (op1 + op2 + op3.zeroExtend(op1.size)).msb() + + elif expr.is_op("FLAG_SIGN_SUBWC"): + op1, op2, op3 = args + return (op1 - (op2 + op3.zeroExtend(op1.size))).msb() + + + elif expr.is_op("FLAG_EQ_SUBWC"): + op1, op2, op3 = args + res = op1 - (op2 + op3.zeroExtend(op1.size)) + return ExprCond(res, ExprInt(0, 1), ExprInt(1, 1)) + + elif expr.is_op("CC_U<="): + op_cf, op_zf = args + return op_cf | op_zf + + elif expr.is_op("CC_U>="): + op_cf, = args + return ~op_cf + + elif expr.is_op("CC_S<"): + op_nf, op_of = args + return op_nf ^ op_of + + elif expr.is_op("CC_S>"): + op_nf, op_of, op_zf = args + return ~(op_zf | (op_nf ^ op_of)) + + elif expr.is_op("CC_S<="): + op_nf, op_of, op_zf = args + return op_zf | (op_nf ^ op_of) + + elif expr.is_op("CC_S>="): + op_nf, op_of = args + return ~(op_nf ^ op_of) + + elif expr.is_op("CC_U>"): + op_cf, op_zf = args + return ~(op_cf | op_zf) + + elif expr.is_op("CC_U<"): + op_cf, = args + return op_cf + + elif expr.is_op("CC_NEG"): + op_nf, = args + return op_nf + + elif expr.is_op("CC_EQ"): + op_zf, = args + return op_zf + + elif expr.is_op("CC_NE"): + op_zf, = args + return ~op_zf + + return expr + diff --git a/miasm2/ir/ir.py b/miasm2/ir/ir.py index 721101e2..38a24263 100644 --- a/miasm2/ir/ir.py +++ b/miasm2/ir/ir.py @@ -440,6 +440,21 @@ class IRBlock(object): return '\n'.join(out) + def simplify(self, simplifier): + """ + Simplify expressions in each assignblock + @simplifier: ExpressionSimplifier instance + """ + modified = False + assignblks = [] + for assignblk in self: + new_assignblk = assignblk.simplify(simplifier) + if assignblk != new_assignblk: + modified = True + assignblks.append(new_assignblk) + return modified, IRBlock(self.loc_key, assignblks) + + class irbloc(IRBlock): """ DEPRECATED object diff --git a/miasm2/ir/symbexec.py b/miasm2/ir/symbexec.py index 9ab455da..1a077de5 100644 --- a/miasm2/ir/symbexec.py +++ b/miasm2/ir/symbexec.py @@ -4,7 +4,7 @@ from collections import MutableMapping from miasm2.expression.expression import ExprOp, ExprId, ExprLoc, ExprInt, \ ExprMem, ExprCompose, ExprSlice, ExprCond -from miasm2.expression.simplifications import expr_simp +from miasm2.expression.simplifications import expr_simp_explicit from miasm2.ir.ir import AssignBlock log = logging.getLogger("symbexec") @@ -138,7 +138,7 @@ class MemArray(MutableMapping): """ - def __init__(self, base, expr_simp=expr_simp): + def __init__(self, base, expr_simp=expr_simp_explicit): self._base = base self.expr_simp = expr_simp self._mask = int(base.mask) @@ -461,7 +461,7 @@ class MemSparse(object): """ - def __init__(self, addrsize, expr_simp=expr_simp): + def __init__(self, addrsize, expr_simp=expr_simp_explicit): """ @addrsize: size (in bits) of the addresses manipulated by the MemSparse @expr_simp: an ExpressionSimplifier instance @@ -604,7 +604,7 @@ class MemSparse(object): class SymbolMngr(object): """Symbolic store manager (IDs and MEMs)""" - def __init__(self, init=None, addrsize=None, expr_simp=expr_simp): + def __init__(self, init=None, addrsize=None, expr_simp=expr_simp_explicit): assert addrsize is not None if init is None: init = {} @@ -807,7 +807,7 @@ class SymbolicExecutionEngine(object): def __init__(self, ir_arch, state=None, func_read=None, func_write=None, - sb_expr_simp=expr_simp): + sb_expr_simp=expr_simp_explicit): self.expr_to_visitor = { ExprInt: self.eval_exprint, @@ -823,7 +823,7 @@ class SymbolicExecutionEngine(object): if state is None: state = {} - self.symbols = SymbolMngr(addrsize=ir_arch.addrsize, expr_simp=expr_simp) + self.symbols = SymbolMngr(addrsize=ir_arch.addrsize, expr_simp=sb_expr_simp) for dst, src in state.iteritems(): self.symbols.write(dst, src) @@ -1270,9 +1270,9 @@ class symbexec(SymbolicExecutionEngine): def __init__(self, ir_arch, known_symbols, func_read=None, func_write=None, - sb_expr_simp=expr_simp): + sb_expr_simp=expr_simp_explicit): warnings.warn("Deprecated API: use SymbolicExecutionEngine") super(symbexec, self).__init__(ir_arch, known_symbols, func_read, func_write, - sb_expr_simp=expr_simp) + sb_expr_simp=sb_expr_simp) diff --git a/miasm2/ir/translators/C.py b/miasm2/ir/translators/C.py index 11ccf137..33c21049 100644 --- a/miasm2/ir/translators/C.py +++ b/miasm2/ir/translators/C.py @@ -1,7 +1,7 @@ from miasm2.ir.translators.translator import Translator from miasm2.core import asmblock from miasm2.expression.modint import size2mask - +from miasm2.expression.expression import ExprInt, ExprCond, ExprCompose def int_size_to_bn(value, size): if size < 32: @@ -125,6 +125,28 @@ class TranslatorC(Translator): out = 'parity(%s)' % out return out + elif expr.op.startswith("zeroExt_"): + arg = expr.args[0] + if expr.size == arg.size: + return arg + return self.from_expr(ExprCompose(arg, ExprInt(0, expr.size - arg.size))) + + elif expr.op.startswith("signExt_"): + arg = expr.args[0] + if expr.size == arg.size: + return arg + add_size = expr.size - arg.size + new_expr = ExprCompose( + arg, + ExprCond( + arg.msb(), + ExprInt(size2mask(add_size), add_size), + ExprInt(0, add_size) + ) + ) + return self.from_expr(new_expr) + + elif expr.op in ['cntleadzeros', 'cnttrailzeros']: arg = expr.args[0] out = self.from_expr(arg) diff --git a/miasm2/ir/translators/z3_ir.py b/miasm2/ir/translators/z3_ir.py index 887c68d0..1b0578b7 100644 --- a/miasm2/ir/translators/z3_ir.py +++ b/miasm2/ir/translators/z3_ir.py @@ -229,6 +229,12 @@ class TranslatorZ3(Translator): index = - i % size out = size - (index + 1) res = z3.If((src & (1 << index)) != 0, out, res) + elif expr.op.startswith("zeroExt"): + arg, = expr.args + res = z3.ZeroExt(expr.size - arg.size, self.from_expr(arg)) + elif expr.op.startswith("signExt"): + arg, = expr.args + res = z3.SignExt(expr.size - arg.size, self.from_expr(arg)) else: raise NotImplementedError("Unsupported OP yet: %s" % expr.op) diff --git a/miasm2/jitter/codegen.py b/miasm2/jitter/codegen.py index abba9843..69e83de5 100644 --- a/miasm2/jitter/codegen.py +++ b/miasm2/jitter/codegen.py @@ -8,6 +8,7 @@ from miasm2.ir.ir import IRBlock, AssignBlock from miasm2.ir.translators.C import TranslatorC, int_size_to_bn from miasm2.core.asmblock import AsmBlockBad +from miasm2.expression.simplifications import expr_simp_high_to_explicit TRANSLATOR_NO_SYMBOL = TranslatorC(loc_db=None) @@ -166,6 +167,13 @@ class CGen(object): irblock_head = self.assignblk_to_irbloc(instr, assignblk_head) irblocks = [irblock_head] + assignblks_extra + # Simplify high level operators + out = [] + for irblock in irblocks: + new_irblock = irblock.simplify(expr_simp_high_to_explicit)[1] + out.append(new_irblock) + irblocks = out + for irblock in irblocks: assert irblock.dst is not None irblocks_list.append(irblocks) diff --git a/miasm2/jitter/jitcore_python.py b/miasm2/jitter/jitcore_python.py index 61bd98d0..b97727cd 100644 --- a/miasm2/jitter/jitcore_python.py +++ b/miasm2/jitter/jitcore_python.py @@ -1,7 +1,7 @@ import miasm2.jitter.jitcore as jitcore import miasm2.expression.expression as m2_expr import miasm2.jitter.csts as csts -from miasm2.expression.simplifications import ExpressionSimplifier +from miasm2.expression.simplifications import ExpressionSimplifier, expr_simp_explicit from miasm2.jitter.emulatedsymbexec import EmulatedSymbExec ################################################################################ @@ -20,12 +20,11 @@ class JitCore_Python(jitcore.JitCore): self.ircfg = self.ir_arch.new_ircfg() # CPU & VM (None for now) will be set later - expr_simp = ExpressionSimplifier() - expr_simp.enable_passes(ExpressionSimplifier.PASS_COMMONS) + self.symbexec = self.SymbExecClass( None, None, self.ir_arch, {}, - sb_expr_simp=expr_simp + sb_expr_simp=expr_simp_explicit ) self.symbexec.enable_emulated_simplifications() diff --git a/miasm2/jitter/llvmconvert.py b/miasm2/jitter/llvmconvert.py index 4a0eae93..de5f19df 100644 --- a/miasm2/jitter/llvmconvert.py +++ b/miasm2/jitter/llvmconvert.py @@ -830,6 +830,28 @@ class LLVMFunction(): self.update_cache(expr, ret) return ret + + if op.startswith('zeroExt_'): + arg = expr.args[0] + if expr.size == arg.size: + return arg + new_expr = ExprCompose(arg, ExprInt(0, expr.size - arg.size)) + return self.add_ir(new_expr) + + if op.startswith("signExt_"): + arg = expr.args[0] + add_size = expr.size - arg.size + new_expr = ExprCompose( + arg, + ExprCond( + arg.msb(), + ExprInt(size2mask(add_size), add_size), + ExprInt(0, add_size) + ) + ) + return self.add_ir(new_expr) + + if op == "segm": fc_ptr = self.mod.get_global("segm2addr") |