about summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorTheofilos Augoustis <theofilos.augoustis@gmail.com>2024-02-02 17:54:41 +0100
committerTheofilos Augoustis <theofilos.augoustis@gmail.com>2024-02-02 17:54:41 +0100
commita24c5f12d6d909898472f7208cbe16b086a001c9 (patch)
tree64c25b6f537295b1529bff5d12a3fb0cac80d236
parentc6503a3ddcce2fdefc5c93c6901f26c761ae859b (diff)
downloadfocaccia-a24c5f12d6d909898472f7208cbe16b086a001c9.tar.gz
focaccia-a24c5f12d6d909898472f7208cbe16b086a001c9.zip
Basic reproducer generator setup
Co-authored-by: Alp Berkman <alp.berkman@no-reply.com>
Co-authored-by: Theofilos Augoustis <theofilos.augoustis@gmail.com>
-rw-r--r--.gitignore4
-rwxr-xr-xfocaccia.py54
-rw-r--r--focaccia/compare.py3
-rw-r--r--focaccia/lldb_target.py38
-rw-r--r--focaccia/parser.py1
-rw-r--r--focaccia/reproducer.py172
-rwxr-xr-x[-rw-r--r--]tools/capture_transforms.py2
-rwxr-xr-x[-rw-r--r--]tools/convert.py2
-rwxr-xr-x[-rw-r--r--]tools/verify_qemu.py2
9 files changed, 257 insertions, 21 deletions
diff --git a/.gitignore b/.gitignore
index 39b5bdb..ea2880a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -12,3 +12,7 @@ __pycache__/
 
 # Dev environment
 .gdbinit
+
+# Focaccia files
+qemu.sym
+qemu.trace
diff --git a/focaccia.py b/focaccia.py
index bf5a9ff..0637690 100755
--- a/focaccia.py
+++ b/focaccia.py
@@ -1,17 +1,21 @@
-#! /usr/bin/env python3
+#!/usr/bin/env python3
 
 import argparse
 import platform
-from typing import Iterable
+from typing import Iterable, Tuple
 
 from focaccia.arch import supported_architectures
 from focaccia.compare import compare_simple, compare_symbolic, ErrorTypes
 from focaccia.lldb_target import LLDBConcreteTarget
-from focaccia.match import fold_traces
-from focaccia.parser import parse_arancini
+from focaccia.match import fold_traces, match_traces
+from focaccia.parser import parse_arancini, parse_snapshots
 from focaccia.snapshot import ProgramState
 from focaccia.symbolic import collect_symbolic_trace
 from focaccia.utils import print_result
+from focaccia.reproducer import Reproducer
+from focaccia.compare import ErrorSeverity
+
+
 
 verbosity = {
     'info':    ErrorTypes.INFO,
@@ -20,7 +24,7 @@ verbosity = {
 }
 
 def collect_concrete_trace(oracle_program: str, breakpoints: Iterable[int]) \
-        -> list[ProgramState]:
+        -> Tuple[list[ProgramState], list]:
     """Gather snapshots from a native execution via an external debugger.
 
     :param oracle_program: Program to execute.
@@ -37,11 +41,13 @@ def collect_concrete_trace(oracle_program: str, breakpoints: Iterable[int]) \
 
     # Execute the native program
     snapshots = []
+    basic_blocks = []
     while not target.is_exited():
         snapshots.append(target.record_snapshot())
+        basic_blocks.append(target.get_next_basic_block())
         target.run()
 
-    return snapshots
+    return snapshots, basic_blocks
 
 def parse_arguments():
     parser = argparse.ArgumentParser(description='Comparator for emulator logs to reference')
@@ -77,9 +83,32 @@ def parse_arguments():
                              ' may as well stem from incomplete input data.'
                              ' \'info\' will report absolutely everything.'
                              ' [Default: warning]')
+    parser.add_argument('-r', '--reproducer',
+                        action='store_true',
+                        default=False,
+                        help='Enable reproducer to get assembly code'
+                             ' which should replicate the first error.')
+    parser.add_argument('--trace-type',
+                        type=str,
+                        default='qemu',
+                        choices=['qemu', 'arancini'],
+                        help='Trace type of the emulator.'
+                             ' Currently only Qemu and Arancini traces are accepted.'
+                             ' Use \'qemu\' for Qemu and \'arancini\' for Arancini.'
+                             ' [Default: qemu]')
     args = parser.parse_args()
     return args
 
+def print_reproducer(result, min_severity: ErrorSeverity, oracle, oracle_args):
+    for res in result:
+        errs = [e for e in res['errors'] if e.severity >= min_severity]
+        #breakpoint()
+        if errs:
+            rep = Reproducer(oracle, oracle_args, res['snap'], res['ref'])
+            print(rep.asm())
+            return
+
+
 def main():
     args = parse_arguments()
 
@@ -97,13 +126,19 @@ def main():
 
     # Parse reference trace
     with open(txl_path, "r") as txl_file:
-        test_states = parse_arancini(txl_file, arch)
+        if args.trace_type == 'qemu':
+            test_states = parse_snapshots(txl_file)
+        elif args.trace_type == 'arancini':
+            test_states = parse_arancini(txl_file, arch)
+        else:
+            test_states = parse_snapshots(txl_file)
 
     # Compare reference trace to a truth
     if args.symbolic:
         print(f'Tracing {oracle} symbolically with arguments {oracle_args}...')
         transforms = collect_symbolic_trace(oracle, oracle_args)
-        fold_traces(test_states, transforms)
+        test_states, transforms = match_traces(test_states, transforms)
+        #fold_traces(test_states, transforms)
         result = compare_symbolic(test_states, transforms)
     else:
         # Record truth states from a concrete execution of the oracle
@@ -113,5 +148,8 @@ def main():
 
     print_result(result, verbosity[args.error_level])
 
+    if args.reproducer:
+        print_reproducer(result, verbosity[args.error_level], oracle, oracle_args)
+
 if __name__ == '__main__':
     main()
diff --git a/focaccia/compare.py b/focaccia/compare.py
index 43a0133..65c0f49 100644
--- a/focaccia/compare.py
+++ b/focaccia/compare.py
@@ -316,7 +316,8 @@ def compare_symbolic(test_states: Iterable[ProgramState],
                 'pc': pc_cur,
                 'txl': _calc_transformation(cur_state, next_state),
                 'ref': transform,
-                'errors': errors
+                'errors': errors,
+                'snap': cur_state,
             })
 
             # Step forward
diff --git a/focaccia/lldb_target.py b/focaccia/lldb_target.py
index 903e73d..b51ec3d 100644
--- a/focaccia/lldb_target.py
+++ b/focaccia/lldb_target.py
@@ -77,15 +77,14 @@ class LLDBConcreteTarget:
         """Step forward by a single instruction."""
         thread: lldb.SBThread = self.process.GetThreadAtIndex(0)
         thread.StepInstruction(False)
-    
+
     def run_until(self, address: int) -> None:
         """Continue execution until the address is arrived, ignores other breakpoints"""
         bp = self.target.BreakpointCreateByAddress(address)
         while self.read_register("pc") != address:
-            self.target.run()
+            self.run()
         self.target.BreakpointDelete(bp.GetID())
 
-
     def record_snapshot(self) -> ProgramState:
         """Record the concrete target's state in a ProgramState object."""
         # Determine current arch
@@ -219,8 +218,8 @@ class LLDBConcreteTarget:
         command = f'breakpoint delete {addr}'
         result = lldb.SBCommandReturnObject()
         self.interpreter.HandleCommand(command, result)
-    
-    def get_basic_block(self, addr: int) -> [lldb.SBInstruction]:
+
+    def get_basic_block(self, addr: int) -> list[lldb.SBInstruction]:
         """Returns a basic block pointed by addr
         a code section is considered a basic block only if
         the last instruction is a brach, e.g. JUMP, CALL, RET
@@ -232,12 +231,29 @@ class LLDBConcreteTarget:
         block.append(self.target.ReadInstructions(lldb.SBAddress(addr, self.target), 1)[0])
 
         return block
-    
+
+    def get_basic_block_inst(self, addr: int) -> list[str]:
+        inst = []
+        for bb in self.get_basic_block(addr):
+            inst.append(f'{bb.GetMnemonic(self.target)} {bb.GetOperands(self.target)}')
+        return inst
+
+    def get_next_basic_block(self) -> list[lldb.SBInstruction]:
+        return self.get_basic_block(self.read_register("pc"))
+
     def get_symbol(self, addr: int) -> lldb.SBSymbol:
-        """Returns the symbol that belongs to the addr"""
-        for s in self.target.module.symbols:
-            if (s.GetType() == lldb.eSymbolTypeCode 
-            and s.GetStartAddress().GetLoadAddress(self.target) <= addr 
-            and addr < s.GetEndAddress().GetLoadAddress(self.target)):
+        """Returns the symbol that belongs to the addr
+        """
+        for s in self.module.symbols:
+            if (s.GetType() == lldb.eSymbolTypeCode and s.GetStartAddress().GetLoadAddress(self.target) <= addr  < s.GetEndAddress().GetLoadAddress(self.target)):
                 return s
         raise ConcreteSectionError(f'Error getting the symbol to which address {hex(addr)} belongs to')
+
+    def get_symbol_limit(self) -> int:
+        """Returns the address after all the symbols"""
+        addr = 0
+        for s in self.module.symbols:
+            if s.GetStartAddress().IsValid():
+                if s.GetStartAddress().GetLoadAddress(self.target) > addr:
+                    addr = s.GetEndAddress().GetLoadAddress(self.target)
+        return addr
diff --git a/focaccia/parser.py b/focaccia/parser.py
index a5a1014..9fb83d8 100644
--- a/focaccia/parser.py
+++ b/focaccia/parser.py
@@ -4,6 +4,7 @@ import base64
 import json
 import re
 from typing import TextIO
+import lldb
 
 from .arch import supported_architectures, Arch
 from .snapshot import ProgramState
diff --git a/focaccia/reproducer.py b/focaccia/reproducer.py
new file mode 100644
index 0000000..90e1378
--- /dev/null
+++ b/focaccia/reproducer.py
@@ -0,0 +1,172 @@
+
+from .lldb_target import LLDBConcreteTarget
+from .snapshot import ProgramState
+from .symbolic import SymbolicTransform, eval_symbol
+from .arch import x86
+
+class ReproducerMemoryError(Exception):
+    pass
+class ReproducerBasicBlockError(Exception):
+    pass
+class ReproducerRegisterError(Exception):
+    pass
+
+class Reproducer():
+    def __init__(self, oracle: str, argv: str, snap: ProgramState, sym: SymbolicTransform) -> None:
+
+        target = LLDBConcreteTarget(oracle)
+
+        self.pc = snap.read_register("pc")
+        self.bb = target.get_basic_block_inst(self.pc)
+        self.sl = target.get_symbol_limit()
+        self.snap = snap
+        self.sym = sym
+
+    def get_bb(self) -> str:
+        try:
+            asm = ""
+            asm += f'_bb_{hex(self.pc)}:\n'
+            for i in self.bb[:-1]:
+                asm += f'{i}\n'
+            asm += f'ret\n'
+            asm += f'\n'
+
+            return asm
+        except:
+            raise ReproducerBasicBlockError(f'{hex(self.pc)}\n{self.snap}\n{self.sym}\n{self.bb}')
+
+    def get_regs(self) -> str:
+        general_regs = ['RIP', 'RAX', 'RBX','RCX','RDX', 'RSI','RDI','RBP','RSP','R8','R9','R10','R11','R12','R13','R14','R15',]
+        flag_regs = ['CF', 'PF', 'AF', 'ZF', 'SF', 'TF', 'IF', 'DF', 'OF', 'IOPL', 'NT',]
+        eflag_regs = ['RF', 'VM', 'AC', 'VIF', 'VIP', 'ID',]
+
+        try:
+            asm = ""
+            asm += f'_setup_regs:\n'
+            for reg in self.sym.get_used_registers():
+                if reg in general_regs:
+                    asm += f'mov ${hex(self.snap.read_register(reg))}, %{reg.lower()}\n'
+
+            if 'RFLAGS' in self.sym.get_used_registers():
+                asm += f'pushfq ${hex(self.snap.read_register("RFLAGS"))}\n'
+
+            if any(reg in self.sym.get_used_registers() for reg in flag_regs+eflag_regs):
+                asm += f'pushfd ${hex(x86.compose_rflags(self.snap.regs))}\n'
+            asm += f'ret\n'
+            asm += f'\n'
+
+            return asm
+        except:
+            raise ReproducerRegisterError(f'{hex(self.pc)}\n{self.snap}\n{self.sym}\n{self.bb}')
+
+    def get_mem(self) -> str:
+        try:
+            asm = ""
+            asm += f'_setup_mem:\n'
+            for mem in self.sym.get_used_memory_addresses():
+                addr = eval_symbol(mem.ptr, self.snap)
+                val = self.snap.read_memory(addr, int(mem.size/8))
+
+                if addr < self.sl:
+                    asm += f'.org {hex(addr)}\n'
+                    for b in val:
+                        asm += f'.byte ${hex(b)}\n'
+            asm += f'\n'
+
+            return asm
+        except:
+            raise ReproducerMemoryError(f'{hex(self.pc)}\n{self.snap}\n{self.sym}\n{self.bb}')
+
+    def get_dyn(self) -> str:
+        try:
+            asm = ""
+            asm += f'_setup_dyn:\n'
+            for mem in self.sym.get_used_memory_addresses():
+                addr = eval_symbol(mem.ptr, self.snap)
+                val = self.snap.read_memory(addr, int(mem.size/8))
+
+                if addr >= self.sl:
+                    asm += f'mov ${hex(addr)}, %rdi\n'
+                    asm += f'call _alloc\n'
+                    for b in val:
+                        asm += f'mov ${hex(addr)}, %rax\n'
+                        asm += f'movb ${hex(b)}, (%rax)\n'
+                        addr += 1
+            asm += f'ret\n'
+            asm += f'\n'
+
+            return asm
+        except:
+            raise ReproducerMemoryError(f'{hex(self.pc)}\n{self.snap}\n{self.sym}\n{self.bb}')
+
+    def get_start(self) -> str:
+        asm = ""
+        asm += f'_start:\n'
+        asm += f'call _setup_dyn\n'
+        asm += f'call _setup_regs\n'
+        asm += f'call _bb_{hex(self.pc)}\n'
+        asm += f'call _exit\n'
+        asm += f'\n'
+
+        return asm
+
+    def get_exit(self) -> str:
+        asm = ""
+        asm += f'_exit:\n'
+        asm += f'movq $0, %rdi\n'
+        asm += f'movq $60, %rax\n'
+        asm += f'syscall\n'
+        asm += f'\n'
+
+        return asm
+
+    def get_alloc(self) -> str:
+        asm = ""
+        asm += f'_alloc:\n'
+        asm += f'movq $4096, %rsi\n'
+        asm += f'movq $(PROT_READ | PROT_WRITE), %rdx\n'
+        asm += f'movq $(MAP_PRIVATE | MAP_ANONYMOUS), %r10\n'
+        asm += f'movq $-1, %r8\n'
+        asm += f'movq $0, %r9\n'
+        asm += f'movq $syscall_mmap, %rax\n'
+        asm += f'syscall\n'
+        asm += f'ret\n'
+        asm += f'\n'
+
+        return asm
+
+    def get_code(self) -> str:
+        asm = ""
+        asm += f'.section .text\n'
+        asm += f'.global _start\n'
+        asm += f'\n'
+        asm += f'.org {hex(self.pc)}\n'
+        asm += self.get_bb()
+        asm += self.get_start()
+        asm += self.get_exit()
+        asm += self.get_alloc()
+        asm += self.get_regs()
+        asm += self.get_dyn()
+
+        return asm
+
+    def get_data(self) -> str:
+        asm = ""
+        asm += f'.section .data\n'
+        asm += f'PROT_READ  = 0x1\n'
+        asm += f'PROT_WRITE = 0x2\n'
+        asm += f'MAP_PRIVATE = 0x2\n'
+        asm += f'MAP_ANONYMOUS = 0x20\n'
+        asm += f'syscall_mmap = 9\n'
+        asm += f'\n'
+
+        asm += self.get_mem()
+
+        return asm
+
+    def asm(self) -> str:
+        asm = ""
+        asm += self.get_code()
+        asm += self.get_data()
+
+        return asm
diff --git a/tools/capture_transforms.py b/tools/capture_transforms.py
index de35d86..5439b05 100644..100755
--- a/tools/capture_transforms.py
+++ b/tools/capture_transforms.py
@@ -1,4 +1,4 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
 
 import argparse
 import logging
diff --git a/tools/convert.py b/tools/convert.py
index 27a8a4a..f21a2fa 100644..100755
--- a/tools/convert.py
+++ b/tools/convert.py
@@ -1,3 +1,5 @@
+#!/usr/bin/env python3
+
 import argparse
 import sys
 
diff --git a/tools/verify_qemu.py b/tools/verify_qemu.py
index da2e985..779b903 100644..100755
--- a/tools/verify_qemu.py
+++ b/tools/verify_qemu.py
@@ -1,3 +1,5 @@
+#!/usr/bin/env python3
+
 """
 Spawn GDB, connect to QEMU's GDB server, and read test states from that.