summary refs log tree commit diff stats
path: root/classification/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'classification/main.py')
-rwxr-xr-xclassification/main.py25
1 files changed, 13 insertions, 12 deletions
diff --git a/classification/main.py b/classification/main.py
index 2a6f6d9ad..7c1f2d4b4 100755
--- a/classification/main.py
+++ b/classification/main.py
@@ -1,17 +1,18 @@
 from transformers import pipeline
-from os import listdir, path
+from argparse import ArgumentParser
 
-directory : str = "./test_mails"
-classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
+from test import test
 
-for name in listdir(directory):
-    with open(path.join(directory, name), "r") as file:
-        sequence_to_classify = file.read()
+parser = ArgumentParser(prog='classifier')
+parser.add_argument('-d', '--deepseek', action='store_true')
+args = parser.parse_args()
 
-    candidate_labels = ['semantic', 'other', 'mistranslation', 'instruction']
-    result = classifier(sequence_to_classify, candidate_labels, multi_label=True)
+def main():
+    if args.deepseek:
+        print("deepseek not supported")
+    else:
+        classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
+        test(classifier)
 
-    print(name)
-    for label, score in zip(result["labels"], result["scores"]):
-        print(f"{label}: {score:.3f}")
-    print("\n")
+if __name__ == "__main__":
+    main()