summary refs log tree commit diff stats
path: root/classification/main.py
diff options
context:
space:
mode:
authorChristian Krinitsin <mail@krinitsin.com>2025-06-01 14:54:03 +0200
committerChristian Krinitsin <mail@krinitsin.com>2025-06-01 14:54:03 +0200
commitb372f5a5601734a150a32651c507a1e9ead575a5 (patch)
tree14f8fa513bbae677304106c3969080e12f695e01 /classification/main.py
parent0d401089e9e72a8d9fb9b41d920126aa9fb23b05 (diff)
downloadqemu-analysis-b372f5a5601734a150a32651c507a1e9ead575a5.tar.gz
qemu-analysis-b372f5a5601734a150a32651c507a1e9ead575a5.zip
classifier: iterates through mailing list
Diffstat (limited to 'classification/main.py')
-rwxr-xr-xclassification/main.py25
1 files changed, 21 insertions, 4 deletions
diff --git a/classification/main.py b/classification/main.py
index 7c1f2d4b4..3394a3fba 100755
--- a/classification/main.py
+++ b/classification/main.py
@@ -1,18 +1,35 @@
 from transformers import pipeline
 from argparse import ArgumentParser
+from os import path
 
 from test import test
+from files import list_files_recursive
+from output import output
 
 parser = ArgumentParser(prog='classifier')
 parser.add_argument('-d', '--deepseek', action='store_true')
+parser.add_argument('-t', '--test', action='store_true')
 args = parser.parse_args()
 
+categories = ['semantic', 'other', 'mistranslation', 'instruction']
+
 def main():
     if args.deepseek:
-        print("deepseek not supported")
-    else:
-        classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
-        test(classifier)
+        print("deepseek currently not supported")
+        exit()
+
+    classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
+    if args.test:
+        test(classifier, categories)
+        exit()
+
+    bugs = list_files_recursive("../mailinglist/output_mailinglist")
+    for bug in bugs:
+        with open(bug, "r") as file:
+            text = file.read()
+
+        result = classifier(text, categories, multi_label=True)
+        output(text, result['labels'], result['scores'], path.basename(bug))
 
 if __name__ == "__main__":
     main()