diff options
Diffstat (limited to 'classification/main.py')
| -rwxr-xr-x | classification/main.py | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/classification/main.py b/classification/main.py index 04f2d8c49..2a6f6d9ad 100755 --- a/classification/main.py +++ b/classification/main.py @@ -1,10 +1,17 @@ from transformers import pipeline +from os import listdir, path +directory : str = "./test_mails" classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") -with open("test", "r") as file: - sequence_to_classify = file.read() -candidate_labels = ['semantic bug', 'no semantic bug'] -result = classifier(sequence_to_classify, candidate_labels, multi_label=False) -print(result['labels']) -print(result['scores']) +for name in listdir(directory): + with open(path.join(directory, name), "r") as file: + sequence_to_classify = file.read() + + candidate_labels = ['semantic', 'other', 'mistranslation', 'instruction'] + result = classifier(sequence_to_classify, candidate_labels, multi_label=True) + + print(name) + for label, score in zip(result["labels"], result["scores"]): + print(f"{label}: {score:.3f}") + print("\n") |