about summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorAjax <commial@gmail.com>2017-03-31 15:09:01 +0200
committerAjax <commial@gmail.com>2017-04-06 13:47:38 +0200
commit620c96e891d0ad356332713a23b39b9d2382470c (patch)
tree7ce86a2fcc502800a4c426a43cb362e17ed80002
parentb1ed94019554b25d4d8924594f8868318e8a8c4a (diff)
downloadmiasm-620c96e891d0ad356332713a23b39b9d2382470c.tar.gz
miasm-620c96e891d0ad356332713a23b39b9d2382470c.zip
Introduce a naive "System V" calling convention
-rw-r--r--miasm2/arch/aarch64/jit.py4
-rw-r--r--miasm2/arch/arm/jit.py13
-rw-r--r--miasm2/arch/x86/jit.py59
-rw-r--r--miasm2/os_dep/linux_stdlib.py56
-rw-r--r--test/arch/x86/qemu/testqemu.py8
5 files changed, 81 insertions, 59 deletions
diff --git a/miasm2/arch/aarch64/jit.py b/miasm2/arch/aarch64/jit.py
index 255bb91d..e3f3e3fa 100644
--- a/miasm2/arch/aarch64/jit.py
+++ b/miasm2/arch/aarch64/jit.py
@@ -56,6 +56,10 @@ class jitter_aarch64l(jitter):
             arg = self.get_stack_arg(index - self.max_reg_arg)
         return arg
 
+    func_args_systemv = func_args_stdcall
+    func_ret_systemv = func_ret_stdcall
+    get_arg_n_systemv = get_arg_n_stdcall
+
     def init_run(self, *args, **kwargs):
         jitter.init_run(self, *args, **kwargs)
         self.cpu.PC = self.pc
diff --git a/miasm2/arch/arm/jit.py b/miasm2/arch/arm/jit.py
index 70c708e1..e0d08679 100644
--- a/miasm2/arch/arm/jit.py
+++ b/miasm2/arch/arm/jit.py
@@ -34,11 +34,7 @@ class jitter_arml(jitter):
 
     @named_arguments
     def func_args_stdcall(self, n_args):
-        args = []
-        for i in xrange(min(n_args, 4)):
-            args.append(self.cpu.get_gpreg()['R%d' % i])
-        for i in xrange(max(0, n_args - 4)):
-            args.append(self.get_stack_arg(i))
+        args = [self.get_arg_n_stdcall(i) for i in xrange(n_args)]
         ret_ad = self.cpu.LR
         return ret_ad, args
 
@@ -48,13 +44,18 @@ class jitter_arml(jitter):
             self.cpu.R0 = ret_value
         return True
 
+
     def get_arg_n_stdcall(self, index):
         if index < 4:
-            arg = self.cpu.get_gpreg()['R%d' % index]
+            arg = getattr(self.cpu, 'R%d' % index)
         else:
             arg = self.get_stack_arg(index-4)
         return arg
 
+    func_args_systemv = func_args_stdcall
+    func_ret_systemv = func_ret_stdcall
+    get_arg_n_systemv = get_arg_n_stdcall
+
     def init_run(self, *args, **kwargs):
         jitter.init_run(self, *args, **kwargs)
         self.cpu.PC = self.pc
diff --git a/miasm2/arch/x86/jit.py b/miasm2/arch/x86/jit.py
index cfdabf8c..4f50315f 100644
--- a/miasm2/arch/x86/jit.py
+++ b/miasm2/arch/x86/jit.py
@@ -92,6 +92,10 @@ class jitter_x86_32(jitter):
     def get_stack_arg(self, index):
         return upck32(self.vm.get_mem(self.cpu.ESP + 4 * index, 4))
 
+    def init_run(self, *args, **kwargs):
+        jitter.init_run(self, *args, **kwargs)
+        self.cpu.EIP = self.pc
+
     # calling conventions
 
     # stdcall
@@ -108,6 +112,8 @@ class jitter_x86_32(jitter):
         if ret_value2 is not None:
             self.cpu.EDX = ret_value2
 
+    get_arg_n_stdcall = get_stack_arg
+
     # cdecl
     @named_arguments
     def func_args_cdecl(self, n_args):
@@ -115,18 +121,23 @@ class jitter_x86_32(jitter):
         args = [self.get_stack_arg(i) for i in xrange(n_args)]
         return ret_ad, args
 
-    def func_ret_cdecl(self, ret_addr, ret_value):
+    def func_ret_cdecl(self, ret_addr, ret_value=None):
         self.cpu.EIP = ret_addr
-        self.cpu.EAX = ret_value
+        if ret_value is not None:
+            self.cpu.EAX = ret_value
 
-    def init_run(self, *args, **kwargs):
-        jitter.init_run(self, *args, **kwargs)
-        self.cpu.EIP = self.pc
+    get_arg_n_cdecl = get_stack_arg
+
+    # System V
+    func_args_systemv = func_args_cdecl
+    func_ret_systemv = func_ret_cdecl
+    get_arg_n_systemv = get_stack_arg
 
 
 class jitter_x86_64(jitter):
 
     C_Gen = x86_64_CGen
+    args_regs_systemv = ['RDI', 'RSI', 'RDX', 'RCX', 'R8', 'R9']
 
     def __init__(self, *args, **kwargs):
         sp = asmblock.AsmSymbolPool()
@@ -152,6 +163,13 @@ class jitter_x86_64(jitter):
     def get_stack_arg(self, index):
         return upck64(self.vm.get_mem(self.cpu.RSP + 8 * index, 8))
 
+    def init_run(self, *args, **kwargs):
+        jitter.init_run(self, *args, **kwargs)
+        self.cpu.RIP = self.pc
+
+    # calling conventions
+
+    # stdcall
     @named_arguments
     def func_args_stdcall(self, n_args):
         args_regs = ['RCX', 'RDX', 'R8', 'R9']
@@ -169,23 +187,22 @@ class jitter_x86_64(jitter):
             self.cpu.RAX = ret_value
         return True
 
+    # cdecl
+    func_args_cdecl = func_args_stdcall
+    func_ret_cdecl = func_ret_stdcall
+
+    # System V
+
+    def get_arg_n_systemv(self, index):
+        args_regs = self.args_regs_systemv
+        if index < len(args_regs):
+            return getattr(self.cpu, args_regs[index])
+        return self.get_stack_arg(index - len(args_regs))
+
     @named_arguments
-    def func_args_cdecl(self, n_args):
-        args_regs = ['RCX', 'RDX', 'R8', 'R9']
+    def func_args_systemv(self, n_args):
         ret_ad = self.pop_uint64_t()
-        args = []
-        for i in xrange(min(n_args, 4)):
-            args.append(self.cpu.get_gpreg()[args_regs[i]])
-        for i in xrange(max(0, n_args - 4)):
-            args.append(self.get_stack_arg(i))
+        args = [self.get_arg_n_systemv(index) for index in xrange(n_args)]
         return ret_ad, args
 
-    def func_ret_cdecl(self, ret_addr, ret_value=None):
-        self.pc = self.cpu.RIP = ret_addr
-        if ret_value is not None:
-            self.cpu.RAX = ret_value
-        return True
-
-    def init_run(self, *args, **kwargs):
-        jitter.init_run(self, *args, **kwargs)
-        self.cpu.RIP = self.pc
+    func_ret_systemv = func_ret_cdecl
diff --git a/miasm2/os_dep/linux_stdlib.py b/miasm2/os_dep/linux_stdlib.py
index b05b2cd9..683104d0 100644
--- a/miasm2/os_dep/linux_stdlib.py
+++ b/miasm2/os_dep/linux_stdlib.py
@@ -25,9 +25,9 @@ def xxx_isprint(jitter):
 
     checks for any printable character including space.
     '''
-    ret_addr, args = jitter.func_args_stdcall(['c'])
+    ret_addr, args = jitter.func_args_systemv(['c'])
     ret = 1 if chr(args.c & 0xFF) in printable else 0
-    return jitter.func_ret_stdcall(ret_addr, ret)
+    return jitter.func_ret_systemv(ret_addr, ret)
 
 
 def xxx_memcpy(jitter):
@@ -37,9 +37,9 @@ def xxx_memcpy(jitter):
 
     copies n bytes from memory area src to memory area dest.
     '''
-    ret_addr, args = jitter.func_args_stdcall(['dest', 'src', 'n'])
+    ret_addr, args = jitter.func_args_systemv(['dest', 'src', 'n'])
     jitter.vm.set_mem(args.dest, jitter.vm.get_mem(args.src, args.n))
-    return jitter.func_ret_stdcall(ret_addr, args.dest)
+    return jitter.func_ret_systemv(ret_addr, args.dest)
 
 
 def xxx_memset(jitter):
@@ -50,9 +50,9 @@ def xxx_memset(jitter):
     fills the first n bytes of the memory area pointed to by s with the constant
     byte c.'''
 
-    ret_addr, args = jitter.func_args_stdcall(['dest', 'c', 'n'])
+    ret_addr, args = jitter.func_args_systemv(['dest', 'c', 'n'])
     jitter.vm.set_mem(args.dest, chr(args.c & 0xFF) * args.n)
-    return jitter.func_ret_stdcall(ret_addr, args.dest)
+    return jitter.func_ret_systemv(ret_addr, args.dest)
 
 
 def xxx_puts(jitter):
@@ -62,7 +62,7 @@ def xxx_puts(jitter):
 
     writes the string s and a trailing newline to stdout.
     '''
-    ret_addr, args = jitter.func_args_stdcall(['s'])
+    ret_addr, args = jitter.func_args_systemv(['s'])
     index = args.s
     char = jitter.vm.get_mem(index, 1)
     while char != '\x00':
@@ -70,7 +70,7 @@ def xxx_puts(jitter):
         index += 1
         char = jitter.vm.get_mem(index, 1)
     stdout.write('\n')
-    return jitter.func_ret_stdcall(ret_addr, 1)
+    return jitter.func_ret_systemv(ret_addr, 1)
 
 
 def get_fmt_args(jitter, fmt, cur_arg):
@@ -89,9 +89,9 @@ def get_fmt_args(jitter, fmt, cur_arg):
                 if char.lower() in '%cdfsux':
                     break
             if token.endswith('s'):
-                arg = jitter.get_str_ansi(jitter.get_arg_n_stdcall(cur_arg))
+                arg = jitter.get_str_ansi(jitter.get_arg_n_systemv(cur_arg))
             else:
-                arg = jitter.get_arg_n_stdcall(cur_arg)
+                arg = jitter.get_arg_n_systemv(cur_arg)
             char = token % arg
             cur_arg += 1
         output += char
@@ -99,67 +99,67 @@ def get_fmt_args(jitter, fmt, cur_arg):
 
 
 def xxx_snprintf(jitter):
-    ret_addr, args = jitter.func_args_stdcall(['string', 'size', 'fmt'])
+    ret_addr, args = jitter.func_args_systemv(['string', 'size', 'fmt'])
     cur_arg, fmt = 3, args.fmt
     size = args.size if args.size else 1
     output = get_fmt_args(jitter, fmt, cur_arg)
     output = output[:size - 1]
     ret = len(output)
     jitter.vm.set_mem(args.string, output + '\x00')
-    return jitter.func_ret_stdcall(ret_addr, ret)
+    return jitter.func_ret_systemv(ret_addr, ret)
 
 
 def xxx_sprintf(jitter):
-    ret_addr, args = jitter.func_args_stdcall(['string', 'fmt'])
+    ret_addr, args = jitter.func_args_systemv(['string', 'fmt'])
     cur_arg, fmt = 2, args.fmt
     output = get_fmt_args(jitter, fmt, cur_arg)
     ret = len(output)
     jitter.vm.set_mem(args.string, output + '\x00')
-    return jitter.func_ret_stdcall(ret_addr, ret)
+    return jitter.func_ret_systemv(ret_addr, ret)
 
 
 def xxx_printf(jitter):
-    ret_addr, args = jitter.func_args_stdcall(['fmt'])
+    ret_addr, args = jitter.func_args_systemv(['fmt'])
     cur_arg, fmt = 1, args.fmt
     output = get_fmt_args(jitter, fmt, cur_arg)
     ret = len(output)
     print output,
-    return jitter.func_ret_stdcall(ret_addr, ret)
+    return jitter.func_ret_systemv(ret_addr, ret)
 
 
 def xxx_strcpy(jitter):
-    ret_ad, args = jitter.func_args_stdcall(["dst", "src"])
+    ret_ad, args = jitter.func_args_systemv(["dst", "src"])
     str_src = jitter.get_str_ansi(args.src) + '\x00'
     jitter.vm.set_mem(args.dst, str_src)
-    jitter.func_ret_stdcall(ret_ad, args.dst)
+    jitter.func_ret_systemv(ret_ad, args.dst)
 
 
 def xxx_strlen(jitter):
-    ret_ad, args = jitter.func_args_stdcall(["src"])
+    ret_ad, args = jitter.func_args_systemv(["src"])
     str_src = jitter.get_str_ansi(args.src)
-    jitter.func_ret_stdcall(ret_ad, len(str_src))
+    jitter.func_ret_systemv(ret_ad, len(str_src))
 
 
 def xxx_malloc(jitter):
-    ret_ad, args = jitter.func_args_stdcall(["msize"])
+    ret_ad, args = jitter.func_args_systemv(["msize"])
     addr = linobjs.heap.alloc(jitter, args.msize)
-    jitter.func_ret_stdcall(ret_ad, addr)
+    jitter.func_ret_systemv(ret_ad, addr)
 
 
 def xxx_free(jitter):
-    ret_ad, args = jitter.func_args_stdcall(["ptr"])
-    jitter.func_ret_stdcall(ret_ad, 0)
+    ret_ad, args = jitter.func_args_systemv(["ptr"])
+    jitter.func_ret_systemv(ret_ad, 0)
 
 
 def xxx_strcmp(jitter):
-    ret_ad, args = jitter.func_args_stdcall(["ptr_str1", "ptr_str2"])
+    ret_ad, args = jitter.func_args_systemv(["ptr_str1", "ptr_str2"])
     s1 = jitter.get_str_ansi(args.ptr_str1)
     s2 = jitter.get_str_ansi(args.ptr_str2)
-    jitter.func_ret_stdcall(ret_ad, cmp(s1, s2))
+    jitter.func_ret_systemv(ret_ad, cmp(s1, s2))
 
 
 def xxx_strncmp(jitter):
-    ret_ad, args = jitter.func_args_stdcall(["ptr_str1", "ptr_str2", "size"])
+    ret_ad, args = jitter.func_args_systemv(["ptr_str1", "ptr_str2", "size"])
     s1 = jitter.get_str_ansi(args.ptr_str1, args.size)
     s2 = jitter.get_str_ansi(args.ptr_str2, args.size)
-    jitter.func_ret_stdcall(ret_ad, cmp(s1, s2))
+    jitter.func_ret_systemv(ret_ad, cmp(s1, s2))
diff --git a/test/arch/x86/qemu/testqemu.py b/test/arch/x86/qemu/testqemu.py
index 5f26d6f3..e6c487f2 100644
--- a/test/arch/x86/qemu/testqemu.py
+++ b/test/arch/x86/qemu/testqemu.py
@@ -40,7 +40,7 @@ nb_tests = 1
 def xxx___printf_chk(jitter):
     """Tiny implementation of printf_chk"""
     global nb_tests
-    ret_ad, args = jitter.func_args_cdecl(["out", "format"])
+    ret_ad, args = jitter.func_args_systemv(["out", "format"])
     if args.out != 1:
         raise RuntimeError("Not implemented")
     fmt = jitter.get_str_ansi(args.format)
@@ -89,7 +89,7 @@ def xxx___printf_chk(jitter):
 
     sys.stdout.write("[%d] %s" % (nb_tests, output))
     nb_tests += 1
-    jitter.func_ret_cdecl(ret_ad, 0)
+    jitter.func_ret_systemv(ret_ad, 0)
 
 def xxx_puts(jitter):
     '''
@@ -98,7 +98,7 @@ def xxx_puts(jitter):
 
     writes the string s and a trailing newline to stdout.
     '''
-    ret_addr, args = jitter.func_args_cdecl(['target'])
+    ret_addr, args = jitter.func_args_systemv(['target'])
     output = jitter.get_str_ansi(args.target)
     # Check with expected result
     line = expected.next()
@@ -106,7 +106,7 @@ def xxx_puts(jitter):
         print "Expected:", line
         print "Obtained:", output
         raise RuntimeError("Bad semantic")
-    return jitter.func_ret_cdecl(ret_addr, 1)
+    return jitter.func_ret_systemv(ret_addr, 1)
 
 # Parse arguments
 parser = Sandbox_Linux_x86_32.parser(description="ELF sandboxer")