about summary refs log tree commit diff stats
path: root/rebuild_wrappers.py
diff options
context:
space:
mode:
Diffstat (limited to 'rebuild_wrappers.py')
-rwxr-xr-xrebuild_wrappers.py206
1 files changed, 146 insertions, 60 deletions
diff --git a/rebuild_wrappers.py b/rebuild_wrappers.py
index 79e97736..65f0f786 100755
--- a/rebuild_wrappers.py
+++ b/rebuild_wrappers.py
@@ -21,43 +21,34 @@ except ImportError:
 try:
 	# Python 3.8+
 	from typing import final
-	import typing
 except ImportError:
 	#print("Your Python version does not have all typing utilities, fallback to dummy ones")
-	def _overload_dummy(*args, **kwds):
-		raise NotImplementedError(
-				"You should not call an overloaded function. "
-				"A series of @overload-decorated functions "
-				"outside a stub module should always be followed "
-				"by an implementation that is not @overload-ed.")
-	def overload(fun): return _overload_dummy # type: ignore
-	class Typing:
-		pass
-	typing = Typing() # type: ignore
-	typing.overload = overload # type: ignore
 	final = lambda fun: fun # type: ignore
 
 import os
 import sys
 
 class FunctionType(str):
-	values: List[str] = ['E', 'e', 'v', 'c', 'w', 'i', 'I', 'C', 'W', 'u', 'U', 'f', 'd', 'D', 'K', 'l', 'L', 'p', 'V', 'O', 'S', 'N', 'M', 'H', 'P', 'A']
+	values: List[str] = ['E', 'v', 'c', 'w', 'i', 'I', 'C', 'W', 'u', 'U', 'f', 'd', 'D', 'K', 'l', 'L', 'p', 'V', 'O', 'S', 'N', 'M', 'H', 'P', 'A']
 	
 	@staticmethod
 	def validate(s: str, post: str) -> bool:
 		if len(s) < 3:
 			raise NotImplementedError("Type {0} too short{1}".format(s, post))
+		chk_type = s[0] + s[2:]
 		if "E" in s:
 			if ("E" in s[:2]) or ("E" in s[3:]):
 				raise NotImplementedError(
 					"emu64_t* not as the first parameter{0}".format(post))
 			if len(s) < 4:
 				raise NotImplementedError("Type {0} too short{1}".format(s, post))
+			# TODO: change *FEv into true functions (right now they are redirected to *FE)
+			#chk_type = s[0] + s[3:]
 		
 		if s[1] not in ["F"]:
 			raise NotImplementedError("Bad middle letter {0}{1}".format(s[1], post))
 		
-		return all(c in FunctionType.values for c in s[2:]) and (('v' not in s[2:]) or (len(s) == 3))
+		return all(c in FunctionType.values for c in chk_type) and (('v' not in chk_type[1:]) or (len(chk_type) == 2))
 	
 	def splitchar(self) -> List[int]:
 		"""
@@ -219,12 +210,71 @@ def readFiles(files: Sequence[Filename]) -> \
 	redirects : Dict[ClausesStr, Dict[RedirectType, FunctionType]] = {}
 	mytypedefs: Dict[Filename,   Dict[RedirectType, List[str]]]    = {}
 	
-	halt_required = False # Is there a GO(*, .FE*)?
-	
+	functions: Dict[str, Filename] = {}
+	halt_required = False # Is there a GO(*, .FE*) or similar in-batch error(s)?
 	# First read the files inside the headers
 	for filepath in files:
 		filename: Filename = filepath.split("/")[-1]
 		dependants: Clause = Clause()
+		
+		def add_function_name(funname: Union[str, None], funsname: Dict[ClausesStr, List[str]] = {"": []}):
+			# Optional arguments are evaluated only once!
+			nonlocal halt_required
+			if funname == None:
+				for k in funsname:
+					if (k != "") and (len(funsname[k]) != 0):
+						# Note: if this condition ever raises, check the wrapper pointed by it.
+						# If you find no problem, comment the error below, add a "pass" line (so python is happy)
+						# and open a ticket so I can fix this.
+						raise NotImplementedError("Some functions are only implemented under one condition (probably) ({0}/{1})"
+							.format(k, filename) + " [extra note in the script]")
+					for f in funsname[k]:
+						if f in ['_fini', '_init', '__bss_start', '__data_start', '_edata', '_end']:
+							continue # Always allow those symbols [TODO: check if OK]
+						if f in functions:
+							# Check for resemblances between functions[f] and filename
+							if filename.startswith(functions[f][:-12]) or functions[f].startswith(filename[:-12]):
+								# Probably OK
+								continue
+							# Manual compatible libs detection
+							match = lambda l, r: (filename[7:-10], functions[f][7:-10]) in [(l, r), (r, l)]
+							if match("sdl1image", "sdl2image") \
+							 or match("sdl1mixer", "sdl2mixer") \
+							 or match("sdl1net", "sdl2net") \
+							 or match("sdl1ttf", "sdl2ttf") \
+							 or match("libc", "tcmallocminimal"):
+								continue
+							
+							# Note: this test is very (too) simple. If it ever raises, comment
+							# `halt_required = True` and open an issue.
+							print("The function or data {0} is declared in multiple files ({1}/{2})"
+								.format(f, functions[f], filename) + " [extra note in the script]")
+							halt_required = True
+						functions[f] = filename
+			else:
+				if funname == "":
+					raise NotImplementedError("This function name (\"\") is suspicious... ({0})".format(filename))
+				l = len(dependants.defines)
+				already_pst = funname in funsname[""]
+				if l > 1:
+					return
+				elif l == 1:
+					funsname.setdefault(str(dependants), [])
+					already_pst = already_pst or (funname in funsname[str(dependants)])
+				if already_pst:
+					print("Function or data {0} is duplicated! ({1})".format(funname, filename))
+					halt_required = True
+					return
+				if l == 1:
+					s = str(dependants.defines[0].inverted())
+					if (s in funsname) and (funname in funsname[s]):
+						funsname[s].remove(funname)
+						funsname[""].append(funname)
+					else:
+						funsname[str(dependants)].append(funname)
+				else:
+					funsname[""].append(funname)
+		
 		with open(filepath, 'r') as file:
 			for line in file:
 				ln = line.strip()
@@ -261,7 +311,8 @@ def readFiles(files: Sequence[Filename]) -> \
 						gotype = ln.split("(")[0].strip()
 						funname = ln.split(",")[0].split("(")[1].strip()
 						ln = ln.split(",")[1].split(")")[0].strip()
-					except IndexError as e:
+						add_function_name(funname)
+					except IndexError:
 						raise NotImplementedError("Invalid GO command: {0}:{1}".format(
 							filename, line[:-1]
 						))
@@ -275,7 +326,7 @@ def readFiles(files: Sequence[Filename]) -> \
 						# Ok, this is acceptable: there is 0, 1 and/or void
 						ln = ln[:2] + (ln[2:]
 							.replace("v", "")   # void   -> nothing
-							.replace("0", "p")  # 0      -> pointer
+							.replace("0", "i")  # 0      -> integer
 							.replace("1", "i")) # 1      -> integer
 						assert(len(ln) >= 3)
 						redirects.setdefault(str(dependants), {})
@@ -286,22 +337,36 @@ def readFiles(files: Sequence[Filename]) -> \
 						gbl[str(dependants)].append(FunctionType(ln))
 					
 					if ln[2] == "E":
-						if (gotype != "GOM"):
+						if (gotype != "GOM") and (gotype != "GOWM"):
 							if (gotype != "GO2") or not (line.split(',')[2].split(')')[0].strip().startswith('my_')):
 								print("\033[91mThis is probably not what you meant!\033[m ({0}:{1})".format(filename, line[:-1]))
 								halt_required = True
-						# filename isn't stored with the '_private.h' part
 						if len(ln) > 3:
 							funtype = RedirectType(FunctionType(ln[:2] + ln[3:]))
 						else:
 							funtype = RedirectType(FunctionType(ln[:2] + "v"))
+						# filename isn't stored with the '_private.h' part
 						mytypedefs.setdefault(filename[:-10], {})
 						mytypedefs[filename[:-10]].setdefault(funtype, [])
 						mytypedefs[filename[:-10]][funtype].append(funname)
-					# OK on box64
-					# elif gotype == "GOM":
-					# 	print("\033[94mAre you sure of this?\033[m ({0}:{1})".format(filename, line[:-1]))
-					# 	halt_required = True
+					elif (gotype == "GOM") or (gotype == "GOWM"):
+						# OK on box64 for a GOM to not have emu...
+						funtype = RedirectType(FunctionType(ln))
+						mytypedefs.setdefault(filename[:-10], {})
+						mytypedefs[filename[:-10]].setdefault(funtype, [])
+						mytypedefs[filename[:-10]][funtype].append(funname)
+						# print("\033[94mAre you sure of this?\033[m ({0}:{1})".format(filename, line[:-1]))
+						# halt_required = True
+				elif ("GO" in ln) or ("DATA" in ln):
+					# Probably "//GO(..., " or "DATA(...," at least
+					try:
+						funname = ln.split('(')[1].split(',')[0].strip()
+						add_function_name(funname)
+					except IndexError:
+						# Oops, it wasn't...
+						pass
+		
+		add_function_name(None)
 	
 	if halt_required:
 		raise ValueError("Fix all previous errors before proceeding")
@@ -511,7 +576,7 @@ def main(root: str, files: Sequence[Filename], ver: str):
 	allowed_fpr   : str = "fd"
 	
 	# Sanity checks
-	forbidden_simple: str = "EeDKVOSNMHPA"
+	forbidden_simple: str = "EDKVOSNMHPA"
 	assert(len(allowed_simply) + len(allowed_regs) + len(allowed_fpr) + len(forbidden_simple) == len(FunctionType.values))
 	assert(all(c not in allowed_regs for c in allowed_simply))
 	assert(all(c not in allowed_simply + allowed_regs for c in allowed_fpr))
@@ -598,21 +663,24 @@ typedef struct x64emu_s x64emu_t;
 typedef void (*wrapper_t)(x64emu_t* emu, uintptr_t fnc);
 
 // list of defined wrapper
-// v = void, i = int32, u = uint32, U/I= (u)int64
-// l = signed long, L = unsigned long (long is an int with the size of a pointer)
-// p = pointer, P = void* on the stack
-// f = float, d = double, D = long double, K = fake long double
-// V = vaargs, E = current x86emu struct, e = ref to current x86emu struct
-// 0 = constant 0, 1 = constant 1
-// o = stdout
+// E = current x86emu struct
+// v = void
 // C = unsigned byte c = char
 // W = unsigned short w = short
+// u = uint32, i = int32
+// U = uint64, I= int64
+// L = unsigned long, l = signed long (long is an int with the size of a pointer)
+// H = Huge 128bits value/struct
+// p = pointer, P = void* on the stack
+// f = float, d = double, D = long double, K = fake long double
+// V = vaargs
 // O = libc O_ flags bitfield
+// o = stdout
 // S = _IO_2_1_stdXXX_ pointer (or FILE*)
 // N = ... automatically sending 1 arg
 // M = ... automatically sending 2 args
-// H = Huge 128bits value/struct
 // A = va_list
+// 0 = constant 0, 1 = constant 1
 
 """,
 		"fntypes.h": """/*******************************************************************
@@ -643,8 +711,8 @@ int isSimpleWrapper(wrapper_t fun);
 	
 	# Rewrite the wrapper.c file:
 	# i and u should only be 32 bits
-	#            E            e             v       c         w          i          I          C          W           u           U           f        d         D              K         l           L            p        V        O          S        N      M      H                    P        A
-	td_types = ["x64emu_t*", "x64emu_t**", "void", "int8_t", "int16_t", "int64_t", "int64_t", "uint8_t", "uint16_t", "uint64_t", "uint64_t", "float", "double", "long double", "double", "intptr_t", "uintptr_t", "void*", "void*", "int32_t", "void*", "...", "...", "unsigned __int128", "void*", "void*"]
+	#            E            v       c         w          i          I          C          W           u           U           f        d         D              K         l           L            p        V        O          S        N      M      H                    P        A
+	td_types = ["x64emu_t*", "void", "int8_t", "int16_t", "int64_t", "int64_t", "uint8_t", "uint16_t", "uint64_t", "uint64_t", "float", "double", "long double", "double", "intptr_t", "uintptr_t", "void*", "void*", "int32_t", "void*", "...", "...", "unsigned __int128", "void*", "void*"]
 	if len(FunctionType.values) != len(td_types):
 		raise NotImplementedError("len(values) = {lenval} != len(td_types) = {lentypes}".format(lenval=len(FunctionType.values), lentypes=len(td_types)))
 	
@@ -671,7 +739,6 @@ int isSimpleWrapper(wrapper_t fun);
 		# Return type template
 		vals = [
 			"\n#error Invalid return type: emulator\n",                # E
-			"\n#error Invalid return type: &emulator\n",               # e
 			"fn({0});",                                                # v
 			"R_RAX=fn({0});",                                          # c
 			"R_RAX=fn({0});",                                          # w
@@ -700,21 +767,21 @@ int isSimpleWrapper(wrapper_t fun);
 		
 		# Name of the registers
 		reg_arg = ["R_RDI", "R_RSI", "R_RDX", "R_RCX", "R_R8", "R_R9"]
+		assert(len(reg_arg) == 6)
 		# vreg: value is in a general register
-		#         E  e  v  c  w  i  I  C  W  u  U  f  d  D  K  l  L  p  V  O  S  N  M  H  P  A
-		vreg   = [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 2, 2, 0, 1]
+		#         E  v  c  w  i  I  C  W  u  U  f  d  D  K  l  L  p  V  O  S  N  M  H  P  A
+		vreg   = [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 2, 2, 0, 1]
 		# vxmm: value is in a XMM register
-		#         E  e  v  c  w  i  I  C  W  u  U  f  d  D  K  l  L  p  V  O  S  N  M  H  P  A
-		vxmm   = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+		#         E  v  c  w  i  I  C  W  u  U  f  d  D  K  l  L  p  V  O  S  N  M  H  P  A
+		vxmm   = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
 		# vother: value is elsewere
-		#         E  e  v  c  w  i  I  C  W  u  U  f  d  D  K  l  L  p  V  O  S  N  M  H  P  A
-		vother = [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0]
+		#         E  v  c  w  i  I  C  W  u  U  f  d  D  K  l  L  p  V  O  S  N  M  H  P  A
+		vother = [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
 		# vstack: value is on the stack (or out of register)
-		#         E  e  v  c  w  i  I  C  W  u  U  f  d  D  K  l  L  p  V  O  S  N  M  H  P  A
-		vstack = [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 0, 1, 1, 1, 2, 2, 1, 1]
+		#         E  v  c  w  i  I  C  W  u  U  f  d  D  K  l  L  p  V  O  S  N  M  H  P  A
+		vstack = [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 0, 1, 1, 1, 2, 2, 1, 1]
 		arg_r = [
 			"",                            # E
-			"",                            # e
 			"",                            # v
 			"(int8_t){p}, ",               # c
 			"(int16_t){p}, ",              # w
@@ -742,7 +809,6 @@ int isSimpleWrapper(wrapper_t fun);
 		]
 		arg_x = [
 			"",                      # E
-			"",                      # e
 			"",                      # v
 			"",                      # c
 			"",                      # w
@@ -770,7 +836,6 @@ int isSimpleWrapper(wrapper_t fun);
 		]
 		arg_o = [
 			"emu, ",                   # E
-			"&emu, ",                  # e
 			"",                        # v
 			"",                        # c
 			"",                        # w
@@ -794,11 +859,10 @@ int isSimpleWrapper(wrapper_t fun);
 			"",                        # M
 			"",                        # H
 			"",                        # P
-			"",                      # A
+			"",                        # A
 		]
 		arg_s = [
 			"",                                         # E
-			"",                                         # e
 			"",                                         # v
 			"*(int8_t*)(R_RSP + {p}), ",                # c
 			"*(int16_t*)(R_RSP + {p}), ",               # w
@@ -844,6 +908,24 @@ int isSimpleWrapper(wrapper_t fun);
 			raise NotImplementedError("len(values) = {lenval} != len(arg_o) = {lenargo}".format(lenval=len(FunctionType.values), lenargo=len(arg_o)))
 		if len(FunctionType.values) != len(vals):
 			raise NotImplementedError("len(values) = {lenval} != len(vals) = {lenvals}".format(lenval=len(FunctionType.values), lenvals=len(vals)))
+		# When arg_* is not empty, v* should not be 0
+		if any(map(lambda v, a: (a != "") and (v == 0), vstack, arg_s)):
+			raise NotImplementedError("Something in the stack has a null offset and a non-empty arg string")
+		if any(map(lambda v, a: (a != "") and (v == 0), vreg, arg_r)):
+			raise NotImplementedError("Something in the stack has a null offset and a non-empty arg string")
+		if any(map(lambda v, a: (a != "") and (v == 0), vxmm, arg_x)):
+			raise NotImplementedError("Something in the stack has a null offset and a non-empty arg string")
+		if any(map(lambda v, a: (a != "") and (v == 0), vother, arg_o)):
+			raise NotImplementedError("Something in the stack has a null offset and a non-empty arg string")
+		# Everything is either in the stack or somewhere else, it cannot be in a GPr and in an XMMr...
+		if any(map(lambda o, s: (o == 0) == (s == 0), vother, vstack)):
+			raise NotImplementedError("Something cannot be in exactly one of the stack and somewhere else")
+		if any(map(lambda r, x: (r > 0) and (x > 0), vreg, vxmm)):
+			raise NotImplementedError("Something can be in both a general purpose register and in an XMM register")
+		if any(map(lambda r, s: (r > 0) and (s == 0), vreg, vstack)):
+			raise NotImplementedError("Something can be in a general purpose register but not in the stack")
+		if any(map(lambda x, s: (x > 0) and (s == 0), vxmm, vstack)):
+			raise NotImplementedError("Something can be in an XMM register but not in the stack")
 		
 		# Helper functions to write the function definitions
 		def function_args(args: FunctionType, d: int = 8, r: int = 0, x: int = 0) -> str:
@@ -862,15 +944,20 @@ int isSimpleWrapper(wrapper_t fun);
 			
 			idx = FunctionType.values.index(args[0])
 			if (r < 6) and (vreg[idx] > 0):
-				# Value is in a general register (and there is still one available)
-				if (vreg[idx] == 2) and (r == 6):
-					return arg_r[idx-1].format(p=reg_arg[r]) + arg_s[idx-1].format(p=d) + function_args(args[1:], d + vother[idx-1]*8, r+1, x)
-				elif (vreg[idx] == 2) and (r < 6):
-					return arg_r[idx].format(p=reg_arg[r]) + arg_r[idx].format(p=reg_arg[r+1]) + function_args(args[1:], d, r+2, x)
-				else:
-					return arg_r[idx].format(p=reg_arg[r]) + function_args(args[1:], d, r+1, x)
+				ret = ""
+				for _ in range(vreg[idx]):
+					# There may be values in multiple registers
+					if r < 6:
+						# Value is in a general register
+						ret = ret + arg_r[idx].format(p=reg_arg[r])
+						r = r + 1
+					else:
+						# Remaining is in the stack
+						ret = ret + arg_s[idx].format(p=d)
+						d = d + 8
+				return ret + function_args(args[1:], d, r, x)
 			elif (x < 8) and (vxmm[idx] > 0):
-				# Value is in an XMM register (and there is still one available)
+				# Value is in an XMM register
 				return arg_x[idx].format(p=x) + function_args(args[1:], d, r, x+1)
 			elif vstack[idx] > 0:
 				# Value is in the stack
@@ -881,8 +968,6 @@ int isSimpleWrapper(wrapper_t fun);
 		
 		def function_writer(f, N: FunctionType, W: str) -> None:
 			# Write to f the function type N (real type W)
-			# rettype is a single character, args is the string of argument types
-			# (those could actually be deduced from N)
 			
 			f.write("void {0}(x64emu_t *emu, uintptr_t fcn) {2} {1} fn = ({1})fcn; ".format(N, W, "{"))
 			# Generic function
@@ -945,6 +1030,7 @@ int isSimpleWrapper(wrapper_t fun);
 	
 	# Rewrite the *types.h files:
 	td_types[FunctionType.values.index('A')] = "va_list"
+	td_types[FunctionType.values.index('V')] = "..."
 	for fn in mytypedefs:
 		with open(os.path.join(root, "src", "wrapped", "generated", fn + "types.h"), 'w') as file:
 			file.write(files_header["fntypes.h"].format(lbr="{", rbr="}", version=ver, filename=fn))
@@ -968,6 +1054,6 @@ if __name__ == '__main__':
 		if v == "--":
 			limit.append(i)
 	Define.defines = list(map(DefineType, sys.argv[2:limit[0]]))
-	if main(sys.argv[1], sys.argv[limit[0]+1:], "2.0.1.14") != 0:
+	if main(sys.argv[1], sys.argv[limit[0]+1:], "2.0.2.15") != 0:
 		exit(2)
 	exit(0)