about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--miasm2/analysis/data_flow.py250
1 files changed, 248 insertions, 2 deletions
diff --git a/miasm2/analysis/data_flow.py b/miasm2/analysis/data_flow.py
index 22a07687..a0ff867b 100644
--- a/miasm2/analysis/data_flow.py
+++ b/miasm2/analysis/data_flow.py
@@ -3,7 +3,9 @@
 from collections import namedtuple
 from miasm2.core.graph import DiGraph
 from miasm2.ir.ir import AssignBlock, IRBlock
-from miasm2.expression.expression import ExprLoc, ExprMem, ExprId
+from miasm2.expression.expression import ExprLoc, ExprMem, ExprId, ExprInt
+from miasm2.expression.simplifications import expr_simp
+from miasm2.core.interval import interval
 
 class ReachingDefinitions(dict):
     """
@@ -636,7 +638,6 @@ def expr_has_call(expr):
 
 class PropagateExpr(object):
 
-
     def assignblk_is_propagation_barrier(self, assignblk):
         for dst, src in assignblk.iteritems():
             if expr_has_call(src):
@@ -766,3 +767,248 @@ class PropagateExpr(object):
                 new_block = IRBlock(block.loc_key, assignblks)
                 ssa.graph.blocks[block.loc_key] = new_block
         return modified
+
+
+def stack_to_reg(expr):
+    if expr.is_mem():
+        ptr = expr.arg
+        SP = ir_arch_a.sp
+        if ptr == SP:
+            return ExprId("STACK.0", expr.size)
+        elif (ptr.is_op('+') and
+              len(ptr.args) == 2 and
+              ptr.args[0] == SP and
+              ptr.args[1].is_int()):
+            diff = int(ptr.args[1])
+            assert diff % 4 == 0
+            diff = (0 - diff) & 0xFFFFFFFF
+            return ExprId("STACK.%d" % (diff / 4), expr.size)
+    return False
+
+
+def is_stack_access(ir_arch_a, expr):
+    if not expr.is_mem():
+        return False
+    ptr = expr.arg
+    diff = expr_simp(ptr - ir_arch_a.sp)
+    if not diff.is_int():
+        return False
+    return expr
+
+
+def visitor_get_stack_accesses(ir_arch_a, expr, stack_vars):
+    if is_stack_access(ir_arch_a, expr):
+        stack_vars.add(expr)
+    return expr
+
+
+def get_stack_accesses(ir_arch_a, expr):
+    result = set()
+    expr.visit(lambda expr:visitor_get_stack_accesses(ir_arch_a, expr, result))
+    return result
+
+
+def get_interval_length(interval_in):
+    length = 0
+    for start, stop in interval_in.intervals:
+        length += stop + 1 - start
+    return length
+
+
+def check_expr_below_stack(ir_arch_a, expr):
+    """
+    Return False if expr pointer is below original stack pointer
+    @ir_arch_a: ira instance
+    @expr: Expression instance
+    """
+    ptr = expr.arg
+    diff = expr_simp(ptr - ir_arch_a.sp)
+    if not diff.is_int():
+        return True
+    if int(diff) == 0 or int(expr_simp(diff.msb())) == 0:
+        return False
+    return True
+
+
+def retrieve_stack_accesses(ir_arch_a, ssa):
+    """
+    Walk the ssa graph and find stack based variables.
+    Return a dictionnary linking stack base address to its size/name
+    @ir_arch_a: ira instance
+    @ssa: SSADiGraph instance
+    """
+    stack_vars = set()
+    for block in ssa.graph.blocks.itervalues():
+        for assignblk in block:
+            for dst, src in assignblk.iteritems():
+                stack_vars.update(get_stack_accesses(ir_arch_a, dst))
+                stack_vars.update(get_stack_accesses(ir_arch_a, src))
+    stack_vars = filter(lambda expr: check_expr_below_stack(ir_arch_a, expr), stack_vars)
+
+    base_to_var = {}
+    for var in stack_vars:
+        base_to_var.setdefault(var.arg, set()).add(var)
+
+
+    base_to_interval = {}
+    for addr, vars in base_to_var.iteritems():
+        var_interval = interval()
+        for var in vars:
+            offset = expr_simp(addr - ir_arch_a.sp)
+            if not offset.is_int():
+                # skip non linear stack offset
+                continue
+
+            start = int(offset)
+            stop = int(expr_simp(offset + ExprInt(var.size / 8, offset.size)))
+            mem = interval([(start, stop-1)])
+            var_interval += mem
+        base_to_interval[addr] = var_interval
+    if not base_to_interval:
+        return {}
+    # Check if not intervals overlap
+    _, tmp = base_to_interval.popitem()
+    while base_to_interval:
+        addr, mem = base_to_interval.popitem()
+        assert (tmp & mem).empty
+        tmp += mem
+
+    base_to_info = {}
+    base_to_name = {}
+    for addr, vars in base_to_var.iteritems():
+        name = "var_%d" % (len(base_to_info))
+        size = max([var.size for var in vars])
+        base_to_info[addr] = size, name
+    return base_to_info
+
+
+def fix_stack_vars(expr, base_to_info):
+    """
+    Replace local stack accesses in expr using informations in @base_to_info
+    @expr: Expression instance
+    @base_to_info: dictionnary linking stack base address to its size/name
+    """
+    if not expr.is_mem():
+        return expr
+    ptr = expr.arg
+    if ptr not in base_to_info:
+        return expr
+    size, name = base_to_info[ptr]
+    var = ExprId(name, size)
+    if size == expr.size:
+        return var
+    assert expr.size < size
+    return var[:expr.size]
+
+
+def replace_mem_stack_vars(expr, base_to_info):
+    return expr.visit(lambda expr:fix_stack_vars(expr, base_to_info))
+
+
+def replace_stack_vars(ir_arch_a, ssa):
+    """
+    Try to replace stack based memory accesses by variables.
+    WARNING: may fail
+
+    @ir_arch_a: ira instance
+    @ssa: SSADiGraph instance
+    """
+    defuse = SSADefUse.from_ssa(ssa)
+
+    base_to_info = retrieve_stack_accesses(ir_arch_a, ssa)
+    stack_vars = {}
+    modified = False
+    for block in ssa.graph.blocks.itervalues():
+        assignblks = []
+        for assignblk in block:
+            out = {}
+            for dst, src in assignblk.iteritems():
+                new_dst = dst.visit(lambda expr:replace_mem_stack_vars(expr, base_to_info))
+                new_src = src.visit(lambda expr:replace_mem_stack_vars(expr, base_to_info))
+                if new_dst != dst or new_src != src:
+                    modified |= True
+
+                out[new_dst] = new_src
+
+            out = AssignBlock(out, assignblk.instr)
+            assignblks.append(out)
+        new_block = IRBlock(block.loc_key, assignblks)
+        ssa.graph.blocks[block.loc_key] = new_block
+    return modified
+
+
+def memlookup_test(expr, bs, is_addr_ro_variable, result):
+    if expr.is_mem() and expr.arg.is_int():
+        ptr = int(expr.arg)
+        if is_addr_ro_variable(bs, ptr, expr.size):
+            result.add(expr)
+        return False
+    return True
+
+
+def memlookup_visit(expr, bs, is_addr_ro_variable):
+    result = set()
+    expr.visit(lambda expr: expr,
+               lambda expr: memlookup_test(expr, bs, is_addr_ro_variable, result))
+    return result
+
+
+def get_memlookup(expr, bs, is_addr_ro_variable):
+    return memlookup_visit(expr, bs, is_addr_ro_variable)
+
+
+def read_mem(bs, expr):
+    ptr = int(expr.arg)
+    var_bytes = bs.getbytes(ptr, expr.size / 8)[::-1]
+    try:
+        value = int(var_bytes.encode('hex'), 16)
+    except:
+        return expr
+    return ExprInt(value, expr.size)
+
+
+def load_from_int(ir_arch, bs, is_addr_ro_variable):
+    """
+    Replace memory read based on constant with static value
+    @ir_arch: ira instance
+    @bs: binstream instance
+    @is_addr_ro_variable: callback(addr, size) to test memory candidate
+    """
+
+    modified = False
+    for label, block in ir_arch.blocks.iteritems():
+        assignblks = list()
+        for assignblk in block:
+            out = {}
+            for dst, src in assignblk.iteritems():
+                # Test src
+                mems = get_memlookup(src, bs, is_addr_ro_variable)
+                src_new = src
+                if mems:
+                    replace = {}
+                    for mem in mems:
+                        value = read_mem(bs, mem)
+                        replace[mem] = value
+                    src_new = src.replace_expr(replace)
+                    if src_new != src:
+                        modified = True
+                # Test dst pointer if dst is mem
+                if dst.is_mem():
+                    ptr = dst.arg
+                    mems = get_memlookup(ptr, bs, is_addr_ro_variable)
+                    ptr_new = ptr
+                    if mems:
+                        replace = {}
+                        for mem in mems:
+                            value = read_mem(bs, mem)
+                            replace[mem] = value
+                        ptr_new = ptr.replace_expr(replace)
+                        if ptr_new != ptr:
+                            modified = True
+                            dst = ExprMem(ptr_new, dst.size)
+                out[dst] = src_new
+            out = AssignBlock(out, assignblk.instr)
+            assignblks.append(out)
+        block = IRBlock(block.loc_key, assignblks)
+        ir_arch.blocks[block.loc_key] = block
+    return modified