about summary refs log tree commit diff stats
path: root/wrapperhelper/gen.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'wrapperhelper/gen.cpp')
-rw-r--r--wrapperhelper/gen.cpp70
1 files changed, 47 insertions, 23 deletions
diff --git a/wrapperhelper/gen.cpp b/wrapperhelper/gen.cpp
index c7cba1e2..743474cd 100644
--- a/wrapperhelper/gen.cpp
+++ b/wrapperhelper/gen.cpp
@@ -289,6 +289,7 @@ std::string WrapperGenerator::GenDeclare(ASTContext *Ctx,
                                          const RecordInfo &Record) {
   (void)Ctx;
   std::string RecordStr;
+  std::string PreDecl;
   RecordStr += "\ntypedef ";
   RecordStr +=
       (Record.is_union ? "union " : "struct ") + Record.type_name + " {\n";
@@ -327,10 +328,10 @@ std::string WrapperGenerator::GenDeclare(ASTContext *Ctx,
         FieldStr += Name;
         RecordStr += FieldStr;
       } else {
-        RecordStr += TypeStringify(StripTypedef(Type), Field, nullptr);
+        RecordStr += TypeStringify(StripTypedef(Type), Field, nullptr, PreDecl);
       }
     } else {
-      RecordStr += TypeStringify(StripTypedef(Type), Field, nullptr);
+      RecordStr += TypeStringify(StripTypedef(Type), Field, nullptr, PreDecl);
     }
     RecordStr += ";\n";
   }
@@ -547,6 +548,7 @@ std::string WrapperGenerator::GenDeclareDiffTriple(
   (void)Ctx;
   std::string GuestRecord;
   std::string HostRecord;
+  std::string PreDecl;
   std::vector<uint64_t> GuestFieldOff;
   std::vector<uint64_t> HostFieldOff;
   GuestRecord += "typedef ";
@@ -599,7 +601,7 @@ std::string WrapperGenerator::GenDeclareDiffTriple(
         std::cout << "Err: unknown type size " << typeSize << std::endl;
         break;
       }
-      HostRecord += TypeStringify(StripTypedef(Type), Field, nullptr);
+      HostRecord += TypeStringify(StripTypedef(Type), Field, nullptr, PreDecl);
     } else if (Type->isFunctionPointerType()) {
       auto FuncType = StripTypedef(Type->getPointeeType());
       if (callbacks.count(FuncType)) {
@@ -634,12 +636,12 @@ std::string WrapperGenerator::GenDeclareDiffTriple(
         GuestRecord += FieldStr;
         HostRecord += "host_" + FieldStr;
       } else {
-        GuestRecord += TypeStringify(StripTypedef(Type), Field, nullptr);
-        HostRecord += TypeStringify(StripTypedef(Type), Field, nullptr);
+        GuestRecord += TypeStringify(StripTypedef(Type), Field, nullptr, PreDecl);
+        HostRecord += TypeStringify(StripTypedef(Type), Field, nullptr, PreDecl);
       }
     } else {
-      HostRecord += TypeStringify(StripTypedef(Type), Field, nullptr);
-      GuestRecord += TypeStringify(StripTypedef(Type), Field, nullptr);
+      HostRecord += TypeStringify(StripTypedef(Type), Field, nullptr, PreDecl);
+      GuestRecord += TypeStringify(StripTypedef(Type), Field, nullptr, PreDecl);
     }
     GuestRecord += ";\n";
     HostRecord += ";\n";
@@ -685,12 +687,12 @@ std::string WrapperGenerator::GenRecordConvert(const RecordInfo &Record) {
     if (!AlignDiffFields.size()) {
       return res;
     }
-    res += "void g2h_" + Record.type_name + "(" + "struct host_" + Record.type_name + "*d, struct" + Record.type_name + "*s) {\n";
+    res += "void g2h_" + Record.type_name + "(" + "struct host_" + Record.type_name + " *d, struct " + Record.type_name + " *s) {\n";
     std::string body = "    memcpy(d, s, offsetof(struct " + Record.type_name +
                        ", " + AlignDiffFields[0]->getNameAsString() + "));\n";
     std::string offstr = "offsetof(struct " + Record.type_name + ", " +
                          AlignDiffFields[0]->getNameAsString() + ")";
-    for (size_t i = 1; i < AlignDiffFields.size() - 1; i++) {
+    for (size_t i = 0; i < AlignDiffFields.size() - 1; i++) {
       body += "    memcpy(d->" + AlignDiffFields[i]->getNameAsString() + ", " +
               "s->" + AlignDiffFields[i]->getNameAsString() + ", " +
               "offsetof(struct " + Record.type_name + ", " +
@@ -707,7 +709,7 @@ std::string WrapperGenerator::GenRecordConvert(const RecordInfo &Record) {
             " - " + offstr + ");\n";
     res += body + "}\n";
 
-    res += "void h2g_" + Record.type_name + "(struct" + Record.type_name + "*d, " + "struct host_" + Record.type_name + "*s) {\n";
+    res += "void h2g_" + Record.type_name + "(struct " + Record.type_name + " *d, " + "struct host_" + Record.type_name + " *s) {\n";
     res += body;
     res += "}\n";
   }
@@ -775,6 +777,7 @@ void WrapperGenerator::ParseRecordRecursive(
 std::string WrapperGenerator::TypeStringify(const Type *Type,
                                          FieldDecl *FieldDecl,
                                          ParmVarDecl *ParmDecl,
+                                         std::string& PreDecl,
                                          std::string indent,
                                          std::string Name) {
   std::string res;
@@ -807,7 +810,7 @@ std::string WrapperGenerator::TypeStringify(const Type *Type,
       res += records[StripTypedef(Type->getCanonicalTypeInternal())].type_name;
       res += " ";
     } else {
-      res += AnonRecordDecl(Type->getAs<RecordType>(), indent);
+      res += AnonRecordDecl(Type->getAs<RecordType>(), PreDecl, indent + "  ");
     }
     res += name;
   } else if (Type->isConstantArrayType()) {
@@ -816,6 +819,21 @@ std::string WrapperGenerator::TypeStringify(const Type *Type,
     int EleSize = ArrayType->getSize().getZExtValue();
     if (ArrayType->getElementType()->isPointerType()) {
       res += "void *";
+    } else if (ArrayType->getElementType()->isEnumeralType()) {
+      res += "int ";
+    } else if (ArrayType->getElementType()->isRecordType()) {
+      auto RecordType = ArrayType->getElementType()->getAs<clang::RecordType>();
+      auto RecordDecl = RecordType->getDecl();
+      if (RecordDecl->isCompleteDefinition()) {
+        auto& Ctx = RecordDecl->getDeclContext()->getParentASTContext();
+        PreDecl += "#include \"";
+        PreDecl += GetDeclHeaderFile(Ctx, RecordDecl);
+        PreDecl += "\"";
+        PreDecl += "\n";
+      }
+      res += StripTypedef(ArrayType->getElementType())
+                ->getCanonicalTypeInternal()
+                .getAsString();
     } else {
       res += StripTypedef(ArrayType->getElementType())
                  ->getCanonicalTypeInternal()
@@ -887,13 +905,13 @@ std::string WrapperGenerator::SimpleTypeStringify(const Type *Type,
   return indent + res;
 }
 
-std::string WrapperGenerator::AnonRecordDecl(const RecordType *Type, std::string indent) {
+std::string WrapperGenerator::AnonRecordDecl(const RecordType *Type, std::string& PreDecl, std::string indent) {
   auto RecordDecl = Type->getDecl();
   std::string res;
   res += Type->isUnionType() ? "union {\n" : "struct {\n";
   for (const auto &field : RecordDecl->fields()) {
     auto FieldType = field->getType();
-    res += TypeStringify(StripTypedef(FieldType), field, nullptr, indent + "    ");
+    res += TypeStringify(StripTypedef(FieldType), field, nullptr, PreDecl, indent + "  ");
     res += ";\n";
   }
   res += indent + "} ";
@@ -907,7 +925,7 @@ WrapperGenerator::SimpleAnonRecordDecl(const RecordType *Type, std::string inden
   res += Type->isUnionType() ? "union {\n" : "struct {\n";
   for (const auto &field : RecordDecl->fields()) {
     auto FieldType = field->getType();
-    res += SimpleTypeStringify(StripTypedef(FieldType), field, nullptr, indent + "    ");
+    res += SimpleTypeStringify(StripTypedef(FieldType), field, nullptr, indent + "  ");
     res += ";\n";
   }
   res += indent + "} ";
@@ -917,15 +935,16 @@ WrapperGenerator::SimpleAnonRecordDecl(const RecordType *Type, std::string inden
 // Get func info from FunctionType
 FuncDefinition WrapperGenerator::GetFuncDefinition(const Type *Type) {
   FuncDefinition res;
+  std::string PreDecl;
   auto ProtoType = Type->getAs<FunctionProtoType>();
   res.ret = StripTypedef(ProtoType->getReturnType());
   res.ret_str =
-      TypeStringify(StripTypedef(ProtoType->getReturnType()), nullptr, nullptr);
+      TypeStringify(StripTypedef(ProtoType->getReturnType()), nullptr, nullptr, PreDecl);
   for (unsigned i = 0; i < ProtoType->getNumParams(); i++) {
     auto ParamType = ProtoType->getParamType(i);
     res.arg_types.push_back(StripTypedef(ParamType));
     res.arg_types_str.push_back(
-        TypeStringify(StripTypedef(ParamType), nullptr, nullptr));
+        TypeStringify(StripTypedef(ParamType), nullptr, nullptr, PreDecl));
     res.arg_names.push_back(std::string("a") + std::to_string(i));
   }
   if (ProtoType->isVariadic()) {
@@ -938,15 +957,16 @@ FuncDefinition WrapperGenerator::GetFuncDefinition(const Type *Type) {
 // Get funcdecl info from FunctionDecl
 FuncDefinition WrapperGenerator::GetFuncDefinition(FunctionDecl *Decl) {
   FuncDefinition res;
+  std::string PreDecl;
   auto RetType = Decl->getReturnType();
   res.ret = RetType.getTypePtr();
-  res.ret_str = TypeStringify(StripTypedef(RetType), nullptr, nullptr);
+  res.ret_str = TypeStringify(StripTypedef(RetType), nullptr, nullptr, PreDecl);
   for (unsigned i = 0; i < Decl->getNumParams(); i++) {
     auto ParamDecl = Decl->getParamDecl(i);
     auto ParamType = ParamDecl->getType();
     res.arg_types.push_back(ParamType.getTypePtr());
     res.arg_types_str.push_back(
-        TypeStringify(StripTypedef(ParamType), nullptr, nullptr));
+        TypeStringify(StripTypedef(ParamType), nullptr, nullptr, PreDecl));
     res.arg_names.push_back(ParamDecl->getNameAsString());
   }
   if (Decl->isVariadic()) {
@@ -960,7 +980,8 @@ std::vector<uint64_t> WrapperGenerator::GetRecordFieldOffDiff(
     const Type *Type, const std::string &GuestTriple,
     const std::string &HostTriple, std::vector<uint64_t> &GuestFieldOff,
     std::vector<uint64_t> &HostFieldOff) {
-  std::string Code = TypeStringify(Type, nullptr, nullptr, "", "dummy;");
+  std::string PreDecl;
+  std::string Code = TypeStringify(Type, nullptr, nullptr, PreDecl, "", "dummy;");
   std::vector<uint64_t> OffsetDiff;
   GuestFieldOff = GetRecordFieldOff(Code, GuestTriple);
   HostFieldOff = GetRecordFieldOff(Code, HostTriple);
@@ -978,15 +999,18 @@ std::vector<uint64_t> WrapperGenerator::GetRecordFieldOffDiff(
 // Get the size under a specific triple
 uint64_t WrapperGenerator::GetRecordSize(const Type *Type,
                                     const std::string &Triple) {
-  std::string Code = TypeStringify(Type, nullptr, nullptr, "", "dummy;");
-  return ::GetRecordSize(Code, Triple);
+  std::string PreDecl;
+  std::string Code = TypeStringify(Type, nullptr, nullptr, PreDecl, "", "dummy;");
+  auto Size = ::GetRecordSize(PreDecl + Code, Triple);
+  return Size;
 }
 
 // Get the align under a specific triple
 CharUnits::QuantityType WrapperGenerator::GetRecordAlign(const Type *Type,
                                      const std::string &Triple) {
-  std::string Code = TypeStringify(Type, nullptr, nullptr, "", "dummy;");
-  return ::GetRecordAlign(Code, Triple);
+  std::string PreDecl{};
+  std::string Code = TypeStringify(Type, nullptr, nullptr, PreDecl, "", "dummy;");
+  return ::GetRecordAlign(PreDecl + Code, Triple);
 }
 
 // Generate the func sig by type, used for export func