about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--focaccia/compare.py27
-rw-r--r--focaccia/miasm_util.py43
-rw-r--r--focaccia/snapshot.py39
-rw-r--r--focaccia/symbolic.py59
4 files changed, 97 insertions, 71 deletions
diff --git a/focaccia/compare.py b/focaccia/compare.py
index d89a41a..43a0133 100644
--- a/focaccia/compare.py
+++ b/focaccia/compare.py
@@ -21,7 +21,9 @@ class ErrorSeverity:
     def __repr__(self) -> str:
         return f'[{self.name}]'
 
-    def __eq__(self, other: Self) -> bool:
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, Self):
+            return False
         return self._numeral == other._numeral
 
     def __lt__(self, other: Self) -> bool:
@@ -195,9 +197,6 @@ def _find_register_errors(txl_from: ProgramState,
                                 f'Value of register {regname} has changed, but'
                                 f' is not set in the tested state.'))
             continue
-        except KeyError as err:
-            print(f'[WARNING] {err}')
-            continue
 
         if txl_val != truth_val:
             errors.append(Error(ErrorTypes.CONFIRMED,
@@ -293,28 +292,24 @@ def compare_symbolic(test_states: Iterable[ProgramState],
     transforms = iter(transforms)
 
     result = []
-    cur_state = next(test_states)   # The state before the transformation
-    transform = next(transforms)    # Operates on `cur_state`
 
+    cur_state = next(test_states)   # The state before the transformation
+    transform = next(transforms)    # Transform that operates on `cur_state`
     while True:
         try:
             next_state = next(test_states) # The state after the transformation
 
             pc_cur = cur_state.read_register('PC')
             pc_next = next_state.read_register('PC')
-            start_addr, end_addr = transform.range
-            if pc_cur != start_addr:
-                print(f'Program counter {hex(pc_cur)} in translated code has'
-                      f' no corresponding reference state! Skipping.'
-                      f' (reference: {hex(start_addr)})')
+            if (pc_cur, pc_next) != transform.range:
+                repr_range = lambda r: f'[{hex(r[0])} -> {hex(r[1])}]'
+                print(f'[WARNING] Test states {repr_range((pc_cur, pc_next))}'
+                      f' do not match the symbolic transformation'
+                      f' {repr_range(transform.range)} against which they are'
+                      f' tested! Skipping.')
                 cur_state = next_state
                 transform = next(transforms)
                 continue
-            if pc_next != end_addr:
-                print(f'Tested state transformation is {hex(pc_cur)} ->'
-                      f' {hex(pc_next)}, but reference transform is'
-                      f' {hex(start_addr)} -> {hex(end_addr)}!'
-                      f' Skipping.')
 
             errors = _find_errors_symbolic(cur_state, next_state, transform)
             result.append({
diff --git a/focaccia/miasm_util.py b/focaccia/miasm_util.py
index f43c151..7ab9c87 100644
--- a/focaccia/miasm_util.py
+++ b/focaccia/miasm_util.py
@@ -6,7 +6,8 @@ from miasm.expression.expression import Expr, ExprOp, ExprId, ExprLoc, \
                                         ExprSlice, ExprCond
 from miasm.expression.simplifications import expr_simp_explicit
 
-from .snapshot import ProgramState
+from .snapshot import ReadableProgramState, \
+                      RegisterAccessError, MemoryAccessError
 
 def simp_segm(expr_simp, expr: ExprOp):
     """Simplify a segmentation expression to an addition of the segment
@@ -29,7 +30,7 @@ def simp_segm(expr_simp, expr: ExprOp):
 expr_simp = expr_simp_explicit
 expr_simp.enable_passes({ExprOp: [simp_segm]})
 
-class MiasmConcreteState:
+class MiasmSymbolResolver:
     """Resolves atomic symbols to some state."""
 
     miasm_flag_aliases = {
@@ -39,23 +40,33 @@ class MiasmConcreteState:
         'I_D':    'ID',
     }
 
-    def __init__(self, state: ProgramState, loc_db: LocationDB):
+    def __init__(self, state: ReadableProgramState, loc_db: LocationDB):
         self._state = state
         self._loc_db = loc_db
 
-    def resolve_register(self, regname: str) -> int | None:
+    @staticmethod
+    def _miasm_to_regname(regname: str) -> str:
+        """Convert a register name as used by Miasm to one that follows
+        Focaccia's naming conventions."""
         regname = regname.upper()
-        if regname in self.miasm_flag_aliases:
-            regname = self.miasm_flag_aliases[regname]
-        return self._state.read_register(regname)
+        return MiasmSymbolResolver.miasm_flag_aliases.get(regname, regname)
+
+    def resolve_register(self, regname: str) -> int | None:
+        try:
+            return self._state.read_register(self._miasm_to_regname(regname))
+        except RegisterAccessError:
+            return None
 
     def resolve_memory(self, addr: int, size: int) -> bytes | None:
-        return self._state.read_memory(addr, size)
+        try:
+            return self._state.read_memory(addr, size)
+        except MemoryAccessError:
+            return None
 
     def resolve_location(self, loc: LocKey) -> int | None:
         return self._loc_db.get_location_offset(loc)
 
-def eval_expr(expr: Expr, conc_state: MiasmConcreteState) -> Expr:
+def eval_expr(expr: Expr, conc_state: MiasmSymbolResolver) -> Expr:
     """Evaluate a symbolic expression with regard to a concrete reference
     state.
 
@@ -95,7 +106,7 @@ def _eval_exprint(expr: ExprInt, _):
     """Evaluate an ExprInt using the current state"""
     return expr
 
-def _eval_exprid(expr: ExprId, state: MiasmConcreteState):
+def _eval_exprid(expr: ExprId, state: MiasmSymbolResolver):
     """Evaluate an ExprId using the current state"""
     val = state.resolve_register(expr.name)
     if val is None:
@@ -104,14 +115,14 @@ def _eval_exprid(expr: ExprId, state: MiasmConcreteState):
         return ExprInt(val, expr.size)
     return val
 
-def _eval_exprloc(expr: ExprLoc, state: MiasmConcreteState):
+def _eval_exprloc(expr: ExprLoc, state: MiasmSymbolResolver):
     """Evaluate an ExprLoc using the current state"""
     offset = state.resolve_location(expr.loc_key)
     if offset is None:
         return expr
     return ExprInt(offset, expr.size)
 
-def _eval_exprmem(expr: ExprMem, state: MiasmConcreteState):
+def _eval_exprmem(expr: ExprMem, state: MiasmSymbolResolver):
     """Evaluate an ExprMem using the current state.
     This function first evaluates the memory pointer value.
     """
@@ -133,14 +144,14 @@ def _eval_exprmem(expr: ExprMem, state: MiasmConcreteState):
     assert(len(mem) * 8 == expr.size)
     return ExprInt(int.from_bytes(mem), expr.size)
 
-def _eval_exprcond(expr, state: MiasmConcreteState):
+def _eval_exprcond(expr, state: MiasmSymbolResolver):
     """Evaluate an ExprCond using the current state"""
     cond = eval_expr(expr.cond, state)
     src1 = eval_expr(expr.src1, state)
     src2 = eval_expr(expr.src2, state)
     return ExprCond(cond, src1, src2)
 
-def _eval_exprslice(expr, state: MiasmConcreteState):
+def _eval_exprslice(expr, state: MiasmSymbolResolver):
     """Evaluate an ExprSlice using the current state"""
     arg = eval_expr(expr.arg, state)
     return ExprSlice(arg, expr.start, expr.stop)
@@ -161,7 +172,7 @@ def _eval_cpuid(rax: ExprInt, out_reg: ExprInt):
         raise ValueError(f'Output register may not be {out_reg}.')
     return ExprInt(regs[int(out_reg)], out_reg.size)
 
-def _eval_exprop(expr, state: MiasmConcreteState):
+def _eval_exprop(expr, state: MiasmSymbolResolver):
     """Evaluate an ExprOp using the current state"""
     args = []
     for oarg in expr.args:
@@ -175,7 +186,7 @@ def _eval_exprop(expr, state: MiasmConcreteState):
         return _eval_cpuid(args[0], args[1])
     return ExprOp(expr.op, *args)
 
-def _eval_exprcompose(expr, state: MiasmConcreteState):
+def _eval_exprcompose(expr, state: MiasmSymbolResolver):
     """Evaluate an ExprCompose using the current state"""
     args = []
     for arg in expr.args:
diff --git a/focaccia/snapshot.py b/focaccia/snapshot.py
index 6342660..592c2bb 100644
--- a/focaccia/snapshot.py
+++ b/focaccia/snapshot.py
@@ -79,24 +79,47 @@ class SparseMemory:
 
         assert(len(data) == offset)  # Exactly all data was written
 
-class ProgramState:
+class ReadableProgramState:
+    """Interface for read-only program states. Used for typing purposes."""
+
+    def read_register(self, reg: str) -> int:
+        """Read a register's value.
+
+        :raise RegisterAccessError: If `reg` is not a register name, or if the
+                                    register has no value.
+        """
+        raise NotImplementedError('ReadableProgramState.read_register is abstract.')
+
+    def read_memory(self, addr: int, size: int) -> bytes:
+        """Read a number of bytes from memory.
+
+        :param addr: The address from which to read data.
+        :param data: Number of bytes to read, starting at `addr`. Must be
+                     at least zero.
+
+        :raise MemoryAccessError: If `[addr, addr + size)` is not entirely
+                                  contained in the set of stored bytes.
+        :raise ValueError: If `size < 0`.
+        """
+        raise NotImplementedError('ReadableProgramState.read_memory is abstract.')
+
+class ProgramState(ReadableProgramState):
     """A snapshot of the program's state."""
     def __init__(self, arch: Arch):
         self.arch = arch
 
-        dict_t = dict[str, int | None]
-        self.regs: dict_t = { reg: None for reg in arch.regnames }
+        self.regs: dict[str, int | None] = {reg: None for reg in arch.regnames}
         self.mem = SparseMemory()
 
     def read_register(self, reg: str) -> int:
         """Read a register's value.
 
-        :raise KeyError:            If `reg` is not a register name.
-        :raise RegisterAccessError: If the register has no value.
+        :raise RegisterAccessError: If `reg` is not a register name, or if the
+                                    register has no value.
         """
         regname = self.arch.to_regname(reg)
         if regname is None:
-            raise KeyError(f'Not a register name: {reg}')
+            raise RegisterAccessError(reg, f'Not a register name: {reg}')
 
         assert(regname in self.regs)
         regval = self.regs[regname]
@@ -111,11 +134,11 @@ class ProgramState:
     def set_register(self, reg: str, value: int):
         """Assign a value to a register.
 
-        :raise KeyError: If `reg` is not a register name.
+        :raise RegisterAccessError: If `reg` is not a register name.
         """
         regname = self.arch.to_regname(reg)
         if regname is None:
-            raise KeyError(f'Not a register name: {reg}')
+            raise RegisterAccessError(reg, f'Not a register name: {reg}')
 
         self.regs[regname] = value
 
diff --git a/focaccia/symbolic.py b/focaccia/symbolic.py
index 50108d4..66ccaef 100644
--- a/focaccia/symbolic.py
+++ b/focaccia/symbolic.py
@@ -14,10 +14,11 @@ from .arch import Arch, supported_architectures
 from .lldb_target import LLDBConcreteTarget, \
                          ConcreteRegisterError, \
                          ConcreteMemoryError
-from .miasm_util import MiasmConcreteState, eval_expr
-from .snapshot import ProgramState
+from .miasm_util import MiasmSymbolResolver, eval_expr
+from .snapshot import ProgramState, ReadableProgramState, \
+                      RegisterAccessError, MemoryAccessError
 
-def eval_symbol(symbol: Expr, conc_state: ProgramState) -> int:
+def eval_symbol(symbol: Expr, conc_state: ReadableProgramState) -> int:
     """Evaluate a symbol based on a concrete reference state.
 
     :param conc_state: A concrete state.
@@ -28,21 +29,19 @@ def eval_symbol(symbol: Expr, conc_state: ProgramState) -> int:
     :raise MemoryAccessError: If the concrete state does not contain memory
                               that is referenced by the symbolic expression.
     """
-    class ConcreteStateWrapper(MiasmConcreteState):
+    class ConcreteStateWrapper(MiasmSymbolResolver):
         """Extend the state resolver with assumptions about the expressions
         that may be resolved with `eval_symbol`."""
-        def __init__(self, conc_state: ProgramState):
+        def __init__(self, conc_state: ReadableProgramState):
             super().__init__(conc_state, LocationDB())
 
         def resolve_register(self, regname: str) -> int:
-            regname = regname.upper()
-            regname = self.miasm_flag_aliases.get(regname, regname)
-            return self._state.read_register(regname)
+            return self._state.read_register(self._miasm_to_regname(regname))
 
         def resolve_memory(self, addr: int, size: int) -> bytes:
             return self._state.read_memory(addr, size)
 
-        def resolve_location(self, _):
+        def resolve_location(self, loc):
             raise ValueError(f'[In eval_symbol]: Unable to evaluate symbols'
                              f' that contain IR location expressions.')
 
@@ -190,22 +189,20 @@ class SymbolicTransform:
         """
         accessed_regs = set[str]()
 
-        class ConcreteStateWrapper(MiasmConcreteState):
+        class RegisterCollector(MiasmSymbolResolver):
             def __init__(self): pass
             def resolve_register(self, regname: str) -> int | None:
-                accessed_regs.add(regname)
+                accessed_regs.add(self._miasm_to_regname(regname))
                 return None
-            def resolve_memory(self, addr: int, size: int):
-                pass
-            def resolve_location(self, _):
-                assert(False)
+            def resolve_memory(self, addr: int, size: int): pass
+            def resolve_location(self, loc): assert(False)
 
-        state = ConcreteStateWrapper()
+        resolver = RegisterCollector()
         for expr in self.changed_regs.values():
-            eval_expr(expr, state)
+            eval_expr(expr, resolver)
         for addr_expr, mem_expr in self.changed_mem.items():
-            eval_expr(addr_expr, state)
-            eval_expr(mem_expr, state)
+            eval_expr(addr_expr, resolver)
+            eval_expr(mem_expr, resolver)
 
         return list(accessed_regs)
 
@@ -388,7 +385,7 @@ class DisassemblyError(Exception):
         self.faulty_pc = faulty_pc
         self.err_msg = err_msg
 
-def _run_block(pc: int, conc_state: MiasmConcreteState, ctx: DisassemblyContext) \
+def _run_block(pc: int, conc_state: MiasmSymbolResolver, ctx: DisassemblyContext) \
         -> tuple[int | None, list[dict]]:
     """Run a basic block.
 
@@ -457,22 +454,22 @@ def _run_block(pc: int, conc_state: MiasmConcreteState, ctx: DisassemblyContext)
             # instructions are translated to multiple IR instructions.
             pass
 
-class _LLDBConcreteState:
-    """A back-end replacement for the `ProgramState` object from which
-    `MiasmConcreteState` reads its values. This reads values directly from an
-    LLDB target instead. This saves us the trouble of recording a full program
-    state, and allows us instead to read values from LLDB on demand.
+class _LLDBConcreteState(ReadableProgramState):
+    """A wrapper around `LLDBConcreteTarget` that provides access via a
+    `ReadableProgramState` interface. Reads values directly from an LLDB
+    target. This saves us the trouble of recording a full program state, and
+    allows us instead to read values from LLDB on demand.
     """
     def __init__(self, target: LLDBConcreteTarget, arch: Arch):
         self._target = target
         self._arch = arch
 
-    def read_register(self, reg: str) -> int | None:
+    def read_register(self, reg: str) -> int:
         from focaccia.arch import x86
 
         regname = self._arch.to_regname(reg)
         if regname is None:
-            return None
+            raise RegisterAccessError(reg, f'Not a register name: {reg}')
 
         try:
             return self._target.read_register(regname)
@@ -482,13 +479,13 @@ class _LLDBConcreteState:
                 rflags = x86.decompose_rflags(self._target.read_register('rflags'))
                 if regname in rflags:
                     return rflags[regname]
-            return None
+            raise RegisterAccessError(regname, '')
 
-    def read_memory(self, addr: int, size: int):
+    def read_memory(self, addr: int, size: int) -> bytes:
         try:
             return self._target.read_memory(addr, size)
         except ConcreteMemoryError:
-            return None
+            raise MemoryAccessError(addr, size, 'Unable to read memory from LLDB.')
 
 def collect_symbolic_trace(binary: str,
                            args: list[str],
@@ -533,7 +530,7 @@ def collect_symbolic_trace(binary: str,
         try:
             pc, strace = _run_block(
                 pc,
-                MiasmConcreteState(conc_state, ctx.loc_db),
+                MiasmSymbolResolver(conc_state, ctx.loc_db),
                 ctx)
         except DisassemblyError as err:
             # This happens if we encounter an instruction that is not