From dbbaa64f16cef5a2b32056a67433116dab84ab81 Mon Sep 17 00:00:00 2001 From: Christian Krinitsin Date: Thu, 29 May 2025 17:10:08 +0200 Subject: first version of categories in the classifier --- classification/main.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) (limited to 'classification/main.py') diff --git a/classification/main.py b/classification/main.py index 04f2d8c49..2a6f6d9ad 100755 --- a/classification/main.py +++ b/classification/main.py @@ -1,10 +1,17 @@ from transformers import pipeline +from os import listdir, path +directory : str = "./test_mails" classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") -with open("test", "r") as file: - sequence_to_classify = file.read() -candidate_labels = ['semantic bug', 'no semantic bug'] -result = classifier(sequence_to_classify, candidate_labels, multi_label=False) -print(result['labels']) -print(result['scores']) +for name in listdir(directory): + with open(path.join(directory, name), "r") as file: + sequence_to_classify = file.read() + + candidate_labels = ['semantic', 'other', 'mistranslation', 'instruction'] + result = classifier(sequence_to_classify, candidate_labels, multi_label=True) + + print(name) + for label, score in zip(result["labels"], result["scores"]): + print(f"{label}: {score:.3f}") + print("\n") -- cgit 1.4.1