summaryrefslogtreecommitdiffstats
path: root/classification
diff options
context:
space:
mode:
Diffstat (limited to '')
-rwxr-xr-xclassification/main.py25
-rw-r--r--classification/test.py16
-rw-r--r--classification/test_input/mail_other_1 (renamed from classification/test_mails/mail_other_1)0
-rw-r--r--classification/test_input/mail_other_2 (renamed from classification/test_mails/mail_other_2)0
-rw-r--r--classification/test_input/mail_other_3 (renamed from classification/test_mails/mail_other_3)0
-rw-r--r--classification/test_input/mail_semantic_1 (renamed from classification/test_mails/mail_semantic_1)0
-rw-r--r--classification/test_input/mail_semantic_2 (renamed from classification/test_mails/mail_semantic_2)0
7 files changed, 29 insertions, 12 deletions
diff --git a/classification/main.py b/classification/main.py
index 2a6f6d9a..7c1f2d4b 100755
--- a/classification/main.py
+++ b/classification/main.py
@@ -1,17 +1,18 @@
from transformers import pipeline
-from os import listdir, path
+from argparse import ArgumentParser
-directory : str = "./test_mails"
-classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
+from test import test
-for name in listdir(directory):
- with open(path.join(directory, name), "r") as file:
- sequence_to_classify = file.read()
+parser = ArgumentParser(prog='classifier')
+parser.add_argument('-d', '--deepseek', action='store_true')
+args = parser.parse_args()
- candidate_labels = ['semantic', 'other', 'mistranslation', 'instruction']
- result = classifier(sequence_to_classify, candidate_labels, multi_label=True)
+def main():
+ if args.deepseek:
+ print("deepseek not supported")
+ else:
+ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
+ test(classifier)
- print(name)
- for label, score in zip(result["labels"], result["scores"]):
- print(f"{label}: {score:.3f}")
- print("\n")
+if __name__ == "__main__":
+ main()
diff --git a/classification/test.py b/classification/test.py
new file mode 100644
index 00000000..e0db0031
--- /dev/null
+++ b/classification/test.py
@@ -0,0 +1,16 @@
+from os import listdir, path
+
+directory : str = "./test_input"
+
+def test(classifier):
+ for name in listdir(directory):
+ with open(path.join(directory, name), "r") as file:
+ sequence_to_classify = file.read()
+
+ candidate_labels = ['semantic', 'other', 'mistranslation', 'instruction']
+ result = classifier(sequence_to_classify, candidate_labels, multi_label=True)
+
+ print(name)
+ for label, score in zip(result["labels"], result["scores"]):
+ print(f"{label}: {score:.3f}")
+ print("")
diff --git a/classification/test_mails/mail_other_1 b/classification/test_input/mail_other_1
index f4a85532..f4a85532 100644
--- a/classification/test_mails/mail_other_1
+++ b/classification/test_input/mail_other_1
diff --git a/classification/test_mails/mail_other_2 b/classification/test_input/mail_other_2
index df6aceba..df6aceba 100644
--- a/classification/test_mails/mail_other_2
+++ b/classification/test_input/mail_other_2
diff --git a/classification/test_mails/mail_other_3 b/classification/test_input/mail_other_3
index 504ddc48..504ddc48 100644
--- a/classification/test_mails/mail_other_3
+++ b/classification/test_input/mail_other_3
diff --git a/classification/test_mails/mail_semantic_1 b/classification/test_input/mail_semantic_1
index af6a2480..af6a2480 100644
--- a/classification/test_mails/mail_semantic_1
+++ b/classification/test_input/mail_semantic_1
diff --git a/classification/test_mails/mail_semantic_2 b/classification/test_input/mail_semantic_2
index 4c78171d..4c78171d 100644
--- a/classification/test_mails/mail_semantic_2
+++ b/classification/test_input/mail_semantic_2