about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--miasm2/analysis/sandbox.py22
-rw-r--r--miasm2/arch/arm/jit.py27
-rw-r--r--miasm2/arch/x86/jit.py58
-rw-r--r--miasm2/jitter/jitload.py39
-rw-r--r--miasm2/os_dep/win_api_x86_32.py2
5 files changed, 57 insertions, 91 deletions
diff --git a/miasm2/analysis/sandbox.py b/miasm2/analysis/sandbox.py
index ca6dcfe6..22bd2094 100644
--- a/miasm2/analysis/sandbox.py
+++ b/miasm2/analysis/sandbox.py
@@ -160,8 +160,11 @@ class OS_Win(OS):
 
     def __init__(self, custom_methods, *args, **kwargs):
         from miasm2.jitter.loader.pe import vm_load_pe, vm_load_pe_libs, preload_pe, libimp_pe
+        from miasm2.os_dep import win_api_x86_32
+        methods = win_api_x86_32.__dict__
+        methods.update(custom_methods)
 
-        super(OS_Win, self).__init__(custom_methods, *args, **kwargs)
+        super(OS_Win, self).__init__(methods, *args, **kwargs)
 
         # Import manager
         libs = libimp_pe()
@@ -187,7 +190,7 @@ class OS_Win(OS):
         preload_pe(self.jitter.vm, self.pe, libs)
 
         # Library calls handler
-        self.jitter.add_lib_handler(libs, custom_methods)
+        self.jitter.add_lib_handler(libs, methods)
 
         # Manage SEH
         if self.options.use_seh:
@@ -217,8 +220,11 @@ class OS_Linux(OS):
 
     def __init__(self, custom_methods, *args, **kwargs):
         from miasm2.jitter.loader.elf import vm_load_elf, preload_elf, libimp_elf
+        from miasm2.os_dep import linux_stdlib
+        methods = linux_stdlib.__dict__
+        methods.update(custom_methods)
 
-        super(OS_Linux, self).__init__(custom_methods, *args, **kwargs)
+        super(OS_Linux, self).__init__(methods, *args, **kwargs)
 
         # Import manager
         self.libs = libimp_elf()
@@ -230,12 +236,16 @@ class OS_Linux(OS):
         self.entry_point = self.elf.Ehdr.entry
 
         # Library calls handler
-        self.jitter.add_lib_handler(self.libs, custom_methods)
+        self.jitter.add_lib_handler(self.libs, methods)
 
 class OS_Linux_str(OS):
     def __init__(self, custom_methods, *args, **kwargs):
         from miasm2.jitter.loader.elf import libimp_elf
-        super(OS_Linux_str, self).__init__(custom_methods, *args, **kwargs)
+        from miasm2.os_dep import linux_stdlib
+        methods = linux_stdlib.__dict__
+        methods.update(custom_methods)
+
+        super(OS_Linux_str, self).__init__(methods, *args, **kwargs)
 
         # Import manager
         libs = libimp_elf()
@@ -246,7 +256,7 @@ class OS_Linux_str(OS):
         self.jitter.vm.add_memory_page(self.options.load_base_addr, PAGE_READ | PAGE_WRITE, data)
 
         # Library calls handler
-        self.jitter.add_lib_handler(libs, custom_methods)
+        self.jitter.add_lib_handler(libs, methods)
 
     @classmethod
     def update_parser(cls, parser):
diff --git a/miasm2/arch/arm/jit.py b/miasm2/arch/arm/jit.py
index d089bafb..8803725e 100644
--- a/miasm2/arch/arm/jit.py
+++ b/miasm2/arch/arm/jit.py
@@ -58,33 +58,6 @@ class jitter_arml(jitter):
             arg = self.get_stack_arg(n-4)
         return arg
 
-    def add_lib_handler(self, libs, user_globals=None):
-        """Add a function to handle libs call with breakpoints
-        @libs: libimp instance
-        @user_globals: dictionnary for defined user function
-        """
-        if user_globals is None:
-            user_globals = {}
-
-        from miasm2.os_dep import linux_stdlib
-
-        def handle_lib(jitter):
-            fname = libs.fad2cname[jitter.pc]
-            if fname in user_globals:
-                f = user_globals[fname]
-            elif fname in linux_stdlib.__dict__:
-                f = linux_stdlib.__dict__[fname]
-            else:
-                log.debug('%s' % repr(fname))
-                raise ValueError('unknown api', hex(jitter.pop_uint32_t()), repr(fname))
-            f(jitter)
-            jitter.pc = getattr(jitter.cpu, jitter.ir_arch.pc.name)
-            return True
-
-        for f_addr in libs.fad2cname:
-            self.add_breakpoint(f_addr, handle_lib)
-
-
     def init_run(self, *args, **kwargs):
         jitter.init_run(self, *args, **kwargs)
         self.cpu.PC = self.pc
diff --git a/miasm2/arch/x86/jit.py b/miasm2/arch/x86/jit.py
index 36afcce5..08bac4db 100644
--- a/miasm2/arch/x86/jit.py
+++ b/miasm2/arch/x86/jit.py
@@ -106,32 +106,6 @@ class jitter_x86_32(jitter):
         self.cpu.EIP = ret_addr
         self.cpu.EAX = ret_value
 
-    def add_lib_handler(self, libs, user_globals=None):
-        """Add a function to handle libs call with breakpoints
-        @libs: libimp instance
-        @user_globals: dictionnary for defined user function
-        """
-        if user_globals is None:
-            user_globals = {}
-
-        from miasm2.os_dep import win_api_x86_32
-
-        def handle_lib(jitter):
-            fname = libs.fad2cname[jitter.pc]
-            if fname in user_globals:
-                f = user_globals[fname]
-            elif fname in win_api_x86_32.__dict__:
-                f = win_api_x86_32.__dict__[fname]
-            else:
-                log.debug('%s' % repr(fname))
-                raise ValueError('unknown api', hex(jitter.pop_uint32_t()), repr(fname))
-            f(jitter)
-            jitter.pc = getattr(jitter.cpu, jitter.ir_arch.pc.name)
-            return True
-
-        for f_addr in libs.fad2cname:
-            self.add_breakpoint(f_addr, handle_lib)
-
     def init_run(self, *args, **kwargs):
         jitter.init_run(self, *args, **kwargs)
         self.cpu.EIP = self.pc
@@ -165,10 +139,6 @@ class jitter_x86_64(jitter):
         x = upck64(self.vm.get_mem(self.cpu.RSP + 8 * n, 8))
         return x
 
-    def init_run(self, *args, **kwargs):
-        jitter.init_run(self, *args, **kwargs)
-        self.cpu.RIP = self.pc
-
     def func_args_stdcall(self, n_args):
         args_regs = ['RCX', 'RDX', 'R8', 'R9']
         ret_ad = self.pop_uint64_t()
@@ -207,28 +177,6 @@ class jitter_x86_64(jitter):
             self.cpu.RAX = ret_value
         return True
 
-    def add_lib_handler(self, libs, user_globals=None):
-        """Add a function to handle libs call with breakpoints
-        @libs: libimp instance
-        @user_globals: dictionnary for defined user function
-        """
-        if user_globals is None:
-            user_globals = {}
-
-        from miasm2.os_dep import win_api_x86_32
-
-        def handle_lib(jitter):
-            fname = libs.fad2cname[jitter.pc]
-            if fname in user_globals:
-                f = user_globals[fname]
-            elif fname in win_api_x86_32.__dict__:
-                f = win_api_x86_32.__dict__[fname]
-            else:
-                log.debug('%s' % repr(fname))
-                raise ValueError('unknown api', hex(jitter.pop_uint64_t()), repr(fname))
-            f(jitter)
-            jitter.pc = getattr(jitter.cpu, jitter.ir_arch.pc.name)
-            return True
-
-        for f_addr in libs.fad2cname:
-            self.add_breakpoint(f_addr, handle_lib)
+    def init_run(self, *args, **kwargs):
+        jitter.init_run(self, *args, **kwargs)
+        self.cpu.RIP = self.pc
diff --git a/miasm2/jitter/jitload.py b/miasm2/jitter/jitload.py
index c297ba50..0405b46d 100644
--- a/miasm2/jitter/jitload.py
+++ b/miasm2/jitter/jitload.py
@@ -40,8 +40,10 @@ class CallbackHandler(object):
         self.callbacks = {}  # Key -> [callback list]
 
     def add_callback(self, name, callback):
-        "Add a callback to the key 'name'"
-        self.callbacks[name] = self.callbacks.get(name, []) + [callback]
+        """Add a callback to the key @name, iff the @callback isn't already
+        assigned to it"""
+        if callback not in self.callbacks.get(name, []):
+            self.callbacks[name] = self.callbacks.get(name, []) + [callback]
 
     def set_callback(self, name, *args):
         "Set the list of callback for key 'name'"
@@ -351,3 +353,36 @@ class jitter:
         """Set an unicode string in memory"""
         s = "\x00".join(list(s)) + '\x00' * 3
         self.vm.set_mem(addr, s)
+
+    @staticmethod
+    def handle_lib(jitter):
+        """Resolve the name of the function which cause the handler call. Then
+        call the corresponding handler from users callback.
+        """
+        fname = jitter.libs.fad2cname[jitter.pc]
+        if fname in jitter.user_globals:
+            func = jitter.user_globals[fname]
+        else:
+            log.debug('%s' % repr(fname))
+            raise ValueError('unknown api', hex(jitter.pc), repr(fname))
+        func(jitter)
+        jitter.pc = getattr(jitter.cpu, jitter.ir_arch.pc.name)
+        return True
+
+    def handle_function(self, f_addr):
+        """Add a brakpoint which will trigger the function handler"""
+        self.add_breakpoint(f_addr, self.handle_lib)
+
+    def add_lib_handler(self, libs, user_globals=None):
+        """Add a function to handle libs call with breakpoints
+        @libs: libimp instance
+        @user_globals: dictionnary for defined user function
+        """
+        if user_globals is None:
+            user_globals = {}
+
+        self.libs = libs
+        self.user_globals = user_globals
+
+        for f_addr in libs.fad2cname:
+            self.handle_function(f_addr)
diff --git a/miasm2/os_dep/win_api_x86_32.py b/miasm2/os_dep/win_api_x86_32.py
index a4c07e59..0996d616 100644
--- a/miasm2/os_dep/win_api_x86_32.py
+++ b/miasm2/os_dep/win_api_x86_32.py
@@ -954,7 +954,7 @@ def kernel32_GetProcAddress(jitter):
     else:
         ad = 0
     ad = winobjs.runtime_dll.lib_get_add_func(libbase, fname)
-
+    jitter.add_breakpoint(ad, jitter.handle_lib)
     jitter.func_ret_stdcall(ret_ad, ad)