summary refs log tree commit diff stats
path: root/classification/test.py
blob: a8bc7e1ee68b84302353b42543378ade8bc5d831 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from os import listdir, path

directory : str = "./test_input"

def test(classifier):
    for name in listdir(directory):
        if name == "README.md":
            continue

        with open(path.join(directory, name), "r") as file:
            sequence_to_classify = file.read()

        candidate_labels = ['semantic', 'other', 'mistranslation', 'instruction']
        result = classifier(sequence_to_classify, candidate_labels, multi_label=True)

        print(name)
        for label, score in zip(result["labels"], result["scores"]):
            print(f"{label}: {score:.3f}")
        print("")