diff options
Diffstat (limited to 'rebuild_wrappers.py')
| -rwxr-xr-x | rebuild_wrappers.py | 206 |
1 files changed, 146 insertions, 60 deletions
diff --git a/rebuild_wrappers.py b/rebuild_wrappers.py index 79e97736..65f0f786 100755 --- a/rebuild_wrappers.py +++ b/rebuild_wrappers.py @@ -21,43 +21,34 @@ except ImportError: try: # Python 3.8+ from typing import final - import typing except ImportError: #print("Your Python version does not have all typing utilities, fallback to dummy ones") - def _overload_dummy(*args, **kwds): - raise NotImplementedError( - "You should not call an overloaded function. " - "A series of @overload-decorated functions " - "outside a stub module should always be followed " - "by an implementation that is not @overload-ed.") - def overload(fun): return _overload_dummy # type: ignore - class Typing: - pass - typing = Typing() # type: ignore - typing.overload = overload # type: ignore final = lambda fun: fun # type: ignore import os import sys class FunctionType(str): - values: List[str] = ['E', 'e', 'v', 'c', 'w', 'i', 'I', 'C', 'W', 'u', 'U', 'f', 'd', 'D', 'K', 'l', 'L', 'p', 'V', 'O', 'S', 'N', 'M', 'H', 'P', 'A'] + values: List[str] = ['E', 'v', 'c', 'w', 'i', 'I', 'C', 'W', 'u', 'U', 'f', 'd', 'D', 'K', 'l', 'L', 'p', 'V', 'O', 'S', 'N', 'M', 'H', 'P', 'A'] @staticmethod def validate(s: str, post: str) -> bool: if len(s) < 3: raise NotImplementedError("Type {0} too short{1}".format(s, post)) + chk_type = s[0] + s[2:] if "E" in s: if ("E" in s[:2]) or ("E" in s[3:]): raise NotImplementedError( "emu64_t* not as the first parameter{0}".format(post)) if len(s) < 4: raise NotImplementedError("Type {0} too short{1}".format(s, post)) + # TODO: change *FEv into true functions (right now they are redirected to *FE) + #chk_type = s[0] + s[3:] if s[1] not in ["F"]: raise NotImplementedError("Bad middle letter {0}{1}".format(s[1], post)) - return all(c in FunctionType.values for c in s[2:]) and (('v' not in s[2:]) or (len(s) == 3)) + return all(c in FunctionType.values for c in chk_type) and (('v' not in chk_type[1:]) or (len(chk_type) == 2)) def splitchar(self) -> List[int]: """ @@ -219,12 +210,71 @@ def readFiles(files: Sequence[Filename]) -> \ redirects : Dict[ClausesStr, Dict[RedirectType, FunctionType]] = {} mytypedefs: Dict[Filename, Dict[RedirectType, List[str]]] = {} - halt_required = False # Is there a GO(*, .FE*)? - + functions: Dict[str, Filename] = {} + halt_required = False # Is there a GO(*, .FE*) or similar in-batch error(s)? # First read the files inside the headers for filepath in files: filename: Filename = filepath.split("/")[-1] dependants: Clause = Clause() + + def add_function_name(funname: Union[str, None], funsname: Dict[ClausesStr, List[str]] = {"": []}): + # Optional arguments are evaluated only once! + nonlocal halt_required + if funname == None: + for k in funsname: + if (k != "") and (len(funsname[k]) != 0): + # Note: if this condition ever raises, check the wrapper pointed by it. + # If you find no problem, comment the error below, add a "pass" line (so python is happy) + # and open a ticket so I can fix this. + raise NotImplementedError("Some functions are only implemented under one condition (probably) ({0}/{1})" + .format(k, filename) + " [extra note in the script]") + for f in funsname[k]: + if f in ['_fini', '_init', '__bss_start', '__data_start', '_edata', '_end']: + continue # Always allow those symbols [TODO: check if OK] + if f in functions: + # Check for resemblances between functions[f] and filename + if filename.startswith(functions[f][:-12]) or functions[f].startswith(filename[:-12]): + # Probably OK + continue + # Manual compatible libs detection + match = lambda l, r: (filename[7:-10], functions[f][7:-10]) in [(l, r), (r, l)] + if match("sdl1image", "sdl2image") \ + or match("sdl1mixer", "sdl2mixer") \ + or match("sdl1net", "sdl2net") \ + or match("sdl1ttf", "sdl2ttf") \ + or match("libc", "tcmallocminimal"): + continue + + # Note: this test is very (too) simple. If it ever raises, comment + # `halt_required = True` and open an issue. + print("The function or data {0} is declared in multiple files ({1}/{2})" + .format(f, functions[f], filename) + " [extra note in the script]") + halt_required = True + functions[f] = filename + else: + if funname == "": + raise NotImplementedError("This function name (\"\") is suspicious... ({0})".format(filename)) + l = len(dependants.defines) + already_pst = funname in funsname[""] + if l > 1: + return + elif l == 1: + funsname.setdefault(str(dependants), []) + already_pst = already_pst or (funname in funsname[str(dependants)]) + if already_pst: + print("Function or data {0} is duplicated! ({1})".format(funname, filename)) + halt_required = True + return + if l == 1: + s = str(dependants.defines[0].inverted()) + if (s in funsname) and (funname in funsname[s]): + funsname[s].remove(funname) + funsname[""].append(funname) + else: + funsname[str(dependants)].append(funname) + else: + funsname[""].append(funname) + with open(filepath, 'r') as file: for line in file: ln = line.strip() @@ -261,7 +311,8 @@ def readFiles(files: Sequence[Filename]) -> \ gotype = ln.split("(")[0].strip() funname = ln.split(",")[0].split("(")[1].strip() ln = ln.split(",")[1].split(")")[0].strip() - except IndexError as e: + add_function_name(funname) + except IndexError: raise NotImplementedError("Invalid GO command: {0}:{1}".format( filename, line[:-1] )) @@ -275,7 +326,7 @@ def readFiles(files: Sequence[Filename]) -> \ # Ok, this is acceptable: there is 0, 1 and/or void ln = ln[:2] + (ln[2:] .replace("v", "") # void -> nothing - .replace("0", "p") # 0 -> pointer + .replace("0", "i") # 0 -> integer .replace("1", "i")) # 1 -> integer assert(len(ln) >= 3) redirects.setdefault(str(dependants), {}) @@ -286,22 +337,36 @@ def readFiles(files: Sequence[Filename]) -> \ gbl[str(dependants)].append(FunctionType(ln)) if ln[2] == "E": - if (gotype != "GOM"): + if (gotype != "GOM") and (gotype != "GOWM"): if (gotype != "GO2") or not (line.split(',')[2].split(')')[0].strip().startswith('my_')): print("\033[91mThis is probably not what you meant!\033[m ({0}:{1})".format(filename, line[:-1])) halt_required = True - # filename isn't stored with the '_private.h' part if len(ln) > 3: funtype = RedirectType(FunctionType(ln[:2] + ln[3:])) else: funtype = RedirectType(FunctionType(ln[:2] + "v")) + # filename isn't stored with the '_private.h' part mytypedefs.setdefault(filename[:-10], {}) mytypedefs[filename[:-10]].setdefault(funtype, []) mytypedefs[filename[:-10]][funtype].append(funname) - # OK on box64 - # elif gotype == "GOM": - # print("\033[94mAre you sure of this?\033[m ({0}:{1})".format(filename, line[:-1])) - # halt_required = True + elif (gotype == "GOM") or (gotype == "GOWM"): + # OK on box64 for a GOM to not have emu... + funtype = RedirectType(FunctionType(ln)) + mytypedefs.setdefault(filename[:-10], {}) + mytypedefs[filename[:-10]].setdefault(funtype, []) + mytypedefs[filename[:-10]][funtype].append(funname) + # print("\033[94mAre you sure of this?\033[m ({0}:{1})".format(filename, line[:-1])) + # halt_required = True + elif ("GO" in ln) or ("DATA" in ln): + # Probably "//GO(..., " or "DATA(...," at least + try: + funname = ln.split('(')[1].split(',')[0].strip() + add_function_name(funname) + except IndexError: + # Oops, it wasn't... + pass + + add_function_name(None) if halt_required: raise ValueError("Fix all previous errors before proceeding") @@ -511,7 +576,7 @@ def main(root: str, files: Sequence[Filename], ver: str): allowed_fpr : str = "fd" # Sanity checks - forbidden_simple: str = "EeDKVOSNMHPA" + forbidden_simple: str = "EDKVOSNMHPA" assert(len(allowed_simply) + len(allowed_regs) + len(allowed_fpr) + len(forbidden_simple) == len(FunctionType.values)) assert(all(c not in allowed_regs for c in allowed_simply)) assert(all(c not in allowed_simply + allowed_regs for c in allowed_fpr)) @@ -598,21 +663,24 @@ typedef struct x64emu_s x64emu_t; typedef void (*wrapper_t)(x64emu_t* emu, uintptr_t fnc); // list of defined wrapper -// v = void, i = int32, u = uint32, U/I= (u)int64 -// l = signed long, L = unsigned long (long is an int with the size of a pointer) -// p = pointer, P = void* on the stack -// f = float, d = double, D = long double, K = fake long double -// V = vaargs, E = current x86emu struct, e = ref to current x86emu struct -// 0 = constant 0, 1 = constant 1 -// o = stdout +// E = current x86emu struct +// v = void // C = unsigned byte c = char // W = unsigned short w = short +// u = uint32, i = int32 +// U = uint64, I= int64 +// L = unsigned long, l = signed long (long is an int with the size of a pointer) +// H = Huge 128bits value/struct +// p = pointer, P = void* on the stack +// f = float, d = double, D = long double, K = fake long double +// V = vaargs // O = libc O_ flags bitfield +// o = stdout // S = _IO_2_1_stdXXX_ pointer (or FILE*) // N = ... automatically sending 1 arg // M = ... automatically sending 2 args -// H = Huge 128bits value/struct // A = va_list +// 0 = constant 0, 1 = constant 1 """, "fntypes.h": """/******************************************************************* @@ -643,8 +711,8 @@ int isSimpleWrapper(wrapper_t fun); # Rewrite the wrapper.c file: # i and u should only be 32 bits - # E e v c w i I C W u U f d D K l L p V O S N M H P A - td_types = ["x64emu_t*", "x64emu_t**", "void", "int8_t", "int16_t", "int64_t", "int64_t", "uint8_t", "uint16_t", "uint64_t", "uint64_t", "float", "double", "long double", "double", "intptr_t", "uintptr_t", "void*", "void*", "int32_t", "void*", "...", "...", "unsigned __int128", "void*", "void*"] + # E v c w i I C W u U f d D K l L p V O S N M H P A + td_types = ["x64emu_t*", "void", "int8_t", "int16_t", "int64_t", "int64_t", "uint8_t", "uint16_t", "uint64_t", "uint64_t", "float", "double", "long double", "double", "intptr_t", "uintptr_t", "void*", "void*", "int32_t", "void*", "...", "...", "unsigned __int128", "void*", "void*"] if len(FunctionType.values) != len(td_types): raise NotImplementedError("len(values) = {lenval} != len(td_types) = {lentypes}".format(lenval=len(FunctionType.values), lentypes=len(td_types))) @@ -671,7 +739,6 @@ int isSimpleWrapper(wrapper_t fun); # Return type template vals = [ "\n#error Invalid return type: emulator\n", # E - "\n#error Invalid return type: &emulator\n", # e "fn({0});", # v "R_RAX=fn({0});", # c "R_RAX=fn({0});", # w @@ -700,21 +767,21 @@ int isSimpleWrapper(wrapper_t fun); # Name of the registers reg_arg = ["R_RDI", "R_RSI", "R_RDX", "R_RCX", "R_R8", "R_R9"] + assert(len(reg_arg) == 6) # vreg: value is in a general register - # E e v c w i I C W u U f d D K l L p V O S N M H P A - vreg = [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 2, 2, 0, 1] + # E v c w i I C W u U f d D K l L p V O S N M H P A + vreg = [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 2, 2, 0, 1] # vxmm: value is in a XMM register - # E e v c w i I C W u U f d D K l L p V O S N M H P A - vxmm = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + # E v c w i I C W u U f d D K l L p V O S N M H P A + vxmm = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] # vother: value is elsewere - # E e v c w i I C W u U f d D K l L p V O S N M H P A - vother = [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0] + # E v c w i I C W u U f d D K l L p V O S N M H P A + vother = [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0] # vstack: value is on the stack (or out of register) - # E e v c w i I C W u U f d D K l L p V O S N M H P A - vstack = [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 0, 1, 1, 1, 2, 2, 1, 1] + # E v c w i I C W u U f d D K l L p V O S N M H P A + vstack = [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 0, 1, 1, 1, 2, 2, 1, 1] arg_r = [ "", # E - "", # e "", # v "(int8_t){p}, ", # c "(int16_t){p}, ", # w @@ -742,7 +809,6 @@ int isSimpleWrapper(wrapper_t fun); ] arg_x = [ "", # E - "", # e "", # v "", # c "", # w @@ -770,7 +836,6 @@ int isSimpleWrapper(wrapper_t fun); ] arg_o = [ "emu, ", # E - "&emu, ", # e "", # v "", # c "", # w @@ -794,11 +859,10 @@ int isSimpleWrapper(wrapper_t fun); "", # M "", # H "", # P - "", # A + "", # A ] arg_s = [ "", # E - "", # e "", # v "*(int8_t*)(R_RSP + {p}), ", # c "*(int16_t*)(R_RSP + {p}), ", # w @@ -844,6 +908,24 @@ int isSimpleWrapper(wrapper_t fun); raise NotImplementedError("len(values) = {lenval} != len(arg_o) = {lenargo}".format(lenval=len(FunctionType.values), lenargo=len(arg_o))) if len(FunctionType.values) != len(vals): raise NotImplementedError("len(values) = {lenval} != len(vals) = {lenvals}".format(lenval=len(FunctionType.values), lenvals=len(vals))) + # When arg_* is not empty, v* should not be 0 + if any(map(lambda v, a: (a != "") and (v == 0), vstack, arg_s)): + raise NotImplementedError("Something in the stack has a null offset and a non-empty arg string") + if any(map(lambda v, a: (a != "") and (v == 0), vreg, arg_r)): + raise NotImplementedError("Something in the stack has a null offset and a non-empty arg string") + if any(map(lambda v, a: (a != "") and (v == 0), vxmm, arg_x)): + raise NotImplementedError("Something in the stack has a null offset and a non-empty arg string") + if any(map(lambda v, a: (a != "") and (v == 0), vother, arg_o)): + raise NotImplementedError("Something in the stack has a null offset and a non-empty arg string") + # Everything is either in the stack or somewhere else, it cannot be in a GPr and in an XMMr... + if any(map(lambda o, s: (o == 0) == (s == 0), vother, vstack)): + raise NotImplementedError("Something cannot be in exactly one of the stack and somewhere else") + if any(map(lambda r, x: (r > 0) and (x > 0), vreg, vxmm)): + raise NotImplementedError("Something can be in both a general purpose register and in an XMM register") + if any(map(lambda r, s: (r > 0) and (s == 0), vreg, vstack)): + raise NotImplementedError("Something can be in a general purpose register but not in the stack") + if any(map(lambda x, s: (x > 0) and (s == 0), vxmm, vstack)): + raise NotImplementedError("Something can be in an XMM register but not in the stack") # Helper functions to write the function definitions def function_args(args: FunctionType, d: int = 8, r: int = 0, x: int = 0) -> str: @@ -862,15 +944,20 @@ int isSimpleWrapper(wrapper_t fun); idx = FunctionType.values.index(args[0]) if (r < 6) and (vreg[idx] > 0): - # Value is in a general register (and there is still one available) - if (vreg[idx] == 2) and (r == 6): - return arg_r[idx-1].format(p=reg_arg[r]) + arg_s[idx-1].format(p=d) + function_args(args[1:], d + vother[idx-1]*8, r+1, x) - elif (vreg[idx] == 2) and (r < 6): - return arg_r[idx].format(p=reg_arg[r]) + arg_r[idx].format(p=reg_arg[r+1]) + function_args(args[1:], d, r+2, x) - else: - return arg_r[idx].format(p=reg_arg[r]) + function_args(args[1:], d, r+1, x) + ret = "" + for _ in range(vreg[idx]): + # There may be values in multiple registers + if r < 6: + # Value is in a general register + ret = ret + arg_r[idx].format(p=reg_arg[r]) + r = r + 1 + else: + # Remaining is in the stack + ret = ret + arg_s[idx].format(p=d) + d = d + 8 + return ret + function_args(args[1:], d, r, x) elif (x < 8) and (vxmm[idx] > 0): - # Value is in an XMM register (and there is still one available) + # Value is in an XMM register return arg_x[idx].format(p=x) + function_args(args[1:], d, r, x+1) elif vstack[idx] > 0: # Value is in the stack @@ -881,8 +968,6 @@ int isSimpleWrapper(wrapper_t fun); def function_writer(f, N: FunctionType, W: str) -> None: # Write to f the function type N (real type W) - # rettype is a single character, args is the string of argument types - # (those could actually be deduced from N) f.write("void {0}(x64emu_t *emu, uintptr_t fcn) {2} {1} fn = ({1})fcn; ".format(N, W, "{")) # Generic function @@ -945,6 +1030,7 @@ int isSimpleWrapper(wrapper_t fun); # Rewrite the *types.h files: td_types[FunctionType.values.index('A')] = "va_list" + td_types[FunctionType.values.index('V')] = "..." for fn in mytypedefs: with open(os.path.join(root, "src", "wrapped", "generated", fn + "types.h"), 'w') as file: file.write(files_header["fntypes.h"].format(lbr="{", rbr="}", version=ver, filename=fn)) @@ -968,6 +1054,6 @@ if __name__ == '__main__': if v == "--": limit.append(i) Define.defines = list(map(DefineType, sys.argv[2:limit[0]])) - if main(sys.argv[1], sys.argv[limit[0]+1:], "2.0.1.14") != 0: + if main(sys.argv[1], sys.argv[limit[0]+1:], "2.0.2.15") != 0: exit(2) exit(0) |