about summary refs log tree commit diff stats
path: root/run.py
blob: 768a73d2553772a1073bf033ed24d6ee0f04954b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Functionality to execute native programs and collect snapshots via lldb."""

import platform
import sys
import lldb
from typing import Callable

# TODO: The debugger callback is currently specific to a single architecture.
#       We should make it generic.
from arch import Arch, x86
from snapshot import ProgramState

class SnapshotBuilder:
    """At every breakpoint, writes register contents to a stream.

    Generated snapshots are stored in and can be read from `self.states`.
    """
    def __init__(self, arch: Arch):
        self.arch = arch
        self.states = []
        self.regnames = set(arch.regnames)

    def create_snapshot(self, frame: lldb.SBFrame):
        state = ProgramState(self.arch)
        state.set('PC', frame.GetPC())
        for regname in self.arch.regnames:
            reg = frame.FindRegister(regname)
            regval = int(reg.GetValue(), base=16)
            state.set(regname, regval)
            if regname == 'RFLAGS':
                flags = x86.decompose_rflags(regval)
                for flag_name, val in flags.items():
                    state.set(flag_name, val)
        return state

    def __call__(self, frame):
        snapshot = self.create_snapshot(frame)
        self.states.append(snapshot)

class Debugger:
    def __init__(self, program):
        self.debugger = lldb.SBDebugger.Create()
        self.debugger.SetAsync(False)
        self.target = self.debugger.CreateTargetWithFileAndArch(program,
                                                                lldb.LLDB_ARCH_DEFAULT)
        self.module = self.target.FindModule(self.target.GetExecutable())
        self.interpreter = self.debugger.GetCommandInterpreter()

    def set_breakpoint_by_addr(self, address: int):
        command = f"b -a {address} -s {self.module.GetFileSpec().GetFilename()}"
        result = lldb.SBCommandReturnObject()
        self.interpreter.HandleCommand(command, result)

    def get_breakpoints_count(self):
        return self.target.GetNumBreakpoints()

    def execute(self, callback: Callable):
        error = lldb.SBError()
        listener = self.debugger.GetListener()
        process = self.target.Launch(listener, None, None, None, None, None, None, 0,
                                     True, error)

        # Check if the process has launched successfully
        if process.IsValid():
            print(f'Launched process: {process}')
        else:
            print('Failed to launch process', file=sys.stderr)

        while True:
            state = process.GetState()
            if state == lldb.eStateStopped:
                 for thread in process:
                    callback(thread.GetFrameAtIndex(0))
                 process.Continue()
            if state == lldb.eStateExited:
                break

        print(f'Process state: {process.GetState()}')
        print('Program output:')
        print(process.GetSTDOUT(1024))
        print(process.GetSTDERR(1024))

def run_native_execution(oracle_program: str, breakpoints: set[int]):
    """Gather snapshots from a native execution via an external debugger.

    :param oracle_program: Program to execute.
    :param breakpoints: List of addresses at which to break and record the
                        program's state.

    :return: A list of snapshots gathered from the execution.
    """
    assert(platform.machine() == "x86_64")

    debugger = Debugger(oracle_program)

    # Set breakpoints
    for address in breakpoints:
        debugger.set_breakpoint_by_addr(address)
    assert(debugger.get_breakpoints_count() == len(breakpoints))

    # Execute the native program
    builder = SnapshotBuilder(x86.ArchX86())
    debugger.execute(builder)

    return builder.states