about summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorTheofilos Augoustis <theofilos.augoustis@gmail.com>2023-10-11 16:21:21 +0200
committerTheofilos Augoustis <theofilos.augoustis@gmail.com>2023-10-11 16:21:21 +0200
commit69c55d68d68c00007afa1af76a1d06f74ee72fe6 (patch)
tree991b92b4a5ba447b9fb5f77db4377bd9d14fbdf9
parentb9c08cadc158b18d7cab14a830a9e11f590ec7bd (diff)
downloadfocaccia-69c55d68d68c00007afa1af76a1d06f74ee72fe6.tar.gz
focaccia-69c55d68d68c00007afa1af76a1d06f74ee72fe6.zip
Refactor file structure
- main.py: focaccia user-interface

- snapshot.py: state trace snapshots handling

- compare.py: snapshot comparison algorithms

- run.py: native execution tracer

- arancini.py: Arancini log handling

- arch/: per-architecture abstractions

Co-authored-by: Theofilos Augoustis <theofilos.augoustis@gmail.com>
Co-authored-by: Nicola Crivellin <nicola.crivellin98@gmail.com>
-rw-r--r--.gitignore2
-rw-r--r--arancini.py94
-rw-r--r--arch/arch.py6
-rw-r--r--arch/x86.py33
-rw-r--r--[-rwxr-xr-x]compare.py265
-rwxr-xr-xmain.py92
-rw-r--r--[-rwxr-xr-x]run.py100
-rw-r--r--snapshot.py38
-rw-r--r--utils.py2
9 files changed, 326 insertions, 306 deletions
diff --git a/.gitignore b/.gitignore
index ee32d85..4586156 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,5 +7,5 @@ build*
 *.dot
 build*/
 out-*/
-__pycache__/*
+__pycache__/
 
diff --git a/arancini.py b/arancini.py
new file mode 100644
index 0000000..adbe6c8
--- /dev/null
+++ b/arancini.py
@@ -0,0 +1,94 @@
+"""Tools for working with arancini's output."""
+
+import re
+from functools import partial as bind
+
+from snapshot import ProgramState
+from arch.arch import Arch
+
+def parse_break_addresses(lines: list[str]) -> set[int]:
+    """Parse all breakpoint addresses from an arancini log."""
+    addresses = set()
+    for l in lines:
+        if l.startswith('INVOKE'):
+            addr = int(l.split('=')[1].strip(), base=16)
+            addresses.add(addr)
+
+    return addresses
+
+def parse(lines: list[str], arch: Arch) -> list[ProgramState]:
+    """Parse an arancini log into a list of snapshots.
+
+    :return: A list of program snapshots.
+    """
+
+    labels = get_labels()
+
+    # The regex decides for a line whether it contains a register
+    # based on a match with that register's label.
+    regex = re.compile("|".join(labels.keys()))
+
+    def try_parse_line(line: str) -> tuple[str, int] | None:
+        """Try to parse a register name and that register's value from a line.
+
+        :return: A register name and a register value if the line contains
+                 that information. None if parsing fails.
+        """
+        match = regex.match(line)
+        if match:
+            label = match.group(0)
+            register, get_reg_value = labels[label]
+            return register, get_reg_value(line)
+        return None
+
+    # Parse a list of program snapshots
+    snapshots = []
+    for line in lines:
+        if 'Backwards' in line and len(snapshots) > 0:
+            snapshots[-1].set_backwards()
+            continue
+
+        match = try_parse_line(line)
+        if match:
+            reg, value = match
+            if reg == 'PC':
+                snapshots.append(ProgramState(arch))
+            snapshots[-1].set(reg, value)
+
+    return snapshots
+
+def get_labels():
+    """Construct a helper structure for the arancini log parser."""
+    split_value = lambda x,i: int(x.split()[i], 16)
+
+    split_first = bind(split_value, i=1)
+    split_second = bind(split_value, i=2)
+
+    split_equal = lambda x,i: int(x.split('=')[i], 16)
+
+    # A mapping from regex patterns to the register name and a
+    # function that extracts that register's value from the line
+    labels = {'INVOKE':  ('PC',      bind(split_equal, i=1)),
+              'RAX':     ('RAX',     split_first),
+              'RBX':     ('RBX',     split_first),
+              'RCX':     ('RCX',     split_first),
+              'RDX':     ('RDX',     split_first),
+              'RSI':     ('RSI',     split_first),
+              'RDI':     ('RDI',     split_first),
+              'RBP':     ('RBP',     split_first),
+              'RSP':     ('RSP',     split_first),
+              'R8':      ('R8',      split_first),
+              'R9':      ('R9',      split_first),
+              'R10':     ('R10',     split_first),
+              'R11':     ('R11',     split_first),
+              'R12':     ('R12',     split_first),
+              'R13':     ('R13',     split_first),
+              'R14':     ('R14',     split_first),
+              'R15':     ('R15',     split_first),
+              'flag ZF': ('flag ZF', split_second),
+              'flag CF': ('flag CF', split_second),
+              'flag OF': ('flag OF', split_second),
+              'flag SF': ('flag SF', split_second),
+              'flag PF': ('flag PF', split_second),
+              'flag DF': ('flag DF', split_second)}
+    return labels
diff --git a/arch/arch.py b/arch/arch.py
new file mode 100644
index 0000000..36a4e3f
--- /dev/null
+++ b/arch/arch.py
@@ -0,0 +1,6 @@
+class Arch():
+    def __init__(self, regnames: list[str]):
+        self.regnames = regnames
+
+    def __eq__(self, other):
+        return self.regnames == other.regnames
diff --git a/arch/x86.py b/arch/x86.py
new file mode 100644
index 0000000..0f60457
--- /dev/null
+++ b/arch/x86.py
@@ -0,0 +1,33 @@
+"""Architexture-specific configuration."""
+
+from .arch import Arch
+
+# Names of registers in the architexture
+regnames = ['PC',
+            'RAX',
+            'RBX',
+            'RCX',
+            'RDX',
+            'RSI',
+            'RDI',
+            'RBP',
+            'RSP',
+            'R8',
+            'R9',
+            'R10',
+            'R11',
+            'R12',
+            'R13',
+            'R14',
+            'R15',
+            'RFLAGS',
+            'flag ZF',
+            'flag CF',
+            'flag OF',
+            'flag SF',
+            'flag PF',
+            'flag DF']
+
+class ArchX86(Arch):
+    def __init__(self):
+        super().__init__(regnames)
diff --git a/compare.py b/compare.py
index f4576dd..df8c378 100755..100644
--- a/compare.py
+++ b/compare.py
@@ -1,150 +1,17 @@
-#! /bin/python3
-import re
-import sys
-import shutil
-import argparse
-from typing import List, Callable
-from functools import partial as bind
-
-from utils import check_version
+from snapshot import ProgramState
 from utils import print_separator
 
-from run import Runner
-
-progressive = False
-
-class ContextBlock:
-    regnames = ['PC',
-                'RAX',
-                'RBX',
-                'RCX',
-                'RDX',
-                'RSI',
-                'RDI',
-                'RBP',
-                'RSP',
-                'R8',
-                'R9',
-                'R10',
-                'R11',
-                'R12',
-                'R13',
-                'R14',
-                'R15',
-                'flag ZF',
-                'flag CF',
-                'flag OF',
-                'flag SF',
-                'flag PF',
-                'flag DF']
-
-    def __init__(self):
-        dict_type = dict[str, int|None]  # A register may not have a value
-        self.regs = dict_type({reg: None for reg in ContextBlock.regnames})
-        self.has_backwards = False
-        self.matched = False
-
-    def set_backwards(self):
-        self.has_backwards = True
-
-    def set(self, reg: str, value: int):
-        """Assign a value to a register.
-
-        :raises RuntimeError: if the register already has a value.
-        """
-        if self.regs[reg] != None:
-            raise RuntimeError("Reassigning register")
-        self.regs[reg] = value
-
-    def __repr__(self):
-        return self.regs.__repr__()
-
-class Constructor:
-    """Builds a list of context blocks."""
-    def __init__(self, structure: dict[str, tuple[str, Callable[[str], int]]]):
-        self.cblocks = list[ContextBlock]()
-        self.labels = structure
-        self.regex = re.compile("|".join(structure.keys()))
-
-    def match(self, line: str) -> (tuple[str, int] | None):
-        """Find a register name and that register's value in a line.
-
-        :return: A register name and a register value.
-        """
-        match = self.regex.match(line)
-        if match:
-            label = match.group(0)
-            register, get_reg_value = self.labels[label]
-            return register, get_reg_value(line)
-
-        return None
-
-    def add_backwards(self):
-        self.cblocks[-1].set_backwards()
-
-    def add(self, reg: str, value: int):
-        if reg == 'PC':
-            self.cblocks.append(ContextBlock())
-        self.cblocks[-1].set(reg, value)
-
-def parse(lines: list[str], labels: dict):
-    """Parse a list of lines into a list of cblocks."""
-    ctor = Constructor(labels)
-    for line in lines:
-        if 'Backwards' in line:
-            ctor.add_backwards()
-            continue
-
-        match = ctor.match(line)
-        if match:
-            key, value = match
-            ctor.add(key, value)
-
-    return ctor.cblocks
-
-def get_labels():
-    split_value = lambda x,i: int(x.split()[i], 16)
-
-    split_first = bind(split_value, i=1)
-    split_second = bind(split_value, i=2)
-
-    split_equal = lambda x,i: int(x.split('=')[i], 16)
-
-    # A mapping from regex patterns to the register name and a
-    # function that extracts that register's value from the line
-    labels = {'INVOKE':  ('PC',      bind(split_equal, i=1)),
-              'RAX':     ('RAX',     split_first),
-              'RBX':     ('RBX',     split_first),
-              'RCX':     ('RCX',     split_first),
-              'RDX':     ('RDX',     split_first),
-              'RSI':     ('RSI',     split_first),
-              'RDI':     ('RDI',     split_first),
-              'RBP':     ('RBP',     split_first),
-              'RSP':     ('RSP',     split_first),
-              'R8':      ('R8',      split_first),
-              'R9':      ('R9',      split_first),
-              'R10':     ('R10',     split_first),
-              'R11':     ('R11',     split_first),
-              'R12':     ('R12',     split_first),
-              'R13':     ('R13',     split_first),
-              'R14':     ('R14',     split_first),
-              'R15':     ('R15',     split_first),
-              'flag ZF': ('flag ZF', split_second),
-              'flag CF': ('flag CF', split_second),
-              'flag OF': ('flag OF', split_second),
-              'flag SF': ('flag SF', split_second),
-              'flag PF': ('flag PF', split_second),
-              'flag DF': ('flag DF', split_second)}
-    return labels
-
-def calc_transformation(previous: ContextBlock, current: ContextBlock):
+def calc_transformation(previous: ProgramState, current: ProgramState):
     """Calculate the difference between two context blocks.
 
     :return: A context block that contains in its registers the difference
              between the corresponding input blocks' register values.
     """
-    transformation = ContextBlock()
-    for reg in ContextBlock.regnames:
+    assert(previous.arch == current.arch)
+
+    arch = previous.arch
+    transformation = ProgramState(arch)
+    for reg in arch.regnames:
         prev_val, cur_val = previous.regs[reg], current.regs[reg]
         if prev_val is not None and cur_val is not None:
             transformation.regs[reg] = cur_val - prev_val
@@ -158,15 +25,17 @@ def equivalent(val1, val2, transformation, previous_translation):
     # TODO: maybe incorrect
     return val1 - previous_translation == transformation
 
-def verify(translation: ContextBlock, reference: ContextBlock,
-           transformation: ContextBlock, previous_translation: ContextBlock):
+def verify(translation: ProgramState, reference: ProgramState,
+           transformation: ProgramState, previous_translation: ProgramState):
+    assert(translation.arch == reference.arch)
+
     if translation.regs["PC"] != reference.regs["PC"]:
         return 1
 
     print_separator()
-    print(f'For PC={hex(translation.regs["PC"])}')
+    print(f'For PC={translation.as_repr("PC")}')
     print_separator()
-    for reg in ContextBlock.regnames:
+    for reg in translation.arch.regnames:
         if translation.regs[reg] is None:
             print(f'Element not available in translation: {reg}')
         elif reference.regs[reg] is None:
@@ -174,16 +43,27 @@ def verify(translation: ContextBlock, reference: ContextBlock,
         elif not equivalent(translation.regs[reg], reference.regs[reg],
                             transformation.regs[reg],
                             previous_translation.regs[reg]):
-            txl = hex(translation.regs[reg])
-            ref = hex(reference.regs[reg])
+            txl = translation.as_repr(reg)
+            ref = reference.as_repr(reg)
             print(f'Difference for {reg}: {txl} != {ref}')
 
     return 0
 
-def compare(txl: List[ContextBlock], native: List[ContextBlock], stats: bool = False):
+def compare(txl: list[ProgramState],
+            native: list[ProgramState],
+            progressive: bool = False,
+            stats: bool = False):
+    """Compare two lists of snapshots and output the differences.
+
+    :param txl: The translated, and possibly faulty, state of the program.
+    :param native: The 'correct' reference state of the program.
+    :param progressive:
+    :param stats:
+    """
+
     if len(txl) != len(native):
-        print(f'Different number of blocks discovered translation: {len(txl)} vs. '
-              f'reference: {len(native)}', file=sys.stdout)
+        print(f'Different numbers of blocks discovered: '
+              f'{len(txl)} in translation vs. {len(native)} in reference.')
 
     previous_reference = native[0]
     previous_translation = txl[0]
@@ -210,8 +90,8 @@ def compare(txl: List[ContextBlock], native: List[ContextBlock], stats: bool = F
             if i == len(native):
                 matched = False
                 # TODO: add verbose output
-                print_separator(stream=sys.stdout)
-                print(f'No match for PC {hex(translation.regs["PC"])}', file=sys.stdout)
+                print_separator()
+                print(f'No match for PC {hex(translation.regs["PC"])}')
                 if translation.regs['PC'] not in unmatched_pcs:
                     unmatched_pcs[translation.regs['PC']] = 0
                 unmatched_pcs[translation.regs['PC']] += 1
@@ -238,12 +118,14 @@ def compare(txl: List[ContextBlock], native: List[ContextBlock], stats: bool = F
             if matched:
                 i += 1
     else:
+        txl = iter(txl)
+        native = iter(native)
         for translation, reference in zip(txl, native):
             transformation = calc_transformation(previous_reference, reference)
             if verify(translation, reference, transformation, previous_translation) == 1:
                 # TODO: add verbose output
-                print_separator(stream=sys.stdout)
-                print(f'No match for PC {hex(translation.regs["PC"])}', file=sys.stdout)
+                print_separator()
+                print(f'No match for PC {hex(translation.regs["PC"])}')
                 if translation.regs['PC'] not in unmatched_pcs:
                     unmatched_pcs[translation.regs['PC']] = 0
                 unmatched_pcs[translation.regs['PC']] += 1
@@ -277,80 +159,3 @@ def compare(txl: List[ContextBlock], native: List[ContextBlock], stats: bool = F
             if ref.matched == False:
                 unmatched_count += 1
     return 0
-
-def read_logs(txl_path, native_path, program):
-    txl = []
-    with open(txl_path, "r") as txl_file:
-        txl = txl_file.readlines()
-
-    native = []
-    if program is not None:
-        runner = Runner(txl, program)
-        native = runner.run()
-    else:
-        with open(native_path, "r") as native_file:
-            native = native_file.readlines()
-
-    return txl, native
-
-def parse_arguments():
-    parser = argparse.ArgumentParser(description='Comparator for emulator logs to reference')
-    parser.add_argument('-p', '--program',
-                        type=str,
-                        help='Path to oracle program')
-    parser.add_argument('-r', '--ref',
-                        type=str,
-                        required=True,
-                        help='Path to the reference log (gathered with run.sh)')
-    parser.add_argument('-t', '--txl',
-                        type=str,
-                        required=True,
-                        help='Path to the translation log (gathered via Arancini)')
-    parser.add_argument('-s', '--stats',
-                        action='store_true',
-                        default=False,
-                        help='Run statistics on comparisons')
-    parser.add_argument('-v', '--verbose',
-                        action='store_true',
-                        default=True,
-                        help='Path to oracle program')
-    parser.add_argument('--progressive',
-                        action='store_true',
-                        default=False,
-                        help='Try to match exhaustively before declaring \
-                        mismatch')
-    args = parser.parse_args()
-    return args
-
-if __name__ == "__main__":
-    check_version('3.7')
-
-    args = parse_arguments()
-
-    txl_path = args.txl
-    native_path = args.ref
-    program = args.program
-
-    stats = args.stats
-    verbose = args.verbose
-    progressive = args.progressive
-
-    if verbose:
-        print("Enabling verbose program output")
-        print(f"Verbose: {verbose}")
-        print(f"Statistics: {stats}")
-        print(f"Progressive: {progressive}")
-
-    if program is None and native_path is None:
-        raise ValueError('Either program or path to native file must be'
-                         'provided')
-
-    txl, native = read_logs(txl_path, native_path, program)
-
-    if program != None and native_path != None:
-        with open(native_path, 'w') as w:
-            w.write(''.join(native))
-
-    txl = parse(txl, get_labels())
-    native = parse(native, get_labels())
-    compare(txl, native, stats)
diff --git a/main.py b/main.py
new file mode 100755
index 0000000..076dc0e
--- /dev/null
+++ b/main.py
@@ -0,0 +1,92 @@
+#! /bin/python3
+
+import argparse
+
+import arancini
+from arch import x86
+from compare import compare
+from run import run_native_execution
+from utils import check_version
+
+def read_logs(txl_path, native_path, program):
+    txl = []
+    with open(txl_path, "r") as txl_file:
+        txl = txl_file.readlines()
+
+    native = []
+    if program is not None:
+        breakpoints = arancini.parse_break_addresses(txl)
+        native = run_native_execution(program, breakpoints)
+    else:
+        assert(native_path is not None)
+        with open(native_path, "r") as native_file:
+            native = native_file.readlines()
+
+    return txl, native
+
+def parse_arguments():
+    parser = argparse.ArgumentParser(description='Comparator for emulator logs to reference')
+    parser.add_argument('-p', '--program',
+                        type=str,
+                        help='Path to oracle program')
+    parser.add_argument('-r', '--ref',
+                        type=str,
+                        required=True,
+                        help='Path to the reference log (gathered with run.sh)')
+    parser.add_argument('-t', '--txl',
+                        type=str,
+                        required=True,
+                        help='Path to the translation log (gathered via Arancini)')
+    parser.add_argument('-s', '--stats',
+                        action='store_true',
+                        default=False,
+                        help='Run statistics on comparisons')
+    parser.add_argument('-v', '--verbose',
+                        action='store_true',
+                        default=True,
+                        help='Path to oracle program')
+    parser.add_argument('--progressive',
+                        action='store_true',
+                        default=False,
+                        help='Try to match exhaustively before declaring \
+                        mismatch')
+    args = parser.parse_args()
+    return args
+
+def main():
+    args = parse_arguments()
+
+    txl_path = args.txl
+    native_path = args.ref
+    program = args.program
+
+    stats = args.stats
+    verbose = args.verbose
+    progressive = args.progressive
+
+    # Our architexture
+    arch = x86.ArchX86()
+
+    if verbose:
+        print("Enabling verbose program output")
+        print(f"Verbose: {verbose}")
+        print(f"Statistics: {stats}")
+        print(f"Progressive: {progressive}")
+
+    if program is None and native_path is None:
+        raise ValueError('Either program or path to native file must be'
+                         'provided')
+
+    txl, native = read_logs(txl_path, native_path, program)
+
+    if program != None and native_path != None:
+        with open(native_path, 'w') as w:
+            w.write(''.join(native))
+
+    txl = arancini.parse(txl, arch)
+    native = arancini.parse(native, arch)
+    compare(txl, native, stats)
+
+if __name__ == "__main__":
+    check_version('3.7')
+    main()
diff --git a/run.py b/run.py
index f1f1060..9b51fb5 100755..100644
--- a/run.py
+++ b/run.py
@@ -1,39 +1,23 @@
-#! /bin/python3
-import os
+"""Functionality to execute native programs and collect snapshots via lldb."""
+
 import re
 import sys
 import lldb
-import shutil
-import argparse
+from typing import Callable
 
+# TODO: The debugger callback is currently specific to a single architexture.
+#       We should make it generic.
+from arch import x86
 from utils import print_separator
 
 verbose = False
 
-regnames = ['PC',
-            'RAX',
-            'RBX',
-            'RCX',
-            'RDX',
-            'RSI',
-            'RDI',
-            'RBP',
-            'RSP',
-            'R8',
-            'R9',
-            'R10',
-            'R11',
-            'R12',
-            'R13',
-            'R14',
-            'R15',
-            'RFLAGS']
-
 class DebuggerCallback:
-    def __init__(self, ostream=sys.stdout, skiplist: set = {}):
+    """At every breakpoint, writes register contents to a stream."""
+
+    def __init__(self, ostream=sys.stdout):
         self.stream = ostream
-        self.regex = re.compile('(' + '|'.join(regnames) + ')$')
-        self.skiplist = skiplist
+        self.regex = re.compile('(' + '|'.join(x86.regnames) + ')$')
 
     @staticmethod
     def parse_flags(flag_reg: int):
@@ -60,7 +44,6 @@ class DebuggerCallback:
         flags['PF'] = int(0 != flag_reg & (1 << 1))
         return flags
 
-
     def print_regs(self, frame):
         for reg in frame.GetRegisters():
             for sub_reg in reg:
@@ -93,11 +76,6 @@ class DebuggerCallback:
     def __call__(self, frame):
         pc = frame.GetPC()
 
-        # Skip this PC
-        if pc in self.skiplist:
-            self.skiplist.discard(pc)
-            return False
-
         print_separator('=', stream=self.stream, count=20)
         print(f'INVOKE PC={hex(pc)}', file=self.stream)
         print_separator('=', stream=self.stream, count=20)
@@ -109,8 +87,6 @@ class DebuggerCallback:
         print("STACK:", file=self.stream)
         self.print_stack(frame, 20)
 
-        return True  # Continue execution
-
 class Debugger:
     def __init__(self, program):
         self.debugger = lldb.SBDebugger.Create()
@@ -131,7 +107,7 @@ class Debugger:
     def get_breakpoints_count(self):
         return self.target.GetNumBreakpoints()
 
-    def execute(self, callback: callable):
+    def execute(self, callback: Callable):
         error = lldb.SBError()
         listener = self.debugger.GetListener()
         process = self.target.Launch(listener, None, None, None, None, None, None, 0,
@@ -169,46 +145,24 @@ class ListWriter:
     def __str__(self):
         return "".join(self.data)
 
-class Runner:
-    def __init__(self, dbt_log: list, oracle_program: str):
-        self.log = dbt_log
-        self.program = oracle_program
-        self.debugger = Debugger(self.program)
-        self.writer = ListWriter()
-
-    @staticmethod
-    def get_addresses(lines: list):
-        addresses = []
-
-        backlist = []
-        backlist_regex = re.compile(r'^\s\s\d*:')
-
-        skiplist = set()
-        for l in lines:
-            if l.startswith('INVOKE'):
-                addresses.append(int(l.split('=')[1].strip(), base=16))
-
-                if addresses[-1] in backlist:
-                    skiplist.add(addresses[-1])
-                backlist = []
-
-            if backlist_regex.match(l):
-                backlist.append(int(l.split()[0].split(':')[0], base=16))
-
-        return set(addresses), skiplist
-
-    def run(self):
-        # Get all addresses to stop at
-        addresses, skiplist = Runner.get_addresses(self.log)
+def run_native_execution(oracle_program: str, breakpoints: set[int]):
+    """Gather snapshots from a native execution via an external debugger.
 
-        # Set breakpoints
-        for address in addresses:
-            self.debugger.set_breakpoint_by_addr(address)
+    :param oracle_program: Program to execute.
+    :param breakpoints: List of addresses at which to break and record the
+                        program's state.
 
-        # Sanity check
-        assert(self.debugger.get_breakpoints_count() == len(addresses))
+    :return: A textual log of the program's execution in arancini's log format.
+    """
+    debugger = Debugger(oracle_program)
+    writer = ListWriter()
 
-        self.debugger.execute(DebuggerCallback(self.writer, skiplist))
+    # Set breakpoints
+    for address in breakpoints:
+        debugger.set_breakpoint_by_addr(address)
+    assert(debugger.get_breakpoints_count() == len(breakpoints))
 
-        return self.writer.data
+    # Execute the native program
+    debugger.execute(DebuggerCallback(writer))
 
+    return writer.data
diff --git a/snapshot.py b/snapshot.py
new file mode 100644
index 0000000..d5136ad
--- /dev/null
+++ b/snapshot.py
@@ -0,0 +1,38 @@
+from arch.arch import Arch
+
+class ProgramState():
+    """A snapshot of the program's state."""
+    def __init__(self, arch: Arch):
+        self.arch = arch
+
+        dict_t = dict[str, int]
+        self.regs = dict_t({ reg: None for reg in arch.regnames })
+        self.has_backwards = False
+        self.matched = False
+
+    def set_backwards(self):
+        self.has_backwards = True
+
+    def set(self, reg: str, value: int):
+        """Assign a value to a register.
+
+        :raises RuntimeError: if the register already has a value.
+        """
+        assert(reg in self.arch.regnames)
+
+        if self.regs[reg] != None:
+            raise RuntimeError("Reassigning register")
+        self.regs[reg] = value
+
+    def as_repr(self, reg: str):
+        """Get a representational string of a register's value."""
+        assert(reg in self.arch.regnames)
+
+        value = self.regs[reg]
+        if value is not None:
+            return hex(value)
+        else:
+            return "<none>"
+
+    def __repr__(self):
+        return self.regs.__repr__()
diff --git a/utils.py b/utils.py
index d841c7c..1390283 100644
--- a/utils.py
+++ b/utils.py
@@ -1,5 +1,3 @@
-#! /bin/python3
-
 import sys
 import shutil