about summary refs log tree commit diff stats
path: root/src/miasm/jitter/jitload.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/miasm/jitter/jitload.py')
-rw-r--r--src/miasm/jitter/jitload.py585
1 files changed, 585 insertions, 0 deletions
diff --git a/src/miasm/jitter/jitload.py b/src/miasm/jitter/jitload.py
new file mode 100644
index 00000000..99e4429d
--- /dev/null
+++ b/src/miasm/jitter/jitload.py
@@ -0,0 +1,585 @@
+import logging
+import warnings
+from functools import wraps
+from collections import namedtuple
+try:
+    from collections.abc import Sequence, Iterator
+except ImportError:
+    from collections import Sequence, Iterator
+
+from future.utils import viewitems
+
+from miasm.jitter.csts import *
+from miasm.core.utils import *
+from miasm.core.bin_stream import bin_stream_vm
+from miasm.jitter.emulatedsymbexec import EmulatedSymbExec
+from miasm.jitter.codegen import CGen
+from miasm.jitter.jitcore_cc_base import JitCore_Cc_Base
+
+hnd = logging.StreamHandler()
+hnd.setFormatter(logging.Formatter("[%(levelname)-8s]: %(message)s"))
+log = logging.getLogger('jitload.py')
+log.addHandler(hnd)
+log.setLevel(logging.CRITICAL)
+log_func = logging.getLogger('jit function call')
+log_func.addHandler(hnd)
+log_func.setLevel(logging.CRITICAL)
+
+try:
+    from miasm.jitter import VmMngr
+except ImportError:
+    log.error('cannot import VmMngr')
+
+
+def named_arguments(func):
+    """Function decorator to allow the use of .func_args_*() methods
+    with either the number of arguments or the list of the argument
+    names.
+
+    The wrapper is also used to log the argument values.
+
+    @func: function
+
+    """
+    @wraps(func)
+    def newfunc(self, args):
+        if isinstance(args, Sequence):
+            ret_ad, arg_vals = func(self, len(args))
+            arg_vals = namedtuple("args", args)(*arg_vals)
+            # func_name(arguments) return address
+            log_func.info(
+                '%s(%s) ret addr: %s',
+                get_caller_name(1),
+                ', '.join(
+                    "%s=0x%x" % (field, value)
+                    for field, value in viewitems(arg_vals._asdict())
+                ),
+                hex(ret_ad)
+            )
+            return ret_ad, namedtuple("args", args)(*arg_vals)
+        else:
+            ret_ad, arg_vals = func(self, args)
+            # func_name(arguments) return address
+            log_func.info('%s(%s) ret addr: %s',
+                get_caller_name(1),
+                ', '.join(hex(arg) for arg in arg_vals),
+                hex(ret_ad))
+            return ret_ad, arg_vals
+    return newfunc
+
+
+class CallbackHandler(object):
+
+    "Handle a list of callback"
+
+    def __init__(self):
+        self.callbacks = {}  # Key -> [callback list]
+
+    def add_callback(self, key, callback):
+        """Add a callback to the key @key, iff the @callback isn't already
+        assigned to it"""
+        if callback not in self.callbacks.get(key, []):
+            self.callbacks[key] = self.callbacks.get(key, []) + [callback]
+
+    def set_callback(self, key, *args):
+        "Set the list of callback for key 'key'"
+        self.callbacks[key] = list(args)
+
+    def get_callbacks(self, key):
+        "Return the list of callbacks associated to key 'key'"
+        return self.callbacks.get(key, [])
+
+    def remove_callback(self, callback):
+        """Remove the callback from the list.
+        Return the list of empty keys (removed)"""
+
+        to_check = set()
+        for key, cb_list in viewitems(self.callbacks):
+            try:
+                cb_list.remove(callback)
+                to_check.add(key)
+            except ValueError:
+                pass
+
+        empty_keys = []
+        for key in to_check:
+            if len(self.callbacks[key]) == 0:
+                empty_keys.append(key)
+                del(self.callbacks[key])
+
+        return empty_keys
+
+    def has_callbacks(self, key):
+        return key in self.callbacks
+
+    def remove_key(self, key):
+        """Remove and return all callbacks associated to @key"""
+        callbacks = self.callbacks.get(key, [])
+        del self.callbacks[key]
+        return callbacks
+
+    def call_callbacks(self, key, *args):
+        """Call callbacks associated to key 'key' with arguments args. While
+        callbacks return True, continue with next callback.
+        Iterator on other results."""
+
+        res = True
+
+        for c in self.get_callbacks(key):
+            res = c(*args)
+            if res is not True:
+                yield res
+
+    def __call__(self, key, *args):
+        "Wrapper for call_callbacks"
+        return self.call_callbacks(key, *args)
+
+
+class CallbackHandlerBitflag(CallbackHandler):
+
+    "Handle a list of callback with conditions on bitflag"
+
+    def call_callbacks(self, bitflag, *args):
+        """Call each callbacks associated with bit set in bitflag. While
+        callbacks return True, continue with next callback.
+        Iterator on other results"""
+
+        for bitflag_expected in self.callbacks:
+            if bitflag_expected & bitflag == bitflag_expected:
+                # If the flag matched
+                for res in super(CallbackHandlerBitflag,
+                                 self).call_callbacks(bitflag_expected, *args):
+                    if res is not True:
+                        yield res
+
+
+class ExceptionHandle(object):
+
+    "Return type for exception handler"
+
+    def __init__(self, except_flag):
+        self.except_flag = except_flag
+
+    @classmethod
+    def memoryBreakpoint(cls):
+        return cls(EXCEPT_BREAKPOINT_MEMORY)
+
+    def __eq__(self, to_cmp):
+        if not isinstance(to_cmp, ExceptionHandle):
+            return False
+        return (self.except_flag == to_cmp.except_flag)
+
+    def __ne__(self, to_cmp):
+        return not self.__eq__(to_cmp)
+
+
+class JitterException(Exception):
+
+    "Raised when any unhandled exception occurs (in jitter.vm or jitter.cpu)"
+
+    def __init__(self, exception_flag):
+        super(JitterException, self).__init__()
+        self.exception_flag = exception_flag
+
+    def __str__(self):
+        return "A jitter exception occurred: %s (0x%x)" % (
+            self.exception_flag_to_str(), self.exception_flag
+        )
+
+    def exception_flag_to_str(self):
+        exception_flag_list = []
+        for name, value in JitterExceptions.items():
+            if value & self.exception_flag == value:
+                exception_flag_list.append(name)
+        return ' & '.join(exception_flag_list)
+
+
+class Jitter(object):
+
+    "Main class for JIT handling"
+
+    C_Gen = CGen
+
+    def __init__(self, lifter, jit_type="gcc"):
+        """Init an instance of jitter.
+        @lifter: Lifter instance for this architecture
+        @jit_type: JiT backend to use. Available options are:
+            - "gcc"
+            - "llvm"
+            - "python"
+        """
+
+        self.arch = lifter.arch
+        self.attrib = lifter.attrib
+        arch_name = lifter.arch.name  # (lifter.arch.name, lifter.attrib)
+        self.running = False
+
+        try:
+            if arch_name == "x86":
+                from miasm.jitter.arch import JitCore_x86 as jcore
+            elif arch_name == "arm":
+                from miasm.jitter.arch import JitCore_arm as jcore
+            elif arch_name == "armt":
+                from miasm.jitter.arch import JitCore_arm as jcore
+                lifter.arch.name = 'arm'
+            elif arch_name == "aarch64":
+                from miasm.jitter.arch import JitCore_aarch64 as jcore
+            elif arch_name == "msp430":
+                from miasm.jitter.arch import JitCore_msp430 as jcore
+            elif arch_name == "mips32":
+                from miasm.jitter.arch import JitCore_mips32 as jcore
+            elif arch_name == "ppc32":
+                from miasm.jitter.arch import JitCore_ppc32 as jcore
+            elif arch_name == "mep":
+                from miasm.jitter.arch import JitCore_mep as jcore
+            else:
+                raise ValueError("unknown jit arch: %s" % arch_name)
+        except ImportError:
+            raise RuntimeError('Unsupported jit arch: %s' % arch_name)
+
+        self.vm = VmMngr.Vm()
+        self.cpu = jcore.JitCpu()
+        self.lifter = lifter
+        self.bs = bin_stream_vm(self.vm)
+        self.ircfg = self.lifter.new_ircfg()
+
+        self.symbexec = EmulatedSymbExec(
+            self.cpu, self.vm, self.lifter, {}
+        )
+        self.symbexec.reset_regs()
+
+        try:
+            if jit_type == "llvm":
+                from miasm.jitter.jitcore_llvm import JitCore_LLVM as JitCore
+            elif jit_type == "python":
+                from miasm.jitter.jitcore_python import JitCore_Python as JitCore
+            elif jit_type == "gcc":
+                from miasm.jitter.jitcore_gcc import JitCore_Gcc as JitCore
+            else:
+                raise ValueError("Unknown jitter %s" % jit_type)
+        except ImportError:
+            raise RuntimeError('Unsupported jitter: %s' % jit_type)
+
+        self.jit = JitCore(self.lifter, self.bs)
+        if isinstance(self.jit, JitCore_Cc_Base):
+            self.jit.init_codegen(self.C_Gen(self.lifter))
+        elif jit_type == "python":
+            self.jit.set_cpu_vm(self.cpu, self.vm)
+
+        self.cpu.init_regs()
+        self.vm.init_memory_page_pool()
+        self.vm.init_code_bloc_pool()
+        self.vm.init_memory_breakpoint()
+
+        self.jit.load()
+        self.cpu.vmmngr = self.vm
+        self.cpu.jitter = self.jit
+        self.stack_size = 0x10000
+        self.stack_base = 0x1230000
+
+        # Init callback handler
+        self.breakpoints_handler = CallbackHandler()
+        self.exceptions_handler = CallbackHandlerBitflag()
+        self.init_exceptions_handler()
+        self.exec_cb = None
+
+    def init_exceptions_handler(self):
+        "Add common exceptions handlers"
+
+        def exception_automod(jitter):
+            "Tell the JiT backend to update blocks modified"
+
+            self.jit.updt_automod_code(jitter.vm)
+            self.vm.set_exception(0)
+
+            return True
+
+        self.add_exception_handler(EXCEPT_CODE_AUTOMOD, exception_automod)
+
+    def add_breakpoint(self, addr, callback):
+        """Add a callback associated with addr.
+        @addr: breakpoint address
+        @callback: function with definition (jitter instance)
+        """
+        self.breakpoints_handler.add_callback(addr, callback)
+        self.jit.add_disassembly_splits(addr)
+        # De-jit previously jitted blocks
+        self.jit.updt_automod_code_range(self.vm, [(addr, addr)])
+
+    def set_breakpoint(self, addr, *args):
+        """Set callbacks associated with addr.
+        @addr: breakpoint address
+        @args: functions with definition (jitter instance)
+        """
+        self.breakpoints_handler.set_callback(addr, *args)
+        self.jit.add_disassembly_splits(addr)
+
+    def get_breakpoint(self, addr):
+        """
+        Return breakpoints handlers for address @addr
+        @addr: integer
+        """
+        return self.breakpoints_handler.get_callbacks(addr)
+
+    def remove_breakpoints_by_callback(self, callback):
+        """Remove callbacks associated with breakpoint.
+        @callback: callback to remove
+        """
+        empty_keys = self.breakpoints_handler.remove_callback(callback)
+        for key in empty_keys:
+            self.jit.remove_disassembly_splits(key)
+
+    def remove_breakpoints_by_address(self, address):
+        """Remove all breakpoints associated with @address.
+        @address: address of breakpoints to remove
+        """
+        callbacks = self.breakpoints_handler.remove_key(address)
+        if callbacks:
+            self.jit.remove_disassembly_splits(address)
+
+    def add_exception_handler(self, flag, callback):
+        """Add a callback associated with an exception flag.
+        @flag: bitflag
+        @callback: function with definition (jitter instance)
+        """
+        self.exceptions_handler.add_callback(flag, callback)
+
+    def run_at(self, pc):
+        """Wrapper on JiT backend. Run the code at PC and return the next PC.
+        @pc: address of code to run"""
+
+        return self.jit.run_at(
+            self.cpu, pc,
+            set(self.breakpoints_handler.callbacks)
+        )
+
+    def runiter_once(self, pc):
+        """Iterator on callbacks results on code running from PC.
+        Check exceptions before breakpoints."""
+
+        self.pc = pc
+        # Callback called before exec
+        if self.exec_cb is not None:
+            res = self.exec_cb(self)
+            if res is not True:
+                yield res
+
+        # Check breakpoints
+        old_pc = self.pc
+        for res in self.breakpoints_handler.call_callbacks(self.pc, self):
+            if res is not True:
+                if isinstance(res, Iterator):
+                    # If the breakpoint is a generator, yield it step by step
+                    for tmp in res:
+                        yield tmp
+                else:
+                    yield res
+
+        # Check exceptions (raised by breakpoints)
+        exception_flag = self.get_exception()
+        for res in self.exceptions_handler(exception_flag, self):
+            if res is not True:
+                if isinstance(res, Iterator):
+                    for tmp in res:
+                        yield tmp
+                else:
+                    yield res
+
+        # If a callback changed pc, re call every callback
+        if old_pc != self.pc:
+            return
+
+        # Exceptions should never be activated before run
+        exception_flag = self.get_exception()
+        if exception_flag:
+            raise JitterException(exception_flag)
+
+        # Run the block at PC
+        self.pc = self.run_at(self.pc)
+
+        # Check exceptions (raised by the execution of the block)
+        exception_flag = self.get_exception()
+        for res in self.exceptions_handler(exception_flag, self):
+            if res is not True:
+                if isinstance(res, Iterator):
+                    for tmp in res:
+                        yield tmp
+                else:
+                    yield res
+
+    def init_run(self, pc):
+        """Create an iterator on pc with runiter.
+        @pc: address of code to run
+        """
+        self.run_iterator = self.runiter_once(pc)
+        self.pc = pc
+        self.running = True
+
+    def continue_run(self, step=False, trace=False):
+        """PRE: init_run.
+        Continue the run of the current session until iterator returns or run is
+        set to False.
+        If step is True, run only one time.
+        If trace is True, activate trace log option until execution stops
+        Return the iterator value"""
+
+        if trace:
+            self.set_trace_log()
+        while self.running:
+            try:
+                return next(self.run_iterator)
+            except StopIteration:
+                pass
+
+            self.run_iterator = self.runiter_once(self.pc)
+
+            if step is True:
+                break
+        if trace:
+            self.set_trace_log(False, False, False)
+        return None
+
+
+    def run(self, addr):
+        """
+        Launch emulation
+        @addr: (int) start address
+        """
+        self.init_run(addr)
+        return self.continue_run()
+
+    def run_until(self, addr, trace=False):
+        """PRE: init_run.
+        Continue the run of the current session until iterator returns, run is
+        set to False or addr is reached.
+        If trace is True, activate trace log option until execution stops
+        Return the iterator value"""
+
+        def stop_exec(jitter):
+            jitter.remove_breakpoints_by_callback(stop_exec)
+            return False
+        self.add_breakpoint(addr, stop_exec)
+        return self.continue_run(trace=trace)
+
+    def init_stack(self):
+        self.vm.add_memory_page(
+            self.stack_base,
+            PAGE_READ | PAGE_WRITE,
+            b"\x00" * self.stack_size,
+            "Stack")
+        sp = self.arch.getsp(self.attrib)
+        setattr(self.cpu, sp.name, self.stack_base + self.stack_size)
+        # regs = self.cpu.get_gpreg()
+        # regs[sp.name] = self.stack_base+self.stack_size
+        # self.cpu.set_gpreg(regs)
+
+    def get_exception(self):
+        return self.cpu.get_exception() | self.vm.get_exception()
+
+    # common functions
+    def get_c_str(self, addr, max_char=None):
+        """Get C str from vm.
+        @addr: address in memory
+        @max_char: maximum len"""
+        l = 0
+        tmp = addr
+        while ((max_char is None or l < max_char) and
+               self.vm.get_mem(tmp, 1) != b"\x00"):
+            tmp += 1
+            l += 1
+        value = self.vm.get_mem(addr, l)
+        value = force_str(value)
+        return value
+
+    def set_c_str(self, addr, value):
+        """Set C str str from vm.
+        @addr: address in memory
+        @value: str"""
+        value = force_bytes(value)
+        self.vm.set_mem(addr, value + b'\x00')
+
+    def get_str_ansi(self, addr, max_char=None):
+        raise NotImplementedError("Deprecated: use os_dep.win_api_x86_32.get_win_str_a")
+
+    def get_str_unic(self, addr, max_char=None):
+        raise NotImplementedError("Deprecated: use os_dep.win_api_x86_32.get_win_str_a")
+
+    @staticmethod
+    def handle_lib(jitter):
+        """Resolve the name of the function which cause the handler call. Then
+        call the corresponding handler from users callback.
+        """
+        fname = jitter.libs.fad2cname[jitter.pc]
+        if fname in jitter.user_globals:
+            func = jitter.user_globals[fname]
+        else:
+            log.debug('%r', fname)
+            raise ValueError('unknown api', hex(jitter.pc), repr(fname))
+        ret = func(jitter)
+        jitter.pc = getattr(jitter.cpu, jitter.lifter.pc.name)
+
+        # Don't break on a None return
+        if ret is None:
+            return True
+        else:
+            return ret
+
+    def handle_function(self, f_addr):
+        """Add a breakpoint which will trigger the function handler"""
+        self.add_breakpoint(f_addr, self.handle_lib)
+
+    def add_lib_handler(self, libs, user_globals=None):
+        """Add a function to handle libs call with breakpoints
+        @libs: libimp instance
+        @user_globals: dictionary for defined user function
+        """
+        if user_globals is None:
+            user_globals = {}
+
+        self.libs = libs
+        out = {}
+        for name, func in viewitems(user_globals):
+            out[name] = func
+        self.user_globals = out
+
+        for f_addr in libs.fad2cname:
+            self.handle_function(f_addr)
+
+    def eval_expr(self, expr):
+        """Eval expression @expr in the context of the current instance. Side
+        effects are passed on it"""
+        self.symbexec.update_engine_from_cpu()
+        ret = self.symbexec.eval_updt_expr(expr)
+        self.symbexec.update_cpu_from_engine()
+
+        return ret
+
+    def set_trace_log(self,
+                      trace_instr=True, trace_regs=True,
+                      trace_new_blocks=False):
+        """
+        Activate/Deactivate trace log options
+
+        @trace_instr: activate instructions tracing log
+        @trace_regs: activate registers tracing log
+        @trace_new_blocks: dump new code blocks log
+        """
+
+        # As trace state changes, clear already jitted blocks
+        self.jit.clear_jitted_blocks()
+
+        self.jit.log_mn = trace_instr
+        self.jit.log_regs = trace_regs
+        self.jit.log_newbloc = trace_new_blocks
+
+
+class jitter(Jitter):
+    """
+    DEPRECATED object
+    Use Jitter instead of jitter
+    """
+
+
+    def __init__(self, *args, **kwargs):
+        warnings.warn("Deprecated API: use Jitter")
+        super(jitter, self).__init__(*args, **kwargs)