about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--miasm/arch/java_arch.py191
1 files changed, 151 insertions, 40 deletions
diff --git a/miasm/arch/java_arch.py b/miasm/arch/java_arch.py
index 19b201a2..efc8d032 100644
--- a/miasm/arch/java_arch.py
+++ b/miasm/arch/java_arch.py
@@ -51,6 +51,9 @@ ARG_ARRAYTYPE_TYPES = {
     11:"long"
 }
 
+AFS_symb = "symb__intern__"
+AFS_imm = "imm"
+
 class mnemonic:
     def __init__(self, name, code, size, desc, fmt=None,
                  breakflow=False, splitflow=False, dstflow=False):
@@ -385,7 +388,7 @@ mnemonic('impdep2', 255, 1, 'reserved for implementation-dependent operations wi
 
 
 class java_mnemo_metaclass(type):
-    rebuilt_inst = True
+    rebuilt_inst = False
     
     def dis(cls, op, admode=None, sex=0, offset=0, ):
         i = cls.__new__(cls)
@@ -398,6 +401,54 @@ class java_mnemo_metaclass(type):
         i = cls.__new__(cls)
         i.__init__(sex)
         return i._asm(l, symbol_reloc_off, address=address)
+    
+    def asm_instr(cls, l, sex=0):
+        i = cls.__new__(cls)
+        i.__init__(sex)
+        i._asm_instr(l)
+        return i
+    
+    def fix_symbol(cls, a, symbol_pool = None):
+        if not AFS_symb in a: return a
+        cp = a.copy()
+        if not symbol_pool:
+            del cp[AFS_symb]
+            if not AFS_imm in cp:
+                cp[AFS_imm] = 0
+            return cp
+        raise Exception('.fix_symbol() cannot handle that for now (and should not have too do so).')
+    
+    def is_mem(cls, a): return False
+    
+    def get_label(cls, a):
+        if not AFS_symb in a:
+            return None
+        n = a[AFS_symb]
+        if len(n) != 1:
+            return None
+        k = n.keys()[0]
+        if n[k] != 1:
+            return None
+        return k
+    
+    def has_symb(cls, a):
+        return AFS_symb in a
+    
+    def get_symbols(cls, a):
+        if AFS_symb in a:
+            return a[AFS_symb].items()
+        return None
+    
+    def names2symbols(cls, a, s_dict):
+        all_s = a[AFS_symb]
+        for name, s in s_dict.items():
+            count = all_s[name]
+            del(all_s[name])
+            all_s[s] = count
+    
+    def parse_address(cls, a):
+        if a.isdigit(): return {AFS_imm: int(a)}
+        return {AFS_symb: {a: 1}}
 
 
 class java_mn:
@@ -414,44 +465,88 @@ class java_mn:
     
     def getnextflow(self):
         return self.offset + self.m.size
-
+    
     def getdstflow(self):
         if len(self.arg) == 1:
-            return [ self.offset + self.arg[0] ]
-        if self.m.name == 'tableswitch':
-            return map(lambda x: self.offset + x, self.arg[:1]+self.arg[3:])
-        if self.m.name == 'lookupswitch':
-            return map(lambda x: self.offset + x, self.arg[:1]+[ self.arg[2*i+3] for i in range(len(self.arg[2:])/2) ])
-        raise ValueError('incorrect (?) dstflow intruction.')
+            dsts = [ self.arg[0] ]
+        elif self.m.name == 'tableswitch':
+            dsts = self.arg[:1]+self.arg[3:]
+        elif self.m.name == 'lookupswitch':
+            dsts = self.arg[:1]+[ self.arg[2*i+3] for i in range(len(self.arg[2:])/2) ]
+        out = []
+        for d in dsts:
+            if type(d) is int:
+                out.append(self.offset + d)
+            elif not AFS_symb in d:
+                out.append(self.offset + a[AFS_imm])
+            else:
+                out.append(d)
+        return out
     
     def setdstflow(self, dst):
         if len(self.arg) == 1:
-            self.arg = dst
+            self.arg = [{AFS_symb:{dst[0]:1}}]
         elif self.m.name == 'tableswitch':
             self.arg = [dst[0], self.arg[1], self.arg[2]] + dst[1:]
         elif self.m.name == 'lookupswitch':
             self.arg = [dst[0], self.arg[1]] + reduce(lambda x, y: x+y, [ [self.arg[2*i+2], dst[i+1]] for i in range(len(dst)-1) ])
     
-    def fixdst(self, *args): pass
-
+    def fixdst(self, lbls, my_offset, is_mem):
+        dsts = [0]
+        if self.m.name == 'tableswitch':
+            dsts += range(3, len(args))
+        elif self.m.name == 'lookupswitch':
+            dsts += range(3, len(args), 2)
+        newarg = []
+        for i, a in enumerate(self.arg):
+            if not i in dsts:
+                newarg.append(a)
+                continue
+            offset = lbls[a[AFS_symb].keys()[0].name]
+            if self.m.size == 0:
+                self.fixsize()
+            newarg.append({AFS_imm:offset-(my_offset)+self.m.size})
+        self.arg = newarg
+    
+    def fixsize(self):
+        if self.m.name.endswith('switch'):
+            self.size = 4 * len(self.arg) + 1 # opcode + args
+            self.size +=  ((4 - ((self.offset+1) % 4)) % 4) # align
+        else:
+            raise ValueError(".fixsize() should not be called for %s." % self.m.name)
+    
     def set_args_symbols(self, cpsymbols={}):
         self.arg = self.m.argfmt.resolve(self.arg, cpsymbols=cpsymbols)
-
+    
     def is_subcall(self):
         return self.m.name.startswith('jsr') or self.m.name.startswith('invoke')
     
     def __str__(self):
+        arg = []
+        for a in self.arg:
+            if type(a) is not dict:
+                arg.append(a)
+                continue
+            if len(a) == 1:
+                if AFS_imm in a:
+                    arg.append(a[AFS_imm])
+                    continue
+                elif AFS_symb in a and len(a[AFS_symb]) == 1:
+                    arg.append(a[AFS_symb].keys()[0])
+                    continue
+            log.warning('Weird argument spotted while assembling %s %r' % (self.m.name, self.arg))
+            arg.append(0)
         if self.m.name == 'tableswitch':
-            out = "tableswitch    %d %d\n" % (self.arg[1], self.arg[2])
-            out += "    %s: %s\n" % ('default', self.arg[0])
-            out += "\n".join(["    %-8s %s" % (str(i)+':', self.arg[3+i-int(self.arg[1])]) for i in range(int(self.arg[1]), int(self.arg[2])+1)])
+            out = "tableswitch    %d %d\n" % (arg[1], arg[2])
+            out += "    %s: %s\n" % ('default', arg[0])
+            out += "\n".join(["    %-8s %s" % (str(i)+':', arg[3+i-int(arg[1])]) for i in range(int(arg[1]), int(arg[2])+1)])
             return out
         if self.m.name == 'lookupswitch':
-            out = "lookupswitch   %s\n" % self.arg[1]
-            out += "    %s:\t%s\n" % ('default', self.arg[0])
-            out += "\n".join(["    %-8s %s" % (str(self.arg[2*i+2])+':', self.arg[2*i+3]) for i in range(len(self.arg)/2-1) ])
+            out = "lookupswitch   %s\n" % arg[1]
+            out += "    %s:\t%s\n" % ('default', arg[0])
+            out += "\n".join(["    %-8s %s" % (str(arg[2*i+2])+':', arg[2*i+3]) for i in range(len(arg)/2-1) ])
             return out
-        return "%-15s" % self.m.name + " ".join(map(str, self.arg))
+        return "%-15s" % self.m.name + " ".join(map(str, arg))
     
     def _dis(self, bin, offset=0):
         if type(bin) is str:
@@ -465,8 +560,11 @@ class java_mn:
             log.warning(e.message)
             return False
         return True
-
-    def _asm(self, txt, symbol_reloc_off={}, address=0):
+    
+    @classmethod
+    def parse_mnemo(cls, txt):
+        if ';' in txt: txt = txt[:txt.index(';')]
+        txt = txt.strip()
         txt = filter(lambda x: x != ',', list(shlex(txt)))
         t = []
         r = ''
@@ -476,26 +574,36 @@ class java_mn:
             else:
                 t.append(r+l)
                 r = ''
-        mnemo = mnemo_db_name[t[0]]
+        return None, t[0], t[1:]
+    
+    def _asm_instr(self, txt, address=0):
+        p, mn, t = self.parse_mnemo(txt)
+        self.m = mnemo_db_name[mn]
+        self.arg = t
+        self.offset = address
+    
+    def _asm(self, txt, symbol_reloc_off={}, address=0):
+        p, mn, t = self.parse_mnemo(txt)
+        mnemo = mnemo_db_name[mn]
         if mnemo.name == 'tableswitch':
             table = {}
             dflt = None
-            if len(t) % 3 == 0:
+            if len(t) % 3 == 2:
                 # 'tableswitch' has the second (optional) argument
-                arg = t[1:3]
-                rest = t[3:]
-            elif len(t) % 3 == 2:
+                arg = t[0:2]
+                rest = t[2:]
+            elif len(t) % 3 == 1:
                 # 'tableswitch' does not have the second argument ; we
                 # will have to set it
-                arg = [ t[1] ]
-                rest = t[2:]
+                arg = t[:1]
+                rest = t[1:]
             else:
                 # 'tableswitch' have the second argument plus the "to"
                 # keyword, just before.
-                if t[2] != "to":
-                    log.warning("Wrong argument format for tableswitch instruction: expecting 'to', but got '%s'" % t[2])
-                arg = [ t[1], t[3] ]
-                rest = t[4:]
+                if t[1] != "to":
+                    log.warning("Wrong argument format for tableswitch instruction: expecting 'to', but got '%s'" % t[1])
+                arg = [ t[0], t[2] ]
+                rest = t[3:]
             for i in range(len(rest)/3):
                 k = rest[3*i]
                 v = rest[3*i+2]
@@ -522,14 +630,14 @@ class java_mn:
             for k in range(min(keys), len(keys)):
                 arg.append(table[str(k)])
         elif mnemo.name == 'lookupswitch':
-            nbr = int(t[1])
+            nbr = int(t[0])
             table = {}
             dflt = None
-            for i in range(len(t[2:])/3):
-                k = t[3*i+2]
-                v = t[3*i+4]
-                if t[3*i+3] != ':':
-                    log.warning("Invalid lookupswitch format: expecting ':', but got '%s'." %  t[3*i+3])
+            for i in range(len(t[1:])/3):
+                k = t[3*i+1]
+                v = t[3*i+3]
+                if t[3*i+2] != ':':
+                    log.warning("Invalid lookupswitch format: expecting ':', but got '%s'." %  t[3*i+2])
                 if k == 'default': dflt = v
                 else:
                     if k in table:
@@ -542,5 +650,8 @@ class java_mn:
             for k in table:
                 arg += [ k, table[k] ]
         else:
-            arg = t[1:]
-        return chr(mnemo.code) + mnemo.argfmt.set(arg, address=address)
+            arg = t
+        try:
+            return [ chr(mnemo.code) + mnemo.argfmt.set(arg, address=address) ]
+        except: pass
+        return [ chr(mnemo.code) + mnemo.argfmt.set([0]*len(arg), address=address) ]