about summary refs log tree commit diff stats
path: root/miasm2/core/objc.py
diff options
context:
space:
mode:
authorFabrice Desclaux <fabrice.desclaux@cea.fr>2017-08-01 14:22:34 +0200
committerFabrice Desclaux <fabrice.desclaux@cea.fr>2017-08-07 16:42:35 +0200
commit8052568e455d8336bf640226c084329378917342 (patch)
treeca75c5c17e99442c86217e65f03e2d99ed689d44 /miasm2/core/objc.py
parentfd58654a591c4c745a1edb25f619fd3f69ebde74 (diff)
downloadmiasm-8052568e455d8336bf640226c084329378917342.tar.gz
miasm-8052568e455d8336bf640226c084329378917342.zip
Objc: add ObjC cmp
Diffstat (limited to 'miasm2/core/objc.py')
-rw-r--r--miasm2/core/objc.py127
1 files changed, 83 insertions, 44 deletions
diff --git a/miasm2/core/objc.py b/miasm2/core/objc.py
index ca005da6..a29fab3a 100644
--- a/miasm2/core/objc.py
+++ b/miasm2/core/objc.py
@@ -36,6 +36,19 @@ class ObjC(object):
                 self.align == other.align and
                 self.size == other.size)
 
+    def cmp_base(self, other):
+        assert self.__class__ in OBJC_PRIO
+        assert other.__class__ in OBJC_PRIO
+
+        if OBJC_PRIO[self.__class__] != OBJC_PRIO[other.__class__]:
+            return cmp(OBJC_PRIO[self.__class__], OBJC_PRIO[other.__class__])
+        if self.align != other.align:
+            return cmp(self.align, other.align)
+        return cmp(self.size, other.size)
+
+    def __str__(self):
+        return objc_to_str(self)
+
 
 class ObjCDecl(ObjC):
     """C Declaration identified"""
@@ -50,10 +63,11 @@ class ObjCDecl(ObjC):
     def __str__(self):
         return '%s' % (self.name)
 
-    def __eq__(self, other):
-        if not self.eq_base(other):
-            return False
-        return self.name == other.name
+    def __cmp__(self, other):
+        ret = self.cmp_base(other)
+        if ret:
+            return ret
+        return cmp(self.name, other.name)
 
 
 class ObjCInt(ObjC):
@@ -67,8 +81,8 @@ class ObjCInt(ObjC):
     def __str__(self):
         return 'int'
 
-    def __eq__(self, other):
-        return self.eq_base(other)
+    def __cmp__(self, other):
+        return self.cmp_base(other)
 
 
 class ObjCPtr(ObjC):
@@ -109,10 +123,11 @@ class ObjCPtr(ObjC):
         else:
             return '*%s' % (target)
 
-    def __eq__(self, other):
-        if not self.eq_base(other):
-            return False
-        return self.objtype == other.objtype
+    def __cmp__(self, other):
+        ret = self.cmp_base(other)
+        if ret:
+            return ret
+        return cmp(self.objtype, other.objtype)
 
 
 class ObjCArray(ObjC):
@@ -137,11 +152,14 @@ class ObjCArray(ObjC):
     def __str__(self):
         return '%s[%d]' % (self.objtype, self.elems)
 
-    def __eq__(self, other):
-        if not self.eq_base(other):
-            return False
-        return (self.elems == other.elems and
-                self.objtype == other.objtype)
+    def __cmp__(self, other):
+        ret = self.cmp_base(other)
+        if ret:
+            return ret
+        ret = cmp(self.elems, other.elems)
+        if ret:
+            return ret
+        return cmp(self.objtype, other.objtype)
 
 
 class ObjCStruct(ObjC):
@@ -174,16 +192,18 @@ class ObjCStruct(ObjC):
     def __str__(self):
         return 'struct %s' % (self.name)
 
-    def __eq__(self, other):
-        if not (self.eq_base(other) and self.name == other.name):
-            return False
-        if len(self.fields) != len(other.fields):
-            return False
+    def __cmp__(self, other):
+        ret = self.cmp_base(other)
+        if ret:
+            return ret
+        ret = cmp(len(self.fields), len(other.fields))
+        if ret:
+            return ret
         for field_a, field_b in zip(self.fields, other.fields):
-            if field_a != field_b:
-                return False
-        return True
-
+            ret = cmp(field_a, field_b)
+            if ret:
+                return ret
+        return 0
 
 class ObjCUnion(ObjC):
     """C object for unions"""
@@ -215,16 +235,18 @@ class ObjCUnion(ObjC):
     def __str__(self):
         return 'union %s' % (self.name)
 
-    def __eq__(self, other):
-        if not (self.eq_base(other) and self.name == other.name):
-            return False
-        if len(self.fields) != len(other.fields):
-            return False
+    def __cmp__(self, other):
+        ret = self.cmp_base(other)
+        if ret:
+            return ret
+        ret = cmp(len(self.fields), len(other.fields))
+        if ret:
+            return ret
         for field_a, field_b in zip(self.fields, other.fields):
-            if field_a != field_b:
-                return False
-        return True
-
+            ret = cmp(field_a, field_b)
+            if ret:
+                return ret
+        return 0
 
 class ObjCEllipsis(ObjC):
     """C integer"""
@@ -234,8 +256,8 @@ class ObjCEllipsis(ObjC):
         self.size = None
         self.align = None
 
-    def __eq__(self, other):
-        return self.eq_base(other)
+    def __cmp__(self, other):
+        return self.cmp_base(other)
 
 
 class ObjCFunc(ObjC):
@@ -263,16 +285,33 @@ class ObjCFunc(ObjC):
             out.append("  %s" % arg)
         return '\n'.join(out)
 
-    def __eq__(self, other):
-        if not (self.eq_base(other) and self.name == other.name and
-                self.type_ret == other.type_ret):
-            return False
-        if len(self.args) != len(other.args):
-            return False
+    def __cmp__(self, other):
+        ret = self.cmp_base(other)
+        if ret:
+            return ret
+        ret = cmp(self.name, other.name)
+        if ret:
+            return ret
+        ret = cmp(len(self.args), len(other.args))
+        if ret:
+            return ret
         for arg_a, arg_b in zip(self.args, other.args):
-            if arg_a != arg_b:
-                return False
-        return True
+            ret = cmp(arg_a, arg_b)
+            if ret:
+                return ret
+        return 0
+
+OBJC_PRIO = {
+    ObjC: 0,
+    ObjCDecl:1,
+    ObjCInt:2,
+    ObjCPtr:3,
+    ObjCArray:4,
+    ObjCStruct:5,
+    ObjCUnion:6,
+    ObjCEllipsis:7,
+    ObjCFunc:8,
+}
 
 
 def access_simplifier(expr):