summary refs log tree commit diff stats
path: root/classification/main.py
blob: 7c1f2d4b4cbc551e73078e5b87a2a29d10c927a6 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from transformers import pipeline
from argparse import ArgumentParser

from test import test

parser = ArgumentParser(prog='classifier')
parser.add_argument('-d', '--deepseek', action='store_true')
args = parser.parse_args()

def main():
    if args.deepseek:
        print("deepseek not supported")
    else:
        classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
        test(classifier)

if __name__ == "__main__":
    main()