about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--arch/x86.py2
-rw-r--r--compare.py215
-rw-r--r--gen_trace.py57
-rw-r--r--lldb_target.py9
-rwxr-xr-xmain.py33
-rw-r--r--snapshot.py9
-rw-r--r--symbolic.py31
-rw-r--r--trace_symbols.py45
8 files changed, 210 insertions, 191 deletions
diff --git a/arch/x86.py b/arch/x86.py
index 01c1631..25213a0 100644
--- a/arch/x86.py
+++ b/arch/x86.py
@@ -22,6 +22,8 @@ regnames = [
     'R14',
     'R15',
     'RFLAGS',
+    # Segment registers
+    'CS', 'DS', 'SS', 'ES', 'FS', 'GS',
     # FLAGS
     'CF', 'PF', 'AF', 'ZF', 'SF', 'TF', 'IF', 'DF', 'OF', 'IOPL', 'NT',
     # EFLAGS
diff --git a/compare.py b/compare.py
index a191025..8a25d8a 100644
--- a/compare.py
+++ b/compare.py
@@ -1,4 +1,5 @@
-from snapshot import ProgramState
+from snapshot import ProgramState, SnapshotSymbolResolver
+from symbolic import SymbolicTransform
 from utils import print_separator
 
 def calc_transformation(previous: ProgramState, current: ProgramState):
@@ -117,144 +118,76 @@ def compare_simple(test_states: list[ProgramState],
 
     return result
 
-def equivalent(val1, val2, transformation, previous_translation):
-    if val1 == val2:
-        return True
-
-    # TODO: maybe incorrect
-    return val1 - previous_translation == transformation
-
-def verify(translation: ProgramState, reference: ProgramState,
-           transformation: ProgramState, previous_translation: ProgramState):
-    assert(translation.arch == reference.arch)
-
-    if translation.regs["PC"] != reference.regs["PC"]:
-        return 1
-
-    print_separator()
-    print(f'For PC={translation.as_repr("PC")}')
-    print_separator()
-    for reg in translation.arch.regnames:
-        if translation.regs[reg] is None:
-            print(f'Element not available in translation: {reg}')
-        elif reference.regs[reg] is None:
-            print(f'Element not available in reference: {reg}')
-        elif not equivalent(translation.regs[reg], reference.regs[reg],
-                            transformation.regs[reg],
-                            previous_translation.regs[reg]):
-            txl = translation.as_repr(reg)
-            ref = reference.as_repr(reg)
-            print(f'Difference for {reg}: {txl} != {ref}')
-
-    return 0
-
-def compare(txl: list[ProgramState],
-            native: list[ProgramState],
-            progressive: bool = False,
-            stats: bool = False):
-    """Compare two lists of snapshots and output the differences.
-
-    :param txl: The translated, and possibly faulty, state of the program.
-    :param native: The 'correct' reference state of the program.
-    :param progressive:
-    :param stats:
-    """
+def find_errors_symbolic(txl_from: ProgramState,
+                         txl_to: ProgramState,
+                         transform_truth: SymbolicTransform) \
+        -> list[dict]:
+    arch = txl_from.arch
+    resolver = SnapshotSymbolResolver(txl_from)
 
-    if len(txl) != len(native):
-        print(f'Different numbers of blocks discovered: '
-              f'{len(txl)} in translation vs. {len(native)} in reference.')
-
-    previous_reference = native[0]
-    previous_translation = txl[0]
-
-    unmatched_pcs = {}
-    pc_to_skip = ""
-    if progressive:
-        i = 0
-        for translation in txl:
-            previous = i
-
-            while i < len(native):
-                reference = native[i]
-                transformation = calc_transformation(previous_reference, reference)
-                if verify(translation, reference, transformation, previous_translation) == 0:
-                    reference.matched = True
-                    break
-
-                i += 1
-
-            matched = True
-
-            # Didn't find anything
-            if i == len(native):
-                matched = False
-                # TODO: add verbose output
-                print_separator()
-                print(f'No match for PC {hex(translation.regs["PC"])}')
-                if translation.regs['PC'] not in unmatched_pcs:
-                    unmatched_pcs[translation.regs['PC']] = 0
-                unmatched_pcs[translation.regs['PC']] += 1
-
-                i = previous
-
-            # Necessary since we may have run out of native BBs to check and
-            # previous becomes len(native)
-            #
-            # We continue checking to report unmatched translation PCs
-            if i < len(native):
-                previous_reference = native[i]
-
-            previous_translation = translation
-
-            # Skip next reference when there is a backwards branch
-            # NOTE: if a reference was skipped, don't skip it again
-            #       necessary for loops which may have multiple backwards
-            #       branches
-            if translation.has_backwards and translation.regs['PC'] != pc_to_skip:
-                pc_to_skip = translation.regs['PC']
-                i += 1
-
-            if matched:
-                i += 1
-    else:
-        txl = iter(txl)
-        native = iter(native)
-        for translation, reference in zip(txl, native):
-            transformation = calc_transformation(previous_reference, reference)
-            if verify(translation, reference, transformation, previous_translation) == 1:
-                # TODO: add verbose output
-                print_separator()
-                print(f'No match for PC {hex(translation.regs["PC"])}')
-                if translation.regs['PC'] not in unmatched_pcs:
-                    unmatched_pcs[translation.regs['PC']] = 0
-                unmatched_pcs[translation.regs['PC']] += 1
-            else:
-                reference.matched = True
-
-            if translation.has_backwards:
-                next(native)
-
-            previous_reference = reference
-            previous_translation = translation
-
-    if stats:
-        print_separator()
-        print('Statistics:')
-        print_separator()
-
-        for pc in unmatched_pcs:
-            print(f'PC {hex(pc)} unmatched {unmatched_pcs[pc]} times')
-
-        # NOTE: currently doesn't handle mismatched due backward branches
-        current = ""
-        unmatched_count = 0
-        for ref in native:
-            ref_pc = ref.regs['PC']
-            if ref_pc != current:
-                if unmatched_count:
-                    print(f'Reference PC {hex(current)} unmatched {unmatched_count} times')
-                current = ref_pc
-
-            if ref.matched == False:
-                unmatched_count += 1
-    return 0
+    assert(txl_from.read('PC') == transform_truth.start_addr)
+    assert(txl_to.read('PC') == transform_truth.end_addr)
+
+    errors = []
+    for reg in arch.regnames:
+        if txl_from.read(reg) is None or txl_to.read(reg) is None:
+            print(f'A value for {reg} must be set in all translated states.'
+                  ' Skipping.')
+            continue
+
+        txl_val = txl_to.read(reg)
+        try:
+            truth = transform_truth.eval_register_transform(reg.lower(), resolver)
+            print(f'Evaluated symbolic formula to {hex(txl_val)} vs. txl {hex(txl_val)}')
+            if txl_val != truth:
+                errors.append({
+                    'reg': reg,
+                    'expected': truth,
+                    'actual': txl_val,
+                    'equation': transform_truth.state.regs.get(reg),
+                })
+        except AttributeError:
+            print(f'Register {reg} does not exist.')
+
+    return errors
+
+def compare_symbolic(test_states: list[ProgramState],
+                     transforms: list[SymbolicTransform]):
+    #assert(len(test_states) == len(transforms) - 1)
+    PC_REGNAME = 'PC'
+
+    result = [{
+        'pc': test_states[0].regs[PC_REGNAME],
+        'txl': test_states[0],
+        'ref': transforms[0],
+        'errors': []
+    }]
+
+    _list = zip(test_states[:-1], test_states[1:], transforms)
+    for cur_state, next_state, transform in _list:
+        pc_cur = cur_state.read(PC_REGNAME)
+        pc_next = next_state.read(PC_REGNAME)
+
+        # The program counter should always be set on a snapshot
+        assert(pc_cur is not None and pc_next is not None)
+
+        if pc_cur != transform.start_addr:
+            print(f'Program counter {hex(pc_cur)} in translated code has no'
+                  f' corresponding reference state! Skipping.'
+                  f' (reference: {hex(transform.start_addr)})')
+            continue
+        if pc_next != transform.end_addr:
+            print(f'Tested state transformation is {hex(pc_cur)} ->'
+                  f' {hex(pc_next)}, but reference transform is'
+                  f' {hex(transform.start_addr)} -> {hex(transform.end_addr)}!'
+                  f' Skipping.')
+
+        errors = find_errors_symbolic(cur_state, next_state, transform)
+        result.append({
+            'pc': pc_cur,
+            'txl': calc_transformation(cur_state, next_state),
+            'ref': transform,
+            'errors': errors
+        })
+
+    return result
diff --git a/gen_trace.py b/gen_trace.py
index 64fcf8f..ec5cb86 100644
--- a/gen_trace.py
+++ b/gen_trace.py
@@ -2,42 +2,55 @@ import argparse
 import lldb
 import lldb_target
 
-def parse_args():
-    prog = argparse.ArgumentParser()
-    prog.add_argument('binary',
-                      help='The executable to trace.')
-    prog.add_argument('-o', '--output',
-                      default='breakpoints',
-                      type=str,
-                      help='File to which the recorded trace is written.')
-    prog.add_argument('--args',
-                      default=[],
-                      nargs='+',
-                      help='Arguments to the executable.')
-    return prog.parse_args()
-
-def record_trace(binary: str, args: list[str] = []) -> list[int]:
+def record_trace(binary: str,
+                 args: list[str] = [],
+                 func_name: str | None = 'main') -> list[int]:
+    """
+    :param binary:    The binary file to execute.
+    :param args:      Arguments to the program. Should *not* include the
+                      executable's location as the usual first argument.
+    :param func_name: Only record trace of a specific function.
+    """
     # Set up LLDB target
     target = lldb_target.LLDBConcreteTarget(binary, args)
 
     # Skip to first instruction in `main`
-    result = lldb.SBCommandReturnObject()
-    break_at_main = f'b -b main -s {target.module.GetFileSpec().GetFilename()}'
-    target.interpreter.HandleCommand(break_at_main, result)
-    target.run()
+    if func_name is not None:
+        result = lldb.SBCommandReturnObject()
+        break_at_func = f'b -b {func_name} -s {target.module.GetFileSpec().GetFilename()}'
+        target.interpreter.HandleCommand(break_at_func, result)
+        target.run()
 
     # Run until main function is exited
     trace = []
     while not target.is_exited():
         thread = target.process.GetThreadAtIndex(0)
-        func_names = [thread.GetFrameAtIndex(i).GetFunctionName() for i in range(0, thread.GetNumFrames())]
-        if 'main' not in func_names:
-            break
+
+        # Break if the traced function is exited
+        if func_name is not None:
+            func_names = [thread.GetFrameAtIndex(i).GetFunctionName() \
+                          for i in range(0, thread.GetNumFrames())]
+            if func_name not in func_names:
+                break
         trace.append(target.read_register('pc'))
         thread.StepInstruction(False)
 
     return trace
 
+def parse_args():
+    prog = argparse.ArgumentParser()
+    prog.add_argument('binary',
+                      help='The executable to trace.')
+    prog.add_argument('-o', '--output',
+                      default='breakpoints',
+                      type=str,
+                      help='File to which the recorded trace is written.')
+    prog.add_argument('--args',
+                      default=[],
+                      nargs='+',
+                      help='Arguments to the executable.')
+    return prog.parse_args()
+
 def main():
     args = parse_args()
     trace = record_trace(args.binary, args.args)
diff --git a/lldb_target.py b/lldb_target.py
index 5477ab7..dd0d543 100644
--- a/lldb_target.py
+++ b/lldb_target.py
@@ -124,7 +124,7 @@ class LLDBConcreteTarget(ConcreteTarget):
             raise SimConcreteMemoryError(f'Error when writing to address'
                                          f' {hex(addr)}: {err}')
 
-    def get_mappings(self):
+    def get_mappings(self) -> list[MemoryMap]:
         mmap = []
 
         region_list = self.process.GetMemoryRegions()
@@ -134,11 +134,12 @@ class LLDBConcreteTarget(ConcreteTarget):
 
             perms = f'{"r" if region.IsReadable() else "-"}' \
                     f'{"w" if region.IsWritable() else "-"}' \
-                    f'{"x" if region.IsExecutable() else "-"}' \
+                    f'{"x" if region.IsExecutable() else "-"}'
+            name = region.GetName()
 
             mmap.append(MemoryMap(region.GetRegionBase(),
                                   region.GetRegionEnd(),
-                                  0,             # offset?
-                                  "<no-name>",   # name?
+                                  0,    # offset?
+                                  name if name is not None else '<none>',
                                   perms))
         return mmap
diff --git a/main.py b/main.py
index 9451e42..b0aeb36 100755
--- a/main.py
+++ b/main.py
@@ -4,8 +4,10 @@ import argparse
 
 import arancini
 from arch import x86
-from compare import compare_simple
+from compare import compare_simple, compare_symbolic
+from gen_trace import record_trace
 from run import run_native_execution
+from symbolic import collect_symbolic_trace
 from utils import check_version, print_separator
 
 def parse_inputs(txl_path, ref_path, program):
@@ -49,11 +51,12 @@ def parse_arguments():
                         action='store_true',
                         default=True,
                         help='Path to oracle program')
-    parser.add_argument('--progressive',
+    parser.add_argument('--symbolic',
                         action='store_true',
                         default=False,
-                        help='Try to match exhaustively before declaring \
-                        mismatch')
+                        help='Use an advanced algorithm that uses symbolic'
+                             ' execution to determine accurate data'
+                             ' transformations')
     args = parser.parse_args()
     return args
 
@@ -66,13 +69,12 @@ def main():
 
     stats = args.stats
     verbose = args.verbose
-    progressive = args.progressive
 
     if verbose:
         print("Enabling verbose program output")
         print(f"Verbose: {verbose}")
         print(f"Statistics: {stats}")
-        print(f"Progressive: {progressive}")
+        print(f"Symbolic: {args.symbolic}")
 
     if program is None and reference_path is None:
         raise ValueError('Either program or path to native file must be'
@@ -85,7 +87,18 @@ def main():
             for snapshot in ref:
                 print(snapshot, file=w)
 
-    result = compare_simple(txl, ref)
+    if args.symbolic:
+        assert(program is not None)
+
+        full_trace = record_trace(program, args=[])
+        transforms = collect_symbolic_trace(program, full_trace)
+        # TODO: Transform the traces so that the states match
+        result = compare_symbolic(txl, transforms)
+
+        raise NotImplementedError('The symbolic comparison algorithm is not'
+                                  ' supported yet.')
+    else:
+        result = compare_simple(txl, ref)
 
     # Print results
     for res in result:
@@ -104,6 +117,12 @@ def main():
                   f'    (txl) {reg}: {hex(txl.regs[reg])}\n'
                   f'    (ref) {reg}: {hex(ref.regs[reg])}')
 
+    print()
+    print('#' * 60)
+    print(f'Found {sum(len(res["errors"]) for res in result)} errors.')
+    print('#' * 60)
+    print()
+
 if __name__ == "__main__":
     check_version('3.7')
     main()
diff --git a/snapshot.py b/snapshot.py
index 01c6446..3170649 100644
--- a/snapshot.py
+++ b/snapshot.py
@@ -38,3 +38,12 @@ class ProgramState:
 
     def __repr__(self):
         return repr(self.regs)
+
+class SnapshotSymbolResolver(SymbolResolver):
+    def __init__(self, snapshot: ProgramState):
+        self._state = snapshot
+
+    def resolve(self, symbol: str):
+        if symbol not in self._state.arch.regnames:
+            raise SymbolResolveError(symbol, 'Symbol is not a register name.')
+        return self._state.read(symbol)
diff --git a/symbolic.py b/symbolic.py
index 53e1bbf..56857d7 100644
--- a/symbolic.py
+++ b/symbolic.py
@@ -5,7 +5,7 @@ import claripy as cp
 from angr.exploration_techniques import Symbion
 
 from arch import Arch, x86
-from interpreter import SymbolResolver
+from interpreter import eval as eval_symbol, SymbolResolver
 from lldb_target import LLDBConcreteTarget
 
 def symbolize_state(state: angr.SimState,
@@ -28,7 +28,7 @@ def symbolize_state(state: angr.SimState,
 
     if stack_name not in _exclude:
         symb_stack = cp.BVS(stack_name, stack_size * 8, explicit_name=True)
-        state.memory.store(state.regs.rbp - stack_size, symb_stack)
+        state.memory.store(state.regs.rsp - stack_size, symb_stack)
 
     for reg in arch.regnames:
         if reg not in _exclude:
@@ -68,7 +68,15 @@ class SymbolicTransform:
         self.end_addr = end_inst
 
     def eval_register_transform(self, regname: str, resolver: SymbolResolver):
-        raise NotImplementedError('TODO')
+        """
+        :param regname:  Name of the register to evaluate.
+        :param resolver: A provider for the values to be plugged into the
+                         symbolic equation.
+
+        :raise angr.SimConcreteRegisterError: If the state contains no register
+                                              named `regname`.
+        """
+        return eval_symbol(resolver, self.state.regs.get(regname))
 
     def __repr__(self) -> str:
         return f'Symbolic state transformation: \
@@ -87,7 +95,7 @@ def collect_symbolic_trace(binary: str, trace: list[int]) \
                         concrete_target=target,
                         use_sim_procedures=False)
 
-    entry_state = proj.factory.entry_state()
+    entry_state = proj.factory.entry_state(addr=trace[0])
     entry_state.options.add(angr.options.SYMBION_KEEP_STUBS_ON_SYNC)
     entry_state.options.add(angr.options.SYMBION_SYNC_CLE)
 
@@ -105,13 +113,26 @@ def collect_symbolic_trace(binary: str, trace: list[int]) \
         symbion = proj.factory.simgr(entry_state)
         symbion.use_technique(Symbion(find=[cur_inst]))
 
-        conc_exploration = symbion.run()
+        try:
+            if cur_inst != entry_state.addr:
+                conc_exploration = symbion.run()
+            else:
+                symbion.move('active', 'found')
+                conc_exploration = symbion
+        except angr.AngrError as err:
+            print(f'Angr error: {err} Returning partial result.')
+            return result
         conc_state = conc_exploration.found[0]
+        entry_state = conc_state
 
         concrete_states[conc_state.addr] = conc_state.copy()
 
         # Start symbolic execution with the concrete ('truth') state and try
         # to reach the next instruction in the trace
+        #
+        # -- Notes --
+        # It does not even work when I feed the entirely concrete state
+        # `conc_state` that I receive from Symbion into the symbolic simgr.
         simgr = proj.factory.simgr(symbolize_state(conc_state))
         symb_exploration = simgr.explore(find=next_inst)
 
diff --git a/trace_symbols.py b/trace_symbols.py
index e529522..6e7cb3b 100644
--- a/trace_symbols.py
+++ b/trace_symbols.py
@@ -9,7 +9,7 @@ from arch import x86
 from gen_trace import record_trace
 from interpreter import eval, SymbolResolver, SymbolResolveError
 from lldb_target import LLDBConcreteTarget
-from symbolic import symbolize_state, collect_symbolic_trace
+from symbolic import collect_symbolic_trace
 
 # Size of the memory region on the stack that is tracked symbolically
 # We track [rbp - STACK_SIZE, rbp).
@@ -95,12 +95,7 @@ def print_state(state: angr.SimState, file=sys.stdout, conc_state=None):
         print('<unable to read stack memory>', file=file)
     print('-' * 80, file=file)
 
-def parse_args():
-    prog = argparse.ArgumentParser()
-    prog.add_argument('binary', type=str)
-    return prog.parse_args()
-
-def collect_concrete_trace(binary: str) -> list[angr.SimState]:
+def collect_concrete_trace(binary: str, trace: list[int]) -> list[angr.SimState]:
     target = LLDBConcreteTarget(binary)
     proj = angr.Project(binary,
                         concrete_target=target,
@@ -110,29 +105,53 @@ def collect_concrete_trace(binary: str) -> list[angr.SimState]:
     state.options.add(angr.options.SYMBION_KEEP_STUBS_ON_SYNC)
     state.options.add(angr.options.SYMBION_SYNC_CLE)
 
+    # Remove first address from trace if it is the entry point.
+    # Symbion doesn't find an address if it's the current state.
+    if len(trace) > 0 and trace[0] == state.addr:
+        trace = trace[1:]
+
     result = []
 
-    trace = record_trace(binary)
     for inst in trace:
         symbion = proj.factory.simgr(state)
         symbion.use_technique(Symbion(find=[inst]))
 
-        conc_exploration = symbion.run()
+        try:
+            conc_exploration = symbion.run()
+        except angr.AngrError:
+            assert(target.is_exited())
+            break
         state = conc_exploration.found[0]
         result.append(state.copy())
 
     return result
 
+def parse_args():
+    prog = argparse.ArgumentParser()
+    prog.add_argument('binary', type=str)
+    prog.add_argument('--only-main', action='store_true', default=False)
+    return prog.parse_args()
+
 def main():
     args = parse_args()
     binary = args.binary
+    only_main = args.only_main
 
     # Generate a program trace from a real execution
-    concrete_trace = collect_concrete_trace(binary)
-    trace = [int(state.addr) for state in concrete_trace]
+    print('Collecting a program trace from a concrete execution...')
+    trace = record_trace(binary, [],
+                         func_name='main' if only_main else None)
     print(f'Found {len(trace)} trace points.')
 
-    symbolic_trace = collect_symbolic_trace(binary, trace)
+    print('Executing the trace to collect concrete program states...')
+    concrete_trace = collect_concrete_trace(binary, trace)
+
+    print('Re-tracing symbolically...')
+    try:
+        symbolic_trace = collect_symbolic_trace(binary, trace)
+    except KeyboardInterrupt:
+        print('Keyboard interrupt. Exiting.')
+        exit(0)
 
     with open('concrete.log', 'w') as conc_log:
         for state in concrete_trace:
@@ -141,6 +160,8 @@ def main():
         for conc, symb in zip(concrete_trace, symbolic_trace):
             print_state(symb.state, file=symb_log, conc_state=conc)
 
+    print('Written symbolic trace to "symbolic.log".')
+
 if __name__ == "__main__":
     main()
     print('\nDone.')