diff options
| -rw-r--r-- | .gitignore | 9 | ||||
| -rw-r--r-- | .gitmodules | 3 | ||||
| -rw-r--r-- | README.md | 80 | ||||
| -rwxr-xr-x | compare.py | 354 | ||||
| m--------- | cpuid | 0 | ||||
| -rwxr-xr-x | focaccia.py | 222 | ||||
| -rw-r--r-- | focaccia/__init__.py | 0 | ||||
| -rw-r--r-- | focaccia/arch/__init__.py | 14 | ||||
| -rw-r--r-- | focaccia/arch/aarch64.py | 123 | ||||
| -rw-r--r-- | focaccia/arch/arch.py | 84 | ||||
| -rw-r--r-- | focaccia/arch/x86.py | 200 | ||||
| -rw-r--r-- | focaccia/compare.py | 302 | ||||
| -rw-r--r-- | focaccia/lldb_target.py | 315 | ||||
| -rw-r--r-- | focaccia/match.py | 105 | ||||
| -rw-r--r-- | focaccia/miasm_util.py | 253 | ||||
| -rw-r--r-- | focaccia/parser.py | 172 | ||||
| -rw-r--r-- | focaccia/reproducer.py | 172 | ||||
| -rw-r--r-- | focaccia/snapshot.py | 180 | ||||
| -rw-r--r-- | focaccia/symbolic.py | 692 | ||||
| -rw-r--r-- | focaccia/trace.py | 74 | ||||
| -rw-r--r-- | focaccia/utils.py | 116 | ||||
| -rw-r--r-- | nix.shell | 12 | ||||
| -rw-r--r-- | requirements.txt | 1 | ||||
| -rwxr-xr-x | run.py | 214 | ||||
| -rw-r--r-- | test/test_snapshot.py | 74 | ||||
| -rw-r--r-- | test/test_sparse_memory.py | 33 | ||||
| -rw-r--r-- | tools/_qemu_tool.py | 314 | ||||
| -rwxr-xr-x | tools/capture_transforms.py | 27 | ||||
| -rwxr-xr-x | tools/convert.py | 49 | ||||
| -rwxr-xr-x | tools/verify_qemu.py | 106 | ||||
| -rw-r--r-- | utils.py | 18 |
31 files changed, 3730 insertions, 588 deletions
diff --git a/.gitignore b/.gitignore index ee32d85..ea2880a 100644 --- a/.gitignore +++ b/.gitignore @@ -3,9 +3,16 @@ build* *.md *.out *.txt +!requirements.txt *.bin *.dot build*/ out-*/ -__pycache__/* +__pycache__/ +# Dev environment +.gdbinit + +# Focaccia files +qemu.sym +qemu.trace diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..a6d7f14 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "cpuid"] + path = cpuid + url = https://github.com/flababah/cpuid.py.git diff --git a/README.md b/README.md index 7cf64cd..67db62c 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,83 @@ -# DBT Testing +# Focaccia This repository contains initial code for comprehensive testing of binary translators. +## Requirements + +For Python dependencies, see the `requirements.txt`. We also require at least LLDB version 17 for `fs_base`/`gs_base` +register support. + +I had to compile LLDB myself; these are the steps I had to take (you also need swig version >= 4): + +``` +git clone https://github.com/llvm/llvm-project <llvm-path> +cd <llvm-path> +cmake -S llvm -B build -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_PROJECTS="clang;lldb" -DLLDB_ENABLE_PYTHON=TRUE -DLLDB_ENABLE_SWIG=TRUE +cmake --build build/ --parallel $(nproc) + +# Add the built LLDB python bindings to your PYTHONPATH: +PYTHONPATH="$PYTHONPATH:$(./build/bin/lldb -P)" +``` + +It will take a while to compile. + +## How To Use + +`focaccia.py` is the main executable. Invoke `focaccia.py --help` to see what you can do with it. + +## Tools + +The `tools/` directory contains additional utility scripts to work with focaccia. + + - `convert.py`: Convert logs from QEMU or Arancini to focaccia's snapshot log format. + +## Project Overview (for developers) + +### Snapshots and comparison + +The following files belong to a rough framework for the snapshot comparison engine: + + - `focaccia/snapshot.py`: Structures used to work with snapshots. The `ProgramState` class is our primary +representation of program snapshots. + + - `focaccia/compare.py`: The central algorithms that work on snapshots. + + - `focaccia/arch/`: Abstractions over different processor architectures. Currently we have x86 and aarch64. + +### Concolic execution + +The following files belong to a prototype of a data-dependency generator based on symbolic +execution: + + - `focaccia/symbolic.py`: Algorithms and data structures to compute and manipulate symbolic program transformations. +This handles the symbolic part of "concolic" execution. + + - `focaccia/lldb_target.py`: Tools for executing a program concretely and tracking its execution using +[LLDB](https://lldb.llvm.org/). This handles the concrete part of "concolic" execution. + + - `focaccia/miasm_util.py`: Tools to evaluate Miasm's symbolic expressions based on a concrete state. Ties the symbolic +and concrete parts together into "concolic" execution. + +### Helpers + + - `focaccia/parser.py`: Utilities for parsing logs from Arancini and QEMU, as well as serializing/deserializing to/from +our own log format. + + - `focaccia/match.py`: Algorithms for trace matching. + +### Supporting new architectures + +To add support for an architecture <arch>, do the following: + + - Add a file `focaccia/arch/<arch>.py`. This module declares the architecture's description, such as register names and +an architecture class. The convention is to declare state flags (e.g. flags in RFLAGS for x86) as separate registers. + + - Add the class to the `supported_architectures` dict in `focaccia/arch/__init__.py`. + + - Depending on Miasm's support for <arch>, add register name aliases to the `MiasmSymbolResolver.miasm_flag_aliases` +dict in `focaccia/miasm_util.py`. + + - Depending on the existence of a flags register in <arch>, implement conversion from the flags register's value to +values of single logical flags (e.g. implement the operation `RFLAGS['OF']`) in the respective concrete targets (LLDB, +GDB, ...). diff --git a/compare.py b/compare.py deleted file mode 100755 index ffd1e93..0000000 --- a/compare.py +++ /dev/null @@ -1,354 +0,0 @@ -#! /bin/python3 -import re -import sys -import shutil -import argparse -from typing import List -from functools import partial as bind - -from utils import check_version -from utils import print_separator - -from run import Runner - -progressive = False - -class ContextBlock: - regnames = ['PC', - 'RAX', - 'RBX', - 'RCX', - 'RDX', - 'RSI', - 'RDI', - 'RBP', - 'RSP', - 'R8', - 'R9', - 'R10', - 'R11', - 'R12', - 'R13', - 'R14', - 'R15', - 'flag ZF', - 'flag CF', - 'flag OF', - 'flag SF', - 'flag PF', - 'flag DF'] - - def __init__(self): - self.regs = {reg: None for reg in ContextBlock.regnames} - self.has_backwards = False - self.matched = False - - def set_backwards(self): - self.has_backwards = True - - def set(self, idx: int, value: int): - self.regs[list(self.regs.keys())[idx]] = value - - def __repr__(self): - return self.regs.__repr__() - -class Constructor: - def __init__(self, structure: dict): - self.cblocks = [] - self.structure = structure - self.patterns = list(self.structure.keys()) - - def match(self, line: str): - # find patterns that match it - regex = re.compile("|".join(self.patterns)) - match = regex.match(line) - - idx = self.patterns.index(match.group(0)) if match else 0 - - pattern = self.patterns[idx] - register = ContextBlock.regnames[idx] - - return register, self.structure[pattern](line) - - def add_backwards(self): - self.cblocks[-1].set_backwards() - - def add(self, key: str, value: int): - if key == 'PC': - self.cblocks.append(ContextBlock()) - - if self.cblocks[-1].regs[key] != None: - raise RuntimeError("Reassigning register") - - self.cblocks[-1].regs[key] = value - -class Transformations: - def __init__(self, previous: ContextBlock, current: ContextBlock): - self.transformation = ContextBlock() - for el1 in current.regs.keys(): - for el2 in previous.regs.keys(): - if el1 != el2: - continue - self.transformation.regs[el1] = current.regs[el1] - previous.regs[el2] - -def parse(lines: list, labels: list): - ctor = Constructor(labels) - - patterns = ctor.patterns.copy() - patterns.append('Backwards') - regex = re.compile("|".join(patterns)) - lines = [l for l in lines if regex.match(l) is not None] - - for line in lines: - if 'Backwards' in line: - ctor.add_backwards() - continue - - key, value = ctor.match(line) - ctor.add(key, value) - - return ctor.cblocks - -def get_labels(): - split_value = lambda x,i: int(x.split()[i], 16) - - split_first = bind(split_value, i=1) - split_second = bind(split_value, i=2) - - split_equal = lambda x,i: int(x.split('=')[i], 16) - labels = {'INVOKE': bind(split_equal, i=1), - 'RAX': split_first, - 'RBX': split_first, - 'RCX': split_first, - 'RDX': split_first, - 'RSI': split_first, - 'RDI': split_first, - 'RBP': split_first, - 'RSP': split_first, - 'R8': split_first, - 'R9': split_first, - 'R10': split_first, - 'R11': split_first, - 'R12': split_first, - 'R13': split_first, - 'R14': split_first, - 'R15': split_first, - 'flag ZF': split_second, - 'flag CF': split_second, - 'flag OF': split_second, - 'flag SF': split_second, - 'flag PF': split_second, - 'flag DF': split_second} - return labels - -def equivalent(val1, val2, transformation, previous_translation): - if val1 == val2: - return True - - # TODO: maybe incorrect - return val1 - previous_translation == transformation - -def verify(translation: ContextBlock, reference: ContextBlock, - transformation: Transformations, previous_translation: ContextBlock): - if translation.regs["PC"] != reference.regs["PC"]: - return 1 - - print_separator() - print(f'For PC={hex(translation.regs["PC"])}') - print_separator() - for el1 in translation.regs.keys(): - for el2 in reference.regs.keys(): - if el1 != el2: - continue - - if translation.regs[el1] is None: - print(f'Element not available in translation: {el1}') - continue - - if reference.regs[el2] is None: - print(f'Element not available in reference: {el2}') - continue - - if not equivalent(translation.regs[el1], reference.regs[el2], - transformation.regs[el1], previous_translation.regs[el1]): - print(f'Difference for {el1}: {hex(translation.regs[el1])} != {hex(reference.regs[el2])}') - return 0 - -def compare(txl: List[ContextBlock], native: List[ContextBlock], stats: bool = False): - txl = parse(txl, get_labels()) - native = parse(native, get_labels()) - - if len(txl) != len(native): - print(f'Different number of blocks discovered translation: {len(txl)} vs. ' - f'reference: {len(native)}', file=sys.stdout) - - previous_reference = native[0] - previous_translation = txl[0] - - unmatched_pcs = {} - pc_to_skip = "" - if progressive: - i = 0 - for translation in txl: - previous = i - - while i < len(native): - reference = native[i] - transformations = Transformations(previous_reference, reference) - if verify(translation, reference, transformations.transformation, previous_translation) == 0: - reference.matched = True - break - - i += 1 - - matched = True - - # Didn't find anything - if i == len(native): - matched = False - # TODO: add verbose output - print_separator(stream=sys.stdout) - print(f'No match for PC {hex(translation.regs["PC"])}', file=sys.stdout) - if translation.regs['PC'] not in unmatched_pcs: - unmatched_pcs[translation.regs['PC']] = 0 - unmatched_pcs[translation.regs['PC']] += 1 - - i = previous - - # Necessary since we may have run out of native BBs to check and - # previous becomes len(native) - # - # We continue checking to report unmatched translation PCs - if i < len(native): - previous_reference = native[i] - - previous_translation = translation - - # Skip next reference when there is a backwards branch - # NOTE: if a reference was skipped, don't skip it again - # necessary for loops which may have multiple backwards - # branches - if translation.has_backwards and translation.regs['PC'] != pc_to_skip: - pc_to_skip = translation.regs['PC'] - i += 1 - - if matched: - i += 1 - else: - txl = iter(txl) - native = iter(native) - for translation, reference in zip(txl, native): - transformations = Transformations(previous_reference, reference) - if verify(translation, reference, transformations.transformation, previous_translation) == 1: - # TODO: add verbose output - print_separator(stream=sys.stdout) - print(f'No match for PC {hex(translation.regs["PC"])}', file=sys.stdout) - if translation.regs['PC'] not in unmatched_pcs: - unmatched_pcs[translation.regs['PC']] = 0 - unmatched_pcs[translation.regs['PC']] += 1 - else: - reference.matched = True - - if translation.has_backwards: - next(native) - - previous_reference = reference - previous_translation = translation - - if stats: - print_separator() - print('Statistics:') - print_separator() - - for pc in unmatched_pcs: - print(f'PC {hex(pc)} unmatched {unmatched_pcs[pc]} times') - - # NOTE: currently doesn't handle mismatched due backward branches - current = "" - unmatched_count = 0 - for ref in native: - ref_pc = ref.regs['PC'] - if ref_pc != current: - if unmatched_count: - print(f'Reference PC {hex(current)} unmatched {unmatched_count} times') - current = ref_pc - - if ref.matched == False: - unmatched_count += 1 - return 0 - -def read_logs(txl_path, native_path, program): - txl = [] - with open(txl_path, "r") as txl_file: - txl = txl_file.readlines() - - native = [] - if program is not None: - runner = Runner(txl, program) - native = runner.run() - else: - with open(native_path, "r") as native_file: - native = native_file.readlines() - - return txl, native - -def parse_arguments(): - parser = argparse.ArgumentParser(description='Comparator for emulator logs to reference') - parser.add_argument('-p', '--program', - type=str, - help='Path to oracle program') - parser.add_argument('-r', '--ref', - type=str, - required=True, - help='Path to the reference log (gathered with run.sh)') - parser.add_argument('-t', '--txl', - type=str, - required=True, - help='Path to the translation log (gathered via Arancini)') - parser.add_argument('-s', '--stats', - action='store_true', - default=False, - help='Run statistics on comparisons') - parser.add_argument('-v', '--verbose', - action='store_true', - default=True, - help='Path to oracle program') - parser.add_argument('--progressive', - action='store_true', - default=False, - help='Try to match exhaustively before declaring \ - mismatch') - args = parser.parse_args() - return args - -if __name__ == "__main__": - check_version('3.7') - - args = parse_arguments() - - txl_path = args.txl - native_path = args.ref - program = args.program - - stats = args.stats - verbose = args.verbose - progressive = args.progressive - - if verbose: - print("Enabling verbose program output") - print(f"Verbose: {verbose}") - print(f"Statistics: {stats}") - print(f"Progressive: {progressive}") - - if program is None and native_path is None: - raise ValueError('Either program or path to native file must be' - 'provided') - - txl, native = read_logs(txl_path, native_path, program) - - if program != None and native_path != None: - with open(native_path, 'w') as w: - w.write(''.join(native)) - - compare(txl, native, stats) - diff --git a/cpuid b/cpuid new file mode 160000 +Subproject 335f97a08af46dda14a09f2e825dddbbe7e8177 diff --git a/focaccia.py b/focaccia.py new file mode 100755 index 0000000..f0c6efe --- /dev/null +++ b/focaccia.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 + +import argparse +import platform +from typing import Callable, Iterable + +import focaccia.parser as parser +from focaccia.arch import supported_architectures, Arch +from focaccia.compare import compare_simple, compare_symbolic, ErrorTypes +from focaccia.lldb_target import LLDBConcreteTarget +from focaccia.match import fold_traces, match_traces +from focaccia.snapshot import ProgramState +from focaccia.symbolic import collect_symbolic_trace, SymbolicTransform +from focaccia.utils import print_result, get_envp +from focaccia.reproducer import Reproducer +from focaccia.compare import ErrorSeverity +from focaccia.trace import Trace, TraceEnvironment + +verbosity = { + 'info': ErrorTypes.INFO, + 'warning': ErrorTypes.POSSIBLE, + 'error': ErrorTypes.CONFIRMED, +} + +concrete_trace_parsers = { + 'focaccia': lambda f, _: parser.parse_snapshots(f), + 'qemu': parser.parse_qemu, + 'arancini': parser.parse_arancini, +} + +_MatchingAlgorithm = Callable[ + [list[ProgramState], list[SymbolicTransform]], + tuple[list[ProgramState], list[SymbolicTransform]] +] + +matching_algorithms: dict[str, _MatchingAlgorithm] = { + 'none': lambda c, s: (c, s), + 'simple': match_traces, + 'fold': fold_traces, +} + +def collect_concrete_trace(env: TraceEnvironment, breakpoints: Iterable[int]) \ + -> list[ProgramState]: + """Gather snapshots from a native execution via an external debugger. + + :param env: Program to execute and the environment in which to execute it. + :param breakpoints: List of addresses at which to break and record the + program's state. + + :return: A list of snapshots gathered from the execution. + """ + target = LLDBConcreteTarget(env.binary_name, env.argv, env.envp) + + # Set breakpoints + for address in breakpoints: + target.set_breakpoint(address) + + # Execute the native program + snapshots = [] + while not target.is_exited(): + snapshots.append(target.record_snapshot()) + target.run() + + return snapshots + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.description = '''An emulator tester and verifier. + +You can pre-record symbolic traces with `tools/capture_transforms.py`, then pass +them to the verifier with the --oracle-trace argument. +''' + + # Specification of the symbolic truth trace + symb_trace = parser.add_mutually_exclusive_group(required=True) + symb_trace.add_argument('-p', '--oracle-program', + help='A program from which a symbolic truth will be' + ' recorded.') + symb_trace.add_argument('-o', '--oracle-trace', '--symb-trace', + help='A symbolic trace that serves as a truth state' + ' for comparison.') + parser.add_argument('-a', '--oracle-args', + nargs='*', + default=[], + help='Arguments to the oracle program.') + parser.add_argument('-e', '--oracle-env', + nargs='*', + help='Override the oracle program\'s environment during' + ' symbolic and concrete execution.') + + # Specification of the concrete test trace + parser.add_argument('-t', '--test-trace', + required=True, + help='The concrete test states to test against the' + ' symbolic truth.') + parser.add_argument('--test-trace-type', + default='focaccia', + choices=list(concrete_trace_parsers.keys()), + help='Log file format of the tested program trace.' + ' [Default: focaccia]') + + # Algorithm and output control + parser.add_argument('--match', + choices=list(matching_algorithms.keys()), + default='simple', + help='Select an algorithm to match the test trace to' + ' the truth trace. Only applicable if --symbolic' + ' is enabled.' + ' [Default: simple]') + parser.add_argument('--symbolic', + action='store_true', + default=False, + help='Use an advanced algorithm that uses symbolic' + ' execution to determine accurate data' + ' transformations. This improves the quality of' + ' generated errors significantly, but will take' + ' more time to complete.') + parser.add_argument('--error-level', + default='warning', + choices=list(verbosity.keys()), + help='Verbosity of reported errors. \'error\' only' + ' reports mismatches that have been detected as' + ' errors in the translation with certainty.' + ' \'warning\' will report possible errors that' + ' may as well stem from incomplete input data.' + ' \'info\' will report absolutely everything.' + ' [Default: warning]') + parser.add_argument('--no-verifier', + action='store_true', + default=False, + help='Don\'t print verifier output.') + + # Reproducer + parser.add_argument('--reproducer', + action='store_true', + default=False, + help='Generate repoducer executables for detected' + ' errors.') + + return parser.parse_args() + +def print_reproducer(result, min_severity: ErrorSeverity, oracle, oracle_args): + for res in result: + errs = [e for e in res['errors'] if e.severity >= min_severity] + #breakpoint() + if errs: + rep = Reproducer(oracle, oracle_args, res['snap'], res['ref']) + print(rep.asm()) + return + +def get_test_trace(args, arch: Arch) -> Trace[ProgramState]: + path = args.test_trace + parser = concrete_trace_parsers[args.test_trace_type] + with open(path, 'r') as txl_file: + return parser(txl_file, arch) + +def get_truth_env(args) -> TraceEnvironment: + oracle = args.oracle_program + oracle_args = args.oracle_args + if args.oracle_env: + oracle_env = args.oracle_env + else: + oracle_env = get_envp() + return TraceEnvironment(oracle, oracle_args, oracle_env) + +def get_symbolic_trace(args): + if args.oracle_program: + env = get_truth_env(args) + print('Tracing', env) + return collect_symbolic_trace(env) + elif args.oracle_trace: + with open(args.oracle_trace, 'r') as file: + return parser.parse_transformations(file) + raise AssertionError() + +def main(): + args = parse_arguments() + + # Determine the current machine's architecture. The log type must match the + # architecture on which focaccia is executed because focaccia wants to + # execute the reference program concretely. + if platform.machine() not in supported_architectures: + print(f'Machine {platform.machine()} is not supported! Exiting.') + exit(1) + arch = supported_architectures[platform.machine()] + + # Parse reference trace + test_trace = get_test_trace(args, arch) + + # Compare reference trace to a truth + if args.symbolic: + symb_trace = get_symbolic_trace(args) + match = matching_algorithms[args.match] + conc, symb = match(test_trace.states, symb_trace.states) + + result = compare_symbolic(conc, symb) + oracle_env = symb_trace.env + else: + if not args.oracle_program: + print('Argument --oracle-program is required for non-symbolic' + ' verification!') + exit(1) + + # Record truth states from a concrete execution of the oracle + breakpoints = [state.read_register('PC') for state in test_trace] + env = get_truth_env(args) + truth_trace = collect_concrete_trace(env, breakpoints) + + result = compare_simple(test_trace.states, truth_trace) + oracle_env = env + + if not args.no_verifier: + print_result(result, verbosity[args.error_level]) + + if args.reproducer: + print_reproducer(result, + verbosity[args.error_level], + oracle_env.binary_name, + oracle_env.argv) + +if __name__ == '__main__': + main() diff --git a/focaccia/__init__.py b/focaccia/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/focaccia/__init__.py diff --git a/focaccia/arch/__init__.py b/focaccia/arch/__init__.py new file mode 100644 index 0000000..1797176 --- /dev/null +++ b/focaccia/arch/__init__.py @@ -0,0 +1,14 @@ +from .arch import Arch +from . import x86, aarch64 + +supported_architectures: dict[str, Arch] = { + 'x86_64': x86.ArchX86(), + 'aarch64': aarch64.ArchAArch64('little'), + 'aarch64l': aarch64.ArchAArch64('little'), + 'aarch64b': aarch64.ArchAArch64('big'), +} +"""A dictionary containing all supported architectures at their names. + +The arch names (keys) should be compatible with the string returned from +`platform.machine()`. +""" diff --git a/focaccia/arch/aarch64.py b/focaccia/arch/aarch64.py new file mode 100644 index 0000000..2e510b9 --- /dev/null +++ b/focaccia/arch/aarch64.py @@ -0,0 +1,123 @@ +"""Description of 64-bit ARM.""" + +from .arch import Arch, RegisterDescription as _Reg + +archname = 'aarch64' + +registers = [ + _Reg(('R0', 0, 64), ('X0', 0, 64), ('W0', 0, 32)), + _Reg(('R1', 0, 64), ('X1', 0, 64), ('W1', 0, 32)), + _Reg(('R2', 0, 64), ('X2', 0, 64), ('W2', 0, 32)), + _Reg(('R3', 0, 64), ('X3', 0, 64), ('W3', 0, 32)), + _Reg(('R4', 0, 64), ('X4', 0, 64), ('W4', 0, 32)), + _Reg(('R5', 0, 64), ('X5', 0, 64), ('W5', 0, 32)), + _Reg(('R6', 0, 64), ('X6', 0, 64), ('W6', 0, 32)), + _Reg(('R7', 0, 64), ('X7', 0, 64), ('W7', 0, 32)), + _Reg(('R8', 0, 64), ('X8', 0, 64), ('W8', 0, 32)), + _Reg(('R9', 0, 64), ('X9', 0, 64), ('W9', 0, 32)), + _Reg(('R10', 0, 64), ('X10', 0, 64), ('W10', 0, 32)), + _Reg(('R11', 0, 64), ('X11', 0, 64), ('W11', 0, 32)), + _Reg(('R12', 0, 64), ('X12', 0, 64), ('W12', 0, 32)), + _Reg(('R13', 0, 64), ('X13', 0, 64), ('W13', 0, 32)), + _Reg(('R14', 0, 64), ('X14', 0, 64), ('W14', 0, 32)), + _Reg(('R15', 0, 64), ('X15', 0, 64), ('W15', 0, 32)), + _Reg(('R16', 0, 64), ('X16', 0, 64), ('W16', 0, 32)), + _Reg(('R17', 0, 64), ('X17', 0, 64), ('W17', 0, 32)), + _Reg(('R18', 0, 64), ('X18', 0, 64), ('W18', 0, 32)), + _Reg(('R19', 0, 64), ('X19', 0, 64), ('W19', 0, 32)), + _Reg(('R20', 0, 64), ('X20', 0, 64), ('W20', 0, 32)), + _Reg(('R21', 0, 64), ('X21', 0, 64), ('W21', 0, 32)), + _Reg(('R22', 0, 64), ('X22', 0, 64), ('W22', 0, 32)), + _Reg(('R23', 0, 64), ('X23', 0, 64), ('W23', 0, 32)), + _Reg(('R24', 0, 64), ('X24', 0, 64), ('W24', 0, 32)), + _Reg(('R25', 0, 64), ('X25', 0, 64), ('W25', 0, 32)), + _Reg(('R26', 0, 64), ('X26', 0, 64), ('W26', 0, 32)), + _Reg(('R27', 0, 64), ('X27', 0, 64), ('W27', 0, 32)), + _Reg(('R28', 0, 64), ('X28', 0, 64), ('W28', 0, 32)), + _Reg(('R29', 0, 64), ('X29', 0, 64), ('W29', 0, 32)), + _Reg(('R30', 0, 64), ('X30', 0, 64), ('W30', 0, 32), ('LR', 0, 64)), + + _Reg(('RZR', 0, 64), ('XZR', 0, 64), ('WZR', 0, 32)), + _Reg(('SP', 0, 64), ('RSP', 0, 64)), + _Reg(('PC', 0, 64)), + + _Reg(('V0', 0, 128), ('Q0', 0, 128), ('D0', 0, 64), ('S0', 0, 32), ('H0', 0, 16), ('B0', 0, 8)), + _Reg(('V1', 0, 128), ('Q1', 0, 128), ('D1', 0, 64), ('S1', 0, 32), ('H1', 0, 16), ('B1', 0, 8)), + _Reg(('V2', 0, 128), ('Q2', 0, 128), ('D2', 0, 64), ('S2', 0, 32), ('H2', 0, 16), ('B2', 0, 8)), + _Reg(('V3', 0, 128), ('Q3', 0, 128), ('D3', 0, 64), ('S3', 0, 32), ('H3', 0, 16), ('B3', 0, 8)), + _Reg(('V4', 0, 128), ('Q4', 0, 128), ('D4', 0, 64), ('S4', 0, 32), ('H4', 0, 16), ('B4', 0, 8)), + _Reg(('V5', 0, 128), ('Q5', 0, 128), ('D5', 0, 64), ('S5', 0, 32), ('H5', 0, 16), ('B5', 0, 8)), + _Reg(('V6', 0, 128), ('Q6', 0, 128), ('D6', 0, 64), ('S6', 0, 32), ('H6', 0, 16), ('B6', 0, 8)), + _Reg(('V7', 0, 128), ('Q7', 0, 128), ('D7', 0, 64), ('S7', 0, 32), ('H7', 0, 16), ('B7', 0, 8)), + _Reg(('V8', 0, 128), ('Q8', 0, 128), ('D8', 0, 64), ('S8', 0, 32), ('H8', 0, 16), ('B8', 0, 8)), + _Reg(('V9', 0, 128), ('Q9', 0, 128), ('D9', 0, 64), ('S9', 0, 32), ('H9', 0, 16), ('B9', 0, 8)), + _Reg(('V10', 0, 128), ('Q10', 0, 128), ('D10', 0, 64), ('S10', 0, 32), ('H10', 0, 16), ('B10', 0, 8)), + _Reg(('V11', 0, 128), ('Q11', 0, 128), ('D11', 0, 64), ('S11', 0, 32), ('H11', 0, 16), ('B11', 0, 8)), + _Reg(('V12', 0, 128), ('Q12', 0, 128), ('D12', 0, 64), ('S12', 0, 32), ('H12', 0, 16), ('B12', 0, 8)), + _Reg(('V13', 0, 128), ('Q13', 0, 128), ('D13', 0, 64), ('S13', 0, 32), ('H13', 0, 16), ('B13', 0, 8)), + _Reg(('V14', 0, 128), ('Q14', 0, 128), ('D14', 0, 64), ('S14', 0, 32), ('H14', 0, 16), ('B14', 0, 8)), + _Reg(('V15', 0, 128), ('Q15', 0, 128), ('D15', 0, 64), ('S15', 0, 32), ('H15', 0, 16), ('B15', 0, 8)), + _Reg(('V16', 0, 128), ('Q16', 0, 128), ('D16', 0, 64), ('S16', 0, 32), ('H16', 0, 16), ('B16', 0, 8)), + _Reg(('V17', 0, 128), ('Q17', 0, 128), ('D17', 0, 64), ('S17', 0, 32), ('H17', 0, 16), ('B17', 0, 8)), + _Reg(('V18', 0, 128), ('Q18', 0, 128), ('D18', 0, 64), ('S18', 0, 32), ('H18', 0, 16), ('B18', 0, 8)), + _Reg(('V19', 0, 128), ('Q19', 0, 128), ('D19', 0, 64), ('S19', 0, 32), ('H19', 0, 16), ('B19', 0, 8)), + _Reg(('V20', 0, 128), ('Q20', 0, 128), ('D20', 0, 64), ('S20', 0, 32), ('H20', 0, 16), ('B20', 0, 8)), + _Reg(('V21', 0, 128), ('Q21', 0, 128), ('D21', 0, 64), ('S21', 0, 32), ('H21', 0, 16), ('B21', 0, 8)), + _Reg(('V22', 0, 128), ('Q22', 0, 128), ('D22', 0, 64), ('S22', 0, 32), ('H22', 0, 16), ('B22', 0, 8)), + _Reg(('V23', 0, 128), ('Q23', 0, 128), ('D23', 0, 64), ('S23', 0, 32), ('H23', 0, 16), ('B23', 0, 8)), + _Reg(('V24', 0, 128), ('Q24', 0, 128), ('D24', 0, 64), ('S24', 0, 32), ('H24', 0, 16), ('B24', 0, 8)), + _Reg(('V25', 0, 128), ('Q25', 0, 128), ('D25', 0, 64), ('S25', 0, 32), ('H25', 0, 16), ('B25', 0, 8)), + _Reg(('V26', 0, 128), ('Q26', 0, 128), ('D26', 0, 64), ('S26', 0, 32), ('H26', 0, 16), ('B26', 0, 8)), + _Reg(('V27', 0, 128), ('Q27', 0, 128), ('D27', 0, 64), ('S27', 0, 32), ('H27', 0, 16), ('B27', 0, 8)), + _Reg(('V28', 0, 128), ('Q28', 0, 128), ('D28', 0, 64), ('S28', 0, 32), ('H28', 0, 16), ('B28', 0, 8)), + _Reg(('V29', 0, 128), ('Q29', 0, 128), ('D29', 0, 64), ('S29', 0, 32), ('H29', 0, 16), ('B29', 0, 8)), + _Reg(('V30', 0, 128), ('Q30', 0, 128), ('D30', 0, 64), ('S30', 0, 32), ('H30', 0, 16), ('B30', 0, 8)), + _Reg(('V31', 0, 128), ('Q31', 0, 128), ('D31', 0, 64), ('S31', 0, 32), ('H31', 0, 16), ('B31', 0, 8)), + + _Reg(('CPSR', 0, 64), + ('N', 31, 32), + ('Z', 30, 31), + ('C', 29, 30), + ('V', 28, 29), + ('Q', 27, 28), + ('SSBS', 23, 24), + ('PAN', 22, 23), + ('DIT', 21, 22), + ('GE', 16, 20), + ('E', 9, 10), + ('A', 8, 9), + ('I', 7, 8), + ('F', 6, 7), + ('M', 0, 4), + ), +] + +# Names of registers in the architecture +regnames = [desc.base.base_reg for desc in registers] + +def decompose_cpsr(cpsr: int) -> dict[str, int]: + """Extract individual flag values from the CPSR register.""" + return { + 'N': (cpsr & (1 << 31)) != 0, + 'Z': (cpsr & (1 << 30)) != 0, + 'C': (cpsr & (1 << 29)) != 0, + 'V': (cpsr & (1 << 28)) != 0, + 'Q': (cpsr & (1 << 27)) != 0, + # Reserved: [26:24] + 'SSBS': (cpsr & (1 << 23)) != 0, + 'PAN': (cpsr & (1 << 22)) != 0, + 'DIT': (cpsr & (1 << 21)) != 0, + # Reserved: [20] + 'GE': (cpsr & (0b1111 << 16)) != 0, + # Reserved: [15:10] + 'E': (cpsr & (1 << 9)) != 0, + 'A': (cpsr & (1 << 8)) != 0, + 'I': (cpsr & (1 << 7)) != 0, + 'F': (cpsr & (1 << 6)) != 0, + # Reserved: [5:4] + 'M': (cpsr & 0b1111) != 0, + } + +class ArchAArch64(Arch): + def __init__(self, endianness: Arch.Endianness): + super().__init__(archname, registers, 64, endianness) diff --git a/focaccia/arch/arch.py b/focaccia/arch/arch.py new file mode 100644 index 0000000..ce5e532 --- /dev/null +++ b/focaccia/arch/arch.py @@ -0,0 +1,84 @@ +from typing import Literal + +class RegisterAccessor: + def __init__(self, regname: str, start_bit: int, end_bit: int): + """An accessor that describes a range of bits. + + Builds a bit range [start_bit, end_bit), meaning `end_bit` is excluded + from the range. + + Example: An object `RegisterAccessor(0, 1)` accesses exactly the first + bit of a value. `RegisterAccessor(0, 0)` is invalid as it references + a range of zero bits. + + :param start_bit: First bit included in the range. This is the least + significant bit in the range. + :param end_bit: First bit *not* included in the range. This is the most + significant bit of the range. + """ + assert(start_bit < end_bit) + self.base_reg = regname + self.start = start_bit + self.end = end_bit + + self.num_bits = end_bit - start_bit + self.mask = 0 + for i in range(start_bit, end_bit): + self.mask |= 1 << i + + def __repr__(self) -> str: + return f'{self.base_reg}[{self.start}:{self.end - 1}]' + +class RegisterDescription: + def __init__(self, base: tuple[str, int, int], *subsets: tuple[str, int, int]): + self.base = RegisterAccessor(*base) + self.subsets = [(name, RegisterAccessor(base[0], s, e)) for name, s, e in subsets] + +class Arch(): + Endianness = Literal['little', 'big'] + + def __init__(self, + archname: str, + registers: list[RegisterDescription], + ptr_size: int, + endianness: Endianness = 'little'): + self.archname = archname + self.ptr_size = ptr_size + self.endianness: Literal['little', 'big'] = endianness + + self._accessors = {} + for desc in registers: + self._accessors[desc.base.base_reg.upper()] = desc.base + self._accessors |= {name: acc for name, acc in desc.subsets} + + self.regnames = set(desc.base.base_reg.upper() for desc in registers) + """Names of the architecture's base registers.""" + + self.all_regnames = set(self._accessors.keys()) + """Names of the architecture's registers, including register aliases.""" + + def to_regname(self, name: str) -> str | None: + """Transform a string into a standard register name. + + :param name: The possibly non-standard name to look up. + :return: The 'corrected' register name, or None if `name` cannot be + transformed into a register name. + """ + name = name.upper() + if name in self._accessors: + return name + return None + + def get_reg_accessor(self, regname: str) -> RegisterAccessor | None: + """Get an accessor for a register name, which may be an alias. + + Is used internally by ProgramState to access aliased registers. + """ + _regname = self.to_regname(regname) + return self._accessors.get(_regname, None) + + def __eq__(self, other): + return self.archname == other.archname + + def __repr__(self) -> str: + return self.archname diff --git a/focaccia/arch/x86.py b/focaccia/arch/x86.py new file mode 100644 index 0000000..fefab37 --- /dev/null +++ b/focaccia/arch/x86.py @@ -0,0 +1,200 @@ +"""Architecture-specific configuration.""" + +from .arch import Arch, RegisterDescription as _Reg + +archname = 'x86_64' + +registers = [ + # General-purpose registers + _Reg(('RIP', 0, 64), ('EIP', 0, 32), ('IP', 0, 16)), + _Reg(('RAX', 0, 64), ('EAX', 0, 32), ('AX', 0, 16), ('AL', 0, 8), ('AH', 8, 16)), + _Reg(('RBX', 0, 64), ('EBX', 0, 32), ('BX', 0, 16), ('BL', 0, 8), ('BH', 8, 16)), + _Reg(('RCX', 0, 64), ('ECX', 0, 32), ('CX', 0, 16), ('CL', 0, 8), ('CH', 8, 16)), + _Reg(('RDX', 0, 64), ('EDX', 0, 32), ('DX', 0, 16), ('DL', 0, 8), ('DH', 8, 16)), + _Reg(('RSI', 0, 64), ('ESI', 0, 32), ('SI', 0, 16), ('SIL', 0, 8)), + _Reg(('RDI', 0, 64), ('EDI', 0, 32), ('DI', 0, 16), ('DIL', 0, 8)), + _Reg(('RBP', 0, 64), ('EBP', 0, 32), ('BP', 0, 16), ('BPL', 0, 8)), + _Reg(('RSP', 0, 64), ('ESP', 0, 32), ('SP', 0, 16), ('SPL', 0, 8)), + _Reg(('R8', 0, 64)), + _Reg(('R9', 0, 64)), + _Reg(('R10', 0, 64)), + _Reg(('R11', 0, 64)), + _Reg(('R12', 0, 64)), + _Reg(('R13', 0, 64)), + _Reg(('R14', 0, 64)), + _Reg(('R15', 0, 64)), + + # RFLAGS + _Reg(('RFLAGS', 0, 64), ('EFLAGS', 0, 32), ('FLAGS', 0, 16), + ('CF', 0, 1), + ('PF', 2, 3), + ('AF', 4, 5), + ('ZF', 6, 7), + ('SF', 7, 8), + ('TF', 8, 9), + ('IF', 9, 10), + ('DF', 10, 11), + ('OF', 11, 12), + ('IOPL', 12, 14), + ('NT', 14, 15), + ('MD', 15, 16), + + ('RF', 16, 17), + ('VM', 17, 18), + ('AC', 18, 19), + ('VIF', 19, 20), + ('VIP', 20, 21), + ('ID', 21, 22), + ('AI', 31, 32), + ), + + # Segment registers + _Reg(('CS', 0, 16)), + _Reg(('DS', 0, 16)), + _Reg(('SS', 0, 16)), + _Reg(('ES', 0, 16)), + _Reg(('FS', 0, 16)), + _Reg(('GS', 0, 16)), + _Reg(('FS_BASE', 0, 64)), + _Reg(('GS_BASE', 0, 64)), + + # x87 floating-point registers + _Reg(('ST0', 0, 80)), + _Reg(('ST1', 0, 80)), + _Reg(('ST2', 0, 80)), + _Reg(('ST3', 0, 80)), + _Reg(('ST4', 0, 80)), + _Reg(('ST5', 0, 80)), + _Reg(('ST6', 0, 80)), + _Reg(('ST7', 0, 80)), + + # Vector registers + _Reg(('ZMM0', 0, 512), ('YMM0', 0, 256), ('XMM0', 0, 128), ('MM0', 0, 64)), + _Reg(('ZMM1', 0, 512), ('YMM1', 0, 256), ('XMM1', 0, 128), ('MM1', 0, 64)), + _Reg(('ZMM2', 0, 512), ('YMM2', 0, 256), ('XMM2', 0, 128), ('MM2', 0, 64)), + _Reg(('ZMM3', 0, 512), ('YMM3', 0, 256), ('XMM3', 0, 128), ('MM3', 0, 64)), + _Reg(('ZMM4', 0, 512), ('YMM4', 0, 256), ('XMM4', 0, 128), ('MM4', 0, 64)), + _Reg(('ZMM5', 0, 512), ('YMM5', 0, 256), ('XMM5', 0, 128), ('MM5', 0, 64)), + _Reg(('ZMM6', 0, 512), ('YMM6', 0, 256), ('XMM6', 0, 128), ('MM6', 0, 64)), + _Reg(('ZMM7', 0, 512), ('YMM7', 0, 256), ('XMM7', 0, 128), ('MM7', 0, 64)), + _Reg(('ZMM8', 0, 512), ('YMM8', 0, 256), ('XMM8', 0, 128)), + _Reg(('ZMM9', 0, 512), ('YMM9', 0, 256), ('XMM9', 0, 128)), + _Reg(('ZMM10', 0, 512), ('YMM10', 0, 256), ('XMM10', 0, 128)), + _Reg(('ZMM11', 0, 512), ('YMM11', 0, 256), ('XMM11', 0, 128)), + _Reg(('ZMM12', 0, 512), ('YMM12', 0, 256), ('XMM12', 0, 128)), + _Reg(('ZMM13', 0, 512), ('YMM13', 0, 256), ('XMM13', 0, 128)), + _Reg(('ZMM14', 0, 512), ('YMM14', 0, 256), ('XMM14', 0, 128)), + _Reg(('ZMM15', 0, 512), ('YMM15', 0, 256), ('XMM15', 0, 128)), + + _Reg(('ZMM16', 0, 512), ('YMM16', 0, 256), ('XMM16', 0, 128)), + _Reg(('ZMM17', 0, 512), ('YMM17', 0, 256), ('XMM17', 0, 128)), + _Reg(('ZMM18', 0, 512), ('YMM18', 0, 256), ('XMM18', 0, 128)), + _Reg(('ZMM19', 0, 512), ('YMM19', 0, 256), ('XMM19', 0, 128)), + _Reg(('ZMM20', 0, 512), ('YMM20', 0, 256), ('XMM20', 0, 128)), + _Reg(('ZMM21', 0, 512), ('YMM21', 0, 256), ('XMM21', 0, 128)), + _Reg(('ZMM22', 0, 512), ('YMM22', 0, 256), ('XMM22', 0, 128)), + _Reg(('ZMM23', 0, 512), ('YMM23', 0, 256), ('XMM23', 0, 128)), + _Reg(('ZMM24', 0, 512), ('YMM24', 0, 256), ('XMM24', 0, 128)), + _Reg(('ZMM25', 0, 512), ('YMM25', 0, 256), ('XMM25', 0, 128)), + _Reg(('ZMM26', 0, 512), ('YMM26', 0, 256), ('XMM26', 0, 128)), + _Reg(('ZMM27', 0, 512), ('YMM27', 0, 256), ('XMM27', 0, 128)), + _Reg(('ZMM28', 0, 512), ('YMM28', 0, 256), ('XMM28', 0, 128)), + _Reg(('ZMM29', 0, 512), ('YMM29', 0, 256), ('XMM29', 0, 128)), + _Reg(('ZMM30', 0, 512), ('YMM30', 0, 256), ('XMM30', 0, 128)), + _Reg(('ZMM31', 0, 512), ('YMM31', 0, 256), ('XMM31', 0, 128)), +] + +# Names of registers in the architecture +regnames = [desc.base.base_reg for desc in registers] + +# A dictionary mapping aliases to standard register names. +regname_aliases = { + 'PC': 'RIP', + 'NF': 'SF', # negative flag == sign flag in Miasm? +} + +def decompose_rflags(rflags: int) -> dict[str, int]: + """Decompose the RFLAGS register's value into its separate flags. + + Uses flag name abbreviation conventions from + `https://en.wikipedia.org/wiki/FLAGS_register`. + + :param rflags: The RFLAGS register value. + :return: A dictionary mapping Miasm's flag names to their values. + """ + return { + # FLAGS + 'CF': (rflags & 0x0001) != 0, + # 0x0002 reserved + 'PF': (rflags & 0x0004) != 0, + # 0x0008 reserved + 'AF': (rflags & 0x0010) != 0, + # 0x0020 reserved + 'ZF': (rflags & 0x0040) != 0, + 'SF': (rflags & 0x0080) != 0, + 'TF': (rflags & 0x0100) != 0, + 'IF': (rflags & 0x0200) != 0, + 'DF': (rflags & 0x0400) != 0, + 'OF': (rflags & 0x0800) != 0, + 'IOPL': (rflags & 0x3000) != 0, + 'NT': (rflags & 0x4000) != 0, + + # EFLAGS + 'RF': (rflags & 0x00010000) != 0, + 'VM': (rflags & 0x00020000) != 0, + 'AC': (rflags & 0x00040000) != 0, + 'VIF': (rflags & 0x00080000) != 0, + 'VIP': (rflags & 0x00100000) != 0, + 'ID': (rflags & 0x00200000) != 0, + } + +def compose_rflags(rflags: dict[str, int]) -> int: + """Compose separate flags into RFLAGS register's value. + + Uses flag name abbreviation conventions from + `https://en.wikipedia.org/wiki/FLAGS_register`. + + :param rflags: A dictionary mapping Miasm's flag names to their alues. + :return: The RFLAGS register value. + """ + return ( + # FLAGS + (0x0001 if rflags.get('CF', 0) else 0) | + # 0x0002 reserved + (0x0004 if rflags.get('PF', 0) else 0) | + # 0x0008 reserved + (0x0010 if rflags.get('AF', 0) else 0) | + # 0x0020 reserved + (0x0040 if rflags.get('ZF', 0) else 0) | + (0x0080 if rflags.get('SF', 0) else 0) | + (0x0100 if rflags.get('TF', 0) else 0) | + (0x0200 if rflags.get('IF', 0) else 0) | + (0x0400 if rflags.get('DF', 0) else 0) | + (0x0800 if rflags.get('OF', 0) else 0) | + (0x3000 if rflags.get('IOPL', 0) else 0) | + (0x4000 if rflags.get('NT', 0) else 0) | + + # EFLAGS + (0x00010000 if rflags.get('RF', 0) else 0) | + (0x00020000 if rflags.get('VM', 0) else 0) | + (0x00040000 if rflags.get('AC', 0) else 0) | + (0x00080000 if rflags.get('VIF', 0) else 0) | + (0x00100000 if rflags.get('VIP', 0) else 0) | + (0x00200000 if rflags.get('ID', 0) else 0) + ) + +class ArchX86(Arch): + def __init__(self): + super().__init__(archname, registers, 64) + + def to_regname(self, name: str) -> str | None: + """The X86 override of the standard register name lookup. + + Applies certain register name aliases. + """ + reg = super().to_regname(name) + if reg is not None: + return reg + + # Apply custom register alias rules + return regname_aliases.get(name.upper(), None) diff --git a/focaccia/compare.py b/focaccia/compare.py new file mode 100644 index 0000000..13f965c --- /dev/null +++ b/focaccia/compare.py @@ -0,0 +1,302 @@ +from __future__ import annotations +from typing import Iterable + +from .snapshot import ProgramState, MemoryAccessError, RegisterAccessError +from .symbolic import SymbolicTransform +from .utils import ErrorSeverity + +class ErrorTypes: + INFO = ErrorSeverity(0, 'INFO') + INCOMPLETE = ErrorSeverity(2, 'INCOMPLETE DATA') + POSSIBLE = ErrorSeverity(4, 'UNCONFIRMED ERROR') + CONFIRMED = ErrorSeverity(5, 'ERROR') + +class Error: + """A state comparison error.""" + def __init__(self, severity: ErrorSeverity, msg: str): + self.severity = severity + self.error_msg = msg + + def __repr__(self) -> str: + return f'{self.severity} {self.error_msg}' + +def _calc_transformation(previous: ProgramState, current: ProgramState): + """Calculate the difference between two context blocks. + + :return: A context block that contains in its registers the difference + between the corresponding input blocks' register values. + """ + assert(previous.arch == current.arch) + + arch = previous.arch + transformation = ProgramState(arch) + for reg in arch.regnames: + try: + prev_val = previous.read_register(reg) + cur_val = current.read_register(reg) + transformation.set_register(reg, cur_val - prev_val) + except RegisterAccessError: + # Register is not set in either state + pass + + return transformation + +def _find_errors(transform_txl: ProgramState, transform_truth: ProgramState) \ + -> list[Error]: + """Find possible errors between a reference and a tested state. + + :param txl_state: The translated state to check for errors. + :param prev_txl_state: The translated snapshot immediately preceding + `txl_state`. + :param truth_state: The reference state against which to check the + translated state `txl_state` for errors. + :param prev_truth_state: The reference snapshot immediately preceding + `prev_truth_state`. + + :return: A list of errors; one entry for each register that may have + faulty contents. Is empty if no errors were found. + """ + assert(transform_truth.arch == transform_txl.arch) + + errors = [] + for reg in transform_truth.arch.regnames: + try: + diff_txl = transform_txl.read_register(reg) + diff_truth = transform_truth.read_register(reg) + except RegisterAccessError: + errors.append(Error(ErrorTypes.INFO, + f'Unable to calculate difference:' + f' Value for register {reg} is not set in' + f' either the tested or the reference state.')) + continue + + if diff_txl != diff_truth: + errors.append(Error( + ErrorTypes.CONFIRMED, + f'Transformation of register {reg} is false.' + f' Expected difference: {hex(diff_truth)},' + f' actual difference in the translation: {hex(diff_txl)}.')) + + return errors + +def compare_simple(test_states: list[ProgramState], + truth_states: list[ProgramState]) -> list[dict]: + """Simple comparison of programs. + + :param test_states: A program flow to check for errors. + :param truth_states: A reference program flow that defines a correct + program execution. + + :return: Information, including possible errors, about each processed + snapshot. + """ + PC_REGNAME = 'PC' + + if len(test_states) == 0: + print('No states to compare. Exiting.') + return [] + + # No errors in initial snapshot because we can't perform difference + # calculations on it + result = [{ + 'pc': test_states[0].read_register(PC_REGNAME), + 'txl': test_states[0], 'ref': truth_states[0], + 'errors': [] + }] + + it_prev = zip(iter(test_states), iter(truth_states)) + it_cur = zip(iter(test_states[1:]), iter(truth_states[1:])) + + for txl, truth in it_cur: + prev_txl, prev_truth = next(it_prev) + + pc_txl = txl.read_register(PC_REGNAME) + pc_truth = truth.read_register(PC_REGNAME) + + if pc_txl != pc_truth: + print(f'Unmatched program counter {hex(pc_txl)}' + f' in translated code!') + continue + + transform_truth = _calc_transformation(prev_truth, truth) + transform_txl = _calc_transformation(prev_txl, txl) + errors = _find_errors(transform_txl, transform_truth) + result.append({ + 'pc': pc_txl, + 'txl': transform_txl, 'ref': transform_truth, + 'errors': errors + }) + + return result + +def _find_register_errors(txl_from: ProgramState, + txl_to: ProgramState, + transform_truth: SymbolicTransform) \ + -> list[Error]: + """Find errors in register values. + + Errors might be: + - A register value was modified, but the tested state contains no + reference value for that register. + - The tested destination state's value for a register does not match + the value expected by the symbolic transformation. + """ + # Calculate expected register values + try: + truth = transform_truth.eval_register_transforms(txl_from) + except MemoryAccessError as err: + s, e = transform_truth.range + return [Error( + ErrorTypes.POSSIBLE, + f'Register transformations {hex(s)} -> {hex(e)} depend on' + f' {err.mem_size} bytes at memory address {hex(err.mem_addr)}' + f' that are not entirely present in the tested state' + f' {hex(txl_from.read_register("pc"))}.', + )] + except RegisterAccessError as err: + s, e = transform_truth.range + return [Error(ErrorTypes.INCOMPLETE, + f'Register transformations {hex(s)} -> {hex(e)} depend' + f' on the value of register {err.regname}, which is not' + f' set in the tested state.')] + + # Compare expected values to actual values in the tested state + errors = [] + for regname, truth_val in truth.items(): + try: + txl_val = txl_to.read_register(regname) + except RegisterAccessError: + errors.append(Error(ErrorTypes.INCOMPLETE, + f'Value of register {regname} has changed, but' + f' is not set in the tested state.')) + continue + + if txl_val != truth_val: + errors.append(Error(ErrorTypes.CONFIRMED, + f'Content of register {regname} is false.' + f' Expected value: {hex(truth_val)}, actual' + f' value in the translation: {hex(txl_val)}.')) + return errors + +def _find_memory_errors(txl_from: ProgramState, + txl_to: ProgramState, + transform_truth: SymbolicTransform) \ + -> list[Error]: + """Find errors in memory values. + + Errors might be: + - A range of memory was written, but the tested state contains no + reference value for that range. + - The tested destination state's content for the tested range does not + match the value expected by the symbolic transformation. + """ + # Calculate expected register values + try: + truth = transform_truth.eval_memory_transforms(txl_from) + except MemoryAccessError as err: + s, e = transform_truth.range + return [Error(ErrorTypes.INCOMPLETE, + f'Memory transformations {hex(s)} -> {hex(e)} depend on' + f' {err.mem_size} bytes at memory address {hex(err.mem_addr)}' + f' that are not entirely present in the tested state at' + f' {hex(txl_from.read_register("pc"))}.')] + except RegisterAccessError as err: + s, e = transform_truth.range + return [Error(ErrorTypes.INCOMPLETE, + f'Memory transformations {hex(s)} -> {hex(e)} depend on' + f' the value of register {err.regname}, which is not' + f' set in the tested state.')] + + # Compare expected values to actual values in the tested state + errors = [] + for addr, truth_bytes in truth.items(): + size = len(truth_bytes) + try: + txl_bytes = txl_to.read_memory(addr, size) + except MemoryAccessError: + errors.append(Error(ErrorTypes.POSSIBLE, + f'Memory range [{hex(addr)}, {hex(addr + size)})' + f' is not set in the tested result state at' + f' {hex(txl_to.read_register("pc"))}. This is' + f' either an error in the translation or' + f' the recorded test state is missing data.')) + continue + + if txl_bytes != truth_bytes: + errors.append(Error(ErrorTypes.CONFIRMED, + f'Content of memory at {hex(addr)} is false.' + f' Expected content: {truth_bytes.hex()},' + f' actual content in the translation:' + f' {txl_bytes.hex()}.')) + return errors + +def _find_errors_symbolic(txl_from: ProgramState, + txl_to: ProgramState, + transform_truth: SymbolicTransform) \ + -> list[Error]: + """Tries to find errors in transformations between tested states. + + Applies a transformation to a source state and tests whether the result + matches a given destination state. + + :param txl_from: Source state. This is a state from the tested + program, and is assumed as the starting point for + the transformation. + :param txl_to: Destination state. This is a possibly faulty state + from the tested program, and is tested for + correctness with respect to the source state. + :param transform_truth: The symbolic transformation that maps the source + state to the destination state. + """ + from_pc = txl_from.read_register('PC') + to_pc = txl_to.read_register('PC') + assert((from_pc, to_pc) == transform_truth.range) + + errors = [] + errors.extend(_find_register_errors(txl_from, txl_to, transform_truth)) + errors.extend(_find_memory_errors(txl_from, txl_to, transform_truth)) + + return errors + +def compare_symbolic(test_states: Iterable[ProgramState], + transforms: Iterable[SymbolicTransform]) \ + -> list[dict]: + test_states = iter(test_states) + transforms = iter(transforms) + + result = [] + + cur_state = next(test_states) # The state before the transformation + transform = next(transforms) # Transform that operates on `cur_state` + while True: + try: + next_state = next(test_states) # The state after the transformation + + pc_cur = cur_state.read_register('PC') + pc_next = next_state.read_register('PC') + if (pc_cur, pc_next) != transform.range: + repr_range = lambda r: f'[{hex(r[0])} -> {hex(r[1])}]' + print(f'[WARNING] Test states {repr_range((pc_cur, pc_next))}' + f' do not match the symbolic transformation' + f' {repr_range(transform.range)} against which they are' + f' tested! Skipping.') + cur_state = next_state + transform = next(transforms) + continue + + errors = _find_errors_symbolic(cur_state, next_state, transform) + result.append({ + 'pc': pc_cur, + 'txl': _calc_transformation(cur_state, next_state), + 'ref': transform, + 'errors': errors, + 'snap': cur_state, + }) + + # Step forward + cur_state = next_state + transform = next(transforms) + except StopIteration: + break + + return result diff --git a/focaccia/lldb_target.py b/focaccia/lldb_target.py new file mode 100644 index 0000000..1f31337 --- /dev/null +++ b/focaccia/lldb_target.py @@ -0,0 +1,315 @@ +import os + +import lldb + +from .arch import supported_architectures +from .snapshot import ProgramState + +class MemoryMap: + """Description of a range of mapped memory. + + Inspired by https://github.com/angr/angr-targets/blob/master/angr_targets/memory_map.py, + meaning we initially used angr and I wanted to keep the interface when we + switched to a different tool. + """ + def __init__(self, start_address, end_address, name, perms): + self.start_address = start_address + self.end_address = end_address + self.name = name + self.perms = perms + + def __str__(self): + return f'MemoryMap[0x{self.start_address:x}, 0x{self.end_address:x}]' \ + f': {self.name}' + +class ConcreteRegisterError(Exception): + pass + +class ConcreteMemoryError(Exception): + pass + +class ConcreteSectionError(Exception): + pass + +class LLDBConcreteTarget: + from focaccia.arch import aarch64, x86 + + flag_register_names = { + aarch64.archname: 'cpsr', + x86.archname: 'rflags', + } + + flag_register_decompose = { + aarch64.archname: aarch64.decompose_cpsr, + x86.archname: x86.decompose_rflags, + } + + def __init__(self, + executable: str, + argv: list[str] = [], + envp: list[str] | None = None): + """Construct an LLDB concrete target. Stop at entry. + + :param argv: List of arguements. Does NOT include the conventional + executable name as the first entry. + :param envp: List of environment entries. Defaults to current + `os.environ` if `None`. + :raises RuntimeError: If the process is unable to launch. + """ + if envp is None: + envp = [f'{k}={v}' for k, v in os.environ.items()] + + self.debugger = lldb.SBDebugger.Create() + self.debugger.SetAsync(False) + self.target = self.debugger.CreateTargetWithFileAndArch(executable, + lldb.LLDB_ARCH_DEFAULT) + self.module = self.target.FindModule(self.target.GetExecutable()) + self.interpreter = self.debugger.GetCommandInterpreter() + + # Set up objects for process execution + self.error = lldb.SBError() + self.listener = self.debugger.GetListener() + self.process = self.target.Launch(self.listener, + argv, envp, # argv, envp + None, None, None, # stdin, stdout, stderr + None, # working directory + 0, + True, self.error) + if not self.process.IsValid(): + raise RuntimeError(f'[In LLDBConcreteTarget.__init__]: Failed to' + f' launch process.') + + # Determine current arch + self.archname = self.target.GetPlatform().GetTriple().split('-')[0] + if self.archname not in supported_architectures: + err = f'LLDBConcreteTarget: Architecture {self.archname} is not' \ + f' supported by Focaccia.' + print(f'[ERROR] {err}') + raise NotImplementedError(err) + self.arch = supported_architectures[self.archname] + + def is_exited(self): + """Signals whether the concrete process has exited. + + :return: True if the process has exited. False otherwise. + """ + return self.process.GetState() == lldb.eStateExited + + def run(self): + """Continue execution of the concrete process.""" + state = self.process.GetState() + if state == lldb.eStateExited: + raise RuntimeError(f'Tried to resume process execution, but the' + f' process has already exited.') + assert(state == lldb.eStateStopped) + self.process.Continue() + + def step(self): + """Step forward by a single instruction.""" + thread: lldb.SBThread = self.process.GetThreadAtIndex(0) + thread.StepInstruction(False) + + def run_until(self, address: int) -> None: + """Continue execution until the address is arrived, ignores other breakpoints""" + bp = self.target.BreakpointCreateByAddress(address) + while self.read_register("pc") != address: + self.run() + self.target.BreakpointDelete(bp.GetID()) + + def record_snapshot(self) -> ProgramState: + """Record the concrete target's state in a ProgramState object.""" + state = ProgramState(self.arch) + + # Query and store register state + for regname in self.arch.regnames: + try: + conc_val = self.read_register(regname) + state.set_register(regname, conc_val) + except KeyError: + pass + except ConcreteRegisterError: + pass + + # Query and store memory state + for mapping in self.get_mappings(): + assert(mapping.end_address > mapping.start_address) + size = mapping.end_address - mapping.start_address + try: + data = self.read_memory(mapping.start_address, size) + state.write_memory(mapping.start_address, data) + except ConcreteMemoryError: + pass + + return state + + def _get_register(self, regname: str) -> lldb.SBValue: + """Find a register by name. + + :raise ConcreteRegisterError: If no register with the specified name + can be found. + """ + frame = self.process.GetThreadAtIndex(0).GetFrameAtIndex(0) + reg = frame.FindRegister(regname) + if not reg.IsValid(): + raise ConcreteRegisterError( + f'[In LLDBConcreteTarget._get_register]: Register {regname}' + f' not found.') + return reg + + def read_flags(self) -> dict[str, int | bool]: + """Read the current state flags. + + If the concrete target's architecture has state flags, read and return + their current values. + + This handles the conversion from implementation details like flags + registers to the logical flag values. For example: On X86, this reads + the RFLAGS register and extracts the flag bits from its value. + + :return: Dictionary mapping flag names to values. The values may be + booleans in the case of true binary flags or integers in the + case of multi-byte flags. Is empty if the current architecture + does not have state flags of the access is not implemented for + it. + """ + if self.archname not in self.flag_register_names: + return {} + + flags_reg = self.flag_register_names[self.archname] + flags_val = self._get_register(flags_reg).GetValueAsUnsigned() + return self.flag_register_decompose[self.archname](flags_val) + + def read_register(self, regname: str) -> int: + """Read the value of a register. + + :raise ConcreteRegisterError: If `regname` is not a valid register name + or the target is otherwise unable to read + the register's value. + """ + try: + reg = self._get_register(regname) + assert(reg.IsValid()) + if reg.size > 8: # reg is a vector register + reg.data.byte_order = lldb.eByteOrderLittle + val = 0 + for ui64 in reversed(reg.data.uint64s): + val <<= 64 + val |= ui64 + return val + return reg.GetValueAsUnsigned() + except ConcreteRegisterError as err: + flags = self.read_flags() + if regname in flags: + return flags[regname] + raise ConcreteRegisterError( + f'[In LLDBConcreteTarget.read_register]: Unable to read' + f' register {regname}: {err}') + + def write_register(self, regname: str, value: int): + """Read the value of a register. + + :raise ConcreteRegisterError: If `regname` is not a valid register name + or the target is otherwise unable to set + the register's value. + """ + reg = self._get_register(regname) + error = lldb.SBError() + reg.SetValueFromCString(hex(value), error) + if not error.success: + raise ConcreteRegisterError( + f'[In LLDBConcreteTarget.write_register]: Unable to set' + f' {regname} to value {hex(value)}!') + + def read_memory(self, addr, size): + """Read bytes from memory. + + :raise ConcreteMemoryError: If unable to read `size` bytes from `addr`. + """ + err = lldb.SBError() + content = self.process.ReadMemory(addr, size, err) + if not err.success: + raise ConcreteMemoryError(f'Error when reading {size} bytes at' + f' address {hex(addr)}: {err}') + if self.arch.endianness == 'little': + return content + else: + return bytes(reversed(content)) + + def write_memory(self, addr, value: bytes): + """Write bytes to memory. + + :raise ConcreteMemoryError: If unable to write at `addr`. + """ + err = lldb.SBError() + res = self.process.WriteMemory(addr, value, err) + if not err.success or res != len(value): + raise ConcreteMemoryError(f'Error when writing to address' + f' {hex(addr)}: {err}') + + def get_mappings(self) -> list[MemoryMap]: + mmap = [] + + region_list = self.process.GetMemoryRegions() + for i in range(region_list.GetSize()): + region = lldb.SBMemoryRegionInfo() + region_list.GetMemoryRegionAtIndex(i, region) + + perms = f'{"r" if region.IsReadable() else "-"}' \ + f'{"w" if region.IsWritable() else "-"}' \ + f'{"x" if region.IsExecutable() else "-"}' + name = region.GetName() + + mmap.append(MemoryMap(region.GetRegionBase(), + region.GetRegionEnd(), + name if name is not None else '<none>', + perms)) + return mmap + + def set_breakpoint(self, addr): + command = f'b -a {addr} -s {self.module.GetFileSpec().GetFilename()}' + result = lldb.SBCommandReturnObject() + self.interpreter.HandleCommand(command, result) + + def remove_breakpoint(self, addr): + command = f'breakpoint delete {addr}' + result = lldb.SBCommandReturnObject() + self.interpreter.HandleCommand(command, result) + + def get_basic_block(self, addr: int) -> list[lldb.SBInstruction]: + """Returns a basic block pointed by addr + a code section is considered a basic block only if + the last instruction is a brach, e.g. JUMP, CALL, RET + """ + block = [] + while not self.target.ReadInstructions(lldb.SBAddress(addr, self.target), 1)[0].is_branch: + block.append(self.target.ReadInstructions(lldb.SBAddress(addr, self.target), 1)[0]) + addr += self.target.ReadInstructions(lldb.SBAddress(addr, self.target), 1)[0].size + block.append(self.target.ReadInstructions(lldb.SBAddress(addr, self.target), 1)[0]) + + return block + + def get_basic_block_inst(self, addr: int) -> list[str]: + inst = [] + for bb in self.get_basic_block(addr): + inst.append(f'{bb.GetMnemonic(self.target)} {bb.GetOperands(self.target)}') + return inst + + def get_next_basic_block(self) -> list[lldb.SBInstruction]: + return self.get_basic_block(self.read_register("pc")) + + def get_symbol(self, addr: int) -> lldb.SBSymbol: + """Returns the symbol that belongs to the addr + """ + for s in self.module.symbols: + if (s.GetType() == lldb.eSymbolTypeCode and s.GetStartAddress().GetLoadAddress(self.target) <= addr < s.GetEndAddress().GetLoadAddress(self.target)): + return s + raise ConcreteSectionError(f'Error getting the symbol to which address {hex(addr)} belongs to') + + def get_symbol_limit(self) -> int: + """Returns the address after all the symbols""" + addr = 0 + for s in self.module.symbols: + if s.GetStartAddress().IsValid(): + if s.GetStartAddress().GetLoadAddress(self.target) > addr: + addr = s.GetEndAddress().GetLoadAddress(self.target) + return addr diff --git a/focaccia/match.py b/focaccia/match.py new file mode 100644 index 0000000..be88176 --- /dev/null +++ b/focaccia/match.py @@ -0,0 +1,105 @@ +from typing import Iterable + +from .snapshot import ProgramState +from .symbolic import SymbolicTransform + +def _find_index(seq: Iterable, target, access_seq_elem=lambda el: el): + for i, el in enumerate(seq): + if access_seq_elem(el) == target: + return i + return None + +def fold_traces(ctrace: list[ProgramState], + strace: list[SymbolicTransform]): + """Try to fold a higher-granularity symbolic trace to match a lower- + granularity concrete trace. + + Modifies the inputs in-place. + + :param ctrace: A concrete trace. Is assumed to have lower granularity than + `truth`. + :param strace: A symbolic trace. Is assumed to have higher granularity than + `test`. We assume that because we control the symbolic trace + generation algorithm, and it produces traces on the level of + single instructions, which is the highest granularity + possible. + """ + if not ctrace or not strace: + return [], [] + + assert(ctrace[0].read_register('pc') == strace[0].addr) + + i = 0 + for next_state in ctrace[1:]: + next_pc = next_state.read_register('pc') + index_in_truth = _find_index(strace[i:], next_pc, lambda el: el.range[1]) + + # If no next element (i.e. no foldable range) is found in the truth + # trace, assume that the test trace contains excess states. Remove one + # and try again. This might skip testing some states, but covers more + # of the entire trace. + if index_in_truth is None: + ctrace.pop(i + 1) + continue + + # Fold the range of truth states until the next test state + for _ in range(index_in_truth): + strace[i].concat(strace.pop(i + 1)) + + i += 1 + if len(strace) <= i: + break + + # Fold remaining symbolic transforms into one + while i + 1 < len(strace): + strace[i].concat(strace.pop(i + 1)) + + return ctrace, strace + +def match_traces(ctrace: list[ProgramState], \ + strace: list[SymbolicTransform]): + """Try to match traces that don't follow the same program flow. + + This algorithm is useful if traces of the same binary mismatch due to + differences in environment during their recording. + + Does not modify the arguments. Creates and returns new lists. + + :param test: A concrete trace. + :param truth: A symbolic trace. + + :return: The modified traces. + """ + if not strace: + return [], [] + + states = [] + matched_transforms = [] + + state_iter = iter(ctrace) + symb_i = 0 + for cur_state in state_iter: + pc = cur_state.read_register('pc') + + if pc != strace[symb_i].addr: + next_i = _find_index(strace[symb_i+1:], pc, lambda t: t.addr) + + # Drop the concrete state if no address in the symbolic trace + # matches + if next_i is None: + continue + + # Otherwise, jump to the next matching symbolic state + symb_i += next_i + 1 + + # Append the now matching state/transform pair to the traces + assert(cur_state.read_register('pc') == strace[symb_i].addr) + states.append(cur_state) + matched_transforms.append(strace[symb_i]) + + # Step forward + symb_i += 1 + + assert(len(states) == len(matched_transforms)) + + return states, matched_transforms diff --git a/focaccia/miasm_util.py b/focaccia/miasm_util.py new file mode 100644 index 0000000..a2cd025 --- /dev/null +++ b/focaccia/miasm_util.py @@ -0,0 +1,253 @@ +from typing import Callable + +from miasm.analysis.machine import Machine +from miasm.core.locationdb import LocationDB, LocKey +from miasm.expression.expression import Expr, ExprOp, ExprId, ExprLoc, \ + ExprInt, ExprMem, ExprCompose, \ + ExprSlice, ExprCond +from miasm.expression.simplifications import expr_simp_explicit + +from . import arch +from .arch import Arch +from .snapshot import ReadableProgramState, \ + RegisterAccessError, MemoryAccessError + +def make_machine(_arch: Arch) -> Machine: + """Create a Miasm `Machine` object corresponding to an `Arch`.""" + machines = { + arch.x86.archname: lambda _: Machine('x86_64'), + # Miasm only has ARM machine names with the l/b suffix: + arch.aarch64.archname: lambda a: Machine(f'aarch64{a.endianness[0]}'), + } + return machines[_arch.archname](_arch) + +def simp_segm(expr_simp, expr: ExprOp): + """Simplify a segmentation expression to an addition of the segment + register's base value and the address argument. + """ + import miasm.arch.x86.regs as regs + + base_regs = { + regs.FS: ExprId('fs_base', 64), + regs.GS: ExprId('gs_base', 64), + } + + if expr.op == 'segm': + segm, addr = expr.args + assert(segm == regs.FS or segm == regs.GS) + return expr_simp(base_regs[segm] + addr) + return expr + +def simp_fadd(expr_simp, expr: ExprOp): + from .utils import float_bits_to_uint, uint_bits_to_float, \ + double_bits_to_uint, uint_bits_to_double + + if expr.op != 'fadd': + return expr + + assert(len(expr.args) == 2) + lhs, rhs = expr.args + if lhs.is_int() and rhs.is_int(): + assert(lhs.size == rhs.size) + if lhs.size == 32: + uint_to_float = uint_bits_to_float + float_to_uint = float_bits_to_uint + elif lhs.size == 64: + uint_to_float = uint_bits_to_double + float_to_uint = double_bits_to_uint + else: + raise NotImplementedError('fadd on values of size not in {32, 64}') + + res = float_to_uint(uint_to_float(lhs.arg) + uint_to_float(rhs.arg)) + return expr_simp(ExprInt(res, expr.size)) + return expr + +# The expression simplifier used in this module +expr_simp = expr_simp_explicit +expr_simp.enable_passes({ + ExprOp: [simp_segm, simp_fadd], +}) + +class MiasmSymbolResolver: + """Resolves atomic symbols to some state.""" + + miasm_flag_aliases = { + arch.x86.archname: { + 'NF': 'SF', + 'I_F': 'IF', + 'IOPL_F': 'IOPL', + 'I_D': 'ID', + }, + arch.aarch64.archname: { + 'NF': 'N', + 'SF': 'N', + 'ZF': 'Z', + 'CF': 'C', + 'VF': 'V', + 'OF': 'V', + 'QF': 'Q', + + 'AF': 'A', + 'EF': 'E', + 'IF': 'I', + 'FF': 'F', + } + } + + def __init__(self, + state: ReadableProgramState, + loc_db: LocationDB): + self._state = state + self._loc_db = loc_db + self._arch = state.arch + self.endianness: Arch.Endianness = self._arch.endianness + + def _miasm_to_regname(self, regname: str) -> str: + """Convert a register name as used by Miasm to one that follows + Focaccia's naming conventions.""" + regname = regname.upper() + if self._arch.archname in self.miasm_flag_aliases: + aliases = self.miasm_flag_aliases[self._arch.archname] + return aliases.get(regname, regname) + return regname + + def resolve_register(self, regname: str) -> int | None: + try: + return self._state.read_register(self._miasm_to_regname(regname)) + except RegisterAccessError as err: + print(f'Not a register: {regname} ({err})') + return None + + def resolve_memory(self, addr: int, size: int) -> bytes | None: + try: + return self._state.read_memory(addr, size) + except MemoryAccessError: + return None + + def resolve_location(self, loc: LocKey) -> int | None: + return self._loc_db.get_location_offset(loc) + +def eval_expr(expr: Expr, conc_state: MiasmSymbolResolver) -> Expr: + """Evaluate a symbolic expression with regard to a concrete reference + state. + + :param expr: An expression to evaluate. + :param conc_state: The concrete reference state from which symbolic + register and memory state is resolved. + + :return: The most simplified and concrete representation of `expr` that + is producible with the values from `conc_state`. Is guaranteed to + be either an `ExprInt` or an `ExprLoc` *if* `conc_state` only + returns concrete register- and memory values. + """ + # Most of these implementation are just copy-pasted members of + # `SymbolicExecutionEngine`. + expr_to_visitor: dict[type[Expr], Callable] = { + ExprInt: _eval_exprint, + ExprId: _eval_exprid, + ExprLoc: _eval_exprloc, + ExprMem: _eval_exprmem, + ExprSlice: _eval_exprslice, + ExprCond: _eval_exprcond, + ExprOp: _eval_exprop, + ExprCompose: _eval_exprcompose, + } + + visitor = expr_to_visitor.get(expr.__class__, None) + if visitor is None: + raise TypeError("Unknown expr type") + + ret = visitor(expr, conc_state) + ret = expr_simp(ret) + assert(ret is not None) + + return ret + +def _eval_exprint(expr: ExprInt, _): + """Evaluate an ExprInt using the current state""" + return expr + +def _eval_exprid(expr: ExprId, state: MiasmSymbolResolver): + """Evaluate an ExprId using the current state""" + val = state.resolve_register(expr.name) + if val is None: + return expr + if isinstance(val, int): + return ExprInt(val, expr.size) + return val + +def _eval_exprloc(expr: ExprLoc, state: MiasmSymbolResolver): + """Evaluate an ExprLoc using the current state""" + offset = state.resolve_location(expr.loc_key) + if offset is None: + return expr + return ExprInt(offset, expr.size) + +def _eval_exprmem(expr: ExprMem, state: MiasmSymbolResolver): + """Evaluate an ExprMem using the current state. + This function first evaluates the memory pointer value. + """ + assert(expr.size % 8 == 0) + + addr = eval_expr(expr.ptr, state) + if not addr.is_int(): + return expr + + assert(isinstance(addr, ExprInt)) + mem = state.resolve_memory(int(addr), expr.size // 8) + if mem is None: + return expr + + assert(len(mem) * 8 == expr.size) + return ExprInt(int.from_bytes(mem, byteorder=state.endianness), expr.size) + +def _eval_exprcond(expr, state: MiasmSymbolResolver): + """Evaluate an ExprCond using the current state""" + cond = eval_expr(expr.cond, state) + src1 = eval_expr(expr.src1, state) + src2 = eval_expr(expr.src2, state) + return ExprCond(cond, src1, src2) + +def _eval_exprslice(expr, state: MiasmSymbolResolver): + """Evaluate an ExprSlice using the current state""" + arg = eval_expr(expr.arg, state) + return ExprSlice(arg, expr.start, expr.stop) + +def _eval_cpuid(rax: ExprInt, out_reg: ExprInt): + """Evaluate the `x86_cpuid` operator by performing a real invocation of + the CPUID instruction. + + :param rax: The current value of RAX. Must be concrete. + :param out_reg: An index in `[0, 4)` signaling which register's value + shall be returned. Must be concrete. + """ + from cpuid import cpuid + + regs = cpuid.CPUID()(int(rax)) + + if int(out_reg) >= len(regs): + raise ValueError(f'Output register may not be {out_reg}.') + return ExprInt(regs[int(out_reg)], out_reg.size) + +def _eval_exprop(expr, state: MiasmSymbolResolver): + """Evaluate an ExprOp using the current state""" + args = [eval_expr(arg, state) for arg in expr.args] + + # Special case: CPUID instruction + # Evaluate the expression to a value obtained from an an actual call to + # the CPUID instruction. Can't do this in an expression simplifier plugin + # because the arguments must be concrete. + if expr.op == 'x86_cpuid': + if args[0].is_int() and args[1].is_int(): + assert(isinstance(args[0], ExprInt) and isinstance(args[1], ExprInt)) + return _eval_cpuid(args[0], args[1]) + return expr + + return ExprOp(expr.op, *args) + +def _eval_exprcompose(expr, state: MiasmSymbolResolver): + """Evaluate an ExprCompose using the current state""" + args = [] + for arg in expr.args: + args.append(eval_expr(arg, state)) + return ExprCompose(*args) diff --git a/focaccia/parser.py b/focaccia/parser.py new file mode 100644 index 0000000..c37c07a --- /dev/null +++ b/focaccia/parser.py @@ -0,0 +1,172 @@ +"""Parsing of JSON files containing snapshot data.""" + +import base64 +import json +import re +from typing import TextIO + +from .arch import supported_architectures, Arch +from .snapshot import ProgramState +from .symbolic import SymbolicTransform +from .trace import Trace, TraceEnvironment + +class ParseError(Exception): + """A parse error.""" + +def _get_or_throw(obj: dict, key: str): + """Get a value from a dict or throw a ParseError if not present.""" + val = obj.get(key) + if val is not None: + return val + raise ParseError(f'Expected value at key {key}, but found none.') + +def parse_transformations(json_stream: TextIO) -> Trace[SymbolicTransform]: + """Parse symbolic transformations from a text stream.""" + data = json.load(json_stream) + + env = TraceEnvironment.from_json(_get_or_throw(data, 'env')) + strace = [SymbolicTransform.from_json(item) \ + for item in _get_or_throw(data, 'states')] + + return Trace(strace, env) + +def serialize_transformations(transforms: Trace[SymbolicTransform], + out_stream: TextIO): + """Serialize symbolic transformations to a text stream.""" + json.dump({ + 'env': transforms.env.to_json(), + 'states': [t.to_json() for t in transforms], + }, out_stream) + +def parse_snapshots(json_stream: TextIO) -> Trace[ProgramState]: + """Parse snapshots from our JSON format.""" + json_data = json.load(json_stream) + + arch = supported_architectures[_get_or_throw(json_data, 'architecture')] + env = TraceEnvironment.from_json(_get_or_throw(json_data, 'env')) + snapshots = [] + for snapshot in _get_or_throw(json_data, 'snapshots'): + state = ProgramState(arch) + for reg, val in _get_or_throw(snapshot, 'registers').items(): + state.set_register(reg, val) + for mem in _get_or_throw(snapshot, 'memory'): + start, end = _get_or_throw(mem, 'range') + data = base64.b64decode(_get_or_throw(mem, 'data')) + assert(len(data) == end - start) + state.write_memory(start, data) + + snapshots.append(state) + + return Trace(snapshots, env) + +def serialize_snapshots(snapshots: Trace[ProgramState], out_stream: TextIO): + """Serialize a list of snapshots to out JSON format.""" + if not snapshots: + return json.dump({}, out_stream) + + arch = snapshots[0].arch + res = { + 'architecture': arch.archname, + 'env': snapshots.env.to_json(), + 'snapshots': [] + } + for snapshot in snapshots: + assert(snapshot.arch == arch) + regs = {r: v for r, v in snapshot.regs.items() if v is not None} + mem = [] + for addr, data in snapshot.mem._pages.items(): + mem.append({ + 'range': [addr, addr + len(data)], + 'data': base64.b64encode(data).decode('ascii') + }) + res['snapshots'].append({ 'registers': regs, 'memory': mem }) + + json.dump(res, out_stream) + +def _make_unknown_env() -> TraceEnvironment: + return TraceEnvironment('', [], [], '?') + +def parse_qemu(stream: TextIO, arch: Arch) -> Trace[ProgramState]: + """Parse a QEMU log from a stream. + + Recommended QEMU log option: `qemu -d exec,cpu,fpu,vpu,nochain`. The `exec` + flag is strictly necessary for the log to be parseable. + + :return: A list of parsed program states, in order of occurrence in the + log. + """ + states = [] + for line in stream: + if line.startswith('Trace'): + states.append(ProgramState(arch)) + continue + if states: + _parse_qemu_line(line, states[-1]) + + return Trace(states, _make_unknown_env()) + +def _parse_qemu_line(line: str, cur_state: ProgramState): + """Try to parse a single register-assignment line from a QEMU log. + + Set all registers for which the line specified values in a `ProgramState` + object. + + :param line: The log line to parse. + :param cur_state: The state on which to set parsed register values. + """ + line = line.strip() + + # Remove padding spaces around equality signs + line = re.sub(' =', '=', line) + line = re.sub('= +', '=', line) + + # Standardize register names + line = re.sub('YMM0([0-9])', lambda m: f'YMM{m.group(1)}', line) + line = re.sub('FPR([0-9])', lambda m: f'ST{m.group(1)}', line) + + # Bring each register assignment into a new line + line = re.sub(' ([A-Z0-9]+)=', lambda m: f'\n{m.group(1)}=', line) + + # Remove all trailing information from register assignments + line = re.sub('^([A-Z0-9]+)=([0-9a-f ]+).*$', + lambda m: f'{m.group(1)}={m.group(2)}', + line, + 0, re.MULTILINE) + + # Now parse registers and their values from the resulting lines + lines = line.split('\n') + for line in lines: + split = line.split('=') + if len(split) == 2: + regname, value = split + value = value.replace(' ', '') + regname = cur_state.arch.to_regname(regname) + if regname is not None: + cur_state.set_register(regname, int(value, 16)) + +def parse_arancini(stream: TextIO, arch: Arch) -> Trace[ProgramState]: + aliases = { + 'Program counter': 'RIP', + 'flag ZF': 'ZF', + 'flag CF': 'CF', + 'flag OF': 'OF', + 'flag SF': 'SF', + 'flag PF': 'PF', + 'flag DF': 'DF', + } + + states = [] + for line in stream: + if line.startswith('INVOKE PC='): + states.append(ProgramState(arch)) + continue + + # Parse a register assignment + split = line.split(':') + if len(split) == 2 and states: + regname, value = split + regname = arch.to_regname(aliases.get(regname, regname)) + if regname is not None: + states[-1].set_register(regname, int(value, 16)) + + return Trace(states, _make_unknown_env()) diff --git a/focaccia/reproducer.py b/focaccia/reproducer.py new file mode 100644 index 0000000..90e1378 --- /dev/null +++ b/focaccia/reproducer.py @@ -0,0 +1,172 @@ + +from .lldb_target import LLDBConcreteTarget +from .snapshot import ProgramState +from .symbolic import SymbolicTransform, eval_symbol +from .arch import x86 + +class ReproducerMemoryError(Exception): + pass +class ReproducerBasicBlockError(Exception): + pass +class ReproducerRegisterError(Exception): + pass + +class Reproducer(): + def __init__(self, oracle: str, argv: str, snap: ProgramState, sym: SymbolicTransform) -> None: + + target = LLDBConcreteTarget(oracle) + + self.pc = snap.read_register("pc") + self.bb = target.get_basic_block_inst(self.pc) + self.sl = target.get_symbol_limit() + self.snap = snap + self.sym = sym + + def get_bb(self) -> str: + try: + asm = "" + asm += f'_bb_{hex(self.pc)}:\n' + for i in self.bb[:-1]: + asm += f'{i}\n' + asm += f'ret\n' + asm += f'\n' + + return asm + except: + raise ReproducerBasicBlockError(f'{hex(self.pc)}\n{self.snap}\n{self.sym}\n{self.bb}') + + def get_regs(self) -> str: + general_regs = ['RIP', 'RAX', 'RBX','RCX','RDX', 'RSI','RDI','RBP','RSP','R8','R9','R10','R11','R12','R13','R14','R15',] + flag_regs = ['CF', 'PF', 'AF', 'ZF', 'SF', 'TF', 'IF', 'DF', 'OF', 'IOPL', 'NT',] + eflag_regs = ['RF', 'VM', 'AC', 'VIF', 'VIP', 'ID',] + + try: + asm = "" + asm += f'_setup_regs:\n' + for reg in self.sym.get_used_registers(): + if reg in general_regs: + asm += f'mov ${hex(self.snap.read_register(reg))}, %{reg.lower()}\n' + + if 'RFLAGS' in self.sym.get_used_registers(): + asm += f'pushfq ${hex(self.snap.read_register("RFLAGS"))}\n' + + if any(reg in self.sym.get_used_registers() for reg in flag_regs+eflag_regs): + asm += f'pushfd ${hex(x86.compose_rflags(self.snap.regs))}\n' + asm += f'ret\n' + asm += f'\n' + + return asm + except: + raise ReproducerRegisterError(f'{hex(self.pc)}\n{self.snap}\n{self.sym}\n{self.bb}') + + def get_mem(self) -> str: + try: + asm = "" + asm += f'_setup_mem:\n' + for mem in self.sym.get_used_memory_addresses(): + addr = eval_symbol(mem.ptr, self.snap) + val = self.snap.read_memory(addr, int(mem.size/8)) + + if addr < self.sl: + asm += f'.org {hex(addr)}\n' + for b in val: + asm += f'.byte ${hex(b)}\n' + asm += f'\n' + + return asm + except: + raise ReproducerMemoryError(f'{hex(self.pc)}\n{self.snap}\n{self.sym}\n{self.bb}') + + def get_dyn(self) -> str: + try: + asm = "" + asm += f'_setup_dyn:\n' + for mem in self.sym.get_used_memory_addresses(): + addr = eval_symbol(mem.ptr, self.snap) + val = self.snap.read_memory(addr, int(mem.size/8)) + + if addr >= self.sl: + asm += f'mov ${hex(addr)}, %rdi\n' + asm += f'call _alloc\n' + for b in val: + asm += f'mov ${hex(addr)}, %rax\n' + asm += f'movb ${hex(b)}, (%rax)\n' + addr += 1 + asm += f'ret\n' + asm += f'\n' + + return asm + except: + raise ReproducerMemoryError(f'{hex(self.pc)}\n{self.snap}\n{self.sym}\n{self.bb}') + + def get_start(self) -> str: + asm = "" + asm += f'_start:\n' + asm += f'call _setup_dyn\n' + asm += f'call _setup_regs\n' + asm += f'call _bb_{hex(self.pc)}\n' + asm += f'call _exit\n' + asm += f'\n' + + return asm + + def get_exit(self) -> str: + asm = "" + asm += f'_exit:\n' + asm += f'movq $0, %rdi\n' + asm += f'movq $60, %rax\n' + asm += f'syscall\n' + asm += f'\n' + + return asm + + def get_alloc(self) -> str: + asm = "" + asm += f'_alloc:\n' + asm += f'movq $4096, %rsi\n' + asm += f'movq $(PROT_READ | PROT_WRITE), %rdx\n' + asm += f'movq $(MAP_PRIVATE | MAP_ANONYMOUS), %r10\n' + asm += f'movq $-1, %r8\n' + asm += f'movq $0, %r9\n' + asm += f'movq $syscall_mmap, %rax\n' + asm += f'syscall\n' + asm += f'ret\n' + asm += f'\n' + + return asm + + def get_code(self) -> str: + asm = "" + asm += f'.section .text\n' + asm += f'.global _start\n' + asm += f'\n' + asm += f'.org {hex(self.pc)}\n' + asm += self.get_bb() + asm += self.get_start() + asm += self.get_exit() + asm += self.get_alloc() + asm += self.get_regs() + asm += self.get_dyn() + + return asm + + def get_data(self) -> str: + asm = "" + asm += f'.section .data\n' + asm += f'PROT_READ = 0x1\n' + asm += f'PROT_WRITE = 0x2\n' + asm += f'MAP_PRIVATE = 0x2\n' + asm += f'MAP_ANONYMOUS = 0x20\n' + asm += f'syscall_mmap = 9\n' + asm += f'\n' + + asm += self.get_mem() + + return asm + + def asm(self) -> str: + asm = "" + asm += self.get_code() + asm += self.get_data() + + return asm diff --git a/focaccia/snapshot.py b/focaccia/snapshot.py new file mode 100644 index 0000000..1945d71 --- /dev/null +++ b/focaccia/snapshot.py @@ -0,0 +1,180 @@ +from .arch.arch import Arch + +class RegisterAccessError(Exception): + """Raised when a register access fails.""" + def __init__(self, regname: str, msg: str): + super().__init__(msg) + self.regname = regname + +class MemoryAccessError(Exception): + """Raised when a memory access fails.""" + def __init__(self, addr: int, size: int, msg: str): + super().__init__(msg) + self.mem_addr = addr + self.mem_size = size + +class SparseMemory: + """Sparse memory. + + Note that out-of-bound reads are possible when performed on unwritten + sections of existing pages and that there is no safeguard check for them. + """ + def __init__(self, page_size=256): + self.page_size = page_size + self._pages: dict[int, bytes] = {} + + def _to_page_addr_and_offset(self, addr: int) -> tuple[int, int]: + off = addr % self.page_size + return addr - off, off + + def read(self, addr: int, size: int) -> bytes: + """Read a number of bytes from memory. + :param addr: The offset from where to read. + :param size: The number of bytes to read, starting at at `addr`. + + :return: `size` bytes of data. + :raise MemoryAccessError: If `[addr, addr + size)` is not entirely + contained in the set of stored bytes. + :raise ValueError: If `size < 0`. + """ + if size < 0: + raise ValueError(f'A negative size is not allowed!') + + res = bytes() + while size > 0: + page_addr, off = self._to_page_addr_and_offset(addr) + if page_addr not in self._pages: + raise MemoryAccessError(addr, size, + f'Address {hex(addr)} is not contained' + f' in the sparse memory.') + data = self._pages[page_addr] + assert(len(data) == self.page_size) + read_size = min(size, self.page_size - off) + res += data[off:off+read_size] + + size -= read_size + addr += read_size + return res + + def write(self, addr: int, data: bytes): + """Store bytes in the memory. + :param addr: The address at which to store the data. + :param data: The data to store at `addr`. + """ + offset = 0 # Current offset into `data` + while offset < len(data): + page_addr, off = self._to_page_addr_and_offset(addr) + if page_addr not in self._pages: + self._pages[page_addr] = bytes(self.page_size) + page = self._pages[page_addr] + assert(len(page) == self.page_size) + + write_size = min(len(data) - offset, self.page_size - off) + new_page = page[:off] + data[offset:offset + write_size] + page[off+write_size:] + assert(len(new_page) == self.page_size) + self._pages[page_addr] = new_page + + offset += write_size + addr += write_size + + assert(len(data) == offset) # Exactly all data was written + +class ReadableProgramState: + """Interface for read-only program states.""" + def __init__(self, arch: Arch): + self.arch = arch + + def read_register(self, reg: str) -> int: + """Read a register's value. + + :raise RegisterAccessError: If `reg` is not a register name, or if the + register has no value. + """ + raise NotImplementedError('ReadableProgramState.read_register is abstract.') + + def read_memory(self, addr: int, size: int) -> bytes: + """Read a number of bytes from memory. + + :param addr: The address from which to read data. + :param data: Number of bytes to read, starting at `addr`. Must be + at least zero. + + :raise MemoryAccessError: If `[addr, addr + size)` is not entirely + contained in the set of stored bytes. + :raise ValueError: If `size < 0`. + """ + raise NotImplementedError('ReadableProgramState.read_memory is abstract.') + +class ProgramState(ReadableProgramState): + """A snapshot of the program's state.""" + def __init__(self, arch: Arch): + super().__init__(arch=arch) + + self.regs: dict[str, int | None] = {reg: None for reg in arch.regnames} + self.mem = SparseMemory() + + def read_register(self, reg: str) -> int: + """Read a register's value. + + :raise RegisterAccessError: If `reg` is not a register name, or if the + register has no value. + """ + acc = self.arch.get_reg_accessor(reg) + if acc is None: + raise RegisterAccessError(reg, f'Not a register name: {reg}') + + assert(acc.base_reg in self.regs) + regval = self.regs[acc.base_reg] + if regval is None: + raise RegisterAccessError( + acc.base_reg, + f'[In ProgramState.read_register]: Unable to read value of' + f' register {reg} (a.k.a. {acc}): The register is not set.' + f' Full state: {self}') + + return (regval & acc.mask) >> acc.start + + def set_register(self, reg: str, value: int): + """Assign a value to a register. + + :raise RegisterAccessError: If `reg` is not a register name. + """ + acc = self.arch.get_reg_accessor(reg) + if acc is None: + raise RegisterAccessError(reg, f'Not a register name: {reg}') + + assert(acc.base_reg in self.regs) + base_reg_size = self.arch.get_reg_accessor(acc.base_reg).num_bits + + val = self.regs[acc.base_reg] + if val is None: + val = 0 + val &= (~acc.mask & ((1 << base_reg_size) - 1)) # Clear bits in range + val |= (value << acc.start) & acc.mask # Set bits in range + + self.regs[acc.base_reg] = val + + def read_memory(self, addr: int, size: int) -> bytes: + """Read a number of bytes from memory. + + :param addr: The address from which to read data. + :param data: Number of bytes to read, starting at `addr`. Must be + at least zero. + + :raise MemoryAccessError: If `[addr, addr + size)` is not entirely + contained in the set of stored bytes. + :raise ValueError: If `size < 0`. + """ + return self.mem.read(addr, size) + + def write_memory(self, addr: int, data: bytes): + """Write a number of bytes to memory. + + :param addr: The address at which to store the data. + :param data: The data to store at `addr`. + """ + self.mem.write(addr, data) + + def __repr__(self): + regs = {r: hex(v) for r, v in self.regs.items() if v is not None} + return f'Snapshot ({self.arch.archname}): {regs}' diff --git a/focaccia/symbolic.py b/focaccia/symbolic.py new file mode 100644 index 0000000..9aeff56 --- /dev/null +++ b/focaccia/symbolic.py @@ -0,0 +1,692 @@ +"""Tools and utilities for symbolic execution with Miasm.""" + +from __future__ import annotations +from typing import Iterable +import logging +import sys + +from miasm.analysis.binary import ContainerELF +from miasm.analysis.machine import Machine +from miasm.core.cpu import instruction as miasm_instr +from miasm.core.locationdb import LocationDB +from miasm.expression.expression import Expr, ExprId, ExprMem, ExprInt +from miasm.ir.ir import Lifter +from miasm.ir.symbexec import SymbolicExecutionEngine + +from .arch import Arch, supported_architectures +from .lldb_target import LLDBConcreteTarget, \ + ConcreteRegisterError, \ + ConcreteMemoryError +from .miasm_util import MiasmSymbolResolver, eval_expr, make_machine +from .snapshot import ProgramState, ReadableProgramState, \ + RegisterAccessError, MemoryAccessError +from .trace import Trace, TraceEnvironment + +logger = logging.getLogger('focaccia-symbolic') +warn = logger.warn + +# Disable Miasm's disassembly logger +logging.getLogger('asmblock').setLevel(logging.CRITICAL) + +def eval_symbol(symbol: Expr, conc_state: ReadableProgramState) -> int: + """Evaluate a symbol based on a concrete reference state. + + :param conc_state: A concrete state. + :return: The resolved value. + + :raise ValueError: If the concrete state does not contain a register value + that is referenced by the symbolic expression. + :raise MemoryAccessError: If the concrete state does not contain memory + that is referenced by the symbolic expression. + """ + class ConcreteStateWrapper(MiasmSymbolResolver): + """Extend the state resolver with assumptions about the expressions + that may be resolved with `eval_symbol`.""" + def __init__(self, conc_state: ReadableProgramState): + super().__init__(conc_state, LocationDB()) + + def resolve_register(self, regname: str) -> int: + return self._state.read_register(self._miasm_to_regname(regname)) + + def resolve_memory(self, addr: int, size: int) -> bytes: + return self._state.read_memory(addr, size) + + def resolve_location(self, loc): + raise ValueError(f'[In eval_symbol]: Unable to evaluate symbols' + f' that contain IR location expressions.') + + res = eval_expr(symbol, ConcreteStateWrapper(conc_state)) + assert(isinstance(res, ExprInt)) # Must be either ExprInt or ExprLoc, + # but ExprLocs are disallowed by the + # ConcreteStateWrapper + return int(res) + +class Instruction: + """An instruction.""" + def __init__(self, + instr: miasm_instr, + machine: Machine, + arch: Arch, + loc_db: LocationDB | None = None): + self.arch = arch + self.machine = machine + + if loc_db is not None: + instr.args = instr.resolve_args_with_symbols(loc_db) + self.instr: miasm_instr = instr + """The underlying Miasm instruction object.""" + + assert(instr.offset is not None) + assert(instr.l is not None) + self.addr: int = instr.offset + self.length: int = instr.l + + @staticmethod + def from_bytecode(asm: bytes, arch: Arch) -> Instruction: + """Disassemble an instruction.""" + machine = make_machine(arch) + assert(machine.mn is not None) + _instr = machine.mn.dis(asm, arch.ptr_size) + return Instruction(_instr, machine, arch, None) + + @staticmethod + def from_string(s: str, arch: Arch, offset: int = 0, length: int = 0) -> Instruction: + machine = make_machine(arch) + assert(machine.mn is not None) + _instr = machine.mn.fromstring(s, LocationDB(), arch.ptr_size) + _instr.offset = offset + _instr.l = length + return Instruction(_instr, machine, arch, None) + + def to_bytecode(self) -> bytes: + """Assemble the instruction to byte code.""" + assert(self.machine.mn is not None) + return self.machine.mn.asm(self.instr)[0] + + def to_string(self) -> str: + """Convert the instruction to an Intel-syntax assembly string.""" + return str(self.instr) + + def __repr__(self): + return self.to_string() + +class SymbolicTransform: + """A symbolic transformation mapping one program state to another.""" + def __init__(self, + transform: dict[Expr, Expr], + instrs: list[Instruction], + arch: Arch, + from_addr: int, + to_addr: int): + """ + :param state: The symbolic transformation in the form of a SimState + object. + :param first_inst: An instruction address. The transformation + represents the modifications to the program state + performed by this instruction. + """ + self.arch = arch + + self.addr = from_addr + """The instruction address of the program state on which the + transformation operates. Equivalent to `self.range[0]`.""" + + self.range = (from_addr, to_addr) + """The range of addresses that the transformation covers. + The transformation `t` maps the program state at instruction + `t.range[0]` to the program state at instruction `t.range[1]`.""" + + self.changed_regs: dict[str, Expr] = {} + """Maps register names to expressions for the register's content. + + Contains only registers that are changed by the transformation. + Register names are already normalized to a respective architecture's + naming conventions.""" + + self.changed_mem: dict[Expr, Expr] = {} + """Maps memory addresses to memory content. + + For a dict tuple `(addr, value)`, `value.size` is the number of *bits* + written to address `addr`. Memory addresses may depend on other + symbolic values, such as register content, and are therefore symbolic + themselves.""" + + self.instructions: list[Instruction] = instrs + """The sequence of instructions that comprise this transformation.""" + + for dst, expr in transform.items(): + assert(isinstance(dst, ExprMem) or isinstance(dst, ExprId)) + + if isinstance(dst, ExprMem): + assert(dst.size == expr.size) + assert(expr.size % 8 == 0) + self.changed_mem[dst.ptr] = expr + else: + assert(isinstance(dst, ExprId)) + regname = arch.to_regname(dst.name) + if regname is not None: + self.changed_regs[regname] = expr + + def concat(self, other: SymbolicTransform) -> SymbolicTransform: + """Concatenate two transformations. + + The symbolic transform on which `concat` is called is the transform + that is applied first, meaning: `(a.concat(b))(state) == b(a(state))`. + + Note that if transformation are concatenated that write to the same + memory location when applied to a specific starting state, the + concatenation may not recognize equivalence of syntactically different + symbolic address expressions. In this case, if you calculate all memory + values and store them at their address, the final result will depend on + the random iteration order over the `changed_mem` dict. + + :param other: The transformation to concatenate to `self`. + + :return: Returns `self`. `self` is modified in-place. + :raise ValueError: If the two transformations don't span a contiguous + range of instructions. + """ + from typing import Callable + from miasm.expression.expression import ExprLoc, ExprSlice, ExprCond, \ + ExprOp, ExprCompose + from miasm.expression.simplifications import expr_simp_explicit + + if self.range[1] != other.range[0]: + repr_range = lambda r: f'[{hex(r[0])} -> {hex(r[1])}]' + raise ValueError( + f'Unable to concatenate transformation' + f' {repr_range(self.range)} with {repr_range(other.range)};' + f' the concatenated transformations must span a' + f' contiguous range of instructions.') + + def _eval_exprslice(expr: ExprSlice): + arg = _concat_to_self(expr.arg) + return ExprSlice(arg, expr.start, expr.stop) + + def _eval_exprcond(expr: ExprCond): + cond = _concat_to_self(expr.cond) + src1 = _concat_to_self(expr.src1) + src2 = _concat_to_self(expr.src2) + return ExprCond(cond, src1, src2) + + def _eval_exprop(expr: ExprOp): + args = [_concat_to_self(arg) for arg in expr.args] + return ExprOp(expr.op, *args) + + def _eval_exprcompose(expr: ExprCompose): + args = [_concat_to_self(arg) for arg in expr.args] + return ExprCompose(*args) + + expr_to_visitor: dict[type[Expr], Callable] = { + ExprInt: lambda e: e, + ExprId: lambda e: self.changed_regs.get(e.name, e), + ExprLoc: lambda e: e, + ExprMem: lambda e: ExprMem(_concat_to_self(e.ptr), e.size), + ExprSlice: _eval_exprslice, + ExprCond: _eval_exprcond, + ExprOp: _eval_exprop, + ExprCompose: _eval_exprcompose, + } + + def _concat_to_self(expr: Expr): + visitor = expr_to_visitor[expr.__class__] + return expr_simp_explicit(visitor(expr)) + + new_regs = self.changed_regs.copy() + for reg, expr in other.changed_regs.items(): + new_regs[reg] = _concat_to_self(expr) + + new_mem = self.changed_mem.copy() + for addr, expr in other.changed_mem.items(): + new_addr = _concat_to_self(addr) + new_expr = _concat_to_self(expr) + new_mem[new_addr] = new_expr + + self.changed_regs = new_regs + self.changed_mem = new_mem + self.range = (self.range[0], other.range[1]) + self.instructions.extend(other.instructions) + + return self + + def get_used_registers(self) -> list[str]: + """Find all registers used by the transformation as input. + + :return: A list of register names. + """ + accessed_regs = set[str]() + + class RegisterCollector(MiasmSymbolResolver): + def __init__(self, arch: Arch): + self._arch = arch # MiasmSymbolResolver needs this + def resolve_register(self, regname: str) -> int | None: + accessed_regs.add(self._miasm_to_regname(regname)) + return None + def resolve_memory(self, addr: int, size: int): pass + def resolve_location(self, loc): assert(False) + + resolver = RegisterCollector(self.arch) + for expr in self.changed_regs.values(): + eval_expr(expr, resolver) + for addr_expr, mem_expr in self.changed_mem.items(): + eval_expr(addr_expr, resolver) + eval_expr(mem_expr, resolver) + + return list(accessed_regs) + + def get_used_memory_addresses(self) -> list[ExprMem]: + """Find all memory addresses used by the transformation as input. + + :return: A list of memory access expressions. + """ + from typing import Callable + from miasm.expression.expression import ExprLoc, ExprSlice, ExprCond, \ + ExprOp, ExprCompose + + accessed_mem = set[ExprMem]() + + def _eval(expr: Expr): + def _eval_exprmem(expr: ExprMem): + accessed_mem.add(expr) # <-- this is the only important line! + _eval(expr.ptr) + def _eval_exprcond(expr: ExprCond): + _eval(expr.cond) + _eval(expr.src1) + _eval(expr.src2) + def _eval_exprop(expr: ExprOp): + for arg in expr.args: + _eval(arg) + def _eval_exprcompose(expr: ExprCompose): + for arg in expr.args: + _eval(arg) + + expr_to_visitor: dict[type[Expr], Callable] = { + ExprInt: lambda e: e, + ExprId: lambda e: e, + ExprLoc: lambda e: e, + ExprMem: _eval_exprmem, + ExprSlice: lambda e: _eval(e.arg), + ExprCond: _eval_exprcond, + ExprOp: _eval_exprop, + ExprCompose: _eval_exprcompose, + } + visitor = expr_to_visitor[expr.__class__] + visitor(expr) + + for expr in self.changed_regs.values(): + _eval(expr) + for addr_expr, mem_expr in self.changed_mem.items(): + _eval(addr_expr) + _eval(mem_expr) + + return list(accessed_mem) + + def eval_register_transforms(self, conc_state: ReadableProgramState) \ + -> dict[str, int]: + """Calculate register transformations when applied to a concrete state. + + :param conc_state: A concrete program state that serves as the input + state on which the transformation operates. + + :return: A map from register names to the register values that were + changed by the transformation. + :raise MemoryError: + :raise ValueError: + """ + res = {} + for regname, expr in self.changed_regs.items(): + res[regname] = eval_symbol(expr, conc_state) + return res + + def eval_memory_transforms(self, conc_state: ReadableProgramState) \ + -> dict[int, bytes]: + """Calculate memory transformations when applied to a concrete state. + + :param conc_state: A concrete program state that serves as the input + state on which the transformation operates. + + :return: A map from memory addresses to the bytes that were changed by + the transformation. + :raise MemoryError: + :raise ValueError: + """ + res = {} + for addr, expr in self.changed_mem.items(): + addr = eval_symbol(addr, conc_state) + length = int(expr.size / 8) + res[addr] = eval_symbol(expr, conc_state) \ + .to_bytes(length, byteorder=self.arch.endianness) + return res + + @classmethod + def from_json(cls, data: dict) -> SymbolicTransform: + """Parse a symbolic transformation from a JSON object. + + :raise KeyError: if a parse error occurs. + """ + from miasm.expression.parser import str_to_expr as parse + + def decode_inst(obj: list, arch: Arch): + length, text = obj + try: + return Instruction.from_string(text, arch, offset=0, length=length) + except Exception as err: + warn(f'[In SymbolicTransform.from_json] Unable to parse' + f' instruction string "{text}": {err}.') + return None + + arch = supported_architectures[data['arch']] + start_addr = int(data['from_addr']) + end_addr = int(data['to_addr']) + + t = SymbolicTransform({}, [], arch, start_addr, end_addr) + t.changed_regs = { name: parse(val) for name, val in data['regs'].items() } + t.changed_mem = { parse(addr): parse(val) for addr, val in data['mem'].items() } + instrs = [decode_inst(b, arch) for b in data['instructions']] + t.instructions = [inst for inst in instrs if inst is not None] + + # Recover the instructions' address information + addr = t.addr + for inst in t.instructions: + inst.addr = addr + addr += inst.length + + return t + + def to_json(self) -> dict: + """Serialize a symbolic transformation as a JSON object.""" + def encode_inst(inst: Instruction): + try: + return [inst.length, inst.to_string()] + except Exception as err: + warn(f'[In SymbolicTransform.to_json] Unable to serialize' + f' "{inst}" as string: {err}. This instruction will not' + f' be serialized.') + return None + + instrs = [encode_inst(inst) for inst in self.instructions] + instrs = [inst for inst in instrs if inst is not None] + return { + 'arch': self.arch.archname, + 'from_addr': self.range[0], + 'to_addr': self.range[1], + 'instructions': instrs, + 'regs': { name: repr(expr) for name, expr in self.changed_regs.items() }, + 'mem': { repr(addr): repr(val) for addr, val in self.changed_mem.items() }, + } + + def __repr__(self) -> str: + start, end = self.range + res = f'Symbolic state transformation {hex(start)} -> {hex(end)}:\n' + res += ' [Symbols]\n' + for reg, expr in self.changed_regs.items(): + res += f' {reg:6s} = {expr}\n' + for addr, expr in self.changed_mem.items(): + res += f' {ExprMem(addr, expr.size)} = {expr}\n' + res += ' [Instructions]\n' + for inst in self.instructions: + res += f' {inst}\n' + + return res[:-1] # Remove trailing newline + +class MemoryBinstream: + """A binary stream interface that reads bytes from a program state's + memory.""" + def __init__(self, state: ReadableProgramState): + self._state = state + + def __len__(self): + return 0xffffffff + + def __getitem__(self, key: int | slice): + if isinstance(key, slice): + return self._state.read_memory(key.start, key.stop - key.start) + return self._state.read_memory(key, 1) + +class DisassemblyContext: + def __init__(self, target: ReadableProgramState): + self.loc_db = LocationDB() + + # Determine the binary's architecture + self.machine = make_machine(target.arch) + self.arch = target.arch + + # Create disassembly/lifting context + assert(self.machine.dis_engine is not None) + binstream = MemoryBinstream(target) + self.mdis = self.machine.dis_engine(binstream, loc_db=self.loc_db) + self.mdis.follow_call = True + self.lifter = self.machine.lifter(self.loc_db) + +def run_instruction(instr: miasm_instr, + conc_state: MiasmSymbolResolver, + lifter: Lifter) \ + -> tuple[ExprInt | None, dict[Expr, Expr]]: + """Compute the symbolic equation of a single instruction. + + The concolic engine tries to express the instruction's equation as + independent of the concrete state as possible. + + May fail if the instruction is not supported. Failure is signalled by + returning `None` as the next program counter. + + :param instr: The instruction to run. + :param conc_state: A concrete reference state at `pc = instr.offset`. Used + to resolve symbolic program counters, i.e. to 'guide' + the symbolic execution on the correct path. This is the + concrete part of our concolic execution. + :param lifter: A lifter of the appropriate architecture. Get this from + a `DisassemblyContext` or a `Machine`. + + :return: The next program counter and a symbolic state. The PC is None if + an error occurs or when the program exits. The returned state + is `instr`'s symbolic transformation. + """ + from miasm.expression.expression import ExprCond, LocKey + from miasm.expression.simplifications import expr_simp + + def create_cond_state(cond: Expr, iftrue: dict, iffalse: dict) -> dict: + """Combines states that are to be reached conditionally. + + Example: + State A: + RAX = 0x42 + @[RBP - 0x4] = 0x123 + State B: + RDI = -0x777 + @[RBP - 0x4] = 0x5c32 + Condition: + RCX > 0x4 ? A : B + + Result State: + RAX = (RCX > 0x4) ? 0x42 : RAX + RDI = (RCX > 0x4) ? RDI : -0x777 + @[RBP - 0x4] = (RCX > 0x4) ? 0x123 : 0x5c32 + """ + res = {} + for dst, v in iftrue.items(): + if dst not in iffalse: + res[dst] = expr_simp(ExprCond(cond, v, dst)) + else: + res[dst] = expr_simp(ExprCond(cond, v, iffalse[dst])) + for dst, v in iffalse.items(): + if dst not in iftrue: + res[dst] = expr_simp(ExprCond(cond, dst, v)) + return res + + def _execute_location(loc, base_state: dict | None) \ + -> tuple[Expr, dict]: + """Execute a single IR block via symbolic engine. No fancy stuff.""" + # Query the location's IR block + irblock = ircfg.get_block(loc) + if irblock is None: + return loc, base_state if base_state is not None else {} + + # Apply IR block to the current state + engine = SymbolicExecutionEngine(lifter, state=base_state) + new_pc = engine.eval_updt_irblock(irblock) + modified = dict(engine.modified()) + return new_pc, modified + + def execute_location(loc: Expr | LocKey) -> tuple[ExprInt, dict]: + """Execute chains of IR blocks until a concrete program counter is + reached.""" + seen_locs = set() # To break out of loop instructions + new_pc, modified = _execute_location(loc, None) + + # Run chained IR blocks until a real program counter is reached. + # This used to be recursive (and much more elegant), but large RCX + # values for 'REP ...' instructions could make the stack overflow. + while not new_pc.is_int(): + seen_locs.add(new_pc) + + if new_pc.is_loc(): + # Jump to the next location. + new_pc, modified = _execute_location(new_pc, modified) + elif new_pc.is_cond(): + # Explore conditional paths manually by constructing + # conditional states based on the possible outcomes. + assert(isinstance(new_pc, ExprCond)) + cond = new_pc.cond + pc_iftrue, pc_iffalse = new_pc.src1, new_pc.src2 + + pc_t, state_t = _execute_location(pc_iftrue, modified.copy()) + pc_f, state_f = _execute_location(pc_iffalse, modified.copy()) + modified = create_cond_state(cond, state_t, state_f) + new_pc = expr_simp(ExprCond(cond, pc_t, pc_f)) + else: + # Concretisize PC in case it is, e.g., a memory expression + new_pc = eval_expr(new_pc, conc_state) + + # Avoid infinite loops for loop instructions (REP ...) by making + # the jump to the next loop iteration (or exit) concrete. + if new_pc in seen_locs: + new_pc = eval_expr(new_pc, conc_state) + seen_locs.clear() + + assert(isinstance(new_pc, ExprInt)) + return new_pc, modified + + # Lift instruction to IR + ircfg = lifter.new_ircfg() + try: + loc = lifter.add_instr_to_ircfg(instr, ircfg, None, False) + assert(isinstance(loc, Expr) or isinstance(loc, LocKey)) + except NotImplementedError as err: + warn(f'[WARNING] Unable to lift instruction {instr}: {err}. Skipping.') + return None, {} # Create an empty transform for the instruction + + # Execute instruction symbolically + new_pc, modified = execute_location(loc) + modified[lifter.pc] = new_pc # Add PC update to state + + return new_pc, modified + +class _LLDBConcreteState(ReadableProgramState): + """A wrapper around `LLDBConcreteTarget` that provides access via a + `ReadableProgramState` interface. Reads values directly from an LLDB + target. This saves us the trouble of recording a full program state, and + allows us instead to read values from LLDB on demand. + """ + def __init__(self, target: LLDBConcreteTarget): + super().__init__(target.arch) + self._target = target + + def read_register(self, reg: str) -> int: + regname = self.arch.to_regname(reg) + if regname is None: + raise RegisterAccessError(reg, f'Not a register name: {reg}') + + try: + return self._target.read_register(regname) + except ConcreteRegisterError: + raise RegisterAccessError(regname, '') + + def read_memory(self, addr: int, size: int) -> bytes: + try: + return self._target.read_memory(addr, size) + except ConcreteMemoryError: + raise MemoryAccessError(addr, size, 'Unable to read memory from LLDB.') + +def collect_symbolic_trace(env: TraceEnvironment, + start_addr: int | None = None + ) -> Trace[SymbolicTransform]: + """Execute a program and compute state transformations between executed + instructions. + + :param binary: The binary to trace. + :param args: Arguments to the program. + """ + binary = env.binary_name + + # Set up concrete reference state + target = LLDBConcreteTarget(binary, env.argv, env.envp) + if start_addr is not None: + target.run_until(start_addr) + lldb_state = _LLDBConcreteState(target) + + ctx = DisassemblyContext(lldb_state) + arch = ctx.arch + + # Trace concolically + strace: list[SymbolicTransform] = [] + while not target.is_exited(): + pc = target.read_register('pc') + + # Disassemble instruction at the current PC + try: + instr = ctx.mdis.dis_instr(pc) + except: + err = sys.exc_info()[1] + warn(f'Unable to disassemble instruction at {hex(pc)}: {err}.' + f' Skipping.') + target.step() + continue + + # Run instruction + conc_state = MiasmSymbolResolver(lldb_state, ctx.loc_db) + new_pc, modified = run_instruction(instr, conc_state, ctx.lifter) + + # Create symbolic transform + instruction = Instruction(instr, ctx.machine, ctx.arch, ctx.loc_db) + if new_pc is None: + new_pc = pc + instruction.length + else: + new_pc = int(new_pc) + transform = SymbolicTransform(modified, [instruction], arch, pc, new_pc) + strace.append(transform) + + # Predict next concrete state. + # We verify the symbolic execution backend on the fly for some + # additional protection from bugs in the backend. + predicted_regs = transform.eval_register_transforms(lldb_state) + predicted_mems = transform.eval_memory_transforms(lldb_state) + + # Step forward + target.step() + if target.is_exited(): + break + + # Verify last generated transform by comparing concrete state against + # predicted values. + assert(len(strace) > 0) + for reg, val in predicted_regs.items(): + conc_val = lldb_state.read_register(reg) + if conc_val != val: + warn(f'Symbolic execution backend generated false equation for' + f' [{hex(instruction.addr)}]: {instruction}:' + f' Predicted {reg} = {hex(val)}, but the' + f' concrete state has value {reg} = {hex(conc_val)}.' + f'\nFaulty transformation: {transform}') + for addr, data in predicted_mems.items(): + conc_data = lldb_state.read_memory(addr, len(data)) + if conc_data != data: + warn(f'Symbolic execution backend generated false equation for' + f' [{hex(instruction.addr)}]: {instruction}: Predicted' + f' mem[{hex(addr)}:{hex(addr+len(data))}] = {data},' + f' but the concrete state has value' + f' mem[{hex(addr)}:{hex(addr+len(data))}] = {conc_data}.' + f'\nFaulty transformation: {transform}') + raise Exception() + + return Trace(strace, env) diff --git a/focaccia/trace.py b/focaccia/trace.py new file mode 100644 index 0000000..094358f --- /dev/null +++ b/focaccia/trace.py @@ -0,0 +1,74 @@ +from __future__ import annotations +from typing import Generic, TypeVar + +from .utils import file_hash + +T = TypeVar('T') + +class TraceEnvironment: + """Data that defines the environment in which a trace was recorded.""" + def __init__(self, + binary: str, + argv: list[str], + envp: list[str], + binary_hash: str | None = None): + self.argv = argv + self.envp = envp + self.binary_name = binary + if binary_hash is None: + self.binary_hash = file_hash(binary) + else: + self.binary_hash = binary_hash + + @classmethod + def from_json(cls, json: dict) -> TraceEnvironment: + """Parse a JSON object into a TraceEnvironment.""" + return cls( + json['binary_name'], + json['argv'], + json['envp'], + json['binary_hash'], + ) + + def to_json(self) -> dict: + """Serialize a TraceEnvironment to a JSON object.""" + return { + 'binary_name': self.binary_name, + 'binary_hash': self.binary_hash, + 'argv': self.argv, + 'envp': self.envp, + } + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TraceEnvironment): + return False + + return self.binary_name == other.binary_name \ + and self.binary_hash == other.binary_hash \ + and self.argv == other.argv \ + and self.envp == other.envp + + def __repr__(self) -> str: + return f'{self.binary_name} {" ".join(self.argv)}' \ + f'\n bin-hash={self.binary_hash}' \ + f'\n envp={repr(self.envp)}' + +class Trace(Generic[T]): + def __init__(self, + trace_states: list[T], + env: TraceEnvironment): + self.states = trace_states + self.env = env + + def __len__(self) -> int: + return len(self.states) + + def __getitem__(self, i: int) -> T: + return self.states[i] + + def __iter__(self): + return iter(self.states) + + def __repr__(self) -> str: + return f'Trace with {len(self.states)} trace points.' \ + f' Environment: {repr(self.env)}' diff --git a/focaccia/utils.py b/focaccia/utils.py new file mode 100644 index 0000000..c4f6a74 --- /dev/null +++ b/focaccia/utils.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +import ctypes +import os +import shutil +import sys +from functools import total_ordering +from hashlib import sha256 + +@total_ordering +class ErrorSeverity: + def __init__(self, num: int, name: str): + """Construct an error severity. + + :param num: A numerical value that orders the severity with respect + to other `ErrorSeverity` objects. Smaller values are less + severe. + :param name: A descriptive name for the error severity, e.g. 'fatal' + or 'info'. + """ + self._numeral = num + self.name = name + + def __repr__(self) -> str: + return f'[{self.name}]' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ErrorSeverity): + return False + return self._numeral == other._numeral + + def __lt__(self, other: ErrorSeverity) -> bool: + return self._numeral < other._numeral + + def __hash__(self) -> int: + return hash(self._numeral) + +def float_bits_to_uint(v: float) -> int: + """Bit-cast a float to a 32-bit integer.""" + return ctypes.c_uint32.from_buffer(ctypes.c_float(v)).value + +def uint_bits_to_float(v: int) -> float: + """Bit-cast a 32-bit integer to a float.""" + return ctypes.c_float.from_buffer(ctypes.c_uint32(v)).value + +def double_bits_to_uint(v: float) -> int: + """Bit-cast a double to a 64-bit integer.""" + return ctypes.c_uint64.from_buffer(ctypes.c_double(v)).value + +def uint_bits_to_double(v: int) -> float: + """Bit-cast a 64-bit integer to a double.""" + return ctypes.c_double.from_buffer(ctypes.c_uint64(v)).value + +def file_hash(filename: str, hash = sha256(), chunksize: int = 65536) -> str: + """Calculate a file's hash. + + :param filename: Name of the file to hash. + :param hash: The hash algorithm to use. + :param chunksize: Optimization option. Size of contiguous chunks to read + from the file and feed into the hashing algorithm. + :return: A hex digest. + """ + with open(filename, 'rb') as file: + while True: + data = file.read(chunksize) + if not data: + break + hash.update(data) + return hash.hexdigest() + +def get_envp() -> list[str]: + """Return current environment array. + + Merge dict-like `os.environ` struct to the traditional list-like + environment array. + """ + return [f'{k}={v}' for k, v in os.environ.items()] + +def print_separator(separator: str = '-', stream=sys.stdout, count: int = 80): + maxtermsize = count + termsize = shutil.get_terminal_size((80, 20)).columns + print(separator * min(termsize, maxtermsize), file=stream) + +def print_result(result, min_severity: ErrorSeverity): + """Print a comparison result.""" + shown = 0 + suppressed = 0 + + for res in result: + # Filter errors by severity + errs = [e for e in res['errors'] if e.severity >= min_severity] + suppressed += len(res['errors']) - len(errs) + shown += len(errs) + + if errs: + pc = res['pc'] + print_separator() + print(f'For PC={hex(pc)}') + print_separator() + + # Print all non-suppressed errors + for n, err in enumerate(errs, start=1): + print(f' {n:2}. {err}') + + if errs: + print() + print(f'Expected transformation: {res["ref"]}') + print(f'Actual difference: {res["txl"]}') + + print() + print('#' * 60) + print(f'Found {shown} errors.') + print(f'Suppressed {suppressed} low-priority errors' + f' (showing {min_severity} and higher).') + print('#' * 60) + print() diff --git a/nix.shell b/nix.shell new file mode 100644 index 0000000..00fef51 --- /dev/null +++ b/nix.shell @@ -0,0 +1,12 @@ +{ pkgs ? import <nixpkgs> {} }: +pkgs.mkShell { + nativeBuildInputs = with pkgs; [ + python311 + python311Packages.pip + virtualenv + + gcc gnumake binutils cmake ninja pkg-config + musl qemu swig4 + gdb + ]; +} diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..220ba8b --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +miasm diff --git a/run.py b/run.py deleted file mode 100755 index f1f1060..0000000 --- a/run.py +++ /dev/null @@ -1,214 +0,0 @@ -#! /bin/python3 -import os -import re -import sys -import lldb -import shutil -import argparse - -from utils import print_separator - -verbose = False - -regnames = ['PC', - 'RAX', - 'RBX', - 'RCX', - 'RDX', - 'RSI', - 'RDI', - 'RBP', - 'RSP', - 'R8', - 'R9', - 'R10', - 'R11', - 'R12', - 'R13', - 'R14', - 'R15', - 'RFLAGS'] - -class DebuggerCallback: - def __init__(self, ostream=sys.stdout, skiplist: set = {}): - self.stream = ostream - self.regex = re.compile('(' + '|'.join(regnames) + ')$') - self.skiplist = skiplist - - @staticmethod - def parse_flags(flag_reg: int): - flags = {'ZF': 0, - 'CF': 0, - 'OF': 0, - 'SF': 0, - 'PF': 0, - 'DF': 0} - - # CF (Carry flag) Bit 0 - # PF (Parity flag) Bit 2 - # ZF (Zero flag) Bit 6 - # SF (Sign flag) Bit 7 - # TF (Trap flag) Bit 8 - # IF (Interrupt enable flag) Bit 9 - # DF (Direction flag) Bit 10 - # OF (Overflow flag) Bit 11 - flags['CF'] = int(0 != flag_reg & 1) - flags['ZF'] = int(0 != flag_reg & (1 << 6)) - flags['OF'] = int(0 != flag_reg & (1 << 11)) - flags['SF'] = int(0 != flag_reg & (1 << 7)) - flags['DF'] = int(0 != flag_reg & (1 << 10)) - flags['PF'] = int(0 != flag_reg & (1 << 1)) - return flags - - - def print_regs(self, frame): - for reg in frame.GetRegisters(): - for sub_reg in reg: - match = self.regex.match(sub_reg.GetName().upper()) - if match and match.group() == 'RFLAGS': - flags = DebuggerCallback.parse_flags(int(sub_reg.GetValue(), - base=16)) - for flag in flags: - print(f'flag {flag}:\t{hex(flags[flag])}', - file=self.stream) - elif match: - print(f"{sub_reg.GetName().upper()}:\t\t {hex(int(sub_reg.GetValue(), base=16))}", - file=self.stream) - - def print_stack(self, frame, element_count: int): - first = True - for i in range(element_count): - addr = frame.GetSP() + i * frame.GetThread().GetProcess().GetAddressByteSize() - error = lldb.SBError() - stack_value = int(frame.GetThread().GetProcess().ReadPointerFromMemory(addr, error)) - if error.Success() and not first: - print(f'{hex(stack_value)}', file=self.stream) - elif error.Success(): - print(f'{hex(stack_value)}\t\t<- rsp', file=self.stream) - else: - print(f"Error reading memory at address 0x{addr:x}", - file=self.stream) - first=False - - def __call__(self, frame): - pc = frame.GetPC() - - # Skip this PC - if pc in self.skiplist: - self.skiplist.discard(pc) - return False - - print_separator('=', stream=self.stream, count=20) - print(f'INVOKE PC={hex(pc)}', file=self.stream) - print_separator('=', stream=self.stream, count=20) - - print("Register values:", file=self.stream) - self.print_regs(frame) - print_separator(stream=self.stream) - - print("STACK:", file=self.stream) - self.print_stack(frame, 20) - - return True # Continue execution - -class Debugger: - def __init__(self, program): - self.debugger = lldb.SBDebugger.Create() - self.debugger.SetAsync(False) - self.target = self.debugger.CreateTargetWithFileAndArch(program, - lldb.LLDB_ARCH_DEFAULT) - self.module = self.target.FindModule(self.target.GetExecutable()) - self.interpreter = self.debugger.GetCommandInterpreter() - - def set_breakpoint_by_addr(self, address: int): - command = f"b -a {address} -s {self.module.GetFileSpec().GetFilename()}" - result = lldb.SBCommandReturnObject() - self.interpreter.HandleCommand(command, result) - - if verbose: - print(f'Set breakpoint at address {hex(address)}') - - def get_breakpoints_count(self): - return self.target.GetNumBreakpoints() - - def execute(self, callback: callable): - error = lldb.SBError() - listener = self.debugger.GetListener() - process = self.target.Launch(listener, None, None, None, None, None, None, 0, - True, error) - - # Check if the process has launched successfully - if process.IsValid(): - print(f'Launched process: {process}') - else: - print('Failed to launch process', file=sys.stderr) - - while True: - state = process.GetState() - if state == lldb.eStateStopped: - for thread in process: - callback(thread.GetFrameAtIndex(0)) - process.Continue() - if state == lldb.eStateExited: - break - - self.debugger.Terminate() - - print(f'Process state: {process.GetState()}') - print('Program output:') - print(process.GetSTDOUT(1024)) - print(process.GetSTDERR(1024)) - -class ListWriter: - def __init__(self): - self.data = [] - - def write(self, s): - self.data.append(s) - - def __str__(self): - return "".join(self.data) - -class Runner: - def __init__(self, dbt_log: list, oracle_program: str): - self.log = dbt_log - self.program = oracle_program - self.debugger = Debugger(self.program) - self.writer = ListWriter() - - @staticmethod - def get_addresses(lines: list): - addresses = [] - - backlist = [] - backlist_regex = re.compile(r'^\s\s\d*:') - - skiplist = set() - for l in lines: - if l.startswith('INVOKE'): - addresses.append(int(l.split('=')[1].strip(), base=16)) - - if addresses[-1] in backlist: - skiplist.add(addresses[-1]) - backlist = [] - - if backlist_regex.match(l): - backlist.append(int(l.split()[0].split(':')[0], base=16)) - - return set(addresses), skiplist - - def run(self): - # Get all addresses to stop at - addresses, skiplist = Runner.get_addresses(self.log) - - # Set breakpoints - for address in addresses: - self.debugger.set_breakpoint_by_addr(address) - - # Sanity check - assert(self.debugger.get_breakpoints_count() == len(addresses)) - - self.debugger.execute(DebuggerCallback(self.writer, skiplist)) - - return self.writer.data - diff --git a/test/test_snapshot.py b/test/test_snapshot.py new file mode 100644 index 0000000..ddad410 --- /dev/null +++ b/test/test_snapshot.py @@ -0,0 +1,74 @@ +import unittest + +from focaccia.arch import x86 +from focaccia.snapshot import ProgramState, RegisterAccessError + +class TestProgramState(unittest.TestCase): + def setUp(self): + self.arch = x86.ArchX86() + + def test_register_access_empty_state(self): + state = ProgramState(self.arch) + for reg in x86.regnames: + self.assertRaises(RegisterAccessError, state.read_register, reg) + + def test_register_read_write(self): + state = ProgramState(self.arch) + for reg in x86.regnames: + state.set_register(reg, 0x42) + for reg in x86.regnames: + val = state.read_register(reg) + self.assertEqual(val, 0x42) + + def test_register_aliases_empty_state(self): + state = ProgramState(self.arch) + for reg in self.arch.all_regnames: + self.assertRaises(RegisterAccessError, state.read_register, reg) + + def test_register_aliases_read_write(self): + state = ProgramState(self.arch) + for reg in ['EAX', 'EBX', 'ECX', 'EDX']: + state.set_register(reg, 0xa0ff0) + + for reg in ['AH', 'BH', 'CH', 'DH']: + self.assertEqual(state.read_register(reg), 0xf, reg) + for reg in ['AL', 'BL', 'CL', 'DL']: + self.assertEqual(state.read_register(reg), 0xf0, reg) + for reg in ['AX', 'BX', 'CX', 'DX']: + self.assertEqual(state.read_register(reg), 0x0ff0, reg) + for reg in ['EAX', 'EBX', 'ECX', 'EDX', + 'RAX', 'RBX', 'RCX', 'RDX']: + self.assertEqual(state.read_register(reg), 0xa0ff0, reg) + + def test_flag_aliases(self): + flags = ['CF', 'PF', 'AF', 'ZF', 'SF', 'TF', 'IF', 'DF', 'OF', + 'IOPL', 'NT', 'RF', 'VM', 'AC', 'VIF', 'VIP', 'ID'] + state = ProgramState(self.arch) + + state.set_register('RFLAGS', 0) + for flag in flags: + self.assertEqual(state.read_register(flag), 0) + + state.set_register('RFLAGS', + x86.compose_rflags({'ZF': 1, 'PF': 1, 'OF': 0})) + self.assertEqual(state.read_register('ZF'), 1, self.arch.get_reg_accessor('ZF')) + self.assertEqual(state.read_register('PF'), 1) + self.assertEqual(state.read_register('OF'), 0) + self.assertEqual(state.read_register('AF'), 0) + self.assertEqual(state.read_register('ID'), 0) + self.assertEqual(state.read_register('SF'), 0) + + for flag in flags: + state.set_register(flag, 1) + for flag in flags: + self.assertEqual(state.read_register(flag), 1) + + state.set_register('OF', 1) + state.set_register('AF', 1) + state.set_register('SF', 1) + self.assertEqual(state.read_register('OF'), 1) + self.assertEqual(state.read_register('AF'), 1) + self.assertEqual(state.read_register('SF'), 1) + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_sparse_memory.py b/test/test_sparse_memory.py new file mode 100644 index 0000000..4fd9cba --- /dev/null +++ b/test/test_sparse_memory.py @@ -0,0 +1,33 @@ +import unittest + +from focaccia.snapshot import SparseMemory, MemoryAccessError + +class TestSparseMemory(unittest.TestCase): + def test_oob_read(self): + mem = SparseMemory() + for addr in range(mem.page_size): + self.assertRaises(MemoryAccessError, mem.read, addr, 1) + self.assertRaises(MemoryAccessError, mem.read, addr, 30) + self.assertRaises(MemoryAccessError, mem.read, addr + 0x10, 30) + self.assertRaises(MemoryAccessError, mem.read, addr, mem.page_size) + self.assertRaises(MemoryAccessError, mem.read, addr, mem.page_size - 1) + self.assertRaises(MemoryAccessError, mem.read, addr, mem.page_size + 1) + + def test_basic_read_write(self): + mem = SparseMemory() + + data = b'a' * mem.page_size * 2 + mem.write(0x300, data) + self.assertEqual(mem.read(0x300, len(data)), data) + self.assertEqual(mem.read(0x300, 1), b'a') + self.assertEqual(mem.read(0x400, 1), b'a') + self.assertEqual(mem.read(0x299 + mem.page_size * 2, 1), b'a') + self.assertEqual(mem.read(0x321, 12), b'aaaaaaaaaaaa') + + mem.write(0x321, b'Hello World!') + self.assertEqual(mem.read(0x321, 12), b'Hello World!') + + self.assertRaises(MemoryAccessError, mem.read, 0x300, mem.page_size * 3) + +if __name__ == '__main__': + unittest.main() diff --git a/tools/_qemu_tool.py b/tools/_qemu_tool.py new file mode 100644 index 0000000..b365d39 --- /dev/null +++ b/tools/_qemu_tool.py @@ -0,0 +1,314 @@ +"""Invocable like this: + + gdb -n --batch -x qemu_tool.py + +But please use `tools/verify_qemu.py` instead because we have some more setup +work to do. +""" + +import gdb +from typing import Iterable + +import focaccia.parser as parser +from focaccia.arch import supported_architectures, Arch +from focaccia.compare import compare_symbolic +from focaccia.snapshot import ProgramState, ReadableProgramState, \ + RegisterAccessError, MemoryAccessError +from focaccia.symbolic import SymbolicTransform, eval_symbol, ExprMem +from focaccia.trace import Trace, TraceEnvironment +from focaccia.utils import print_result + +from verify_qemu import make_argparser, verbosity + +class GDBProgramState(ReadableProgramState): + from focaccia.arch import aarch64, x86 + + flag_register_names = { + aarch64.archname: 'cpsr', + x86.archname: 'eflags', + } + + flag_register_decompose = { + aarch64.archname: aarch64.decompose_cpsr, + x86.archname: x86.decompose_rflags, + } + + def __init__(self, process: gdb.Inferior, frame: gdb.Frame, arch: Arch): + super().__init__(arch) + self._proc = process + self._frame = frame + + @staticmethod + def _read_vector_reg_aarch64(val, size) -> int: + return int(str(val['u']), 10) + + @staticmethod + def _read_vector_reg_x86(val, size) -> int: + num_longs = size // 64 + vals = val[f'v{num_longs}_int64'] + res = 0 + for i in range(num_longs): + val = int(vals[i].cast(gdb.lookup_type('unsigned long'))) + res += val << i * 64 + return res + + read_vector_reg = { + aarch64.archname: _read_vector_reg_aarch64, + x86.archname: _read_vector_reg_x86, + } + + def read_register(self, reg: str) -> int: + if reg == 'RFLAGS': + reg = 'EFLAGS' + + try: + val = self._frame.read_register(reg.lower()) + size = val.type.sizeof * 8 + + # For vector registers, we need to apply architecture-specific + # logic because GDB's interface is not consistent. + if size >= 128: # Value is a vector + if self.arch.archname not in self.read_vector_reg: + raise NotImplementedError( + f'Reading vector registers is not implemented for' + f' architecture {self.arch.archname}.') + return self.read_vector_reg[self.arch.archname](val, size) + elif size < 64: + return int(val.cast(gdb.lookup_type('unsigned int'))) + # For non-vector values, just return the 64-bit value + return int(val.cast(gdb.lookup_type('unsigned long'))) + except ValueError as err: + # Try to access the flags register with `reg` as a logical flag name + if self.arch.archname in self.flag_register_names: + flags_reg = self.flag_register_names[self.arch.archname] + flags = int(self._frame.read_register(flags_reg)) + flags = self.flag_register_decompose[self.arch.archname](flags) + if reg in flags: + return flags[reg] + raise RegisterAccessError(reg, + f'[GDB] Unable to access {reg}: {err}') + + def read_memory(self, addr: int, size: int) -> bytes: + try: + mem = self._proc.read_memory(addr, size).tobytes() + if self.arch.endianness == 'little': + return mem + else: + return bytes(reversed(mem)) # Convert to big endian + except gdb.MemoryError as err: + raise MemoryAccessError(addr, size, str(err)) + +class GDBServerStateIterator: + def __init__(self, address: str, port: int): + gdb.execute('set pagination 0') + gdb.execute('set sysroot') + gdb.execute(f'target remote {address}:{port}') + self._process = gdb.selected_inferior() + self._first_next = True + + # Try to determine the guest architecture. This is a bit hacky and + # tailored to GDB's naming for the x86-64 architecture. + split = self._process.architecture().name().split(':') + archname = split[1] if len(split) > 1 else split[0] + archname = archname.replace('-', '_') + if archname not in supported_architectures: + print(f'Error: Current platform ({archname}) is not' + f' supported by Focaccia. Exiting.') + exit(1) + + self.arch = supported_architectures[archname] + self.binary = self._process.progspace.filename + + def __iter__(self): + return self + + def __next__(self): + # The first call to __next__ should yield the first program state, + # i.e. before stepping the first time + if self._first_next: + self._first_next = False + return GDBProgramState(self._process, gdb.selected_frame(), self.arch) + + # Step + pc = gdb.selected_frame().read_register('pc') + new_pc = pc + while pc == new_pc: # Skip instruction chains from REP STOS etc. + gdb.execute('si', to_string=True) + if not self._process.is_valid() or len(self._process.threads()) == 0: + raise StopIteration + new_pc = gdb.selected_frame().read_register('pc') + + return GDBProgramState(self._process, gdb.selected_frame(), self.arch) + +def record_minimal_snapshot(prev_state: ReadableProgramState, + cur_state: ReadableProgramState, + prev_transform: SymbolicTransform, + cur_transform: SymbolicTransform) \ + -> ProgramState: + """Record a minimal snapshot. + + A minimal snapshot must include values (registers and memory) that are + accessed by two transformations: + 1. The values produced by the previous transformation (the + transformation that is producing this snapshot) to check these + values against expected values calculated from the previous + program state. + 2. The values that act as inputs to the transformation acting on this + snapshot, to calculate the expected values of the next snapshot. + + :param prev_transform: The symbolic transformation generating, or + leading to, `cur_state`. Values generated by + this transformation are included in the + snapshot. + :param transform: The symbolic transformation operating on this + snapshot. Input values to this transformation are + included in the snapshot. + """ + assert(cur_state.read_register('pc') == cur_transform.addr) + assert(prev_transform.arch == cur_transform.arch) + + def get_written_addresses(t: SymbolicTransform): + """Get all output memory accesses of a symbolic transformation.""" + return [ExprMem(a, v.size) for a, v in t.changed_mem.items()] + + def set_values(regs: Iterable[str], mems: Iterable[ExprMem], + cur_state: ReadableProgramState, + prev_state: ReadableProgramState, + out_state: ProgramState): + """ + :param prev_state: Addresses of memory included in the snapshot are + resolved relative to this state. + """ + for regname in regs: + try: + regval = cur_state.read_register(regname) + out_state.set_register(regname, regval) + except RegisterAccessError: + pass + for mem in mems: + assert(mem.size % 8 == 0) + addr = eval_symbol(mem.ptr, prev_state) + try: + mem = cur_state.read_memory(addr, int(mem.size / 8)) + out_state.write_memory(addr, mem) + except MemoryAccessError: + pass + + state = ProgramState(cur_transform.arch) + state.set_register('PC', cur_transform.addr) + + set_values(prev_transform.changed_regs.keys(), + get_written_addresses(prev_transform), + cur_state, + prev_state, # Evaluate memory addresses based on previous + # state because they are that state's output + # addresses. + state) + set_values(cur_transform.get_used_registers(), + cur_transform.get_used_memory_addresses(), + cur_state, + cur_state, + state) + return state + +def collect_conc_trace(gdb: GDBServerStateIterator, \ + strace: list[SymbolicTransform]) \ + -> tuple[list[ProgramState], list[SymbolicTransform]]: + """Collect a trace of concrete states from GDB. + + Records minimal concrete states from GDB by using symbolic trace + information to determine which register/memory values are required to + verify the correctness of the program running in GDB. + + May drop symbolic transformations if the symbolic trace and the GDB trace + diverge (e.g. because of differences in environment, etc.). Returns the + new, possibly modified, symbolic trace that matches the returned concrete + trace. + + :return: A list of concrete states and a list of corresponding symbolic + transformations. The lists are guaranteed to have the same length. + """ + def find_index(seq, target, access=lambda el: el): + for i, el in enumerate(seq): + if access(el) == target: + return i + return None + + if not strace: + return [], [] + + states = [] + matched_transforms = [] + + state_iter = iter(gdb) + cur_state = next(state_iter) + symb_i = 0 + + # An online trace matching algorithm. + while True: + try: + pc = cur_state.read_register('pc') + + while pc != strace[symb_i].addr: + next_i = find_index(strace[symb_i+1:], pc, lambda t: t.addr) + + # Drop the concrete state if no address in the symbolic trace + # matches + if next_i is None: + print(f'Warning: Dropping concrete state {hex(pc)}, as no' + f' matching instruction can be found in the symbolic' + f' reference trace.') + cur_state = next(state_iter) + pc = cur_state.read_register('pc') + continue + + # Otherwise, jump to the next matching symbolic state + symb_i += next_i + 1 + + assert(cur_state.read_register('pc') == strace[symb_i].addr) + states.append(record_minimal_snapshot( + states[-1] if states else cur_state, + cur_state, + matched_transforms[-1] if matched_transforms else strace[symb_i], + strace[symb_i])) + matched_transforms.append(strace[symb_i]) + cur_state = next(state_iter) + symb_i += 1 + except StopIteration: + break + + return states, matched_transforms + +def main(): + args = make_argparser().parse_args() + + gdbserver_addr = 'localhost' + gdbserver_port = args.port + gdb_server = GDBServerStateIterator(gdbserver_addr, gdbserver_port) + + executable = gdb_server.binary + argv = [] # QEMU's GDB stub does not support 'info proc cmdline' + envp = [] # Can't get the remote target's environment + env = TraceEnvironment(executable, argv, envp, '?') + + # Read pre-computed symbolic trace + with open(args.symb_trace, 'r') as strace: + symb_transforms = parser.parse_transformations(strace) + + # Use symbolic trace to collect concrete trace from QEMU + conc_states, matched_transforms = collect_conc_trace( + gdb_server, + symb_transforms.states) + + # Verify and print result + if not args.quiet: + res = compare_symbolic(conc_states, matched_transforms) + print_result(res, verbosity[args.error_level]) + + if args.output: + from focaccia.parser import serialize_snapshots + with open(args.output, 'w') as file: + serialize_snapshots(Trace(conc_states, env), file) + +if __name__ == "__main__": + main() diff --git a/tools/capture_transforms.py b/tools/capture_transforms.py new file mode 100755 index 0000000..552b855 --- /dev/null +++ b/tools/capture_transforms.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 + +import argparse + +from focaccia import parser, utils +from focaccia.symbolic import collect_symbolic_trace +from focaccia.trace import TraceEnvironment + +def main(): + prog = argparse.ArgumentParser() + prog.description = 'Trace an executable concolically to capture symbolic' \ + ' transformations among instructions.' + prog.add_argument('binary', help='The program to analyse.') + prog.add_argument('args', action='store', nargs=argparse.REMAINDER, + help='Arguments to the program.') + prog.add_argument('-o', '--output', + default='trace.out', + help='Name of output file. (default: trace.out)') + args = prog.parse_args() + + env = TraceEnvironment(args.binary, args.args, utils.get_envp()) + trace = collect_symbolic_trace(env, None) + with open(args.output, 'w') as file: + parser.serialize_transformations(trace, file) + +if __name__ == "__main__": + main() diff --git a/tools/convert.py b/tools/convert.py new file mode 100755 index 0000000..f21a2fa --- /dev/null +++ b/tools/convert.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 + +import argparse +import sys + +import focaccia.parser as parser +from focaccia.arch import supported_architectures + +convert_funcs = { + 'qemu': parser.parse_qemu, + 'arancini': parser.parse_arancini, +} + +def main(): + """Main.""" + prog = argparse.ArgumentParser() + prog.description = 'Convert other programs\' logs to focaccia\'s log format.' + prog.add_argument('file', help='The log to convert.') + prog.add_argument('--type', + required=True, + choices=convert_funcs.keys(), + help='The log type of `file`') + prog.add_argument('--output', '-o', + help='Output file (default is stdout)') + prog.add_argument('--arch', + default='x86_64', + choices=supported_architectures.keys(), + help='Processor architecture of input log (default is x86)') + args = prog.parse_args() + + # Parse arancini log + arch = supported_architectures[args.arch] + parse_log = convert_funcs[args.type] + with open(args.file, 'r') as in_file: + try: + snapshots = parse_log(in_file, arch) + except parser.ParseError as err: + print(f'Parse error: {err}. Exiting.', file=sys.stderr) + exit(1) + + # Write log in focaccia's format + if args.output: + with open(args.output, 'w') as out_file: + parser.serialize_snapshots(snapshots, out_file) + else: + parser.serialize_snapshots(snapshots, sys.stdout) + +if __name__ == '__main__': + main() diff --git a/tools/verify_qemu.py b/tools/verify_qemu.py new file mode 100755 index 0000000..df9f83d --- /dev/null +++ b/tools/verify_qemu.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 + +""" +Spawn GDB, connect to QEMU's GDB server, and read test states from that. + +We need two scripts (this one and the primary `qemu_tool.py`) because we can't +pass arguments to scripts executed via `gdb -x <script>`. + +This script (`verify_qemu.py`) is the one the user interfaces with. It +eventually calls `execv` to spawn a GDB process that calls the main +`qemu_tool.py` script; `python verify_qemu.py` essentially behaves as if +something like `gdb --batch -x qemu_tool.py` were executed instead. Before it +starts GDB, though, it parses command line arguments and applies some weird but +necessary logic to pass them to `qemu_tool.py`. +""" + +import argparse +import os +import subprocess +import sys + +from focaccia.compare import ErrorTypes + +verbosity = { + 'info': ErrorTypes.INFO, + 'warning': ErrorTypes.POSSIBLE, + 'error': ErrorTypes.CONFIRMED, +} + +def make_argparser(): + """This is also used by the GDB-invoked script to parse its args.""" + prog = argparse.ArgumentParser() + prog.description = """Use Focaccia to test QEMU. + +Uses QEMU's GDB-server feature to read QEMU's emulated state and test its +transformation during emulation against a symbolic truth. + +In fact, this tool could be used to test any emulator that provides a +GDB-server interface. The server must support reading registers, reading +memory, and stepping forward by single instructions. +""" + prog.add_argument('hostname', + help='The hostname at which to find the GDB server.') + prog.add_argument('port', + type=int, + help='The port at which to find the GDB server.') + prog.add_argument('--symb-trace', + required=True, + help='A pre-computed symbolic transformation trace to' \ + ' be used for verification. Generate this with' \ + ' the `tools/capture_transforms.py` tool.') + prog.add_argument('-q', '--quiet', + default=False, + action='store_true', + help='Don\'t print a verification result.') + prog.add_argument('-o', '--output', + help='If specified with a file name, the recorded' + ' emulator states will be written to that file.') + prog.add_argument('--error-level', + default='warning', + choices=list(verbosity.keys())) + return prog + +def quoted(s: str) -> str: + return f'"{s}"' + +def try_remove(l: list, v): + try: + l.remove(v) + except ValueError: + pass + +if __name__ == "__main__": + prog = make_argparser() + prog.add_argument('--gdb', default='/bin/gdb', + help='GDB binary to invoke.') + args = prog.parse_args() + + filepath = os.path.realpath(__file__) + qemu_tool_path = os.path.join(os.path.dirname(filepath), '_qemu_tool.py') + + # We have to remove all arguments we don't want to pass to the qemu tool + # manually here. Not nice, but what can you do.. + argv = sys.argv + try_remove(argv, '--gdb') + try_remove(argv, args.gdb) + + # Assemble the argv array passed to the qemu tool. GDB does not have a + # mechanism to pass arguments to a script that it executes, so we + # overwrite `sys.argv` manually before invoking the script. + argv_str = f'[{", ".join(quoted(a) for a in argv)}]' + path_str = f'[{", ".join(quoted(s) for s in sys.path)}]' + + gdb_cmd = [ + args.gdb, + '-nx', # Don't parse any .gdbinits + '--batch', + '-ex', f'py import sys', + '-ex', f'py sys.argv = {argv_str}', + '-ex', f'py sys.path = {path_str}', + '-x', qemu_tool_path + ] + proc = subprocess.Popen(gdb_cmd) + + ret = proc.wait() + exit(ret) diff --git a/utils.py b/utils.py deleted file mode 100644 index d841c7c..0000000 --- a/utils.py +++ /dev/null @@ -1,18 +0,0 @@ -#! /bin/python3 - -import sys -import shutil - -def print_separator(separator: str = '-', stream=sys.stdout, count: int = 80): - maxtermsize = count - termsize = shutil.get_terminal_size((80, 20)).columns - print(separator * min(termsize, maxtermsize), file=stream) - -def check_version(version: str): - # Script depends on ordered dicts in default dict() - split = version.split('.') - major = int(split[0]) - minor = int(split[1]) - if sys.version_info.major < major and sys.version_info.minor < minor: - raise EnvironmentError("Expected at least Python 3.7") - |