diff options
Diffstat (limited to 'miasm')
67 files changed, 3605 insertions, 1016 deletions
diff --git a/miasm/analysis/data_flow.py b/miasm/analysis/data_flow.py index 7bd6d72f..7340c023 100644 --- a/miasm/analysis/data_flow.py +++ b/miasm/analysis/data_flow.py @@ -1,19 +1,21 @@ """Data flow analysis based on miasm intermediate representation""" from builtins import range -from collections import namedtuple - +from collections import namedtuple, Counter +from pprint import pprint as pp from future.utils import viewitems, viewvalues from miasm.core.utils import encode_hex from miasm.core.graph import DiGraph from miasm.ir.ir import AssignBlock, IRBlock from miasm.expression.expression import ExprLoc, ExprMem, ExprId, ExprInt,\ - ExprAssign, ExprOp -from miasm.expression.simplifications import expr_simp + ExprAssign, ExprOp, ExprWalk, ExprSlice, \ + is_function_call, ExprVisitorCallbackBottomToTop +from miasm.expression.simplifications import expr_simp, expr_simp_explicit from miasm.core.interval import interval from miasm.expression.expression_helper import possible_values from miasm.analysis.ssa import get_phi_sources_parent_block, \ irblock_has_phi - +from miasm.ir.symbexec import get_expr_base_offset +from collections import deque class ReachingDefinitions(dict): """ @@ -131,7 +133,7 @@ class DiGraphDefUse(DiGraph): def __init__(self, reaching_defs, - deref_mem=False, *args, **kwargs): + deref_mem=False, apply_simp=False, *args, **kwargs): """Instantiate a DiGraph @blocks: IR blocks """ @@ -144,7 +146,8 @@ class DiGraphDefUse(DiGraph): super(DiGraphDefUse, self).__init__(*args, **kwargs) self._compute_def_use(reaching_defs, - deref_mem=deref_mem) + deref_mem=deref_mem, + apply_simp=apply_simp) def edge_attr(self, src, dst): """ @@ -155,18 +158,20 @@ class DiGraphDefUse(DiGraph): return self._edge_attr[(src, dst)] def _compute_def_use(self, reaching_defs, - deref_mem=False): + deref_mem=False, apply_simp=False): for block in viewvalues(self._blocks): self._compute_def_use_block(block, reaching_defs, - deref_mem=deref_mem) + deref_mem=deref_mem, + apply_simp=apply_simp) - def _compute_def_use_block(self, block, reaching_defs, deref_mem=False): + def _compute_def_use_block(self, block, reaching_defs, deref_mem=False, apply_simp=False): for index, assignblk in enumerate(block): assignblk_reaching_defs = reaching_defs.get_definitions(block.loc_key, index) for lval, expr in viewitems(assignblk): self.add_node(AssignblkNode(block.loc_key, index, lval)) + expr = expr_simp_explicit(expr) if apply_simp else expr read_vars = expr.get_r(mem_read=deref_mem) if deref_mem and lval.is_mem(): read_vars.update(lval.ptr.get_r(mem_read=deref_mem)) @@ -223,7 +228,7 @@ class DeadRemoval(object): lval.is_mem() or self.ir_arch.IRDst == lval or lval.is_id("exception_flags") or - rval.is_function_call() + is_function_call(rval) ): return True return False @@ -723,307 +728,16 @@ class SSADefUse(DiGraph): - -def expr_test_visit(expr, test): - result = set() - expr.visit( - lambda expr: expr, - lambda expr: test(expr, result) - ) - if result: - return True - else: - return False - - -def expr_has_mem_test(expr, result): - if result: - # Don't analyse if we already found a candidate - return False - if expr.is_mem(): - result.add(expr) - return False - return True - - def expr_has_mem(expr): """ Return True if expr contains at least one memory access @expr: Expr instance """ - return expr_test_visit(expr, expr_has_mem_test) - - -class PropagateThroughExprId(object): - """ - Propagate expressions though ExprId - """ - - def has_propagation_barrier(self, assignblks): - """ - Return True if propagation cannot cross the @assignblks - @assignblks: list of AssignBlock to check - """ - for assignblk in assignblks: - for dst, src in viewitems(assignblk): - if src.is_function_call(): - return True - if dst.is_mem(): - return True - return False - - def is_mem_written(self, ssa, node_a, node_b): - """ - Return True if memory is written at least once between @node_a and - @node_b - - @node: AssignblkNode representing the start position - @successor: AssignblkNode representing the end position - """ - - block_b = ssa.graph.blocks[node_b.label] - nodes_to_do = self.compute_reachable_nodes_from_a_to_b(ssa.graph, node_a.label, node_b.label) - - if node_a.label == node_b.label: - # src is dst - assert nodes_to_do == set([node_a.label]) - if self.has_propagation_barrier(block_b.assignblks[node_a.index:node_b.index]): - return True - else: - # Check everyone but node_a.label and node_b.label - for loc in nodes_to_do - set([node_a.label, node_b.label]): - if loc not in ssa.graph.blocks: - continue - block = ssa.graph.blocks[loc] - if self.has_propagation_barrier(block.assignblks): - return True - # Check node_a.label partially - block_a = ssa.graph.blocks[node_a.label] - if self.has_propagation_barrier(block_a.assignblks[node_a.index:]): - return True - if nodes_to_do.intersection(ssa.graph.successors(node_b.label)): - # There is a path from node_b.label to node_b.label => Check node_b.label fully - if self.has_propagation_barrier(block_b.assignblks): - return True - else: - # Check node_b.label partially - if self.has_propagation_barrier(block_b.assignblks[:node_b.index]): - return True - return False - - def compute_reachable_nodes_from_a_to_b(self, ssa, loc_a, loc_b): - reachables_a = set(ssa.reachable_sons(loc_a)) - reachables_b = set(ssa.reachable_parents_stop_node(loc_b, loc_a)) - return reachables_a.intersection(reachables_b) - - def propagation_allowed(self, ssa, to_replace, node_a, node_b): - """ - Return True if we can replace @node_a source present in @to_replace into - @node_b - @node_a: AssignblkNode position - @node_b: AssignblkNode position - """ - if not expr_has_mem(to_replace[node_a.var]): - return True - if self.is_mem_written(ssa, node_a, node_b): - return False - return True - - - def get_var_definitions(self, ssa): - """ - Return a dictionary linking variable to its assignment location - @ssa: SSADiGraph instance - """ - ircfg = ssa.graph - def_dct = {} - for node in ircfg.nodes(): - block = ircfg.blocks.get(node, None) - if block is None: - continue - for index, assignblk in enumerate(block): - for dst, src in viewitems(assignblk): - if not dst.is_id(): - continue - if dst in ssa.immutable_ids: - continue - assert dst not in def_dct - def_dct[dst] = node, index - return def_dct - - def get_candidates(self, ssa, head, max_expr_depth): - def_dct = self.get_var_definitions(ssa) - defuse = SSADefUse.from_ssa(ssa) - to_replace = {} - node_to_reg = {} - for node in defuse.nodes(): - if node.var in ssa.immutable_ids: - continue - src = defuse.get_node_target(node) - if max_expr_depth is not None and len(str(src)) > max_expr_depth: - continue - if src.is_function_call(): - continue - if node.var.is_mem(): - continue - if src.is_op('Phi'): - continue - to_replace[node.var] = src - node_to_reg[node] = node.var - return node_to_reg, to_replace, defuse - - def propagate(self, ssa, head, max_expr_depth=None): - """ - Do expression propagation - @ssa: SSADiGraph instance - @head: the head location of the graph - @max_expr_depth: the maximum allowed depth of an expression - """ - node_to_reg, to_replace, defuse = self.get_candidates(ssa, head, max_expr_depth) - modified = False - for node, reg in viewitems(node_to_reg): - for successor in defuse.successors(node): - if not self.propagation_allowed(ssa, to_replace, node, successor): - continue - - node_a = node - node_b = successor - block = ssa.graph.blocks[node_b.label] - - replace = {node_a.var: to_replace[node_a.var]} - # Replace - assignblks = list(block) - assignblk = block[node_b.index] - out = {} - for dst, src in viewitems(assignblk): - if src.is_op('Phi'): - out[dst] = src - continue - - if src.is_mem(): - ptr = src.ptr.replace_expr(replace) - new_src = ExprMem(ptr, src.size) - else: - new_src = src.replace_expr(replace) - - if dst.is_id(): - new_dst = dst - elif dst.is_mem(): - ptr = dst.ptr.replace_expr(replace) - new_dst = ExprMem(ptr, dst.size) - else: - new_dst = dst.replace_expr(replace) - if not (new_dst.is_id() or new_dst.is_mem()): - new_dst = dst - if src != new_src or dst != new_dst: - modified = True - out[new_dst] = new_src - out = AssignBlock(out, assignblk.instr) - assignblks[node_b.index] = out - new_block = IRBlock(block.loc_key, assignblks) - ssa.graph.blocks[block.loc_key] = new_block - - return modified - - - -class PropagateExprIntThroughExprId(PropagateThroughExprId): - """ - Propagate ExprInt though ExprId: classic constant propagation - This is a sub family of PropagateThroughExprId. - It reduces leaves in expressions of a program. - """ - - def get_candidates(self, ssa, head, max_expr_depth): - defuse = SSADefUse.from_ssa(ssa) - - to_replace = {} - node_to_reg = {} - for node in defuse.nodes(): - src = defuse.get_node_target(node) - if not src.is_int(): - continue - if src.is_function_call(): - continue - if node.var.is_mem(): - continue - to_replace[node.var] = src - node_to_reg[node] = node.var - return node_to_reg, to_replace, defuse - - def propagation_allowed(self, ssa, to_replace, node_a, node_b): - """ - Propagating ExprInt is always ok - """ - return True - - -class PropagateThroughExprMem(object): - """ - Propagate through ExprMem in very simple cases: - - if no memory write between source and target - - if source does not contain any memory reference - """ - - def propagate(self, ssa, head, max_expr_depth=None): - ircfg = ssa.graph - todo = set() - modified = False - for block in viewvalues(ircfg.blocks): - for i, assignblk in enumerate(block): - for dst, src in viewitems(assignblk): - if not dst.is_mem(): - continue - if expr_has_mem(src): - continue - todo.add((block.loc_key, i + 1, dst, src)) - ptr = dst.ptr - for size in range(8, dst.size, 8): - todo.add((block.loc_key, i + 1, ExprMem(ptr, size), src[:size])) - - while todo: - loc_key, index, mem_dst, mem_src = todo.pop() - block = ircfg.blocks.get(loc_key, None) - if block is None: - continue - assignblks = list(block) - block_modified = False - for i in range(index, len(block)): - assignblk = block[i] - write_mem = False - assignblk_modified = False - out = dict(assignblk) - out_new = {} - for dst, src in viewitems(out): - if dst.is_mem(): - write_mem = True - ptr = dst.ptr.replace_expr({mem_dst:mem_src}) - dst = ExprMem(ptr, dst.size) - src = src.replace_expr({mem_dst:mem_src}) - out_new[dst] = src - if out != out_new: - assignblk_modified = True - - if assignblk_modified: - assignblks[i] = AssignBlock(out_new, assignblk.instr) - block_modified = True - if write_mem: - break - else: - # If no memory written, we may propagate to sons - # if son has only parent - for successor in ircfg.successors(loc_key): - predecessors = ircfg.predecessors(successor) - if len(predecessors) != 1: - continue - todo.add((successor, 0, mem_dst, mem_src)) - - if block_modified: - modified = True - new_block = IRBlock(block.loc_key, assignblks) - ircfg.blocks[block.loc_key] = new_block - return modified + def has_mem(self): + return self.is_mem() + visitor = ExprWalk(has_mem) + return visitor.visit(expr) def stack_to_reg(expr): @@ -1061,7 +775,11 @@ def visitor_get_stack_accesses(ir_arch_a, expr, stack_vars): def get_stack_accesses(ir_arch_a, expr): result = set() - expr.visit(lambda expr:visitor_get_stack_accesses(ir_arch_a, expr, result)) + def get_stack(expr_to_test): + visitor_get_stack_accesses(ir_arch_a, expr_to_test, result) + return None + visitor = ExprWalk(get_stack) + visitor.visit(expr) return result @@ -1207,11 +925,13 @@ def memlookup_test(expr, bs, is_addr_ro_variable, result): def memlookup_visit(expr, bs, is_addr_ro_variable): result = set() - expr.visit(lambda expr: expr, - lambda expr: memlookup_test(expr, bs, is_addr_ro_variable, result)) + def retrieve_memlookup(expr_to_test): + memlookup_test(expr_to_test, bs, is_addr_ro_variable, result) + return None + visitor = ExprWalk(retrieve_memlookup) + visitor.visit(expr) return result - def get_memlookup(expr, bs, is_addr_ro_variable): return memlookup_visit(expr, bs, is_addr_ro_variable) @@ -1696,3 +1416,795 @@ class DiGraphLivenessSSA(DiGraphLivenessIRA): parent_block.infos[-1].var_out = var_info todo.add(parent) + + +def get_phi_sources(phi_src, phi_dsts, ids_to_src): + """ + Return False if the @phi_src has more than one non-phi source + Else, return its source + @ids_to_src: Dictionary linking phi source to its definition + """ + true_values = set() + for src in phi_src.args: + if src in phi_dsts: + # Source is phi dst => skip + continue + true_src = ids_to_src[src] + if true_src in phi_dsts: + # Source is phi dst => skip + continue + # Check if src is not also a phi + if true_src.is_op('Phi'): + phi_dsts.add(src) + true_src = get_phi_sources(true_src, phi_dsts, ids_to_src) + if true_src is False: + return False + if true_src is True: + continue + true_values.add(true_src) + if len(true_values) != 1: + return False + if not true_values: + return True + if len(true_values) != 1: + return False + true_value = true_values.pop() + return true_value + + +class DelDummyPhi(object): + """ + Del dummy phi + """ + + def del_dummy_phi(self, ssa, head): + ids_to_src = {} + for block in viewvalues(ssa.graph.blocks): + for index, assignblock in enumerate(block): + for dst, src in viewitems(assignblock): + if not dst.is_id(): + continue + ids_to_src[dst] = src + + modified = False + for block in ssa.graph.blocks.values(): + if not irblock_has_phi(block): + continue + assignblk = block[0] + modified_assignblk = False + for dst, phi_src in viewitems(assignblk): + assert phi_src.is_op('Phi') + true_value = get_phi_sources(phi_src, set([dst]), ids_to_src) + if true_value is False: + continue + if expr_has_mem(true_value): + continue + fixed_phis = {} + for old_dst, old_phi_src in viewitems(assignblk): + if old_dst == dst: + continue + fixed_phis[old_dst] = old_phi_src + + modified = True + + assignblks = list(block) + assignblks[0] = AssignBlock(fixed_phis, assignblk.instr) + assignblks[1:1] = [AssignBlock({dst: true_value}, assignblk.instr)] + new_irblock = IRBlock(block.loc_key, assignblks) + ssa.graph.blocks[block.loc_key] = new_irblock + + return modified + + +def replace_expr_from_bottom(expr_orig, dct): + def replace(expr): + if expr in dct: + return dct[expr] + return expr + visitor = ExprVisitorCallbackBottomToTop(lambda expr:replace(expr)) + return visitor.visit(expr_orig) + + +def is_mem_sub_part(needle, mem): + """ + If @needle is a sub part of @mem, return the offset of @needle in @mem + Else, return False + @needle: ExprMem + @mem: ExprMem + """ + ptr_base_a, ptr_offset_a = get_expr_base_offset(needle.ptr) + ptr_base_b, ptr_offset_b = get_expr_base_offset(mem.ptr) + if ptr_base_a != ptr_base_b: + return False + # Test if sub part starts after mem + if not (ptr_offset_b <= ptr_offset_a < ptr_offset_b + mem.size // 8): + return False + # Test if sub part ends before mem + if not (ptr_offset_a + needle.size // 8 <= ptr_offset_b + mem.size // 8): + return False + return ptr_offset_a - ptr_offset_b + +class UnionFind(object): + """ + Implementation of UnionFind structure + __classes: a list of Set of equivalent elements + node_to_class: Dictionary linkink an element to its equivalent class + order: Dictionary link an element to it's weight + + The order attributes is used to allow the selection of a representative + element of an equivalence class + """ + + def __init__(self): + self.index = 0 + self.__classes = [] + self.node_to_class = {} + self.order = dict() + + def copy(self): + """ + Return a copy of the object + """ + unionfind = UnionFind() + unionfind.index = self.index + unionfind.__classes = [set(known_class) for known_class in self.__classes] + node_to_class = {} + for class_eq in unionfind.__classes: + for node in class_eq: + node_to_class[node] = class_eq + unionfind.node_to_class = node_to_class + unionfind.order = dict(self.order) + return unionfind + + def replace_node(self, old_node, new_node): + """ + Replace the @old_node by the @new_node + """ + classes = self.get_classes() + node_to_class = dict(self.node_to_class) + + new_classes = [] + replace_dct = {old_node:new_node} + for eq_class in classes: + new_class = set() + for node in eq_class: + new_class.add(replace_expr_from_bottom(node, replace_dct)) + new_classes.append(new_class) + + node_to_class = {} + for class_eq in new_classes: + for node in class_eq: + node_to_class[node] = class_eq + self.__classes = new_classes + self.node_to_class = node_to_class + new_order = dict() + for node,index in self.order.items(): + new_node = replace_expr_from_bottom(node, replace_dct) + new_order[new_node] = index + self.order = new_order + + def get_classes(self): + """ + Return a list of the equivalent classes + """ + classes = [] + for class_tmp in self.__classes: + classes.append(set(class_tmp)) + return classes + + def nodes(self): + for known_class in self.__classes: + for node in known_class: + yield node + + def __eq__(self, other): + if self is other: + return True + if self.__class__ is not other.__class__: + return False + + return Counter(frozenset(known_class) for known_class in self.__classes) == Counter(frozenset(known_class) for known_class in other.__classes) + + def __ne__(self, other): + # required Python 2.7.14 + return not self == other + + def __str__(self): + components = self.__classes + out = ['UnionFind<'] + for component in components: + out.append("\t" + (", ".join([str(node) for node in component]))) + out.append('>') + return "\n".join(out) + + def add_equivalence(self, node_a, node_b): + """ + Add the new equivalence @node_a == @node_b + @node_a is equivalent to @node_b, but @node_b is more representative + than @node_a + """ + if node_b not in self.order: + self.order[node_b] = self.index + self.index += 1 + # As node_a is destination, we always replace its index + self.order[node_a] = self.index + self.index += 1 + + if node_a not in self.node_to_class and node_b not in self.node_to_class: + new_class = set([node_a, node_b]) + self.node_to_class[node_a] = new_class + self.node_to_class[node_b] = new_class + self.__classes.append(new_class) + elif node_a in self.node_to_class and node_b not in self.node_to_class: + known_class = self.node_to_class[node_a] + known_class.add(node_b) + self.node_to_class[node_b] = known_class + elif node_a not in self.node_to_class and node_b in self.node_to_class: + known_class = self.node_to_class[node_b] + known_class.add(node_a) + self.node_to_class[node_a] = known_class + else: + raise RuntimeError("Two nodes cannot be in two classes") + + def _get_master(self, node): + if node not in self.node_to_class: + return None + known_class = self.node_to_class[node] + best_node = node + for node in known_class: + if self.order[node] < self.order[best_node]: + best_node = node + return best_node + + def get_master(self, node): + """ + Return the representative element of the equivalence class containing + @node + @node: ExprMem or ExprId + """ + if not node.is_mem(): + return self._get_master(node) + if node in self.node_to_class: + # Full expr mem is known + return self._get_master(node) + # Test if mem is sub part of known node + for expr in self.node_to_class: + if not expr.is_mem(): + continue + ret = is_mem_sub_part(node, expr) + if ret is False: + continue + master = self._get_master(expr) + master = master[ret * 8 : ret * 8 + node.size] + return master + + return self._get_master(node) + + + def del_element(self, node): + """ + Remove @node for the equivalence classes + """ + assert node in self.node_to_class + known_class = self.node_to_class[node] + known_class.discard(node) + del(self.node_to_class[node]) + del(self.order[node]) + + def del_get_new_master(self, node): + """ + Remove @node for the equivalence classes and return it's representative + equivalent element + @node: Element to delete + """ + if node not in self.node_to_class: + return None + known_class = self.node_to_class[node] + known_class.discard(node) + del(self.node_to_class[node]) + del(self.order[node]) + + if not known_class: + return None + best_node = list(known_class)[0] + for node in known_class: + if self.order[node] < self.order[best_node]: + best_node = node + return best_node + +class ExprToGraph(ExprWalk): + """ + Transform an Expression into a tree and add link nodes to an existing tree + """ + def __init__(self, graph): + super(ExprToGraph, self).__init__(self.link_nodes) + self.graph = graph + + def link_nodes(self, expr, *args, **kwargs): + """ + Transform an Expression @expr into a tree and add link nodes to the + current tree + @expr: Expression + """ + if expr in self.graph.nodes(): + return None + self.graph.add_node(expr) + if expr.is_mem(): + self.graph.add_uniq_edge(expr, expr.ptr) + elif expr.is_slice(): + self.graph.add_uniq_edge(expr, expr.arg) + elif expr.is_cond(): + self.graph.add_uniq_edge(expr, expr.cond) + self.graph.add_uniq_edge(expr, expr.src1) + self.graph.add_uniq_edge(expr, expr.src2) + elif expr.is_compose(): + for arg in expr.args: + self.graph.add_uniq_edge(expr, arg) + elif expr.is_op(): + for arg in expr.args: + self.graph.add_uniq_edge(expr, arg) + return None + +class State(object): + """ + Object representing the state of a program at a given point + The state is represented using equivalence classes + + Each assignment can create/destroy equivalence classes. Interferences + between expression is computed using `may_interfer` function + """ + + def __init__(self): + self.equivalence_classes = UnionFind() + self.undefined = set() + + def copy(self): + state = self.__class__() + state.equivalence_classes = self.equivalence_classes.copy() + state.undefined = self.undefined.copy() + return state + + def __eq__(self, other): + if self is other: + return True + if self.__class__ is not other.__class__: + return False + return ( + set(self.equivalence_classes.nodes()) == set(other.equivalence_classes.nodes()) and + sorted(self.equivalence_classes.edges()) == sorted(other.equivalence_classes.edges()) and + self.undefined == other.undefined + ) + + def __ne__(self, other): + # required Python 2.7.14 + return not self == other + + def may_interfer(self, dsts, src): + """ + Return True is @src may interfer with expressions in @dsts + @dsts: Set of Expressions + @src: expression to test + """ + + srcs = src.get_r() + for src in srcs: + for dst in dsts: + if dst in src: + return True + if dst.is_mem() and src.is_mem(): + base1, offset1 = get_expr_base_offset(dst.ptr) + base2, offset2 = get_expr_base_offset(src.ptr) + if base1 != base2: + return True + assert offset1 + dst.size // 8 - 1 <= int(base1.mask) + assert offset2 + src.size // 8 - 1 <= int(base2.mask) + interval1 = interval([(offset1, offset1 + dst.size // 8 - 1)]) + interval2 = interval([(offset2, offset2 + src.size // 8 - 1)]) + if (interval1 & interval2).empty: + continue + return True + return False + + def _get_representative_expr(self, expr): + representative = self.equivalence_classes.get_master(expr) + if representative is None: + return expr + return representative + + def get_representative_expr(self, expr): + """ + Replace each sub expression of @expr by its representative element + @expr: Expression to analyse + """ + new_expr = expr.visit(self._get_representative_expr) + return new_expr + + def propagation_allowed(self, expr): + """ + Return True if @expr can be propagated + Don't propagate: + - Phi nodes + - call_func_ret / call_func_stack operants + """ + + if ( + expr.is_op('Phi') or + (expr.is_op() and expr.op.startswith("call_func")) + ): + return False + return True + + def eval_assignblock(self, assignblock): + """ + Evaluate the @assignblock on the current state + @assignblock: AssignBlock instance + """ + + out = dict(assignblock.items()) + new_out = dict() + # Replace sub expression by their equivalence class repesentative + for dst, src in out.items(): + if src.is_op('Phi'): + # Don't replace in phi + new_src = src + else: + new_src = self.get_representative_expr(src) + if dst.is_mem(): + new_ptr = self.get_representative_expr(dst.ptr) + new_dst = ExprMem(new_ptr, dst.size) + else: + new_dst = dst + new_dst = expr_simp(new_dst) + new_src = expr_simp(new_src) + new_out[new_dst] = new_src + + # For each destination, update (or delete) dependent's node according to + # equivalence classes + classes = self.equivalence_classes + + for dst in new_out: + + replacement = classes.del_get_new_master(dst) + if replacement is None: + to_del = set([dst]) + to_replace = {} + else: + to_del = set() + to_replace = {dst:replacement} + + graph = DiGraph() + # Build en expression graph linking all classes + has_parents = False + for node in classes.nodes(): + if dst in node: + # Only dependent nodes are interesting here + has_parents = True + expr_to_graph = ExprToGraph(graph) + expr_to_graph.visit(node) + + if not has_parents: + continue + + todo = graph.leaves() + done = set() + + while todo: + node = todo.pop(0) + if node in done: + continue + # If at least one son is not done, re do later + if [son for son in graph.successors(node) if son not in done]: + todo.append(node) + continue + done.add(node) + + # If at least one son cannot be replaced (deleted), our last + # chance is to have an equivalence + if any(son in to_del for son in graph.successors(node)): + # One son has been deleted! + # Try to find a replacement of the whole expression + replacement = classes.del_get_new_master(node) + if replacement is None: + to_del.add(node) + for predecessor in graph.predecessors(node): + if predecessor not in todo: + todo.append(predecessor) + continue + else: + to_replace[node] = replacement + # Continue with replacement + + # Everyson is live or has been replaced + new_node = node.replace_expr(to_replace) + + if new_node == node: + # If node is not touched (Ex: leaf node) + for predecessor in graph.predecessors(node): + if predecessor not in todo: + todo.append(predecessor) + continue + + # Node has been modified, update equivalence classes + classes.replace_node(node, new_node) + to_replace[node] = new_node + + for predecessor in graph.predecessors(node): + if predecessor not in todo: + todo.append(predecessor) + + continue + + new_assignblk = AssignBlock(new_out, assignblock.instr) + dsts = new_out.keys() + + # Remove interfering known classes + to_del = set() + for node in list(classes.nodes()): + if self.may_interfer(dsts, node): + # Interfer with known equivalence class + self.equivalence_classes.del_element(node) + if node.is_id() or node.is_mem(): + self.undefined.add(node) + + + # Update equivalence classes + for dst, src in new_out.items(): + # Delete equivalence class interfering with dst + to_del = set() + classes = self.equivalence_classes + for node in classes.nodes(): + if dst in node: + to_del.add(node) + for node in to_del: + self.equivalence_classes.del_element(node) + if node.is_id() or node.is_mem(): + self.undefined.add(node) + + # Don't create equivalence if self interfer + if self.may_interfer(dsts, src): + if dst in self.equivalence_classes.nodes(): + self.equivalence_classes.del_element(dst) + if dst.is_id() or dst.is_mem(): + self.undefined.add(dst) + continue + + if not self.propagation_allowed(src): + continue + + ## Dont create equivalence if dependence on undef + if dst.is_mem() and self.may_interfer(self.undefined, dst.ptr): + continue + + self.undefined.discard(dst) + if dst in self.equivalence_classes.nodes(): + self.equivalence_classes.del_element(dst) + self.equivalence_classes.add_equivalence(dst, src) + + return new_assignblk + + + def merge(self, other): + """ + Merge the current state with @other + @other: State instance + """ + classes1 = self.equivalence_classes + classes2 = other.equivalence_classes + + undefined = set(node for node in self.undefined if node.is_id() or node.is_mem()) + undefined.update(set(node for node in other.undefined if node.is_id() or node.is_mem())) + # Should we compute interference between srcs and undefined ? + # Nop => should already interfer in other state + components1 = classes1.get_classes() + components2 = classes2.get_classes() + + node_to_component2 = {} + for component in components2: + for node in component: + node_to_component2[node] = component + + out = [] + nodes_ok = set() + while components1: + component1 = components1.pop() + new_component1 = set() + for node in component1: + if node in undefined: + continue + component2 = node_to_component2.get(node) + if component2 is None: + if node.is_id() or node.is_mem(): + assert(node not in nodes_ok) + undefined.add(node) + continue + if node not in component2: + continue + common = component1.intersection(component2) + if len(common) == 1: + if node.is_id() or node.is_mem(): + assert(node not in nodes_ok) + undefined.add(node) + component2.discard(common.pop()) + continue + if common: + nodes_ok.update(common) + out.append(common) + diff = component1.difference(common) + if diff: + components1.append(diff) + component2.difference_update(common) + break + + # Discard remaining components2 elements + for component in components2: + for node in component: + if node.is_id() or node.is_mem(): + assert(node not in nodes_ok) + undefined.add(node) + + all_nodes = set() + for common in out: + all_nodes.update(common) + + new_order = dict( + (node, index) for (node, index) in classes1.order.items() + if node in all_nodes + ) + + unionfind = UnionFind() + new_classes = [] + global_max_index = 0 + for common in out: + min_index = None + master = None + for node in common: + index = new_order[node] + global_max_index = max(index, global_max_index) + if min_index is None or min_index > index: + min_index = index + master = node + for node in common: + if node == master: + continue + unionfind.add_equivalence(node, master) + + unionfind.index = global_max_index + unionfind.order = new_order + state = self.__class__() + state.equivalence_classes = unionfind + state.undefined = undefined + + return state + + +class PropagateExpressions(object): + """ + Propagate expressions + + The algorithm propagates equivalence classes expressions from the entry + point. During the analyse, we replace source nodes by its equivalence + classes representative. Equivalence classes can be modified during analyse + due to memory aliasing. + + For example: + B = A+1 + C = A + A = 6 + D = [B] + + Will result in: + B = A+1 + C = A + A = 6 + D = [C+1] + """ + + @staticmethod + def new_state(): + return State() + + def merge_prev_states(self, ircfg, states, loc_key): + """ + Merge predecessors states of irblock at location @loc_key + @ircfg: IRCfg instance + @sates: Dictionary linking locations to state + @loc_key: location of the current irblock + """ + + prev_states = [] + for predecessor in ircfg.predecessors(loc_key): + prev_states.append((predecessor, states[predecessor])) + + filtered_prev_states = [] + for (_, prev_state) in prev_states: + if prev_state is not None: + filtered_prev_states.append(prev_state) + + prev_states = filtered_prev_states + if not prev_states: + state = self.new_state() + elif len(prev_states) == 1: + state = prev_states[0].copy() + else: + while prev_states: + state = prev_states.pop() + if state is not None: + break + for prev_state in prev_states: + state = state.merge(prev_state) + + return state + + def update_state(self, irblock, state): + """ + Propagate the @state through the @irblock + @irblock: IRBlock instance + @state: State instance + """ + new_assignblocks = [] + modified = False + + for index, assignblock in enumerate(irblock): + if not assignblock.items(): + continue + new_assignblk = state.eval_assignblock(assignblock) + new_assignblocks.append(new_assignblk) + if new_assignblk != assignblock: + modified = True + + new_irblock = IRBlock(irblock.loc_key, new_assignblocks) + + return new_irblock, modified + + def propagate(self, ssa, head, max_expr_depth=None): + """ + Apply algorithm on the @ssa graph + """ + ircfg = ssa.ircfg + self.loc_db = ircfg.loc_db + irblocks = ssa.ircfg.blocks + states = {} + for loc_key, irblock in irblocks.items(): + states[loc_key] = None + + todo = deque([head]) + while todo: + loc_key = todo.popleft() + irblock = irblocks.get(loc_key) + if irblock is None: + continue + + state_orig = states[loc_key] + state = self.merge_prev_states(ircfg, states, loc_key) + state = state.copy() + + new_irblock, modified_irblock = self.update_state(irblock, state) + if ( + state_orig is not None and + state.equivalence_classes == state_orig.equivalence_classes and + state.undefined == state_orig.undefined + ): + continue + + if state_orig: + state.undefined.update(state_orig.undefined) + states[loc_key] = state + # Propagate to sons + for successor in ircfg.successors(loc_key): + todo.append(successor) + + # Update blocks + todo = set(loc_key for loc_key in irblocks) + modified = False + while todo: + loc_key = todo.pop() + irblock = irblocks.get(loc_key) + if irblock is None: + continue + + state = self.merge_prev_states(ircfg, states, loc_key) + new_irblock, modified_irblock = self.update_state(irblock, state) + modified |= modified_irblock + irblocks[new_irblock.loc_key] = new_irblock + + return modified diff --git a/miasm/analysis/depgraph.py b/miasm/analysis/depgraph.py index 7113dd51..964dcef4 100644 --- a/miasm/analysis/depgraph.py +++ b/miasm/analysis/depgraph.py @@ -4,7 +4,8 @@ from functools import total_ordering from future.utils import viewitems -from miasm.expression.expression import ExprInt, ExprLoc, ExprAssign +from miasm.expression.expression import ExprInt, ExprLoc, ExprAssign, \ + ExprWalk, canonize_to_exprloc from miasm.core.graph import DiGraph from miasm.core.locationdb import LocationDB from miasm.expression.simplifications import expr_simp_explicit @@ -333,10 +334,10 @@ class DependencyResultImplicit(DependencyResult): generated loc_keys """ out = [] - expected = self._ircfg.loc_db.canonize_to_exprloc(expected) + expected = canonize_to_exprloc(self._ircfg.loc_db, expected) expected_is_loc_key = expected.is_loc() for consval in possible_values(expr): - value = self._ircfg.loc_db.canonize_to_exprloc(consval.value) + value = canonize_to_exprloc(self._ircfg.loc_db, consval.value) if expected_is_loc_key and value != expected: continue if not expected_is_loc_key and value.is_loc_key(): @@ -449,6 +450,50 @@ class FollowExpr(object): if not(only_follow) or follow_expr.follow) +class FilterExprSources(ExprWalk): + """ + Walk Expression to find sources to track + @follow_mem: (optional) Track memory syntactically + @follow_call: (optional) Track through "call" + """ + def __init__(self, follow_mem, follow_call): + super(FilterExprSources, self).__init__(lambda x:None) + self.follow_mem = follow_mem + self.follow_call = follow_call + self.nofollow = set() + self.follow = set() + + def visit(self, expr, *args, **kwargs): + if expr in self.cache: + return None + ret = self.visit_inner(expr, *args, **kwargs) + self.cache.add(expr) + return ret + + def visit_inner(self, expr, *args, **kwargs): + if expr.is_id(): + self.follow.add(expr) + elif expr.is_int(): + self.nofollow.add(expr) + elif expr.is_loc(): + self.nofollow.add(expr) + elif expr.is_mem(): + if self.follow_mem: + self.follow.add(expr) + else: + self.nofollow.add(expr) + return None + elif expr.is_function_call(): + if self.follow_call: + self.follow.add(expr) + else: + self.nofollow.add(expr) + return None + + ret = super(FilterExprSources, self).visit(expr, *args, **kwargs) + return ret + + class DependencyGraph(object): """Implementation of a dependency graph @@ -480,10 +525,14 @@ class DependencyGraph(object): self._cb_follow = [] if apply_simp: self._cb_follow.append(self._follow_simp_expr) - self._cb_follow.append(lambda exprs: self._follow_exprs(exprs, - follow_mem, - follow_call)) - self._cb_follow.append(self._follow_no_loc_key) + self._cb_follow.append(lambda exprs: self.do_follow(exprs, follow_mem, follow_call)) + + @staticmethod + def do_follow(exprs, follow_mem, follow_call): + visitor = FilterExprSources(follow_mem, follow_call) + for expr in exprs: + visitor.visit(expr) + return visitor.follow, visitor.nofollow @staticmethod def _follow_simp_expr(exprs): @@ -495,64 +544,6 @@ class DependencyGraph(object): follow.add(expr_simp_explicit(expr)) return follow, set() - @staticmethod - def get_expr(expr, follow, nofollow): - """Update @follow/@nofollow according to insteresting nodes - Returns same expression (non modifier visitor). - - @expr: expression to handle - @follow: set of nodes to follow - @nofollow: set of nodes not to follow - """ - if expr.is_id(): - follow.add(expr) - elif expr.is_int(): - nofollow.add(expr) - elif expr.is_mem(): - follow.add(expr) - return expr - - @staticmethod - def follow_expr(expr, _, nofollow, follow_mem=False, follow_call=False): - """Returns True if we must visit sub expressions. - @expr: expression to browse - @follow: set of nodes to follow - @nofollow: set of nodes not to follow - @follow_mem: force the visit of memory sub expressions - @follow_call: force the visit of call sub expressions - """ - if not follow_mem and expr.is_mem(): - nofollow.add(expr) - return False - if not follow_call and expr.is_function_call(): - nofollow.add(expr) - return False - return True - - @classmethod - def _follow_exprs(cls, exprs, follow_mem=False, follow_call=False): - """Extracts subnodes from exprs and returns followed/non followed - expressions according to @follow_mem/@follow_call - - """ - follow, nofollow = set(), set() - for expr in exprs: - expr.visit(lambda x: cls.get_expr(x, follow, nofollow), - lambda x: cls.follow_expr(x, follow, nofollow, - follow_mem, follow_call)) - return follow, nofollow - - @staticmethod - def _follow_no_loc_key(exprs): - """Do not follow loc_keys""" - follow = set() - for expr in exprs: - if expr.is_int() or expr.is_loc(): - continue - follow.add(expr) - - return follow, set() - def _follow_apply_cb(self, expr): """Apply callback functions to @expr @expr : FollowExpr instance""" diff --git a/miasm/analysis/dse.py b/miasm/analysis/dse.py index 3a0482a3..9cc342c7 100644 --- a/miasm/analysis/dse.py +++ b/miasm/analysis/dse.py @@ -59,7 +59,7 @@ from future.utils import viewitems from miasm.core.utils import encode_hex, force_bytes from miasm.expression.expression import ExprMem, ExprInt, ExprCompose, \ - ExprAssign, ExprId, ExprLoc, LocKey + ExprAssign, ExprId, ExprLoc, LocKey, canonize_to_exprloc from miasm.core.bin_stream import bin_stream_vm from miasm.jitter.emulatedsymbexec import EmulatedSymbExec from miasm.expression.expression_helper import possible_values @@ -258,7 +258,7 @@ class DSEEngine(object): # lambda cannot contain statement def default_func(dse): - fname = b"%s_symb" % libimp.fad2cname[dse.jitter.pc] + fname = b"%s_symb" % force_bytes(libimp.fad2cname[dse.jitter.pc]) raise RuntimeError("Symbolic stub '%s' not found" % fname) for addr, fname in viewitems(libimp.fad2cname): @@ -333,8 +333,8 @@ class DSEEngine(object): self.handle(ExprInt(cur_addr, self.ir_arch.IRDst.size)) # Avoid memory issue in ExpressionSimplifier - if len(self.symb.expr_simp.simplified_exprs) > 100000: - self.symb.expr_simp.simplified_exprs.clear() + if len(self.symb.expr_simp.cache) > 100000: + self.symb.expr_simp.cache.clear() # Get IR blocks if cur_addr in self.addr_to_cacheblocks: @@ -633,19 +633,17 @@ class DSEPathConstraint(DSEEngine): self.cur_solver.add(self.z3_trans.from_expr(cons)) def handle(self, cur_addr): - cur_addr = self.ir_arch.loc_db.canonize_to_exprloc(cur_addr) + cur_addr = canonize_to_exprloc(self.ir_arch.loc_db, cur_addr) symb_pc = self.eval_expr(self.ir_arch.IRDst) possibilities = possible_values(symb_pc) cur_path_constraint = set() # path_constraint for the concrete path if len(possibilities) == 1: dst = next(iter(possibilities)).value - dst = self.ir_arch.loc_db.canonize_to_exprloc(dst) + dst = canonize_to_exprloc(self.ir_arch.loc_db, dst) assert dst == cur_addr else: for possibility in possibilities: - target_addr = self.ir_arch.loc_db.canonize_to_exprloc( - possibility.value - ) + target_addr = canonize_to_exprloc(self.ir_arch.loc_db, possibility.value) path_constraint = set() # Set of ExprAssign for the possible path # Get constraint associated to the possible path diff --git a/miasm/analysis/gdbserver.py b/miasm/analysis/gdbserver.py index ac58cdad..b45e9f35 100644 --- a/miasm/analysis/gdbserver.py +++ b/miasm/analysis/gdbserver.py @@ -251,8 +251,8 @@ class GdbServer(object): else: raise NotImplementedError("Unknown Except") elif isinstance(ret, debugging.DebugBreakpointTerminate): - # Connexion should close, but keep it running as a TRAP - # The connexion will be close on instance destruction + # Connection should close, but keep it running as a TRAP + # The connection will be close on instance destruction print(ret) self.status = b"S05" self.send_queue.append(b"S05") diff --git a/miasm/analysis/sandbox.py b/miasm/analysis/sandbox.py index 3040a1a8..1449d7be 100644 --- a/miasm/analysis/sandbox.py +++ b/miasm/analysis/sandbox.py @@ -213,6 +213,7 @@ class OS_Win(OS): fstream.read(), load_hdr=self.options.load_hdr, name=self.fname, + winobjs=win_api_x86_32.winobjs, **kwargs ) self.name2module[fname_basename] = self.pe @@ -227,6 +228,7 @@ class OS_Win(OS): self.ALL_IMP_DLL, libs, self.modules_path, + winobjs=win_api_x86_32.winobjs, **kwargs ) ) @@ -242,6 +244,7 @@ class OS_Win(OS): self.name2module, libs, self.modules_path, + winobjs=win_api_x86_32.winobjs, **kwargs ) diff --git a/miasm/analysis/simplifier.py b/miasm/analysis/simplifier.py index 8e9005a8..43623476 100644 --- a/miasm/analysis/simplifier.py +++ b/miasm/analysis/simplifier.py @@ -11,8 +11,8 @@ from miasm.expression.simplifications import expr_simp from miasm.ir.ir import AssignBlock, IRBlock from miasm.analysis.data_flow import DeadRemoval, \ merge_blocks, remove_empty_assignblks, \ - PropagateExprIntThroughExprId, PropagateThroughExprId, \ - PropagateThroughExprMem, del_unused_edges + del_unused_edges, \ + PropagateExpressions, DelDummyPhi log = logging.getLogger("simplifier") @@ -129,9 +129,7 @@ class IRCFGSimplifierSSA(IRCFGSimplifierCommon): and apply out-of-ssa. Final passes of IRcfgSimplifier are applied This class apply following pass until reaching a fix point: - - do_propagate_int - - do_propagate_mem - - do_propagate_expr + - do_propagate_expressions - do_dead_simp_ssa """ @@ -143,9 +141,9 @@ class IRCFGSimplifierSSA(IRCFGSimplifierCommon): self.ssa_forbidden_regs = self.get_forbidden_regs() - self.propag_int = PropagateExprIntThroughExprId() - self.propag_expr = PropagateThroughExprId() - self.propag_mem = PropagateThroughExprMem() + self.propag_expressions = PropagateExpressions() + self.del_dummy_phi = DelDummyPhi() + self.deadremoval = DeadRemoval(self.ir_arch, self.all_ssa_vars) def get_forbidden_regs(self): @@ -167,9 +165,8 @@ class IRCFGSimplifierSSA(IRCFGSimplifierCommon): """ self.passes = [ self.simplify_ssa, - self.do_propagate_int, - self.do_propagate_mem, - self.do_propagate_expr, + self.do_propagate_expressions, + self.do_del_dummy_phi, self.do_dead_simp_ssa, self.do_remove_empty_assignblks, self.do_del_unused_edges, @@ -245,13 +242,21 @@ class IRCFGSimplifierSSA(IRCFGSimplifierCommon): modified = self.propag_mem.propagate(ssa, head) return modified - @fix_point - def do_propagate_expr(self, ssa, head): + def do_propagate_expressions(self, ssa, head): """ Expressions propagation through ExprId in the @ssa graph @head: Location instance of the graph head """ - modified = self.propag_expr.propagate(ssa, head) + modified = self.propag_expressions.propagate(ssa, head) + return modified + + @fix_point + def do_del_dummy_phi(self, ssa, head): + """ + Del dummy phi + @head: Location instance of the graph head + """ + modified = self.del_dummy_phi.del_dummy_phi(ssa, head) return modified @fix_point diff --git a/miasm/arch/aarch64/arch.py b/miasm/arch/aarch64/arch.py index 10e94517..768f1b03 100644 --- a/miasm/arch/aarch64/arch.py +++ b/miasm/arch/aarch64/arch.py @@ -330,19 +330,19 @@ class instruction_aarch64(instruction): op_str = expr.op return "%s %s %s" % (expr.args[0], op_str, expr.args[1]) elif isinstance(expr, m2_expr.ExprOp) and expr.op == "postinc": - if expr.args[1].arg != 0: + if int(expr.args[1]) != 0: return "[%s], %s" % (expr.args[0], expr.args[1]) else: return "[%s]" % (expr.args[0]) elif isinstance(expr, m2_expr.ExprOp) and expr.op == "preinc_wb": - if expr.args[1].arg != 0: + if int(expr.args[1]) != 0: return "[%s, %s]!" % (expr.args[0], expr.args[1]) else: return "[%s]" % (expr.args[0]) elif isinstance(expr, m2_expr.ExprOp) and expr.op == "preinc": if len(expr.args) == 1: return "[%s]" % (expr.args[0]) - elif not isinstance(expr.args[1], m2_expr.ExprInt) or expr.args[1].arg != 0: + elif not isinstance(expr.args[1], m2_expr.ExprInt) or int(expr.args[1]) != 0: return "[%s, %s]" % (expr.args[0], expr.args[1]) else: return "[%s]" % (expr.args[0]) @@ -350,7 +350,7 @@ class instruction_aarch64(instruction): arg = expr.args[1] if isinstance(arg, m2_expr.ExprId): arg = str(arg) - elif arg.op == 'LSL' and arg.args[1].arg == 0: + elif arg.op == 'LSL' and int(arg.args[1]) == 0: arg = str(arg.args[0]) else: arg = "%s %s %s" % (arg.args[0], arg.op, arg.args[1]) @@ -375,7 +375,7 @@ class instruction_aarch64(instruction): expr = self.args[index] if not expr.is_int(): return - addr = expr.arg + self.offset + addr = (int(expr) + self.offset) & int(expr.mask) loc_key = loc_db.get_or_create_offset_location(addr) self.args[index] = m2_expr.ExprLoc(loc_key, expr.size) @@ -403,7 +403,7 @@ class instruction_aarch64(instruction): if not isinstance(e, m2_expr.ExprInt): log.debug('dyn dst %r', e) return - off = e.arg - self.offset + off = (int(e) - self.offset) & int(e.mask) if int(off % 4): raise ValueError('strange offset! %r' % off) self.args[index] = m2_expr.ExprInt(int(off), 64) @@ -643,7 +643,7 @@ class aarch64_gpreg0(bsi, aarch64_arg): def encode(self): if isinstance(self.expr, m2_expr.ExprInt): - if self.expr.arg == 0: + if int(self.expr) == 0: self.value = 0x1F return True return False @@ -793,7 +793,7 @@ def set_imm_to_size(size, expr): if size > expr.size: expr = m2_expr.ExprInt(int(expr), size) else: - if expr.arg > (1 << size) - 1: + if int(expr) > (1 << size) - 1: return None expr = m2_expr.ExprInt(int(expr), size) return expr @@ -954,11 +954,11 @@ class aarch64_gpreg_ext2(reg_noarg, aarch64_arg): if arg1.op not in EXT2_OP_INV: return False self.parent.option.value = EXT2_OP_INV[arg1.op] - if arg1.args[1].arg == 0: + if int(arg1.args[1]) == 0: self.parent.shift.value = 0 return True - if arg1.args[1].arg != self.get_size(): + if int(arg1.args[1]) != self.get_size(): return False self.parent.shift.value = 1 @@ -1273,7 +1273,7 @@ class aarch64_imm_nsr(aarch64_imm_sf, aarch64_arg): return False if not test_set_sf(self.parent, self.expr.size): return False - value = self.expr.arg + value = int(self.expr) if value == 0: return False @@ -1376,7 +1376,7 @@ class aarch64_imm_hw_sc(aarch64_arg): def encode(self): if isinstance(self.expr, m2_expr.ExprInt): - if self.expr.arg > 0xFFFF: + if int(self.expr) > 0xFFFF: return False self.value = int(self.expr) self.parent.hw.value = 0 @@ -1498,7 +1498,7 @@ class aarch64_deref(aarch64_arg): def decode(self, v): reg = gpregs64_info.expr[v] - off = self.parent.imm.expr.arg + off = int(self.parent.imm.expr) op = self.get_postpre(self.parent) off = self.decode_w_size(off) self.expr = m2_expr.ExprOp(op, reg, m2_expr.ExprInt(off, 64)) @@ -1568,7 +1568,7 @@ class aarch64_deref_nooff(aarch64_deref): reg, off = expr.args if not isinstance(off, m2_expr.ExprInt): return False - if off.arg != 0: + if int(off) != 0: return False else: return False diff --git a/miasm/arch/aarch64/sem.py b/miasm/arch/aarch64/sem.py index e9088bde..2761aed4 100644 --- a/miasm/arch/aarch64/sem.py +++ b/miasm/arch/aarch64/sem.py @@ -401,7 +401,7 @@ def movk(ir, instr, arg1, arg2): assert(arg2.op == 'slice_at' and isinstance(arg2.args[0], ExprInt) and isinstance(arg2.args[1], ExprInt)) - value, shift = int(arg2.args[0].arg), int(arg2.args[1]) + value, shift = int(arg2.args[0]), int(arg2.args[1]) e.append( ExprAssign(arg1[shift:shift + 16], ExprInt(value, 16))) else: @@ -434,7 +434,7 @@ def csel(arg1, arg2, arg3, arg4): def ccmp(ir, instr, arg1, arg2, arg3, arg4): e = [] if(arg2.is_int()): - arg2=ExprInt(arg2.arg.arg,arg1.size) + arg2=ExprInt(int(arg2),arg1.size) default_nf = arg3[0:1] default_zf = arg3[1:2] default_cf = arg3[2:3] @@ -697,7 +697,7 @@ def ldp(ir, instr, arg1, arg2, arg3): def sbfm(ir, instr, arg1, arg2, arg3, arg4): e = [] - rim, sim = int(arg3.arg), int(arg4) + 1 + rim, sim = int(arg3), int(arg4) + 1 if sim > rim: res = arg2[rim:sim].signExtend(arg1.size) else: @@ -709,7 +709,7 @@ def sbfm(ir, instr, arg1, arg2, arg3, arg4): def ubfm(ir, instr, arg1, arg2, arg3, arg4): e = [] - rim, sim = int(arg3.arg), int(arg4) + 1 + rim, sim = int(arg3), int(arg4) + 1 if sim != arg1.size - 1 and rim == sim: # Simple case: lsl value = int(rim) @@ -733,7 +733,7 @@ def ubfm(ir, instr, arg1, arg2, arg3, arg4): def bfm(ir, instr, arg1, arg2, arg3, arg4): e = [] - rim, sim = int(arg3.arg), int(arg4) + 1 + rim, sim = int(arg3), int(arg4) + 1 if sim > rim: res = arg2[rim:sim] e.append(ExprAssign(arg1[:sim-rim], res)) @@ -1038,7 +1038,7 @@ def rev16(ir, instr, arg1, arg2): @sbuild.parse def extr(arg1, arg2, arg3, arg4): compose = ExprCompose(arg2, arg3) - arg1 = compose[int(arg4.arg):int(arg4)+arg1.size] + arg1 = compose[int(arg4):int(arg4)+arg1.size] @sbuild.parse diff --git a/miasm/arch/arm/arch.py b/miasm/arch/arm/arch.py index fc6a0527..2b4476f0 100644 --- a/miasm/arch/arm/arch.py +++ b/miasm/arch/arm/arch.py @@ -413,10 +413,10 @@ class instruction_arm(instruction): if isinstance(expr, ExprOp) and expr.op == 'postinc': o = '[%s]' % r - if s and not (isinstance(s, ExprInt) and s.arg == 0): + if s and not (isinstance(s, ExprInt) and int(s) == 0): o += ', %s' % s else: - if s and not (isinstance(s, ExprInt) and s.arg == 0): + if s and not (isinstance(s, ExprInt) and int(s) == 0): o = '[%s, %s]' % (r, s) else: o = '[%s]' % (r) @@ -437,9 +437,9 @@ class instruction_arm(instruction): if not isinstance(expr, ExprInt): return if self.name == 'BLX': - addr = expr.arg + self.offset + addr = (int(expr) + self.offset) & int(expr.mask) else: - addr = expr.arg + self.offset + addr = (int(expr) + self.offset) & int(expr.mask) loc_key = loc_db.get_or_create_offset_location(addr) self.args[0] = ExprLoc(loc_key, expr.size) @@ -481,7 +481,7 @@ class instruction_arm(instruction): if not isinstance(e, ExprInt): log.debug('dyn dst %r', e) return - off = e.arg - self.offset + off = (int(e) - self.offset) & int(e.mask) if int(off % 4): raise ValueError('strange offset! %r' % off) self.args[0] = ExprInt(off, 32) @@ -514,15 +514,15 @@ class instruction_armt(instruction_arm): if not isinstance(expr, ExprInt): return if self.name == 'BLX': - addr = expr.arg + (self.offset & 0xfffffffc) + addr = (int(expr) + (self.offset & 0xfffffffc)) & int(expr.mask) elif self.name == 'BL': - addr = expr.arg + self.offset + addr = (int(expr) + self.offset) & int(expr.mask) elif self.name.startswith('BP'): - addr = expr.arg + self.offset + addr = (int(expr) + self.offset) & int(expr.mask) elif self.name.startswith('CB'): - addr = expr.arg + self.offset + self.l + 2 + addr = (int(expr) + self.offset + self.l + 2) & int(expr.mask) else: - addr = expr.arg + self.offset + addr = (int(expr) + self.offset) & int(expr.mask) loc_key = loc_db.get_or_create_offset_location(addr) dst = ExprLoc(loc_key, expr.size) @@ -564,7 +564,7 @@ class instruction_armt(instruction_arm): # The first +2 is to compensate instruction len, but strangely, 32 bits # thumb2 instructions len is 2... For the second +2, didn't find it in # the doc. - off = e.arg - self.offset + off = (int(e) - self.offset) & int(e.mask) if int(off % 2): raise ValueError('strange offset! %r' % off) self.args[0] = ExprInt(off, 32) @@ -798,6 +798,9 @@ class arm_arg(m_arg): args = [self.asm_ast_to_expr(tmp, loc_db) for tmp in arg.args] if None in args: return None + if arg.op == "-": + assert len(args) == 2 + return args[0] - args[1] return ExprOp(arg.op, *args) if isinstance(arg, AstInt): return ExprInt(arg.value, 32) @@ -1345,7 +1348,7 @@ class arm_offs_blx(arm_imm): if not isinstance(self.expr, ExprInt): return False # Remove pipeline offset - v = int(self.expr.arg - 8) + v = (int(self.expr) - 8) & int(self.expr.mask) if v & 0x80000000: v &= (1 << 26) - 1 self.parent.lowb.value = (v >> 1) & 1 @@ -1657,6 +1660,33 @@ bs_mr_name = bs_name(l=1, name=mr_name) bs_addi = bs(l=1, fname="add_imm") bs_rw = bs_mod_name(l=1, fname='rw', mn_mod=['W', '']) +class armt_barrier_option(reg_noarg, arm_arg): + reg_info = barrier_info + parser = reg_info.parser + + def decode(self, v): + v = v & self.lmask + if v not in self.reg_info.dct_expr: + return False + self.expr = self.reg_info.dct_expr[v] + return True + + def encode(self): + if not self.expr in self.reg_info.dct_expr_inv: + log.debug("cannot encode reg %r", self.expr) + return False + self.value = self.reg_info.dct_expr_inv[self.expr] + if self.value > self.lmask: + log.debug("cannot encode field value %x %x", + self.value, self.lmask) + return False + return True + + def check_fbits(self, v): + return v & self.fmask == self.fbits + +barrier_option = bs(l=4, cls=(armt_barrier_option,)) + armop("mul", [bs('000000'), bs('0'), scc, rd, bs('0000'), rs, bs('1001'), rm], [rd, rm, rs]) armop("umull", [bs('000010'), bs('0'), scc, rd, rdl, rs, bs('1001'), rm], [rdl, rd, rm, rs]) armop("umlal", [bs('000010'), bs('1'), scc, rd, rdl, rs, bs('1001'), rm], [rdl, rd, rm, rs]) @@ -1706,7 +1736,8 @@ armop("rev16", [bs('01101011'), bs('1111'), rd, bs('1111'), bs('1011'), rm]) armop("pld", [bs8(0xF5), bs_addi, bs_rw, bs('01'), mem_rn_imm, bs('1111'), imm12_off]) -armop("isb", [bs8(0xF5), bs8(0x7F), bs8(0xF0), bs8(0x6F)]) +armop("dsb", [bs('111101010111'), bs('1111'), bs('1111'), bs('0000'), bs('0100'), barrier_option]) +armop("isb", [bs('111101010111'), bs('1111'), bs('1111'), bs('0000'), bs('0110'), barrier_option]) armop("nop", [bs8(0xE3), bs8(0x20), bs8(0xF0), bs8(0)]) class arm_widthm1(arm_imm, m_arg): @@ -2323,7 +2354,6 @@ class arm_sp(arm_reg): reg_info = gpregs_sp parser = reg_info.parser - off5 = bs(l=5, cls=(arm_imm,), fname="off") off3 = bs(l=3, cls=(arm_imm,), fname="off") off8 = bs(l=8, cls=(arm_imm,), fname="off") @@ -2738,7 +2768,7 @@ class armt2_imm10l(arm_imm): def encode(self): if not isinstance(self.expr, ExprInt): return False - v = self.expr.arg.arg + v = int(self.expr) s = 0 if v & 0x80000000: s = 1 @@ -2775,7 +2805,7 @@ class armt2_imm11l(arm_imm): def encode(self): if not isinstance(self.expr, ExprInt): return False - v = self.expr.arg.arg - 4 + v = (int(self.expr) - 4) & int(self.expr.mask) s = 0 if v & 0x80000000: s = 1 @@ -2813,7 +2843,7 @@ class armt2_imm6_11l(arm_imm): def encode(self): if not isinstance(self.expr, ExprInt): return False - v = self.expr.arg.arg - 4 + v = (int(self.expr) - 4) & int(self.expr.mask) s = 0 if v != sign_ext(v & ((1 << 22) - 1), 21, 32): return False @@ -2881,7 +2911,7 @@ class armt_imm5_1(arm_imm): def encode(self): if not isinstance(self.expr, ExprInt): return False - v = self.expr.arg.arg + v = int(self.expr) if v & 0x1: return False self.parent.imm1.value = (v >> 6) & 1 @@ -3227,33 +3257,6 @@ bs_deref_reg_reg = bs(l=4, cls=(armt_deref_reg_reg,)) bs_deref_reg_reg_lsl_1 = bs(l=4, cls=(armt_deref_reg_reg_lsl_1,)) -class armt_barrier_option(reg_noarg, arm_arg): - reg_info = barrier_info - parser = reg_info.parser - - def decode(self, v): - v = v & self.lmask - if v not in self.reg_info.dct_expr: - return False - self.expr = self.reg_info.dct_expr[v] - return True - - def encode(self): - if not self.expr in self.reg_info.dct_expr_inv: - log.debug("cannot encode reg %r", self.expr) - return False - self.value = self.reg_info.dct_expr_inv[self.expr] - if self.value > self.lmask: - log.debug("cannot encode field value %x %x", - self.value, self.lmask) - return False - return True - - def check_fbits(self, v): - return v & self.fmask == self.fbits - -barrier_option = bs(l=4, cls=(armt_barrier_option,)) - armtop("adc", [bs('11110'), imm12_1, bs('0'), bs('1010'), scc, rn_nosppc, bs('0'), imm12_3, rd_nosppc, imm12_8]) armtop("adc", [bs('11101'), bs('01'), bs('1010'), scc, rn_nosppc, bs('0'), imm5_3, rd_nosppc, imm5_2, imm_stype, rm_sh]) armtop("bl", [bs('11110'), tsign, timm10H, bs('11'), tj1, bs('1'), tj2, timm11L]) diff --git a/miasm/arch/arm/regs.py b/miasm/arch/arm/regs.py index 63caada3..2b24b0d5 100644 --- a/miasm/arch/arm/regs.py +++ b/miasm/arch/arm/regs.py @@ -2,7 +2,7 @@ from builtins import range from miasm.expression.expression import * - +from miasm.core.cpu import gen_reg, gen_regs # GP @@ -111,4 +111,67 @@ regs_init = {} for i, r in enumerate(all_regs_ids): regs_init[r] = all_regs_ids_init[i] +coproc_reg_str = [ + "MIDR", "CTR", "TCMTR", "TLBTR", "MIDR", "MPIDR", "REVIDR", + "ID_PFR0", "ID_PFR1", "ID_DFR0", "ID_AFR0", "ID_MMFR0", "ID_MMFR1", "ID_MMFR2", "ID_MMFR3", + "ID_ISAR0", "ID_ISAR1", "ID_ISAR2", "ID_ISAR3", "ID_ISAR4", "ID_ISAR5", + "CCSIDR", "CLIDR", "AIDR", + "CSSELR", + "VPIDR", "VMPIDR", + "SCTLR", "ACTLR", "CPACR", + "SCR", "SDER", "NSACR", + "HSCTLR", "HACTLR", + "HCR", "HDCR", "HCPTR", "HSTR", "HACR", + "TTBR0", "TTBR1", "TTBCR", + "HTCR", "VTCR", + "DACR", + "DFSR", "IFSR", + "ADFSR", "AIFSR", + "HADFSR", "HAIFSR", + "HSR", + "DFAR", "IFAR", + "HDFAR", "HIFAR", "HPFAR", + "ICIALLUIS", "BPIALLIS", + "PAR", + "ICIALLU", "ICIMVAU", "CP15ISB", "BPIALL", "BPIMVA", + "DCIMVAC", "DCISW", + "ATS1CPR", "ATS1CPW", "ATS1CUR", "ATS1CUW", "ATS12NSOPR", "ATS12NSOPW", "ATS12NSOUR", "ATS12NSOUW", + "DCCMVAC", "DCCSW", "CP15DSB", "CP15DMB", + "DCCMVAU", + "DCCIMVAC", "DCCISW", + "ATS1HR", "ATS1HW", + "TLBIALLIS", "TLBIMVAIS", "TLBIASIDIS", "TLBIMVAAIS", + "ITLBIALL", "ITLBIMVA", "ITLBIASID", + "DTLBIALL", "DTLBIMVA", "DTLBIASID", + "TLBIALL", "TLBIMVA", "TLBIASID", "TLBIMVAA", + "TLBIALLHIS", "TLBIMVAHIS", "TLBIALLNSNHIS", + "TLBIALLH", "TLBIMVAH", "TLBIALLNSNH", + "PMCR", "PMCNTENSET", "PMCNTENCLR", "PMOVSR", "PMSWINC", "PMSELR", "PMCEID0", "PMCEID1", + "PMCCNTR", "PMXEVTYPER", "PMXEVCNTR", + "PMUSERENR", "PMINTENSET", "PMINTENCLR", "PMOVSSET", + "PRRR", "NMRR", + "AMAIR0", "AMAIR1", + "HMAIR0", "HMAIR1", + "HAMAIR0", "HAMAIR1", + "VBAR", "MVBAR", + "ISR", + "HVBAR", + "FCSEIDR", "CONTEXTIDR", "TPIDRURW", "TPIDRURO", "TPIDRPRW", + "HTPIDR", + "CNTFRQ", + "CNTKCTL", + "CNTP_TVAL", "CNTP_CTL", + "CNTV_TVAL", "CNTV_CTL", + "CNTHCTL", + "CNTHP_TVAL", "CNTHP_CTL" + ] +coproc_reg_expr, coproc_reg_init, coproc_reg_info = gen_regs(coproc_reg_str, globals(), 32) + +all_regs_ids = all_regs_ids + coproc_reg_expr +all_regs_ids_byname.update(dict([(x.name, x) for x in coproc_reg_expr])) +all_regs_ids_init = all_regs_ids_init + coproc_reg_init + +for i, r in enumerate(coproc_reg_expr): + regs_init[r] = coproc_reg_init[i] + regs_flt_expr = [] diff --git a/miasm/arch/arm/sem.py b/miasm/arch/arm/sem.py index 981a5060..027c3a6a 100644 --- a/miasm/arch/arm/sem.py +++ b/miasm/arch/arm/sem.py @@ -8,6 +8,219 @@ from miasm.arch.arm.regs import * from miasm.jitter.csts import EXCEPT_DIV_BY_ZERO, EXCEPT_INT_XX +coproc_reg_dict = { + ("p15", "c0", 0, "c0", 0): MIDR, + ("p15", "c0", 0, "c0", 1): CTR, + ("p15", "c0", 0, "c0", 2): TCMTR, + ("p15", "c0", 0, "c0", 3): TLBTR, + ("p15", "c0", 0, "c0", 4): MIDR, + ("p15", "c0", 0, "c0", 5): MPIDR, + ("p15", "c0", 0, "c0", 6): REVIDR, + ("p15", "c0", 0, "c0", 7): MIDR, + + ("p15", "c0", 0, "c1", 0): ID_PFR0, + ("p15", "c0", 0, "c1", 1): ID_PFR1, + ("p15", "c0", 0, "c1", 2): ID_DFR0, + ("p15", "c0", 0, "c1", 3): ID_AFR0, + ("p15", "c0", 0, "c1", 4): ID_MMFR0, + ("p15", "c0", 0, "c1", 5): ID_MMFR1, + ("p15", "c0", 0, "c1", 6): ID_MMFR2, + ("p15", "c0", 0, "c1", 7): ID_MMFR3, + + ("p15", "c0", 0, "c2", 0): ID_ISAR0, + ("p15", "c0", 0, "c2", 1): ID_ISAR1, + ("p15", "c0", 0, "c2", 2): ID_ISAR2, + ("p15", "c0", 0, "c2", 3): ID_ISAR3, + ("p15", "c0", 0, "c2", 4): ID_ISAR4, + ("p15", "c0", 0, "c2", 5): ID_ISAR5, + + ("p15", "c0", 1, "c0", 0): CCSIDR, + ("p15", "c0", 1, "c0", 1): CLIDR, + ("p15", "c0", 1, "c0", 7): AIDR, + + ("p15", "c0", 2, "c0", 0): CSSELR, + + ("p15", "c0", 4, "c0", 0): VPIDR, + ("p15", "c0", 4, "c0", 5): VMPIDR, + + ("p15", "c1", 0, "c0", 0): SCTLR, + ("p15", "c1", 0, "c0", 1): ACTLR, + ("p15", "c1", 0, "c0", 2): CPACR, + + ("p15", "c1", 0, "c1", 0): SCR, + ("p15", "c1", 0, "c1", 1): SDER, + ("p15", "c1", 0, "c1", 2): NSACR, + + ("p15", "c1", 4, "c0", 0): HSCTLR, + ("p15", "c1", 4, "c0", 1): HACTLR, + + ("p15", "c1", 4, "c1", 0): HCR, + ("p15", "c1", 4, "c1", 1): HDCR, + ("p15", "c1", 4, "c1", 2): HCPTR, + ("p15", "c1", 4, "c1", 3): HSTR, + ("p15", "c1", 4, "c1", 7): HACR, + + # TODO: TTBRO/TTBR1 64-bit + ("p15", "c2", 0, "c0", 0): TTBR0, + ("p15", "c2", 0, "c0", 1): TTBR1, + ("p15", "c2", 0, "c0", 2): TTBCR, + + ("p15", "c2", 4, "c0", 2): HTCR, + + ("p15", "c2", 4, "c1", 2): VTCR, + + # TODO: HTTBR, VTTBR + + ("p15", "c3", 0, "c0", 0): DACR, + + ("p15", "c5", 0, "c0", 0): DFSR, + ("p15", "c5", 0, "c0", 1): IFSR, + + ("p15", "c5", 0, "c1", 0): ADFSR, + ("p15", "c5", 0, "c1", 1): AIFSR, + + ("p15", "c5", 4, "c1", 0): HADFSR, + ("p15", "c5", 4, "c1", 1): HAIFSR, + + ("p15", "c5", 4, "c2", 0): HSR, + + ("p15", "c6", 0, "c1", 0): DFAR, + ("p15", "c6", 0, "c1", 2): IFAR, + + ("p15", "c6", 4, "c0", 0): HDFAR, + ("p15", "c6", 4, "c0", 2): HIFAR, + ("p15", "c6", 4, "c0", 4): HPFAR, + + ("p15", "c7", 0, "c1", 0): ICIALLUIS, + ("p15", "c7", 0, "c1", 6): BPIALLIS, + + ("p15", "c7", 0, "c4", 0): PAR, + + # TODO: PAR 64-bit + + ("p15", "c7", 0, "c5", 0): ICIALLU, + ("p15", "c7", 0, "c5", 1): ICIMVAU, + ("p15", "c7", 0, "c5", 4): CP15ISB, + ("p15", "c7", 0, "c5", 6): BPIALL, + ("p15", "c7", 0, "c5", 7): BPIMVA, + + ("p15", "c7", 0, "c6", 1): DCIMVAC, + ("p15", "c7", 0, "c6", 2): DCISW, + + ("p15", "c7", 0, "c8", 0): ATS1CPR, + ("p15", "c7", 0, "c8", 1): ATS1CPW, + ("p15", "c7", 0, "c8", 2): ATS1CUR, + ("p15", "c7", 0, "c8", 3): ATS1CUW, + ("p15", "c7", 0, "c8", 4): ATS12NSOPR, + ("p15", "c7", 0, "c8", 5): ATS12NSOPW, + ("p15", "c7", 0, "c8", 6): ATS12NSOUR, + ("p15", "c7", 0, "c8", 7): ATS12NSOUW, + + ("p15", "c7", 0, "c10", 1): DCCMVAC, + ("p15", "c7", 0, "c10", 2): DCCSW, + ("p15", "c7", 0, "c10", 4): CP15DSB, + ("p15", "c7", 0, "c10", 5): CP15DMB, + + ("p15", "c7", 0, "c11", 1): DCCMVAU, + + ("p15", "c7", 0, "c14", 1): DCCIMVAC, + ("p15", "c7", 0, "c14", 2): DCCISW, + + ("p15", "c7", 4, "c8", 0): ATS1HR, + ("p15", "c7", 4, "c8", 1): ATS1HW, + + ("p15", "c8", 0, "c3", 0): TLBIALLIS, + ("p15", "c8", 0, "c3", 1): TLBIMVAIS, + ("p15", "c8", 0, "c3", 2): TLBIASIDIS, + ("p15", "c8", 0, "c3", 3): TLBIMVAAIS, + + ("p15", "c8", 0, "c5", 0): ITLBIALL, + ("p15", "c8", 0, "c5", 1): ITLBIMVA, + ("p15", "c8", 0, "c5", 2): ITLBIASID, + + ("p15", "c8", 0, "c6", 0): DTLBIALL, + ("p15", "c8", 0, "c6", 1): DTLBIMVA, + ("p15", "c8", 0, "c6", 2): DTLBIASID, + + ("p15", "c8", 0, "c7", 0): TLBIALL, + ("p15", "c8", 0, "c7", 1): TLBIMVA, + ("p15", "c8", 0, "c7", 2): TLBIASID, + ("p15", "c8", 0, "c7", 3): TLBIMVAA, + + ("p15", "c8", 4, "c3", 0): TLBIALLHIS, + ("p15", "c8", 4, "c3", 1): TLBIMVAHIS, + ("p15", "c8", 4, "c3", 4): TLBIALLNSNHIS, + + ("p15", "c8", 4, "c7", 0): TLBIALLH, + ("p15", "c8", 4, "c7", 1): TLBIMVAH, + ("p15", "c8", 4, "c7", 2): TLBIALLNSNH, + + ("p15", "c9", 0, "c12", 0): PMCR, + ("p15", "c9", 0, "c12", 1): PMCNTENSET, + ("p15", "c9", 0, "c12", 2): PMCNTENCLR, + ("p15", "c9", 0, "c12", 3): PMOVSR, + ("p15", "c9", 0, "c12", 4): PMSWINC, + ("p15", "c9", 0, "c12", 5): PMSELR, + ("p15", "c9", 0, "c12", 6): PMCEID0, + ("p15", "c9", 0, "c12", 7): PMCEID1, + + ("p15", "c9", 0, "c13", 0): PMCCNTR, + ("p15", "c9", 0, "c13", 1): PMXEVTYPER, + ("p15", "c9", 0, "c13", 2): PMXEVCNTR, + + ("p15", "c9", 0, "c14", 0): PMUSERENR, + ("p15", "c9", 0, "c14", 1): PMINTENSET, + ("p15", "c9", 0, "c14", 2): PMINTENCLR, + ("p15", "c9", 0, "c14", 3): PMOVSSET, + + ("p15", "c10", 0, "c2", 0): PRRR, # ALIAS MAIR0 + ("p15", "c10", 0, "c2", 1): NMRR, # ALIAS MAIR1 + + ("p15", "c10", 0, "c3", 0): AMAIR0, + ("p15", "c10", 0, "c3", 1): AMAIR1, + + ("p15", "c10", 4, "c2", 0): HMAIR0, + ("p15", "c10", 4, "c2", 1): HMAIR1, + + ("p15", "c10", 4, "c3", 0): HAMAIR0, + ("p15", "c10", 4, "c3", 1): HAMAIR1, + + ("p15", "c12", 0, "c0", 0): VBAR, + ("p15", "c12", 0, "c0", 1): MVBAR, + + ("p15", "c12", 0, "c1", 0): ISR, + + ("p15", "c12", 4, "c0", 0): HVBAR, + + ("p15", "c13", 0, "c0", 0): FCSEIDR, + ("p15", "c13", 0, "c0", 1): CONTEXTIDR, + ("p15", "c13", 0, "c0", 2): TPIDRURW, + ("p15", "c13", 0, "c0", 3): TPIDRURO, + ("p15", "c13", 0, "c0", 4): TPIDRPRW, + + ("p15", "c13", 4, "c0", 2): HTPIDR, + + ("p15", "c14", 0, "c0", 0): CNTFRQ, + # TODO: CNTPCT 64-bit + + ("p15", "c14", 0, "c1", 0): CNTKCTL, + + ("p15", "c14", 0, "c2", 0): CNTP_TVAL, + ("p15", "c14", 0, "c2", 1): CNTP_CTL, + + ("p15", "c14", 0, "c3", 0): CNTV_TVAL, + ("p15", "c14", 0, "c3", 1): CNTV_CTL, + + # TODO: CNTVCT, CNTP_CVAL, CNTV_CVAL, CNTVOFF 64-bit + + ("p15", "c14", 4, "c1", 0): CNTHCTL, + + ("p15", "c14", 4, "c2", 0): CNTHP_TVAL, + ("p15", "c14", 4, "c2", 0): CNTHP_CTL + + # TODO: CNTHP_CVAL 64-bit + } + # liris.cnrs.fr/~mmrissa/lib/exe/fetch.php?media=armv7-a-r-manual.pdf EXCEPT_SOFT_BP = (1 << 1) @@ -762,7 +975,6 @@ def blx(ir, instr, a): def st_ld_r(ir, instr, a, a2, b, store=False, size=32, s_ext=False, z_ext=False): e = [] wb = False - b = b.copy() postinc = False b = b.ptr if isinstance(b, ExprOp): @@ -1320,6 +1532,10 @@ def dsb(ir, instr, a): e = [] return e, [] +def isb(ir, instr, a): + # XXX TODO + e = [] + return e, [] def cpsie(ir, instr, a): # XXX TODO @@ -1377,6 +1593,25 @@ def pkhtb(ir, instr, arg1, arg2, arg3): ) return e, [] +def mrc(ir, insr, arg1, arg2, arg3, arg4, arg5, arg6): + e = [] + sreg = (str(arg1), str(arg4), int(arg2), str(arg5), int(arg6)) + if sreg in coproc_reg_dict: + e.append(ExprAssign(arg3, coproc_reg_dict[sreg])) + else: + raise NotImplementedError("Unknown coprocessor register: %s %s %d %s %d" % (str(arg1), str(arg4), int(arg2), str(arg5), int(arg6))) + + return e, [] + +def mcr(ir, insr, arg1, arg2, arg3, arg4, arg5, arg6): + e = [] + sreg = (str(arg1), str(arg4), int(arg2), str(arg5), int(arg6)) + if sreg in coproc_reg_dict: + e.append(ExprAssign(coproc_reg_dict[sreg], arg3)) + else: + raise NotImplementedError("Unknown coprocessor register: %s %s %d %s %d" % (str(arg1), str(arg4), int(arg2), str(arg5), int(arg6))) + + return e, [] COND_EQ = 0 COND_NE = 1 @@ -1517,6 +1752,9 @@ mnemo_condm0 = {'add': add, 'sdiv': sdiv, 'udiv': udiv, + 'mrc': mrc, + 'mcr': mcr, + 'mul': mul, 'umull': umull, 'umlal': umlal, @@ -1630,6 +1868,7 @@ mnemo_nocond = {'lsr': lsr, 'tbh': tbh, 'nop': nop, 'dsb': dsb, + 'isb': isb, 'cpsie': cpsie, 'cpsid': cpsid, 'wfe': wfe, @@ -1775,7 +2014,7 @@ class ir_arml(IntermediateRepresentation): index += 1 instr = block.lines[index] - # Add conditionnal jump to current irblock + # Add conditional jump to current irblock loc_do = self.loc_db.add_location() loc_next = self.get_next_loc_key(instr) diff --git a/miasm/arch/mep/arch.py b/miasm/arch/mep/arch.py index d06a93ca..073acc57 100644 --- a/miasm/arch/mep/arch.py +++ b/miasm/arch/mep/arch.py @@ -4,8 +4,8 @@ from builtins import range from miasm.core.cpu import * from miasm.core.utils import Disasm_Exception -from miasm.expression.expression import Expr, ExprId, ExprInt, ExprLoc, \ - ExprMem, ExprOp +from miasm.expression.expression import ExprId, ExprInt, ExprLoc, \ + ExprMem, ExprOp, is_expr from miasm.core.asm_ast import AstId, AstMem from miasm.arch.mep.regs import * @@ -35,7 +35,7 @@ def ExprInt2SignedString(expr, pos_fmt="%d", neg_fmt="%d", size=None, offset=0): else: mask_length = size mask = (1 << mask_length) - 1 - value = int(expr.arg) & mask + value = int(expr) & mask # Return a signed integer if necessary if (value >> mask_length - 1) == 1: @@ -90,7 +90,7 @@ class instruction_mep(instruction): return "(%s)" % expr.ptr elif isinstance(expr, ExprMem) and isinstance(expr.ptr, ExprOp): - return "0x%X(%s)" % (expr.ptr.args[1].arg, expr.ptr.args[0]) + return "0x%X(%s)" % (int(expr.ptr.args[1]), expr.ptr.args[0]) # Raise an exception if the expression type was not processed message = "instruction_mep.arg2str(): don't know what \ @@ -111,13 +111,13 @@ class instruction_mep(instruction): if self.name == "SSARB": # The first operand is displayed in decimal, not in hex - o += " %d" % self.args[0].arg + o += " %d" % int(self.args[0]) o += self.arg2str(self.args[1]) elif self.name in ["MOV", "ADD"] and isinstance(self.args[1], ExprInt): # The second operand is displayed in decimal, not in hex o += " " + self.arg2str(self.args[0]) - o += ", %s" % ExprInt2SignedString(self.args[1].arg) + o += ", %s" % ExprInt2SignedString(self.args[1]) elif "CPI" in self.name: # The second operand ends with the '+' sign @@ -131,7 +131,7 @@ class instruction_mep(instruction): deref_reg_str = self.arg2str(self.args[1]) o += ", %s+)" % deref_reg_str[:-1] # GV: looks ugly # The third operand is displayed in decimal, not in hex - o += ", %s" % ExprInt2SignedString(self.args[2].arg) + o += ", %s" % ExprInt2SignedString(self.args[2]) elif len(self.args) == 2 and self.name in ["SB", "SH", "LBU", "LB", "LH", "LW"] and \ isinstance(self.args[1], ExprMem) and isinstance(self.args[1].ptr, ExprOp): # Major Opcodes #12 @@ -150,13 +150,13 @@ class instruction_mep(instruction): elif self.name == "SLL" and isinstance(self.args[1], ExprInt): # Major Opcodes #6 # The second operand is displayed in hex, not in decimal o += " " + self.arg2str(self.args[0]) - o += ", 0x%X" % self.args[1].arg + o += ", 0x%X" % int(self.args[1]) elif self.name in ["ADD3", "SLT3"] and isinstance(self.args[2], ExprInt): o += " %s" % self.arg2str(self.args[0]) o += ", %s" % self.arg2str(self.args[1]) # The third operand is displayed in decimal, not in hex - o += ", %s" % ExprInt2SignedString(self.args[2].arg, pos_fmt="0x%X") + o += ", %s" % ExprInt2SignedString(self.args[2], pos_fmt="0x%X") elif self.name == "(RI)": return o @@ -166,7 +166,7 @@ class instruction_mep(instruction): if self.args: o += " " for i, arg in enumerate(self.args): - if not isinstance(arg, Expr): + if not is_expr(arg): raise ValueError('zarb arg type') x = self.arg2str(arg, pos=i) args.append(x) @@ -218,7 +218,7 @@ class instruction_mep(instruction): # Compute the correct address num = self.get_dst_num() - addr = self.args[num].arg + addr = int(self.args[num]) if not self.name == "JMP": addr += self.offset @@ -671,7 +671,7 @@ class mep_deref_reg_offset(mep_arg): return False # Get the integer and check the upper bound - v = int(self.expr.ptr.args[1].arg & 0xFFFF) + v = int(self.expr.ptr.args[1]) & 0xFFFF # Encode the values self.parent.reg04_deref.value = gpr_exprs.index(self.expr.ptr.args[0]) @@ -845,7 +845,7 @@ class mep_int32_noarg(int32_noarg): def encode(self): if not isinstance(self.expr, ExprInt): return False - v = int(self.expr.arg) + v = int(self.expr) # Note: the following lines were commented on purpose #if sign_ext(v & self.lmask, self.l, self.intsize) != v: # return False @@ -923,7 +923,7 @@ class mep_target24(mep_imm): return False # Get the integer and apply a mask - v = int(self.expr.arg) & 0x00FFFFFF + v = int(self.expr) & 0x00FFFFFF # Encode the value into two parts self.parent.imm7.value = (v & 0xFF) >> 1 @@ -940,7 +940,7 @@ class mep_target24_signed(mep_target24): """ mep_target24.decode(self, v) - v = int(self.expr.arg) + v = int(self.expr) self.expr = ExprInt(sign_ext(v, 24, 32), 32) return True @@ -1046,7 +1046,7 @@ class mep_imm7_align4(mep_imm): return False # Get the integer and check the upper bound - v = int(self.expr.arg) + v = int(self.expr) if v > 0x80: return False @@ -1129,7 +1129,7 @@ class mep_disp7_align2(mep_imm): return False # Get the integer - v = int(self.expr.arg) & self.upper_bound + v = int(self.expr) & self.upper_bound # Encode the value self.value = (v >> self.bits_shift) & self.upper_bound @@ -1161,7 +1161,7 @@ class mep_disp12_align2_signed(mep_disp12_align2): """Perform sign extension. """ mep_disp12_align2.decode(self, v) - v = int(self.expr.arg) + v = int(self.expr) self.expr = ExprInt(sign_ext(v, 12, 32), 32) return True @@ -1198,7 +1198,7 @@ class mep_imm24(mep_imm): return False # Get the integer and check the upper bound - v = int(self.expr.arg) + v = int(self.expr) if v > 0xFFFFFF: return False @@ -1236,7 +1236,7 @@ class mep_abs24(mep_imm): return False # Get the integer and check the upper bound - v = int(self.expr.ptr.arg) + v = int(self.expr.ptr) if v > 0xffffff: return False diff --git a/miasm/arch/mep/sem.py b/miasm/arch/mep/sem.py index 1736b139..df484ab5 100644 --- a/miasm/arch/mep/sem.py +++ b/miasm/arch/mep/sem.py @@ -239,7 +239,7 @@ def add3(ir, instr, reg_dst, reg_src, reg_or_imm): result = ExprOp("+", reg_src, reg_or_imm) else: # Rn <- Rm + SignExt(imm16) - value = int(reg_or_imm.arg) + value = int(reg_or_imm) result = ExprOp("+", reg_src, ExprInt(value, 32)) return [ExprAssign(reg_dst, result)], [] @@ -334,7 +334,7 @@ if False: def sltu3(r0, rn, rm_or_imm5): """SLTU3 - Set on less than (unsigned).""" - # if (Rn<Rm) R0<-1 else R0<-0 (Unigned) + # if (Rn<Rm) R0<-1 else R0<-0 (Unsigned) # if (Rn<ZeroExt(imm5)) R0<-1 else R0<-0(Unsigned) r0 = i32(1) if compute_u_inf(rn, rm_or_imm5) else i32(0) @@ -645,7 +645,7 @@ def repeat(rn, disp17): # RPB <- pc+4 // Repeat Begin RPB = PC + i32(4) # RPE <- pc+SignExt((disp17)16..1||0)) // Repeat End - RPE = PC + i32(disp17.arg & 0xFFFFFFFE) + RPE = PC + i32(int(disp17) & 0xFFFFFFFE) # RPC <- Rn RPC = rn in_erepeat = ExprInt(0, 32) @@ -660,7 +660,7 @@ def erepeat(disp17): # RPB <- pc+4 // Repeat Begin RPB = PC + i32(4) # RPE <- pc+SignExt((disp17)16..1||1)) (EREPEAT) - RPE = PC + i32(disp17.arg + 1) + RPE = PC + i32(int(disp17) + 1) # RPC <- undefined in_erepeat = ExprInt(1, 32) diff --git a/miasm/arch/mips32/arch.py b/miasm/arch/mips32/arch.py index 09ff0a24..0398be37 100644 --- a/miasm/arch/mips32/arch.py +++ b/miasm/arch/mips32/arch.py @@ -47,8 +47,8 @@ class additional_info(object): self.except_on_instr = False br_0 = ['B', 'J', 'JR', 'BAL', 'JAL', 'JALR'] -br_1 = ['BGEZ', 'BLTZ', 'BGTZ', 'BLEZ', 'BC1T', 'BC1F'] -br_2 = ['BEQ', 'BEQL', 'BNE'] +br_1 = ['BGEZ', 'BLTZ', 'BGTZ', 'BGTZL', 'BLEZ', 'BLEZL', 'BC1T', 'BC1TL', 'BC1F', 'BC1FL'] +br_2 = ['BEQ', 'BEQL', 'BNE', 'BNEL'] class instruction_mips32(cpu.instruction): @@ -95,8 +95,9 @@ class instruction_mips32(cpu.instruction): def dstflow2label(self, loc_db): if self.name in ["J", 'JAL']: - expr = self.args[0].arg - addr = (self.offset & (0xFFFFFFFF ^ ((1<< 28)-1))) + expr + expr = self.args[0] + offset = int(expr) + addr = ((self.offset & (0xFFFFFFFF ^ ((1<< 28)-1))) + offset) & int(expr.mask) loc_key = loc_db.get_or_create_offset_location(addr) self.args[0] = ExprLoc(loc_key, expr.size) return @@ -106,7 +107,7 @@ class instruction_mips32(cpu.instruction): if not isinstance(expr, ExprInt): return - addr = expr.arg + self.offset + addr = (int(expr) + self.offset) & int(expr.mask) loc_key = loc_db.get_or_create_offset_location(addr) self.args[ndx] = ExprLoc(loc_key, expr.size) @@ -157,7 +158,7 @@ class instruction_mips32(cpu.instruction): raise ValueError('symbol not resolved %s' % self.l) if not isinstance(e, ExprInt): return - off = e.arg - self.offset + off = (int(e) - self.offset) & int(e.mask) if int(off % 4): raise ValueError('strange offset! %r' % off) self.args[ndx] = ExprInt(off, 32) @@ -312,7 +313,7 @@ class mips32_s16imm_noarg(mips32_imm): def encode(self): if not isinstance(self.expr, ExprInt): return False - v = self.expr.arg.arg + v = int(self.expr) if v & 0x80000000: nv = v & ((1 << 16) - 1) assert( v == cpu.sign_ext(nv, 16, 32)) @@ -320,6 +321,26 @@ class mips32_s16imm_noarg(mips32_imm): self.value = v return True + +class mips32_s09imm_noarg(mips32_imm): + def decode(self, v): + v = v & self.lmask + v = cpu.sign_ext(v, 9, 32) + self.expr = ExprInt(v, 32) + return True + + def encode(self): + if not isinstance(self.expr, ExprInt): + return False + v = int(self.expr) + if v & 0x80000000: + nv = v & ((1 << 9) - 1) + assert( v == cpu.sign_ext(nv, 9, 32)) + v = nv + self.value = v + return True + + class mips32_soff_noarg(mips32_imm): def decode(self, v): v = v & self.lmask @@ -333,7 +354,7 @@ class mips32_soff_noarg(mips32_imm): if not isinstance(self.expr, ExprInt): return False # Remove pipeline offset - v = int(self.expr.arg - 4) + v = (int(self.expr) - 4) & 0xFFFFFFFF if v & 0x80000000: nv = v & ((1 << 16+2) - 1) assert( v == cpu.sign_ext(nv, 16+2, 32)) @@ -345,6 +366,9 @@ class mips32_soff_noarg(mips32_imm): class mips32_s16imm(mips32_s16imm_noarg, mips32_arg): pass +class mips32_s09imm(mips32_s09imm_noarg, mips32_arg): + pass + class mips32_soff(mips32_soff_noarg, mips32_arg): pass @@ -358,7 +382,7 @@ class mips32_instr_index(mips32_imm, mips32_arg): def encode(self): if not isinstance(self.expr, ExprInt): return False - v = self.expr.arg.arg + v = int(self.expr) if v & 3: return False v>>=2 @@ -377,7 +401,7 @@ class mips32_u16imm(mips32_imm, mips32_arg): def encode(self): if not isinstance(self.expr, ExprInt): return False - v = self.expr.arg.arg + v = int(self.expr) assert(v < (1<<16)) self.value = v return True @@ -424,7 +448,7 @@ class mips32_esize(mips32_imm, mips32_arg): def encode(self): if not isinstance(self.expr, ExprInt): return False - v = self.expr.arg.arg -1 + v = int(self.expr) -1 assert(v < (1<<16)) self.value = v return True @@ -437,7 +461,7 @@ class mips32_eposh(mips32_imm, mips32_arg): def encode(self): if not isinstance(self.expr, ExprInt): return False - v = int(self.expr.arg) + int(self.parent.epos.expr) -1 + v = int(self.expr) + int(self.parent.epos.expr) -1 self.value = v return True @@ -470,16 +494,22 @@ fd = cpu.bs(l=5, cls=(mips32_fltpreg,)) s16imm = cpu.bs(l=16, cls=(mips32_s16imm,)) u16imm = cpu.bs(l=16, cls=(mips32_u16imm,)) +s09imm = cpu.bs(l=9, cls=(mips32_s09imm,)) sa = cpu.bs(l=5, cls=(mips32_u16imm,)) base = cpu.bs(l=5, cls=(mips32_dreg_imm,)) soff = cpu.bs(l=16, cls=(mips32_soff,)) +oper = cpu.bs(l=5, cls=(mips32_u16imm,)) cpr0 = cpu.bs(l=5, cls=(mips32_imm,), fname="cpr0") cpr = cpu.bs(l=3, cls=(mips32_cpr,)) +stype = cpu.bs(l=5, cls=(mips32_u16imm,)) +hint_pref = cpu.bs(l=5, cls=(mips32_u16imm,)) s16imm_noarg = cpu.bs(l=16, cls=(mips32_s16imm_noarg,), fname="imm", order=-1) +s09imm_noarg = cpu.bs(l=9, cls=(mips32_s09imm_noarg,), fname="imm", + order=-1) hint = cpu.bs(l=5, default_val="00000") fcc = cpu.bs(l=3, cls=(mips32_fccreg,)) @@ -668,13 +698,18 @@ mips32op("mfhi", [cpu.bs('000000'), cpu.bs('0000000000'), rd, mips32op("b", [cpu.bs('000100'), cpu.bs('00000'), cpu.bs('00000'), soff], alias = True) mips32op("bne", [cpu.bs('000101'), rs, rt, soff]) +mips32op("bnel", [cpu.bs('010101'), rs, rt, soff]) + mips32op("beq", [cpu.bs('000100'), rs, rt, soff]) +mips32op("beql", [cpu.bs('010100'), rs, rt, soff]) mips32op("blez", [cpu.bs('000110'), rs, cpu.bs('00000'), soff]) +mips32op("blezl", [cpu.bs('010110'), rs, cpu.bs('00000'), soff]) mips32op("bcc", [cpu.bs('000001'), rs, bs_bcc, soff]) mips32op("bgtz", [cpu.bs('000111'), rs, cpu.bs('00000'), soff]) +mips32op("bgtzl", [cpu.bs('010111'), rs, cpu.bs('00000'), soff]) mips32op("bal", [cpu.bs('000001'), cpu.bs('00000'), cpu.bs('10001'), soff], alias = True) @@ -697,7 +732,6 @@ mips32op("mtc0", [cpu.bs('010000'), cpu.bs('00100'), rt, cpr0, cpu.bs('00000000'), cpr]) mips32op("mtc1", [cpu.bs('010001'), cpu.bs('00100'), rt, fs, cpu.bs('00000000000')]) - # XXXX TODO CFC1 mips32op("cfc1", [cpu.bs('010001'), cpu.bs('00010'), rt, fs, cpu.bs('00000000000')]) @@ -715,8 +749,12 @@ mips32op("c", [cpu.bs('010001'), bs_fmt, ft, fs, fcc, cpu.bs('0'), mips32op("bc1t", [cpu.bs('010001'), cpu.bs('01000'), fcc, cpu.bs('0'), cpu.bs('1'), soff]) +mips32op("bc1tl", [cpu.bs('010001'), cpu.bs('01000'), fcc, cpu.bs('1'), + cpu.bs('1'), soff]) mips32op("bc1f", [cpu.bs('010001'), cpu.bs('01000'), fcc, cpu.bs('0'), cpu.bs('0'), soff]) +mips32op("bc1fl", [cpu.bs('010001'), cpu.bs('01000'), fcc, cpu.bs('1'), + cpu.bs('0'), soff]) mips32op("swc1", [cpu.bs('111001'), base, ft, s16imm_noarg], [ft, base]) @@ -753,3 +791,33 @@ mips32op("tlbwi", [cpu.bs('010000'), cpu.bs('1'), cpu.bs('0'*19), mips32op("teq", [cpu.bs('000000'), rs, rt, bs_code, cpu.bs('110100')], [rs, rt]) +mips32op("tne", [cpu.bs('000000'), rs, rt, bs_code, cpu.bs('110110')], + [rs, rt]) + +mips32op("clz", [cpu.bs('011100'), rs, rt, rd, cpu.bs('00000'), cpu.bs('100000')], + [rd, rs]) +mips32op("clz", [cpu.bs('000000'), rs, cpu.bs('00000'), rd, cpu.bs('00001010000')], + [rd, rs]) + +mips32op("ll", [cpu.bs('110000'), base, rt, s16imm_noarg], [rt, base]) +mips32op("ll", [cpu.bs('011111'), base, rt, s09imm_noarg, cpu.bs('0110110')], [rt, base]) + +mips32op("sc", [cpu.bs('111000'), base, rt, s16imm_noarg], [rt, base]) +mips32op("sc", [cpu.bs('011111'), base, rt, s09imm_noarg, cpu.bs('0'), cpu.bs('100110')], [rt, base]) + +mips32op("sync", [cpu.bs('000000000000000000000'), stype, cpu.bs('001111')], [stype]) + +mips32op("pref", [cpu.bs('110011'), base, hint_pref, s16imm_noarg], [hint_pref, base]) +mips32op("pref", [cpu.bs('011111'), base, hint_pref, s09imm_noarg, cpu.bs('0110101')], [hint_pref, base]) + +mips32op("tlbwr", [cpu.bs('01000010000000000000000000000110')], []) +mips32op("tlbr", [cpu.bs('01000010000000000000000000000001')], []) + +mips32op("cache", [cpu.bs('101111'), base, oper, s16imm_noarg], [oper, base]) +mips32op("cache", [cpu.bs('011111'), base, oper, s09imm_noarg, cpu.bs('0100101')], [oper, base]) + +mips32op("eret", [cpu.bs('01000010000000000000000000011000')], []) + +mips32op("mtlo", [cpu.bs('000000'), rs, cpu.bs('000000000000000'), cpu.bs('010011')], [rs]) +mips32op("mthi", [cpu.bs('000000'), rs, cpu.bs('000000000000000'), cpu.bs('010001')], [rs]) + diff --git a/miasm/arch/mips32/regs.py b/miasm/arch/mips32/regs.py index 1513e989..967b7458 100644 --- a/miasm/arch/mips32/regs.py +++ b/miasm/arch/mips32/regs.py @@ -40,15 +40,43 @@ R_HI_init = ExprId('R_HI_init', 32) cpr0_str = ["CPR0_%d"%x for x in range(0x100)] cpr0_str[0] = "INDEX" +cpr0_str[8] = "RANDOM" cpr0_str[16] = "ENTRYLO0" cpr0_str[24] = "ENTRYLO1" +cpr0_str[32] = "CONTEXT" +cpr0_str[33] = "CONTEXTCONFIG" cpr0_str[40] = "PAGEMASK" +cpr0_str[41] = "PAGEGRAIN" +cpr0_str[42] = "SEGCTL0" +cpr0_str[43] = "SEGCTL1" +cpr0_str[44] = "SEGCTL2" +cpr0_str[45] = "PWBASE" +cpr0_str[46] = "PWFIELD" +cpr0_str[47] = "PWSIZE" +cpr0_str[48] = "WIRED" +cpr0_str[54] = "PWCTL" +cpr0_str[64] = "BADVADDR" +cpr0_str[65] = "BADINSTR" +cpr0_str[66] = "BADINSTRP" cpr0_str[72] = "COUNT" cpr0_str[80] = "ENTRYHI" cpr0_str[104] = "CAUSE" cpr0_str[112] = "EPC" +cpr0_str[120] = "PRID" +cpr0_str[121] = "EBASE" cpr0_str[128] = "CONFIG" +cpr0_str[129] = "CONFIG1" +cpr0_str[130] = "CONFIG2" +cpr0_str[131] = "CONFIG3" +cpr0_str[132] = "CONFIG4" +cpr0_str[133] = "CONFIG5" cpr0_str[152] = "WATCHHI" +cpr0_str[250] = "KSCRATCH" +cpr0_str[251] = "KSCRATCH1" +cpr0_str[252] = "KSCRATCH2" +cpr0_str[253] = "KSCRATCH3" +cpr0_str[254] = "KSCRATCH4" +cpr0_str[255] = "KSCRATCH5" regs_cpr0_expr, regs_cpr0_init, regs_cpr0_info = gen_regs(cpr0_str, globals()) diff --git a/miasm/arch/mips32/sem.py b/miasm/arch/mips32/sem.py index 5fc491a7..23684a8d 100644 --- a/miasm/arch/mips32/sem.py +++ b/miasm/arch/mips32/sem.py @@ -67,6 +67,12 @@ def lbu(arg1, arg2): arg1 = mem8[arg2.ptr].zeroExtend(32) @sbuild.parse +def lh(arg1, arg2): + """A word is loaded into a register @arg1 from the + specified address @arg2.""" + arg1 = mem16[arg2.ptr].signExtend(32) + +@sbuild.parse def lhu(arg1, arg2): """A word is loaded (unsigned extended) into a register @arg1 from the specified address @arg2.""" @@ -78,6 +84,11 @@ def lb(arg1, arg2): arg1 = mem8[arg2.ptr].signExtend(32) @sbuild.parse +def ll(arg1, arg2): + "To load a word from memory for an atomic read-modify-write" + arg1 = arg2 + +@sbuild.parse def beq(arg1, arg2, arg3): "Branches on @arg3 if the quantities of two registers @arg1, @arg2 are eq" dst = arg3 if ExprOp(m2_expr.TOK_EQUAL, arg1, arg2) else ExprLoc(ir.get_next_break_loc_key(instr), ir.IRDst.size) @@ -85,6 +96,13 @@ def beq(arg1, arg2, arg3): ir.IRDst = dst @sbuild.parse +def beql(arg1, arg2, arg3): + "Branches on @arg3 if the quantities of two registers @arg1, @arg2 are eq" + dst = arg3 if ExprOp(m2_expr.TOK_EQUAL, arg1, arg2) else ExprLoc(ir.get_next_delay_loc_key(instr), ir.IRDst.size) + PC = dst + ir.IRDst = dst + +@sbuild.parse def bgez(arg1, arg2): """Branches on @arg2 if the quantities of register @arg1 is greater than or equal to zero""" @@ -93,6 +111,14 @@ def bgez(arg1, arg2): ir.IRDst = dst @sbuild.parse +def bgezl(arg1, arg2): + """Branches on @arg2 if the quantities of register @arg1 is greater than or + equal to zero""" + dst = ExprLoc(ir.get_next_delay_loc_key(instr), ir.IRDst.size) if ExprOp(m2_expr.TOK_INF_SIGNED, arg1, ExprInt(0, arg1.size)) else arg2 + PC = dst + ir.IRDst = dst + +@sbuild.parse def bne(arg1, arg2, arg3): """Branches on @arg3 if the quantities of two registers @arg1, @arg2 are NOT equal""" @@ -101,6 +127,14 @@ def bne(arg1, arg2, arg3): ir.IRDst = dst @sbuild.parse +def bnel(arg1, arg2, arg3): + """Branches on @arg3 if the quantities of two registers @arg1, @arg2 are NOT + equal""" + dst = ExprLoc(ir.get_next_delay_loc_key(instr), ir.IRDst.size) if ExprOp(m2_expr.TOK_EQUAL, arg1, arg2) else arg3 + PC = dst + ir.IRDst = dst + +@sbuild.parse def lui(arg1, arg2): """The immediate value @arg2 is shifted left 16 bits and stored in the register @arg1. The lower 16 bits are zeroes.""" @@ -111,6 +145,14 @@ def nop(): """Do nothing""" @sbuild.parse +def sync(arg1): + """Synchronize Shared Memory""" + +@sbuild.parse +def pref(arg1, arg2): + """To move data between memory and cache""" + +@sbuild.parse def j(arg1): """Jump to an address @arg1""" PC = arg1 @@ -248,6 +290,13 @@ def bltz(arg1, arg2): ir.IRDst = dst_o @sbuild.parse +def bltzl(arg1, arg2): + """Branches on @arg2 if the register @arg1 is less than zero""" + dst_o = arg2 if ExprOp(m2_expr.TOK_INF_SIGNED, arg1, ExprInt(0, arg1.size)) else ExprLoc(ir.get_next_delay_loc_key(instr), ir.IRDst.size) + PC = dst_o + ir.IRDst = dst_o + +@sbuild.parse def blez(arg1, arg2): """Branches on @arg2 if the register @arg1 is less than or equal to zero""" cond = ExprOp(m2_expr.TOK_INF_EQUAL_SIGNED, arg1, ExprInt(0, arg1.size)) @@ -256,6 +305,14 @@ def blez(arg1, arg2): ir.IRDst = dst_o @sbuild.parse +def blezl(arg1, arg2): + """Branches on @arg2 if the register @arg1 is less than or equal to zero""" + cond = ExprOp(m2_expr.TOK_INF_EQUAL_SIGNED, arg1, ExprInt(0, arg1.size)) + dst_o = arg2 if cond else ExprLoc(ir.get_next_delay_loc_key(instr), ir.IRDst.size) + PC = dst_o + ir.IRDst = dst_o + +@sbuild.parse def bgtz(arg1, arg2): """Branches on @arg2 if the register @arg1 is greater than zero""" cond = ExprOp(m2_expr.TOK_INF_EQUAL_SIGNED, arg1, ExprInt(0, arg1.size)) @@ -264,6 +321,14 @@ def bgtz(arg1, arg2): ir.IRDst = dst_o @sbuild.parse +def bgtzl(arg1, arg2): + """Branches on @arg2 if the register @arg1 is greater than zero""" + cond = ExprOp(m2_expr.TOK_INF_EQUAL_SIGNED, arg1, ExprInt(0, arg1.size)) + dst_o = ExprLoc(ir.get_next_delay_loc_key(instr), ir.IRDst.size) if cond else arg2 + PC = dst_o + ir.IRDst = dst_o + +@sbuild.parse def wsbh(arg1, arg2): arg1 = ExprCompose(arg2[8:16], arg2[0:8], arg2[24:32], arg2[16:24]) @@ -320,6 +385,14 @@ def tlbwi(): def tlbp(): "TODO XXX" +@sbuild.parse +def tlbwr(): + "TODO XXX" + +@sbuild.parse +def tlbr(): + "TODO XXX" + def ins(ir, instr, a, b, c, d): e = [] pos = int(c) @@ -327,12 +400,12 @@ def ins(ir, instr, a, b, c, d): my_slices = [] if pos != 0: - my_slices.append((a[:pos], 0, pos)) + my_slices.append(a[:pos]) if l != 0: - my_slices.append((b[:l], pos, pos+l)) + my_slices.append(b[:l]) if pos + l != 32: - my_slices.append((a[pos+l:], pos+l, 32)) - r = m2_expr.ExprCompose(my_slices) + my_slices.append(a[pos+l:]) + r = m2_expr.ExprCompose(*my_slices) e.append(m2_expr.ExprAssign(a, r)) return e, [] @@ -364,12 +437,24 @@ def bc1t(arg1, arg2): ir.IRDst = dst_o @sbuild.parse +def bc1tl(arg1, arg2): + dst_o = arg2 if arg1 else ExprLoc(ir.get_next_delay_loc_key(instr), ir.IRDst.size) + PC = dst_o + ir.IRDst = dst_o + +@sbuild.parse def bc1f(arg1, arg2): dst_o = ExprLoc(ir.get_next_break_loc_key(instr), ir.IRDst.size) if arg1 else arg2 PC = dst_o ir.IRDst = dst_o @sbuild.parse +def bc1fl(arg1, arg2): + dst_o = ExprLoc(ir.get_next_delay_loc_key(instr), ir.IRDst.size) if arg1 else arg2 + PC = dst_o + ir.IRDst = dst_o + +@sbuild.parse def cvt_d_w(arg1, arg2): # TODO XXX arg1 = 'flt_d_w'(arg2) @@ -424,6 +509,23 @@ def ei(arg1): def ehb(arg1): "NOP" +@sbuild.parse +def sc(arg1, arg2): + arg2 = arg1; + arg1 = ExprInt(0x1, 32) + +@sbuild.parse +def mthi(arg1): + R_HI = arg1 + +@sbuild.parse +def mtlo(arg1): + R_LOW = arg1 + +def clz(ir, instr, rs, rd): + e = [] + e.append(ExprAssign(rd, ExprOp('cntleadzeros', rs))) + return e, [] def teq(ir, instr, arg1, arg2): e = [] @@ -436,7 +538,7 @@ def teq(ir, instr, arg1, arg2): do_except.append(m2_expr.ExprAssign(exception_flags, m2_expr.ExprInt( EXCEPT_DIV_BY_ZERO, exception_flags.size))) do_except.append(m2_expr.ExprAssign(ir.IRDst, loc_next_expr)) - blk_except = IRBlock(loc_except.index, [AssignBlock(do_except, instr)]) + blk_except = IRBlock(loc_except, [AssignBlock(do_except, instr)]) cond = arg1 - arg2 @@ -447,6 +549,28 @@ def teq(ir, instr, arg1, arg2): return e, [blk_except] +def tne(ir, instr, arg1, arg2): + e = [] + + loc_except, loc_except_expr = ir.gen_loc_key_and_expr(ir.IRDst.size) + loc_next = ir.get_next_loc_key(instr) + loc_next_expr = m2_expr.ExprLoc(loc_next, ir.IRDst.size) + + do_except = [] + do_except.append(m2_expr.ExprAssign(exception_flags, m2_expr.ExprInt( + EXCEPT_DIV_BY_ZERO, exception_flags.size))) + do_except.append(m2_expr.ExprAssign(ir.IRDst, loc_next_expr)) + blk_except = IRBlock(loc_except, [AssignBlock(do_except, instr)]) + + cond = arg1 ^ arg2 + + + e = [] + e.append(m2_expr.ExprAssign(ir.IRDst, + m2_expr.ExprCond(cond, loc_next_expr, loc_except_expr))) + + return e, [blk_except] + mnemo_func = sbuild.functions mnemo_func.update({ @@ -473,8 +597,10 @@ mnemo_func.update({ 'subu': l_sub, 'xor': l_xor, 'xori': l_xor, + 'clz': clz, 'teq': teq, -}) + 'tne': tne + }) def get_mnemo_expr(ir, instr, *args): instr, extra_ir = mnemo_func[instr.name.lower()](ir, instr, *args) @@ -511,6 +637,9 @@ class ir_mips32l(IntermediateRepresentation): def get_next_break_loc_key(self, instr): return self.loc_db.get_or_create_offset_location(instr.offset + 8) + def get_next_delay_loc_key(self, instr): + return self.loc_db.get_or_create_offset_location(instr.offset + 16) + class ir_mips32b(ir_mips32l): def __init__(self, loc_db=None): self.addrsize = 32 diff --git a/miasm/arch/msp430/arch.py b/miasm/arch/msp430/arch.py index a700b04a..93854153 100644 --- a/miasm/arch/msp430/arch.py +++ b/miasm/arch/msp430/arch.py @@ -64,7 +64,7 @@ class msp430_arg(m_arg): def asm_ast_to_expr(self, value, loc_db): if isinstance(value, AstId): name = value.name - if isinstance(name, Expr): + if is_expr(name): return name assert isinstance(name, str) if name in gpregs.str: diff --git a/miasm/arch/ppc/arch.py b/miasm/arch/ppc/arch.py index 29550931..2b951027 100644 --- a/miasm/arch/ppc/arch.py +++ b/miasm/arch/ppc/arch.py @@ -129,9 +129,9 @@ class instruction_ppc(instruction): if not isinstance(e, ExprInt): return if name[-1] != 'A': - ad = e.arg + self.offset + ad = (int(e) + self.offset) & 0xFFFFFFFF else: - ad = e.arg + ad = int(e) loc_key = loc_db.get_or_create_offset_location(ad) s = ExprLoc(loc_key, e.size) self.args[address_index] = s @@ -175,11 +175,11 @@ class instruction_ppc(instruction): if self.name[-1] != 'A': if self.offset is None: raise ValueError('symbol not resolved %s' % self.l) - off = e.arg - (self.offset + self.l) + off = (int(e) + 0x100000000 - (self.offset + self.l)) & 0xFFFFFFFF if int(off % 4): raise ValueError('Offset %r must be a multiple of four' % off) else: - off = e.arg + off = int(e) self.args[0] = ExprInt(off, 32) def get_args_expr(self): @@ -343,7 +343,7 @@ class ppc_s14imm_branch(ppc_imm): def encode(self): if not isinstance(self.expr, ExprInt): return False - v = self.expr.arg.arg + v = int(self.expr) if v & 0x3: return False v = v >> 2 @@ -362,7 +362,7 @@ class ppc_s24imm_branch(ppc_imm): def encode(self): if not isinstance(self.expr, ExprInt): return False - v = self.expr.arg.arg + v = int(self.expr) if v & 0x3: return False v = v >> 2 @@ -381,7 +381,7 @@ class ppc_s16imm(ppc_imm): def encode(self): if not isinstance(self.expr, ExprInt): return False - v = self.expr.arg.arg + v = int(self.expr) if sign_ext(v & self.lmask, 16, 32) != v: return False self.value = v & self.lmask @@ -398,7 +398,7 @@ class ppc_u16imm(ppc_imm): def encode(self): if not isinstance(self.expr, ExprInt): return False - v = self.expr.arg.arg + v = int(self.expr) if v & self.lmask != v: return False self.value = v & self.lmask @@ -416,7 +416,7 @@ class ppc_spr(ppc_imm): def encode(self, e): if not isinstance(e, ExprInt): return False - self.value = ppc_swap_10(e.arg) + self.value = ppc_swap_10(int(e)) return True class ppc_tbr(ppc_imm): @@ -428,7 +428,7 @@ class ppc_tbr(ppc_imm): def encode(self, e): if not isinstance(e, ExprInt): return False - self.value = ppc_swap_10(e.arg) + self.value = ppc_swap_10(int(e)) return True class ppc_u08imm(ppc_u16imm): @@ -443,6 +443,13 @@ class ppc_u04imm(ppc_u16imm): class ppc_u02imm_noarg(imm_noarg): pass +class ppc_float(ppc_reg): + reg_info = floatregs + parser = reg_info.parser + +class ppc_vex(ppc_reg): + reg_info = vexregs + parser = reg_info.parser def ppc_bo_bi_to_mnemo(bo, bi, prefer_taken=True, default_taken=True): bo2mnemo = { 0: 'DNZF', 2: 'DZF', 4: 'F', 8: 'DNZT', @@ -520,7 +527,7 @@ class ppc_deref32(ppc_arg): if len(addr.args) != 2: return False reg, disp = addr.args[0], addr.args[1] - v = int(disp.arg) + v = int(disp) if sign_ext(v & 0xFFFF, 16, 32) != v: return False v &= 0xFFFF @@ -566,6 +573,16 @@ dregimm = bs(l=16, cls=(ppc_deref32,)) rc_mod = bs_mod_name(l=1, mn_mod=['', '.'], fname='rc') +frd = bs(l=5, cls=(ppc_float,)) +frb = bs(l=5, cls=(ppc_float,)) +frs = bs(l=5, cls=(ppc_float,)) +fm = bs(l=8, cls=(ppc_u08imm,)) + +va = bs(l=5, cls=(ppc_vex,)) +vb = bs(l=5, cls=(ppc_vex,)) +vd = bs(l=5, cls=(ppc_vex,)) +rb_noarg = bs(l=5, cls=(ppc_gpreg_noarg,), fname="rb") + arith1_name = {"MULLI": 0b000111, "SUBFIC": 0b001000, "ADDIC": 0b001100, "ADDIC.": 0b001101 } @@ -636,6 +653,17 @@ dcb_name = {"DCBST": 0b00001, "DCBF": 0b00010, "DCBI": 0b01110, "DCBA": 0b10111, "ICBI": 0b11110, "DCBZ": 0b11111 } + +load1_name_float = {"LFS": 0b110000, "LFD": 0b110010 } +load1_name_float_u = {"LFSU": 0b110001, "LFDU": 0b110011 } +store1_name_float = {"STFS": 0b110100, "STFD": 0b110110 } +store1_name_float_u = {"STFSU": 0b110101, "STFDU": 0b110111 } + +load1_name_vex = {"LVEBX": 0b0000000111, "LVEHX": 0b0000100111, + "LVEWX": 0b0001000111, "LVSL": 0b0000000110, + "LVSR": 0b0000100110, "LVX": 0b0001100111, + "LVXL": 0b0101100111 } + class bs_mod_name_prio4(bs_mod_name): prio = 4 @@ -762,3 +790,15 @@ ppcop("SRAWI", [bs('011111'), rs, ra, sh, bs('1100111000'), rc_mod], [ra, rs, sh]) ppcop("EIEIO", [bs('011111'), bs('000000000000000'), bs('11010101100')]) + +ppcop("load1f", [bs_name(l=6, name=load1_name_float), frd, ra_noarg, dregimm]) +ppcop("load1fu", [bs_name(l=6, name=load1_name_float_u), frd, ra_noarg, dregimm]) +ppcop("store1f", [bs_name(l=6, name=store1_name_float), frd, ra_noarg, dregimm]) +ppcop("store1fu", [bs_name(l=6, name=store1_name_float_u), frd, ra_noarg, dregimm]) +ppcop("MTFSF", [bs('111111'), bs('0'), fm, bs('0'), frb, bs('10110001110')]) +ppcop("MTFSF.", [bs('111111'), bs('0'), fm, bs('0'), frb, bs('10110001111')]) +ppcop("MFFS", [bs('111111'), frd, bs('00000000001001000111'), bs('0')]) +ppcop("MFFS.", [bs('111111'), frd, bs('00000000001001000111'), bs('1')]) + +ppcop("load1vex", [bs('011111'), vd, ra, rb, bs_name(l=10, name=load1_name_vex), bs('0')]) +ppcop("mtvscr", [bs('0001000000000000'), vb, bs('11001000100')]) diff --git a/miasm/arch/ppc/regs.py b/miasm/arch/ppc/regs.py index 4b710045..00781d6a 100644 --- a/miasm/arch/ppc/regs.py +++ b/miasm/arch/ppc/regs.py @@ -35,7 +35,7 @@ xerbcreg_expr, xerbcreg_init, xerbcreg = gen_regs(xerbcreg_str, globals(), 7) -otherregs_str = ["PC", "CTR", "LR" ] +otherregs_str = ["PC", "CTR", "LR", "FPSCR", "VRSAVE", "VSCR" ] otherregs_expr, otherregs_init, otherregs = gen_regs(otherregs_str, globals(), 32) @@ -55,10 +55,18 @@ mmuregs_str = (["SR%d" % i for i in range(16)] + mmuregs_expr, mmuregs_init, mmuregs = gen_regs(mmuregs_str, globals(), 32) +floatregs_str = (["FPR%d" % i for i in range(32)]) +floatregs_expr, floatregs_init, floatregs = gen_regs(floatregs_str, + globals(), 64) + +vexregs_str = (["VR%d" % i for i in range(32)]) +vexregs_expr, vexregs_init, vexregs = gen_regs(vexregs_str, + globals(), 128) + regs_flt_expr = [] all_regs_ids = (gpregs_expr + crfbitregs_expr + xerbitregs_expr + - xerbcreg_expr + otherregs_expr + superregs_expr + mmuregs_expr + + xerbcreg_expr + otherregs_expr + superregs_expr + mmuregs_expr + floatregs_expr + vexregs_expr + [ exception_flags, spr_access, reserve, reserve_address ]) all_regs_ids_byname = dict([(x.name, x) for x in all_regs_ids]) all_regs_ids_init = [ExprId("%s_init" % x.name, x.size) for x in all_regs_ids] diff --git a/miasm/arch/ppc/sem.py b/miasm/arch/ppc/sem.py index fd6db8f3..7ca7e3e1 100644 --- a/miasm/arch/ppc/sem.py +++ b/miasm/arch/ppc/sem.py @@ -25,6 +25,20 @@ sr_dict = { 12: SR12, 13: SR13, 14: SR14, 15: SR15 } +float_dict = { + 0: FPR0, 1: FPR1, 2: FPR2, 3: FPR3, 4: FPR4, 5: FPR5, 6: FPR6, 7: FPR7, 8: FPR8, + 9: FPR9, 10: FPR10, 11: FPR11, 12: FPR12, 13: FPR13, 14: FPR14, 15: FPR15, 16: FPR16, + 17: FPR17, 18: FPR18, 19: FPR19, 20: FPR20, 21: FPR21, 22: FPR22, 23: FPR23, 24: FPR24, + 25: FPR25, 26: FPR26, 27: FPR27, 28: FPR28, 29: FPR29, 30: FPR30, 31: FPR31 +} + +vex_dict = { + 0: VR0, 1: VR1, 2: VR2, 3: VR3, 4: VR4, 5: VR5, 6: VR6, 7: VR7, 8: VR8, + 9: VR9, 10: VR10, 11: VR11, 12: VR12, 13: VR13, 14: VR14, 15: VR15, 16: VR16, + 17: VR17, 18: VR18, 19: VR19, 20: VR20, 21: VR21, 22: VR22, 23: VR23, 24: VR24, + 25: VR25, 26: VR26, 27: VR27, 28: VR28, 29: VR29, 30: VR30, 31: VR31, +} + crf_dict = dict((ExprId("CR%d" % i, 4), dict( (bit, ExprId("CR%d_%s" % (i, bit), 1)) for bit in ['LT', 'GT', 'EQ', 'SO' ] )) @@ -34,6 +48,8 @@ ctx = { 'crf_dict': crf_dict, 'spr_dict': spr_dict, 'sr_dict': sr_dict, + 'float_dict': float_dict, + 'vex_dict': vex_dict, 'expr': expr, } @@ -125,7 +141,7 @@ def mn_do_cntlzw(ir, instr, ra, rs): return ret, [] def crbit_to_reg(bit): - bit = bit.arg.arg + bit = int(bit) crid = bit // 4 bitname = [ 'LT', 'GT', 'EQ', 'SO' ][bit % 4] return all_regs_ids_byname["CR%d_%s" % (crid, bitname)] @@ -232,8 +248,8 @@ def mn_do_exts(ir, instr, ra, rs): def byte_swap(expr): nbytes = expr.size // 8 - bytes = [ expr[i*8:i*8+8] for i in range(nbytes - 1, -1, -1) ] - return ExprCompose(bytes) + lbytes = [ expr[i*8:i*8+8] for i in range(nbytes - 1, -1, -1) ] + return ExprCompose(*lbytes) def mn_do_load(ir, instr, arg1, arg2, arg3=None): assert instr.name[0] == 'L' @@ -244,6 +260,12 @@ def mn_do_load(ir, instr, arg1, arg2, arg3=None): return mn_do_lmw(ir, instr, arg1, arg2) elif instr.name[1] == 'S': raise RuntimeError("LSWI, and LSWX need implementing") + elif instr.name[1] == 'F': + print("Warning, instruction %s implemented as NOP" % instr) + return [], [] + elif instr.name[1] == 'V': + print("Warning, instruction %s implemented as NOP" % instr) + return [], [] size = {'B': 8, 'H': 16, 'W': 32}[instr.name[1]] @@ -298,7 +320,7 @@ def mn_do_load(ir, instr, arg1, arg2, arg3=None): def mn_do_lmw(ir, instr, rd, src): ret = [] - address = src.arg + address = src.ptr ri = int(rd.name[1:],10) i = 0 while ri <= 31: @@ -348,7 +370,7 @@ def mn_mfmsr(rd): rd = MSR def mn_mfspr(ir, instr, arg1, arg2): - sprid = arg2.arg.arg + sprid = int(arg2) gprid = int(arg1.name[1:]) if sprid in spr_dict: return [ ExprAssign(arg1, spr_dict[sprid]) ], [] @@ -365,7 +387,7 @@ def mn_mtcrf(ir, instr, crm, rs): ret = [] for i in range(8): - if crm.arg.arg & (1 << (7 - i)): + if int(crm) & (1 << (7 - i)): j = (28 - 4 * i) + 3 for b in ['LT', 'GT', 'EQ', 'SO']: ret.append(ExprAssign(all_regs_ids_byname["CR%d_%s" % (i, b)], @@ -379,7 +401,7 @@ def mn_mtmsr(ir, instr, rs): return [ ExprAssign(MSR, rs) ], [] def mn_mtspr(ir, instr, arg1, arg2): - sprid = arg1.arg.arg + sprid = int(arg1) gprid = int(arg2.name[1:]) if sprid in spr_dict: return [ ExprAssign(spr_dict[sprid], arg2) ], [] @@ -505,7 +527,7 @@ def mn_do_rfi(ir, instr): ret = [ ExprAssign(MSR, (MSR & ~ExprInt(0b1111111101110011, 32) | ExprCompose(SRR1[0:2], ExprInt(0, 2), - SRR1[4:7], ExprInt(0, 1), + SRR1[4:7], ExprInt(0, 1), SRR1[8:16], ExprInt(0, 16)))), ExprAssign(PC, dest), ExprAssign(ir.IRDst, dest) ] @@ -562,7 +584,7 @@ def mn_do_srawi(ir, instr, ra, rs, imm): if instr.name[-1] == '.': ret += mn_compute_flags(rvalue) - mask = ExprInt(0xFFFFFFFF >> (32 - imm.arg.arg), 32) + mask = ExprInt(0xFFFFFFFF >> (32 - int(imm)), 32) ret.append(ExprAssign(XER_CA, rs.msb() & ExprCond(rs & mask, ExprInt(1, 1), ExprInt(0, 1)))) @@ -580,7 +602,7 @@ def mn_do_srw(ir, instr, ra, rs, rb): def mn_do_stmw(ir, instr, rs, dest): ret = [] - address = dest.arg + address = dest.ptr ri = int(rs.name[1:],10) i = 0 while ri <= 31: @@ -599,6 +621,9 @@ def mn_do_store(ir, instr, arg1, arg2, arg3=None): if instr.name[2] == 'S': raise RuntimeError("STSWI, and STSWX need implementing") + elif instr.name[2] == 'F': + print("Warning, instruction %s implemented as NOP" % instr) + return [], [] size = {'B': 8, 'H': 16, 'W': 32}[instr.name[2]] @@ -650,8 +675,8 @@ def mn_do_store(ir, instr, arg1, arg2, arg3=None): ret.append(ExprAssign(ir.IRDst, loc_next)) dont = flags + [ ExprAssign(CR0_EQ, ExprInt(0,1)), ExprAssign(ir.IRDst, loc_next) ] - additional_ir = [ IRBlock(loc_do, [ AssignBlock(ret) ]), - IRBlock(loc_dont, [ AssignBlock(dont) ]) ] + additional_ir = [ IRBlock(loc_do.loc_key, [ AssignBlock(ret) ]), + IRBlock(loc_dont.loc_key, [ AssignBlock(dont) ]) ] ret = [ ExprAssign(reserve, ExprInt(0, 1)), ExprAssign(ir.IRDst, ExprCond(reserve, loc_do, loc_dont)) ] @@ -834,16 +859,21 @@ sem_dir = { 'MCRF': mn_do_mcrf, 'MCRXR': mn_do_mcrxr, 'MFCR': mn_do_mfcr, + 'MFFS': mn_do_nop_warn, + 'MFFS.': mn_do_nop_warn, 'MFMSR': mn_mfmsr, 'MFSPR': mn_mfspr, 'MFSR': mn_mfsr, 'MFSRIN': mn_do_nop_warn, - 'MFTB': mn_mfmsr, + 'MTFSF': mn_do_nop_warn, + 'MTFSF.': mn_do_nop_warn, + 'MFTB': mn_mfspr, 'MTCRF': mn_mtcrf, 'MTMSR': mn_mtmsr, 'MTSPR': mn_mtspr, 'MTSR': mn_mtsr, 'MTSRIN': mn_do_nop_warn, + 'MTVSCR': mn_do_nop_warn, 'NAND': mn_do_nand, 'NAND.': mn_do_nand, 'NOR': mn_do_nor, @@ -879,7 +909,7 @@ class ir_ppc32b(IntermediateRepresentation): def get_ir(self, instr): args = instr.args[:] if instr.name[0:5] in [ 'ADDIS', 'ORIS', 'XORIS', 'ANDIS' ]: - args[2] = ExprInt(args[2].arg << 16, 32) + args[2] = ExprInt(int(args[2]) << 16, 32) if instr.name[0:3] == 'ADD': if instr.name[0:4] == 'ADDZ': last_arg = ExprInt(0, 32) @@ -920,17 +950,17 @@ class ir_ppc32b(IntermediateRepresentation): instr_ir, extra_ir = mn_do_or(self, instr, *args) elif instr.name[0:2] == 'RL': instr_ir, extra_ir = mn_do_rotate(self, instr, args[0], args[1], - args[2], args[3].arg.arg, - args[4].arg.arg) + args[2], int(args[3]), + int(args[4])) elif instr.name == 'STMW': instr_ir, extra_ir = mn_do_stmw(self, instr, *args) elif instr.name[0:2] == 'ST': instr_ir, extra_ir = mn_do_store(self, instr, *args) elif instr.name[0:4] == 'SUBF': if instr.name[0:5] == 'SUBFZ': - last_arg = ExprInt(0) + last_arg = ExprInt(0, 32) elif instr.name[0:5] == 'SUBFM': - last_arg = ExprInt(0xFFFFFFFF) + last_arg = ExprInt(0xFFFFFFFF, 32) else: last_arg = args[2] instr_ir, extra_ir = mn_do_sub(self, instr, args[0], args[1], diff --git a/miasm/arch/x86/arch.py b/miasm/arch/x86/arch.py index 33b41236..127dded4 100644 --- a/miasm/arch/x86/arch.py +++ b/miasm/arch/x86/arch.py @@ -278,7 +278,7 @@ class x86_arg(m_arg): if value.name in ["FAR"]: return None - loc_key = loc_db.get_or_create_name_location(value.name.encode()) + loc_key = loc_db.get_or_create_name_location(value.name) return ExprLoc(loc_key, size_hint) if isinstance(value, AstOp): # First pass to retrieve fixed_size @@ -481,7 +481,7 @@ class instruction_x86(instruction): expr = self.args[0] if not expr.is_int(): return - addr = expr.arg + int(self.offset) + addr = (int(expr) + int(self.offset)) & int(expr.mask) loc_key = loc_db.get_or_create_offset_location(addr) self.args[0] = ExprLoc(loc_key, expr.size) @@ -547,7 +547,7 @@ class instruction_x86(instruction): def __str__(self): return self.to_string() - + def to_string(self, loc_db=None): o = super(instruction_x86, self).to_string(loc_db) if self.additional_info.g1.value & 1: @@ -1706,7 +1706,7 @@ def exprfindmod(e, o=None): def test_addr_size(ptr, size): if isinstance(ptr, ExprInt): - return ptr.arg < (1 << size) + return int(ptr) < (1 << size) else: return ptr.size == size @@ -1767,13 +1767,13 @@ def parse_mem(expr, parent, w8, sx=0, xmm=0, mm=0, bnd=0): value = ExprInt(int(disp), cast_size) if admode < value.size: if signed: - if int(disp.arg) != sign_ext(int(value), admode, disp.size): + if int(disp) != sign_ext(int(value), admode, disp.size): continue else: - if int(disp.arg) != int(value): + if int(disp) != int(value): continue else: - if int(disp.arg) != sign_ext(int(value), value.size, admode): + if int(disp) != sign_ext(int(value), value.size, admode): continue x1 = dict(dct_expr) x1[f_imm] = (encoding, value) @@ -1913,7 +1913,10 @@ def modrm2expr(modrm, parent, w8, sx=0, xmm=0, mm=0, bnd=0): if parent.disp.value is None: return None o.append(ExprInt(int(parent.disp.expr), admode)) - expr = ExprOp('+', *o) + if len(o) == 1: + expr = o[0] + else: + expr = ExprOp('+', *o) if w8 == 0: opmode = 8 elif sx == 1: @@ -2918,7 +2921,7 @@ class bs_rel_off(bs_cond_imm): parent_len = len(prefix) * 8 + self.parent.l + self.l assert(parent_len % 8 == 0) - v = int(self.expr.arg) - parent_len // 8 + v = int(self.expr) - parent_len // 8 if prefix is None: return mask = ((1 << self.l) - 1) diff --git a/miasm/arch/x86/sem.py b/miasm/arch/x86/sem.py index cf3539e2..86a933a0 100644 --- a/miasm/arch/x86/sem.py +++ b/miasm/arch/x86/sem.py @@ -28,7 +28,8 @@ from miasm.arch.x86.arch import mn_x86, repeat_mn, replace_regs from miasm.ir.ir import IntermediateRepresentation, IRBlock, AssignBlock from miasm.core.sembuilder import SemBuilder from miasm.jitter.csts import EXCEPT_DIV_BY_ZERO, EXCEPT_ILLEGAL_INSN, \ - EXCEPT_PRIV_INSN, EXCEPT_SOFT_BP, EXCEPT_INT_XX, EXCEPT_INT_1 + EXCEPT_PRIV_INSN, EXCEPT_SOFT_BP, EXCEPT_INT_XX, EXCEPT_INT_1, \ + EXCEPT_SYSCALL import math import struct @@ -1161,7 +1162,9 @@ def setalc(_, instr): def bswap(_, instr, dst): e = [] if dst.size == 16: - result = m2_expr.ExprCompose(dst[8:16], dst[:8]) + # BSWAP referencing a 16-bit register is undefined + # Seems to return 0 actually + result = m2_expr.ExprInt(0, 16) elif dst.size == 32: result = m2_expr.ExprCompose( dst[24:32], dst[16:24], dst[8:16], dst[:8]) @@ -3386,9 +3389,11 @@ def icebp(_, instr): def l_int(_, instr, src): e = [] # XXX - if src.arg == 1: + assert src.is_int() + value = int(src) + if value == 1: except_int = EXCEPT_INT_1 - elif src.arg == 3: + elif value == 3: except_int = EXCEPT_SOFT_BP else: except_int = EXCEPT_INT_XX @@ -3408,7 +3413,7 @@ def l_sysenter(_, instr): def l_syscall(_, instr): e = [] e.append(m2_expr.ExprAssign(exception_flags, - m2_expr.ExprInt(EXCEPT_PRIV_INSN, 32))) + m2_expr.ExprInt(EXCEPT_SYSCALL, 32))) return e, [] # XXX diff --git a/miasm/core/asmblock.py b/miasm/core/asmblock.py index 8f47947f..93ad6b13 100644 --- a/miasm/core/asmblock.py +++ b/miasm/core/asmblock.py @@ -628,6 +628,7 @@ class AsmCFG(DiGraph): This method should be called if a block's '.bto' in nodes have been modified without notifying this instance to resynchronize edges. """ + self._pendings = {} for block in self.blocks: edges = [] # Rebuild edges from bto @@ -960,7 +961,9 @@ def fix_loc_offset(loc_db, loc_key, offset, modified): loc_offset = loc_db.get_location_offset(loc_key) if loc_offset == offset: return - loc_db.set_location_offset(loc_key, offset, force=True) + if loc_offset is not None: + loc_db.unset_location_offset(loc_key) + loc_db.set_location_offset(loc_key, offset) modified.add(loc_key) @@ -1209,7 +1212,7 @@ def assemble_block(mnemo, block, loc_db, conservative=False): data = b"" for expr in instr.raw: expr_int = fix_expr_val(expr, loc_db) - data += pck[expr_int.size](expr_int.arg) + data += pck[expr_int.size](int(expr_int)) instr.data = data instr.offset = offset_i diff --git a/miasm/core/bin_stream.py b/miasm/core/bin_stream.py index 727a853d..9224053f 100644 --- a/miasm/core/bin_stream.py +++ b/miasm/core/bin_stream.py @@ -137,7 +137,7 @@ class bin_stream(object): if endianness == LITTLE_ENDIAN: return upck16le(data) else: - return upck32be(data) + return upck16be(data) def get_u32(self, addr, endianness=None): """ diff --git a/miasm/core/cpu.py b/miasm/core/cpu.py index 3dc7bd68..aee22c97 100644 --- a/miasm/core/cpu.py +++ b/miasm/core/cpu.py @@ -1587,13 +1587,8 @@ class imm_noarg(object): if e == [None]: return None, None - assert(isinstance(e, m2_expr.Expr)) - if isinstance(e, tuple): - self.expr = self.int2expr(e[1]) - elif isinstance(e, m2_expr.Expr): - self.expr = e - else: - raise TypeError('zarb expr') + assert(m2_expr.is_expr(e)) + self.expr = e if self.expr is None: log.debug('cannot fromstring int %r', text) return None, None diff --git a/miasm/core/graph.py b/miasm/core/graph.py index 01f580a3..8bb4371d 100644 --- a/miasm/core/graph.py +++ b/miasm/core/graph.py @@ -732,6 +732,49 @@ class DiGraph(object): yield scc + def compute_weakly_connected_components(self): + """ + Return the weakly connected components + """ + remaining = set(self.nodes()) + components = [] + while remaining: + node = remaining.pop() + todo = set() + todo.add(node) + component = set() + done = set() + while todo: + node = todo.pop() + if node in done: + continue + done.add(node) + remaining.discard(node) + component.add(node) + todo.update(self.predecessors(node)) + todo.update(self.successors(node)) + components.append(component) + return components + + + + def replace_node(self, node, new_node): + """ + Replace @node by @new_node + """ + + predecessors = self.predecessors(node) + successors = self.successors(node) + self.del_node(node) + for predecessor in predecessors: + if predecessor == node: + predecessor = new_node + self.add_uniq_edge(predecessor, new_node) + for successor in successors: + if successor == node: + successor = new_node + self.add_uniq_edge(new_node, successor) + class DiGraphSimplifier(object): """Wrapper on graph simplification passes. diff --git a/miasm/core/objc.py b/miasm/core/objc.py index 123d339a..117e3b7d 100644 --- a/miasm/core/objc.py +++ b/miasm/core/objc.py @@ -14,7 +14,8 @@ from functools import total_ordering from miasm.core.utils import cmp_elts from miasm.expression.expression_reduce import ExprReducer -from miasm.expression.expression import ExprInt, ExprId, ExprOp, ExprMem +from miasm.expression.expression import ExprInt, ExprId, ExprOp, ExprMem, \ + is_op_segm from miasm.core.ctypesmngr import CTypeUnion, CTypeStruct, CTypeId, CTypePtr,\ CTypeArray, CTypeOp, CTypeSizeof, CTypeEnum, CTypeFunc, CTypeEllipsis @@ -1045,7 +1046,7 @@ class ExprToAccessC(ExprReducer): def reduce_op(self, node, lvl=0, **kwargs): """Generate access for ExprOp""" - if not (node.expr.is_op("+") or node.expr.is_op_segm()) \ + if not (node.expr.is_op("+") or is_op_segm(node.expr)) \ or len(node.args) != 2: return None type_arg1 = self.get_solo_type(node.args[1]) diff --git a/miasm/core/parse_asm.py b/miasm/core/parse_asm.py index 2b4f1195..b742a2d2 100644 --- a/miasm/core/parse_asm.py +++ b/miasm/core/parse_asm.py @@ -65,7 +65,7 @@ def guess_next_new_label(loc_db): """Generate a new label @loc_db: the LocationDB instance""" i = 0 - gen_name = b"loc_%.8X" + gen_name = "loc_%.8X" while True: name = gen_name % i label = loc_db.get_name_location(name) @@ -121,7 +121,7 @@ def parse_txt(mnemo, attrib, txt, loc_db=None): # label beginning with .L match_re = LABEL_RE.match(line) if match_re: - label_name = match_re.group(1).encode() + label_name = match_re.group(1) label = loc_db.get_or_create_name_location(label_name) lines.append(label) continue diff --git a/miasm/core/utils.py b/miasm/core/utils.py index 7667a656..37248c40 100644 --- a/miasm/core/utils.py +++ b/miasm/core/utils.py @@ -81,10 +81,16 @@ def printable(string): def force_bytes(value): - try: - return value.encode() - except AttributeError: + if isinstance(value, bytes): return value + if not isinstance(value, str): + return value + out = [] + for c in value: + c = ord(c) + assert c < 0x100 + out.append(c) + return bytes(out) def force_str(value): diff --git a/miasm/expression/expression.py b/miasm/expression/expression.py index 93094979..c2bf5b8b 100644 --- a/miasm/expression/expression.py +++ b/miasm/expression/expression.py @@ -98,18 +98,6 @@ def should_parenthesize_child(child, parent): def str_protected_child(child, parent): return ("(%s)" % child) if should_parenthesize_child(child, parent) else str(child) -def visit_chk(visitor): - "Function decorator launching callback on Expression visit" - def wrapped(expr, callback, test_visit=lambda x: True): - if (test_visit is not None) and (not test_visit(expr)): - return expr - expr_new = visitor(expr, callback, test_visit) - if expr_new is None: - return None - expr_new2 = callback(expr_new) - return expr_new2 - return wrapped - # Expression display @@ -152,6 +140,49 @@ class DiGraphExpr(DiGraph): return "" +def is_expr(expr): + return isinstance( + expr, + ( + ExprInt, ExprId, ExprMem, + ExprSlice, ExprCompose, ExprCond, + ExprLoc, ExprOp + ) + ) + +def is_associative(expr): + "Return True iff current operation is associative" + return (expr.op in ['+', '*', '^', '&', '|']) + +def is_commutative(expr): + "Return True iff current operation is commutative" + return (expr.op in ['+', '*', '^', '&', '|']) + +def is_op_segm(expr): + """Returns True if is ExprOp and op == 'segm'""" + return expr.is_op('segm') + +def is_mem_segm(expr): + """Returns True if is ExprMem and ptr is_op_segm""" + return expr.is_mem() and is_op_segm(expr.ptr) + +def canonize_to_exprloc(locdb, expr): + """ + If expr is ExprInt, return ExprLoc with corresponding loc_key + Else, return expr + + @expr: Expr instance + """ + if expr.is_int(): + loc_key = locdb.get_or_create_offset_location(int(expr)) + ret = ExprLoc(loc_key, expr.size) + return ret + return expr + +def is_function_call(expr): + """Returns true if the considered Expr is a function call + """ + return expr.is_op() and expr.op.startswith('call') @total_ordering class LocKey(object): @@ -183,6 +214,263 @@ class LocKey(object): def __str__(self): return "loc_key_%d" % self.key + +class ExprWalkBase(object): + """ + Walk through sub-expressions, call @callback on them. + If @callback returns a non None value, stop walk and return this value + """ + + def __init__(self, callback): + self.callback = callback + + def visit(self, expr, *args, **kwargs): + if expr.is_int() or expr.is_id() or expr.is_loc(): + pass + elif expr.is_assign(): + ret = self.visit(expr.dst, *args, **kwargs) + if ret: + return ret + src = self.visit(expr.src, *args, **kwargs) + if ret: + return ret + elif expr.is_cond(): + ret = self.visit(expr.cond, *args, **kwargs) + if ret: + return ret + ret = self.visit(expr.src1, *args, **kwargs) + if ret: + return ret + ret = self.visit(expr.src2, *args, **kwargs) + if ret: + return ret + elif expr.is_mem(): + ret = self.visit(expr.ptr, *args, **kwargs) + if ret: + return ret + elif expr.is_slice(): + ret = self.visit(expr.arg, *args, **kwargs) + if ret: + return ret + elif expr.is_op(): + for arg in expr.args: + ret = self.visit(arg, *args, **kwargs) + if ret: + return ret + elif expr.is_compose(): + for arg in expr.args: + ret = self.visit(arg, *args, **kwargs) + if ret: + return ret + else: + raise TypeError("Visitor can only take Expr") + + ret = self.callback(expr, *args, **kwargs) + return ret + + +class ExprWalk(ExprWalkBase): + """ + Walk through sub-expressions, call @callback on them. + If @callback returns a non None value, stop walk and return this value + Use cache mechanism. + """ + def __init__(self, callback): + self.cache = set() + self.callback = callback + + def visit(self, expr, *args, **kwargs): + if expr in self.cache: + return None + ret = super(ExprWalk, self).visit(expr, *args, **kwargs) + if ret: + return ret + self.cache.add(expr) + return None + + +class ExprGetR(ExprWalkBase): + """ + Return ExprId/ExprMem used by a given expression + """ + def __init__(self, mem_read=False, cst_read=False): + super(ExprGetR, self).__init__(lambda x:None) + self.mem_read = mem_read + self.cst_read = cst_read + self.elements = set() + self.cache = dict() + + def get_r_leaves(self, expr): + if (expr.is_int() or expr.is_loc()) and self.cst_read: + self.elements.add(expr) + elif expr.is_mem(): + self.elements.add(expr) + elif expr.is_id(): + self.elements.add(expr) + + def visit(self, expr, *args, **kwargs): + cache_key = (expr, self.mem_read, self.cst_read) + if cache_key in self.cache: + return self.cache[cache_key] + ret = self.visit_inner(expr, *args, **kwargs) + self.cache[cache_key] = ret + return ret + + def visit_inner(self, expr, *args, **kwargs): + self.get_r_leaves(expr) + if expr.is_mem() and not self.mem_read: + # Don't visit memory sons + return None + + if expr.is_assign(): + if expr.dst.is_mem() and self.mem_read: + ret = super(ExprGetR, self).visit(expr.dst, *args, **kwargs) + if expr.src.is_mem(): + self.elements.add(expr.src) + self.get_r_leaves(expr.src) + if expr.src.is_mem() and not self.mem_read: + return None + ret = super(ExprGetR, self).visit(expr.src, *args, **kwargs) + return ret + ret = super(ExprGetR, self).visit(expr, *args, **kwargs) + return ret + + +class ExprVisitorBase(object): + """ + Rebuild expression by visiting sub-expressions + """ + def visit(self, expr, *args, **kwargs): + if expr.is_int() or expr.is_id() or expr.is_loc(): + ret = expr + elif expr.is_assign(): + dst = self.visit(expr.dst, *args, **kwargs) + src = self.visit(expr.src, *args, **kwargs) + ret = ExprAssign(dst, src) + elif expr.is_cond(): + cond = self.visit(expr.cond, *args, **kwargs) + src1 = self.visit(expr.src1, *args, **kwargs) + src2 = self.visit(expr.src2, *args, **kwargs) + ret = ExprCond(cond, src1, src2) + elif expr.is_mem(): + ptr = self.visit(expr.ptr, *args, **kwargs) + ret = ExprMem(ptr, expr.size) + elif expr.is_slice(): + arg = self.visit(expr.arg, *args, **kwargs) + ret = ExprSlice(arg, expr.start, expr.stop) + elif expr.is_op(): + args = [self.visit(arg, *args, **kwargs) for arg in expr.args] + ret = ExprOp(expr.op, *args) + elif expr.is_compose(): + args = [self.visit(arg, *args, **kwargs) for arg in expr.args] + ret = ExprCompose(*args) + else: + raise TypeError("Visitor can only take Expr") + return ret + + +class ExprVisitorCallbackTopToBottom(ExprVisitorBase): + """ + Rebuild expression by visiting sub-expressions + Call @callback on each sub-expression + if @callback return non None value, replace current node with this value + Else, continue visit of sub-expressions + """ + def __init__(self, callback): + super(ExprVisitorCallbackTopToBottom, self).__init__() + self.cache = dict() + self.callback = callback + + def visit(self, expr, *args, **kwargs): + if expr in self.cache: + return self.cache[expr] + ret = self.visit_inner(expr, *args, **kwargs) + self.cache[expr] = ret + return ret + + def visit_inner(self, expr, *args, **kwargs): + ret = self.callback(expr) + if ret: + return ret + ret = super(ExprVisitorCallbackTopToBottom, self).visit(expr, *args, **kwargs) + return ret + + +class ExprVisitorCallbackBottomToTop(ExprVisitorBase): + """ + Rebuild expression by visiting sub-expressions + Call @callback from leaves to root expressions + """ + def __init__(self, callback): + super(ExprVisitorCallbackBottomToTop, self).__init__() + self.cache = dict() + self.callback = callback + + def visit(self, expr, *args, **kwargs): + if expr in self.cache: + return self.cache[expr] + ret = self.visit_inner(expr, *args, **kwargs) + self.cache[expr] = ret + return ret + + def visit_inner(self, expr, *args, **kwargs): + ret = super(ExprVisitorCallbackBottomToTop, self).visit(expr, *args, **kwargs) + ret = self.callback(ret) + return ret + + +class ExprVisitorCanonize(ExprVisitorCallbackBottomToTop): + def __init__(self): + super(ExprVisitorCanonize, self).__init__(self.canonize) + + def canonize(self, expr): + if not expr.is_op(): + return expr + if not expr.is_associative(): + return expr + + # ((a+b) + c) => (a + b + c) + args = [] + for arg in expr.args: + if isinstance(arg, ExprOp) and expr.op == arg.op: + args += arg.args + else: + args.append(arg) + args = canonize_expr_list(args) + new_expr = ExprOp(expr.op, *args) + return new_expr + + +class ExprVisitorContains(ExprWalkBase): + """ + Visitor to test if a needle is in an Expression + Cache results + """ + def __init__(self): + self.cache = set() + super(ExprVisitorContains, self).__init__(self.eq_expr) + + def eq_expr(self, expr, needle, *args, **kwargs): + if expr == needle: + return True + return None + + def visit(self, expr, needle, *args, **kwargs): + if (expr, needle) in self.cache: + return None + ret = super(ExprVisitorContains, self).visit(expr, needle, *args, **kwargs) + if ret: + return ret + self.cache.add((expr, needle)) + return None + + + def contains(self, expr, needle): + return self.visit(expr, needle) + +contains_visitor = ExprVisitorContains() +canonize_visitor = ExprVisitorCanonize() + # IR definitions class Expr(object): @@ -337,36 +625,16 @@ class Expr(object): """Find and replace sub expression using dct @dct: dictionary associating replaced Expr to its new Expr value """ - return self.visit(lambda expr: dct.get(expr, expr)) + def replace(expr): + if expr in dct: + return dct[expr] + return None + visitor = ExprVisitorCallbackTopToBottom(lambda expr:replace(expr)) + return visitor.visit(self) def canonize(self): "Canonize the Expression" - - def must_canon(expr): - return not expr.is_canon - - def canonize_visitor(expr): - if expr.is_canon: - return expr - if isinstance(expr, ExprOp): - if expr.is_associative(): - # ((a+b) + c) => (a + b + c) - args = [] - for arg in expr.args: - if isinstance(arg, ExprOp) and expr.op == arg.op: - args += arg.args - else: - args.append(arg) - args = canonize_expr_list(args) - new_e = ExprOp(expr.op, *args) - else: - new_e = expr - else: - new_e = expr - new_e.is_canon = True - return new_e - - return self.visit(canonize_visitor, must_canon) + return canonize_visitor.visit(self) def msb(self): "Return the Most Significant Bit" @@ -424,6 +692,10 @@ class Expr(object): return False def is_aff(self): + warnings.warn('DEPRECATION WARNING: use is_assign()') + return False + + def is_assign(self): return False def is_cond(self): @@ -449,6 +721,32 @@ class Expr(object): """Returns True if is ExprMem and ptr is_op_segm""" return False + def __contains__(self, expr): + ret = contains_visitor.contains(self, expr) + return ret + + def visit(self, callback): + """ + Apply callback to all sub expression of @self + This function keeps a cache to avoid rerunning @callback on common sub + expressions. + + @callback: fn(Expr) -> Expr + """ + visitor = ExprVisitorCallbackBottomToTop(callback) + return visitor.visit(self) + + def get_r(self, mem_read=False, cst_read=False): + visitor = ExprGetR(mem_read, cst_read) + visitor.visit(self) + return visitor.elements + + + def get_w(self, mem_read=False, cst_read=False): + if self.is_assign(): + return set([self.dst]) + return set() + class ExprInt(Expr): """An ExprInt represent a constant in Miasm IR. @@ -508,12 +806,6 @@ class ExprInt(Expr): else: return str("0x%X" % self._get_int()) - def get_r(self, mem_read=False, cst_read=False): - if cst_read: - return set([self]) - else: - return set() - def get_w(self): return set() @@ -524,13 +816,6 @@ class ExprInt(Expr): return "%s(0x%X, %d)" % (self.__class__.__name__, self._get_int(), self._size) - def __contains__(self, expr): - return self == expr - - @visit_chk - def visit(self, callback, test_visit=None): - return self - def copy(self): return ExprInt(self._arg, self._size) @@ -591,9 +876,6 @@ class ExprId(Expr): def __str__(self): return str(self._name) - def get_r(self, mem_read=False, cst_read=False): - return set([self]) - def get_w(self): return set([self]) @@ -603,13 +885,6 @@ class ExprId(Expr): def _exprrepr(self): return "%s(%r, %d)" % (self.__class__.__name__, self._name, self._size) - def __contains__(self, expr): - return self == expr - - @visit_chk - def visit(self, callback, test_visit=None): - return self - def copy(self): return ExprId(self._name, self._size) @@ -653,12 +928,6 @@ class ExprLoc(Expr): def __str__(self): return str(self._loc_key) - def get_r(self, mem_read=False, cst_read=False): - if cst_read: - return set([self]) - else: - return set() - def get_w(self): return set() @@ -668,13 +937,6 @@ class ExprLoc(Expr): def _exprrepr(self): return "%s(%r, %d)" % (self.__class__.__name__, self._loc_key, self._size) - def __contains__(self, expr): - return self == expr - - @visit_chk - def visit(self, callback, test_visit=None): - return self - def copy(self): return ExprLoc(self._loc_key, self._size) @@ -745,12 +1007,6 @@ class ExprAssign(Expr): def __str__(self): return "%s = %s" % (str(self._dst), str(self._src)) - def get_r(self, mem_read=False, cst_read=False): - elements = self._src.get_r(mem_read, cst_read) - if isinstance(self._dst, ExprMem) and mem_read: - elements.update(self._dst.ptr.get_r(mem_read, cst_read)) - return elements - def get_w(self): if isinstance(self._dst, ExprMem): return set([self._dst]) # [memreg] @@ -763,19 +1019,6 @@ class ExprAssign(Expr): def _exprrepr(self): return "%s(%r, %r)" % (self.__class__.__name__, self._dst, self._src) - def __contains__(self, expr): - return (self == expr or - self._src.__contains__(expr) or - self._dst.__contains__(expr)) - - @visit_chk - def visit(self, callback, test_visit=None): - dst, src = self._dst.visit(callback, test_visit), self._src.visit(callback, test_visit) - if dst == self._dst and src == self._src: - return self - else: - return ExprAssign(dst, src) - def copy(self): return ExprAssign(self._dst.copy(), self._src.copy()) @@ -788,7 +1031,12 @@ class ExprAssign(Expr): arg.graph_recursive(graph) graph.add_uniq_edge(self, arg) + def is_aff(self): + warnings.warn('DEPRECATION WARNING: use is_assign()') + return True + + def is_assign(self): return True @@ -845,12 +1093,6 @@ class ExprCond(Expr): def __str__(self): return "%s?(%s,%s)" % (str_protected_child(self._cond, self), str(self._src1), str(self._src2)) - def get_r(self, mem_read=False, cst_read=False): - out_src1 = self.src1.get_r(mem_read, cst_read) - out_src2 = self.src2.get_r(mem_read, cst_read) - return self.cond.get_r(mem_read, - cst_read).union(out_src1).union(out_src2) - def get_w(self): return set() @@ -862,21 +1104,6 @@ class ExprCond(Expr): return "%s(%r, %r, %r)" % (self.__class__.__name__, self._cond, self._src1, self._src2) - def __contains__(self, expr): - return (self == expr or - self.cond.__contains__(expr) or - self.src1.__contains__(expr) or - self.src2.__contains__(expr)) - - @visit_chk - def visit(self, callback, test_visit=None): - cond = self._cond.visit(callback, test_visit) - src1 = self._src1.visit(callback, test_visit) - src2 = self._src2.visit(callback, test_visit) - if cond == self._cond and src1 == self._src1 and src2 == self._src2: - return self - return ExprCond(cond, src1, src2) - def copy(self): return ExprCond(self._cond.copy(), self._src1.copy(), @@ -953,12 +1180,6 @@ class ExprMem(Expr): def __str__(self): return "@%d[%s]" % (self.size, str(self.ptr)) - def get_r(self, mem_read=False, cst_read=False): - if mem_read: - return set(self._ptr.get_r(mem_read, cst_read).union(set([self]))) - else: - return set([self]) - def get_w(self): return set([self]) # [memreg] @@ -969,16 +1190,6 @@ class ExprMem(Expr): return "%s(%r, %r)" % (self.__class__.__name__, self._ptr, self._size) - def __contains__(self, expr): - return self == expr or self._ptr.__contains__(expr) - - @visit_chk - def visit(self, callback, test_visit=None): - ptr = self._ptr.visit(callback, test_visit) - if ptr == self._ptr: - return self - return ExprMem(ptr, self.size) - def copy(self): ptr = self.ptr.copy() return ExprMem(ptr, size=self.size) @@ -1108,10 +1319,6 @@ class ExprOp(Expr): return (self._op + '(' + ', '.join([str(arg) for arg in self._args]) + ')') - def get_r(self, mem_read=False, cst_read=False): - return reduce(lambda elements, arg: - elements.union(arg.get_r(mem_read, cst_read)), self._args, set()) - def get_w(self): raise ValueError('op cannot be written!', self) @@ -1123,14 +1330,6 @@ class ExprOp(Expr): return "%s(%r, %s)" % (self.__class__.__name__, self._op, ', '.join(repr(arg) for arg in self._args)) - def __contains__(self, expr): - if self == expr: - return True - for arg in self._args: - if arg.__contains__(expr): - return True - return False - def is_function_call(self): return self._op.startswith('call') @@ -1153,14 +1352,6 @@ class ExprOp(Expr): "Return True iff current operation is commutative" return (self._op in ['+', '*', '^', '&', '|']) - @visit_chk - def visit(self, callback, test_visit=None): - args = [arg.visit(callback, test_visit) for arg in self._args] - modified = any([arg[0] != arg[1] for arg in zip(self._args, args)]) - if modified: - return ExprOp(self._op, *args) - return self - def copy(self): args = [arg.copy() for arg in self._args] return ExprOp(self._op, *args) @@ -1213,9 +1404,6 @@ class ExprSlice(Expr): def __str__(self): return "%s[%d:%d]" % (str_protected_child(self._arg, self), self._start, self._stop) - def get_r(self, mem_read=False, cst_read=False): - return self._arg.get_r(mem_read, cst_read) - def get_w(self): return self._arg.get_w() @@ -1226,18 +1414,6 @@ class ExprSlice(Expr): return "%s(%r, %d, %d)" % (self.__class__.__name__, self._arg, self._start, self._stop) - def __contains__(self, expr): - if self == expr: - return True - return self._arg.__contains__(expr) - - @visit_chk - def visit(self, callback, test_visit=None): - arg = self._arg.visit(callback, test_visit) - if arg == self._arg: - return self - return ExprSlice(arg, self._start, self._stop) - def copy(self): return ExprSlice(self._arg.copy(), self._start, self._stop) @@ -1310,10 +1486,6 @@ class ExprCompose(Expr): def __str__(self): return '{' + ', '.join(["%s %s %s" % (arg, idx, idx + arg.size) for idx, arg in self.iter_args()]) + '}' - def get_r(self, mem_read=False, cst_read=False): - return reduce(lambda elements, arg: - elements.union(arg.get_r(mem_read, cst_read)), self._args, set()) - def get_w(self): return reduce(lambda elements, arg: elements.union(arg.get_w()), self._args, set()) @@ -1325,24 +1497,6 @@ class ExprCompose(Expr): def _exprrepr(self): return "%s%r" % (self.__class__.__name__, self._args) - def __contains__(self, expr): - if self == expr: - return True - for arg in self._args: - if arg == expr: - return True - if arg.__contains__(expr): - return True - return False - - @visit_chk - def visit(self, callback, test_visit=None): - args = [arg.visit(callback, test_visit) for arg in self._args] - modified = any([arg != arg_new for arg, arg_new in zip(self._args, args)]) - if modified: - return ExprCompose(*args) - return self - def copy(self): args = [arg.copy() for arg in self._args] return ExprCompose(*args) @@ -1669,8 +1823,8 @@ def match_expr(expr, pattern, tks, result=None): return False return result - elif expr.is_aff(): - if not pattern.is_aff(): + elif expr.is_assign(): + if not pattern.is_assign(): return False if match_expr(expr.src, pattern.src, tks, result) is False: return False diff --git a/miasm/expression/simplifications.py b/miasm/expression/simplifications.py index 03a779a6..3f54b158 100644 --- a/miasm/expression/simplifications.py +++ b/miasm/expression/simplifications.py @@ -11,6 +11,7 @@ from miasm.expression import simplifications_cond from miasm.expression import simplifications_explicit from miasm.expression.expression_helper import fast_unify import miasm.expression.expression as m2_expr +from miasm.expression.expression import ExprVisitorCallbackBottomToTop # Expression Simplifier # --------------------- @@ -22,7 +23,7 @@ log_exprsimp.addHandler(console_handler) log_exprsimp.setLevel(logging.WARNING) -class ExpressionSimplifier(object): +class ExpressionSimplifier(ExprVisitorCallbackBottomToTop): """Wrapper on expression simplification passes. @@ -49,6 +50,8 @@ class ExpressionSimplifier(object): simplifications_common.simp_double_signext, simplifications_common.simp_zeroext_eq_cst, simplifications_common.simp_ext_eq_ext, + simplifications_common.simp_ext_cond_int, + simplifications_common.simp_sub_cf_zero, simplifications_common.simp_cmp_int, simplifications_common.simp_cmp_bijective_op, @@ -118,8 +121,8 @@ class ExpressionSimplifier(object): def __init__(self): + super(ExpressionSimplifier, self).__init__(self.expr_simp_inner) self.expr_simp_cb = {} - self.simplified_exprs = set() def enable_passes(self, passes): """Add passes from @passes @@ -129,7 +132,7 @@ class ExpressionSimplifier(object): """ # Clear cache of simplifiied expressions when adding a new pass - self.simplified_exprs.clear() + self.cache.clear() for k, v in viewitems(passes): self.expr_simp_cb[k] = fast_unify(self.expr_simp_cb.get(k, []) + v) @@ -156,46 +159,29 @@ class ExpressionSimplifier(object): return expression - def expr_simp(self, expression): + def expr_simp_inner(self, expression): """Apply enabled simplifications on expression and find a stable state @expression: Expr instance Return an Expr instance""" - if expression in self.simplified_exprs: - return expression - # Find a stable state while True: # Canonize and simplify - e_new = self.apply_simp(expression.canonize()) - if e_new == expression: - break - - # Launch recursivity - expression = self.expr_simp_wrapper(e_new) - self.simplified_exprs.add(expression) - # Mark expression as simplified - self.simplified_exprs.add(e_new) - - return e_new - - def expr_simp_wrapper(self, expression, callback=None): - """Apply enabled simplifications on expression - @expression: Expr instance - @manual_callback: If set, call this function instead of normal one - Return an Expr instance""" + new_expr = self.apply_simp(expression.canonize()) + if new_expr == expression: + return new_expr + # Run recursively simplification on fresh new expression + new_expr = self.visit(new_expr) + expression = new_expr + return new_expr - if expression in self.simplified_exprs: - return expression - - if callback is None: - callback = self.expr_simp - - return expression.visit(callback, lambda e: e not in self.simplified_exprs) + def expr_simp(self, expression): + "Call simplification recursively" + return self.visit(expression) - def __call__(self, expression, callback=None): - "Wrapper on expr_simp_wrapper" - return self.expr_simp_wrapper(expression, callback) + def __call__(self, expression): + "Call simplification recursively" + return self.visit(expression) # Public ExprSimplificationPass instance with commons passes diff --git a/miasm/expression/simplifications_common.py b/miasm/expression/simplifications_common.py index 1c0bb04c..932db49a 100644 --- a/miasm/expression/simplifications_common.py +++ b/miasm/expression/simplifications_common.py @@ -32,30 +32,30 @@ def simp_cst_propagation(e_s, expr): int2 = args.pop() int1 = args.pop() if op_name == '+': - out = int1.arg + int2.arg + out = mod_size2uint[int1.size](int(int1) + int(int2)) elif op_name == '*': - out = int1.arg * int2.arg + out = mod_size2uint[int1.size](int(int1) * int(int2)) elif op_name == '**': - out =int1.arg ** int2.arg + out = mod_size2uint[int1.size](int(int1) ** int(int2)) elif op_name == '^': - out = int1.arg ^ int2.arg + out = mod_size2uint[int1.size](int(int1) ^ int(int2)) elif op_name == '&': - out = int1.arg & int2.arg + out = mod_size2uint[int1.size](int(int1) & int(int2)) elif op_name == '|': - out = int1.arg | int2.arg + out = mod_size2uint[int1.size](int(int1) | int(int2)) elif op_name == '>>': if int(int2) > int1.size: out = 0 else: - out = int1.arg >> int2.arg + out = mod_size2uint[int1.size](int(int1) >> int(int2)) elif op_name == '<<': if int(int2) > int1.size: out = 0 else: - out = int1.arg << int2.arg + out = mod_size2uint[int1.size](int(int1) << int(int2)) elif op_name == 'a>>': - tmp1 = mod_size2int[int1.arg.size](int1.arg) - tmp2 = mod_size2uint[int2.arg.size](int2.arg) + tmp1 = mod_size2int[int1.size](int(int1)) + tmp2 = mod_size2uint[int2.size](int(int2)) if tmp2 > int1.size: is_signed = int(int1) & (1 << (int1.size - 1)) if is_signed: @@ -63,55 +63,57 @@ def simp_cst_propagation(e_s, expr): else: out = 0 else: - out = mod_size2uint[int1.arg.size](tmp1 >> tmp2) + out = mod_size2uint[int1.size](tmp1 >> tmp2) elif op_name == '>>>': - shifter = int2.arg % int2.size - out = (int1.arg >> shifter) | (int1.arg << (int2.size - shifter)) + shifter = int(int2) % int2.size + out = (int(int1) >> shifter) | (int(int1) << (int2.size - shifter)) elif op_name == '<<<': - shifter = int2.arg % int2.size - out = (int1.arg << shifter) | (int1.arg >> (int2.size - shifter)) + shifter = int(int2) % int2.size + out = (int(int1) << shifter) | (int(int1) >> (int2.size - shifter)) elif op_name == '/': - out = int1.arg // int2.arg + assert int(int2), "division by 0" + out = int(int1) // int(int2) elif op_name == '%': - out = int1.arg % int2.arg + assert int(int2), "division by 0" + out = int(int1) % int(int2) elif op_name == 'sdiv': - assert int2.arg.arg - tmp1 = mod_size2int[int1.arg.size](int1.arg) - tmp2 = mod_size2int[int2.arg.size](int2.arg) - out = mod_size2uint[int1.arg.size](tmp1 // tmp2) + assert int(int2), "division by 0" + tmp1 = mod_size2int[int1.size](int(int1)) + tmp2 = mod_size2int[int2.size](int(int2)) + out = mod_size2uint[int1.size](tmp1 // tmp2) elif op_name == 'smod': - assert int2.arg.arg - tmp1 = mod_size2int[int1.arg.size](int1.arg) - tmp2 = mod_size2int[int2.arg.size](int2.arg) - out = mod_size2uint[int1.arg.size](tmp1 % tmp2) + assert int(int2), "division by 0" + tmp1 = mod_size2int[int1.size](int(int1)) + tmp2 = mod_size2int[int2.size](int(int2)) + out = mod_size2uint[int1.size](tmp1 % tmp2) elif op_name == 'umod': - assert int2.arg.arg - tmp1 = mod_size2uint[int1.arg.size](int1.arg) - tmp2 = mod_size2uint[int2.arg.size](int2.arg) - out = mod_size2uint[int1.arg.size](tmp1 % tmp2) + assert int(int2), "division by 0" + tmp1 = mod_size2uint[int1.size](int(int1)) + tmp2 = mod_size2uint[int2.size](int(int2)) + out = mod_size2uint[int1.size](tmp1 % tmp2) elif op_name == 'udiv': - assert int2.arg.arg - tmp1 = mod_size2uint[int1.arg.size](int1.arg) - tmp2 = mod_size2uint[int2.arg.size](int2.arg) - out = mod_size2uint[int1.arg.size](tmp1 // tmp2) + assert int(int2), "division by 0" + tmp1 = mod_size2uint[int1.size](int(int1)) + tmp2 = mod_size2uint[int2.size](int(int2)) + out = mod_size2uint[int1.size](tmp1 // tmp2) - args.append(ExprInt(out, int1.size)) + args.append(ExprInt(int(out), int1.size)) # cnttrailzeros(int) => int if op_name == "cnttrailzeros" and args[0].is_int(): i = 0 - while args[0].arg & (1 << i) == 0 and i < args[0].size: + while int(args[0]) & (1 << i) == 0 and i < args[0].size: i += 1 return ExprInt(i, args[0].size) # cntleadzeros(int) => int if op_name == "cntleadzeros" and args[0].is_int(): - if args[0].arg == 0: + if int(args[0]) == 0: return ExprInt(args[0].size, args[0].size) i = args[0].size - 1 - while args[0].arg & (1 << i) == 0: + while int(args[0]) & (1 << i) == 0: i -= 1 return ExprInt(expr.size - (i + 1), args[0].size) @@ -120,6 +122,7 @@ def simp_cst_propagation(e_s, expr): len(args[0].args) == 1): return args[0].args[0] + # -(int) => -int if op_name == '-' and len(args) == 1 and args[0].is_int(): return ExprInt(-int(args[0]), expr.size) @@ -207,13 +210,13 @@ def simp_cst_propagation(e_s, expr): j += 1 i += 1 - if op_name in ['|', '&', '%', '/', '**'] and len(args) == 1: + if op_name in ['+', '^', '|', '&', '%', '/', '**'] and len(args) == 1: return args[0] # A <<< A.size => A if (op_name in ['<<<', '>>>'] and args[1].is_int() and - args[1].arg == args[0].size): + int(args[1]) == args[0].size): return args[0] # (A <<< X) <<< Y => A <<< (X+Y) (or <<< >>>) if X + Y does not overflow @@ -277,7 +280,10 @@ def simp_cst_propagation(e_s, expr): # ((A & A.mask) if op_name == "&" and args[-1] == expr.mask: - return ExprOp('&', *args[:-1]) + args = args[:-1] + if len(args) == 1: + return args[0] + return ExprOp('&', *args) # ((A | A.mask) if op_name == "|" and args[-1] == expr.mask: @@ -289,7 +295,7 @@ def simp_cst_propagation(e_s, expr): # ((A & mask) >> shift) with mask < 2**shift => 0 if op_name == ">>" and args[1].is_int() and args[0].is_op("&"): if (args[0].args[1].is_int() and - 2 ** args[1].arg > args[0].args[1].arg): + 2 ** int(args[1]) > int(args[0].args[1])): return ExprInt(0, args[0].size) # parity(int) => int @@ -315,7 +321,6 @@ def simp_cst_propagation(e_s, expr): args = args[0].args return ExprOp('*', *(list(args[:-1]) + [ExprInt(-int(args[-1]), expr.size)])) - # A << int with A ExprCompose => move index if (op_name == "<<" and args[0].is_compose() and args[1].is_int() and int(args[1]) != 0): @@ -450,8 +455,8 @@ def simp_cond_factor(e_s, expr): for cond, vals in viewitems(conds): new_src1 = [x.src1 for x in vals] new_src2 = [x.src2 for x in vals] - src1 = e_s.expr_simp_wrapper(ExprOp(expr.op, *new_src1)) - src2 = e_s.expr_simp_wrapper(ExprOp(expr.op, *new_src2)) + src1 = e_s.expr_simp(ExprOp(expr.op, *new_src1)) + src2 = e_s.expr_simp(ExprOp(expr.op, *new_src2)) c_out.append(ExprCond(cond, src1, src2)) if len(c_out) == 1: @@ -471,7 +476,7 @@ def simp_slice(e_s, expr): if expr.arg.is_int(): total_bit = expr.stop - expr.start mask = (1 << (expr.stop - expr.start)) - 1 - return ExprInt(int((expr.arg.arg >> expr.start) & mask), total_bit) + return ExprInt(int((int(expr.arg) >> expr.start) & mask), total_bit) # Slice(Slice(A, x), y) => Slice(A, z) if expr.arg.is_slice(): if expr.stop - expr.start > expr.arg.stop - expr.arg.start: @@ -521,7 +526,7 @@ def simp_slice(e_s, expr): # distributivity of slice and & # (a & int)[x:y] => 0 if int[x:y] == 0 if expr.arg.is_op("&") and expr.arg.args[-1].is_int(): - tmp = e_s.expr_simp_wrapper(expr.arg.args[-1][expr.start:expr.stop]) + tmp = e_s.expr_simp(expr.arg.args[-1][expr.start:expr.stop]) if tmp.is_int(0): return tmp # distributivity of slice and exprcond @@ -536,7 +541,7 @@ def simp_slice(e_s, expr): # (a * int)[0:y] => (a[0:y] * int[0:y]) if expr.start == 0 and expr.arg.is_op("*") and expr.arg.args[-1].is_int(): - args = [e_s.expr_simp_wrapper(a[expr.start:expr.stop]) for a in expr.arg.args] + args = [e_s.expr_simp(a[expr.start:expr.stop]) for a in expr.arg.args] return ExprOp(expr.arg.op, *args) # (a >> int)[x:y] => a[x+int:y+int] with int+y <= a.size @@ -626,7 +631,7 @@ def simp_cond(_, expr): expr = expr.src1 # int ? A:B => A or B elif expr.cond.is_int(): - if expr.cond.arg == 0: + if int(expr.cond) == 0: expr = expr.src2 else: expr = expr.src1 @@ -646,8 +651,8 @@ def simp_cond(_, expr): elif (expr.cond.is_cond() and expr.cond.src1.is_int() and expr.cond.src2.is_int()): - int1 = expr.cond.src1.arg.arg - int2 = expr.cond.src2.arg.arg + int1 = int(expr.cond.src1) + int2 = int(expr.cond.src2) if int1 and int2: expr = expr.src1 elif int1 == 0 and int2 == 0: @@ -906,6 +911,15 @@ def simp_cond_flag(_, expr): return expr +def simp_sub_cf_zero(_, expr): + """FLAG_SUB_CF(0, X) => (X)?1:0""" + if not expr.is_op("FLAG_SUB_CF"): + return expr + if not expr.args[0].is_int(0): + return expr + return ExprCond(expr.args[1], ExprInt(1, 1), ExprInt(0, 1)) + + def simp_cmp_int(expr_simp, expr): """ ({X, 0} == int) => X == int[:] @@ -1069,6 +1083,13 @@ def simp_cmp_bijective_op(expr_simp, expr): args_a.remove(value) args_b.remove(value) + # a + b == a + b + c + if not args_a: + return ExprOp(TOK_EQUAL, ExprOp(op, *args_b), ExprInt(0, args_b[0].size)) + # a + b + c == a + b + if not args_b: + return ExprOp(TOK_EQUAL, ExprOp(op, *args_a), ExprInt(0, args_a[0].size)) + arg_a = ExprOp(op, *args_a) arg_b = ExprOp(op, *args_b) return ExprOp(TOK_EQUAL, arg_a, arg_b) @@ -1362,6 +1383,23 @@ def simp_ext_cst(_, expr): return ret + +def simp_ext_cond_int(e_s, expr): + """ + zeroExt(ExprCond(X, Int, Int)) => ExprCond(X, Int, Int) + """ + if not (expr.op.startswith("zeroExt") or expr.op.startswith("signExt")): + return expr + arg = expr.args[0] + if not arg.is_cond(): + return expr + if not (arg.src1.is_int() and arg.src2.is_int()): + return expr + src1 = ExprOp(expr.op, arg.src1) + src2 = ExprOp(expr.op, arg.src2) + return e_s(ExprCond(arg.cond, src1, src2)) + + def simp_slice_of_ext(_, expr): """ C.zeroExt(X)[A:B] => 0 if A >= size(C) diff --git a/miasm/ir/ir.py b/miasm/ir/ir.py index 613a4bce..3219b5fc 100644 --- a/miasm/ir/ir.py +++ b/miasm/ir/ir.py @@ -44,6 +44,26 @@ def _expr_loc_to_symb(expr, loc_db): name = sorted(names)[0] return m2_expr.ExprId(name, expr.size) +def slice_rest(expr): + "Return the completion of the current slice" + size = expr.arg.size + if expr.start >= size or expr.stop > size: + raise ValueError('bad slice rest %s %s %s' % + (size, expr.start, expr.stop)) + + if expr.start == expr.stop: + return [(0, size)] + + rest = [] + if expr.start != 0: + rest.append((0, expr.start)) + if expr.stop < size: + rest.append((expr.stop, size)) + + return rest + + + class AssignBlock(object): """Represent parallel IR assignment, such as: EAX = EBX @@ -99,7 +119,7 @@ class AssignBlock(object): # Complete the source with missing slice parts new_dst = dst.arg rest = [(m2_expr.ExprSlice(dst.arg, r[0], r[1]), r[0], r[1]) - for r in dst.slice_rest()] + for r in slice_rest(dst)] all_a = [(src, dst.start, dst.stop)] + rest all_a.sort(key=lambda x: x[1]) args = [expr for (expr, _, _) in all_a] @@ -752,7 +772,7 @@ class IntermediateRepresentation(object): if loc_key is None: offset = getattr(instr, "offset", None) - loc_key = self.loc_db.add_location(offset=offset) + loc_key = self.loc_db.get_or_create_offset_location(offset) block = AsmBlock(loc_key) block.lines = [instr] self.add_asmblock_to_ircfg(block, ircfg, gen_pc_updt) @@ -865,7 +885,7 @@ class IntermediateRepresentation(object): return irblock def is_pc_written(self, block): - """Return the first Assignblk of the @blockin which PC is written + """Return the first Assignblk of the @block in which PC is written @block: IRBlock instance""" all_pc = viewvalues(self.arch.pc) for assignblk in block: diff --git a/miasm/ir/symbexec.py b/miasm/ir/symbexec.py index 943c8b03..8c6245b8 100644 --- a/miasm/ir/symbexec.py +++ b/miasm/ir/symbexec.py @@ -121,7 +121,7 @@ class MemArray(MutableMapping): content relatively to an integer offset from *base*. The value associated to a given offset is a description of the slice of a - stored expression. The slice size depends on the configutation of the + stored expression. The slice size depends on the configuration of the MemArray. For example, for a slice size of 8 bits, the assignment: - @32[EAX+0x10] = EBX @@ -567,7 +567,7 @@ class MemSparse(object): memarray = self.base_to_memarray.get(base, None) if memarray is not None: mems = memarray.read(offset, size) - ret = ExprCompose(*mems) + ret = mems[0] if len(mems) == 1 else ExprCompose(*mems) else: ret = ExprMem(ptr, size) return ret @@ -1096,7 +1096,7 @@ class SymbolicExecutionEngine(object): """ # Update value if needed - if expr.is_aff(): + if expr.is_assign(): ret = self.eval_expr(expr.src) self.eval_updt_assignblk(AssignBlock([expr])) else: diff --git a/miasm/ir/translators/C.py b/miasm/ir/translators/C.py index 6be5d961..7778abf7 100644 --- a/miasm/ir/translators/C.py +++ b/miasm/ir/translators/C.py @@ -3,7 +3,8 @@ from miasm.expression.modint import size2mask from miasm.expression.expression import ExprInt, ExprCond, ExprCompose, \ TOK_EQUAL, \ TOK_INF_SIGNED, TOK_INF_UNSIGNED, \ - TOK_INF_EQUAL_SIGNED, TOK_INF_EQUAL_UNSIGNED + TOK_INF_EQUAL_SIGNED, TOK_INF_EQUAL_UNSIGNED, \ + is_associative def int_size_to_bn(value, size): if size < 32: @@ -288,7 +289,7 @@ class TranslatorC(Translator): out = "bignum_mask(%s, %d)"% (out, expr.size) return out - elif expr.is_associative(): + elif is_associative(expr): args = [self.from_expr(arg) for arg in expr.args] if expr.size <= self.NATIVE_INT_MAX_SIZE: @@ -466,7 +467,7 @@ class TranslatorC(Translator): else: raise NotImplementedError('Unknown op: %r' % expr.op) - elif len(expr.args) >= 3 and expr.is_associative(): # ????? + elif len(expr.args) >= 3 and is_associative(expr): # ????? oper = ['(%s&%s)' % ( self.from_expr(arg), self._size2mask(arg.size), diff --git a/miasm/ir/translators/z3_ir.py b/miasm/ir/translators/z3_ir.py index 3452f162..1a36e94e 100644 --- a/miasm/ir/translators/z3_ir.py +++ b/miasm/ir/translators/z3_ir.py @@ -15,7 +15,7 @@ log.addHandler(console_handler) log.setLevel(logging.WARNING) class Z3Mem(object): - """Memory abstration for TranslatorZ3. Memory elements are only accessed, + """Memory abstraction for TranslatorZ3. Memory elements are only accessed, never written. To give a concrete value for a given memory cell in a solver, add "mem32.get(address, size) == <value>" constraints to your equation. The endianness of memory accesses is handled accordingly to the "endianness" @@ -129,7 +129,7 @@ class TranslatorZ3(Translator): self.loc_db = loc_db def from_ExprInt(self, expr): - return z3.BitVecVal(expr.arg.arg, expr.size) + return z3.BitVecVal(int(expr), expr.size) def from_ExprId(self, expr): return z3.BitVec(str(expr), expr.size) diff --git a/miasm/jitter/JitCore.c b/miasm/jitter/JitCore.c index 1ba082c5..dfead5a8 100644 --- a/miasm/jitter/JitCore.c +++ b/miasm/jitter/JitCore.c @@ -41,6 +41,25 @@ PyObject * JitCpu_set_vmmngr(JitCpu *self, PyObject *value, void *closure) return 0; } + + +PyObject * JitCpu_get_vmcpu(JitCpu *self, void *closure) +{ + PyObject * ret; + uint64_t addr; + addr = (uint64_t) self->cpu; + ret = PyLong_FromUnsignedLongLong(addr); + return ret; +} + +PyObject * JitCpu_set_vmcpu(JitCpu *self, PyObject *value, void *closure) +{ + fprintf(stderr, "Set vmcpu not supported yet\n"); + exit(-1); +} + + + PyObject * JitCpu_get_jitter(JitCpu *self, void *closure) { if (self->jitter) { diff --git a/miasm/jitter/JitCore.h b/miasm/jitter/JitCore.h index 7b7f6c13..ff6ff159 100644 --- a/miasm/jitter/JitCore.h +++ b/miasm/jitter/JitCore.h @@ -203,6 +203,8 @@ void JitCpu_dealloc(JitCpu* self); PyObject * JitCpu_new(PyTypeObject *type, PyObject *args, PyObject *kwds); PyObject * JitCpu_get_vmmngr(JitCpu *self, void *closure); PyObject * JitCpu_set_vmmngr(JitCpu *self, PyObject *value, void *closure); +PyObject * JitCpu_get_vmcpu(JitCpu *self, void *closure); +PyObject * JitCpu_set_vmcpu(JitCpu *self, PyObject *value, void *closure); PyObject * JitCpu_get_jitter(JitCpu *self, void *closure); PyObject * JitCpu_set_jitter(JitCpu *self, PyObject *value, void *closure); void Resolve_dst(block_id* BlockDst, uint64_t addr, uint64_t is_local); diff --git a/miasm/jitter/arch/JitCore_mips32.h b/miasm/jitter/arch/JitCore_mips32.h index 74eb35ef..8478fb53 100644 --- a/miasm/jitter/arch/JitCore_mips32.h +++ b/miasm/jitter/arch/JitCore_mips32.h @@ -83,7 +83,7 @@ struct vm_cpu { uint32_t CPR0_5; uint32_t CPR0_6; uint32_t CPR0_7; - uint32_t CPR0_8; + uint32_t RANDOM; uint32_t CPR0_9; uint32_t CPR0_10; uint32_t CPR0_11; @@ -107,8 +107,8 @@ struct vm_cpu { uint32_t CPR0_29; uint32_t CPR0_30; uint32_t CPR0_31; - uint32_t CPR0_32; - uint32_t CPR0_33; + uint32_t CONTEXT; + uint32_t CONTEXTCONFIG; uint32_t CPR0_34; uint32_t CPR0_35; uint32_t CPR0_36; @@ -116,20 +116,20 @@ struct vm_cpu { uint32_t CPR0_38; uint32_t CPR0_39; uint32_t PAGEMASK; - uint32_t CPR0_41; - uint32_t CPR0_42; - uint32_t CPR0_43; - uint32_t CPR0_44; - uint32_t CPR0_45; - uint32_t CPR0_46; - uint32_t CPR0_47; - uint32_t CPR0_48; + uint32_t PAGEGRAIN; + uint32_t SEGCTL0; + uint32_t SEGCTL1; + uint32_t SEGCTL2; + uint32_t PWBASE; + uint32_t PWFIELD; + uint32_t PWSIZE; + uint32_t WIRED; uint32_t CPR0_49; uint32_t CPR0_50; uint32_t CPR0_51; uint32_t CPR0_52; uint32_t CPR0_53; - uint32_t CPR0_54; + uint32_t PWCTL; uint32_t CPR0_55; uint32_t CPR0_56; uint32_t CPR0_57; @@ -139,9 +139,9 @@ struct vm_cpu { uint32_t CPR0_61; uint32_t CPR0_62; uint32_t CPR0_63; - uint32_t CPR0_64; - uint32_t CPR0_65; - uint32_t CPR0_66; + uint32_t BADVADDR; + uint32_t BADINSTR; + uint32_t BADINSTRP; uint32_t CPR0_67; uint32_t CPR0_68; uint32_t CPR0_69; @@ -195,8 +195,8 @@ struct vm_cpu { uint32_t CPR0_117; uint32_t CPR0_118; uint32_t CPR0_119; - uint32_t CPR0_120; - uint32_t CPR0_121; + uint32_t PRID; + uint32_t EBASE; uint32_t CPR0_122; uint32_t CPR0_123; uint32_t CPR0_124; @@ -204,11 +204,11 @@ struct vm_cpu { uint32_t CPR0_126; uint32_t CPR0_127; uint32_t CONFIG; - uint32_t CPR0_129; - uint32_t CPR0_130; - uint32_t CPR0_131; - uint32_t CPR0_132; - uint32_t CPR0_133; + uint32_t CONFIG1; + uint32_t CONFIG2; + uint32_t CONFIG3; + uint32_t CONFIG4; + uint32_t CONFIG5; uint32_t CPR0_134; uint32_t CPR0_135; uint32_t CPR0_136; @@ -325,12 +325,12 @@ struct vm_cpu { uint32_t CPR0_247; uint32_t CPR0_248; uint32_t CPR0_249; - uint32_t CPR0_250; - uint32_t CPR0_251; - uint32_t CPR0_252; - uint32_t CPR0_253; - uint32_t CPR0_254; - uint32_t CPR0_255; + uint32_t KSCRATCH0; + uint32_t KSCRATCH1; + uint32_t KSCRATCH2; + uint32_t KSCRATCH3; + uint32_t KSCRATCH4; + uint32_t KSCRATCH5; }; _MIASM_EXPORT void dump_gpregs(struct vm_cpu* vmcpu); diff --git a/miasm/jitter/arch/JitCore_ppc32_regs.h b/miasm/jitter/arch/JitCore_ppc32_regs.h index a16d1e95..79191d32 100644 --- a/miasm/jitter/arch/JitCore_ppc32_regs.h +++ b/miasm/jitter/arch/JitCore_ppc32_regs.h @@ -121,3 +121,72 @@ JITCORE_PPC_REG_EXPAND(DBAT2L, 32) JITCORE_PPC_REG_EXPAND(DBAT3U, 32) JITCORE_PPC_REG_EXPAND(DBAT3L, 32) JITCORE_PPC_REG_EXPAND(SDR1, 32) + +JITCORE_PPC_REG_EXPAND(FPR0, 64) +JITCORE_PPC_REG_EXPAND(FPR1, 64) +JITCORE_PPC_REG_EXPAND(FPR2, 64) +JITCORE_PPC_REG_EXPAND(FPR3, 64) +JITCORE_PPC_REG_EXPAND(FPR4, 64) +JITCORE_PPC_REG_EXPAND(FPR5, 64) +JITCORE_PPC_REG_EXPAND(FPR6, 64) +JITCORE_PPC_REG_EXPAND(FPR7, 64) +JITCORE_PPC_REG_EXPAND(FPR8, 64) +JITCORE_PPC_REG_EXPAND(FPR9, 64) +JITCORE_PPC_REG_EXPAND(FPR10, 64) +JITCORE_PPC_REG_EXPAND(FPR11, 64) +JITCORE_PPC_REG_EXPAND(FPR12, 64) +JITCORE_PPC_REG_EXPAND(FPR13, 64) +JITCORE_PPC_REG_EXPAND(FPR14, 64) +JITCORE_PPC_REG_EXPAND(FPR15, 64) +JITCORE_PPC_REG_EXPAND(FPR16, 64) +JITCORE_PPC_REG_EXPAND(FPR17, 64) +JITCORE_PPC_REG_EXPAND(FPR18, 64) +JITCORE_PPC_REG_EXPAND(FPR19, 64) +JITCORE_PPC_REG_EXPAND(FPR20, 64) +JITCORE_PPC_REG_EXPAND(FPR21, 64) +JITCORE_PPC_REG_EXPAND(FPR22, 64) +JITCORE_PPC_REG_EXPAND(FPR23, 64) +JITCORE_PPC_REG_EXPAND(FPR24, 64) +JITCORE_PPC_REG_EXPAND(FPR25, 64) +JITCORE_PPC_REG_EXPAND(FPR26, 64) +JITCORE_PPC_REG_EXPAND(FPR27, 64) +JITCORE_PPC_REG_EXPAND(FPR28, 64) +JITCORE_PPC_REG_EXPAND(FPR29, 64) +JITCORE_PPC_REG_EXPAND(FPR30, 64) +JITCORE_PPC_REG_EXPAND(FPR31, 64) +JITCORE_PPC_REG_EXPAND(FPSCR, 32) + +JITCORE_PPC_REG_EXPAND(VR0, 128) +JITCORE_PPC_REG_EXPAND(VR1, 128) +JITCORE_PPC_REG_EXPAND(VR2, 128) +JITCORE_PPC_REG_EXPAND(VR3, 128) +JITCORE_PPC_REG_EXPAND(VR4, 128) +JITCORE_PPC_REG_EXPAND(VR5, 128) +JITCORE_PPC_REG_EXPAND(VR6, 128) +JITCORE_PPC_REG_EXPAND(VR7, 128) +JITCORE_PPC_REG_EXPAND(VR8, 128) +JITCORE_PPC_REG_EXPAND(VR9, 128) +JITCORE_PPC_REG_EXPAND(VR10, 128) +JITCORE_PPC_REG_EXPAND(VR11, 128) +JITCORE_PPC_REG_EXPAND(VR12, 128) +JITCORE_PPC_REG_EXPAND(VR13, 128) +JITCORE_PPC_REG_EXPAND(VR14, 128) +JITCORE_PPC_REG_EXPAND(VR15, 128) +JITCORE_PPC_REG_EXPAND(VR16, 128) +JITCORE_PPC_REG_EXPAND(VR17, 128) +JITCORE_PPC_REG_EXPAND(VR18, 128) +JITCORE_PPC_REG_EXPAND(VR19, 128) +JITCORE_PPC_REG_EXPAND(VR20, 128) +JITCORE_PPC_REG_EXPAND(VR21, 128) +JITCORE_PPC_REG_EXPAND(VR22, 128) +JITCORE_PPC_REG_EXPAND(VR23, 128) +JITCORE_PPC_REG_EXPAND(VR24, 128) +JITCORE_PPC_REG_EXPAND(VR25, 128) +JITCORE_PPC_REG_EXPAND(VR26, 128) +JITCORE_PPC_REG_EXPAND(VR27, 128) +JITCORE_PPC_REG_EXPAND(VR28, 128) +JITCORE_PPC_REG_EXPAND(VR29, 128) +JITCORE_PPC_REG_EXPAND(VR30, 128) +JITCORE_PPC_REG_EXPAND(VR31, 128) +JITCORE_PPC_REG_EXPAND(VRSAVE, 32) +JITCORE_PPC_REG_EXPAND(VSCR, 32) diff --git a/miasm/jitter/arch/JitCore_x86.c b/miasm/jitter/arch/JitCore_x86.c index 361b18b4..9081f3d8 100644 --- a/miasm/jitter/arch/JitCore_x86.c +++ b/miasm/jitter/arch/JitCore_x86.c @@ -414,7 +414,7 @@ PyObject* cpu_set_segm_base(JitCpu* self, PyObject* args) PyGetInt_uint64_t(item1, segm_num); PyGetInt_uint64_t(item2, segm_base); - ((struct vm_cpu*)self->cpu)->segm_base[segm_num] = segm_base; + ((struct vm_cpu*)self->cpu)->segm_base[segm_num & 0xFFFF] = segm_base; Py_INCREF(Py_None); return Py_None; @@ -429,13 +429,13 @@ PyObject* cpu_get_segm_base(JitCpu* self, PyObject* args) if (!PyArg_ParseTuple(args, "O", &item1)) RAISE(PyExc_TypeError,"Cannot parse arguments"); PyGetInt_uint64_t(item1, segm_num); - v = PyLong_FromLong((long)(((struct vm_cpu*)self->cpu)->segm_base[segm_num])); + v = PyLong_FromLong((long)(((struct vm_cpu*)self->cpu)->segm_base[segm_num & 0xFFFF])); return v; } uint64_t segm2addr(JitCpu* jitcpu, uint64_t segm, uint64_t addr) { - return addr + ((struct vm_cpu*)jitcpu->cpu)->segm_base[segm]; + return addr + ((struct vm_cpu*)jitcpu->cpu)->segm_base[segm & 0xFFFF]; } void MEM_WRITE_08(JitCpu* jitcpu, uint64_t addr, uint8_t src) @@ -730,6 +730,11 @@ static PyGetSetDef JitCpu_getseters[] = { "vmmngr", NULL}, + {"vmcpu", + (getter)JitCpu_get_vmcpu, (setter)JitCpu_set_vmcpu, + "vmcpu", + NULL}, + {"jitter", (getter)JitCpu_get_jitter, (setter)JitCpu_set_jitter, "jitter", diff --git a/miasm/jitter/bn.h b/miasm/jitter/bn.h index 1aa6b432..8c4a8ba1 100644 --- a/miasm/jitter/bn.h +++ b/miasm/jitter/bn.h @@ -35,7 +35,7 @@ Code slightly modified to support ast generation calculus style from Expr. #include <assert.h> -/* This macro defines the word size in bytes of the array that constitues the big-number data structure. */ +/* This macro defines the word size in bytes of the array that constitutes the big-number data structure. */ #ifndef WORD_SIZE #define WORD_SIZE 4 #endif diff --git a/miasm/jitter/codegen.py b/miasm/jitter/codegen.py index abbf65d2..0b5b7961 100644 --- a/miasm/jitter/codegen.py +++ b/miasm/jitter/codegen.py @@ -6,8 +6,8 @@ from builtins import zip from future.utils import viewitems, viewvalues -from miasm.expression.expression import Expr, ExprId, ExprLoc, ExprInt, \ - ExprMem, ExprCond, LocKey +from miasm.expression.expression import ExprId, ExprLoc, ExprInt, \ + ExprMem, ExprCond, LocKey, is_expr from miasm.ir.ir import IRBlock, AssignBlock from miasm.ir.translators.C import TranslatorC @@ -123,7 +123,7 @@ class CGen(object): def dst_to_c(self, src): """Translate Expr @src into C code""" - if not isinstance(src, Expr): + if not is_expr(src): src = ExprInt(src, self.PC.size) return self.id_to_c(src) @@ -413,7 +413,7 @@ class CGen(object): @dst: potential instruction destination""" out = [] - if isinstance(dst, Expr): + if is_expr(dst): out += self.gen_post_code(attrib, "DST_value") out.append('BlockDst->address = DST_value;') out += self.gen_post_instr_checks(attrib) diff --git a/miasm/jitter/csts.py b/miasm/jitter/csts.py index 6d40fe0d..3829ed98 100644 --- a/miasm/jitter/csts.py +++ b/miasm/jitter/csts.py @@ -9,6 +9,7 @@ EXCEPT_CODE_AUTOMOD = (1 << 0) EXCEPT_SOFT_BP = (1 << 1) EXCEPT_INT_XX = (1 << 2) EXCEPT_SPR_ACCESS = (1 << 3) +EXCEPT_SYSCALL = (1 << 4) EXCEPT_BREAKPOINT_MEMORY = (1 << 10) # Deprecated EXCEPT_BREAKPOINT_INTERN = EXCEPT_BREAKPOINT_MEMORY diff --git a/miasm/jitter/emulatedsymbexec.py b/miasm/jitter/emulatedsymbexec.py index aacfba9f..844e7d5f 100644 --- a/miasm/jitter/emulatedsymbexec.py +++ b/miasm/jitter/emulatedsymbexec.py @@ -94,10 +94,10 @@ class EmulatedSymbExec(SymbolicExecutionEngine): data = self.expr_simp(data) if not isinstance(data, m2_expr.ExprInt): raise RuntimeError("A simplification is missing: %s" % data) - to_write = data.arg.arg + to_write = int(data) # Format information - addr = dest.ptr.arg.arg + addr = int(dest.ptr) size = data.size // 8 content = hex(to_write).replace("0x", "").replace("L", "") content = "0" * (size * 2 - len(content)) + content @@ -120,7 +120,7 @@ class EmulatedSymbExec(SymbolicExecutionEngine): if not isinstance(value, m2_expr.ExprInt): raise ValueError("A simplification is missing: %s" % value) - setattr(self.cpu, symbol.name, value.arg.arg) + setattr(self.cpu, symbol.name, int(value)) else: raise NotImplementedError("Type not handled: %s" % symbol) @@ -140,7 +140,7 @@ class EmulatedSymbExec(SymbolicExecutionEngine): # CPU specific simplifications def _simp_handle_segm(self, e_s, expr): """Handle 'segm' operation""" - if not expr.is_op_segm(): + if not m2_expr.is_op_segm(expr): return expr if not expr.args[0].is_int(): return expr diff --git a/miasm/jitter/jitcore.py b/miasm/jitter/jitcore.py index ebda656f..cc531cf5 100644 --- a/miasm/jitter/jitcore.py +++ b/miasm/jitter/jitcore.py @@ -104,7 +104,7 @@ class JitCore(object): cur_block.ad_max = cur_block.lines[-1].offset + cur_block.lines[-1].l else: # 1 byte block for unknown mnemonic - offset = ir_arch.loc_db.get_location_offset(cur_block.loc_key) + offset = self.ir_arch.loc_db.get_location_offset(cur_block.loc_key) cur_block.ad_min = offset cur_block.ad_max = offset+1 @@ -198,10 +198,7 @@ class JitCore(object): """ mem_range = interval() - - for block in blocks: - mem_range += interval([(block.ad_min, block.ad_max - 1)]) - + mem_range = interval([(block.ad_min, block.ad_max - 1) for block in blocks]) return mem_range def __updt_jitcode_mem_range(self, vm): @@ -235,12 +232,6 @@ class JitCore(object): # Modified blocks modified_blocks.add(block) - # Generate interval to delete - del_interval = self.blocks_to_memrange(modified_blocks) - - # Remove interval from monitored interval list - self.blocks_mem_interval -= del_interval - # Remove modified blocks for block in modified_blocks: try: @@ -259,6 +250,9 @@ class JitCore(object): # Remove label -> block link del(self.loc_key_to_block[block.loc_key]) + # Re generate blocks intervals + self.blocks_mem_interval = self.blocks_to_memrange(self.loc_key_to_block.values()) + return modified_blocks def updt_automod_code_range(self, vm, mem_range): diff --git a/miasm/jitter/jitcore_cc_base.py b/miasm/jitter/jitcore_cc_base.py index 995c458b..afb2876c 100644 --- a/miasm/jitter/jitcore_cc_base.py +++ b/miasm/jitter/jitcore_cc_base.py @@ -1,5 +1,6 @@ #-*- coding:utf-8 -*- +import glob import os import tempfile import platform @@ -76,6 +77,12 @@ class JitCore_Cc_Base(JitCore): ext = sysconfig.get_config_var('EXT_SUFFIX') if ext is None: ext = ".so" if not is_win else ".lib" + if is_win: + # sysconfig.get_config_var('EXT_SUFFIX') is .pyd on Windows and need to be forced to .lib + # Additionally windows built libraries may have a name like VmMngr.cp38-win_amd64.lib + ext_files = glob.glob(os.path.join(lib_dir, "VmMngr.*lib")) + if len(ext_files) == 1: + ext = os.path.basename(ext_files[0]).replace("VmMngr", "") libs = [ os.path.join(lib_dir, "VmMngr" + ext), diff --git a/miasm/jitter/jitcore_gcc.py b/miasm/jitter/jitcore_gcc.py index 1520cf38..7ffef69e 100644 --- a/miasm/jitter/jitcore_gcc.py +++ b/miasm/jitter/jitcore_gcc.py @@ -1,5 +1,6 @@ #-*- coding:utf-8 -*- +import sys import os import tempfile import ctypes @@ -70,7 +71,7 @@ class JitCore_Gcc(JitCore_Cc_Base): get_python_inc(), "..", "libs", - "python27.lib" + "python%d%d.lib" % (sys.version_info.major, sys.version_info.minor) ) ) cl = [ diff --git a/miasm/jitter/jitcore_llvm.py b/miasm/jitter/jitcore_llvm.py index 46e93282..df7d5950 100644 --- a/miasm/jitter/jitcore_llvm.py +++ b/miasm/jitter/jitcore_llvm.py @@ -1,5 +1,6 @@ from __future__ import print_function import os +import glob import importlib import tempfile import sysconfig @@ -9,6 +10,9 @@ import miasm.jitter.jitcore as jitcore from miasm.jitter import Jitllvm import platform +import llvmlite +llvmlite.binding.load_library_permanently(Jitllvm.__file__) + is_win = platform.system() == "Windows" class JitCore_LLVM(jitcore.JitCore): @@ -56,10 +60,16 @@ class JitCore_LLVM(jitcore.JitCore): # Get architecture dependent Jitcore library (if any) lib_dir = os.path.dirname(os.path.realpath(__file__)) - lib_dir = os.path.join(lib_dir, 'arch') ext = sysconfig.get_config_var('EXT_SUFFIX') if ext is None: ext = ".so" if not is_win else ".pyd" + if is_win: + # sysconfig.get_config_var('EXT_SUFFIX') is .pyd on Windows and need to be forced to .lib + # Additionally windows built libraries may have a name like VmMngr.cp38-win_amd64.lib + ext_files = glob.glob(os.path.join(lib_dir, "VmMngr.*pyd")) + if len(ext_files) == 1: + ext = os.path.basename(ext_files[0]).replace("VmMngr", "") + lib_dir = os.path.join(lib_dir, 'arch') try: jit_lib = os.path.join( lib_dir, self.arch_dependent_libs[self.ir_arch.arch.name] + ext diff --git a/miasm/jitter/jitload.py b/miasm/jitter/jitload.py index 68f9c40d..85d5636f 100644 --- a/miasm/jitter/jitload.py +++ b/miasm/jitter/jitload.py @@ -393,13 +393,16 @@ class Jitter(object): self.pc = pc self.run = True - def continue_run(self, step=False): + def continue_run(self, step=False, trace=False): """PRE: init_run. Continue the run of the current session until iterator returns or run is set to False. If step is True, run only one time. + If trace is True, activate trace log option until execution stops Return the iterator value""" + if trace: + self.set_trace_log() while self.run: try: return next(self.run_iterator) @@ -409,8 +412,9 @@ class Jitter(object): self.run_iterator = self.runiter_once(self.pc) if step is True: - return None - + break + if trace: + self.set_trace_log(False, False, False) return None @@ -422,6 +426,18 @@ class Jitter(object): self.init_run(addr) return self.continue_run() + def run_until(self, addr, trace=False): + """PRE: init_run. + Continue the run of the current session until iterator returns, run is + set to False or addr is reached. + If trace is True, activate trace log option until execution stops + Return the iterator value""" + + def stop_exec(jitter): + jitter.remove_breakpoints_by_callback(stop_exec) + return False + self.add_breakpoint(addr, stop_exec) + return self.continue_run(trace=trace) def init_stack(self): self.vm.add_memory_page( diff --git a/miasm/jitter/loader/pe.py b/miasm/jitter/loader/pe.py index 961bfd93..73cb1367 100644 --- a/miasm/jitter/loader/pe.py +++ b/miasm/jitter/loader/pe.py @@ -171,7 +171,7 @@ def get_export_name_addr_list(e): return out -def vm_load_pe(vm, fdata, align_s=True, load_hdr=True, name="", **kargs): +def vm_load_pe(vm, fdata, align_s=True, load_hdr=True, name="", winobjs=None, **kargs): """Load a PE in memory (@vm) from a data buffer @fdata @vm: VmMngr instance @fdata: data buffer to parse @@ -207,6 +207,9 @@ def vm_load_pe(vm, fdata, align_s=True, load_hdr=True, name="", **kargs): pe.content[:hdr_len] + max(0, (min_len - hdr_len)) * b"\x00" ) + + if winobjs: + winobjs.allocated_pages[pe.NThdr.ImageBase] = (pe.NThdr.ImageBase, len(pe_hdr)) vm.add_memory_page( pe.NThdr.ImageBase, PAGE_READ | PAGE_WRITE, @@ -237,8 +240,12 @@ def vm_load_pe(vm, fdata, align_s=True, load_hdr=True, name="", **kargs): attrib = PAGE_READ if section.flags & 0x80000000: attrib |= PAGE_WRITE + + section_addr = pe.rva2virt(section.addr) + if winobjs: + winobjs.allocated_pages[section_addr] = (section_addr, len(data)) vm.add_memory_page( - pe.rva2virt(section.addr), + section_addr, attrib, data, "%r: %r" % (name, section.name) diff --git a/miasm/jitter/vm_mngr.c b/miasm/jitter/vm_mngr.c index 0c8a0586..d0e49213 100644 --- a/miasm/jitter/vm_mngr.c +++ b/miasm/jitter/vm_mngr.c @@ -551,6 +551,46 @@ int vm_read_mem(vm_mngr_t* vm_mngr, uint64_t addr, char** buffer_ptr, size_t siz return 0; } + +/* + Try to read @size bytes from vm mmemory + Return the number of bytes consecutively read +*/ +uint64_t vm_read_mem_ret_buf(vm_mngr_t* vm_mngr, uint64_t addr, size_t size, char *buffer) +{ + size_t len; + uint64_t addr_diff; + uint64_t size_out; + size_t addr_diff_st; + + struct memory_page_node * mpn; + + size_out = 0; + /* read is multiple page wide */ + while (size){ + mpn = get_memory_page_from_address(vm_mngr, addr, 0); + if (!mpn){ + return size_out; + } + + addr_diff = addr - mpn->ad; + if (addr_diff > SIZE_MAX) { + fprintf(stderr, "Size too big\n"); + exit(EXIT_FAILURE); + } + addr_diff_st = (size_t) addr_diff; + len = MIN(size, mpn->size - addr_diff_st); + memcpy(buffer, (char*)mpn->ad_hp + (addr_diff_st), len); + buffer += len; + size_out += len; + addr += len; + size -= len; + } + + return size_out; +} + + int vm_write_mem(vm_mngr_t* vm_mngr, uint64_t addr, char *buffer, size_t size) { size_t len; diff --git a/miasm/loader/elf_init.py b/miasm/loader/elf_init.py index 14f4dc7c..72d08302 100644 --- a/miasm/loader/elf_init.py +++ b/miasm/loader/elf_init.py @@ -92,6 +92,8 @@ class WRel32(StructWrapper): wrapped._fields.append(("type", "u08")) def get_sym(self): + if isinstance(self.parent.linksection, NullSection): + return None return self.parent.linksection.symtab[self.cstr.info >> 8].name def get_type(self): diff --git a/miasm/loader/new_cstruct.py b/miasm/loader/new_cstruct.py index ec591aa8..16c947a5 100644 --- a/miasm/loader/new_cstruct.py +++ b/miasm/loader/new_cstruct.py @@ -4,6 +4,7 @@ from __future__ import print_function import re import struct +from miasm.core.utils import force_bytes from future.utils import PY3, viewitems, with_metaclass type2realtype = {} @@ -213,9 +214,10 @@ class CStruct(with_metaclass(Cstruct_Metaclass, object)): if cpt == None: if value == None: o = struct.calcsize(fmt) * b"\x00" + elif ffmt.endswith('s'): + new_value = force_bytes(value) + o = struct.pack(self.sex + fmt, new_value) else: - if isinstance(value, str): - value = value.encode() o = struct.pack(self.sex + fmt, value) else: o = b"" diff --git a/miasm/loader/pe.py b/miasm/loader/pe.py index f402e980..2d257906 100644 --- a/miasm/loader/pe.py +++ b/miasm/loader/pe.py @@ -267,7 +267,7 @@ class DescName(CStruct): return name, off + len(name) + 1 def sets(self, value): - return bytes(value) + b"\x00" + return force_bytes(value) + b"\x00" class ImportByName(CStruct): @@ -434,7 +434,7 @@ class DirImport(CStruct): # entry.firstthunk = rva # rva+=(len(entry.firstthunks)+1)*self.parent_head._wsize//8 # Rva size if entry.originalfirstthunk and entry.firstthunk: - if isinstance(entry.originalfirstthunk, struct_array): + if isinstance(entry.originalfirstthunks, struct_array): tmp_thunk = entry.originalfirstthunks elif isinstance(entry.firstthunks, struct_array): tmp_thunk = entry.firstthunks @@ -457,6 +457,11 @@ class DirImport(CStruct): rva += len(imp) def build_content(self, raw): + if self.parent_head._wsize == 32: + mask_ptr = 0x80000000 + elif self.parent_head._wsize == 64: + mask_ptr = 0x8000000000000000 + dirimp = self.parent_head.NThdr.optentries[DIRECTORY_ENTRY_IMPORT] of1 = dirimp.rva if not of1: # No Import @@ -918,7 +923,7 @@ class DirDelay(CStruct): return out, off def sete(self, entries): - return "".join(bytes(entry) for entry in entries) + b"\x00" * (4 * 8) # DelayDesc_e + return b"".join(bytes(entry) for entry in entries) + b"\x00" * (4 * 8) # DelayDesc_e def __len__(self): rva_size = self.parent_head._wsize // 8 @@ -1306,19 +1311,6 @@ class DirRes(CStruct): out = [] tmp_off = off - for _ in range(nbr): - if tmp_off >= ofend: - break - if tmp_off + length >= len(raw): - log.warn('warning bad resource offset') - break - try: - entry, length = ResEntry.unpack_l(raw, tmp_off, self.parent_head) - except RuntimeError: - log.warn('bad resentry') - return None, tmp_off - out.append(entry) - tmp_off += length resdesc.resentries = struct_array(self, raw, off, ResEntry, @@ -1334,7 +1326,7 @@ class DirRes(CStruct): # data dir off = entry.offsettodata if not 0 <= off < len(raw): - log.warn('bad resrouce entry') + log.warn('bad resource entry') continue data = ResDataEntry.unpack(raw, off, @@ -1348,7 +1340,7 @@ class DirRes(CStruct): log.warn('warning recusif subdir') continue if not 0 <= off < len(self.parent_head.img_rva): - log.warn('bad resrouce entry') + log.warn('bad resource entry') continue subdir, length = ResDesc_e.unpack_l(raw, off, @@ -1360,7 +1352,7 @@ class DirRes(CStruct): ResEntry, nbr) except RuntimeError: - log.warn('bad resrouce entry') + log.warn('bad resource entry') continue entry.subdir = subdir @@ -1372,17 +1364,21 @@ class DirRes(CStruct): return of1 = self.parent_head.NThdr.optentries[DIRECTORY_ENTRY_RESOURCE].rva raw[self.parent_head.rva2off(of1)] = bytes(self.resdesc) - dir_todo = {self.parent_head.NThdr.optentries[ - DIRECTORY_ENTRY_RESOURCE].rva: self.resdesc} + length = len(self.resdesc) + dir_todo = { + self.parent_head.NThdr.optentries[DIRECTORY_ENTRY_RESOURCE].rva: self.resdesc + } + of1 = of1 + length + raw[self.parent_head.rva2off(of1)] = bytes(self.resdesc.resentries) dir_done = {} while dir_todo: of1, my_dir = dir_todo.popitem() dir_done[of1] = my_dir raw[self.parent_head.rva2off(of1)] = bytes(my_dir) of1 += len(my_dir) + raw[self.parent_head.rva2off(of1)] = bytes(my_dir.resentries) of_base = of1 for entry in my_dir.resentries: - raw[of_base] = bytes(entry) of_base += len(entry) if entry.name_s: raw[self.parent_head.rva2off(entry.name)] = bytes(entry.name_s) diff --git a/miasm/os_dep/common.py b/miasm/os_dep/common.py index 4a92ef2a..74100817 100644 --- a/miasm/os_dep/common.py +++ b/miasm/os_dep/common.py @@ -71,15 +71,15 @@ class heap(object): self.addr &= self.mask ^ (self.align - 1) return ret - def alloc(self, jitter, size, perm=PAGE_READ | PAGE_WRITE): + def alloc(self, jitter, size, perm=PAGE_READ | PAGE_WRITE, cmt=""): """ @jitter: a jitter instance @size: the size to allocate @perm: permission flags (see vm_alloc doc) """ - return self.vm_alloc(jitter.vm, size, perm) + return self.vm_alloc(jitter.vm, size, perm=perm, cmt=cmt) - def vm_alloc(self, vm, size, perm=PAGE_READ | PAGE_WRITE): + def vm_alloc(self, vm, size, perm=PAGE_READ | PAGE_WRITE, cmt=""): """ @vm: a VmMngr instance @size: the size to allocate @@ -91,7 +91,7 @@ class heap(object): addr, perm, b"\x00" * (size), - "Heap alloc by %s" % get_caller_name(2) + "Heap alloc by %s %s" % (get_caller_name(2), cmt) ) return addr diff --git a/miasm/os_dep/linux/syscall.py b/miasm/os_dep/linux/syscall.py index fc6bbd8a..acebe2cb 100644 --- a/miasm/os_dep/linux/syscall.py +++ b/miasm/os_dep/linux/syscall.py @@ -5,7 +5,7 @@ import logging import struct import termios -from miasm.jitter.csts import EXCEPT_PRIV_INSN, EXCEPT_INT_XX +from miasm.jitter.csts import EXCEPT_INT_XX, EXCEPT_SYSCALL from miasm.core.utils import pck64 log = logging.getLogger('syscalls') @@ -401,7 +401,7 @@ def sys_x86_64_arch_prctl(jitter, linux_env): jitter.cpu.set_segm_base(jitter.cpu.FS, addr) elif code == 0x3001: # CET status (disabled) - jitter.cpu.set_mem(addr, pck64(0)) + jitter.vm.set_mem(addr, pck64(0)) else: raise RuntimeError("Not implemented") jitter.cpu.RAX = 0 @@ -681,7 +681,7 @@ def sys_x86_64_connect(jitter, linux_env): log.debug("sys_connect(%x, %r, %x)", fd, raddr, addrlen) # Stub - # Always refuse the connexion + # Always refuse the connection jitter.cpu.RAX = -1 @@ -979,16 +979,12 @@ syscall_callbacks_arml = { } def syscall_x86_64_exception_handler(linux_env, syscall_callbacks, jitter): - """Call to actually handle an EXCEPT_PRIV_INSN exception + """Call to actually handle an EXCEPT_SYSCALL exception In the case of an error raised by a SYSCALL, call the corresponding syscall_callbacks @linux_env: LinuxEnvironment_x86_64 instance @syscall_callbacks: syscall number -> func(jitter, linux_env) """ - # Ensure the jitter has break on a SYSCALL - cur_instr = jitter.jit.mdis.dis_instr(jitter.pc) - if cur_instr.name != "SYSCALL": - return True # Dispatch to SYSCALL stub syscall_number = jitter.cpu.RAX @@ -1002,14 +998,13 @@ def syscall_x86_64_exception_handler(linux_env, syscall_callbacks, jitter): # Clean exception and move pc to the next instruction, to let the jitter # continue - jitter.cpu.set_exception(jitter.cpu.get_exception() ^ EXCEPT_PRIV_INSN) - jitter.pc += cur_instr.l + jitter.cpu.set_exception(jitter.cpu.get_exception() ^ EXCEPT_SYSCALL) return True def syscall_x86_32_exception_handler(linux_env, syscall_callbacks, jitter): - """Call to actually handle an EXCEPT_PRIV_INSN exception + """Call to actually handle an EXCEPT_INT_XX exception In the case of an error raised by a SYSCALL, call the corresponding syscall_callbacks @linux_env: LinuxEnvironment_x86_32 instance @@ -1078,7 +1073,7 @@ def enable_syscall_handling(jitter, linux_env, syscall_callbacks): if arch_name == "x8664": handler = syscall_x86_64_exception_handler handler = functools.partial(handler, linux_env, syscall_callbacks) - jitter.add_exception_handler(EXCEPT_PRIV_INSN, handler) + jitter.add_exception_handler(EXCEPT_SYSCALL, handler) elif arch_name == "x8632": handler = syscall_x86_32_exception_handler handler = functools.partial(handler, linux_env, syscall_callbacks) @@ -1089,4 +1084,3 @@ def enable_syscall_handling(jitter, linux_env, syscall_callbacks): jitter.add_exception_handler(EXCEPT_INT_XX, handler) else: raise ValueError("No syscall handler implemented for %s" % arch_name) - diff --git a/miasm/os_dep/win_api_x86_32.py b/miasm/os_dep/win_api_x86_32.py index c1870d97..568a646d 100644 --- a/miasm/os_dep/win_api_x86_32.py +++ b/miasm/os_dep/win_api_x86_32.py @@ -157,6 +157,9 @@ class c_winobjs(object): self.cryptcontext_num = 0 self.cryptcontext = {} self.phhash_crypt_md5 = 0x55555 + # key used by EncodePointer and DecodePointer + # (kernel32) + self.ptr_encode_key = 0xabababab self.files_hwnd = {} self.windowlong_dw = 0x77700 self.module_cur_hwnd = 0x88800 @@ -272,7 +275,7 @@ class mdl(object): def kernel32_HeapAlloc(jitter): ret_ad, args = jitter.func_args_stdcall(["heap", "flags", "size"]) - alloc_addr = winobjs.heap.alloc(jitter, args.size) + alloc_addr = winobjs.heap.alloc(jitter, args.size, cmt=hex(ret_ad)) jitter.func_ret_stdcall(ret_ad, alloc_addr) @@ -420,6 +423,36 @@ def kernel32_CloseHandle(jitter): ret_ad, _ = jitter.func_args_stdcall(["hwnd"]) jitter.func_ret_stdcall(ret_ad, 1) +def kernel32_EncodePointer(jitter): + """ + PVOID EncodePointer( + _In_ PVOID Ptr + ); + + Encoding globally available pointers helps protect them from being + exploited. The EncodePointer function obfuscates the pointer value + with a secret so that it cannot be predicted by an external agent. + The secret used by EncodePointer is different for each process. + + A pointer must be decoded before it can be used. + + """ + ret, args = jitter.func_args_stdcall(1) + jitter.func_ret_stdcall(ret, args[0] ^ winobjs.ptr_encode_key) + return True + +def kernel32_DecodePointer(jitter): + """ + PVOID DecodePointer( + PVOID Ptr + ); + + The function returns the decoded pointer. + + """ + ret, args = jitter.func_args_stdcall(1) + jitter.func_ret_stdcall(ret, args[0] ^ winobjs.ptr_encode_key) + return True def user32_GetForegroundWindow(jitter): ret_ad, _ = jitter.func_args_stdcall(0) @@ -505,7 +538,7 @@ def advapi32_CryptHashData(jitter): data = jitter.vm.get_mem(args.pbdata, args.dwdatalen) log.debug('will hash %X', args.dwdatalen) - log.debug(repr(data[:10]) + "...") + log.debug(repr(data[:0x10]) + "...") winobjs.cryptcontext[args.hhash].h.update(data) jitter.func_ret_stdcall(ret_ad, 1) @@ -518,12 +551,18 @@ def advapi32_CryptGetHashParam(jitter): raise ValueError("unknown crypt context") if args.param == 2: + # HP_HASHVAL # XXX todo: save h state? h = winobjs.cryptcontext[args.hhash].h.digest() + jitter.vm.set_mem(args.pbdata, h) + jitter.vm.set_u32(args.dwdatalen, len(h)) + elif args.param == 4: + # HP_HASHSIZE + ret = winobjs.cryptcontext[args.hhash].h.digest_size + jitter.vm.set_u32(args.pbdata, ret) + jitter.vm.set_u32(args.dwdatalen, 4) else: raise ValueError('not impl', args.param) - jitter.vm.set_mem(args.pbdata, h) - jitter.vm.set_u32(args.dwdatalen, len(h)) jitter.func_ret_stdcall(ret_ad, 1) @@ -606,7 +645,7 @@ def kernel32_CreateFile(jitter, funcname, get_str): h = open(sb_fname, 'r+b') ret = winobjs.handle_pool.add(sb_fname, h) else: - log.warning("FILE %r DOES NOT EXIST!", fname) + log.warning("FILE %r (%s) DOES NOT EXIST!", fname, sb_fname) elif args.dwcreationdisposition == 1: # create new if os.access(sb_fname, os.R_OK): @@ -759,11 +798,13 @@ def kernel32_VirtualProtect(jitter): jitter.vm.set_u32(args.lpfloldprotect, ACCESS_DICT_INV[old]) paddr = args.lpvoid - (args.lpvoid % winobjs.alloc_align) - psize = args.dwsize + paddr_max = (args.lpvoid + args.dwsize + winobjs.alloc_align - 1) + paddr_max_round = paddr_max - (paddr_max % winobjs.alloc_align) + psize = paddr_max_round - paddr for addr, items in list(winobjs.allocated_pages.items()): alloc_addr, alloc_size = items - if not (alloc_addr <= paddr and - paddr + psize <= alloc_addr + alloc_size): + if (paddr + psize <= alloc_addr or + paddr > alloc_addr + alloc_size): continue size = jitter.vm.get_all_memory()[addr]["size"] # Page is included in Protect area @@ -1112,21 +1153,22 @@ def kernel32_GetCommandLineW(jitter): def shell32_CommandLineToArgvW(jitter): ret_ad, args = jitter.func_args_stdcall(["pcmd", "pnumargs"]) cmd = get_win_str_w(jitter, args.pcmd) + if cmd.startswith('"') and cmd.endswith('"'): + cmd = cmd[1:-1] log.info("CommandLineToArgv %r", cmd) tks = cmd.split(' ') addr = winobjs.heap.alloc(jitter, len(cmd) * 2 + 4 * len(tks)) addr_ret = winobjs.heap.alloc(jitter, 4 * (len(tks) + 1)) o = 0 for i, t in enumerate(tks): - jitter.set_win_str_w(addr + o, t) + set_win_str_w(jitter, addr + o, t) jitter.vm.set_u32(addr_ret + 4 * i, addr + o) o += len(t)*2 + 2 - jitter.vm.set_u32(addr_ret + 4 * i, 0) + jitter.vm.set_u32(addr_ret + 4 * (i+1), 0) jitter.vm.set_u32(args.pnumargs, len(tks)) jitter.func_ret_stdcall(ret_ad, addr_ret) - def cryptdll_MD5Init(jitter): ret_ad, args = jitter.func_args_stdcall(["ad_ctx"]) index = len(winobjs.cryptdll_md5_h) @@ -1333,7 +1375,7 @@ def ntoskrnl_RtlGetVersion(jitter): 0x2, # min vers 0x666, # build nbr 0x2, # platform id - ) + jitter.set_win_str_w("Service pack 4") + ) + encode_win_str_w("Service pack 4") jitter.vm.set_mem(args.ptr_version, s) jitter.func_ret_stdcall(ret_ad, 0) @@ -1519,7 +1561,7 @@ def kernel32_lstrcpy(jitter): def msvcrt__mbscpy(jitter): ret_ad, args = jitter.func_args_cdecl(["ptr_str1", "ptr_str2"]) s2 = get_win_str_w(jitter, args.ptr_str2) - jitter.set_win_str_w(args.ptr_str1, s2) + set_win_str_w(jitter, args.ptr_str1, s2) jitter.func_ret_cdecl(ret_ad, args.ptr_str1) def msvcrt_wcscpy(jitter): @@ -1533,7 +1575,7 @@ def kernel32_lstrcpyn(jitter): if len(s2) >= args.mlen: s2 = s2[:args.mlen - 1] log.info("Copy '%r'", s2) - jitter.set_win_str_a(args.ptr_str1, s2) + set_win_str_a(jitter, args.ptr_str1, s2) jitter.func_ret_stdcall(ret_ad, args.ptr_str1) @@ -1628,15 +1670,82 @@ def kernel32_GetVolumeInformationW(jitter): def kernel32_MultiByteToWideChar(jitter): + MB_ERR_INVALID_CHARS = 0x8 + CP_ACP = 0x000 + CP_1252 = 0x4e4 + ret_ad, args = jitter.func_args_stdcall(["codepage", "dwflags", "lpmultibytestr", "cbmultibyte", "lpwidecharstr", "cchwidechar"]) - src = get_win_str_a(jitter, args.lpmultibytestr) - l = len(src) + 1 - set_win_str_w(jitter, args.lpwidecharstr, src) - jitter.func_ret_stdcall(ret_ad, l) + if args.codepage != CP_ACP and args.codepage != CP_1252: + raise NotImplementedError + # according to MSDN: + # "Note that, if cbMultiByte is 0, the function fails." + if args.cbmultibyte == 0: + raise ValueError + # according to MSDN: + # "Alternatively, this parameter can be set to -1 if the string is + # null-terminated." + if args.cbmultibyte == 0xffffffff: + src_len = 0 + while jitter.vm.get_mem(args.lpmultibytestr + src_len, 1) != b'\0': + src_len += 1 + src = jitter.vm.get_mem(args.lpmultibytestr, src_len) + else: + src = jitter.vm.get_mem(args.lpmultibytestr, args.cbmultibyte) + if args.dwflags & MB_ERR_INVALID_CHARS: + # will raise an exception if decoding fails + s = src.decode("cp1252", errors="replace").encode("utf-16le") + else: + # silently replace undecodable chars with U+FFFD + s = src.decode("cp1252", errors="replace").encode("utf-16le") + if args.cchwidechar > 0: + # return value is number of bytes written + retval = min(args.cchwidechar, len(s)) + jitter.vm.set_mem(args.lpwidecharstr, s[:retval]) + else: + # return value is number of bytes to write + # i.e., size of dest. buffer to allocate + retval = len(s) + jitter.func_ret_stdcall(ret_ad, retval) + + +def kernel32_WideCharToMultiByte(jitter): + """ + int WideCharToMultiByte( + UINT CodePage, + DWORD dwFlags, + _In_NLS_string_(cchWideChar)LPCWCH lpWideCharStr, + int cchWideChar, + LPSTR lpMultiByteStr, + int cbMultiByte, + LPCCH lpDefaultChar, + LPBOOL lpUsedDefaultChar + ); + + """ + CP_ACP = 0x000 + CP_1252 = 0x4e4 + + ret, args = jitter.func_args_stdcall([ + 'CodePage', 'dwFlags', 'lpWideCharStr', 'cchWideChar', + 'lpMultiByteStr', 'cbMultiByte', 'lpDefaultChar', 'lpUsedDefaultChar', + ]) + if args.CodePage != CP_ACP and args.CodePage != CP_1252: + raise NotImplementedError + src = jitter.vm.get_mem(args.lpWideCharStr, args.cchWideChar * 2) + dst = src.decode("utf-16le").encode("cp1252", errors="replace") + if args.cbMultiByte > 0: + # return value is the number of bytes written + retval = min(args.cbMultiByte, len(dst)) + jitter.vm.set_mem(args.lpMultiByteStr, dst[:retval]) + else: + # return value is the size of the buffer to allocate + # to get the multibyte string + retval = len(dst) + jitter.func_ret_stdcall(ret, retval) def my_GetEnvironmentVariable(jitter, funcname, get_str, set_str, mylen): @@ -1870,6 +1979,7 @@ def ntdll_LdrLoadDll(jitter): libname = s.lower() ad = winobjs.runtime_dll.lib_get_add_base(libname) + log.info("Loading %r ret 0x%x", s, ad) jitter.vm.set_u32(args.modhandle, ad) jitter.func_ret_stdcall(ret_ad, 0) @@ -1911,7 +2021,7 @@ def msvcrt_memset(jitter): def msvcrt_strrchr(jitter): ret_ad, args = jitter.func_args_cdecl(['pstr','c']) s = get_win_str_a(jitter, args.pstr) - c = int_to_byte(args.c) + c = int_to_byte(args.c).decode() ret = args.pstr + s.rfind(c) log.info("strrchr(%x '%s','%s') = %x" % (args.pstr,s,c,ret)) jitter.func_ret_cdecl(ret_ad, ret) @@ -1919,7 +2029,7 @@ def msvcrt_strrchr(jitter): def msvcrt_wcsrchr(jitter): ret_ad, args = jitter.func_args_cdecl(['pstr','c']) s = get_win_str_w(jitter, args.pstr) - c = int_to_byte(args.c) + c = int_to_byte(args.c).decode() ret = args.pstr + (s.rfind(c)*2) log.info("wcsrchr(%x '%s',%s) = %x" % (args.pstr,s,c,ret)) jitter.func_ret_cdecl(ret_ad, ret) @@ -2339,13 +2449,88 @@ def user32_GetKeyboardType(jitter): jitter.func_ret_stdcall(ret_ad, ret) + +class startupinfo(object): + """ + typedef struct _STARTUPINFOA { + /* 00000000 */ DWORD cb; + /* 00000004 */ LPSTR lpReserved; + /* 00000008 */ LPSTR lpDesktop; + /* 0000000C */ LPSTR lpTitle; + /* 00000010 */ DWORD dwX; + /* 00000014 */ DWORD dwY; + /* 00000018 */ DWORD dwXSize; + /* 0000001C */ DWORD dwYSize; + /* 00000020 */ DWORD dwXCountChars; + /* 00000024 */ DWORD dwYCountChars; + /* 00000028 */ DWORD dwFillAttribute; + /* 0000002C */ DWORD dwFlags; + /* 00000030 */ WORD wShowWindow; + /* 00000032 */ WORD cbReserved2; + /* 00000034 */ LPBYTE lpReserved2; + /* 00000038 */ HANDLE hStdInput; + /* 0000003C */ HANDLE hStdOutput; + /* 00000040 */ HANDLE hStdError; + } STARTUPINFOA, *LPSTARTUPINFOA; + + """ + # TODO: fill with relevant values + # for now, struct is just a placeholder + cb = 0x0 + lpReserved = 0x0 + lpDesktop = 0x0 + lpTitle = 0x0 + dwX = 0x0 + dwY = 0x0 + dwXSize = 0x0 + dwYSize = 0x0 + dwXCountChars = 0x0 + dwYCountChars = 0x0 + dwFillAttribute = 0x0 + dwFlags = 0x0 + wShowWindow = 0x0 + cbReserved2 = 0x0 + lpReserved2 = 0x0 + hStdInput = 0x0 + hStdOutput = 0x0 + hStdError = 0x0 + + def pack(self): + return struct.pack('IIIIIIIIIIIIHHIIII', + self.cb, + self.lpReserved, + self.lpDesktop, + self.lpTitle, + self.dwX, + self.dwY, + self.dwXSize, + self.dwYSize, + self.dwXCountChars, + self.dwYCountChars, + self.dwFillAttribute, + self.dwFlags, + self.wShowWindow, + self.cbReserved2, + self.lpReserved2, + self.hStdInput, + self.hStdOutput, + self.hStdError) + def kernel32_GetStartupInfo(jitter, funcname, set_str): - ret_ad, args = jitter.func_args_stdcall(["ptr"]) + """ + void GetStartupInfo( + LPSTARTUPINFOW lpStartupInfo + ); - s = b"\x00" * 0x2c + b"\x81\x00\x00\x00" + b"\x0a" + Retrieves the contents of the STARTUPINFO structure that was specified + when the calling process was created. + + https://docs.microsoft.com/en-us/windows/win32/api/processthreadsapi/nf-processthreadsapi-getstartupinfow - jitter.vm.set_mem(args.ptr, s) + """ + ret_ad, args = jitter.func_args_stdcall(["ptr"]) + jitter.vm.set_mem(args.ptr, startupinfo().pack()) jitter.func_ret_stdcall(ret_ad, args.ptr) @@ -2877,7 +3062,7 @@ class win32_find_data(object): for k, v in viewitems(kargs): setattr(self, k, v) - def toStruct(self): + def toStruct(self, encode_str=encode_win_str_w): s = struct.pack('=IQQQIIII', self.fileattrib, self.creationtime, @@ -2887,10 +3072,10 @@ class win32_find_data(object): self.filesizelow, self.dwreserved0, self.dwreserved1) - fname = self.cfilename.encode('utf-8') + b'\x00' * MAX_PATH + fname = encode_str(self.cfilename) + b'\x00' * MAX_PATH fname = fname[:MAX_PATH] s += fname - fname = self.alternamefilename.encode('utf-8') + b'\x00' * 14 + fname = encode_str(self.alternamefilename) + b'\x00' * 14 fname = fname[:14] s += fname return s @@ -2927,33 +3112,66 @@ class find_data_mngr(object): return fname - -def kernel32_FindFirstFileA(jitter): - ret_ad, args = jitter.func_args_stdcall(["pfilepattern", "pfindfiledata"]) - - filepattern = get_win_str_a(jitter, args.pfilepattern) +def my_FindFirstFile(jitter, pfilepattern, pfindfiledata, get_win_str, encode_str): + filepattern = get_win_str(jitter, pfilepattern) h = winobjs.find_data.findfirst(filepattern) fname = winobjs.find_data.findnext(h) fdata = win32_find_data(cfilename=fname) - jitter.vm.set_mem(args.pfindfiledata, fdata.toStruct()) + jitter.vm.set_mem(pfindfiledata, fdata.toStruct(encode_str=encode_str)) + return h + +def kernel32_FindFirstFileA(jitter): + ret_ad, args = jitter.func_args_stdcall(["pfilepattern", "pfindfiledata"]) + h = my_FindFirstFile(jitter, args.pfilepattern, args.pfindfiledata, + get_win_str_a, encode_win_str_a) jitter.func_ret_stdcall(ret_ad, h) +def kernel32_FindFirstFileW(jitter): + ret_ad, args = jitter.func_args_stdcall(["pfilepattern", "pfindfiledata"]) + h = my_FindFirstFile(jitter, args.pfilepattern, args.pfindfiledata, + get_win_str_w, encode_win_str_w) + jitter.func_ret_stdcall(ret_ad, h) -def kernel32_FindNextFileA(jitter): - ret_ad, args = jitter.func_args_stdcall(["handle", "pfindfiledata"]) +def kernel32_FindFirstFileExA(jitter): + ret_ad, args = jitter.func_args_stdcall([ + "lpFileName", + "fInfoLevelId", + "lpFindFileData", + "fSearchOp", + "lpSearchFilter", + "dwAdditionalFlags"]) + h = my_FindFirstFile(jitter, args.lpFileName, args.lpFindFileData, + get_win_str_a, encode_win_str_a) + jitter.func_ret_stdcall(ret_ad, h) + +def kernel32_FindFirstFileExW(jitter): + ret_ad, args = jitter.func_args_stdcall([ + "lpFileName", + "fInfoLevelId", + "lpFindFileData", + "fSearchOp", + "lpSearchFilter", + "dwAdditionalFlags"]) + h = my_FindFirstFile(jitter, args.lpFileName, args.lpFindFileData, + get_win_str_w, encode_win_str_w) + jitter.func_ret_stdcall(ret_ad, h) +def my_FindNextFile(jitter, encode_str): + ret_ad, args = jitter.func_args_stdcall(["handle", "pfindfiledata"]) fname = winobjs.find_data.findnext(args.handle) if fname is None: + winobjs.lastwin32error = 0x12 # ERROR_NO_MORE_FILES ret = 0 else: ret = 1 fdata = win32_find_data(cfilename=fname) - jitter.vm.set_mem(args.pfindfiledata, fdata.toStruct()) - + jitter.vm.set_mem(args.pfindfiledata, fdata.toStruct(encode_str=encode_str)) jitter.func_ret_stdcall(ret_ad, ret) +kernel32_FindNextFileA = lambda jitter: my_FindNextFile(jitter, encode_win_str_a) +kernel32_FindNextFileW = lambda jitter: my_FindNextFile(jitter, encode_win_str_w) def kernel32_GetNativeSystemInfo(jitter): ret_ad, args = jitter.func_args_stdcall(["sys_ptr"]) @@ -3080,7 +3298,7 @@ class FLS(object): ''' DWORD FlsAlloc( PFLS_CALLBACK_FUNCTION lpCallback - ); + ); ''' ret_ad, args = jitter.func_args_stdcall(["lpCallback"]) index = len(self.slots) @@ -3097,7 +3315,7 @@ class FLS(object): ret_ad, args = jitter.func_args_stdcall(["dwFlsIndex", "lpFlsData"]) self.slots[args.dwFlsIndex] = args.lpFlsData jitter.func_ret_stdcall(ret_ad, 1) - + def kernel32_FlsGetValue(self, jitter): ''' PVOID FlsGetValue( @@ -3105,8 +3323,8 @@ class FLS(object): ); ''' ret_ad, args = jitter.func_args_stdcall(["dwFlsIndex"]) - jitter.func_ret_stdcall(ret_ad, self.slots[args.dwFlsIndex]) - + jitter.func_ret_stdcall(ret_ad, self.slots[args.dwFlsIndex]) + fls = FLS() @@ -3129,15 +3347,15 @@ def kernel32_GetStdHandle(jitter): HANDLE WINAPI GetStdHandle( _In_ DWORD nStdHandle ); - - STD_INPUT_HANDLE (DWORD)-10 + + STD_INPUT_HANDLE (DWORD)-10 The standard input device. Initially, this is the console input buffer, CONIN$. - STD_OUTPUT_HANDLE (DWORD)-11 + STD_OUTPUT_HANDLE (DWORD)-11 The standard output device. Initially, this is the active console screen buffer, CONOUT$. - STD_ERROR_HANDLE (DWORD)-12 - The standard error device. Initially, this is the active console screen buffer, CONOUT$. + STD_ERROR_HANDLE (DWORD)-12 + The standard error device. Initially, this is the active console screen buffer, CONOUT$. ''' ret_ad, args = jitter.func_args_stdcall(["nStdHandle"]) jitter.func_ret_stdcall(ret_ad, { @@ -3146,7 +3364,7 @@ def kernel32_GetStdHandle(jitter): STD_INPUT_HANDLE: 3, }[args.nStdHandle]) - + FILE_TYPE_UNKNOWN = 0x0000 FILE_TYPE_CHAR = 0x0002 @@ -3226,13 +3444,13 @@ def kernel32_IsProcessorFeaturePresent(jitter): 17: False, }[args.ProcessorFeature]) - + def kernel32_GetACP(jitter): ''' UINT GetACP(); ''' ret_ad, args = jitter.func_args_stdcall([]) - # Windows-1252: Latin 1 / Western European Superset of ISO-8859-1 (without C1 controls). + # Windows-1252: Latin 1 / Western European Superset of ISO-8859-1 (without C1 controls). jitter.func_ret_stdcall(ret_ad, 1252) @@ -3257,7 +3475,7 @@ def kernel32_IsValidCodePage(jitter): ); ''' ret_ad, args = jitter.func_args_stdcall(["CodePage"]) - jitter.func_ret_stdcall(ret_ad, args.CodePage in VALID_CODE_PAGES) + jitter.func_ret_stdcall(ret_ad, args.CodePage in VALID_CODE_PAGES) def kernel32_GetCPInfo(jitter): @@ -3270,8 +3488,102 @@ def kernel32_GetCPInfo(jitter): ret_ad, args = jitter.func_args_stdcall(["CodePage", "lpCPInfo"]) assert args.CodePage == 1252 # ref: http://www.rensselaer.org/dept/cis/software/g77-mingw32/include/winnls.h - #define MAX_LEADBYTES 12 + #define MAX_LEADBYTES 12 #define MAX_DEFAULTCHAR 2 jitter.vm.set_mem(args.lpCPInfo, struct.pack('<I', 0x1) + b'??' + b'\x00' * 12) jitter.func_ret_stdcall(ret_ad, 1) - + + +def kernel32_GetStringTypeW(jitter): + """ + BOOL GetStringTypeW( + DWORD dwInfoType, + _In_NLS_string_(cchSrc)LPCWCH lpSrcStr, + int cchSrc, + LPWORD lpCharType + ); + + Retrieves character type information for the characters in the specified + Unicode source string. For each character in the string, the function + sets one or more bits in the corresponding 16-bit element of the output + array. Each bit identifies a given character type, for example, letter, + digit, or neither. + + """ + # These types support ANSI C and POSIX (LC_CTYPE) character typing + # functions.A bitwise-OR of these values is retrieved in the array in the + # output buffer when dwInfoType is set to CT_CTYPE1. For DBCS locales, the + # type attributes apply to both narrow characters and wide characters. The + # Japanese hiragana and katakana characters, and the kanji ideograph + # characters all have the C1_ALPHA attribute. + CT_TYPE1 = 0x01 + # TODO handle other types of information + # (CT_TYPE2, CT_TYPE3) + # for now, they raise NotImplemented + CT_TYPE2 = 0x02 + CT_TYPE3 = 0x03 + + C1_UPPER = 0x0001 # Uppercase + C1_LOWER = 0x0002 # Lowercase + C1_DIGIT = 0x0004 # Decimal digits + C1_SPACE = 0x0008 # Space characters + C1_PUNCT = 0x0010 # Punctuation + C1_CNTRL = 0x0020 # Control characters + C1_BLANK = 0x0040 # Blank characters + C1_XDIGIT = 0x0080 # Hexadecimal digits + C1_ALPHA = 0x0100 # Any linguistic character: alphabetical, syllabary, or ideographic + C1_DEFINED = 0x0200 # A defined character, but not one of the other C1_* types + + # the following sets have been generated from the Linux python library curses + # e.g., C1_PUNCT_SET = [chr(i) for i in range(256) if curses.ascii.ispunct(chr(i))] + C1_PUNCT_SET = ['!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', + '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', + '^', '_', '`', '{', '|', '}', '~'] + C1_CNTRL_SET = ['\x00', '\x01', '\x02', '\x03', '\x04', '\x05', '\x06', + '\x07', '\x08', '\t', '\n', '\x0b', '\x0c', '\r', '\x0e', '\x0f', + '\x10', '\x11', '\x12', '\x13', '\x14', '\x15', '\x16', '\x17', + '\x18', '\x19', '\x1a', '\x1b', '\x1c', '\x1d', '\x1e', '\x1f', + '\x7f'] + C1_BLANK_SET = ['\t', ' '] + C1_XDIGIT_SET = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', + 'B', 'C', 'D', 'E', 'F', 'a', 'b', 'c', 'd', 'e', 'f'] + + ret, args = jitter.func_args_stdcall(['dwInfoType', 'lpSrcStr', 'cchSrc', + 'lpCharType']) + s = jitter.vm.get_mem(args.lpSrcStr, args.cchSrc).decode("utf-16") + if args.dwInfoType == CT_TYPE1: + # iterate over characters from the decoded W string + for i, c in enumerate(s): + # TODO handle non-ascii characters + if not c.isascii(): + continue + val = 0 + if c.isupper(): + val |= C1_UPPER + if c.islower(): + val |= C1_LOWER + if c.isdigit(): + val |= C1_DIGIT + if c.isspace(): + val |= C1_SPACE + if c in C1_PUNCT_SET: + val |= C1_PUNCT + if c in C1_CNTRL_SET: + val |= C1_CNTRL + if c in C1_BLANK_SET: + val |= C1_BLANK + if c in C1_XDIGIT_SET: + val |= C1_XDIGIT + if c.isalpha(): + val |= C1_ALPHA + if val == 0: + val = C1_DEFINED + jitter.vm.set_u16(args.lpCharType + i * 2, val) + elif args.dwInfoType == CT_TYPE2: + raise NotImplemented + elif args.dwInfoType == CT_TYPE3: + raise NotImplemented + else: + raise ValueError("CT_TYPE unknown: %i" % args.dwInfoType) + jitter.func_ret_stdcall(ret, 1) + return True diff --git a/miasm/os_dep/win_api_x86_32_seh.py b/miasm/os_dep/win_api_x86_32_seh.py index 28699d68..57416477 100644 --- a/miasm/os_dep/win_api_x86_32_seh.py +++ b/miasm/os_dep/win_api_x86_32_seh.py @@ -189,18 +189,23 @@ def build_ldr_data(jitter, modules_info): "Loader struct" ) # (ldrdata.get_size() - offset)) + last_module = modules_info.module2entry[ + modules_info.modules[-1]] + if main_pe: ldrdata.InLoadOrderModuleList.flink = main_addr_entry - ldrdata.InLoadOrderModuleList.blink = 0 + ldrdata.InLoadOrderModuleList.blink = last_module + ldrdata.InMemoryOrderModuleList.flink = main_addr_entry + \ LdrDataEntry.get_type().get_offset("InMemoryOrderLinks") - ldrdata.InMemoryOrderModuleList.blink = 0 - + ldrdata.InMemoryOrderModuleList.blink = last_module + \ + LdrDataEntry.get_type().get_offset("InMemoryOrderLinks") if ntdll_pe: ldrdata.InInitializationOrderModuleList.flink = ntdll_addr_entry + \ LdrDataEntry.get_type().get_offset("InInitializationOrderLinks") - ldrdata.InInitializationOrderModuleList.blink = 0 + ldrdata.InInitializationOrderModuleList.blink = last_module + \ + LdrDataEntry.get_type().get_offset("InInitializationOrderLinks") # Add dummy dll base jitter.vm.add_memory_page(peb_ldr_data_address + 0x24, @@ -312,9 +317,11 @@ def set_link_list_entry(jitter, loaded_modules, modules_info, offset): prev_module_entry = peb_ldr_data_address + 0xC if i == len(loaded_modules) - 1: next_module_entry = peb_ldr_data_address + 0xC - jitter.vm.set_mem(cur_module_entry + offset, - (pck32(next_module_entry + offset) + - pck32(prev_module_entry + offset))) + + list_entry = ListEntry(jitter.vm, cur_module_entry + offset) + list_entry.flink = next_module_entry + offset + list_entry.blink = prev_module_entry + offset + def fix_InLoadOrderModuleList(jitter, modules_info): diff --git a/miasm/runtime/divti3.c b/miasm/runtime/divti3.c new file mode 100644 index 00000000..fc5c1b4d --- /dev/null +++ b/miasm/runtime/divti3.c @@ -0,0 +1,36 @@ +/* ===-- divti3.c - Implement __divti3 -------------------------------------=== + * + * The LLVM Compiler Infrastructure + * + * This file is dual licensed under the MIT and the University of Illinois Open + * Source Licenses. See LICENSE.TXT for details. + * + * ===----------------------------------------------------------------------=== + * + * This file implements __divti3 for the compiler_rt library. + * + * ===----------------------------------------------------------------------=== + */ + +#if __x86_64 + +#include "int_lib.h" +#include "export.h" + +tu_int __udivmodti4(tu_int a, tu_int b, tu_int* rem); + +/* Returns: a / b */ + +ti_int +__divti3(ti_int a, ti_int b) +{ + const int bits_in_tword_m1 = (int)(sizeof(ti_int) * CHAR_BIT) - 1; + ti_int s_a = a >> bits_in_tword_m1; /* s_a = a < 0 ? -1 : 0 */ + ti_int s_b = b >> bits_in_tword_m1; /* s_b = b < 0 ? -1 : 0 */ + a = (a ^ s_a) - s_a; /* negate if s_a == -1 */ + b = (b ^ s_b) - s_b; /* negate if s_b == -1 */ + s_a ^= s_b; /* sign of quotient */ + return (__udivmodti4(a, b, (tu_int*)0) ^ s_a) - s_a; /* negate if s_a == -1 */ +} + +#endif diff --git a/miasm/runtime/export.h b/miasm/runtime/export.h new file mode 100644 index 00000000..f21a83a8 --- /dev/null +++ b/miasm/runtime/export.h @@ -0,0 +1,10 @@ +#ifndef MIASM_RT_EXPORT_H +#define MIASM_RT_EXPORT_H + +#ifdef _WIN32 +#define _MIASM_EXPORT __declspec(dllexport) +#else +#define _MIASM_EXPORT __attribute__((visibility("default"))) +#endif + +#endif diff --git a/miasm/runtime/int_endianness.h b/miasm/runtime/int_endianness.h new file mode 100644 index 00000000..def046c3 --- /dev/null +++ b/miasm/runtime/int_endianness.h @@ -0,0 +1,114 @@ +//===-- int_endianness.h - configuration header for compiler-rt -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file is a configuration header for compiler-rt. +// This file is not part of the interface of this library. +// +//===----------------------------------------------------------------------===// + +#ifndef INT_ENDIANNESS_H +#define INT_ENDIANNESS_H + +#if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \ + defined(__ORDER_LITTLE_ENDIAN__) + +// Clang and GCC provide built-in endianness definitions. +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define _YUGA_LITTLE_ENDIAN 0 +#define _YUGA_BIG_ENDIAN 1 +#elif __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ +#define _YUGA_LITTLE_ENDIAN 1 +#define _YUGA_BIG_ENDIAN 0 +#endif // __BYTE_ORDER__ + +#else // Compilers other than Clang or GCC. + +#if defined(__SVR4) && defined(__sun) +#include <sys/byteorder.h> + +#if defined(_BIG_ENDIAN) +#define _YUGA_LITTLE_ENDIAN 0 +#define _YUGA_BIG_ENDIAN 1 +#elif defined(_LITTLE_ENDIAN) +#define _YUGA_LITTLE_ENDIAN 1 +#define _YUGA_BIG_ENDIAN 0 +#else // !_LITTLE_ENDIAN +#error "unknown endianness" +#endif // !_LITTLE_ENDIAN + +#endif // Solaris and AuroraUX. + +// .. + +#if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__DragonFly__) || \ + defined(__minix) +#include <sys/endian.h> + +#if _BYTE_ORDER == _BIG_ENDIAN +#define _YUGA_LITTLE_ENDIAN 0 +#define _YUGA_BIG_ENDIAN 1 +#elif _BYTE_ORDER == _LITTLE_ENDIAN +#define _YUGA_LITTLE_ENDIAN 1 +#define _YUGA_BIG_ENDIAN 0 +#endif // _BYTE_ORDER + +#endif // *BSD + +#if defined(__OpenBSD__) +#include <machine/endian.h> + +#if _BYTE_ORDER == _BIG_ENDIAN +#define _YUGA_LITTLE_ENDIAN 0 +#define _YUGA_BIG_ENDIAN 1 +#elif _BYTE_ORDER == _LITTLE_ENDIAN +#define _YUGA_LITTLE_ENDIAN 1 +#define _YUGA_BIG_ENDIAN 0 +#endif // _BYTE_ORDER + +#endif // OpenBSD + +// .. + +// Mac OSX has __BIG_ENDIAN__ or __LITTLE_ENDIAN__ automatically set by the +// compiler (at least with GCC) +#if defined(__APPLE__) || defined(__ellcc__) + +#ifdef __BIG_ENDIAN__ +#if __BIG_ENDIAN__ +#define _YUGA_LITTLE_ENDIAN 0 +#define _YUGA_BIG_ENDIAN 1 +#endif +#endif // __BIG_ENDIAN__ + +#ifdef __LITTLE_ENDIAN__ +#if __LITTLE_ENDIAN__ +#define _YUGA_LITTLE_ENDIAN 1 +#define _YUGA_BIG_ENDIAN 0 +#endif +#endif // __LITTLE_ENDIAN__ + +#endif // Mac OSX + +// .. + +#if defined(_WIN32) + +#define _YUGA_LITTLE_ENDIAN 1 +#define _YUGA_BIG_ENDIAN 0 + +#endif // Windows + +#endif // Clang or GCC. + +// . + +#if !defined(_YUGA_LITTLE_ENDIAN) || !defined(_YUGA_BIG_ENDIAN) +#error Unable to determine endian +#endif // Check we found an endianness correctly. + +#endif // INT_ENDIANNESS_H diff --git a/miasm/runtime/int_lib.h b/miasm/runtime/int_lib.h new file mode 100644 index 00000000..7f5eb799 --- /dev/null +++ b/miasm/runtime/int_lib.h @@ -0,0 +1,148 @@ +//===-- int_lib.h - configuration header for compiler-rt -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file is a configuration header for compiler-rt. +// This file is not part of the interface of this library. +// +//===----------------------------------------------------------------------===// + +#ifndef INT_LIB_H +#define INT_LIB_H + +// Assumption: Signed integral is 2's complement. +// Assumption: Right shift of signed negative is arithmetic shift. +// Assumption: Endianness is little or big (not mixed). + +// ABI macro definitions + +#if __ARM_EABI__ +#ifdef COMPILER_RT_ARMHF_TARGET +#define COMPILER_RT_ABI +#else +#define COMPILER_RT_ABI __attribute__((__pcs__("aapcs"))) +#endif +#else +#define COMPILER_RT_ABI +#endif + +#define AEABI_RTABI __attribute__((__pcs__("aapcs"))) + +#if defined(_MSC_VER) && !defined(__clang__) +#define ALWAYS_INLINE __forceinline +#define NOINLINE __declspec(noinline) +#define NORETURN __declspec(noreturn) +#define UNUSED +#else +#define ALWAYS_INLINE __attribute__((always_inline)) +#define NOINLINE __attribute__((noinline)) +#define NORETURN __attribute__((noreturn)) +#define UNUSED __attribute__((unused)) +#endif + +#define STR(a) #a +#define XSTR(a) STR(a) +#define SYMBOL_NAME(name) XSTR(__USER_LABEL_PREFIX__) #name + +#if defined(__ELF__) || defined(__MINGW32__) || defined(__wasm__) +#define COMPILER_RT_ALIAS(name, aliasname) \ + COMPILER_RT_ABI __typeof(name) aliasname __attribute__((__alias__(#name))); +#elif defined(__APPLE__) +#if defined(VISIBILITY_HIDDEN) +#define COMPILER_RT_ALIAS_VISIBILITY(name) \ + __asm__(".private_extern " SYMBOL_NAME(name)); +#else +#define COMPILER_RT_ALIAS_VISIBILITY(name) +#endif +#define COMPILER_RT_ALIAS(name, aliasname) \ + __asm__(".globl " SYMBOL_NAME(aliasname)); \ + COMPILER_RT_ALIAS_VISIBILITY(aliasname) \ + __asm__(SYMBOL_NAME(aliasname) " = " SYMBOL_NAME(name)); \ + COMPILER_RT_ABI __typeof(name) aliasname; +#elif defined(_WIN32) +#define COMPILER_RT_ALIAS(name, aliasname) +#else +#error Unsupported target +#endif + +#if defined(__NetBSD__) && (defined(_KERNEL) || defined(_STANDALONE)) +// +// Kernel and boot environment can't use normal headers, +// so use the equivalent system headers. +// +#include <machine/limits.h> +#include <sys/stdint.h> +#include <sys/types.h> +#else +// Include the standard compiler builtin headers we use functionality from. +#include <float.h> +#include <limits.h> +#include <stdbool.h> +#include <stdint.h> +#endif + +// Include the commonly used internal type definitions. +#include "int_types.h" + +// Include internal utility function declarations. +#include "int_util.h" + +COMPILER_RT_ABI si_int __paritysi2(si_int a); +COMPILER_RT_ABI si_int __paritydi2(di_int a); + +COMPILER_RT_ABI di_int __divdi3(di_int a, di_int b); +COMPILER_RT_ABI si_int __divsi3(si_int a, si_int b); +COMPILER_RT_ABI su_int __udivsi3(su_int n, su_int d); + +COMPILER_RT_ABI su_int __udivmodsi4(su_int a, su_int b, su_int *rem); +COMPILER_RT_ABI du_int __udivmoddi4(du_int a, du_int b, du_int *rem); +#ifdef CRT_HAS_128BIT +COMPILER_RT_ABI si_int __clzti2(ti_int a); +COMPILER_RT_ABI tu_int __udivmodti4(tu_int a, tu_int b, tu_int *rem); +#endif + +// Definitions for builtins unavailable on MSVC +#if defined(_MSC_VER) && !defined(__clang__) +#include <intrin.h> + +uint32_t __inline __builtin_ctz(uint32_t value) { + unsigned long trailing_zero = 0; + if (_BitScanForward(&trailing_zero, value)) + return trailing_zero; + return 32; +} + +uint32_t __inline __builtin_clz(uint32_t value) { + unsigned long leading_zero = 0; + if (_BitScanReverse(&leading_zero, value)) + return 31 - leading_zero; + return 32; +} + +#if defined(_M_ARM) || defined(_M_X64) +uint32_t __inline __builtin_clzll(uint64_t value) { + unsigned long leading_zero = 0; + if (_BitScanReverse64(&leading_zero, value)) + return 63 - leading_zero; + return 64; +} +#else +uint32_t __inline __builtin_clzll(uint64_t value) { + if (value == 0) + return 64; + uint32_t msh = (uint32_t)(value >> 32); + uint32_t lsh = (uint32_t)(value & 0xFFFFFFFF); + if (msh != 0) + return __builtin_clz(msh); + return 32 + __builtin_clz(lsh); +} +#endif + +#define __builtin_clzl __builtin_clzll +#endif // defined(_MSC_VER) && !defined(__clang__) + +#endif // INT_LIB_H diff --git a/miasm/runtime/int_types.h b/miasm/runtime/int_types.h new file mode 100644 index 00000000..f89220d5 --- /dev/null +++ b/miasm/runtime/int_types.h @@ -0,0 +1,174 @@ +//===-- int_lib.h - configuration header for compiler-rt -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file is not part of the interface of this library. +// +// This file defines various standard types, most importantly a number of unions +// used to access parts of larger types. +// +//===----------------------------------------------------------------------===// + +#ifndef INT_TYPES_H +#define INT_TYPES_H + +#include "int_endianness.h" + +// si_int is defined in Linux sysroot's asm-generic/siginfo.h +#ifdef si_int +#undef si_int +#endif +typedef int si_int; +typedef unsigned su_int; + +typedef long long di_int; +typedef unsigned long long du_int; + +typedef union { + di_int all; + struct { +#if _YUGA_LITTLE_ENDIAN + su_int low; + si_int high; +#else + si_int high; + su_int low; +#endif // _YUGA_LITTLE_ENDIAN + } s; +} dwords; + +typedef union { + du_int all; + struct { +#if _YUGA_LITTLE_ENDIAN + su_int low; + su_int high; +#else + su_int high; + su_int low; +#endif // _YUGA_LITTLE_ENDIAN + } s; +} udwords; + +#if defined(__LP64__) || defined(__wasm__) || defined(__mips64) || \ + defined(__riscv) || defined(_WIN64) +#define CRT_HAS_128BIT +#endif + +// MSVC doesn't have a working 128bit integer type. Users should really compile +// compiler-rt with clang, but if they happen to be doing a standalone build for +// asan or something else, disable the 128 bit parts so things sort of work. +#if defined(_MSC_VER) && !defined(__clang__) +#undef CRT_HAS_128BIT +#endif + +#ifdef CRT_HAS_128BIT +typedef int ti_int __attribute__((mode(TI))); +typedef unsigned tu_int __attribute__((mode(TI))); + +typedef union { + ti_int all; + struct { +#if _YUGA_LITTLE_ENDIAN + du_int low; + di_int high; +#else + di_int high; + du_int low; +#endif // _YUGA_LITTLE_ENDIAN + } s; +} twords; + +typedef union { + tu_int all; + struct { +#if _YUGA_LITTLE_ENDIAN + du_int low; + du_int high; +#else + du_int high; + du_int low; +#endif // _YUGA_LITTLE_ENDIAN + } s; +} utwords; + +static __inline ti_int make_ti(di_int h, di_int l) { + twords r; + r.s.high = h; + r.s.low = l; + return r.all; +} + +static __inline tu_int make_tu(du_int h, du_int l) { + utwords r; + r.s.high = h; + r.s.low = l; + return r.all; +} + +#endif // CRT_HAS_128BIT + +typedef union { + su_int u; + float f; +} float_bits; + +typedef union { + udwords u; + double f; +} double_bits; + +typedef struct { +#if _YUGA_LITTLE_ENDIAN + udwords low; + udwords high; +#else + udwords high; + udwords low; +#endif // _YUGA_LITTLE_ENDIAN +} uqwords; + +// Check if the target supports 80 bit extended precision long doubles. +// Notably, on x86 Windows, MSVC only provides a 64-bit long double, but GCC +// still makes it 80 bits. Clang will match whatever compiler it is trying to +// be compatible with. +#if ((defined(__i386__) || defined(__x86_64__)) && !defined(_MSC_VER)) || \ + defined(__m68k__) || defined(__ia64__) +#define HAS_80_BIT_LONG_DOUBLE 1 +#else +#define HAS_80_BIT_LONG_DOUBLE 0 +#endif + +typedef union { + uqwords u; + long double f; +} long_double_bits; + +#if __STDC_VERSION__ >= 199901L +typedef float _Complex Fcomplex; +typedef double _Complex Dcomplex; +typedef long double _Complex Lcomplex; + +#define COMPLEX_REAL(x) __real__(x) +#define COMPLEX_IMAGINARY(x) __imag__(x) +#else +typedef struct { + float real, imaginary; +} Fcomplex; + +typedef struct { + double real, imaginary; +} Dcomplex; + +typedef struct { + long double real, imaginary; +} Lcomplex; + +#define COMPLEX_REAL(x) (x).real +#define COMPLEX_IMAGINARY(x) (x).imaginary +#endif +#endif // INT_TYPES_H diff --git a/miasm/runtime/int_util.h b/miasm/runtime/int_util.h new file mode 100644 index 00000000..5fbdfb57 --- /dev/null +++ b/miasm/runtime/int_util.h @@ -0,0 +1,31 @@ +//===-- int_util.h - internal utility functions ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file is not part of the interface of this library. +// +// This file defines non-inline utilities which are available for use in the +// library. The function definitions themselves are all contained in int_util.c +// which will always be compiled into any compiler-rt library. +// +//===----------------------------------------------------------------------===// + +#ifndef INT_UTIL_H +#define INT_UTIL_H + +/// \brief Trigger a program abort (or panic for kernel code). +#define compilerrt_abort() __compilerrt_abort_impl(__FILE__, __LINE__, __func__) + +NORETURN void __compilerrt_abort_impl(const char *file, int line, + const char *function); + +#define COMPILE_TIME_ASSERT(expr) COMPILE_TIME_ASSERT1(expr, __COUNTER__) +#define COMPILE_TIME_ASSERT1(expr, cnt) COMPILE_TIME_ASSERT2(expr, cnt) +#define COMPILE_TIME_ASSERT2(expr, cnt) \ + typedef char ct_assert_##cnt[(expr) ? 1 : -1] UNUSED + +#endif // INT_UTIL_H diff --git a/miasm/runtime/udivmodti4.c b/miasm/runtime/udivmodti4.c new file mode 100644 index 00000000..44a43be4 --- /dev/null +++ b/miasm/runtime/udivmodti4.c @@ -0,0 +1,196 @@ +//===-- udivmodti4.c - Implement __udivmodti4 -----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements __udivmodti4 for the compiler_rt library. +// +//===----------------------------------------------------------------------===// + +#include "int_lib.h" +#include "export.h" + +#ifdef CRT_HAS_128BIT + +// Effects: if rem != 0, *rem = a % b +// Returns: a / b + +// Translated from Figure 3-40 of The PowerPC Compiler Writer's Guide + +_MIASM_EXPORT tu_int __udivmodti4(tu_int a, tu_int b, tu_int *rem) { + const unsigned n_udword_bits = sizeof(du_int) * CHAR_BIT; + const unsigned n_utword_bits = sizeof(tu_int) * CHAR_BIT; + utwords n; + n.all = a; + utwords d; + d.all = b; + utwords q; + utwords r; + unsigned sr; + // special cases, X is unknown, K != 0 + if (n.s.high == 0) { + if (d.s.high == 0) { + // 0 X + // --- + // 0 X + if (rem) + *rem = n.s.low % d.s.low; + return n.s.low / d.s.low; + } + // 0 X + // --- + // K X + if (rem) + *rem = n.s.low; + return 0; + } + // n.s.high != 0 + if (d.s.low == 0) { + if (d.s.high == 0) { + // K X + // --- + // 0 0 + if (rem) + *rem = n.s.high % d.s.low; + return n.s.high / d.s.low; + } + // d.s.high != 0 + if (n.s.low == 0) { + // K 0 + // --- + // K 0 + if (rem) { + r.s.high = n.s.high % d.s.high; + r.s.low = 0; + *rem = r.all; + } + return n.s.high / d.s.high; + } + // K K + // --- + // K 0 + if ((d.s.high & (d.s.high - 1)) == 0) /* if d is a power of 2 */ { + if (rem) { + r.s.low = n.s.low; + r.s.high = n.s.high & (d.s.high - 1); + *rem = r.all; + } + return n.s.high >> __builtin_ctzll(d.s.high); + } + // K K + // --- + // K 0 + sr = __builtin_clzll(d.s.high) - __builtin_clzll(n.s.high); + // 0 <= sr <= n_udword_bits - 2 or sr large + if (sr > n_udword_bits - 2) { + if (rem) + *rem = n.all; + return 0; + } + ++sr; + // 1 <= sr <= n_udword_bits - 1 + // q.all = n.all << (n_utword_bits - sr); + q.s.low = 0; + q.s.high = n.s.low << (n_udword_bits - sr); + // r.all = n.all >> sr; + r.s.high = n.s.high >> sr; + r.s.low = (n.s.high << (n_udword_bits - sr)) | (n.s.low >> sr); + } else /* d.s.low != 0 */ { + if (d.s.high == 0) { + // K X + // --- + // 0 K + if ((d.s.low & (d.s.low - 1)) == 0) /* if d is a power of 2 */ { + if (rem) + *rem = n.s.low & (d.s.low - 1); + if (d.s.low == 1) + return n.all; + sr = __builtin_ctzll(d.s.low); + q.s.high = n.s.high >> sr; + q.s.low = (n.s.high << (n_udword_bits - sr)) | (n.s.low >> sr); + return q.all; + } + // K X + // --- + // 0 K + sr = 1 + n_udword_bits + __builtin_clzll(d.s.low) - + __builtin_clzll(n.s.high); + // 2 <= sr <= n_utword_bits - 1 + // q.all = n.all << (n_utword_bits - sr); + // r.all = n.all >> sr; + if (sr == n_udword_bits) { + q.s.low = 0; + q.s.high = n.s.low; + r.s.high = 0; + r.s.low = n.s.high; + } else if (sr < n_udword_bits) /* 2 <= sr <= n_udword_bits - 1 */ { + q.s.low = 0; + q.s.high = n.s.low << (n_udword_bits - sr); + r.s.high = n.s.high >> sr; + r.s.low = (n.s.high << (n_udword_bits - sr)) | (n.s.low >> sr); + } else /* n_udword_bits + 1 <= sr <= n_utword_bits - 1 */ { + q.s.low = n.s.low << (n_utword_bits - sr); + q.s.high = (n.s.high << (n_utword_bits - sr)) | + (n.s.low >> (sr - n_udword_bits)); + r.s.high = 0; + r.s.low = n.s.high >> (sr - n_udword_bits); + } + } else { + // K X + // --- + // K K + sr = __builtin_clzll(d.s.high) - __builtin_clzll(n.s.high); + // 0 <= sr <= n_udword_bits - 1 or sr large + if (sr > n_udword_bits - 1) { + if (rem) + *rem = n.all; + return 0; + } + ++sr; + // 1 <= sr <= n_udword_bits + // q.all = n.all << (n_utword_bits - sr); + // r.all = n.all >> sr; + q.s.low = 0; + if (sr == n_udword_bits) { + q.s.high = n.s.low; + r.s.high = 0; + r.s.low = n.s.high; + } else { + r.s.high = n.s.high >> sr; + r.s.low = (n.s.high << (n_udword_bits - sr)) | (n.s.low >> sr); + q.s.high = n.s.low << (n_udword_bits - sr); + } + } + } + // Not a special case + // q and r are initialized with: + // q.all = n.all << (n_utword_bits - sr); + // r.all = n.all >> sr; + // 1 <= sr <= n_utword_bits - 1 + su_int carry = 0; + for (; sr > 0; --sr) { + // r:q = ((r:q) << 1) | carry + r.s.high = (r.s.high << 1) | (r.s.low >> (n_udword_bits - 1)); + r.s.low = (r.s.low << 1) | (q.s.high >> (n_udword_bits - 1)); + q.s.high = (q.s.high << 1) | (q.s.low >> (n_udword_bits - 1)); + q.s.low = (q.s.low << 1) | carry; + // carry = 0; + // if (r.all >= d.all) + // { + // r.all -= d.all; + // carry = 1; + // } + const ti_int s = (ti_int)(d.all - r.all - 1) >> (n_utword_bits - 1); + carry = s & 1; + r.all -= d.all & s; + } + q.all = (q.all << 1) | carry; + if (rem) + *rem = r.all; + return q.all; +} + +#endif // CRT_HAS_128BIT diff --git a/miasm/runtime/udivti3.c b/miasm/runtime/udivti3.c new file mode 100644 index 00000000..3844dc9d --- /dev/null +++ b/miasm/runtime/udivti3.c @@ -0,0 +1,24 @@ +//===-- udivti3.c - Implement __udivti3 -----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements __udivti3 for the compiler_rt library. +// +//===----------------------------------------------------------------------===// + +#include "int_lib.h" +#include "export.h" + +#ifdef CRT_HAS_128BIT + +// Returns: a / b + +_MIASM_EXPORT tu_int __udivti3(tu_int a, tu_int b) { + return __udivmodti4(a, b, 0); +} + +#endif // CRT_HAS_128BIT |