about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--README.md7
-rw-r--r--interpreter.py125
-rw-r--r--lldb_target.py21
-rw-r--r--symbolic.py153
-rw-r--r--trace_symbols.py170
-rw-r--r--utils.py22
6 files changed, 427 insertions, 71 deletions
diff --git a/README.md b/README.md
index 04ef446..0397afe 100644
--- a/README.md
+++ b/README.md
@@ -29,6 +29,9 @@ more architectures later. Currently, we only have X86.
 The following files belong to a prototype of a data-dependency generator based on symbolic
 execution:
 
+ - `symbolic.py`: Algorithms and data structures to compute and manipulate symbolic program
+transformations.
+
  - `gen_trace.py`: An invokable tool that generates an instruction trace for an executable's native
 execution. Is imported into `trace_symbols.py`, which uses the core function that records a trace.
 
@@ -45,8 +48,8 @@ changes it has made to the program state before that instruction.
     3. Writes the program state at each instruction to log files; writes the concrete state of the
 real execution to 'concrete.log' and the symbolic difference to 'symbolic.log'.
 
-    This first version is very fragile. As soon as angr can't handle a branch instruction (which is
-the case for almost any branch instruction), it aborts with an error.
+ - `interpreter.py`: Contains an algorithm that evaluates a symbolic expression to a concrete value,
+using a reference state as input.
 
 ## Helpers
 
diff --git a/interpreter.py b/interpreter.py
new file mode 100644
index 0000000..8e876f5
--- /dev/null
+++ b/interpreter.py
@@ -0,0 +1,125 @@
+"""Interpreter for claripy ASTs"""
+
+from inspect import signature
+from logging import debug
+
+import claripy as cp
+
+class SymbolResolver:
+    def __init__(self):
+        pass
+
+    def resolve(self, symbol_name: str) -> cp.ast.Base:
+        raise NotImplementedError()
+
+class SymbolResolveError(Exception):
+    def __init__(self, symbol, reason: str = ""):
+        super().__init__(f'Unable to resolve symbol name \"{symbol}\" to a'
+                         ' concrete value'
+                         + f': {reason}' if len(reason) > 0 else '.')
+
+def eval(resolver: SymbolResolver, expr) -> int:
+    """Evaluate a claripy expression to a concrete value.
+
+    :param resolver: A `SymbolResolver` implementation that can resolve symbol
+                     names to concrete values.
+    :param expr:     The claripy AST to evaluate. Should be a subclass of
+                     `claripy.ast.Base`.
+
+    :return: A concrete value if the expression was resolved successfully.
+             If `expr` is not a claripy AST, `expr` is returned immediately.
+    :raise NotImplementedError:
+    :raise SymbolResolveError: If `resolver` is not able to resolve a symbol.
+    """
+    if not issubclass(type(expr), cp.ast.Base):
+        return expr
+
+    if expr.depth == 1:
+        if expr.symbolic:
+            name = expr._encoded_name.decode()
+            val = resolver.resolve(name)
+            if val is None:
+                raise SymbolResolveError(name)
+            return eval(resolver, val)
+        else: # if expr.concrete
+            assert(expr.concrete)
+            return expr.v
+
+    # Expression is a non-trivial AST, i.e. a function
+    return _eval_op(resolver, expr.op, *expr.args)
+
+def _eval_op(resolver: SymbolResolver, op, *args) -> int:
+    """Evaluate a claripy operator expression.
+
+    :param *args: Arguments to the function `op`. These are NOT evaluated yet!
+    """
+    assert(type(op) is str)
+
+    def concat(*vals):
+        res = 0
+        for val in vals:
+            assert(type(val) is cp.ast.BV)
+            res = res << val.length
+            res = res | eval(resolver, val)
+        return res
+
+    # Handle claripy's operators
+    if op == 'Concat':
+        res = concat(*args)
+        debug(f'Concatenating {args} to {hex(res)}')
+        return res
+    if op == 'Extract':
+        assert(len(args) == 3)
+        start, end, val = (eval(resolver, arg) for arg in args)
+        size = start - end + 1
+        res = (val >> end) & ((1 << size) - 1)
+        debug(f'Extracing range [{start}, {end}] from {hex(val)}: {hex(res)}')
+        return res
+    if op == 'If':
+        assert(len(args) == 3)
+        cond, iftrue, iffalse = (eval(resolver, arg) for arg in args)
+        debug(f'Evaluated branch condition {args[0]} to {cond}')
+        return iftrue if bool(cond) else iffalse
+    if op == 'Reverse':
+        assert(len(args) == 1)
+        return concat(*reversed(args[0].chop(8)))
+
+    # `op` is not one of claripy's special operators, so treat it as the name
+    # of a python operator function (because that is how claripy names its OR,
+    # EQ, etc.)
+
+    # Convert some of the non-python names to magic names
+    # NOTE: We use python's signed comparison operators for unsigned
+    #       comparisons. I'm not sure that this is legal.
+    if op in ['SGE', 'SGT', 'SLE', 'SLT', 'UGE', 'UGT', 'ULE', 'ULT']:
+        op = '__' + op[1:].lower() + '__'
+
+    if op in ['And', 'Or']:
+        op =  '__' + op.lower() + '__'
+
+    resolved_args = [eval(resolver, arg) for arg in args]
+    try:
+        func = getattr(int, op)
+    except AttributeError:
+        raise NotImplementedError(op)
+
+    # Sometimes claripy doesn't build its AST in an arity-respecting way if
+    # adjacent operations are associative. For example, it might pass five
+    # arguments to an XOR function instead of nesting the AST deeper.
+    #
+    # That's why we have to check with the python function's signature for its
+    # number of arguments and manually apply parentheses.
+    sig = signature(func)
+    assert(len(args) >= len(sig.parameters))
+
+    debug(f'Trying to evaluate function {func} with arguments {resolved_args}')
+    if len(sig.parameters) == len(args):
+        return func(*resolved_args)
+    else:
+        # Fold parameters from left by successively applying `op` to a
+        # subset of them
+        return _eval_op(resolver,
+                       op,
+                       func(*resolved_args[0:len(sig.parameters)]),
+                       *resolved_args[len(sig.parameters):]
+                       )
diff --git a/lldb_target.py b/lldb_target.py
index 1ff9f53..93efb7d 100644
--- a/lldb_target.py
+++ b/lldb_target.py
@@ -72,14 +72,22 @@ class LLDBConcreteTarget(ConcreteTarget):
         self.debugger.Terminate()
         print(f'Program exited with status {self.process.GetState()}')
 
-    def read_register(self, regname: str) -> int:
+    def _get_register(self, regname: str) -> lldb.SBValue:
+        """Find a register by name.
+
+        :raise SimConcreteRegisterError: If no register with the specified name
+                                         can be found.
+        """
         frame = self.process.GetThreadAtIndex(0).GetFrameAtIndex(0)
         reg = frame.FindRegister(regname)
         if reg is None:
             raise SimConcreteRegisterError(
-                f'[In LLDBConcreteTarget.read_register]: Register {regname}'
+                f'[In LLDBConcreteTarget._get_register]: Register {regname}'
                 f' not found.')
+        return reg
 
+    def read_register(self, regname: str) -> int:
+        reg = self._get_register(regname)
         val = reg.GetValue()
         if val is None:
             raise SimConcreteRegisterError(
@@ -88,6 +96,15 @@ class LLDBConcreteTarget(ConcreteTarget):
 
         return int(val, 16)
 
+    def write_register(self, regname: str, value: int):
+        reg = self._get_register(regname)
+        error = lldb.SBError()
+        reg.SetValueFromCString(hex(value), error)
+        if not error.success:
+            raise SimConcreteRegisterError(
+                f'[In LLDBConcreteTarget.write_register]: Unable to set'
+                f' {regname} to value {hex(value)}!')
+
     def read_memory(self, addr, size):
         err = lldb.SBError()
         content = self.process.ReadMemory(addr, size, err)
diff --git a/symbolic.py b/symbolic.py
new file mode 100644
index 0000000..a8d45d0
--- /dev/null
+++ b/symbolic.py
@@ -0,0 +1,153 @@
+"""Tools and utilities for symbolic execution with angr."""
+
+import angr
+import claripy as cp
+from angr.exploration_techniques import Symbion
+
+from arch import Arch, x86
+from interpreter import SymbolResolver
+from lldb_target import LLDBConcreteTarget
+
+def symbolize_state(state: angr.SimState,
+                    arch: Arch = x86.ArchX86(),
+                    exclude: list[str] = ['PC', 'RBP', 'RSP'],
+                    stack_name: str = 'stack',
+                    stack_size: int = 0x10) \
+        -> angr.SimState:
+    """Create a copy of a SimState and replace most of it with symbolic
+    values.
+
+    Leaves pc, rbp, and rsp concrete by default. This can be configured with
+    the `exclude` parameter. Add the string 'stack' to the exclude list to
+    prevent stack memory from being replaced with a symbolic buffer.
+
+    :return: A symbolized SymState object.
+    """
+    _exclude = set(exclude)
+    state = state.copy()
+
+    if stack_name not in _exclude:
+        symb_stack = cp.BVS(stack_name, stack_size * 8, explicit_name=True)
+        state.memory.store(state.regs.rbp - stack_size, symb_stack)
+
+    for reg in arch.regnames:
+        if reg not in _exclude:
+            symb_val = cp.BVS(reg, 64, explicit_name=True)
+            try:
+                state.regs.__setattr__(reg.lower(), symb_val)
+            except AttributeError:
+                pass
+    return state
+
+class SymbolicTransform:
+    def __init__(self,
+                 state: angr.SimState,
+                 first_inst: int,
+                 last_inst: int,
+                 end_inst: int):
+        """
+        :param state: The symbolic transformation in the form of a SimState
+                      object.
+        :param first_inst: An instruction address. The transformation operates
+                           on the program state *before* this instruction is
+                           executed.
+        :param last_inst:  An instruction address. The last instruction that
+                           is included in the transformation. This may be equal
+                           to `prev_state` if the `SymbolicTransform`
+                           represents the work done by a single instruction.
+                           The transformation includes all instructions in the
+                           range `[first_inst, last_inst]` (note the inclusive
+                           right bound) of the specific program trace.
+        :param end_inst:   An instruction address. The address of the *next*
+                           instruction executed on the state that results from
+                           the transformation.
+        """
+        self.state = state
+        self.start_addr = first_inst
+        self.last_inst = last_inst
+        self.end_addr = end_inst
+
+    def eval_register_transform(self, regname: str, resolver: SymbolResolver):
+        raise NotImplementedError('TODO')
+
+    def __repr__(self) -> str:
+        return f'Symbolic state transformation: \
+                 {hex(self.start_addr)} -> {hex(self.end_addr)}'
+
+def collect_symbolic_trace(binary: str, trace: list[int]) \
+    -> list[SymbolicTransform]:
+    """Execute a program and compute state transformations between executed
+    instructions.
+
+    :param binary: The binary to trace.
+    :param trace:  A program trace that symbolic execution shall follow.
+    """
+    target = LLDBConcreteTarget(binary)
+    proj = angr.Project(binary,
+                        concrete_target=target,
+                        use_sim_procedures=False)
+
+    entry_state = proj.factory.entry_state()
+    entry_state.options.add(angr.options.SYMBION_KEEP_STUBS_ON_SYNC)
+    entry_state.options.add(angr.options.SYMBION_SYNC_CLE)
+
+    # We keep a history of concrete states at their addresses because of the
+    # backtracking approach described below.
+    concrete_states = {}
+
+    # All recorded symbolic transformations
+    result = []
+
+    for (cur_idx, cur_inst), next_inst in zip(enumerate(trace[0:-1]), trace[1:]):
+        # The last instruction included in the generated transformation
+        last_inst = cur_inst
+
+        symbion = proj.factory.simgr(entry_state)
+        symbion.use_technique(Symbion(find=[cur_inst]))
+
+        conc_exploration = symbion.run()
+        conc_state = conc_exploration.found[0]
+
+        concrete_states[conc_state.addr] = conc_state.copy()
+
+        # Start symbolic execution with the concrete ('truth') state and try
+        # to reach the next instruction in the trace
+        simgr = proj.factory.simgr(symbolize_state(conc_state))
+        symb_exploration = simgr.explore(find=next_inst)
+
+        # Symbolic execution can't handle starting at some jump instructions.
+        # When this occurs, we re-start symbolic execution at an earlier
+        # instruction.
+        #
+        # Example:
+        #   0x401155      cmp   -0x4(%rbp),%eax
+        #   0x401158      jle   0x401162
+        #   ...
+        #   0x401162      addl  $0x1337,-0xc(%rbp)
+        #
+        # Here, symbolic execution can't find a valid state at `0x401162` when
+        # starting at `0x401158`, but it finds it successfully when starting at
+        # `0x401155`.
+        while len(symb_exploration.found) == 0 and cur_idx > 0:
+            print(f'[INFO] Symbolic execution can\'t reach address'
+                  f' {hex(next_inst)} from {hex(cur_inst)}.'
+                  f' Attempting to reach it from {hex(trace[cur_idx - 1])}...')
+            cur_idx -= 1
+            cur_inst = trace[cur_idx]
+            conc_state = concrete_states[cur_inst]
+            simgr = proj.factory.simgr(symbolize_state(conc_state))
+            symb_exploration = simgr.explore(find=next_inst)
+
+        if len(symb_exploration.found) == 0:
+            print(f'Symbolic execution can\'t reach address {hex(next_inst)}.'
+                  ' Exiting.')
+            exit(1)
+
+        result.append(SymbolicTransform(
+            symb_exploration.found[0],
+            cur_inst,
+            last_inst,
+            next_inst
+        ))
+
+    return result
diff --git a/trace_symbols.py b/trace_symbols.py
index c16cd6e..e529522 100644
--- a/trace_symbols.py
+++ b/trace_symbols.py
@@ -1,110 +1,146 @@
-import angr
 import argparse
-import claripy as cp
 import sys
 
+import angr
+import claripy as cp
 from angr.exploration_techniques import Symbion
+
 from arch import x86
 from gen_trace import record_trace
+from interpreter import eval, SymbolResolver, SymbolResolveError
 from lldb_target import LLDBConcreteTarget
+from symbolic import symbolize_state, collect_symbolic_trace
+
+# Size of the memory region on the stack that is tracked symbolically
+# We track [rbp - STACK_SIZE, rbp).
+STACK_SIZE = 0x10
+
+STACK_SYMBOL_NAME = 'stack'
+
+class SimStateResolver(SymbolResolver):
+    """A symbol resolver that resolves symbol names to program state in
+    `angr.SimState` objects.
+    """
+    def __init__(self, state: angr.SimState):
+        self._state = state
+
+    def resolve(self, symbol_name: str) -> cp.ast.Base:
+        # Process special (non-register) symbol names
+        if symbol_name == STACK_SYMBOL_NAME:
+            assert(self._state.regs.rbp.concrete)
+            assert(type(self._state.regs.rbp.v) is int)
+            rbp = self._state.regs.rbp.v
+            return self._state.memory.load(rbp - STACK_SIZE, STACK_SIZE)
+
+        # Try to interpret the symbol as a register name
+        try:
+            return self._state.regs.get(symbol_name.lower())
+        except AttributeError:
+            raise SymbolResolveError(symbol_name,
+                                     f'[SimStateResolver]: No attribute'
+                                     f' {symbol_name} in program state.')
+
+def print_state(state: angr.SimState, file=sys.stdout, conc_state=None):
+    """Print a program state in a fancy way.
+
+    :param conc_state: Provide a concrete program state as a reference to
+                       evaluate all symbolic values in `state` and print their
+                       concrete values in addition to the symbolic expression.
+    """
+    if conc_state is not None:
+        resolver = SimStateResolver(conc_state)
+    else:
+        resolver = None
 
-def print_state(state: angr.SimState, file=sys.stdout):
-    """Print a program state in a fancy way."""
     print('-' * 80, file=file)
     print(f'State at {hex(state.addr)}:', file=file)
     print('-' * 80, file=file)
     for reg in x86.regnames:
         try:
-            val = state.regs.__getattr__(reg.lower())
+            val = state.regs.get(reg.lower())
+        except angr.SimConcreteRegisterError: val = '<inaccessible>'
+        except angr.SimConcreteMemoryError:   val = '<inaccessible>'
+        except AttributeError:                val = '<inaccessible>'
+        except KeyError:                      val = '<inaccessible>'
+        if resolver is not None:
+            concrete_value = eval(resolver, val)
+            if type(concrete_value) is int:
+                concrete_value = hex(concrete_value)
+            print(f'{reg} = {val} ({concrete_value})', file=file)
+        else:
             print(f'{reg} = {val}', file=file)
-        except angr.SimConcreteRegisterError: pass
-        except angr.SimConcreteMemoryError: pass
-        except AttributeError: pass
-        except KeyError: pass
 
     # Print some of the stack
     print('\nStack:', file=file)
     try:
+        # Ensure that the base pointer is concrete
         rbp = state.regs.rbp
-        stack_size = 0xc
-        stack_mem = state.memory.load(rbp - stack_size, stack_size)
+        if not rbp.concrete:
+            if resolver is None:
+                raise SymbolResolveError(rbp,
+                                         '[In print_state]: rbp is symbolic,'
+                                         ' but no resolver is defined. Can\'t'
+                                         ' print stack.')
+            else:
+                rbp = eval(resolver, rbp)
+
+        stack_mem = state.memory.load(rbp - STACK_SIZE, STACK_SIZE)
+
+        if resolver is not None:
+            print(hex(eval(resolver, stack_mem)), file=file)
         print(stack_mem, file=file)
         stack = state.solver.eval(stack_mem, cast_to=bytes)
         print(' '.join(f'{b:02x}' for b in stack[::-1]), file=file)
     except angr.SimConcreteMemoryError:
-        print('<unable to read memory>', file=file)
+        print('<unable to read stack memory>', file=file)
     print('-' * 80, file=file)
 
-def symbolize_state(state: angr.SimState,
-                    exclude: list[str] = ['PC', 'RBP', 'RSP']) \
-        -> angr.SimState:
-    """Create a copy of a SimState and replace most of it with symbolic
-    values.
-
-    Leaves pc, rbp, and rsp concrete by default. This can be configured with
-    the `exclude` parameter.
-
-    :return: A symbolized SymState object.
-    """
-    state = state.copy()
-
-    stack_size = 0xc
-    symb_stack = cp.BVS('stack', stack_size * 8)
-    state.memory.store(state.regs.rbp - stack_size, symb_stack)
-
-    _exclude = set(exclude)
-    for reg in x86.regnames:
-        if reg not in _exclude:
-            symb_val = cp.BVS(reg, 64)
-            try:
-                state.regs.__setattr__(reg.lower(), symb_val)
-            except AttributeError:
-                pass
-    return state
-
 def parse_args():
     prog = argparse.ArgumentParser()
     prog.add_argument('binary', type=str)
     return prog.parse_args()
 
-def main():
-    args = parse_args()
-    binary = args.binary
-
-    conc_log = open('concrete.log', 'w')
-    symb_log = open('symbolic.log', 'w')
-
-    # Generate a program trace from a real execution
-    trace = record_trace(binary)
-    print(f'Found {len(trace)} trace points.')
-
+def collect_concrete_trace(binary: str) -> list[angr.SimState]:
     target = LLDBConcreteTarget(binary)
     proj = angr.Project(binary,
                         concrete_target=target,
                         use_sim_procedures=False)
 
-    entry_state = proj.factory.entry_state()
-    entry_state.options.add(angr.options.SYMBION_KEEP_STUBS_ON_SYNC)
-    entry_state.options.add(angr.options.SYMBION_SYNC_CLE)
+    state = proj.factory.entry_state()
+    state.options.add(angr.options.SYMBION_KEEP_STUBS_ON_SYNC)
+    state.options.add(angr.options.SYMBION_SYNC_CLE)
+
+    result = []
 
-    for cur_inst, next_inst in zip(trace[0:-1], trace[1:]):
-        symbion = proj.factory.simgr(entry_state)
-        symbion.use_technique(Symbion(find=[cur_inst]))
+    trace = record_trace(binary)
+    for inst in trace:
+        symbion = proj.factory.simgr(state)
+        symbion.use_technique(Symbion(find=[inst]))
 
         conc_exploration = symbion.run()
-        conc_state = conc_exploration.found[0]
+        state = conc_exploration.found[0]
+        result.append(state.copy())
+
+    return result
+
+def main():
+    args = parse_args()
+    binary = args.binary
+
+    # Generate a program trace from a real execution
+    concrete_trace = collect_concrete_trace(binary)
+    trace = [int(state.addr) for state in concrete_trace]
+    print(f'Found {len(trace)} trace points.')
 
-        # Start symbolic execution with the concrete ('truth') state and try
-        # to reach the next instruction in the trace
-        simgr = proj.factory.simgr(symbolize_state(conc_state))
-        symb_exploration = simgr.explore(find=next_inst)
-        if len(symb_exploration.found) == 0:
-            print(f'Symbolic execution can\'t reach address {hex(next_inst)}'
-                  f' from {hex(cur_inst)}. Exiting.')
-            exit(1)
+    symbolic_trace = collect_symbolic_trace(binary, trace)
 
-        print_state(conc_state, file=conc_log)
-        print_state(symb_exploration.found[0], file=symb_log)
+    with open('concrete.log', 'w') as conc_log:
+        for state in concrete_trace:
+            print_state(state, file=conc_log)
+    with open('symbolic.log', 'w') as symb_log:
+        for conc, symb in zip(concrete_trace, symbolic_trace):
+            print_state(symb.state, file=symb_log, conc_state=conc)
 
 if __name__ == "__main__":
     main()
+    print('\nDone.')
diff --git a/utils.py b/utils.py
index 1390283..f2c2256 100644
--- a/utils.py
+++ b/utils.py
@@ -14,3 +14,25 @@ def check_version(version: str):
     if sys.version_info.major < major and sys.version_info.minor < minor:
         raise EnvironmentError("Expected at least Python 3.7")
 
+def to_str(expr):
+    """Convert a claripy expression to a nice string representation.
+
+    Actually, the resulting representation is not very nice at all. It mostly
+    serves debugging purposes.
+    """
+    import claripy
+
+    if not issubclass(type(expr), claripy.ast.Base):
+        return f'{type(expr)}[{str(expr)}]'
+
+    assert(expr.depth > 0)
+    if expr.depth == 1:
+        if expr.symbolic:
+            name = expr._encoded_name.decode()
+            return f'symbol[{name}]'
+        else:
+            assert(expr.concrete)
+            return f'value{expr.length}[{hex(expr.v)}]'
+
+    args = [to_str(child) for child in expr.args]
+    return f'expr[{str(expr.op)}({", ".join(args)})]'