summary refs log tree commit diff stats
path: root/classification/main.py
blob: 04f2d8c493b4fdb6f17f51a3bdca806dbd6d815b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
from transformers import pipeline

classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
with open("test", "r") as file:
    sequence_to_classify = file.read()
candidate_labels = ['semantic bug', 'no semantic bug']
result = classifier(sequence_to_classify, candidate_labels, multi_label=False)

print(result['labels'])
print(result['scores'])