about summary refs log tree commit diff stats
path: root/focaccia
diff options
context:
space:
mode:
Diffstat (limited to 'focaccia')
-rw-r--r--focaccia/compare.py3
-rw-r--r--focaccia/lldb_target.py38
-rw-r--r--focaccia/parser.py1
-rw-r--r--focaccia/reproducer.py172
4 files changed, 202 insertions, 12 deletions
diff --git a/focaccia/compare.py b/focaccia/compare.py
index 43a0133..65c0f49 100644
--- a/focaccia/compare.py
+++ b/focaccia/compare.py
@@ -316,7 +316,8 @@ def compare_symbolic(test_states: Iterable[ProgramState],
                 'pc': pc_cur,
                 'txl': _calc_transformation(cur_state, next_state),
                 'ref': transform,
-                'errors': errors
+                'errors': errors,
+                'snap': cur_state,
             })
 
             # Step forward
diff --git a/focaccia/lldb_target.py b/focaccia/lldb_target.py
index 903e73d..b51ec3d 100644
--- a/focaccia/lldb_target.py
+++ b/focaccia/lldb_target.py
@@ -77,15 +77,14 @@ class LLDBConcreteTarget:
         """Step forward by a single instruction."""
         thread: lldb.SBThread = self.process.GetThreadAtIndex(0)
         thread.StepInstruction(False)
-    
+
     def run_until(self, address: int) -> None:
         """Continue execution until the address is arrived, ignores other breakpoints"""
         bp = self.target.BreakpointCreateByAddress(address)
         while self.read_register("pc") != address:
-            self.target.run()
+            self.run()
         self.target.BreakpointDelete(bp.GetID())
 
-
     def record_snapshot(self) -> ProgramState:
         """Record the concrete target's state in a ProgramState object."""
         # Determine current arch
@@ -219,8 +218,8 @@ class LLDBConcreteTarget:
         command = f'breakpoint delete {addr}'
         result = lldb.SBCommandReturnObject()
         self.interpreter.HandleCommand(command, result)
-    
-    def get_basic_block(self, addr: int) -> [lldb.SBInstruction]:
+
+    def get_basic_block(self, addr: int) -> list[lldb.SBInstruction]:
         """Returns a basic block pointed by addr
         a code section is considered a basic block only if
         the last instruction is a brach, e.g. JUMP, CALL, RET
@@ -232,12 +231,29 @@ class LLDBConcreteTarget:
         block.append(self.target.ReadInstructions(lldb.SBAddress(addr, self.target), 1)[0])
 
         return block
-    
+
+    def get_basic_block_inst(self, addr: int) -> list[str]:
+        inst = []
+        for bb in self.get_basic_block(addr):
+            inst.append(f'{bb.GetMnemonic(self.target)} {bb.GetOperands(self.target)}')
+        return inst
+
+    def get_next_basic_block(self) -> list[lldb.SBInstruction]:
+        return self.get_basic_block(self.read_register("pc"))
+
     def get_symbol(self, addr: int) -> lldb.SBSymbol:
-        """Returns the symbol that belongs to the addr"""
-        for s in self.target.module.symbols:
-            if (s.GetType() == lldb.eSymbolTypeCode 
-            and s.GetStartAddress().GetLoadAddress(self.target) <= addr 
-            and addr < s.GetEndAddress().GetLoadAddress(self.target)):
+        """Returns the symbol that belongs to the addr
+        """
+        for s in self.module.symbols:
+            if (s.GetType() == lldb.eSymbolTypeCode and s.GetStartAddress().GetLoadAddress(self.target) <= addr  < s.GetEndAddress().GetLoadAddress(self.target)):
                 return s
         raise ConcreteSectionError(f'Error getting the symbol to which address {hex(addr)} belongs to')
+
+    def get_symbol_limit(self) -> int:
+        """Returns the address after all the symbols"""
+        addr = 0
+        for s in self.module.symbols:
+            if s.GetStartAddress().IsValid():
+                if s.GetStartAddress().GetLoadAddress(self.target) > addr:
+                    addr = s.GetEndAddress().GetLoadAddress(self.target)
+        return addr
diff --git a/focaccia/parser.py b/focaccia/parser.py
index a5a1014..9fb83d8 100644
--- a/focaccia/parser.py
+++ b/focaccia/parser.py
@@ -4,6 +4,7 @@ import base64
 import json
 import re
 from typing import TextIO
+import lldb
 
 from .arch import supported_architectures, Arch
 from .snapshot import ProgramState
diff --git a/focaccia/reproducer.py b/focaccia/reproducer.py
new file mode 100644
index 0000000..90e1378
--- /dev/null
+++ b/focaccia/reproducer.py
@@ -0,0 +1,172 @@
+
+from .lldb_target import LLDBConcreteTarget
+from .snapshot import ProgramState
+from .symbolic import SymbolicTransform, eval_symbol
+from .arch import x86
+
+class ReproducerMemoryError(Exception):
+    pass
+class ReproducerBasicBlockError(Exception):
+    pass
+class ReproducerRegisterError(Exception):
+    pass
+
+class Reproducer():
+    def __init__(self, oracle: str, argv: str, snap: ProgramState, sym: SymbolicTransform) -> None:
+
+        target = LLDBConcreteTarget(oracle)
+
+        self.pc = snap.read_register("pc")
+        self.bb = target.get_basic_block_inst(self.pc)
+        self.sl = target.get_symbol_limit()
+        self.snap = snap
+        self.sym = sym
+
+    def get_bb(self) -> str:
+        try:
+            asm = ""
+            asm += f'_bb_{hex(self.pc)}:\n'
+            for i in self.bb[:-1]:
+                asm += f'{i}\n'
+            asm += f'ret\n'
+            asm += f'\n'
+
+            return asm
+        except:
+            raise ReproducerBasicBlockError(f'{hex(self.pc)}\n{self.snap}\n{self.sym}\n{self.bb}')
+
+    def get_regs(self) -> str:
+        general_regs = ['RIP', 'RAX', 'RBX','RCX','RDX', 'RSI','RDI','RBP','RSP','R8','R9','R10','R11','R12','R13','R14','R15',]
+        flag_regs = ['CF', 'PF', 'AF', 'ZF', 'SF', 'TF', 'IF', 'DF', 'OF', 'IOPL', 'NT',]
+        eflag_regs = ['RF', 'VM', 'AC', 'VIF', 'VIP', 'ID',]
+
+        try:
+            asm = ""
+            asm += f'_setup_regs:\n'
+            for reg in self.sym.get_used_registers():
+                if reg in general_regs:
+                    asm += f'mov ${hex(self.snap.read_register(reg))}, %{reg.lower()}\n'
+
+            if 'RFLAGS' in self.sym.get_used_registers():
+                asm += f'pushfq ${hex(self.snap.read_register("RFLAGS"))}\n'
+
+            if any(reg in self.sym.get_used_registers() for reg in flag_regs+eflag_regs):
+                asm += f'pushfd ${hex(x86.compose_rflags(self.snap.regs))}\n'
+            asm += f'ret\n'
+            asm += f'\n'
+
+            return asm
+        except:
+            raise ReproducerRegisterError(f'{hex(self.pc)}\n{self.snap}\n{self.sym}\n{self.bb}')
+
+    def get_mem(self) -> str:
+        try:
+            asm = ""
+            asm += f'_setup_mem:\n'
+            for mem in self.sym.get_used_memory_addresses():
+                addr = eval_symbol(mem.ptr, self.snap)
+                val = self.snap.read_memory(addr, int(mem.size/8))
+
+                if addr < self.sl:
+                    asm += f'.org {hex(addr)}\n'
+                    for b in val:
+                        asm += f'.byte ${hex(b)}\n'
+            asm += f'\n'
+
+            return asm
+        except:
+            raise ReproducerMemoryError(f'{hex(self.pc)}\n{self.snap}\n{self.sym}\n{self.bb}')
+
+    def get_dyn(self) -> str:
+        try:
+            asm = ""
+            asm += f'_setup_dyn:\n'
+            for mem in self.sym.get_used_memory_addresses():
+                addr = eval_symbol(mem.ptr, self.snap)
+                val = self.snap.read_memory(addr, int(mem.size/8))
+
+                if addr >= self.sl:
+                    asm += f'mov ${hex(addr)}, %rdi\n'
+                    asm += f'call _alloc\n'
+                    for b in val:
+                        asm += f'mov ${hex(addr)}, %rax\n'
+                        asm += f'movb ${hex(b)}, (%rax)\n'
+                        addr += 1
+            asm += f'ret\n'
+            asm += f'\n'
+
+            return asm
+        except:
+            raise ReproducerMemoryError(f'{hex(self.pc)}\n{self.snap}\n{self.sym}\n{self.bb}')
+
+    def get_start(self) -> str:
+        asm = ""
+        asm += f'_start:\n'
+        asm += f'call _setup_dyn\n'
+        asm += f'call _setup_regs\n'
+        asm += f'call _bb_{hex(self.pc)}\n'
+        asm += f'call _exit\n'
+        asm += f'\n'
+
+        return asm
+
+    def get_exit(self) -> str:
+        asm = ""
+        asm += f'_exit:\n'
+        asm += f'movq $0, %rdi\n'
+        asm += f'movq $60, %rax\n'
+        asm += f'syscall\n'
+        asm += f'\n'
+
+        return asm
+
+    def get_alloc(self) -> str:
+        asm = ""
+        asm += f'_alloc:\n'
+        asm += f'movq $4096, %rsi\n'
+        asm += f'movq $(PROT_READ | PROT_WRITE), %rdx\n'
+        asm += f'movq $(MAP_PRIVATE | MAP_ANONYMOUS), %r10\n'
+        asm += f'movq $-1, %r8\n'
+        asm += f'movq $0, %r9\n'
+        asm += f'movq $syscall_mmap, %rax\n'
+        asm += f'syscall\n'
+        asm += f'ret\n'
+        asm += f'\n'
+
+        return asm
+
+    def get_code(self) -> str:
+        asm = ""
+        asm += f'.section .text\n'
+        asm += f'.global _start\n'
+        asm += f'\n'
+        asm += f'.org {hex(self.pc)}\n'
+        asm += self.get_bb()
+        asm += self.get_start()
+        asm += self.get_exit()
+        asm += self.get_alloc()
+        asm += self.get_regs()
+        asm += self.get_dyn()
+
+        return asm
+
+    def get_data(self) -> str:
+        asm = ""
+        asm += f'.section .data\n'
+        asm += f'PROT_READ  = 0x1\n'
+        asm += f'PROT_WRITE = 0x2\n'
+        asm += f'MAP_PRIVATE = 0x2\n'
+        asm += f'MAP_ANONYMOUS = 0x20\n'
+        asm += f'syscall_mmap = 9\n'
+        asm += f'\n'
+
+        asm += self.get_mem()
+
+        return asm
+
+    def asm(self) -> str:
+        asm = ""
+        asm += self.get_code()
+        asm += self.get_data()
+
+        return asm