about summary refs log tree commit diff stats
path: root/rebuild_wrappers.py
diff options
context:
space:
mode:
Diffstat (limited to 'rebuild_wrappers.py')
-rwxr-xr-xrebuild_wrappers.py265
1 files changed, 148 insertions, 117 deletions
diff --git a/rebuild_wrappers.py b/rebuild_wrappers.py
index 39c2698b..67be322b 100755
--- a/rebuild_wrappers.py
+++ b/rebuild_wrappers.py
@@ -1,52 +1,93 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
 
 try:
-	# Python 3.5.2+
-	from typing import Union, List, Sequence, Dict, Tuple, NewType, final
+	# Python 3.5.2+ (NewType)
+	from typing import Union, List, Sequence, Dict, Tuple, NewType, TypeVar
 except ImportError:
-	print("Your Python version does not have the typing module, fallback to empty 'types'")
-	# Stubs
-	class GTStub:
+	#print("Your Python version does not have the typing module, fallback to empty 'types'")
+	# Dummies
+	class GTDummy:
 		def __getitem__(self, t):
-			return None
-	Union = GTStub() # type: ignore
-	List = GTStub() # type: ignore
-	Sequence = GTStub() # type: ignore
-	Dict = GTStub() # type: ignore
-	Tuple = GTStub() # type: ignore
-	def NewType(T, *largs): return largs[0] if len(largs) > 0 else None # type: ignore
-	final = lambda x: x # type: ignore
+			return self
+	Union = GTDummy() # type: ignore
+	List = GTDummy() # type: ignore
+	Sequence = GTDummy() # type: ignore
+	Dict = GTDummy() # type: ignore
+	Tuple = GTDummy() # type: ignore
+	def NewType(T, b): return b # type: ignore
+	def TypeVar(T): return object # type: ignore
+try:
+	# Python 3.8+
+	from typing import final
+	import typing
+except ImportError:
+	#print("Your Python version does not have all typing utilities, fallback to dummy ones")
+	def _overload_dummy(*args, **kwds):
+		raise NotImplementedError(
+				"You should not call an overloaded function. "
+				"A series of @overload-decorated functions "
+				"outside a stub module should always be followed "
+				"by an implementation that is not @overload-ed.")
+	def overload(fun): return _overload_dummy # type: ignore
+	class Typing:
+		pass
+	typing = Typing() # type: ignore
+	typing.overload = overload # type: ignore
+	final = lambda fun: fun # type: ignore
+
+import os
+import sys
+
+class FunctionType(str):
+	values: List[str] = ['E', 'e', 'v', 'c', 'w', 'i', 'I', 'C', 'W', 'u', 'U', 'f', 'd', 'D', 'K', 'l', 'L', 'p', 'V', 'O', 'S', 'N', 'M', 'H', 'P']
+	
+	@staticmethod
+	def validate(s: str, post: str) -> bool:
+		if len(s) < 3:
+			raise NotImplementedError("Type {0} too short{1}".format(s, post))
+		if "E" in s:
+			if ("E" in s[:2]) or ("E" in s[3:]):
+				raise NotImplementedError(
+					"emu64_t* not as the first parameter{0}".format(post))
+			if len(s) < 4:
+				raise NotImplementedError("Type {0} too short{1}".format(s, post))
+		
+		if s[1] not in ["F"]:
+			raise NotImplementedError("Bad middle letter {0}{1}".format(s[1], post))
+		
+		return all(c in FunctionType.values for c in s[2:]) and (('v' not in s[2:]) or (len(s) == 3))
+	
+	def splitchar(self) -> List[int]:
+		"""
+		splitchar -- Sorting key function for function signatures
+		
+		The longest strings are at the end, and for identical length, the string
+		are sorted using a pseudo-lexicographic sort, where characters have a value
+		of `values.index`.
+		"""
+		try:
+			ret = [len(self), FunctionType.values.index(self[0])]
+			for c in self[2:]:
+				ret.append(FunctionType.values.index(c))
+			return ret
+		except ValueError as e:
+			raise ValueError("Value is " + self) from e
+	
+	def __getitem__(self, i: Union[int, slice]) -> 'FunctionType':
+		return FunctionType(super().__getitem__(i))
+
+assert(all(c not in FunctionType.values[:i] for i, c in enumerate(FunctionType.values)))
 
-#FunctionType = NewType('FunctionType', str)
-FunctionType = str
 RedirectType = NewType('RedirectType', FunctionType)
 #DefineType = NewType('DefineType', str)
 DefineType = str
 
+T = TypeVar('T')
+U = TypeVar('U')
+
 Filename = str
 ClausesStr = str
 
-import os
-import sys
-
-values: List[str] = ['E', 'e', 'v', 'c', 'w', 'i', 'I', 'C', 'W', 'u', 'U', 'f', 'd', 'D', 'K', 'l', 'L', 'p', 'V', 'O', 'S', 'N', 'M', 'H', 'P']
-assert(all(c not in values[:i] for i, c in enumerate(values)))
-def splitchar(s: str) -> List[int]:
-	"""
-	splitchar -- Sorting key function for function signatures
-	
-	The longest strings are at the end, and for identical length, the string
-	are sorted using a pseudo-lexicographic sort, where characters have a value
-	of `values.index`.
-	"""
-	try:
-		ret = [len(s), values.index(s[0])]
-		for c in s[2:]:
-			ret.append(values.index(c))
-		return ret
-	except ValueError as e:
-		raise ValueError("Value is " + s) from e
-
 @final
 class Define:
 	name: DefineType
@@ -73,7 +114,7 @@ class Define:
 		self.inverted_ = not self.inverted_
 	def inverted(self) -> "Define":
 		"""
-		invert -- Transform a `defined()` into a `!defined()` and vice-versa, out-of-place.
+		inverted -- Transform a `defined()` into a `!defined()` and vice-versa, out-of-place.
 		"""
 		return Define(self.name, not self.inverted_)
 	
@@ -82,7 +123,6 @@ class Define:
 			return "!defined(" + self.name + ")"
 		else:
 			return "defined(" + self.name + ")"
-
 @final
 class Defines:
 	defines: List[Define]
@@ -113,7 +153,6 @@ class Defines:
 	
 	def __str__(self) -> str:
 		return " && ".join(map(str, self.defines))
-
 @final
 class Clauses:
 	"""
@@ -163,9 +202,9 @@ class Clauses:
 			return "(" + ") || (".join(map(str, self.definess)) + ")"
 
 def readFiles(files: Sequence[Filename]) -> \
-		Tuple[Dict[str, List[FunctionType]],
-		      Dict[str,      Dict[RedirectType, FunctionType]],
-		      Dict[Filename, Dict[RedirectType, List[str]]]]:
+		Tuple[Dict[ClausesStr, List[FunctionType]],
+		      Dict[ClausesStr, Dict[RedirectType, FunctionType]],
+		      Dict[Filename,   Dict[RedirectType, List[str]]]]:
 	"""
 	readFiles
 	
@@ -223,22 +262,10 @@ def readFiles(files: Sequence[Filename]) -> \
 							filename, line[:-1]
 						))
 					
-					if len(ln) < 3:
-						raise NotImplementedError("Type {0} too short ({1}:{2})".format(ln, filename, line[:-1]))
-					if "E" in ln:
-						if ("E" in ln[:2]) or ("E" in ln[3:]):
-							raise NotImplementedError(
-								"emu64_t* not as the first parameter ({0}:{1})".format(filename, line[:-1]))
-						if len(ln) < 4:
-							raise NotImplementedError("Type {0} too short ({1}:{2})".format(ln, filename, line[:-1]))
-					
-					if ln[1] not in ["F"]:
-						raise NotImplementedError("Bad middle letter {0} ({1}:{2})".format(ln[1], filename, line[:-1]))
-					
-					if any(c not in values for c in ln[2:]) or (('v' in ln[2:]) and (len(ln) > 3)):
-						old = RedirectType(ln)
+					if not FunctionType.validate(ln, " ({0}:{1})".format(filename, line[:-1])):
+						old = RedirectType(FunctionType(ln))
 						# This needs more work
-						acceptables = ['v', '0', '1'] + values
+						acceptables = ['v', '0', '1'] + FunctionType.values
 						if any(c not in acceptables for c in ln[2:]):
 							raise NotImplementedError("{0} ({1}:{2})".format(ln[2:], filename, line[:-1]))
 						# Ok, this is acceptable: there is 0, 1 and/or void
@@ -248,18 +275,18 @@ def readFiles(files: Sequence[Filename]) -> \
 							.replace("1", "i")) # 1      -> integer
 						assert(len(ln) >= 3)
 						redirects.setdefault(str(dependants), {})
-						redirects[str(dependants)][old] = ln
+						redirects[str(dependants)][old] = FunctionType(ln)
 					# Simply append the function type if it's not yet existing
 					gbl.setdefault(str(dependants), [])
 					if ln not in gbl[str(dependants)]:
-						gbl[str(dependants)].append(ln)
+						gbl[str(dependants)].append(FunctionType(ln))
 					
 					if ln[2] == "E":
 						# filename isn't stored with the '_private.h' part
 						if len(ln) > 3:
-							funtype = RedirectType(ln[:2] + ln[3:])
+							funtype = RedirectType(FunctionType(ln[:2] + ln[3:]))
 						else:
-							funtype = RedirectType(ln[:2] + "v")
+							funtype = RedirectType(FunctionType(ln[:2] + "v"))
 						mytypedefs.setdefault(filename[:-10], {})
 						mytypedefs[filename[:-10]].setdefault(funtype, [])
 						mytypedefs[filename[:-10]][funtype].append(funname)
@@ -273,17 +300,15 @@ def readFiles(files: Sequence[Filename]) -> \
 	
 	return gbl, redirects, mytypedefs
 
+COrderedDict = Tuple[Dict[T, U], List[T]]
 def sortArrays(
 	gbl_tmp   : Dict[str,      List[FunctionType]],
 	red_tmp   : Dict[str,      Dict[RedirectType, FunctionType]],
 	mytypedefs: Dict[Filename, Dict[RedirectType, List[str]]]) -> \
 		Tuple[
-			Tuple[Dict[ClausesStr, List[FunctionType]],
-				List[ClausesStr]],
-			Tuple[Dict[ClausesStr, List[Tuple[RedirectType, FunctionType]]],
-				List[ClausesStr]],
-			Tuple[Dict[Filename, Dict[RedirectType, List[str]]],
-				Dict[Filename, List[RedirectType]]]
+			COrderedDict[ClausesStr, List[FunctionType]],
+			COrderedDict[ClausesStr, List[Tuple[RedirectType, FunctionType]]],
+			Dict[Filename, COrderedDict[RedirectType, List[str]]]
 		]:
 	# Now, take all function types, and make a new table gbl_vals
 	# This table contains all #if conditions for when a function type needs to
@@ -386,20 +411,20 @@ def sortArrays(
 	
 	# Sort the function types as defined in `splitchar`
 	for k3 in gbl:
-		gbl[k3].sort(key=lambda v: splitchar(v))
+		gbl[k3].sort(key=FunctionType.splitchar)
 	
-	global values
-	values = values + ['0', '1']
+	FunctionType.values = FunctionType.values + ['0', '1']
 	for k3 in redirects:
-		redirects[k3].sort(key=lambda v: splitchar(v[0]) + splitchar(v[1]))
-	values = values[:-2]
+		redirects[k3].sort(key=lambda v: v[0].splitchar() + v[1].splitchar())
+	FunctionType.values = FunctionType.values[:-2]
 	
-	mytypedefs_vals: Dict[Filename, List[RedirectType]] = dict((fn, sorted(mytypedefs[fn].keys(), key=lambda v: splitchar(v))) for fn in mytypedefs)
+	mytypedefs_vals: Dict[Filename, List[RedirectType]] = dict((fn, sorted(mytypedefs[fn].keys(), key=FunctionType.splitchar)) for fn in mytypedefs)
 	for fn in mytypedefs:
 		for v in mytypedefs_vals[fn]:
 			mytypedefs[fn][v].sort()
 	
-	return (gbl, gbl_idxs), (redirects, redirects_idxs), (mytypedefs, mytypedefs_vals)
+	return (gbl, gbl_idxs), (redirects, redirects_idxs), \
+		dict((fn, (mytypedefs[fn], mytypedefs_vals[fn])) for fn in mytypedefs)
 	
 def main(root: str, files: Sequence[Filename], ver: str):
 	"""
@@ -416,22 +441,22 @@ def main(root: str, files: Sequence[Filename], ver: str):
 	#  "defined() && ..." -> [vFEv -> vFv, ...]
 	# tdf_tmp:
 	#  "filename" -> [vFEv -> fopen, ...]
-	gbl_tmp: Dict[str,      List[FunctionType]]
-	red_tmp: Dict[str,      Dict[RedirectType, FunctionType]]
-	tdf_tmp: Dict[Filename, Dict[RedirectType, List[str]]]
+	gbl_tmp: Dict[ClausesStr, List[FunctionType]]
+	red_tmp: Dict[ClausesStr, Dict[RedirectType, FunctionType]]
+	tdf_tmp: Dict[Filename,   Dict[RedirectType, List[str]]]
 	
 	gbl_tmp, red_tmp, tdf_tmp = readFiles(files)
 	
-	gbl            : Dict[ClausesStr, List[FunctionType]]
-	redirects      : Dict[ClausesStr, List[Tuple[RedirectType, FunctionType]]]
-	mytypedefs     : Dict[Filename, Dict[RedirectType, List[str]]]
-	gbl_idxs       : List[ClausesStr]
-	redirects_idxs : List[ClausesStr]
-	mytypedefs_vals: Dict[Filename, List[RedirectType]]
+	gbls      : COrderedDict[ClausesStr, List[FunctionType]]
+	redirects_: COrderedDict[ClausesStr, List[Tuple[RedirectType, FunctionType]]]
+	mytypedefs: Dict[Filename, COrderedDict[RedirectType, List[str]]]
 	
-	(gbl, gbl_idxs), (redirects, redirects_idxs), (mytypedefs, mytypedefs_vals) = \
+	gbls, redirects_, mytypedefs = \
 		sortArrays(gbl_tmp, red_tmp, tdf_tmp)
 	
+	gbl, gbl_idxs = gbls
+	redirects, redirects_idxs = redirects_
+	
 	# Check if there was any new functions compared to last run
 	functions_list: str = ""
 	for k in [str(Clauses())] + gbl_idxs:
@@ -440,6 +465,12 @@ def main(root: str, files: Sequence[Filename], ver: str):
 	for k in [str(Clauses())] + redirects_idxs:
 		for vr, vf in redirects[k]:
 			functions_list = functions_list + "#" + k + " " + vr + " -> " + vf + "\n"
+	for filename in sorted(mytypedefs.keys()):
+		functions_list = functions_list + filename + ":\n"
+		for vr in mytypedefs[filename][1]:
+			functions_list = functions_list + "- " + vr + ":\n"
+			for fn in mytypedefs[filename][0][vr]:
+				functions_list = functions_list + "  - " + fn + "\n"
 	
 	# functions_list is a unique string, compare it with the last run
 	try:
@@ -466,11 +497,11 @@ def main(root: str, files: Sequence[Filename], ver: str):
 	
 	# Sanity checks
 	forbidden_simple: str = "EeDKVOSNMHP"
-	assert(len(allowed_simply) + len(allowed_regs) + len(allowed_fpr) + len(forbidden_simple) == len(values))
+	assert(len(allowed_simply) + len(allowed_regs) + len(allowed_fpr) + len(forbidden_simple) == len(FunctionType.values))
 	assert(all(c not in allowed_regs for c in allowed_simply))
 	assert(all(c not in allowed_simply + allowed_regs for c in allowed_fpr))
 	assert(all(c not in allowed_simply + allowed_regs + allowed_fpr for c in forbidden_simple))
-	assert(all(c in allowed_simply + allowed_regs + allowed_fpr + forbidden_simple for c in values))
+	assert(all(c in allowed_simply + allowed_regs + allowed_fpr + forbidden_simple for c in FunctionType.values))
 	
 	# Only search on real wrappers
 	for k in [str(Clauses())] + gbl_idxs:
@@ -596,12 +627,12 @@ int isSimpleWrapper(wrapper_t fun);
 		# i and u should only be 32 bits
 		#         E            e             v       c         w          i          I          C          W           u           U           f        d         D              K         l           L            p        V        O          S        N      M      H                    P
 		types = ["x64emu_t*", "x64emu_t**", "void", "int8_t", "int16_t", "int64_t", "int64_t", "uint8_t", "uint16_t", "uint64_t", "uint64_t", "float", "double", "long double", "double", "intptr_t", "uintptr_t", "void*", "void*", "int32_t", "void*", "...", "...", "unsigned __int128", "void*"]
-		if len(values) != len(types):
-			raise NotImplementedError("len(values) = {lenval} != len(types) = {lentypes}".format(lenval=len(values), lentypes=len(types)))
+		if len(FunctionType.values) != len(types):
+			raise NotImplementedError("len(values) = {lenval} != len(types) = {lentypes}".format(lenval=len(FunctionType.values), lentypes=len(types)))
 		
 		for v in arr:
-			file.write("typedef " + types[values.index(v[0])] + " (*" + v + "_t)"
-				+ "(" + ', '.join(types[values.index(t)] for t in v[2:]) + ");\n")
+			file.write("typedef " + types[FunctionType.values.index(v[0])] + " (*" + v + "_t)"
+				+ "(" + ', '.join(types[FunctionType.values.index(t)] for t in v[2:]) + ");\n")
 	
 	with open(os.path.join(root, "src", "wrapped", "generated", "wrapper.c"), 'w') as file:
 		file.write(files_header["wrapper.c"].format(lbr="{", rbr="}", version=ver))
@@ -771,24 +802,24 @@ int isSimpleWrapper(wrapper_t fun);
 		]
 
 		# Asserts
-		if len(values) != len(vstack):
-			raise NotImplementedError("len(values) = {lenval} != len(vstack) = {lenvstack}".format(lenval=len(values), lenvstack=len(vstack)))
-		if len(values) != len(vreg):
-			raise NotImplementedError("len(values) = {lenval} != len(vreg) = {lenvreg}".format(lenval=len(values), lenvreg=len(vreg)))
-		if len(values) != len(vxmm):
-			raise NotImplementedError("len(values) = {lenval} != len(vxmm) = {lenvxmm}".format(lenval=len(values), lenvxmm=len(vxmm)))
-		if len(values) != len(vother):
-			raise NotImplementedError("len(values) = {lenval} != len(vother) = {lenvother}".format(lenval=len(values), lenvother=len(vother)))
-		if len(values) != len(arg_s):
-			raise NotImplementedError("len(values) = {lenval} != len(arg_s) = {lenargs}".format(lenval=len(values), lenargs=len(arg_s)))
-		if len(values) != len(arg_r):
-			raise NotImplementedError("len(values) = {lenval} != len(arg_r) = {lenargr}".format(lenval=len(values), lenargr=len(arg_r)))
-		if len(values) != len(arg_x):
-			raise NotImplementedError("len(values) = {lenval} != len(arg_x) = {lenargx}".format(lenval=len(values), lenargx=len(arg_x)))
-		if len(values) != len(arg_o):
-			raise NotImplementedError("len(values) = {lenval} != len(arg_o) = {lenargo}".format(lenval=len(values), lenargo=len(arg_o)))
-		if len(values) != len(vals):
-			raise NotImplementedError("len(values) = {lenval} != len(vals) = {lenvals}".format(lenval=len(values), lenvals=len(vals)))
+		if len(FunctionType.values) != len(vstack):
+			raise NotImplementedError("len(values) = {lenval} != len(vstack) = {lenvstack}".format(lenval=len(FunctionType.values), lenvstack=len(vstack)))
+		if len(FunctionType.values) != len(vreg):
+			raise NotImplementedError("len(values) = {lenval} != len(vreg) = {lenvreg}".format(lenval=len(FunctionType.values), lenvreg=len(vreg)))
+		if len(FunctionType.values) != len(vxmm):
+			raise NotImplementedError("len(values) = {lenval} != len(vxmm) = {lenvxmm}".format(lenval=len(FunctionType.values), lenvxmm=len(vxmm)))
+		if len(FunctionType.values) != len(vother):
+			raise NotImplementedError("len(values) = {lenval} != len(vother) = {lenvother}".format(lenval=len(FunctionType.values), lenvother=len(vother)))
+		if len(FunctionType.values) != len(arg_s):
+			raise NotImplementedError("len(values) = {lenval} != len(arg_s) = {lenargs}".format(lenval=len(FunctionType.values), lenargs=len(arg_s)))
+		if len(FunctionType.values) != len(arg_r):
+			raise NotImplementedError("len(values) = {lenval} != len(arg_r) = {lenargr}".format(lenval=len(FunctionType.values), lenargr=len(arg_r)))
+		if len(FunctionType.values) != len(arg_x):
+			raise NotImplementedError("len(values) = {lenval} != len(arg_x) = {lenargx}".format(lenval=len(FunctionType.values), lenargx=len(arg_x)))
+		if len(FunctionType.values) != len(arg_o):
+			raise NotImplementedError("len(values) = {lenval} != len(arg_o) = {lenargo}".format(lenval=len(FunctionType.values), lenargo=len(arg_o)))
+		if len(FunctionType.values) != len(vals):
+			raise NotImplementedError("len(values) = {lenval} != len(vals) = {lenvals}".format(lenval=len(FunctionType.values), lenvals=len(vals)))
 		
 		# Helper functions to write the function definitions
 		def function_args(args: FunctionType, d: int = 8, r: int = 0, x: int = 0) -> str:
@@ -805,7 +836,7 @@ int isSimpleWrapper(wrapper_t fun);
 			elif args[0] == "1":
 				return "1, " + function_args(args[1:], d, r, x)
 			
-			idx = values.index(args[0])
+			idx = FunctionType.values.index(args[0])
 			if (r < 6) and (vreg[idx] > 0):
 				# Value is in a general register (and there is still one available)
 				if (vreg[idx] == 2) and (r == 6):
@@ -824,17 +855,17 @@ int isSimpleWrapper(wrapper_t fun);
 				# Value is somewhere else
 				return arg_o[idx].format(p=d) + function_args(args[1:], d, r, x)
 		
-		def function_writer(f, N: FunctionType, W: FunctionType) -> None:
+		def function_writer(f, N: FunctionType, W: str) -> None:
 			# Write to f the function type N (real type W)
 			# rettype is a single character, args is the string of argument types
 			# (those could actually be deduced from N)
 			
 			f.write("void {0}(x64emu_t *emu, uintptr_t fcn) {2} {1} fn = ({1})fcn; ".format(N, W, "{"))
 			# Generic function
-			f.write(vals[values.index(N[0])].format(function_args(N[2:])[:-2]) + " }\n")
+			f.write(vals[FunctionType.values.index(N[0])].format(function_args(N[2:])[:-2]) + " }\n")
 		
 		for v in gbl[str(Clauses())]:
-			if v == "vFv":
+			if v == FunctionType("vFv"):
 				# Suppress all warnings...
 				file.write("void vFv(x64emu_t *emu, uintptr_t fcn) { vFv_t fn = (vFv_t)fcn; fn(); (void)emu; }\n")
 			else:
@@ -892,10 +923,10 @@ int isSimpleWrapper(wrapper_t fun);
 	for fn in mytypedefs:
 		with open(os.path.join(root, "src", "wrapped", "generated", fn + "types.h"), 'w') as file:
 			file.write(files_header["fntypes.h"].format(lbr="{", rbr="}", version=ver, filename=fn))
-			generate_typedefs(mytypedefs_vals[fn], file)
+			generate_typedefs(mytypedefs[fn][1], file)
 			file.write("\n#define SUPER() ADDED_FUNCTIONS()")
-			for v in mytypedefs_vals[fn]:
-				for f in mytypedefs[fn][v]:
+			for v in mytypedefs[fn][1]:
+				for f in mytypedefs[fn][0][v]:
 					file.write(" \\\n\tGO({0}, {1}_t)".format(f, v))
 			file.write("\n")
 			file.write(files_guard["fntypes.h"].format(lbr="{", rbr="}", version=ver, filename=fn))
@@ -912,6 +943,6 @@ if __name__ == '__main__':
 		if v == "--":
 			limit.append(i)
 	Define.defines = sys.argv[2:limit[0]]
-	if main(sys.argv[1], sys.argv[limit[0]+1:], "2.0.0.13") != 0:
+	if main(sys.argv[1], sys.argv[limit[0]+1:], "2.0.1.14") != 0:
 		exit(2)
 	exit(0)