diff options
| author | Christian Krinitsin <mail@krinitsin.com> | 2025-05-30 15:56:00 +0200 |
|---|---|---|
| committer | Christian Krinitsin <mail@krinitsin.com> | 2025-05-30 15:56:00 +0200 |
| commit | 712310482c3dbef91c3eb6458d1bff82a275fa52 (patch) | |
| tree | a1bcdb8df87d90ef121a093d4ea416838f84f856 | |
| parent | fb84fa98ea1effc76cea3b3426546b4a3851af0b (diff) | |
| download | qemu-analysis-712310482c3dbef91c3eb6458d1bff82a275fa52.tar.gz qemu-analysis-712310482c3dbef91c3eb6458d1bff82a275fa52.zip | |
add test script for the classifier
| -rwxr-xr-x | classification/main.py | 25 | ||||
| -rw-r--r-- | classification/test.py | 16 | ||||
| -rw-r--r-- | classification/test_input/mail_other_1 (renamed from classification/test_mails/mail_other_1) | 0 | ||||
| -rw-r--r-- | classification/test_input/mail_other_2 (renamed from classification/test_mails/mail_other_2) | 0 | ||||
| -rw-r--r-- | classification/test_input/mail_other_3 (renamed from classification/test_mails/mail_other_3) | 0 | ||||
| -rw-r--r-- | classification/test_input/mail_semantic_1 (renamed from classification/test_mails/mail_semantic_1) | 0 | ||||
| -rw-r--r-- | classification/test_input/mail_semantic_2 (renamed from classification/test_mails/mail_semantic_2) | 0 |
7 files changed, 29 insertions, 12 deletions
diff --git a/classification/main.py b/classification/main.py index 2a6f6d9ad..7c1f2d4b4 100755 --- a/classification/main.py +++ b/classification/main.py @@ -1,17 +1,18 @@ from transformers import pipeline -from os import listdir, path +from argparse import ArgumentParser -directory : str = "./test_mails" -classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") +from test import test -for name in listdir(directory): - with open(path.join(directory, name), "r") as file: - sequence_to_classify = file.read() +parser = ArgumentParser(prog='classifier') +parser.add_argument('-d', '--deepseek', action='store_true') +args = parser.parse_args() - candidate_labels = ['semantic', 'other', 'mistranslation', 'instruction'] - result = classifier(sequence_to_classify, candidate_labels, multi_label=True) +def main(): + if args.deepseek: + print("deepseek not supported") + else: + classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") + test(classifier) - print(name) - for label, score in zip(result["labels"], result["scores"]): - print(f"{label}: {score:.3f}") - print("\n") +if __name__ == "__main__": + main() diff --git a/classification/test.py b/classification/test.py new file mode 100644 index 000000000..e0db00313 --- /dev/null +++ b/classification/test.py @@ -0,0 +1,16 @@ +from os import listdir, path + +directory : str = "./test_input" + +def test(classifier): + for name in listdir(directory): + with open(path.join(directory, name), "r") as file: + sequence_to_classify = file.read() + + candidate_labels = ['semantic', 'other', 'mistranslation', 'instruction'] + result = classifier(sequence_to_classify, candidate_labels, multi_label=True) + + print(name) + for label, score in zip(result["labels"], result["scores"]): + print(f"{label}: {score:.3f}") + print("") diff --git a/classification/test_mails/mail_other_1 b/classification/test_input/mail_other_1 index f4a855325..f4a855325 100644 --- a/classification/test_mails/mail_other_1 +++ b/classification/test_input/mail_other_1 diff --git a/classification/test_mails/mail_other_2 b/classification/test_input/mail_other_2 index df6aceba1..df6aceba1 100644 --- a/classification/test_mails/mail_other_2 +++ b/classification/test_input/mail_other_2 diff --git a/classification/test_mails/mail_other_3 b/classification/test_input/mail_other_3 index 504ddc488..504ddc488 100644 --- a/classification/test_mails/mail_other_3 +++ b/classification/test_input/mail_other_3 diff --git a/classification/test_mails/mail_semantic_1 b/classification/test_input/mail_semantic_1 index af6a2480d..af6a2480d 100644 --- a/classification/test_mails/mail_semantic_1 +++ b/classification/test_input/mail_semantic_1 diff --git a/classification/test_mails/mail_semantic_2 b/classification/test_input/mail_semantic_2 index 4c78171d2..4c78171d2 100644 --- a/classification/test_mails/mail_semantic_2 +++ b/classification/test_input/mail_semantic_2 |