diff options
Diffstat (limited to 'classification/main.py')
| -rwxr-xr-x | classification/main.py | 43 |
1 files changed, 28 insertions, 15 deletions
diff --git a/classification/main.py b/classification/main.py index db54a3f3..ad336899 100755 --- a/classification/main.py +++ b/classification/main.py @@ -33,8 +33,27 @@ def output(text : str, category : str, labels : list, scores : list, identifier file.write("\n") file.write(text) +def get_category(classification : dict): + result = classification['labels'][0] + + for label, score in zip(classification["labels"], classification["scores"]): + if label in negative_categories and score >= 0.92: + result = label + break + + if classification['labels'][0] == "semantic" and classification['scores'][0] <= 0.91: + result = "other" + + if all(i > 0.9 for i in classification["scores"]): + result = "all" + elif all(i < 0.6 for i in classification["scores"]): + result = "none" + + return result + def main(): - classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") + classifier_one = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") + classifier_two = pipeline("zero-shot-classification", model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0") bugs = list_files_recursive("../results/scraper/mailinglist") if args.minimal: @@ -49,23 +68,17 @@ def main(): with open(bug, "r") as file: text = file.read() - result = classifier(text, categories, multi_label=True) - category = result['labels'][0] - - for label, score in zip(result["labels"], result["scores"]): - if label in negative_categories and score >= 0.92: - category = label - break + result_one = classifier_one(text, categories, multi_label=True) + result_two = classifier_two(text, categories, multi_label=True) - if result['labels'][0] == "semantic" and result['scores'][0] <= 0.91: - category = "other" + category_one = get_category(result_one) + category_two = get_category(result_two) - if all(i > 0.9 for i in result["scores"]): - category = "all" - elif all(i < 0.6 for i in result["scores"]): - category = "none" + category = category_one + if category_one != category_two: + category = "review" - output(text, category, result['labels'], result['scores'], path.basename(bug)) + output(text, category, result_one['labels']+result_two['labels'], result_one['scores']+result_two['scores'], path.basename(bug)) if __name__ == "__main__": start_time = monotonic() |