summary refs log tree commit diff stats
path: root/classification/main.py
blob: 2a6f6d9ad80d295e982b9b82d698294d7b4178f4 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from transformers import pipeline
from os import listdir, path

directory : str = "./test_mails"
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

for name in listdir(directory):
    with open(path.join(directory, name), "r") as file:
        sequence_to_classify = file.read()

    candidate_labels = ['semantic', 'other', 'mistranslation', 'instruction']
    result = classifier(sequence_to_classify, candidate_labels, multi_label=True)

    print(name)
    for label, score in zip(result["labels"], result["scores"]):
        print(f"{label}: {score:.3f}")
    print("\n")