about summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorTheofilos Augoustis <theofilos.augoustis@gmail.com>2023-11-27 14:08:55 +0100
committerTheofilos Augoustis <theofilos.augoustis@gmail.com>2023-11-27 14:08:55 +0100
commit836e42215fda0cbd330caef2dc5fc93336d4722c (patch)
tree0b9ce5cca67c511b74b9ae91a8fda2fc0a35e65c
parent5d51b4fe0bb41bc9e86c5775de35a9aef023fec5 (diff)
downloadfocaccia-836e42215fda0cbd330caef2dc5fc93336d4722c.tar.gz
focaccia-836e42215fda0cbd330caef2dc5fc93336d4722c.zip
Add memory storage capabilities to `ProgramState`
The `SparseMemory` class represents a program's memory. While the user
can read from and write to arbitrary memory addresses, it manages its
memory in pages/chunks internally. This is a tradeoff between space
consumption (this solution might have a memory overhead) and lookup
speed of individual memory addresses.

Add two small unit tests for `SparseMemory`.
-rw-r--r--miasm_test.py172
-rw-r--r--miasm_util.py102
-rw-r--r--snapshot.py78
-rw-r--r--test/test_sparse_memory.py33
4 files changed, 229 insertions, 156 deletions
diff --git a/miasm_test.py b/miasm_test.py
index 7ec76a9..64dc04a 100644
--- a/miasm_test.py
+++ b/miasm_test.py
@@ -1,21 +1,16 @@
 import sys
-from typing import Any
 
 from miasm.arch.x86.sem import Lifter_X86_64
 from miasm.analysis.machine import Machine
-from miasm.analysis.binary import Container, ContainerELF
-from miasm.core.asmblock import disasmEngine, AsmCFG
-from miasm.core.interval import interval
+from miasm.analysis.binary import ContainerELF
 from miasm.core.locationdb import LocationDB
-from miasm.expression.expression import ExprId, ExprInt, ExprLoc
 from miasm.ir.symbexec import SymbolicExecutionEngine, SymbolicState
-from miasm.ir.ir import IRBlock, AsmBlock
-from miasm.analysis.dse import DSEEngine
 
+from arch import x86
 from lldb_target import LLDBConcreteTarget, SimConcreteMemoryError, \
                         SimConcreteRegisterError
-from arch import x86
-from miasm_util import MiasmProgramState, eval_expr
+from miasm_util import MiasmConcreteState, eval_expr
+from snapshot import ProgramState
 
 def print_blocks(asmcfg, file=sys.stdout):
     print('=' * 80, file=file)
@@ -30,95 +25,54 @@ def print_state(state: SymbolicState):
         print(f'{str(reg):10s} = {val}')
     print('=' * 80)
 
-def flag_names_to_miasm(regs: dict[str, Any]) -> dict:
-    """Convert standard flag names to Miasm's names.
-
-    :param regs: Modified in-place.
-    :return: Returns `regs`.
-    """
-    regs['NF']     = regs.pop('SF')
-    regs['I_F']    = regs.pop('IF')
-    regs['IOPL_F'] = regs.pop('IOPL')
-    regs['I_D']    = regs.pop('ID')
-    return regs
+def create_state(target: LLDBConcreteTarget) -> ProgramState:
+    def standardize_flag_name(regname: str) -> str:
+        regname = regname.upper()
+        if regname in MiasmConcreteState.miasm_flag_aliases:
+            return MiasmConcreteState.miasm_flag_aliases[regname]
+        return regname
 
-def disasm_elf(addr, mdis: disasmEngine) -> AsmCFG:
-    """Try to disassemble all contents of an ELF file.
-
-    Based on the full-disassembly algorithm in
-    `https://github.com/cea-sec/miasm/blob/master/example/disasm/full.py`
-    (as of commit `a229f4e`).
-
-    :return: An asmcfg.
-    """
-    # Settings for the engine
-    mdis.follow_call = True
-
-    # Initial run
-    asmcfg = mdis.dis_multiblock(addr)
-
-    todo = [addr]
-    done = set()
-    done_interval = interval()
-
-    while todo:
-        while todo:
-            ad = todo.pop(0)
-            if ad in done:
-                continue
-            done.add(ad)
-            asmcfg = mdis.dis_multiblock(ad, asmcfg)
-
-            for block in asmcfg.blocks:
-                for l in block.lines:
-                    done_interval += interval([(l.offset, l.offset + l.l)])
-
-            # Process recursive functions
-            for block in asmcfg.blocks:
-                instr = block.get_subcall_instr()
-                if not instr:
-                    continue
-                for dest in instr.getdstflow(mdis.loc_db):
-                    if not dest.is_loc():
-                        continue
-                    offset = mdis.loc_db.get_location_offset(dest.loc_key)
-                    todo.append(offset)
-
-        # Disassemble all:
-        for _, b in done_interval.intervals:
-            if b in done:
-                continue
-            todo.append(b)
-
-    return asmcfg
-
-def create_state(target: LLDBConcreteTarget) -> MiasmProgramState:
-    regs: dict[ExprId, ExprInt] = {}
-    mem = []
+    state = ProgramState(x86.ArchX86())
 
     # Query and store register state
-    rflags = target.read_register('rflags')
-    rflags = flag_names_to_miasm(x86.decompose_rflags(rflags))
+    rflags = x86.decompose_rflags(target.read_register('rflags'))
     for reg in machine.mn.regs.all_regs_ids_no_alias:
-        regname = reg.name.upper()  # Make flag names upper case, too
+        regname = reg.name
         try:
             conc_val = target.read_register(regname)
-            regs[reg] = ExprInt(conc_val, reg.size)
+            state.set(regname, conc_val)
+        except KeyError:
+            pass
         except SimConcreteRegisterError:
+            regname = standardize_flag_name(regname)
             if regname in rflags:
-                regs[reg] = ExprInt(rflags[regname], reg.size)
+                state.set(regname, rflags[regname])
 
     # Query and store memory state
     for mapping in target.get_mappings():
         assert(mapping.end_address > mapping.start_address)
         size = mapping.end_address - mapping.start_address
         try:
-            mem_state = target.read_memory(mapping.start_address, size)
+            data = target.read_memory(mapping.start_address, size)
+            state.write_memory(mapping.start_address, data)
         except SimConcreteMemoryError:
-            mem_state = f'<unable to access "{mapping.name}">'
-        mem.append((mapping, mem_state))
+            # Unable to read memory from mapping
+            pass
+
+    return state
 
-    return MiasmProgramState(regs, mem)
+def record_concrete_states(binary) -> list[tuple[int, ProgramState]]:
+    """Record a trace of concrete program states by stepping through an
+    executable.
+    """
+    addrs = set()
+    states = []
+    target = LLDBConcreteTarget(binary)
+    while not target.is_exited():
+        addrs.add(target.read_register('pc'))
+        states.append((target.read_register('pc'), create_state(target)))
+        target.step()
+    return states
 
 binary = 'test_program'
 
@@ -131,7 +85,9 @@ pc = int(cont.entry_point)
 # Disassemble binary
 print(f'Disassembling "{binary}"...')
 mdis = machine.dis_engine(cont.bin_stream, loc_db=loc_db)
-asmcfg = disasm_elf(pc, mdis)
+mdis.follow_call = True
+asmcfg = mdis.dis_multiblock(pc)
+
 with open('full_disasm', 'w') as file:
     print(f'Entry point: {hex(pc)}\n', file=file)
     print_blocks(asmcfg, file)
@@ -149,19 +105,13 @@ with open('full_ir', 'w') as file:
     print('=' * 80, file=file)
 print(f'--- Lifted disassembly to IR. Log written to "full_ir.log".')
 
-def record_concrete_states(binary):
-    states = {}
-    target = LLDBConcreteTarget(binary)
-    while not target.is_exited():
-        states[target.read_register('pc')] = create_state(target)
-        target.step()
-    return states
-
+# Record concrete reference states to guide symbolic execution
 print(f'Recording concrete program trace...')
-conc_states = record_concrete_states(binary)
-print(f'Recorded {len(conc_states)} trace points.')
+conc_trace = record_concrete_states(binary)
+conc_trace = [(a, MiasmConcreteState(s, loc_db)) for a, s in conc_trace]
+print(f'Recorded {len(conc_trace)} trace points.')
 
-def run_block(pc: int, conc_state: MiasmProgramState) -> int | None:
+def run_block(pc: int, conc_state: MiasmConcreteState) -> int | None:
     """Run a basic block.
 
     Tries to run IR blocks until the end of an ASM block/basic block is
@@ -186,11 +136,19 @@ def run_block(pc: int, conc_state: MiasmProgramState) -> int | None:
         # The new program counter might be a symbolic value. Try to evaluate
         # it based on the last recorded concrete state at the start of the
         # current basic block.
-        pc = eval_expr(symbolic_pc, conc_state, loc_db)
+        pc = eval_expr(symbolic_pc, conc_state)
+
+        # Initial disassembly might not find all blocks in the binary.
+        # Disassemble code ad-hoc if the new PC has not yet been disassembled.
         if ircfg.get_block(pc) is None:
-            print(f'Unable to access IR block at PC {pc}'
-                  f' (evaluated from the expression PC = {symbolic_pc}).')
-            return None
+            addr = int(pc)
+            cfg = mdis.dis_multiblock(addr)
+            for block in cfg.blocks:
+                lifter.add_asmblock_to_ircfg(block, ircfg)
+            assert(ircfg.get_block(pc) is not None)
+
+            print(f'Disassembled {len(cfg.blocks):4} new blocks at {hex(addr)}'
+                  f' (evaluated from symbolic PC {symbolic_pc}).')
 
         # If the resulting PC is an integer, i.e. a concrete address that can
         # be mapped to the assembly code, we return as we have reached the end
@@ -207,14 +165,28 @@ last_pc = None  # Debugging info
 # Run until no more states can be reached
 print(f'Re-tracing symbolically...')
 while pc is not None:
+    def step_trace(trace, pc: int):
+        for i, (addr, _) in enumerate(trace):
+            if addr == pc:
+                return trace[i:]
+        return []
+
     assert(type(pc) is int)
-    if pc not in conc_states:
+
+    # Find next trace point (the concrete trace may have stopped at more
+    # states than the symbolic trace does)
+    conc_trace = step_trace(conc_trace, pc)
+    if not conc_trace:
         print(f'Next PC {hex(pc)} is not contained in the concrete program'
               f' trace. Last valid PC: {hex(last_pc)}')
         break
     last_pc = pc
 
-    initial_state = conc_states[pc]
+    addr, initial_state = conc_trace[0]
+    assert(addr == pc)
+    conc_trace.pop(0)
+
+    # Run symbolic execution
     pc = run_block(pc, initial_state)
 
 print(f'--- No new PC found. Exiting.')
diff --git a/miasm_util.py b/miasm_util.py
index 4f20dd8..31083d9 100644
--- a/miasm_util.py
+++ b/miasm_util.py
@@ -1,43 +1,36 @@
-from angr_targets.memory_map import MemoryMap
-from miasm.core.locationdb import LocationDB
+from miasm.core.locationdb import LocationDB, LocKey
 from miasm.expression.expression import Expr, ExprOp, ExprId, ExprLoc, \
                                         ExprInt, ExprMem, ExprCompose, \
                                         ExprSlice, ExprCond
 from miasm.expression.simplifications import expr_simp_explicit
 
-class MiasmProgramState:
-    def __init__(self,
-                 regs: dict[ExprId, ExprInt],
-                 mem: list[tuple[MemoryMap, bytes]]
-                 ):
-        self.regs = regs
-        self.memory = mem
-
-    def _find_mem_map(self, addr: int) \
-            -> tuple[MemoryMap, bytes] | tuple[None, None]:
-        for m, data in self.memory:
-            if addr >= m.start_address and addr < m.end_address:
-                return m, data
-        return None, None
-
-    def read_memory(self, addr: int, size: int) -> bytes:
-        res = bytes()
-        while size > 0:
-            m, data = self._find_mem_map(addr)
-            if m is None:
-                raise AttributeError(f'No memory mapping contains the address {addr}.')
-
-            assert(m is not None and data is not None)
-            read_off = addr - m.start_address
-            read_size = min(size, m.end_address - addr)
-            assert(read_off + read_size <= len(data))
-            res += data[read_off:read_off+read_size]
-
-            size -= read_size
-            addr += read_size
-        return res
-
-def eval_expr(expr: Expr, conc_state: MiasmProgramState, loc_db) -> int:
+from snapshot import ProgramState
+
+class MiasmConcreteState:
+    miasm_flag_aliases = {
+        'NF':     'SF',
+        'I_F':    'IF',
+        'IOPL_F': 'IOPL',
+        'I_D':    'ID',
+    }
+
+    def __init__(self, state: ProgramState, loc_db: LocationDB):
+        self.state = state
+        self.loc_db = loc_db
+
+    def resolve_register(self, regname: str) -> int:
+        regname = regname.upper()
+        if regname in self.miasm_flag_aliases:
+            regname = self.miasm_flag_aliases[regname]
+        return self.state.read(regname)
+
+    def resolve_memory(self, addr: int, size: int) -> bytes:
+        return self.state.read_memory(addr, size)
+
+    def resolve_location(self, loc: LocKey) -> int | None:
+        return self.loc_db.get_location_offset(loc)
+
+def eval_expr(expr: Expr, conc_state: MiasmConcreteState) -> int:
     # Most of these implementation are just copy-pasted members of
     # `SymbolicExecutionEngine`.
     expr_to_visitor = {
@@ -55,28 +48,29 @@ def eval_expr(expr: Expr, conc_state: MiasmProgramState, loc_db) -> int:
     if visitor is None:
         raise TypeError("Unknown expr type")
 
-    ret = visitor(expr, conc_state, loc_db)
+    ret = visitor(expr, conc_state)
     ret = expr_simp_explicit(ret)
     assert(ret is not None)
 
     return ret
 
-def _eval_exprint(expr: ExprInt, _, __: LocationDB):
+def _eval_exprint(expr: ExprInt, _):
     """Evaluate an ExprInt using the current state"""
     return expr
 
-def _eval_exprid(expr: ExprId, state: MiasmProgramState, _):
+def _eval_exprid(expr: ExprId, state: MiasmConcreteState):
     """Evaluate an ExprId using the current state"""
-    return state.regs[expr]
+    val = state.resolve_register(expr.name)
+    return ExprInt(val, expr.size)
 
-def _eval_exprloc(expr: ExprLoc, _, loc_db: LocationDB):
+def _eval_exprloc(expr: ExprLoc, state: MiasmConcreteState):
     """Evaluate an ExprLoc using the current state"""
-    offset = loc_db.get_location_offset(expr.loc_key)
+    offset = state.resolve_location(expr.loc_key)
     if offset is not None:
         return ExprInt(offset, expr.size)
     return expr
 
-def _eval_exprmem(expr: ExprMem, state: MiasmProgramState, loc_db: LocationDB):
+def _eval_exprmem(expr: ExprMem, state: MiasmConcreteState):
     """Evaluate an ExprMem using the current state.
     This function first evaluates the memory pointer value.
     """
@@ -87,34 +81,34 @@ def _eval_exprmem(expr: ExprMem, state: MiasmProgramState, loc_db: LocationDB):
     assert(expr.size <= 64)
     assert(expr.size % 8 == 0)
 
-    addr = eval_expr(expr.ptr, state, loc_db)
-    ret = state.read_memory(int(addr), int(expr.size / 8))
+    addr = eval_expr(expr.ptr, state)
+    ret = state.resolve_memory(int(addr), int(expr.size / 8))
     assert(len(ret) * 8 == expr.size)
     return ExprInt(int.from_bytes(ret, byteorder='little'), expr.size)
 
-def _eval_exprcond(expr, state: MiasmProgramState, loc_db: LocationDB):
+def _eval_exprcond(expr, state: MiasmConcreteState):
     """Evaluate an ExprCond using the current state"""
-    cond = eval_expr(expr.cond, state, loc_db)
-    src1 = eval_expr(expr.src1, state, loc_db)
-    src2 = eval_expr(expr.src2, state, loc_db)
+    cond = eval_expr(expr.cond, state)
+    src1 = eval_expr(expr.src1, state)
+    src2 = eval_expr(expr.src2, state)
     return ExprCond(cond, src1, src2)
 
-def _eval_exprslice(expr, state: MiasmProgramState, loc_db: LocationDB):
+def _eval_exprslice(expr, state: MiasmConcreteState):
     """Evaluate an ExprSlice using the current state"""
-    arg = eval_expr(expr.arg, state, loc_db)
+    arg = eval_expr(expr.arg, state)
     return ExprSlice(arg, expr.start, expr.stop)
 
-def _eval_exprop(expr, state: MiasmProgramState, loc_db: LocationDB):
+def _eval_exprop(expr, state: MiasmConcreteState):
     """Evaluate an ExprOp using the current state"""
     args = []
     for oarg in expr.args:
-        arg = eval_expr(oarg, state, loc_db)
+        arg = eval_expr(oarg, state)
         args.append(arg)
     return ExprOp(expr.op, *args)
 
-def _eval_exprcompose(expr, state: MiasmProgramState, loc_db: LocationDB):
+def _eval_exprcompose(expr, state: MiasmConcreteState):
     """Evaluate an ExprCompose using the current state"""
     args = []
     for arg in expr.args:
-        args.append(eval_expr(arg, state, loc_db))
+        args.append(eval_expr(arg, state))
     return ExprCompose(*args)
diff --git a/snapshot.py b/snapshot.py
index 3170649..a4bfb0f 100644
--- a/snapshot.py
+++ b/snapshot.py
@@ -1,4 +1,71 @@
 from arch.arch import Arch
+from interpreter import SymbolResolver, SymbolResolveError
+
+class MemoryAccessError(Exception):
+    def __init__(self, msg: str):
+        super().__init__(msg)
+
+class SparseMemory:
+    """Sparse memory.
+
+    Note that out-of-bound reads are possible when performed on unwritten
+    sections of existing pages and that there is no safeguard check for them.
+    """
+    def __init__(self, page_size=1024):
+        self.page_size = page_size
+        self._pages: dict[int, bytes] = {}
+
+    def _to_page_addr_and_offset(self, addr: int) -> tuple[int, int]:
+        off = addr % self.page_size
+        return addr - off, off
+
+    def read(self, addr: int, size: int) -> bytes:
+        """Read a number of bytes from memory.
+        :param addr: The offset from where to read.
+        :param size: The number of bytes to read, starting at at `addr`.
+
+        :return: `size` bytes of data.
+        :raise MemoryAccessError: If `[addr, addr + size)` is not entirely
+                                  contained in the set of stored bytes.
+        :raise ValueError: If `size < 0`.
+        """
+        if size < 0:
+            raise ValueError(f'A negative size is not allowed!')
+
+        res = bytes()
+        while size > 0:
+            page_addr, off = self._to_page_addr_and_offset(addr)
+            if page_addr not in self._pages:
+                raise MemoryAccessError(f'Address {addr} is not contained in'
+                                        f' the sparse memory.')
+            data = self._pages[page_addr]
+            assert(len(data) == self.page_size)
+            read_size = min(size, self.page_size - off)
+            res += data[off:off+read_size]
+
+            size -= read_size
+            addr += read_size
+        return res
+
+    def write(self, addr: int, data: bytes):
+        """Store bytes in the memory.
+        :param addr: The address at which to store the data.
+        :param data: The data to store at `addr`.
+        """
+        while len(data) > 0:
+            page_addr, off = self._to_page_addr_and_offset(addr)
+            if page_addr not in self._pages:
+                self._pages[page_addr] = bytes(self.page_size)
+            page = self._pages[page_addr]
+            assert(len(page) == self.page_size)
+
+            write_size = min(len(data), self.page_size - off)
+            new_page = page[:off] + data[:write_size] + page[off+write_size:]
+            assert(len(new_page) == self.page_size)
+            self._pages[page_addr] = new_page
+
+            data = data[write_size:]
+            addr += write_size
 
 class ProgramState:
     """A snapshot of the program's state."""
@@ -7,6 +74,7 @@ class ProgramState:
 
         dict_t = dict[str, int | None]
         self.regs: dict_t = { reg: None for reg in arch.regnames }
+        self.mem = SparseMemory()
 
     def read(self, reg: str) -> int:
         """Read a register's value.
@@ -28,14 +96,20 @@ class ProgramState:
     def set(self, reg: str, value: int):
         """Assign a value to a register.
 
-        :raise KeyError:   If `reg` is not a register name.
+        :raise KeyError: If `reg` is not a register name.
         """
         regname = self.arch.to_regname(reg)
         if regname is None:
-            raise KeyError(f'Not a register name: {regname}')
+            raise KeyError(f'Not a register name: {reg}')
 
         self.regs[regname] = value
 
+    def read_memory(self, addr: int, size: int) -> bytes:
+        return self.mem.read(addr, size)
+
+    def write_memory(self, addr: int, data: bytes):
+        self.mem.write(addr, data)
+
     def __repr__(self):
         return repr(self.regs)
 
diff --git a/test/test_sparse_memory.py b/test/test_sparse_memory.py
new file mode 100644
index 0000000..87b4456
--- /dev/null
+++ b/test/test_sparse_memory.py
@@ -0,0 +1,33 @@
+import unittest
+
+from snapshot import SparseMemory, MemoryAccessError
+
+class TestSparseMemory(unittest.TestCase):
+    def test_oob_read(self):
+        mem = SparseMemory()
+        for addr in range(mem.page_size):
+            self.assertRaises(MemoryAccessError, mem.read, addr, 1)
+            self.assertRaises(MemoryAccessError, mem.read, addr, 30)
+            self.assertRaises(MemoryAccessError, mem.read, addr + 0x10, 30)
+            self.assertRaises(MemoryAccessError, mem.read, addr, mem.page_size)
+            self.assertRaises(MemoryAccessError, mem.read, addr, mem.page_size - 1)
+            self.assertRaises(MemoryAccessError, mem.read, addr, mem.page_size + 1)
+
+    def test_basic_read_write(self):
+        mem = SparseMemory()
+
+        data = b'a' * mem.page_size * 2
+        mem.write(0x300, data)
+        self.assertEqual(mem.read(0x300, len(data)), data)
+        self.assertEqual(mem.read(0x300, 1), b'a')
+        self.assertEqual(mem.read(0x400, 1), b'a')
+        self.assertEqual(mem.read(0x299 + mem.page_size * 2, 1), b'a')
+        self.assertEqual(mem.read(0x321, 12), b'aaaaaaaaaaaa')
+
+        mem.write(0x321, b'Hello World!')
+        self.assertEqual(mem.read(0x321, 12), b'Hello World!')
+
+        self.assertRaises(MemoryAccessError, mem.read, 0x300, mem.page_size * 3)
+
+if __name__ == '__main__':
+    unittest.main()