summary refs log tree commit diff stats
path: root/classification
diff options
context:
space:
mode:
authorChristian Krinitsin <mail@krinitsin.com>2025-05-30 15:56:00 +0200
committerChristian Krinitsin <mail@krinitsin.com>2025-05-30 15:56:00 +0200
commit712310482c3dbef91c3eb6458d1bff82a275fa52 (patch)
treea1bcdb8df87d90ef121a093d4ea416838f84f856 /classification
parentfb84fa98ea1effc76cea3b3426546b4a3851af0b (diff)
downloademulator-bug-study-712310482c3dbef91c3eb6458d1bff82a275fa52.tar.gz
emulator-bug-study-712310482c3dbef91c3eb6458d1bff82a275fa52.zip
add test script for the classifier
Diffstat (limited to 'classification')
-rwxr-xr-xclassification/main.py25
-rw-r--r--classification/test.py16
-rw-r--r--classification/test_input/mail_other_1 (renamed from classification/test_mails/mail_other_1)0
-rw-r--r--classification/test_input/mail_other_2 (renamed from classification/test_mails/mail_other_2)0
-rw-r--r--classification/test_input/mail_other_3 (renamed from classification/test_mails/mail_other_3)0
-rw-r--r--classification/test_input/mail_semantic_1 (renamed from classification/test_mails/mail_semantic_1)0
-rw-r--r--classification/test_input/mail_semantic_2 (renamed from classification/test_mails/mail_semantic_2)0
7 files changed, 29 insertions, 12 deletions
diff --git a/classification/main.py b/classification/main.py
index 2a6f6d9a..7c1f2d4b 100755
--- a/classification/main.py
+++ b/classification/main.py
@@ -1,17 +1,18 @@
 from transformers import pipeline
-from os import listdir, path
+from argparse import ArgumentParser
 
-directory : str = "./test_mails"
-classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
+from test import test
 
-for name in listdir(directory):
-    with open(path.join(directory, name), "r") as file:
-        sequence_to_classify = file.read()
+parser = ArgumentParser(prog='classifier')
+parser.add_argument('-d', '--deepseek', action='store_true')
+args = parser.parse_args()
 
-    candidate_labels = ['semantic', 'other', 'mistranslation', 'instruction']
-    result = classifier(sequence_to_classify, candidate_labels, multi_label=True)
+def main():
+    if args.deepseek:
+        print("deepseek not supported")
+    else:
+        classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
+        test(classifier)
 
-    print(name)
-    for label, score in zip(result["labels"], result["scores"]):
-        print(f"{label}: {score:.3f}")
-    print("\n")
+if __name__ == "__main__":
+    main()
diff --git a/classification/test.py b/classification/test.py
new file mode 100644
index 00000000..e0db0031
--- /dev/null
+++ b/classification/test.py
@@ -0,0 +1,16 @@
+from os import listdir, path
+
+directory : str = "./test_input"
+
+def test(classifier):
+    for name in listdir(directory):
+        with open(path.join(directory, name), "r") as file:
+            sequence_to_classify = file.read()
+
+        candidate_labels = ['semantic', 'other', 'mistranslation', 'instruction']
+        result = classifier(sequence_to_classify, candidate_labels, multi_label=True)
+
+        print(name)
+        for label, score in zip(result["labels"], result["scores"]):
+            print(f"{label}: {score:.3f}")
+        print("")
diff --git a/classification/test_mails/mail_other_1 b/classification/test_input/mail_other_1
index f4a85532..f4a85532 100644
--- a/classification/test_mails/mail_other_1
+++ b/classification/test_input/mail_other_1
diff --git a/classification/test_mails/mail_other_2 b/classification/test_input/mail_other_2
index df6aceba..df6aceba 100644
--- a/classification/test_mails/mail_other_2
+++ b/classification/test_input/mail_other_2
diff --git a/classification/test_mails/mail_other_3 b/classification/test_input/mail_other_3
index 504ddc48..504ddc48 100644
--- a/classification/test_mails/mail_other_3
+++ b/classification/test_input/mail_other_3
diff --git a/classification/test_mails/mail_semantic_1 b/classification/test_input/mail_semantic_1
index af6a2480..af6a2480 100644
--- a/classification/test_mails/mail_semantic_1
+++ b/classification/test_input/mail_semantic_1
diff --git a/classification/test_mails/mail_semantic_2 b/classification/test_input/mail_semantic_2
index 4c78171d..4c78171d 100644
--- a/classification/test_mails/mail_semantic_2
+++ b/classification/test_input/mail_semantic_2