summary refs log tree commit diff stats
path: root/scripts/decodetree.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/decodetree.py')
-rw-r--r--scripts/decodetree.py45
1 files changed, 28 insertions, 17 deletions
diff --git a/scripts/decodetree.py b/scripts/decodetree.py
index f85da45ee3..a03dc6b5e3 100644
--- a/scripts/decodetree.py
+++ b/scripts/decodetree.py
@@ -165,11 +165,15 @@ def is_contiguous(bits):
         return -1
 
 
-def eq_fields_for_args(flds_a, flds_b):
-    if len(flds_a) != len(flds_b):
+def eq_fields_for_args(flds_a, arg):
+    if len(flds_a) != len(arg.fields):
         return False
+    # Only allow inference on default types
+    for t in arg.types:
+        if t != 'int':
+            return False
     for k, a in flds_a.items():
-        if k not in flds_b:
+        if k not in arg.fields:
             return False
     return True
 
@@ -313,10 +317,11 @@ class ParameterField:
 
 class Arguments:
     """Class representing the extracted fields of a format"""
-    def __init__(self, nm, flds, extern):
+    def __init__(self, nm, flds, types, extern):
         self.name = nm
         self.extern = extern
-        self.fields = sorted(flds)
+        self.fields = flds
+        self.types = types
 
     def __str__(self):
         return self.name + ' ' + str(self.fields)
@@ -327,8 +332,8 @@ class Arguments:
     def output_def(self):
         if not self.extern:
             output('typedef struct {\n')
-            for n in self.fields:
-                output('    int ', n, ';\n')
+            for (n, t) in zip(self.fields, self.types):
+                output(f'    {t} {n};\n')
             output('} ', self.struct_name(), ';\n\n')
 # end Arguments
 
@@ -719,21 +724,27 @@ def parse_arguments(lineno, name, toks):
     global anyextern
 
     flds = []
+    types = []
     extern = False
-    for t in toks:
-        if re.fullmatch('!extern', t):
+    for n in toks:
+        if re.fullmatch('!extern', n):
             extern = True
             anyextern = True
             continue
-        if not re.fullmatch(re_C_ident, t):
-            error(lineno, f'invalid argument set token "{t}"')
-        if t in flds:
-            error(lineno, f'duplicate argument "{t}"')
-        flds.append(t)
+        if re.fullmatch(re_C_ident + ':' + re_C_ident, n):
+            (n, t) = n.split(':')
+        elif re.fullmatch(re_C_ident, n):
+            t = 'int'
+        else:
+            error(lineno, f'invalid argument set token "{n}"')
+        if n in flds:
+            error(lineno, f'duplicate argument "{n}"')
+        flds.append(n)
+        types.append(t)
 
     if name in arguments:
         error(lineno, 'duplicate argument set', name)
-    arguments[name] = Arguments(name, flds, extern)
+    arguments[name] = Arguments(name, flds, types, extern)
 # end parse_arguments
 
 
@@ -760,11 +771,11 @@ def infer_argument_set(flds):
     global decode_function
 
     for arg in arguments.values():
-        if eq_fields_for_args(flds, arg.fields):
+        if eq_fields_for_args(flds, arg):
             return arg
 
     name = decode_function + str(len(arguments))
-    arg = Arguments(name, flds.keys(), False)
+    arg = Arguments(name, flds.keys(), ['int'] * len(flds), False)
     arguments[name] = arg
     return arg