about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--miasm2/core/types.py56
-rw-r--r--test/core/test_types.py5
2 files changed, 43 insertions, 18 deletions
diff --git a/miasm2/core/types.py b/miasm2/core/types.py
index 2108d446..374fc804 100644
--- a/miasm2/core/types.py
+++ b/miasm2/core/types.py
@@ -124,6 +124,29 @@ def set_allocator(alloc_func):
 
 # Helpers
 
+def to_type(obj):
+    """If possible, return the Type associated with @obj, otherwise raises
+    a ValueError.
+
+    Works with a Type instance (returns obj) or a MemType subclass or instance
+    (returns obj.get_type()).
+    """
+    # obj is a python type
+    if isinstance(obj, type):
+        if issubclass(obj, MemType):
+            if obj.get_type() is None:
+                raise ValueError("%r has no static type; use a subclasses "
+                                 "with a non null _type or use a "
+                                 "Type instance" % obj)
+            return obj.get_type()
+    # obj is not not a type
+    else:
+        if isinstance(obj, Type):
+            return obj
+        elif isinstance(obj, MemType):
+            return obj.get_type()
+    raise ValueError("%r is not a Type or a MemType" % obj)
+
 def indent(s, size=4):
     """Indent a string with @size spaces"""
     return ' '*size + ('\n' + ' '*size).join(s.split('\n'))
@@ -301,7 +324,7 @@ class Type(object):
 
     def _set_self_type(self, self_type):
         """If this field refers to MemSelf/Self, replace it with @self_type
-        (a MemType subclass) when using it. Generally not used outside this
+        (a Type instance) when using it. Generally not used outside this
         module.
         """
         self._self_type = self_type
@@ -383,7 +406,7 @@ class Ptr(Num):
         """
         @fmt: (str) Num compatible format that will be the Ptr representation
             in memory
-        @dst_type: (MemType or Type) the MemType this Ptr points to.
+        @dst_type: (MemType or Type) the Type this Ptr points to.
             If a Type is given, it is transformed into a MemType with
             TheType.lval.
         *type_args, **type_kwargs: arguments to pass to the the pointed
@@ -413,13 +436,12 @@ class Ptr(Num):
         self._type_kwargs = type_kwargs
 
     def _fix_dst_type(self):
-        if self._dst_type == MemSelf:
+        if self._dst_type in [MemSelf, SELF_TYPE_INSTANCE]:
             if self._get_self_type() is not None:
                 self._dst_type = self._get_self_type()
             else:
-                raise ValueError("Unsupported usecase for MemSelf, sorry")
-        if isinstance(self._dst_type, Type):
-            self._dst_type = self._dst_type.lval
+                raise ValueError("Unsupported usecase for (Mem)Self, sorry")
+        self._dst_type = to_type(self._dst_type)
 
     @property
     def dst_type(self):
@@ -450,15 +472,15 @@ class Ptr(Num):
         Equivalent to a pointer dereference rvalue in C.
         """
         dst_addr = self.get_val(vm, addr)
-        return self.dst_type(vm, dst_addr,
-                             *self._type_args, **self._type_kwargs)
+        return self.dst_type.lval(vm, dst_addr,
+                                  *self._type_args, **self._type_kwargs)
 
     def deref_set(self, vm, addr, val):
         """Serializes the @val MemType subclass instance in @vm (VmMngr) at
         @addr. Equivalent to a pointer dereference assignment in C.
         """
         # Sanity check
-        if self.dst_type != val.__class__:
+        if self.dst_type != val.get_type():
             log.warning("Original type was %s, overriden by value of type %s",
                         self._dst_type.__name__, val.__class__.__name__)
 
@@ -470,7 +492,7 @@ class Ptr(Num):
         return MemPtr
 
     def __repr__(self):
-        return "%s(%r)" % (self.__class__.__name__, self.dst_type.get_type())
+        return "%s(%r)" % (self.__class__.__name__, self.dst_type)
 
     def __eq__(self, other):
         return super(Ptr, self).__eq__(other) and \
@@ -538,12 +560,7 @@ class Struct(Type):
         real_fields = []
         uniq_count = 0
         for fname, field in fields:
-            if isinstance(field, type) and issubclass(field, MemType):
-                if field._type is None:
-                    raise ValueError("%r has no static type; use a subclasses "
-                                     "with a non null _type or use a "
-                                     "Type instance")
-                field = field.get_type()
+            field = to_type(field)
 
             # For reflexion
             field._set_self_type(self)
@@ -726,7 +743,8 @@ class Array(Type):
     """
 
     def __init__(self, field_type, array_len=None):
-        self.field_type = field_type
+        # Handle both Type instance and MemType subclasses
+        self.field_type = to_type(field_type)
         self.array_len = array_len
 
     def _set_self_type(self, self_type):
@@ -1113,6 +1131,10 @@ class Self(Void):
     def _build_pinned_type(self):
         return MemSelf
 
+# To avoid reinstanciation when testing equality
+SELF_TYPE_INSTANCE = Self()
+VOID_TYPE_INSTANCE = Void()
+
 
 # MemType classes
 
diff --git a/test/core/test_types.py b/test/core/test_types.py
index bb1d5da1..f6e5cb13 100644
--- a/test/core/test_types.py
+++ b/test/core/test_types.py
@@ -108,7 +108,7 @@ assert other == other2 # But same value
 ## Same stuff for Ptr to MemField
 alloc_addr = my_heap.vm_alloc(jitter.vm,
                               mstruct.get_type().get_field_type("i")
-                                     .dst_type.sizeof())
+                                     .dst_type.size)
 mstruct.i = alloc_addr
 mstruct.i.deref.val = 8
 assert mstruct.i.deref.val == 8
@@ -154,6 +154,9 @@ assert memstr3.val == memstr.val # But the python value is the same
 
 
 # Array tests
+# Construction methods
+assert Array(MyStruct) == Array(MyStruct.get_type())
+assert Array(MyStruct, 10) == Array(MyStruct.get_type(), 10)
 # Allocate buffer manually, since memarray is unsized
 alloc_addr = my_heap.vm_alloc(jitter.vm, 0x100)
 memarray = Array(Num("I")).lval(jitter.vm, alloc_addr)