diff options
Diffstat (limited to '')
| -rw-r--r-- | README.md | 4 | ||||
| -rw-r--r-- | arancini.py | 94 | ||||
| -rw-r--r-- | arch/arch.py | 7 | ||||
| -rw-r--r-- | arch/x86.py | 19 | ||||
| -rw-r--r-- | compare.py | 123 | ||||
| -rw-r--r-- | lldb_target.py | 118 | ||||
| -rwxr-xr-x | main.py | 129 | ||||
| -rw-r--r-- | parser.py | 124 | ||||
| -rw-r--r-- | snapshot.py | 15 | ||||
| -rw-r--r-- | symbolic.py | 171 |
10 files changed, 511 insertions, 293 deletions
diff --git a/README.md b/README.md index fcdbe90..63428b0 100644 --- a/README.md +++ b/README.md @@ -33,8 +33,8 @@ program snapshots. - `compare.py`: The central algorithms that work on snapshots. - - `arancini.py`: Functionality specific to working with arancini. Parsing of arancini's logs into our snapshot -structures. + - `parser.py`: Utilities for parsing logs from Arancini and QEMU, as well as serializing/deserializing to/from our own +log format. - `arch/`: Abstractions over different processor architectures. Will be used to integrate support for more architectures later. Currently, we only have X86. diff --git a/arancini.py b/arancini.py deleted file mode 100644 index 71f45d7..0000000 --- a/arancini.py +++ /dev/null @@ -1,94 +0,0 @@ -"""Tools for working with arancini's output.""" - -import re -from functools import partial as bind - -from snapshot import ProgramState -from arch.arch import Arch - -def parse_break_addresses(lines: list[str]) -> set[int]: - """Parse all breakpoint addresses from an arancini log.""" - addresses = set() - for l in lines: - if l.startswith('INVOKE'): - addr = int(l.split('=')[1].strip(), base=16) - addresses.add(addr) - - return addresses - -def parse(lines: list[str], arch: Arch) -> list[ProgramState]: - """Parse an arancini log into a list of snapshots. - - :return: A list of program snapshots. - """ - - labels = get_labels() - - # The regex decides for a line whether it contains a register - # based on a match with that register's label. - regex = re.compile("|".join(labels.keys())) - - def try_parse_line(line: str) -> tuple[str, int] | None: - """Try to parse a register name and that register's value from a line. - - :return: A register name and a register value if the line contains - that information. None if parsing fails. - """ - match = regex.match(line) - if match: - label = match.group(0) - register, get_reg_value = labels[label] - return register, get_reg_value(line) - return None - - # Parse a list of program snapshots - snapshots = [] - for line in lines: - if 'Backwards' in line and len(snapshots) > 0: - # snapshots[-1].set_backwards() - continue - - match = try_parse_line(line) - if match: - reg, value = match - if reg == 'PC': - snapshots.append(ProgramState(arch)) - snapshots[-1].set(reg, value) - - return snapshots - -def get_labels(): - """Construct a helper structure for the arancini log parser.""" - split_value = lambda x,i: int(x.split()[i], 16) - - split_first = bind(split_value, i=1) - split_second = bind(split_value, i=2) - - split_equal = lambda x,i: int(x.split('=')[i], 16) - - # A mapping from regex patterns to the register name and a - # function that extracts that register's value from the line - labels = {'INVOKE': ('PC', bind(split_equal, i=1)), - 'RAX': ('RAX', split_first), - 'RBX': ('RBX', split_first), - 'RCX': ('RCX', split_first), - 'RDX': ('RDX', split_first), - 'RSI': ('RSI', split_first), - 'RDI': ('RDI', split_first), - 'RBP': ('RBP', split_first), - 'RSP': ('RSP', split_first), - 'R8': ('R8', split_first), - 'R9': ('R9', split_first), - 'R10': ('R10', split_first), - 'R11': ('R11', split_first), - 'R12': ('R12', split_first), - 'R13': ('R13', split_first), - 'R14': ('R14', split_first), - 'R15': ('R15', split_first), - 'flag ZF': ('ZF', split_second), - 'flag CF': ('CF', split_second), - 'flag OF': ('OF', split_second), - 'flag SF': ('SF', split_second), - 'flag PF': ('PF', split_second), - 'flag DF': ('DF', split_second)} - return labels diff --git a/arch/arch.py b/arch/arch.py index ba94631..f2be5cb 100644 --- a/arch/arch.py +++ b/arch/arch.py @@ -3,7 +3,7 @@ from typing import Iterable class Arch(): def __init__(self, archname: str, regnames: Iterable[str]): self.archname = archname - self.regnames = set(regnames) + self.regnames = set(name.upper() for name in regnames) def to_regname(self, name: str) -> str | None: """Transform a string into a standard register name. @@ -20,4 +20,7 @@ class Arch(): return None def __eq__(self, other): - return self.regnames == other.regnames + return self.archname == other.archname + + def __repr__(self) -> str: + return self.archname diff --git a/arch/x86.py b/arch/x86.py index 776291d..95e1a82 100644 --- a/arch/x86.py +++ b/arch/x86.py @@ -1,8 +1,10 @@ -"""Architexture-specific configuration.""" +"""Architecture-specific configuration.""" from .arch import Arch -# Names of registers in the architexture +archname = 'x86_64' + +# Names of registers in the architecture regnames = [ 'RIP', 'RAX', @@ -22,11 +24,22 @@ regnames = [ 'R14', 'R15', 'RFLAGS', + + # x87 float registers + 'ST0', 'ST1', 'ST2', 'ST3', 'ST4', 'ST5', 'ST6', 'ST7', + + # Vector registers + 'YMM0', 'YMM1', 'YMM2', 'YMM3', 'YMM4', + 'YMM5', 'YMM6', 'YMM7', 'YMM8', 'YMM9', + 'YMM10', 'YMM11', 'YMM12', 'YMM13', 'YMM14', 'YMM15', + # Segment registers 'CS', 'DS', 'SS', 'ES', 'FS', 'GS', 'FS_BASE', 'GS_BASE', + # FLAGS 'CF', 'PF', 'AF', 'ZF', 'SF', 'TF', 'IF', 'DF', 'OF', 'IOPL', 'NT', + # EFLAGS 'RF', 'VM', 'AC', 'VIF', 'VIP', 'ID', ] @@ -74,7 +87,7 @@ def decompose_rflags(rflags: int) -> dict[str, int]: class ArchX86(Arch): def __init__(self): - super().__init__("X86", regnames) + super().__init__(archname, regnames) def to_regname(self, name: str) -> str | None: """The X86 override of the standard register name lookup. diff --git a/compare.py b/compare.py index 0f144bf..e5ac244 100644 --- a/compare.py +++ b/compare.py @@ -1,4 +1,4 @@ -from snapshot import ProgramState +from snapshot import ProgramState, MemoryAccessError from symbolic import SymbolicTransform def _calc_transformation(previous: ProgramState, current: ProgramState): @@ -124,35 +124,112 @@ def compare_simple(test_states: list[ProgramState], return result -def _find_errors_symbolic(txl_from: ProgramState, +def _find_register_errors(txl_from: ProgramState, txl_to: ProgramState, transform_truth: SymbolicTransform) \ - -> list[dict]: - arch = txl_from.arch - - assert(txl_from.read('PC') == transform_truth.range[0]) - assert(txl_to.read('PC') == transform_truth.range[1]) + -> list[str]: + """Find errors in register values. + + Errors might be: + - A register value was modified, but the tested state contains no + reference value for that register. + - The tested destination state's value for a register does not match + the value expected by the symbolic transformation. + """ + # Calculate expected register values + try: + truth = transform_truth.calc_register_transform(txl_from) + except MemoryAccessError: + print(f'Transformation at {hex(transform_truth.addr)} depends on' + f' memory that is not set in the tested state. Skipping.') + return [] + # Compare expected values to actual values in the tested state errors = [] - for reg in arch.regnames: - if txl_from.read(reg) is None or txl_to.read(reg) is None: - print(f'A value for {reg} must be set in all translated states.' - ' Skipping.') + for regname, truth_val in truth.items(): + try: + txl_val = txl_to.read(regname) + except ValueError: + errors.append(f'Value of register {regname} has changed, but is' + f' not set in the tested state. Skipping.') continue + except KeyError as err: + print(f'[WARNING] {err}') + continue + + if txl_val != truth_val: + errors.append(f'Content of register {regname} is possibly false.' + f' Expected value: {hex(truth_val)}, actual' + f' value in the translation: {hex(txl_val)}.') + return errors - txl_val = txl_to.read(reg) +def _find_memory_errors(txl_from: ProgramState, + txl_to: ProgramState, + transform_truth: SymbolicTransform) \ + -> list[str]: + """Find errors in memory values. + + Errors might be: + - A range of memory was written, but the tested state contains no + reference value for that range. + - The tested destination state's content for the tested range does not + match the value expected by the symbolic transformation. + """ + # Calculate expected register values + try: + truth = transform_truth.calc_memory_transform(txl_from) + except MemoryAccessError: + print(f'Transformation at {hex(transform_truth.addr)} depends on' + f' memory that is not set in the tested state. Skipping.') + return [] + + # Compare expected values to actual values in the tested state + errors = [] + for addr, truth_bytes in truth.items(): try: - truth = transform_truth.calc_register_transform(txl_from) - print(f'Evaluated symbolic formula to {hex(txl_val)} vs. txl {hex(txl_val)}') - if txl_val != truth: - errors.append({ - 'reg': reg, - 'expected': truth, - 'actual': txl_val, - 'equation': transform_truth.regs_diff[reg], - }) - except AttributeError: - print(f'Register {reg} does not exist.') + txl_bytes = txl_to.read_memory(addr, len(truth_bytes)) + except MemoryAccessError: + errors.append(f'Memory range [{addr}, {addr + len(truth_bytes)})' + f' is not set in the test-result state. Skipping.') + continue + + if txl_bytes != truth_bytes: + errors.append(f'Content of memory at {addr} is possibly false.' + f' Expected content: {truth_bytes.hex()}, actual' + f' content in the translation: {txl_bytes.hex()}.') + return errors + +def _find_errors_symbolic(txl_from: ProgramState, + txl_to: ProgramState, + transform_truth: SymbolicTransform) \ + -> list[str]: + """Tries to find errors in transformations between tested states. + + Applies a transformation to a source state and tests whether the result + matches a given destination state. + + :param txl_from: Source state. This is a state from the tested + program, and is assumed as the starting point for + the transformation. + :param txl_to: Destination state. This is a possibly faulty state + from the tested program, and is tested for + correctness with respect to the source state. + :param transform_truth: The symbolic transformation that maps the source + state to the destination state. + """ + if (txl_from.read('PC') != transform_truth.range[0]) \ + or (txl_to.read('PC') != transform_truth.range[1]): + tstart, tend = transform_truth.range + print(f'[WARNING] Program counters of the tested transformation do not' + f' match the truth transformation:' + f' {hex(txl_from.read("PC"))} -> {hex(txl_to.read("PC"))} (test)' + f' vs. {hex(tstart)} -> {hex(tend)} (truth).' + f' Skipping with no errors.') + return [] + + errors = [] + errors.extend(_find_register_errors(txl_from, txl_to, transform_truth)) + errors.extend(_find_memory_errors(txl_from, txl_to, transform_truth)) return errors diff --git a/lldb_target.py b/lldb_target.py index f587b37..b96f66b 100644 --- a/lldb_target.py +++ b/lldb_target.py @@ -1,6 +1,6 @@ import lldb -from arch import x86 +from arch import supported_architectures, x86 from snapshot import ProgramState class MemoryMap: @@ -53,15 +53,7 @@ class LLDBConcreteTarget: raise RuntimeError(f'[In LLDBConcreteTarget.__init__]: Failed to' f' launch process.') - def set_breakpoint(self, addr): - command = f'b -a {addr} -s {self.module.GetFileSpec().GetFilename()}' - result = lldb.SBCommandReturnObject() - self.interpreter.HandleCommand(command, result) - - def remove_breakpoint(self, addr): - command = f'breakpoint delete {addr}' - result = lldb.SBCommandReturnObject() - self.interpreter.HandleCommand(command, result) + self.archname = self.target.GetPlatform().GetTriple().split('-')[0] def is_exited(self): """Signals whether the concrete process has exited. @@ -84,11 +76,48 @@ class LLDBConcreteTarget: thread: lldb.SBThread = self.process.GetThreadAtIndex(0) thread.StepInstruction(False) + def record_snapshot(self) -> ProgramState: + """Record the concrete target's state in a ProgramState object.""" + # Determine current arch + if self.archname not in supported_architectures: + print(f'[ERROR] LLDBConcreteTarget: Recording snapshots is not' + f' supported for architecture {self.archname}!') + raise NotImplementedError() + arch = supported_architectures[self.archname] + + state = ProgramState(arch) + + # Query and store register state + for regname in arch.regnames: + try: + conc_val = self.read_register(regname) + state.set(regname, conc_val) + except KeyError: + pass + except ConcreteRegisterError: + # Special rule for flags on X86 + if arch.archname == x86.archname: + rflags = x86.decompose_rflags(self.read_register('rflags')) + if regname in rflags: + state.set(regname, rflags[regname]) + + # Query and store memory state + for mapping in self.get_mappings(): + assert(mapping.end_address > mapping.start_address) + size = mapping.end_address - mapping.start_address + try: + data = self.read_memory(mapping.start_address, size) + state.write_memory(mapping.start_address, data) + except ConcreteMemoryError: + pass + + return state + def _get_register(self, regname: str) -> lldb.SBValue: """Find a register by name. - :raise SimConcreteRegisterError: If no register with the specified name - can be found. + :raise ConcreteRegisterError: If no register with the specified name + can be found. """ frame = self.process.GetThreadAtIndex(0).GetFrameAtIndex(0) reg = frame.FindRegister(regname) @@ -99,6 +128,12 @@ class LLDBConcreteTarget: return reg def read_register(self, regname: str) -> int: + """Read the value of a register. + + :raise ConcreteRegisterError: If `regname` is not a valid register name + or the target is otherwise unable to read + the register's value. + """ reg = self._get_register(regname) val = reg.GetValue() if val is None: @@ -109,6 +144,12 @@ class LLDBConcreteTarget: return int(val, 16) def write_register(self, regname: str, value: int): + """Read the value of a register. + + :raise ConcreteRegisterError: If `regname` is not a valid register name + or the target is otherwise unable to set + the register's value. + """ reg = self._get_register(regname) error = lldb.SBError() reg.SetValueFromCString(hex(value), error) @@ -118,19 +159,27 @@ class LLDBConcreteTarget: f' {regname} to value {hex(value)}!') def read_memory(self, addr, size): + """Read bytes from memory. + + :raise ConcreteMemoryError: If unable to read `size` bytes from `addr`. + """ err = lldb.SBError() content = self.process.ReadMemory(addr, size, err) if not err.success: raise ConcreteMemoryError(f'Error when reading {size} bytes at' - f' address {hex(addr)}: {err}') + f' address {hex(addr)}: {err}') return content - def write_memory(self, addr, value): + def write_memory(self, addr, value: bytes): + """Write bytes to memory. + + :raise ConcreteMemoryError: If unable to write at `addr`. + """ err = lldb.SBError() res = self.process.WriteMemory(addr, value, err) if not err.success or res != len(value): raise ConcreteMemoryError(f'Error when writing to address' - f' {hex(addr)}: {err}') + f' {hex(addr)}: {err}') def get_mappings(self) -> list[MemoryMap]: mmap = [] @@ -151,35 +200,12 @@ class LLDBConcreteTarget: perms)) return mmap -def record_snapshot(target: LLDBConcreteTarget) -> ProgramState: - """Record a concrete target's state in a ProgramState object. + def set_breakpoint(self, addr): + command = f'b -a {addr} -s {self.module.GetFileSpec().GetFilename()}' + result = lldb.SBCommandReturnObject() + self.interpreter.HandleCommand(command, result) - :param target: The target from which to query state. Currently assumes an - X86 target. - """ - state = ProgramState(x86.ArchX86()) - - # Query and store register state - rflags = x86.decompose_rflags(target.read_register('rflags')) - for regname in x86.regnames: - try: - conc_val = target.read_register(regname) - state.set(regname, conc_val) - except KeyError: - pass - except ConcreteRegisterError: - if regname in rflags: - state.set(regname, rflags[regname]) - - # Query and store memory state - for mapping in target.get_mappings(): - assert(mapping.end_address > mapping.start_address) - size = mapping.end_address - mapping.start_address - try: - data = target.read_memory(mapping.start_address, size) - state.write_memory(mapping.start_address, data) - except ConcreteMemoryError: - # Unable to read memory from mapping - pass - - return state + def remove_breakpoint(self, addr): + command = f'breakpoint delete {addr}' + result = lldb.SBCommandReturnObject() + self.interpreter.HandleCommand(command, result) diff --git a/main.py b/main.py index fabb05b..a51ecf7 100755 --- a/main.py +++ b/main.py @@ -4,11 +4,12 @@ import argparse import platform from typing import Iterable -import arancini from arch import x86 from compare import compare_simple, compare_symbolic -from lldb_target import LLDBConcreteTarget, record_snapshot -from symbolic import collect_symbolic_trace +from lldb_target import LLDBConcreteTarget +from parser import parse_arancini +from snapshot import ProgramState +from symbolic import SymbolicTransform, collect_symbolic_trace from utils import check_version, print_separator def run_native_execution(oracle_program: str, breakpoints: Iterable[int]): @@ -31,28 +32,58 @@ def run_native_execution(oracle_program: str, breakpoints: Iterable[int]): # Execute the native program snapshots = [] while not target.is_exited(): - snapshots.append(record_snapshot(target)) + snapshots.append(target.record_snapshot()) target.run() return snapshots -def parse_inputs(txl_path, ref_path, program): +def match_traces(test: list[ProgramState], truth: list[SymbolicTransform]): + if not test or not truth: + return [], [] + + assert(test[0].read('pc') == truth[0].addr) + + def index(seq, target, access=lambda el: el): + for i, el in enumerate(seq): + if access(el) == target: + return i + return None + + i = 0 + for next_state in test[1:]: + next_pc = next_state.read('pc') + index_in_truth = index(truth[i:], next_pc, lambda el: el.range[1]) + + # If no next element (i.e. no foldable range) is found in the truth + # trace, assume that the test trace contains excess states. Remove one + # and try again. This might skip testing some states, but covers more + # of the entire trace. + if index_in_truth is None: + test.pop(i + 1) + continue + + # Fold the range of truth states until the next test state + for _ in range(index_in_truth): + truth[i].concat(truth.pop(i + 1)) + + assert(truth[i].range[1] == truth[i + 1].addr) + + i += 1 + if len(truth) <= i: + break + + return test, truth + +def parse_inputs(txl_path, program): # Our architecture arch = x86.ArchX86() - txl = [] with open(txl_path, "r") as txl_file: - txl = arancini.parse(txl_file.readlines(), arch) + txl = parse_arancini(txl_file, arch) - ref = [] - if program is not None: - with open(txl_path, "r") as txl_file: - breakpoints = arancini.parse_break_addresses(txl_file.readlines()) + with open(txl_path, "r") as txl_file: + breakpoints = [state.read('PC') for state in txl] ref = run_native_execution(program, breakpoints) - else: - assert(ref_path is not None) - with open(ref_path, "r") as native_file: - ref = arancini.parse(native_file.readlines(), arch) return txl, ref @@ -60,23 +91,18 @@ def parse_arguments(): parser = argparse.ArgumentParser(description='Comparator for emulator logs to reference') parser.add_argument('-p', '--program', type=str, + required=True, help='Path to oracle program') - parser.add_argument('-r', '--ref', + parser.add_argument('-a', '--program-arg', type=str, - required=True, - help='Path to the reference log (gathered with run.sh)') + required=False, + default=[], + action='append', + help='Arguments to the program specified with --program.') parser.add_argument('-t', '--txl', type=str, required=True, help='Path to the translation log (gathered via Arancini)') - parser.add_argument('-s', '--stats', - action='store_true', - default=False, - help='Run statistics on comparisons') - parser.add_argument('-v', '--verbose', - action='store_true', - default=True, - help='Path to oracle program') parser.add_argument('--symbolic', action='store_true', default=False, @@ -90,46 +116,17 @@ def main(): args = parse_arguments() txl_path = args.txl - reference_path = args.ref program = args.program - - stats = args.stats - verbose = args.verbose - - if verbose: - print("Enabling verbose program output") - print(f"Verbose: {verbose}") - print(f"Statistics: {stats}") - print(f"Symbolic: {args.symbolic}") - - if program is None and reference_path is None: - raise ValueError('Either program or path to native file must be' - 'provided') - - txl, ref = parse_inputs(txl_path, reference_path, program) - - if program != None and reference_path != None: - with open(reference_path, 'w') as w: - for snapshot in ref: - print(snapshot, file=w) + prog_args = args.program_arg + txl, ref = parse_inputs(txl_path, program) if args.symbolic: assert(program is not None) - transforms = collect_symbolic_trace(program, [program]) - - new = transforms[0] \ - .concat(transforms[1]) \ - .concat(transforms[2]) \ - .concat(transforms[3]) \ - .concat(transforms[4]) - print(f'New transform: {new}') - exit(0) - # TODO: Transform the traces so that the states match + print(f'Tracing {program} with arguments {prog_args}...') + transforms = collect_symbolic_trace(program, [program, *prog_args]) + txl, transforms = match_traces(txl, transforms) result = compare_symbolic(txl, transforms) - - raise NotImplementedError('The symbolic comparison algorithm is not' - ' supported yet.') else: result = compare_simple(txl, ref) @@ -140,15 +137,13 @@ def main(): print(f'For PC={hex(pc)}') print_separator() - txl = res['txl'] ref = res['ref'] for err in res['errors']: - reg = err['reg'] - print(f'Content of register {reg} is possibly false.' - f' Expected difference: {err["expected"]}, actual difference' - f' in the translation: {err["actual"]}.\n' - f' (txl) {reg}: {hex(txl.read(reg))}\n' - f' (ref) {reg}: {hex(ref.read(reg))}') + print(f' - {err}') + if res['errors']: + print(ref) + else: + print('No errors found.') print() print('#' * 60) diff --git a/parser.py b/parser.py new file mode 100644 index 0000000..d2fcf13 --- /dev/null +++ b/parser.py @@ -0,0 +1,124 @@ +"""Parsing of JSON files containing snapshot data.""" + +import json +import re +from typing import TextIO + +from arch import supported_architectures, Arch +from snapshot import ProgramState + +class ParseError(Exception): + """A parse error.""" + +def _get_or_throw(obj: dict, key: str): + """Get a value from a dict or throw a ParseError if not present.""" + val = obj.get(key) + if val is not None: + return val + raise ParseError(f'Expected value at key {key}, but found none.') + +def parse_snapshots(json_stream: TextIO) -> list[ProgramState]: + """Parse snapshots from our JSON format.""" + json_data = json.load(json_stream) + + arch = supported_architectures[_get_or_throw(json_data, 'architecture')] + snapshots = [] + for snapshot in _get_or_throw(json_data, 'snapshots'): + state = ProgramState(arch) + for reg, val in _get_or_throw(snapshot, 'registers').items(): + state.set(reg, val) + for mem in _get_or_throw(snapshot, 'memory'): + start, end = _get_or_throw(mem, 'range') + data = _get_or_throw(mem, 'data').encode() + assert(len(data) == end - start) + state.write_memory(start, data) + + snapshots.append(state) + + return snapshots + +def serialize_snapshots(snapshots: list[ProgramState], out_stream: TextIO): + """Serialize a list of snapshots to out JSON format.""" + if not snapshots: + return json.dump({}, out_stream) + + arch = snapshots[0].arch + res = { 'architecture': arch.archname, 'snapshots': [] } + for snapshot in snapshots: + assert(snapshot.arch == arch) + regs = {r: v for r, v in snapshot.regs.items() if v is not None} + mem = [] + for addr, data in snapshot.mem._pages.items(): + mem.append({ + 'range': [addr, addr + len(data)], + 'data': data.decode(), + }) + res['snapshots'].append({ 'registers': regs, 'memory': mem }) + + json.dump(res, out_stream) + +def parse_qemu(stream: TextIO, arch: Arch) -> list[ProgramState]: + states = [] + for line in stream: + if line.startswith('Trace'): + states.append(ProgramState(arch)) + continue + + line = line.strip() + + # Remove padding spaces around equality signs + line = re.sub(' =', '=', line) + line = re.sub('= +', '=', line) + + # Standardize register names + line = re.sub('YMM0([0-9])', lambda m: f'YMM{m.group(1)}', line) + line = re.sub('FPR([0-9])', lambda m: f'ST{m.group(1)}', line) + + # Bring each register assignment into a new line + line = re.sub(' ([A-Z0-9]+)=', lambda m: f'\n{m.group(1)}=', line) + + # Remove all trailing information from register assignments + line = re.sub('^([A-Z0-9]+)=([0-9a-f ]+).*$', + lambda m: f'{m.group(1)}={m.group(2)}', + line, + 0, re.MULTILINE) + + # Now parse registers and their values from the resulting lines + lines = line.split('\n') + for line in lines: + split = line.split('=') + if len(split) == 2: + regname, value = split + value = value.replace(' ', '') + regname = arch.to_regname(regname) + if regname is not None: + states[-1].set(regname, int(value, 16)) + + return states + +def parse_arancini(stream: TextIO, arch: Arch) -> list[ProgramState]: + aliases = { + 'Program counter': 'RIP', + 'flag ZF': 'ZF', + 'flag CF': 'CF', + 'flag OF': 'OF', + 'flag SF': 'SF', + 'flag PF': 'PF', + 'flag DF': 'DF', + } + + states = [] + for line in stream: + if line.startswith('INVOKE PC='): + states.append(ProgramState(arch)) + continue + + # Parse a register assignment + split = line.split(':') + if len(split) == 2 and states: + regname, value = split + regname = arch.to_regname(aliases.get(regname, regname)) + if regname is not None: + states[-1].set(regname, int(value, 16)) + + return states diff --git a/snapshot.py b/snapshot.py index 80c1ac5..ed94a75 100644 --- a/snapshot.py +++ b/snapshot.py @@ -104,9 +104,24 @@ class ProgramState: self.regs[regname] = value def read_memory(self, addr: int, size: int) -> bytes: + """Read a number of bytes from memory. + + :param addr: The address from which to read data. + :param data: Number of bytes to read, starting at `addr`. Must be + at least zero. + + :raise MemoryAccessError: If `[addr, addr + size)` is not entirely + contained in the set of stored bytes. + :raise ValueError: If `size < 0`. + """ return self.mem.read(addr, size) def write_memory(self, addr: int, data: bytes): + """Write a number of bytes to memory. + + :param addr: The address at which to store the data. + :param data: The data to store at `addr`. + """ self.mem.write(addr, data) def __repr__(self): diff --git a/symbolic.py b/symbolic.py index 2a328fd..b005c5e 100644 --- a/symbolic.py +++ b/symbolic.py @@ -8,9 +8,10 @@ from miasm.analysis.machine import Machine from miasm.core.asmblock import AsmCFG from miasm.core.locationdb import LocationDB from miasm.ir.symbexec import SymbolicExecutionEngine +from miasm.ir.ir import IRBlock from miasm.expression.expression import Expr, ExprId, ExprMem, ExprInt -from lldb_target import LLDBConcreteTarget, record_snapshot +from lldb_target import LLDBConcreteTarget from miasm_util import MiasmConcreteState, eval_expr from snapshot import ProgramState from arch import Arch, supported_architectures @@ -36,11 +37,14 @@ class SymbolicTransform: -> dict[int, bytes]: raise NotImplementedError('calc_memory_transform is abstract.') + def __repr__(self) -> str: + start, end = self.range + return f'Symbolic state transformation {hex(start)} -> {hex(end)}' + class MiasmSymbolicTransform(SymbolicTransform): def __init__(self, - transform: dict[ExprId, Expr], + transform: dict[Expr, Expr], arch: Arch, - loc_db: LocationDB, start_addr: int, end_addr: int): """ @@ -55,6 +59,8 @@ class MiasmSymbolicTransform(SymbolicTransform): self.regs_diff: dict[str, Expr] = {} self.mem_diff: dict[ExprMem, Expr] = {} for dst, expr in transform.items(): + assert(isinstance(dst, ExprMem) or isinstance(dst, ExprId)) + if isinstance(dst, ExprMem): self.mem_diff[dst] = expr else: @@ -64,14 +70,16 @@ class MiasmSymbolicTransform(SymbolicTransform): self.regs_diff[regname] = expr self.arch = arch - self.loc_db = loc_db def concat(self, other: MiasmSymbolicTransform) -> Self: - class MiasmSymbolicState: + class MiasmSymbolicState(MiasmConcreteState): """Drop-in replacement for MiasmConcreteState in eval_expr that returns the current transform's symbolic equations instead of symbolic values. Calling eval_expr with this effectively nests the transformation into the concatenated transformation. + + We inherit from `MiasmSymbolicTransform` only for the purpose of + correct type checking. """ def __init__(self, transform: MiasmSymbolicTransform): self.transform = transform @@ -110,7 +118,9 @@ class MiasmSymbolicTransform(SymbolicTransform): def calc_register_transform(self, conc_state: ProgramState) \ -> dict[str, int]: - ref_state = MiasmConcreteState(conc_state, self.loc_db) + # Construct a dummy location DB. At this point, expressions should + # never contain IR locations. + ref_state = MiasmConcreteState(conc_state, LocationDB()) res = {} for regname, expr in self.regs_diff.items(): @@ -119,7 +129,9 @@ class MiasmSymbolicTransform(SymbolicTransform): def calc_memory_transform(self, conc_state: ProgramState) \ -> dict[int, bytes]: - ref_state = MiasmConcreteState(conc_state, self.loc_db) + # Construct a dummy location DB. At this point, expressions should + # never contain IR locations. + ref_state = MiasmConcreteState(conc_state, LocationDB()) res = {} for addr, expr in self.mem_diff.items(): @@ -149,8 +161,58 @@ def _step_until(target: LLDBConcreteTarget, addr: int) -> list[int]: target.step() return trace -def _run_block(pc: int, conc_state: MiasmConcreteState, lifter, ircfg, mdis) \ - -> tuple[int | None, list]: +class DisassemblyContext: + def __init__(self, binary): + self.loc_db = LocationDB() + + # Load the binary + with open(binary, 'rb') as bin_file: + cont = ContainerELF.from_stream(bin_file, self.loc_db) + + self.machine = Machine(cont.arch) + self.entry_point = cont.entry_point + + # Create disassembly/lifting context + self.lifter = self.machine.lifter(self.loc_db) + self.mdis = self.machine.dis_engine(cont.bin_stream, loc_db=self.loc_db) + self.mdis.follow_call = True + self.asmcfg = AsmCFG(self.loc_db) + self.ircfg = self.lifter.new_ircfg_from_asmcfg(self.asmcfg) + + def get_irblock(self, addr: int) -> IRBlock | None: + irblock = self.ircfg.get_block(addr) + + # Initial disassembly might not find all blocks in the binary. + # Disassemble code ad-hoc if the current address has not yet been + # disassembled. + if irblock is None: + cfg = self.mdis.dis_multiblock(addr) + for asmblock in cfg.blocks: + try: + self.lifter.add_asmblock_to_ircfg(asmblock, self.ircfg) + except NotImplementedError as err: + print(f'[WARNING] Unable to disassemble block at' + f' {hex(asmblock.get_range()[0])}:' + f' [Not implemented] {err}') + pass + print(f'Disassembled {len(cfg.blocks):5} new blocks at {hex(int(addr))}.') + irblock = self.ircfg.get_block(addr) + + # Might still be None if disassembly/lifting failed for the block + # at `addr`. + return irblock + +class DisassemblyError(Exception): + def __init__(self, + partial_trace: list[tuple[int, MiasmSymbolicTransform]], + faulty_pc: int, + err_msg: str): + self.partial_trace = partial_trace + self.faulty_pc = faulty_pc + self.err_msg = err_msg + +def _run_block(pc: int, conc_state: MiasmConcreteState, ctx: DisassemblyContext) \ + -> tuple[int | None, list[dict]]: """Run a basic block. Tries to run IR blocks until the end of an ASM block/basic block is @@ -166,37 +228,20 @@ def _run_block(pc: int, conc_state: MiasmConcreteState, lifter, ircfg, mdis) \ found. This happens when an error occurs or when the program exits. """ - global disasm_time - global symb_exec_time - # Start with a clean, purely symbolic state - engine = SymbolicExecutionEngine(lifter) + engine = SymbolicExecutionEngine(ctx.lifter) # A list of symbolic transformation for each single instruction symb_trace = [] while True: - irblock = ircfg.get_block(pc) - - # Initial disassembly might not find all blocks in the binary. - # Disassemble code ad-hoc if the current PC has not yet been - # disassembled. + irblock = ctx.get_irblock(pc) if irblock is None: - cfg = mdis.dis_multiblock(pc) - for asmblock in cfg.blocks: - try: - lifter.add_asmblock_to_ircfg(asmblock, ircfg) - except NotImplementedError as err: - print(f'[ERROR] Unable to disassemble block at' - f' {hex(asmblock.get_range()[0])}:' - f' [Not implemented] {err}') - pass - - irblock = ircfg.get_block(pc) - if irblock is None: - print(f'[ERROR] Unable to disassemble block(s) at {hex(pc)}.') - raise RuntimeError() - print(f'Disassembled {len(cfg.blocks):4} new blocks at {hex(int(pc))}.') + raise DisassemblyError( + symb_trace, + pc, + f'[ERROR] Unable to disassemble block at {hex(pc)}.' + ) # Execute each instruction in the current basic block and record the # resulting change in program state. @@ -240,27 +285,17 @@ def collect_symbolic_trace(binary: str, :param binary: The binary to trace. """ - loc_db = LocationDB() - with open(binary, 'rb') as bin_file: - cont = ContainerELF.from_stream(bin_file, loc_db) - machine = Machine(cont.arch) + ctx = DisassemblyContext(binary) # Find corresponding architecture - if machine.name not in supported_architectures: - print(f'[ERROR] {machine.name} is not supported. Returning.') + mach_name = ctx.machine.name + if mach_name not in supported_architectures: + print(f'[ERROR] {mach_name} is not supported. Returning.') return [] - arch = supported_architectures[machine.name] - - # Create disassembly/lifting context - mdis = machine.dis_engine(cont.bin_stream, loc_db=loc_db) - mdis.follow_call = True - asmcfg = AsmCFG(loc_db) - - lifter = machine.lifter(loc_db) - ircfg = lifter.new_ircfg_from_asmcfg(asmcfg) + arch = supported_architectures[mach_name] if start_addr is None: - pc = cont.entry_point + pc = ctx.entry_point else: pc = start_addr @@ -273,16 +308,40 @@ def collect_symbolic_trace(binary: str, symb_trace = [] # The resulting list of symbolic transforms per instruction # Run until no more states can be reached - initial_state = record_snapshot(target) + initial_state = target.record_snapshot() while pc is not None: assert(target.read_register('pc') == pc) # Run symbolic execution # It uses the concrete state to resolve symbolic program counters to # concrete values. - pc, strace = _run_block( - pc, MiasmConcreteState(initial_state, loc_db), - lifter, ircfg, mdis) + try: + pc, strace = _run_block( + pc, + MiasmConcreteState(initial_state, ctx.loc_db), + ctx + ) + except DisassemblyError as err: + # This happens if we encounter an instruction that is not + # implemented by Miasm. Try to skip that instruction and continue + # at the next one. + print(f'[WARNING] Skipping instruction at {hex(err.faulty_pc)}...') + + # First, catch up to symbolic trace if required + if err.faulty_pc != pc: + ctrace = _step_until(target, err.faulty_pc) + symb_trace.extend(err.partial_trace) + assert(len(ctrace) - 1 == len(err.partial_trace)) # no ghost instr + + # Now step one more time to skip the faulty instruction + target.step() + if target.is_exited(): + break + + symb_trace.append((err.faulty_pc, {})) # Generate empty transform + pc = target.read_register('pc') + initial_state = target.record_snapshot() + continue if pc is None: break @@ -309,12 +368,12 @@ def collect_symbolic_trace(binary: str, break # Query the new reference state for symbolic execution - initial_state = record_snapshot(target) + initial_state = target.record_snapshot() res = [] for (start, diff), (end, _) in zip(symb_trace[:-1], symb_trace[1:]): - res.append(MiasmSymbolicTransform(diff, arch, loc_db, start, end)) + res.append(MiasmSymbolicTransform(diff, arch, start, end)) start, diff = symb_trace[-1] - res.append(MiasmSymbolicTransform(diff, arch, loc_db, start, start)) + res.append(MiasmSymbolicTransform(diff, arch, start, start)) return res |