summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rwxr-xr-xclassification/files.py16
-rwxr-xr-xclassification/main.py25
-rw-r--r--classification/output.py12
-rwxr-xr-x[-rw-r--r--]classification/test.py7
4 files changed, 52 insertions, 8 deletions
diff --git a/classification/files.py b/classification/files.py
new file mode 100755
index 00000000..65efda6f
--- /dev/null
+++ b/classification/files.py
@@ -0,0 +1,16 @@
+import os
+
+def list_files_recursive(path='.'):
+    result = []
+    for entry in os.listdir(path):
+        full_path = os.path.join(path, entry)
+        if os.path.isdir(full_path):
+            result = result + list_files_recursive(full_path)
+        else:
+            result.append(full_path)
+    return result
+
+if __name__ == "__main__":
+    directory_path = '../gitlab/issues_text'
+    arr = list_files_recursive(directory_path)
+    print(arr)
diff --git a/classification/main.py b/classification/main.py
index 7c1f2d4b..3394a3fb 100755
--- a/classification/main.py
+++ b/classification/main.py
@@ -1,18 +1,35 @@
 from transformers import pipeline
 from argparse import ArgumentParser
+from os import path
 
 from test import test
+from files import list_files_recursive
+from output import output
 
 parser = ArgumentParser(prog='classifier')
 parser.add_argument('-d', '--deepseek', action='store_true')
+parser.add_argument('-t', '--test', action='store_true')
 args = parser.parse_args()
 
+categories = ['semantic', 'other', 'mistranslation', 'instruction']
+
 def main():
     if args.deepseek:
-        print("deepseek not supported")
-    else:
-        classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
-        test(classifier)
+        print("deepseek currently not supported")
+        exit()
+
+    classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
+    if args.test:
+        test(classifier, categories)
+        exit()
+
+    bugs = list_files_recursive("../mailinglist/output_mailinglist")
+    for bug in bugs:
+        with open(bug, "r") as file:
+            text = file.read()
+
+        result = classifier(text, categories, multi_label=True)
+        output(text, result['labels'], result['scores'], path.basename(bug))
 
 if __name__ == "__main__":
     main()
diff --git a/classification/output.py b/classification/output.py
new file mode 100644
index 00000000..df64dbcc
--- /dev/null
+++ b/classification/output.py
@@ -0,0 +1,12 @@
+from os import path, makedirs
+
+def output(text : str, labels : list, scores : list, identifier : str):
+    file_path = f"output/{labels[0]}/{identifier}"
+    makedirs(path.dirname(file_path), exist_ok = True)
+
+    with open(file_path, "w") as file:
+        for label, score in zip(labels, scores):
+            file.write(f"{label}: {score:.3f}\n")
+
+        file.write("\n")
+        file.write(text)
diff --git a/classification/test.py b/classification/test.py
index a8bc7e1e..bcd6b439 100644..100755
--- a/classification/test.py
+++ b/classification/test.py
@@ -2,16 +2,15 @@ from os import listdir, path
 
 directory : str = "./test_input"
 
-def test(classifier):
+def test(classifier, categories):
     for name in listdir(directory):
         if name == "README.md":
             continue
 
         with open(path.join(directory, name), "r") as file:
-            sequence_to_classify = file.read()
+            text = file.read()
 
-        candidate_labels = ['semantic', 'other', 'mistranslation', 'instruction']
-        result = classifier(sequence_to_classify, candidate_labels, multi_label=True)
+        result = classifier(text, categories, multi_label=True)
 
         print(name)
         for label, score in zip(result["labels"], result["scores"]):