summary refs log tree commit diff stats
path: root/classification/test.py
diff options
context:
space:
mode:
Diffstat (limited to 'classification/test.py')
-rwxr-xr-x[-rw-r--r--]classification/test.py7
1 files changed, 3 insertions, 4 deletions
diff --git a/classification/test.py b/classification/test.py
index a8bc7e1e..bcd6b439 100644..100755
--- a/classification/test.py
+++ b/classification/test.py
@@ -2,16 +2,15 @@ from os import listdir, path
 
 directory : str = "./test_input"
 
-def test(classifier):
+def test(classifier, categories):
     for name in listdir(directory):
         if name == "README.md":
             continue
 
         with open(path.join(directory, name), "r") as file:
-            sequence_to_classify = file.read()
+            text = file.read()
 
-        candidate_labels = ['semantic', 'other', 'mistranslation', 'instruction']
-        result = classifier(sequence_to_classify, candidate_labels, multi_label=True)
+        result = classifier(text, categories, multi_label=True)
 
         print(name)
         for label, score in zip(result["labels"], result["scores"]):