summary refs log tree commit diff stats
path: root/classification/main.py
diff options
context:
space:
mode:
authorChristian Krinitsin <mail@krinitsin.com>2025-05-29 17:10:08 +0200
committerChristian Krinitsin <mail@krinitsin.com>2025-05-29 17:10:08 +0200
commitdbbaa64f16cef5a2b32056a67433116dab84ab81 (patch)
treefcd4509cb4ae6ce059dfe4850fee213e1a12aee3 /classification/main.py
parentad77852392240639b9db7b18f8566bd458a20ade (diff)
downloadqemu-analysis-dbbaa64f16cef5a2b32056a67433116dab84ab81.tar.gz
qemu-analysis-dbbaa64f16cef5a2b32056a67433116dab84ab81.zip
first version of categories in the classifier
Diffstat (limited to 'classification/main.py')
-rwxr-xr-xclassification/main.py19
1 files changed, 13 insertions, 6 deletions
diff --git a/classification/main.py b/classification/main.py
index 04f2d8c49..2a6f6d9ad 100755
--- a/classification/main.py
+++ b/classification/main.py
@@ -1,10 +1,17 @@
 from transformers import pipeline
+from os import listdir, path
 
+directory : str = "./test_mails"
 classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
-with open("test", "r") as file:
-    sequence_to_classify = file.read()
-candidate_labels = ['semantic bug', 'no semantic bug']
-result = classifier(sequence_to_classify, candidate_labels, multi_label=False)
 
-print(result['labels'])
-print(result['scores'])
+for name in listdir(directory):
+    with open(path.join(directory, name), "r") as file:
+        sequence_to_classify = file.read()
+
+    candidate_labels = ['semantic', 'other', 'mistranslation', 'instruction']
+    result = classifier(sequence_to_classify, candidate_labels, multi_label=True)
+
+    print(name)
+    for label, score in zip(result["labels"], result["scores"]):
+        print(f"{label}: {score:.3f}")
+    print("\n")