diff options
Diffstat (limited to 'classification/main.py')
| -rwxr-xr-x | classification/main.py | 25 |
1 files changed, 21 insertions, 4 deletions
diff --git a/classification/main.py b/classification/main.py index 7c1f2d4b4..3394a3fba 100755 --- a/classification/main.py +++ b/classification/main.py @@ -1,18 +1,35 @@ from transformers import pipeline from argparse import ArgumentParser +from os import path from test import test +from files import list_files_recursive +from output import output parser = ArgumentParser(prog='classifier') parser.add_argument('-d', '--deepseek', action='store_true') +parser.add_argument('-t', '--test', action='store_true') args = parser.parse_args() +categories = ['semantic', 'other', 'mistranslation', 'instruction'] + def main(): if args.deepseek: - print("deepseek not supported") - else: - classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") - test(classifier) + print("deepseek currently not supported") + exit() + + classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") + if args.test: + test(classifier, categories) + exit() + + bugs = list_files_recursive("../mailinglist/output_mailinglist") + for bug in bugs: + with open(bug, "r") as file: + text = file.read() + + result = classifier(text, categories, multi_label=True) + output(text, result['labels'], result['scores'], path.basename(bug)) if __name__ == "__main__": main() |