summary refs log tree commit diff stats
path: root/scripts/qapi/schema.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/qapi/schema.py')
-rw-r--r--scripts/qapi/schema.py92
1 files changed, 59 insertions, 33 deletions
diff --git a/scripts/qapi/schema.py b/scripts/qapi/schema.py
index cf0045f34e..0bfc5256fb 100644
--- a/scripts/qapi/schema.py
+++ b/scripts/qapi/schema.py
@@ -50,9 +50,6 @@ class QAPISchemaEntity(object):
 
     def check(self, schema):
         assert not self._checked
-        if self.info:
-            self._module = os.path.relpath(self.info.fname,
-                                           os.path.dirname(schema.fname))
         seen = {}
         for f in self.features:
             f.check_clash(self.info, seen)
@@ -68,15 +65,18 @@ class QAPISchemaEntity(object):
         if self.doc:
             self.doc.check()
 
-    @property
-    def ifcond(self):
+    def _set_module(self, schema, info):
         assert self._checked
-        return self._ifcond
+        self._module = schema.module_by_fname(info and info.fname)
+        self._module.add_entity(self)
+
+    def set_module(self, schema):
+        self._set_module(schema, self.info)
 
     @property
-    def module(self):
+    def ifcond(self):
         assert self._checked
-        return self._module
+        return self._ifcond
 
     def is_implicit(self):
         return not self.info
@@ -135,15 +135,29 @@ class QAPISchemaVisitor(object):
         pass
 
 
-class QAPISchemaInclude(QAPISchemaEntity):
+class QAPISchemaModule(object):
+    def __init__(self, name):
+        self.name = name
+        self._entity_list = []
+
+    def add_entity(self, ent):
+        self._entity_list.append(ent)
+
+    def visit(self, visitor):
+        visitor.visit_module(self.name)
+        for entity in self._entity_list:
+            if visitor.visit_needed(entity):
+                entity.visit(visitor)
 
-    def __init__(self, fname, info):
+
+class QAPISchemaInclude(QAPISchemaEntity):
+    def __init__(self, sub_module, info):
         QAPISchemaEntity.__init__(self, None, info, None)
-        self.fname = fname
+        self._sub_module = sub_module
 
     def visit(self, visitor):
         QAPISchemaEntity.visit(self, visitor)
-        visitor.visit_include(self.fname, self.info)
+        visitor.visit_include(self._sub_module.name, self.info)
 
 
 class QAPISchemaType(QAPISchemaEntity):
@@ -276,16 +290,14 @@ class QAPISchemaArrayType(QAPISchemaType):
             self.info and self.info.defn_meta)
         assert not isinstance(self.element_type, QAPISchemaArrayType)
 
+    def set_module(self, schema):
+        self._set_module(schema, self.element_type.info)
+
     @property
     def ifcond(self):
         assert self._checked
         return self.element_type.ifcond
 
-    @property
-    def module(self):
-        assert self._checked
-        return self.element_type.module
-
     def is_implicit(self):
         return True
 
@@ -711,10 +723,11 @@ class QAPISchemaCommand(QAPISchemaEntity):
             self.ret_type = schema.resolve_type(
                 self._ret_type_name, self.info, "command's 'returns'")
             if self.name not in self.info.pragma.returns_whitelist:
-                if not (isinstance(self.ret_type, QAPISchemaObjectType)
-                        or (isinstance(self.ret_type, QAPISchemaArrayType)
-                            and isinstance(self.ret_type.element_type,
-                                           QAPISchemaObjectType))):
+                typ = self.ret_type
+                if isinstance(typ, QAPISchemaArrayType):
+                    typ = self.ret_type.element_type
+                    assert typ
+                if not isinstance(typ, QAPISchemaObjectType):
                     raise QAPISemError(
                         self.info,
                         "command's 'returns' cannot take %s"
@@ -782,6 +795,10 @@ class QAPISchema(object):
         self.docs = parser.docs
         self._entity_list = []
         self._entity_dict = {}
+        self._module_dict = {}
+        self._schema_dir = os.path.dirname(fname)
+        self._make_module(None) # built-ins
+        self._make_module(fname)
         self._predefining = True
         self._def_predefineds()
         self._predefining = False
@@ -825,14 +842,26 @@ class QAPISchema(object):
                 info, "%s uses unknown type '%s'" % (what, name))
         return typ
 
+    def _module_name(self, fname):
+        if fname is None:
+            return None
+        return os.path.relpath(fname, self._schema_dir)
+
+    def _make_module(self, fname):
+        name = self._module_name(fname)
+        if not name in self._module_dict:
+            self._module_dict[name] = QAPISchemaModule(name)
+        return self._module_dict[name]
+
+    def module_by_fname(self, fname):
+        name = self._module_name(fname)
+        assert name in self._module_dict
+        return self._module_dict[name]
+
     def _def_include(self, expr, info, doc):
         include = expr['include']
         assert doc is None
-        main_info = info
-        while main_info.parent:
-            main_info = main_info.parent
-        fname = os.path.relpath(include, os.path.dirname(main_info.fname))
-        self._def_entity(QAPISchemaInclude(fname, info))
+        self._def_entity(QAPISchemaInclude(self._make_module(include), info))
 
     def _def_builtin_type(self, name, json_type, c_type):
         self._def_entity(QAPISchemaBuiltinType(name, json_type, c_type))
@@ -1064,15 +1093,12 @@ class QAPISchema(object):
             ent.check(self)
             ent.connect_doc()
             ent.check_doc()
+        for ent in self._entity_list:
+            ent.set_module(self)
 
     def visit(self, visitor):
         visitor.visit_begin(self)
         module = None
-        visitor.visit_module(module)
-        for entity in self._entity_list:
-            if visitor.visit_needed(entity):
-                if entity.module != module:
-                    module = entity.module
-                    visitor.visit_module(module)
-                entity.visit(visitor)
+        for mod in self._module_dict.values():
+            mod.visit(visitor)
         visitor.visit_end()