summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rwxr-xr-xclassification/main.py25
-rw-r--r--classification/test.py16
-rw-r--r--classification/test_input/mail_other_1 (renamed from classification/test_mails/mail_other_1)0
-rw-r--r--classification/test_input/mail_other_2 (renamed from classification/test_mails/mail_other_2)0
-rw-r--r--classification/test_input/mail_other_3 (renamed from classification/test_mails/mail_other_3)0
-rw-r--r--classification/test_input/mail_semantic_1 (renamed from classification/test_mails/mail_semantic_1)0
-rw-r--r--classification/test_input/mail_semantic_2 (renamed from classification/test_mails/mail_semantic_2)0
7 files changed, 29 insertions, 12 deletions
diff --git a/classification/main.py b/classification/main.py
index 2a6f6d9ad..7c1f2d4b4 100755
--- a/classification/main.py
+++ b/classification/main.py
@@ -1,17 +1,18 @@
 from transformers import pipeline
-from os import listdir, path
+from argparse import ArgumentParser
 
-directory : str = "./test_mails"
-classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
+from test import test
 
-for name in listdir(directory):
-    with open(path.join(directory, name), "r") as file:
-        sequence_to_classify = file.read()
+parser = ArgumentParser(prog='classifier')
+parser.add_argument('-d', '--deepseek', action='store_true')
+args = parser.parse_args()
 
-    candidate_labels = ['semantic', 'other', 'mistranslation', 'instruction']
-    result = classifier(sequence_to_classify, candidate_labels, multi_label=True)
+def main():
+    if args.deepseek:
+        print("deepseek not supported")
+    else:
+        classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
+        test(classifier)
 
-    print(name)
-    for label, score in zip(result["labels"], result["scores"]):
-        print(f"{label}: {score:.3f}")
-    print("\n")
+if __name__ == "__main__":
+    main()
diff --git a/classification/test.py b/classification/test.py
new file mode 100644
index 000000000..e0db00313
--- /dev/null
+++ b/classification/test.py
@@ -0,0 +1,16 @@
+from os import listdir, path
+
+directory : str = "./test_input"
+
+def test(classifier):
+    for name in listdir(directory):
+        with open(path.join(directory, name), "r") as file:
+            sequence_to_classify = file.read()
+
+        candidate_labels = ['semantic', 'other', 'mistranslation', 'instruction']
+        result = classifier(sequence_to_classify, candidate_labels, multi_label=True)
+
+        print(name)
+        for label, score in zip(result["labels"], result["scores"]):
+            print(f"{label}: {score:.3f}")
+        print("")
diff --git a/classification/test_mails/mail_other_1 b/classification/test_input/mail_other_1
index f4a855325..f4a855325 100644
--- a/classification/test_mails/mail_other_1
+++ b/classification/test_input/mail_other_1
diff --git a/classification/test_mails/mail_other_2 b/classification/test_input/mail_other_2
index df6aceba1..df6aceba1 100644
--- a/classification/test_mails/mail_other_2
+++ b/classification/test_input/mail_other_2
diff --git a/classification/test_mails/mail_other_3 b/classification/test_input/mail_other_3
index 504ddc488..504ddc488 100644
--- a/classification/test_mails/mail_other_3
+++ b/classification/test_input/mail_other_3
diff --git a/classification/test_mails/mail_semantic_1 b/classification/test_input/mail_semantic_1
index af6a2480d..af6a2480d 100644
--- a/classification/test_mails/mail_semantic_1
+++ b/classification/test_input/mail_semantic_1
diff --git a/classification/test_mails/mail_semantic_2 b/classification/test_input/mail_semantic_2
index 4c78171d2..4c78171d2 100644
--- a/classification/test_mails/mail_semantic_2
+++ b/classification/test_input/mail_semantic_2