diff options
| -rw-r--r-- | src/focaccia/tools/qemu/_qemu_tool.py | 337 |
1 files changed, 170 insertions, 167 deletions
diff --git a/src/focaccia/tools/qemu/_qemu_tool.py b/src/focaccia/tools/qemu/_qemu_tool.py index 6735d9e..a0f1b90 100644 --- a/src/focaccia/tools/qemu/_qemu_tool.py +++ b/src/focaccia/tools/qemu/_qemu_tool.py @@ -272,180 +272,186 @@ class GDBServerStateIterator: archname = split[1] if len(split) > 1 else split[0] return archname.replace('-', '_') -def record_minimal_snapshot(prev_state: ReadableProgramState, - cur_state: ReadableProgramState, - prev_transform: SymbolicTransform, - cur_transform: SymbolicTransform) \ - -> ProgramState: - """Record a minimal snapshot. - - A minimal snapshot must include values (registers and memory) that are - accessed by two transformations: - 1. The values produced by the previous transformation (the - transformation that is producing this snapshot) to check these - values against expected values calculated from the previous - program state. - 2. The values that act as inputs to the transformation acting on this - snapshot, to calculate the expected values of the next snapshot. - - :param prev_transform: The symbolic transformation generating, or - leading to, `cur_state`. Values generated by - this transformation are included in the - snapshot. - :param transform: The symbolic transformation operating on this - snapshot. Input values to this transformation are - included in the snapshot. - """ - assert(cur_state.read_register('pc') == cur_transform.addr) - assert(prev_transform.arch == cur_transform.arch) - - def get_written_addresses(t: SymbolicTransform): - """Get all output memory accesses of a symbolic transformation.""" - return [ExprMem(a, v.size) for a, v in t.changed_mem.items()] - - def set_values(regs: Iterable[str], mems: Iterable[ExprMem], - cur_state: ReadableProgramState, - prev_state: ReadableProgramState, - out_state: ProgramState): +class ConcolicTracer: + def __init__(self, + target: GDBServerStateIterator, + strace: Trace[SymbolicTransform], + deterministic_log: DeterministicLog): + self.target = target + self.symbolic_trace = strace + self.deterministic_log = deterministic_log + + def record_snapshot(self, + prev_state: ReadableProgramState, + cur_state: ReadableProgramState, + prev_transform: SymbolicTransform, + cur_transform: SymbolicTransform) -> ProgramState: + """Record a minimal snapshot. + + A minimal snapshot must include values (registers and memory) that are + accessed by two transformations: + 1. The values produced by the previous transformation (the + transformation that is producing this snapshot) to check these + values against expected values calculated from the previous + program state. + 2. The values that act as inputs to the transformation acting on this + snapshot, to calculate the expected values of the next snapshot. + + :param prev_transform: The symbolic transformation generating, or + leading to, `cur_state`. Values generated by + this transformation are included in the + snapshot. + :param transform: The symbolic transformation operating on this + snapshot. Input values to this transformation are + included in the snapshot. """ - :param prev_state: Addresses of memory included in the snapshot are - resolved relative to this state. + assert(cur_state.read_register('pc') == cur_transform.addr) + assert(prev_transform.arch == cur_transform.arch) + + def get_written_addresses(t: SymbolicTransform): + """Get all output memory accesses of a symbolic transformation.""" + return [ExprMem(a, v.size) for a, v in t.changed_mem.items()] + + def set_values(regs: Iterable[str], mems: Iterable[ExprMem], + cur_state: ReadableProgramState, + prev_state: ReadableProgramState, + out_state: ProgramState): + """ + :param prev_state: Addresses of memory included in the snapshot are + resolved relative to this state. + """ + for regname in regs: + try: + regval = cur_state.read_register(regname) + out_state.set_register(regname, regval) + except RegisterAccessError: + pass + for mem in mems: + assert(mem.size % 8 == 0) + addr = eval_symbol(mem.ptr, prev_state) + try: + mem = cur_state.read_memory(addr, int(mem.size / 8)) + out_state.write_memory(addr, mem) + except MemoryAccessError: + pass + + state = ProgramState(cur_transform.arch) + state.set_register('PC', cur_transform.addr) + + set_values(prev_transform.changed_regs.keys(), + get_written_addresses(prev_transform), + cur_state, + prev_state, # Evaluate memory addresses based on previous + # state because they are that state's output + # addresses. + state) + set_values(cur_transform.get_used_registers(), + cur_transform.get_used_memory_addresses(), + cur_state, + cur_state, + state) + return state + + def trace(self) -> tuple[list[ProgramState], list[SymbolicTransform]]: + """Collect a trace of concrete states from GDB. + + Records minimal concrete states from GDB by using symbolic trace + information to determine which register/memory values are required to + verify the correctness of the program running in GDB. + + May drop symbolic transformations if the symbolic trace and the GDB trace + diverge (e.g. because of differences in environment, etc.). Returns the + new, possibly modified, symbolic trace that matches the returned concrete + trace. + + :return: A list of concrete states and a list of corresponding symbolic + transformations. The lists are guaranteed to have the same length. """ - for regname in regs: - try: - regval = cur_state.read_register(regname) - out_state.set_register(regname, regval) - except RegisterAccessError: - pass - for mem in mems: - assert(mem.size % 8 == 0) - addr = eval_symbol(mem.ptr, prev_state) - try: - mem = cur_state.read_memory(addr, int(mem.size / 8)) - out_state.write_memory(addr, mem) - except MemoryAccessError: - pass - - state = ProgramState(cur_transform.arch) - state.set_register('PC', cur_transform.addr) - - set_values(prev_transform.changed_regs.keys(), - get_written_addresses(prev_transform), - cur_state, - prev_state, # Evaluate memory addresses based on previous - # state because they are that state's output - # addresses. - state) - set_values(cur_transform.get_used_registers(), - cur_transform.get_used_memory_addresses(), - cur_state, - cur_state, - state) - return state - -def collect_conc_trace(gdb: GDBServerStateIterator, \ - strace: list[SymbolicTransform], - start_addr: int | None = None, - stop_addr: int | None = None) \ - -> tuple[list[ProgramState], list[SymbolicTransform]]: - """Collect a trace of concrete states from GDB. - - Records minimal concrete states from GDB by using symbolic trace - information to determine which register/memory values are required to - verify the correctness of the program running in GDB. - - May drop symbolic transformations if the symbolic trace and the GDB trace - diverge (e.g. because of differences in environment, etc.). Returns the - new, possibly modified, symbolic trace that matches the returned concrete - trace. - - :return: A list of concrete states and a list of corresponding symbolic - transformations. The lists are guaranteed to have the same length. - """ - def find_index(seq, target, access=lambda el: el): - for i, el in enumerate(seq): - if access(el) == target: - return i - return None - - if not strace: - return [], [] - - states = [] - matched_transforms = [] - - state_iter = iter(gdb) - cur_state = next(state_iter) - symb_i = 0 - - # Skip to start - try: - pc = cur_state.read_register('pc') - if start_addr and pc != start_addr: - info(f'Tracing QEMU from starting address: {hex(start_addr)}') - cur_state = state_iter.run_until(start_addr) - except Exception as e: - if start_addr: - raise Exception(f'Unable to reach start address {hex(start_addr)}: {e}') - raise Exception(f'Unable to trace: {e}') + def find_index(seq, target, access=lambda el: el): + for i, el in enumerate(seq): + if access(el) == target: + return i + return None - # An online trace matching algorithm. - while True: - try: - pc = cur_state.read_register('pc') + states = [] + matched_transforms = [] - while pc != strace[symb_i].addr: - info(f'PC {hex(pc)} does not match next symbolic reference {hex(strace[symb_i].addr)}') + strace = self.symbolic_trace.states + start_addr = self.symbolic_trace.env.start_address + stop_addr = self.symbolic_trace.env.stop_address - next_i = find_index(strace[symb_i+1:], pc, lambda t: t.addr) + state_iter = iter(self.target) + cur_state = next(state_iter) + symb_i = 0 - # Drop the concrete state if no address in the symbolic trace - # matches - if next_i is None: - warn(f'Dropping concrete state {hex(pc)}, as no' - f' matching instruction can be found in the symbolic' - f' reference trace.') - cur_state = next(state_iter) - pc = cur_state.read_register('pc') - continue + # Skip to start + try: + pc = cur_state.read_register('pc') + if start_addr and pc != start_addr: + info(f'Tracing QEMU from starting address: {hex(start_addr)}') + cur_state = state_iter.run_until(start_addr) + except Exception as e: + if start_addr: + raise Exception(f'Unable to reach start address {hex(start_addr)}: {e}') + raise Exception(f'Unable to trace: {e}') - # Otherwise, jump to the next matching symbolic state - symb_i += next_i + 1 + # An online trace matching algorithm. + while True: + try: + pc = cur_state.read_register('pc') + + while pc != strace[symb_i].addr: + info(f'PC {hex(pc)} does not match next symbolic reference {hex(strace[symb_i].addr)}') + + next_i = find_index(strace[symb_i+1:], pc, lambda t: t.addr) + + # Drop the concrete state if no address in the symbolic trace + # matches + if next_i is None: + warn(f'Dropping concrete state {hex(pc)}, as no' + f' matching instruction can be found in the symbolic' + f' reference trace.') + cur_state = next(state_iter) + pc = cur_state.read_register('pc') + continue + + # Otherwise, jump to the next matching symbolic state + symb_i += next_i + 1 + if symb_i >= len(strace): + break + + assert(cur_state.read_register('pc') == strace[symb_i].addr) + info(f'Validating instruction at address {hex(pc)}') + states.append(self.record_snapshot( + states[-1] if states else cur_state, + cur_state, + matched_transforms[-1] if matched_transforms else strace[symb_i], + strace[symb_i])) + matched_transforms.append(strace[symb_i]) + cur_state = next(state_iter) + symb_i += 1 if symb_i >= len(strace): break - - assert(cur_state.read_register('pc') == strace[symb_i].addr) - info(f'Validating instruction at address {hex(pc)}') - states.append(record_minimal_snapshot( - states[-1] if states else cur_state, - cur_state, - matched_transforms[-1] if matched_transforms else strace[symb_i], - strace[symb_i])) - matched_transforms.append(strace[symb_i]) - cur_state = next(state_iter) - symb_i += 1 - if symb_i >= len(strace): + except StopIteration: + # TODO: The conditions may test for the same + if stop_addr and pc != stop_addr: + raise Exception(f'QEMU stopped at {hex(pc)} before reaching the stop address' + f' {hex(stop_addr)}') + if symb_i+1 < len(strace): + qemu_crash["crashed"] = True + qemu_crash["pc"] = strace[symb_i].addr + qemu_crash["ref"] = strace[symb_i] + qemu_crash["snap"] = states[-1] break - except StopIteration: - # TODO: The conditions may test for the same - if stop_addr and pc != stop_addr: - raise Exception(f'QEMU stopped at {hex(pc)} before reaching the stop address' - f' {hex(stop_addr)}') - if symb_i+1 < len(strace): - qemu_crash["crashed"] = True - qemu_crash["pc"] = strace[symb_i].addr - qemu_crash["ref"] = strace[symb_i] - qemu_crash["snap"] = states[-1] - break - except Exception as e: - print(traceback.format_exc()) - raise e + except Exception as e: + print(traceback.format_exc()) + raise e - # Note: this may occur when symbolic traces were gathered with a stop address - if symb_i >= len(strace): - warn(f'QEMU executed more states than native execution: {symb_i} vs {len(strace)-1}') + # Note: this may occur when symbolic traces were gathered with a stop address + if symb_i >= len(strace): + warn(f'QEMU executed more states than native execution: {symb_i} vs {len(strace)-1}') - return states, matched_transforms + return states, matched_transforms def main(): args = make_argparser().parse_args() @@ -485,11 +491,8 @@ def main(): # Use symbolic trace to collect concrete trace from QEMU try: - conc_states, matched_transforms = collect_conc_trace( - gdb_server, - symb_transforms.states, - symb_transforms.env.start_address, - symb_transforms.env.stop_address) + tracer = ConcolicTracer(gdb_server, symb_transforms, detlog) + conc_states, matched_transforms = tracer.trace() except Exception as e: raise Exception(f'Failed to collect concolic trace from QEMU: {e}') |