diff options
Diffstat (limited to 'focaccia')
| -rw-r--r-- | focaccia/compare.py | 3 | ||||
| -rw-r--r-- | focaccia/lldb_target.py | 38 | ||||
| -rw-r--r-- | focaccia/parser.py | 1 | ||||
| -rw-r--r-- | focaccia/reproducer.py | 172 |
4 files changed, 202 insertions, 12 deletions
diff --git a/focaccia/compare.py b/focaccia/compare.py index 43a0133..65c0f49 100644 --- a/focaccia/compare.py +++ b/focaccia/compare.py @@ -316,7 +316,8 @@ def compare_symbolic(test_states: Iterable[ProgramState], 'pc': pc_cur, 'txl': _calc_transformation(cur_state, next_state), 'ref': transform, - 'errors': errors + 'errors': errors, + 'snap': cur_state, }) # Step forward diff --git a/focaccia/lldb_target.py b/focaccia/lldb_target.py index 903e73d..b51ec3d 100644 --- a/focaccia/lldb_target.py +++ b/focaccia/lldb_target.py @@ -77,15 +77,14 @@ class LLDBConcreteTarget: """Step forward by a single instruction.""" thread: lldb.SBThread = self.process.GetThreadAtIndex(0) thread.StepInstruction(False) - + def run_until(self, address: int) -> None: """Continue execution until the address is arrived, ignores other breakpoints""" bp = self.target.BreakpointCreateByAddress(address) while self.read_register("pc") != address: - self.target.run() + self.run() self.target.BreakpointDelete(bp.GetID()) - def record_snapshot(self) -> ProgramState: """Record the concrete target's state in a ProgramState object.""" # Determine current arch @@ -219,8 +218,8 @@ class LLDBConcreteTarget: command = f'breakpoint delete {addr}' result = lldb.SBCommandReturnObject() self.interpreter.HandleCommand(command, result) - - def get_basic_block(self, addr: int) -> [lldb.SBInstruction]: + + def get_basic_block(self, addr: int) -> list[lldb.SBInstruction]: """Returns a basic block pointed by addr a code section is considered a basic block only if the last instruction is a brach, e.g. JUMP, CALL, RET @@ -232,12 +231,29 @@ class LLDBConcreteTarget: block.append(self.target.ReadInstructions(lldb.SBAddress(addr, self.target), 1)[0]) return block - + + def get_basic_block_inst(self, addr: int) -> list[str]: + inst = [] + for bb in self.get_basic_block(addr): + inst.append(f'{bb.GetMnemonic(self.target)} {bb.GetOperands(self.target)}') + return inst + + def get_next_basic_block(self) -> list[lldb.SBInstruction]: + return self.get_basic_block(self.read_register("pc")) + def get_symbol(self, addr: int) -> lldb.SBSymbol: - """Returns the symbol that belongs to the addr""" - for s in self.target.module.symbols: - if (s.GetType() == lldb.eSymbolTypeCode - and s.GetStartAddress().GetLoadAddress(self.target) <= addr - and addr < s.GetEndAddress().GetLoadAddress(self.target)): + """Returns the symbol that belongs to the addr + """ + for s in self.module.symbols: + if (s.GetType() == lldb.eSymbolTypeCode and s.GetStartAddress().GetLoadAddress(self.target) <= addr < s.GetEndAddress().GetLoadAddress(self.target)): return s raise ConcreteSectionError(f'Error getting the symbol to which address {hex(addr)} belongs to') + + def get_symbol_limit(self) -> int: + """Returns the address after all the symbols""" + addr = 0 + for s in self.module.symbols: + if s.GetStartAddress().IsValid(): + if s.GetStartAddress().GetLoadAddress(self.target) > addr: + addr = s.GetEndAddress().GetLoadAddress(self.target) + return addr diff --git a/focaccia/parser.py b/focaccia/parser.py index a5a1014..9fb83d8 100644 --- a/focaccia/parser.py +++ b/focaccia/parser.py @@ -4,6 +4,7 @@ import base64 import json import re from typing import TextIO +import lldb from .arch import supported_architectures, Arch from .snapshot import ProgramState diff --git a/focaccia/reproducer.py b/focaccia/reproducer.py new file mode 100644 index 0000000..90e1378 --- /dev/null +++ b/focaccia/reproducer.py @@ -0,0 +1,172 @@ + +from .lldb_target import LLDBConcreteTarget +from .snapshot import ProgramState +from .symbolic import SymbolicTransform, eval_symbol +from .arch import x86 + +class ReproducerMemoryError(Exception): + pass +class ReproducerBasicBlockError(Exception): + pass +class ReproducerRegisterError(Exception): + pass + +class Reproducer(): + def __init__(self, oracle: str, argv: str, snap: ProgramState, sym: SymbolicTransform) -> None: + + target = LLDBConcreteTarget(oracle) + + self.pc = snap.read_register("pc") + self.bb = target.get_basic_block_inst(self.pc) + self.sl = target.get_symbol_limit() + self.snap = snap + self.sym = sym + + def get_bb(self) -> str: + try: + asm = "" + asm += f'_bb_{hex(self.pc)}:\n' + for i in self.bb[:-1]: + asm += f'{i}\n' + asm += f'ret\n' + asm += f'\n' + + return asm + except: + raise ReproducerBasicBlockError(f'{hex(self.pc)}\n{self.snap}\n{self.sym}\n{self.bb}') + + def get_regs(self) -> str: + general_regs = ['RIP', 'RAX', 'RBX','RCX','RDX', 'RSI','RDI','RBP','RSP','R8','R9','R10','R11','R12','R13','R14','R15',] + flag_regs = ['CF', 'PF', 'AF', 'ZF', 'SF', 'TF', 'IF', 'DF', 'OF', 'IOPL', 'NT',] + eflag_regs = ['RF', 'VM', 'AC', 'VIF', 'VIP', 'ID',] + + try: + asm = "" + asm += f'_setup_regs:\n' + for reg in self.sym.get_used_registers(): + if reg in general_regs: + asm += f'mov ${hex(self.snap.read_register(reg))}, %{reg.lower()}\n' + + if 'RFLAGS' in self.sym.get_used_registers(): + asm += f'pushfq ${hex(self.snap.read_register("RFLAGS"))}\n' + + if any(reg in self.sym.get_used_registers() for reg in flag_regs+eflag_regs): + asm += f'pushfd ${hex(x86.compose_rflags(self.snap.regs))}\n' + asm += f'ret\n' + asm += f'\n' + + return asm + except: + raise ReproducerRegisterError(f'{hex(self.pc)}\n{self.snap}\n{self.sym}\n{self.bb}') + + def get_mem(self) -> str: + try: + asm = "" + asm += f'_setup_mem:\n' + for mem in self.sym.get_used_memory_addresses(): + addr = eval_symbol(mem.ptr, self.snap) + val = self.snap.read_memory(addr, int(mem.size/8)) + + if addr < self.sl: + asm += f'.org {hex(addr)}\n' + for b in val: + asm += f'.byte ${hex(b)}\n' + asm += f'\n' + + return asm + except: + raise ReproducerMemoryError(f'{hex(self.pc)}\n{self.snap}\n{self.sym}\n{self.bb}') + + def get_dyn(self) -> str: + try: + asm = "" + asm += f'_setup_dyn:\n' + for mem in self.sym.get_used_memory_addresses(): + addr = eval_symbol(mem.ptr, self.snap) + val = self.snap.read_memory(addr, int(mem.size/8)) + + if addr >= self.sl: + asm += f'mov ${hex(addr)}, %rdi\n' + asm += f'call _alloc\n' + for b in val: + asm += f'mov ${hex(addr)}, %rax\n' + asm += f'movb ${hex(b)}, (%rax)\n' + addr += 1 + asm += f'ret\n' + asm += f'\n' + + return asm + except: + raise ReproducerMemoryError(f'{hex(self.pc)}\n{self.snap}\n{self.sym}\n{self.bb}') + + def get_start(self) -> str: + asm = "" + asm += f'_start:\n' + asm += f'call _setup_dyn\n' + asm += f'call _setup_regs\n' + asm += f'call _bb_{hex(self.pc)}\n' + asm += f'call _exit\n' + asm += f'\n' + + return asm + + def get_exit(self) -> str: + asm = "" + asm += f'_exit:\n' + asm += f'movq $0, %rdi\n' + asm += f'movq $60, %rax\n' + asm += f'syscall\n' + asm += f'\n' + + return asm + + def get_alloc(self) -> str: + asm = "" + asm += f'_alloc:\n' + asm += f'movq $4096, %rsi\n' + asm += f'movq $(PROT_READ | PROT_WRITE), %rdx\n' + asm += f'movq $(MAP_PRIVATE | MAP_ANONYMOUS), %r10\n' + asm += f'movq $-1, %r8\n' + asm += f'movq $0, %r9\n' + asm += f'movq $syscall_mmap, %rax\n' + asm += f'syscall\n' + asm += f'ret\n' + asm += f'\n' + + return asm + + def get_code(self) -> str: + asm = "" + asm += f'.section .text\n' + asm += f'.global _start\n' + asm += f'\n' + asm += f'.org {hex(self.pc)}\n' + asm += self.get_bb() + asm += self.get_start() + asm += self.get_exit() + asm += self.get_alloc() + asm += self.get_regs() + asm += self.get_dyn() + + return asm + + def get_data(self) -> str: + asm = "" + asm += f'.section .data\n' + asm += f'PROT_READ = 0x1\n' + asm += f'PROT_WRITE = 0x2\n' + asm += f'MAP_PRIVATE = 0x2\n' + asm += f'MAP_ANONYMOUS = 0x20\n' + asm += f'syscall_mmap = 9\n' + asm += f'\n' + + asm += self.get_mem() + + return asm + + def asm(self) -> str: + asm = "" + asm += self.get_code() + asm += self.get_data() + + return asm |