about summary refs log tree commit diff stats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--miasm2/analysis/mem.py75
-rw-r--r--test/analysis/mem.py38
2 files changed, 93 insertions, 20 deletions
diff --git a/miasm2/analysis/mem.py b/miasm2/analysis/mem.py
index 7dda9041..ce389f89 100644
--- a/miasm2/analysis/mem.py
+++ b/miasm2/analysis/mem.py
@@ -231,6 +231,9 @@ class MemField(object):
     def __len__(self):
         return self.size()
 
+    def __neq__(self, other):
+        return not self == other
+
 
 class Struct(MemField):
     """Dumb struct.pack/unpack field. Mainly used to factorize code.
@@ -253,6 +256,9 @@ class Struct(MemField):
     def __repr__(self):
         return "%s(%s)" % (self.__class__.__name__, self._fmt)
 
+    def __eq__(self, other):
+        return self.__class__ == other.__class__ and self._fmt == other._fmt
+
 
 class Num(Struct):
     """Represents a number (integer or float). The number is encoded with
@@ -339,6 +345,12 @@ class Ptr(Num):
     def __repr__(self):
         return "%s(%r)" % (self.__class__.__name__, self._dst_type)
 
+    def __eq__(self, other):
+        return super(Ptr, self).__eq__(other) and \
+                self.dst_type == other.dst_type and \
+                self._type_args == other._type_args and \
+                self._type_kwargs == other._type_kwargs
+
 
 class Inline(MemField):
     """Field used to inline a MemStruct in another MemStruct. Equivalent to
@@ -381,6 +393,12 @@ class Inline(MemField):
     def __repr__(self):
         return "%s(%r)" % (self.__class__.__name__, self._il_type)
 
+    def __eq__(self, other):
+        return self.__class__ == other.__class__ and \
+                self._il_type == other._il_type and \
+                self._type_args == other._type_args and \
+                self._type_kwargs == other._type_kwargs
+
 
 class Array(MemField):
     """A fixed size array (contiguous sequence) of a MemField subclass
@@ -398,43 +416,48 @@ class Array(MemField):
         mystruct.array = MemSizedArray(vm, addr2, Num("B"), 4)
     """
 
-    def __init__(self, field_type, length):
-        self._field_type = field_type
-        self._array_len = length
+    def __init__(self, field_type, array_len):
+        self.field_type = field_type
+        self.array_len = array_len
 
     def _set_self_type(self, self_type):
         super(Array, self)._set_self_type(self_type)
-        self._field_type._set_self_type(self_type)
+        self.field_type._set_self_type(self_type)
 
     def set(self, vm, addr, val):
         # MemSizedArray assignment
         if isinstance(val, MemSizedArray):
-            if val.array_len != self._array_len or len(val) != self.size():
+            if val.array_len != self.array_len or len(val) != self.size():
                 raise ValueError("Size mismatch in MemSizedArray assignment")
             raw = str(val)
             vm.set_mem(addr, raw)
 
         # list assignment
         elif isinstance(val, list):
-            if len(val) != self._array_len:
+            if len(val) != self.array_len:
                 raise ValueError("Size mismatch in MemSizedArray assignment ")
             offset = 0
             for elt in val:
-                self._field_type.set(vm, addr + offset, elt)
-                offset += self._field_type.size()
+                self.field_type.set(vm, addr + offset, elt)
+                offset += self.field_type.size()
 
         else:
             raise NotImplementedError(
                 "Assignment only implemented for list and MemSizedArray")
 
     def get(self, vm, addr):
-        return MemSizedArray(vm, addr, self._field_type, self._array_len)
+        return MemSizedArray(vm, addr, self.field_type, self.array_len)
 
     def size(self):
-        return self._field_type.size() * self._array_len
+        return self.field_type.size() * self.array_len
 
     def __repr__(self):
-        return "%r[%s]" % (self._field_type, self._array_len)
+        return "%r[%s]" % (self.field_type, self.array_len)
+
+    def __eq__(self, other):
+        return self.__class__ == other.__class__ and \
+                self.field_type == other.field_type and \
+                self.array_len == other.array_len
 
 
 class Union(MemField):
@@ -478,6 +501,10 @@ class Union(MemField):
                                 for name, field in self.field_list)
         return "%s(%s)" % (self.__class__.__name__, fields_repr)
 
+    def __eq__(self, other):
+        return self.__class__ == other.__class__ and \
+                self.field_list == other.field_list
+
 
 class Bits(MemField):
     """Helper class for BitField, not very useful on its own. Represents some
@@ -533,6 +560,12 @@ class Bits(MemField):
         return "%s%r(%d:%d)" % (self.__class__.__name__, self._num,
                                 self._bit_offset, self._bit_offset + self._bits)
 
+    def __eq__(self, other):
+        return self.__class__ == other.__class__ and \
+                self._num == other._num and self._bits == other._bits and \
+                self._bit_offset == other._bit_offset
+
+
 class BitField(Union):
     """A C-like bitfield.
 
@@ -580,6 +613,10 @@ class BitField(Union):
     def get(self, vm, addr):
         return self._num.get(vm, addr)
 
+    def __eq__(self, other):
+        return self.__class__ == other.__class__ and \
+                self._num == other._num and super(BitField, self).__eq__(other)
+
 
 # MemStruct classes
 
@@ -1101,7 +1138,7 @@ def mem_array_type(field_type):
 
 
 class MemSizedArray(MemArray):
-    """A fixed size MemArray. Its additional arg represents the @length (in
+    """A fixed size MemArray. Its additional arg represents the @array_len (in
     number of elements) of this array.
 
     This type is dynamically sized. Use mem_sized_array_type to generate a
@@ -1109,15 +1146,15 @@ class MemSizedArray(MemArray):
     """
     _array_len = None
 
-    def __init__(self, vm, addr=None, field_type=None, length=None):
+    def __init__(self, vm, addr=None, field_type=None, array_len=None):
         # Set the length before anything else to allow get_size() to work for
         # allocation
         if self._array_len is None:
-            self._array_len = length
+            self._array_len = array_len
         super(MemSizedArray, self).__init__(vm, addr, field_type)
         if self._array_len is None or self._field_type is None:
             raise NotImplementedError(
-                "Provide field_type and length to instanciate this class, "
+                "Provide field_type and array_len to instanciate this class, "
                 "or generate a subclass with mem_sized_array_type.")
 
     @property
@@ -1159,18 +1196,18 @@ class MemSizedArray(MemArray):
         return "[%s] [%r; %s]" % (items, self._field_type, self._array_len)
 
 
-def mem_sized_array_type(field_type, length):
+def mem_sized_array_type(field_type, array_len):
     """Generate a MemSizedArray subclass that has a fixed @field_type and a
-    fixed @length. This allows to instanciate the returned type with only
+    fixed @array_len. This allows to instanciate the returned type with only
     the vm and addr arguments, as are standard MemStructs.
     """
     @classmethod
     def sizeof(cls):
         return cls._field_type.size() * cls._array_len
 
-    array_type = type('MemSizedArray_%r_%s' % (field_type, length),
+    array_type = type('MemSizedArray_%r_%s' % (field_type, array_len),
                       (MemSizedArray,),
-                      {'_array_len': length,
+                      {'_array_len': array_len,
                        '_field_type': field_type,
                        'sizeof': sizeof})
     return array_type
diff --git a/test/analysis/mem.py b/test/analysis/mem.py
index df1df9bc..a3642a4f 100644
--- a/test/analysis/mem.py
+++ b/test/analysis/mem.py
@@ -8,7 +8,8 @@ from miasm2.analysis.machine import Machine
 from miasm2.analysis.mem import MemStruct, Num, Ptr, MemStr, MemArray,\
                                 MemSizedArray, Array, mem_array_type,\
                                 mem_sized_array_type, Struct, Inline, mem,\
-                                Union, BitField, MemSelf, MemVoid, set_allocator
+                                Union, BitField, MemSelf, MemVoid, Bits, \
+                                set_allocator
 from miasm2.jitter.csts import PAGE_READ, PAGE_WRITE
 from miasm2.os_dep.common import heap
 
@@ -430,6 +431,41 @@ p.value = mstruct.get_addr()
 assert p.deref_value.cast(MyStruct) == mstruct
 assert p.cast(MemPtrMyStruct).deref_value == mstruct
 
+# Field equality tests
+assert Struct("IH") == Struct("IH")
+assert Struct("I") != Struct("IH")
+assert Num("I") == Num("I")
+assert Num(">I") != Num("<I")
+assert Ptr("I", MyStruct) == Ptr("I", MyStruct)
+assert Ptr(">I", MyStruct) != Ptr("<I", MyStruct)
+assert Ptr("I", MyStruct) != Ptr("I", MyStruct2)
+assert Inline(MyStruct) == Inline(MyStruct)
+assert Inline(MyStruct) != Inline(MyStruct2)
+assert Array(Num("H"), 12) == Array(Num("H"), 12)
+assert Array(Num("H"), 11) != Array(Num("H"), 12)
+assert Array(Num("I"), 12) != Array(Num("H"), 12)
+assert Union([("f1", Num("B")), ("f2", Num("H"))]) == \
+        Union([("f1", Num("B")), ("f2", Num("H"))])
+assert Union([("f2", Num("B")), ("f2", Num("H"))]) != \
+        Union([("f1", Num("B")), ("f2", Num("H"))])
+assert Union([("f1", Num("B")), ("f2", Num("H"))]) != \
+        Union([("f1", Num("I")), ("f2", Num("H"))])
+assert Bits(Num("I"), 3, 8) == Bits(Num("I"), 3, 8)
+assert Bits(Num("I"), 3, 8) != Bits(Num("I"), 3, 8)
+assert Bits(Num("H"), 2, 8) != Bits(Num("I"), 3, 8)
+assert Bits(Num("I"), 3, 7) != Bits(Num("I"), 3, 8)
+assert BitField(Num("B"), [("f1", 2), ("f2", 4), ("f3", 1)]) == \
+        BitField(Num("B"), [("f1", 2), ("f2", 4), ("f3", 1)])
+assert BitField(Num("H"), [("f1", 2), ("f2", 4), ("f3", 1)]) != \
+        BitField(Num("B"), [("f1", 2), ("f2", 4), ("f3", 1)])
+assert BitField(Num("B"), [("f2", 2), ("f2", 4), ("f3", 1)]) != \
+        BitField(Num("B"), [("f1", 2), ("f2", 4), ("f3", 1)])
+assert BitField(Num("B"), [("f1", 1), ("f2", 4), ("f3", 1)]) != \
+        BitField(Num("B"), [("f1", 2), ("f2", 4), ("f3", 1)])
+
+
+# Repr tests
+
 print "Some struct reprs:\n"
 print repr(mstruct), '\n'
 print repr(ms2), '\n'