summary refs log tree commit diff stats
path: root/classification/main.py
diff options
context:
space:
mode:
authorChristian Krinitsin <mail@krinitsin.com>2025-05-30 15:56:00 +0200
committerChristian Krinitsin <mail@krinitsin.com>2025-05-30 15:56:00 +0200
commit712310482c3dbef91c3eb6458d1bff82a275fa52 (patch)
treea1bcdb8df87d90ef121a093d4ea416838f84f856 /classification/main.py
parentfb84fa98ea1effc76cea3b3426546b4a3851af0b (diff)
downloademulator-bug-study-712310482c3dbef91c3eb6458d1bff82a275fa52.tar.gz
emulator-bug-study-712310482c3dbef91c3eb6458d1bff82a275fa52.zip
add test script for the classifier
Diffstat (limited to 'classification/main.py')
-rwxr-xr-xclassification/main.py25
1 files changed, 13 insertions, 12 deletions
diff --git a/classification/main.py b/classification/main.py
index 2a6f6d9a..7c1f2d4b 100755
--- a/classification/main.py
+++ b/classification/main.py
@@ -1,17 +1,18 @@
 from transformers import pipeline
-from os import listdir, path
+from argparse import ArgumentParser
 
-directory : str = "./test_mails"
-classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
+from test import test
 
-for name in listdir(directory):
-    with open(path.join(directory, name), "r") as file:
-        sequence_to_classify = file.read()
+parser = ArgumentParser(prog='classifier')
+parser.add_argument('-d', '--deepseek', action='store_true')
+args = parser.parse_args()
 
-    candidate_labels = ['semantic', 'other', 'mistranslation', 'instruction']
-    result = classifier(sequence_to_classify, candidate_labels, multi_label=True)
+def main():
+    if args.deepseek:
+        print("deepseek not supported")
+    else:
+        classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
+        test(classifier)
 
-    print(name)
-    for label, score in zip(result["labels"], result["scores"]):
-        print(f"{label}: {score:.3f}")
-    print("\n")
+if __name__ == "__main__":
+    main()