about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--miasm/core/bin_stream.py11
-rw-r--r--miasm/tools/pe_helper.py142
-rw-r--r--miasm/tools/to_c_helper.py49
-rw-r--r--miasm/tools/win_api.py76
4 files changed, 204 insertions, 74 deletions
diff --git a/miasm/core/bin_stream.py b/miasm/core/bin_stream.py
index 5db15bcc..9236e037 100644
--- a/miasm/core/bin_stream.py
+++ b/miasm/core/bin_stream.py
@@ -39,6 +39,17 @@ class bin_stream(object):
     def hexdump(self, offset, l):
         return
 
+    def __getitem__(self, item):
+        if not type(item) is slice: # integer
+            self.offset = item
+            return self.readbs(1)
+        start = item.start
+        stop  = item.stop
+        step  = item.step
+        self.offset = start
+        s = self.readbs(stop-start)
+        return s[::step]
+
 class bin_stream_str(bin_stream):
     def __init__(self, bin ="", offset = 0L):
         if offset>len(bin):
diff --git a/miasm/tools/pe_helper.py b/miasm/tools/pe_helper.py
index 44d680d8..9c600165 100644
--- a/miasm/tools/pe_helper.py
+++ b/miasm/tools/pe_helper.py
@@ -290,41 +290,46 @@ def code_is_jmp_imp(e, ad, imp_d):
     return is_jmp_imp(l, imp_d)
 
 
-#giving e and address in function guess function start
-def guess_func_start(e, middle_ad, max_offset = 0x200):
-    ad = middle_ad+1
-    ad_found = None
-    while ad > middle_ad - max_offset:
-        ad-=1
-
-        ####### heuristic CC pad #######
-        if e.virt[ad] == "\xCC":
-            if e.virt[((ad+3)&~3)-1] == "\xCC":
-                ad_found = ((ad+3)&~3)
-                break
-            else:
-                continue
-
-
-        l = x86_mn.dis(e.virt[ad:ad+15])
-        if not l:
-            continue
-        if l.m.name in ["ret"]:
-            ad_found = ad+l.l
-            break
-
-    if not ad_found:
-        print 'cannot find func start'
-        return None
-
-    while e.virt[ad_found] == "\xCC":
-        ad_found+=1
-
-    if e.virt[ad_found:ad_found+3] == "\x8D\x40\x00":
-        ad_found += 3
-
 
-    return ad_found
+# giving e and address in function guess function start
+def guess_func_start(in_str, line_ad, max_offset = 0x200):
+    ad = line_ad+1
+    done = False
+    func_addrs = set()
+    symbol_pool = asmbloc.asm_symbol_pool()
+    all_bloc = asmbloc.dis_bloc_all(x86_mn, in_str, line_ad,
+                                    func_addrs, symbol_pool)
+    while not done:
+        ad_found = None
+        while ad > line_ad - max_offset:
+            ad-=1
+            ####### heuristic CC pad #######
+            if in_str[ad] == "\xCC":
+                if in_str[((ad+3)&~3)-1] == "\xCC":
+                    ad_found = ((ad+3)&~3)
+                    break
+                else:
+                    continue
+            l = x86_mn.dis(in_str[ad:ad+15])
+            if not l:
+                continue
+            if l.m.name in ["ret"]:
+                ad_found = ad+l.l
+                break
+        if not ad_found:
+            print 'cannot find func start'
+            return None
+        while in_str[ad_found] == "\xCC":
+            ad_found+=1
+        # lea eax, [eax]
+        if in_str[ad_found:ad_found+3] == "\x8D\x40\x00":
+            ad_found += 3
+
+        job_done = set()
+        symbol_pool = asmbloc.asm_symbol_pool()
+        all_bloc = asmbloc.dis_bloc_all(x86_mn, in_str, ad_found, job_done, symbol_pool)
+        if func_addrs.issubset(job_done):
+            return ad_found
 
 def get_nul_term(e, ad):
     out = ""
@@ -387,7 +392,7 @@ class libimp:
         self.fad2info = {}
 
     def lib_get_add_base(self, name):
-        name = name.lower()
+        name = name.lower().strip(' ')
         if not "." in name:
             print 'warning adding .dll to modulename'
             name += '.dll'
@@ -543,8 +548,9 @@ def vm_load_pe(e, align_s = True, load_hdr = True):
 
     if aligned:
         if load_hdr:
-            min_len = min(e.SHList[0].addr, 0x1000)
-            pe_hdr = e.content[:0x400]
+            hdr_len = max(0x200, e.NThdr.sectionalignment)
+            min_len = min(e.SHList[0].addr, hdr_len)
+            pe_hdr = e.content[:hdr_len]
             pe_hdr = pe_hdr+min_len*"\x00"
             pe_hdr = pe_hdr[:min_len]
             to_c_helper.vm_add_memory_page(e.NThdr.ImageBase, to_c_helper.PAGE_READ|to_c_helper.PAGE_WRITE, pe_hdr)
@@ -652,21 +658,59 @@ def get_export_name_addr_list(e):
     return out
 
 
+class pattern_class:
+    pass
+
+class pattern_call_x86(pattern_class):
+    patterns = ["\xE8"]
+    @classmethod
+    def test_candidate(cls, in_str, off_i, off_dst):
+        off = off_i + 5 + struct.unpack('i', in_str[off_i+1:off_i+5])[0]
+        #print "XXX", hex(off_i), hex(off)
+        if off == off_dst:
+            return off_i
+        return None
+
+class pattern_jmp_long_x86(pattern_call_x86):
+    patterns = ["\xE9"]
+
+class pattern_jmp_short_x86(pattern_call_x86):
+    patterns = ["\xEB"]
+    @classmethod
+    def test_candidate(cls, in_str, off_i, off_dst):
+        off = off_i + 2 + struct.unpack('b', in_str[off_i+1:off_i+2])[0]
+        #print "XXX", hex(off_i), hex(off)
+        if off == off_dst:
+            return off_i
+        return None
+
 
-class find_call_xref:
-    def __init__(self, e, off):
+class find_pattern:
+    def __init__(self, in_str, off_dst, find_class):
         import re
-        self.e = e
-        self.off = off
-        #create itertor to find simple CALL offsets
-        p = re.escape("\xE8")
-        self.my_iter = re.finditer(p, e.content)
+        self.in_str = in_str
+        self.off_dst = off_dst
+        if not type(find_class) is list:
+            find_class = [find_class]
+        self.find_classes = find_class
+        self.class_index = 0
+        self.ad = 0
     def next(self):
-        while True:
-            off_i = self.my_iter.next().start()
-            off = off_i + 5 + struct.unpack('i', self.e.content[off_i+1:off_i+5])[0]
-            if off == self.off:
-                return off_i
+        while self.class_index < len(self.find_classes):
+            find_class = self.find_classes[self.class_index]
+            for p in find_class.patterns:
+                while True:
+                    #off_i = self.my_iter.next().start()
+                    self.ad = self.in_str.find(p, self.ad)
+                    if self.ad == -1:
+                        break
+                    off = find_class.test_candidate(self.in_str, self.ad, self.off_dst)
+                    self.ad +=1
+                    if off:
+                        #print 'found', hex(off)
+                        return off
+            self.class_index+=1
+            self.ad = 0
         raise StopIteration
     def __iter__(self):
         return self
diff --git a/miasm/tools/to_c_helper.py b/miasm/tools/to_c_helper.py
index 50d79d0b..37bf5324 100644
--- a/miasm/tools/to_c_helper.py
+++ b/miasm/tools/to_c_helper.py
@@ -1121,6 +1121,16 @@ class bin_stream_vm():
     def setoffset(self, val):
         val = val & 0xFFFFFFFF
         self.offset = val
+    def __getitem__(self, item):
+        if not type(item) is slice: # integer
+            self.offset = item
+            return self.readbs(1)
+        start = item.start
+        stop  = item.stop
+        step  = item.step
+        self.offset = start
+        s = self.readbs(stop-start)
+        return s[::step]
 
 
 
@@ -1131,11 +1141,25 @@ updw = lambda bbbb: struct.unpack('I', bbbb)[0]
 pw = lambda x: struct.pack('H', x)
 upw = lambda x: struct.unpack('H', x)[0]
 
+base_dll_imp = ["ntdll.dll",  "kernel32.dll",   "user32.dll",
+               "imm32.dll",    "msvcrt.dll",
+               "oleaut32.dll", "shlwapi.dll",
+               "version.dll",  "advapi32.dll",
+               "ws2help.dll",
+               "rpcrt4.dll",   "shell32.dll", "winmm.dll",
+               #"mswsock.dll",
+               "ws2_32.dll",
+               "gdi32.dll",   "ole32.dll",
+               "secur32.dll",  "comdlg32.dll",
+               #"wsock32.dll"
+               ]
+
 
 def load_pe_in_vm(fname_in, options, all_imp_dll = None, **kargs):
     import os
     import seh_helper
     import win_api
+    global base_dll_imp
     from miasm.tools import pe_helper
     from miasm.tools import codenat
 
@@ -1149,22 +1173,14 @@ def load_pe_in_vm(fname_in, options, all_imp_dll = None, **kargs):
     codenat_tcc_init()
     runtime_dll = pe_helper.libimp(kargs.get('runtime_basead', 0x71111000))
 
-    pe_helper.vm_load_pe(e, align_s = False, load_hdr = options.loadhdr)
+    align_s = False
+    if 'align_s' in kargs:
+        align_s = kargs['align_s']
+    pe_helper.vm_load_pe(e, align_s = align_s, load_hdr = options.loadhdr)
 
     if all_imp_dll == None:
         if options.loadbasedll:
-            all_imp_dll = ["ntdll.dll",  "kernel32.dll",   "user32.dll",
-                           "imm32.dll",    "msvcrt.dll",
-                           "oleaut32.dll", "shlwapi.dll",
-                           "version.dll",  "advapi32.dll",
-                           "ws2help.dll",
-                           "rpcrt4.dll",   "shell32.dll", "winmm.dll",
-                           #"mswsock.dll",
-                           "ws2_32.dll",
-                           "gdi32.dll",   "ole32.dll",
-                           "secur32.dll",  "comdlg32.dll",
-                           #"wsock32.dll"
-                           ]
+            all_imp_dll = base_dll_imp
         else:
             all_imp_dll = []
 
@@ -1174,7 +1190,7 @@ def load_pe_in_vm(fname_in, options, all_imp_dll = None, **kargs):
     for n in mod_list:
         fname = os.path.join('win_dll', n)
         ee = pe_init.PE(open(fname, 'rb').read())
-        pe_helper.vm_load_pe(ee, align_s = False)
+        pe_helper.vm_load_pe(ee, align_s = align_s)
         runtime_dll.add_export_lib(ee, n)
         exp_funcs = pe_helper.get_export_name_addr_list(ee)
         exp_func[n] = exp_funcs
@@ -1260,10 +1276,11 @@ def vm2pe(fname, runtime_dll = None, e_orig = None, max_addr = 1<<64):
     # generation
     open(fname, 'w').write(str(mye))
 
-def manage_runtime_func(my_eip, api_modues, runtime_dll):
+def manage_runtime_func(my_eip, api_modues, runtime_dll, dbg = False):
     from miasm.tools import win_api
     fname = runtime_dll.fad2cname[my_eip]
-    print "call api", fname, hex(updw(vm_get_str(vm_get_gpreg()['esp'], 4)))
+    if dbg:
+        print "call api", fname, hex(updw(vm_get_str(vm_get_gpreg()['esp'], 4)))
     f = None
     for m in api_modues:
         if isinstance(m, dict):
diff --git a/miasm/tools/win_api.py b/miasm/tools/win_api.py
index 053a7f35..741f1c7f 100644
--- a/miasm/tools/win_api.py
+++ b/miasm/tools/win_api.py
@@ -176,7 +176,7 @@ class mdl:
 def get_str_ansi(ad_str, max_char = None):
     l = 0
     tmp = ad_str
-    while vm_get_str(tmp, 1) != "\x00":
+    while (max_char == None or l < max_char) and vm_get_str(tmp, 1) != "\x00":
         tmp +=1
         l+=1
     return vm_get_str(ad_str, l)
@@ -184,7 +184,7 @@ def get_str_ansi(ad_str, max_char = None):
 def get_str_unic(ad_str, max_char = None):
     l = 0
     tmp = ad_str
-    while vm_get_str(tmp, 2) != "\x00\x00":
+    while (max_char == None or l < max_char) and vm_get_str(tmp, 2) != "\x00\x00":
         tmp +=2
         l+=2
     return vm_get_str(ad_str, l)
@@ -505,7 +505,7 @@ def user32_BlockInput():
     regs['eax'] = 1
     vm_set_gpreg(regs)
 
-def advapi32_CryptAcquireContextA():
+def advapi32_CryptAcquireContext(funcname, get_str):
     ret_ad = vm_pop_uint32_t()
     phprov = vm_pop_uint32_t()
     pszcontainer = vm_pop_uint32_t()
@@ -513,10 +513,12 @@ def advapi32_CryptAcquireContextA():
     dwprovtype = vm_pop_uint32_t()
     dwflags = vm_pop_uint32_t()
 
-    print whoami(), hex(ret_ad), '(', hex(phprov), hex(pszcontainer), hex(pszprovider), hex(dwprovtype), hex(dwflags), ')'
+    print funcname, hex(ret_ad), '(', hex(phprov), hex(pszcontainer), hex(pszprovider), hex(dwprovtype), hex(dwflags), ')'
 
-    prov = vm_get_str(pszprovider, 0x100)
-    prov = prov[:prov.find('\x00')]
+    if pszprovider:
+        prov = get_str(pszprovider)
+    else:
+        prov = "NONE"
     print 'prov:', prov
     vm_set_mem(phprov, pdw(winobjs.cryptcontext_hwnd))
 
@@ -526,6 +528,12 @@ def advapi32_CryptAcquireContextA():
     vm_set_gpreg(regs)
 
 
+def advapi32_CryptAcquireContextA():
+    advapi32_CryptAcquireContext(whoami(), get_str_ansi)
+def advapi32_CryptAcquireContextW():
+    advapi32_CryptAcquireContext(whoami(), get_str_unic)
+
+
 def advapi32_CryptCreateHash():
     ret_ad = vm_pop_uint32_t()
     hprov = vm_pop_uint32_t()
@@ -572,6 +580,48 @@ def advapi32_CryptHashData():
     vm_set_gpreg(regs)
 
 
+def advapi32_CryptGetHashParam():
+    ret_ad = vm_pop_uint32_t()
+    hhash = vm_pop_uint32_t()
+    param = vm_pop_uint32_t()
+    pbdata = vm_pop_uint32_t()
+    dwdatalen = vm_pop_uint32_t()
+    dwflags = vm_pop_uint32_t()
+
+    print whoami(), hex(ret_ad), '(', hex(hhash), hex(pbdata), hex(dwdatalen), hex(dwflags), ')'
+
+    if not hhash in winobjs.cryptcontext:
+        raise ValueError("unknown crypt context")
+
+
+    if param == 2:
+        # XXX todo: save h state?
+        h = winobjs.cryptcontext[hhash].h.digest()
+    else:
+        raise ValueError('not impl', param)
+    vm_set_mem(pbdata, h)
+    vm_set_mem(dwdatalen, pdw(len(h)))
+
+    regs = vm_get_gpreg()
+    regs['eip'] = ret_ad
+    regs['eax'] = 1
+    vm_set_gpreg(regs)
+
+
+
+def advapi32_CryptReleaseContext():
+    ret_ad = vm_pop_uint32_t()
+    hhash = vm_pop_uint32_t()
+    flags = vm_pop_uint32_t()
+
+    print whoami(), hex(ret_ad), '(', hex(hhash), hex(flags), ')'
+
+    regs = vm_get_gpreg()
+    regs['eip'] = ret_ad
+    regs['eax'] = 0
+    vm_set_gpreg(regs)
+
+
 def advapi32_CryptDeriveKey():
     ret_ad = vm_pop_uint32_t()
     hprov = vm_pop_uint32_t()
@@ -667,6 +717,8 @@ def kernel32_CreateFile(funcname, get_str):
 
 def kernel32_CreateFileA():
     kernel32_CreateFile(whoami(), get_str_ansi)
+def kernel32_CreateFileW():
+    kernel32_CreateFile(whoami(), lambda x:get_str_unic(x)[::2])
 
 
 
@@ -1051,13 +1103,13 @@ def kernel32_LoadLibraryW():
     vm_set_gpreg(regs)
 
 
-def kernel32_GetModuleHandleA():
+def kernel32_GetModuleHandle(funcname, get_str):
     ret_ad = vm_pop_uint32_t()
     dllname = vm_pop_uint32_t()
-    print whoami(), hex(ret_ad), hex(dllname)
+    print funcname, hex(ret_ad), hex(dllname)
 
     if dllname:
-        libname = get_str_ansi(dllname)
+        libname = get_str(dllname)
         print repr(libname)
         if libname:
             eax = winobjs.runtime_dll.lib_get_add_base(libname)
@@ -1072,6 +1124,12 @@ def kernel32_GetModuleHandleA():
     regs['eax'] = eax
     vm_set_gpreg(regs)
 
+def kernel32_GetModuleHandleA():
+    kernel32_GetModuleHandle(whoami(), get_str_ansi)
+def kernel32_GetModuleHandleW():
+    kernel32_GetModuleHandle(whoami(), lambda x:get_str_unic(x)[::2])
+
+
 def kernel32_VirtualLock():
     ret_ad = vm_pop_uint32_t()
     lpaddress = vm_pop_uint32_t()