summaryrefslogtreecommitdiffstats
path: root/classification/main.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rwxr-xr-xclassification/main.py19
1 files changed, 13 insertions, 6 deletions
diff --git a/classification/main.py b/classification/main.py
index 04f2d8c4..2a6f6d9a 100755
--- a/classification/main.py
+++ b/classification/main.py
@@ -1,10 +1,17 @@
from transformers import pipeline
+from os import listdir, path
+directory : str = "./test_mails"
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
-with open("test", "r") as file:
- sequence_to_classify = file.read()
-candidate_labels = ['semantic bug', 'no semantic bug']
-result = classifier(sequence_to_classify, candidate_labels, multi_label=False)
-print(result['labels'])
-print(result['scores'])
+for name in listdir(directory):
+ with open(path.join(directory, name), "r") as file:
+ sequence_to_classify = file.read()
+
+ candidate_labels = ['semantic', 'other', 'mistranslation', 'instruction']
+ result = classifier(sequence_to_classify, candidate_labels, multi_label=True)
+
+ print(name)
+ for label, score in zip(result["labels"], result["scores"]):
+ print(f"{label}: {score:.3f}")
+ print("\n")