summary refs log tree commit diff stats
path: root/scripts/decodetree.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/decodetree.py')
-rw-r--r--scripts/decodetree.py265
1 files changed, 246 insertions, 19 deletions
diff --git a/scripts/decodetree.py b/scripts/decodetree.py
index a03dc6b5e3..13db585d04 100644
--- a/scripts/decodetree.py
+++ b/scripts/decodetree.py
@@ -35,6 +35,7 @@ arguments = {}
 formats = {}
 allpatterns = []
 anyextern = False
+testforerror = False
 
 translate_prefix = 'trans'
 translate_scope = 'static '
@@ -53,6 +54,80 @@ re_fld_ident = '%[a-zA-Z0-9_]*'
 re_fmt_ident = '@[a-zA-Z0-9_]*'
 re_pat_ident = '[a-zA-Z0-9_]*'
 
+# Local implementation of a topological sort. We use the same API that
+# the Python graphlib does, so that when QEMU moves forward to a
+# baseline of Python 3.9 or newer this code can all be dropped and
+# replaced with:
+#    from graphlib import TopologicalSorter, CycleError
+#
+# https://docs.python.org/3.9/library/graphlib.html#graphlib.TopologicalSorter
+#
+# We only implement the parts of TopologicalSorter we care about:
+#  ts = TopologicalSorter(graph=None)
+#    create the sorter. graph is a dictionary whose keys are
+#    nodes and whose values are lists of the predecessors of that node.
+#    (That is, if graph contains "A" -> ["B", "C"] then we must output
+#    B and C before A.)
+#  ts.static_order()
+#    returns a list of all the nodes in sorted order, or raises CycleError
+#  CycleError
+#    exception raised if there are cycles in the graph. The second
+#    element in the args attribute is a list of nodes which form a
+#    cycle; the first and last element are the same, eg [a, b, c, a]
+#    (Our implementation doesn't give the order correctly.)
+#
+# For our purposes we can assume that the data set is always small
+# (typically 10 nodes or less, actual links in the graph very rare),
+# so we don't need to worry about efficiency of implementation.
+#
+# The core of this implementation is from
+# https://code.activestate.com/recipes/578272-topological-sort/
+# (but updated to Python 3), and is under the MIT license.
+
+class CycleError(ValueError):
+    """Subclass of ValueError raised if cycles exist in the graph"""
+    pass
+
+class TopologicalSorter:
+    """Topologically sort a graph"""
+    def __init__(self, graph=None):
+        self.graph = graph
+
+    def static_order(self):
+        # We do the sort right here, unlike the stdlib version
+        from functools import reduce
+        data = {}
+        r = []
+
+        if not self.graph:
+            return []
+
+        # This code wants the values in the dict to be specifically sets
+        for k, v in self.graph.items():
+            data[k] = set(v)
+
+        # Find all items that don't depend on anything.
+        extra_items_in_deps = (reduce(set.union, data.values())
+                               - set(data.keys()))
+        # Add empty dependencies where needed
+        data.update({item:{} for item in extra_items_in_deps})
+        while True:
+            ordered = set(item for item, dep in data.items() if not dep)
+            if not ordered:
+                break
+            r.extend(ordered)
+            data = {item: (dep - ordered)
+                    for item, dep in data.items()
+                        if item not in ordered}
+        if data:
+            # This doesn't give as nice results as the stdlib, which
+            # gives you the cycle by listing the nodes in order. Here
+            # we only know the nodes in the cycle but not their order.
+            raise CycleError(f'nodes are in a cycle', list(data.keys()))
+
+        return r
+# end TopologicalSorter
+
 def error_with_file(file, lineno, *args):
     """Print an error message from file:line and args and exit."""
     global output_file
@@ -70,8 +145,13 @@ def error_with_file(file, lineno, *args):
 
     if output_file and output_fd:
         output_fd.close()
-        os.remove(output_file)
-    exit(1)
+        # Do not try to remove e.g. -o /dev/null
+        if not output_file.startswith("/dev"):
+            try:
+                os.remove(output_file)
+            except PermissionError:
+                pass
+    exit(0 if testforerror else 1)
 # end error_with_file
 
 
@@ -205,11 +285,14 @@ class Field:
             s = ''
         return str(self.pos) + ':' + s + str(self.len)
 
-    def str_extract(self):
+    def str_extract(self, lvalue_formatter):
         global bitop_width
         s = 's' if self.sign else ''
         return f'{s}extract{bitop_width}(insn, {self.pos}, {self.len})'
 
+    def referenced_fields(self):
+        return []
+
     def __eq__(self, other):
         return self.sign == other.sign and self.mask == other.mask
 
@@ -228,12 +311,12 @@ class MultiField:
     def __str__(self):
         return str(self.subs)
 
-    def str_extract(self):
+    def str_extract(self, lvalue_formatter):
         global bitop_width
         ret = '0'
         pos = 0
         for f in reversed(self.subs):
-            ext = f.str_extract()
+            ext = f.str_extract(lvalue_formatter)
             if pos == 0:
                 ret = ext
             else:
@@ -241,6 +324,12 @@ class MultiField:
             pos += f.len
         return ret
 
+    def referenced_fields(self):
+        l = []
+        for f in self.subs:
+            l.extend(f.referenced_fields())
+        return l
+
     def __ne__(self, other):
         if len(self.subs) != len(other.subs):
             return True
@@ -264,9 +353,12 @@ class ConstField:
     def __str__(self):
         return str(self.value)
 
-    def str_extract(self):
+    def str_extract(self, lvalue_formatter):
         return str(self.value)
 
+    def referenced_fields(self):
+        return []
+
     def __cmp__(self, other):
         return self.value - other.value
 # end ConstField
@@ -283,8 +375,12 @@ class FunctionField:
     def __str__(self):
         return self.func + '(' + str(self.base) + ')'
 
-    def str_extract(self):
-        return self.func + '(ctx, ' + self.base.str_extract() + ')'
+    def str_extract(self, lvalue_formatter):
+        return (self.func + '(ctx, '
+                + self.base.str_extract(lvalue_formatter) + ')')
+
+    def referenced_fields(self):
+        return self.base.referenced_fields()
 
     def __eq__(self, other):
         return self.func == other.func and self.base == other.base
@@ -304,9 +400,12 @@ class ParameterField:
     def __str__(self):
         return self.func
 
-    def str_extract(self):
+    def str_extract(self, lvalue_formatter):
         return self.func + '(ctx)'
 
+    def referenced_fields(self):
+        return []
+
     def __eq__(self, other):
         return self.func == other.func
 
@@ -314,6 +413,32 @@ class ParameterField:
         return not self.__eq__(other)
 # end ParameterField
 
+class NamedField:
+    """Class representing a field already named in the pattern"""
+    def __init__(self, name, sign, len):
+        self.mask = 0
+        self.sign = sign
+        self.len = len
+        self.name = name
+
+    def __str__(self):
+        return self.name
+
+    def str_extract(self, lvalue_formatter):
+        global bitop_width
+        s = 's' if self.sign else ''
+        lvalue = lvalue_formatter(self.name)
+        return f'{s}extract{bitop_width}({lvalue}, 0, {self.len})'
+
+    def referenced_fields(self):
+        return [self.name]
+
+    def __eq__(self, other):
+        return self.name == other.name
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+# end NamedField
 
 class Arguments:
     """Class representing the extracted fields of a format"""
@@ -337,7 +462,6 @@ class Arguments:
             output('} ', self.struct_name(), ';\n\n')
 # end Arguments
 
-
 class General:
     """Common code between instruction formats and instruction patterns"""
     def __init__(self, name, lineno, base, fixb, fixm, udfm, fldm, flds, w):
@@ -351,12 +475,59 @@ class General:
         self.fieldmask = fldm
         self.fields = flds
         self.width = w
+        self.dangling = None
 
     def __str__(self):
         return self.name + ' ' + str_match_bits(self.fixedbits, self.fixedmask)
 
     def str1(self, i):
         return str_indent(i) + self.__str__()
+
+    def dangling_references(self):
+        # Return a list of all named references which aren't satisfied
+        # directly by this format/pattern. This will be either:
+        #  * a format referring to a field which is specified by the
+        #    pattern(s) using it
+        #  * a pattern referring to a field which is specified by the
+        #    format it uses
+        #  * a user error (referring to a field that doesn't exist at all)
+        if self.dangling is None:
+            # Compute this once and cache the answer
+            dangling = []
+            for n, f in self.fields.items():
+                for r in f.referenced_fields():
+                    if r not in self.fields:
+                        dangling.append(r)
+            self.dangling = dangling
+        return self.dangling
+
+    def output_fields(self, indent, lvalue_formatter):
+        # We use a topological sort to ensure that any use of NamedField
+        # comes after the initialization of the field it is referencing.
+        graph = {}
+        for n, f in self.fields.items():
+            refs = f.referenced_fields()
+            graph[n] = refs
+
+        try:
+            ts = TopologicalSorter(graph)
+            for n in ts.static_order():
+                # We only want to emit assignments for the keys
+                # in our fields list, not for anything that ends up
+                # in the tsort graph only because it was referenced as
+                # a NamedField.
+                try:
+                    f = self.fields[n]
+                    output(indent, lvalue_formatter(n), ' = ',
+                           f.str_extract(lvalue_formatter), ';\n')
+                except KeyError:
+                    pass
+        except CycleError as e:
+            # The second element of args is a list of nodes which form
+            # a cycle (there might be others too, but only one is reported).
+            # Pretty-print it to tell the user.
+            cycle = ' => '.join(e.args[1])
+            error(self.lineno, 'field definitions form a cycle: ' + cycle)
 # end General
 
 
@@ -370,8 +541,7 @@ class Format(General):
     def output_extract(self):
         output('static void ', self.extract_name(), '(DisasContext *ctx, ',
                self.base.struct_name(), ' *a, ', insntype, ' insn)\n{\n')
-        for n, f in self.fields.items():
-            output('    a->', n, ' = ', f.str_extract(), ';\n')
+        self.output_fields(str_indent(4), lambda n: 'a->' + n)
         output('}\n\n')
 # end Format
 
@@ -392,11 +562,36 @@ class Pattern(General):
         ind = str_indent(i)
         arg = self.base.base.name
         output(ind, '/* ', self.file, ':', str(self.lineno), ' */\n')
+        # We might have named references in the format that refer to fields
+        # in the pattern, or named references in the pattern that refer
+        # to fields in the format. This affects whether we extract the fields
+        # for the format before or after the ones for the pattern.
+        # For simplicity we don't allow cross references in both directions.
+        # This is also where we catch the syntax error of referring to
+        # a nonexistent field.
+        fmt_refs = self.base.dangling_references()
+        for r in fmt_refs:
+            if r not in self.fields:
+                error(self.lineno, f'format refers to undefined field {r}')
+        pat_refs = self.dangling_references()
+        for r in pat_refs:
+            if r not in self.base.fields:
+                error(self.lineno, f'pattern refers to undefined field {r}')
+        if pat_refs and fmt_refs:
+            error(self.lineno, ('pattern that uses fields defined in format '
+                                'cannot use format that uses fields defined '
+                                'in pattern'))
+        if fmt_refs:
+            # pattern fields first
+            self.output_fields(ind, lambda n: 'u.f_' + arg + '.' + n)
+            assert not extracted, "dangling fmt refs but it was already extracted"
         if not extracted:
             output(ind, self.base.extract_name(),
                    '(ctx, &u.f_', arg, ', insn);\n')
-        for n, f in self.fields.items():
-            output(ind, 'u.f_', arg, '.', n, ' = ', f.str_extract(), ';\n')
+        if not fmt_refs:
+            # pattern fields last
+            self.output_fields(ind, lambda n: 'u.f_' + arg + '.' + n)
+
         output(ind, 'if (', translate_prefix, '_', self.name,
                '(ctx, &u.f_', arg, ')) return true;\n')
 
@@ -473,7 +668,7 @@ class MultiPattern(General):
 
     def prop_format(self):
         for p in self.pats:
-            p.build_tree()
+            p.prop_format()
 
     def prop_width(self):
         width = None
@@ -505,6 +700,12 @@ class IncMultiPattern(MultiPattern):
                 output(ind, '}\n')
             else:
                 p.output_code(i, extracted, p.fixedbits, p.fixedmask)
+
+    def build_tree(self):
+        if not self.pats:
+            error_with_file(self.file, self.lineno, 'empty pattern group')
+        super().build_tree()
+
 #end IncMultiPattern
 
 
@@ -536,8 +737,10 @@ class Tree:
         ind = str_indent(i)
 
         # If we identified all nodes below have the same format,
-        # extract the fields now.
-        if not extracted and self.base:
+        # extract the fields now. But don't do it if the format relies
+        # on named fields from the insn pattern, as those won't have
+        # been initialised at this point.
+        if not extracted and self.base and not self.base.dangling_references():
             output(ind, self.base.extract_name(),
                    '(ctx, &u.f_', self.base.base.name, ', insn);\n')
             extracted = True
@@ -623,7 +826,7 @@ class ExcMultiPattern(MultiPattern):
         return t
 
     def build_tree(self):
-        super().prop_format()
+        super().build_tree()
         self.tree = self.__build_tree(self.pats, self.fixedbits,
                                       self.fixedmask)
 
@@ -659,6 +862,7 @@ def parse_field(lineno, name, toks):
     """Parse one instruction field from TOKS at LINENO"""
     global fields
     global insnwidth
+    global re_C_ident
 
     # A "simple" field will have only one entry;
     # a "multifield" will have several.
@@ -673,6 +877,25 @@ def parse_field(lineno, name, toks):
             func = func[1]
             continue
 
+        if re.fullmatch(re_C_ident + ':s[0-9]+', t):
+            # Signed named field
+            subtoks = t.split(':')
+            n = subtoks[0]
+            le = int(subtoks[1])
+            f = NamedField(n, True, le)
+            subs.append(f)
+            width += le
+            continue
+        if re.fullmatch(re_C_ident + ':[0-9]+', t):
+            # Unsigned named field
+            subtoks = t.split(':')
+            n = subtoks[0]
+            le = int(subtoks[1])
+            f = NamedField(n, False, le)
+            subs.append(f)
+            width += le
+            continue
+
         if re.fullmatch('[0-9]+:s[0-9]+', t):
             # Signed field extract
             subtoks = t.split(':s')
@@ -1286,11 +1509,12 @@ def main():
     global bitop_width
     global variablewidth
     global anyextern
+    global testforerror
 
     decode_scope = 'static '
 
     long_opts = ['decode=', 'translate=', 'output=', 'insnwidth=',
-                 'static-decode=', 'varinsnwidth=']
+                 'static-decode=', 'varinsnwidth=', 'test-for-error']
     try:
         (opts, args) = getopt.gnu_getopt(sys.argv[1:], 'o:vw:', long_opts)
     except getopt.GetoptError as err:
@@ -1319,6 +1543,8 @@ def main():
                 bitop_width = 64
             elif insnwidth != 32:
                 error(0, 'cannot handle insns of width', insnwidth)
+        elif o == '--test-for-error':
+            testforerror = True
         else:
             assert False, 'unhandled option'
 
@@ -1417,6 +1643,7 @@ def main():
 
     if output_file:
         output_fd.close()
+    exit(1 if testforerror else 0)
 # end main