diff options
Diffstat (limited to 'classification/classifier.py')
| -rwxr-xr-x | classification/classifier.py | 51 |
1 files changed, 37 insertions, 14 deletions
diff --git a/classification/classifier.py b/classification/classifier.py index ba01a864..43f5c13c 100755 --- a/classification/classifier.py +++ b/classification/classifier.py @@ -3,10 +3,12 @@ from os import path, listdir, makedirs from datetime import timedelta from time import monotonic from argparse import ArgumentParser +from ollama import chat, ChatResponse parser = ArgumentParser(prog='classifier.py') parser.add_argument('-f', '--full', action='store_true', help="use whole dataset") parser.add_argument('-m', '--multi_label', action='store_true', help="enable multi_label for the classifier") +parser.add_argument('-d', '--deepseek', nargs='?', const="deepseek-r1:7b", type=str, help="use deepseek") parser.add_argument('--model', default="facebook/bart-large-mnli", type=str, help="main model to use") parser.add_argument('--compare', nargs='?', const="MoritzLaurer/deberta-v3-large-zeroshot-v2.0", type=str, help="second model for comparison") args = parser.parse_args() @@ -26,7 +28,8 @@ def list_files_recursive(directory): result.append(full_path) return result -def output(text : str, category : str, labels : list, scores : list, identifier : str): +def output(text : str, category : str, labels : list, scores : list, identifier : str, reasoning : str = None): + print(f"Category: {category}, Time: {timedelta(seconds=monotonic() - start_time)}") file_path = f"output/{category}/{identifier}" makedirs(path.dirname(file_path), exist_ok = True) @@ -40,6 +43,13 @@ def output(text : str, category : str, labels : list, scores : list, identifier file.write("\n") file.write(text) + if reasoning: + file_path = f"reasoning/{category}/{identifier}" + makedirs(path.dirname(file_path), exist_ok = True) + + with open(file_path, "w") as file: + file.write(reasoning) + def get_category(classification : dict): highest_category = classification['labels'][0] @@ -85,11 +95,19 @@ def compare_category(classification : dict, category : str): return "review" def main(): - classifier = pipeline("zero-shot-classification", model=args.model) - if args.compare: - compare_classifier = pipeline("zero-shot-classification", model=args.compare) + if not args.deepseek: + classifier = pipeline("zero-shot-classification", model=args.model) + print(f"The model {args.model} will be used") + if args.compare: + compare_classifier = pipeline("zero-shot-classification", model=args.compare) + print(f"The comparison model {args.compare} will be used") + else: + print(f"The model {args.deepseek} will be used") + with open("preambel", "r") as file: + preambel = file.read() bugs = list_files_recursive("../results/scraper/mailinglist") + bugs = [] if not args.full: bugs = bugs + list_files_recursive("../results/scraper/gitlab/semantic_issues") bugs = bugs + [ "../results/scraper/launchpad/1809546", "../results/scraper/launchpad/1156313" ] @@ -98,22 +116,27 @@ def main(): bugs = bugs + list_files_recursive("../results/scraper/gitlab/issues_text") print(f"{len(bugs)} number of bugs will be processed") - for bug in bugs: - print(f"Processing {bug}") + for i, bug in enumerate(bugs): + print(f"Bug: {bug}, Number: {i+1},", end=" ") with open(bug, "r") as file: text = file.read() - result = classifier(text, categories, multi_label=args.multi_label) - category = get_category(result) + if args.deepseek: + response = chat(args.deepseek, [{'role': 'user', 'content': preambel + "\n" + text,}]) + category = response['message']['content'].split()[-1].strip("* ") + output(text, category, [], [], path.basename(bug), response['message']['content']) + else: + result = classifier(text, categories, multi_label=args.multi_label) + category = get_category(result) - if args.compare and sum(1 for i in positive_categories if i in category) >= 1: - compare_result = compare_classifier(text, categories, multi_label=args.multi_label) - category = compare_category(compare_result, category) + if args.compare and sum(1 for i in positive_categories if i in category) >= 1: + compare_result = compare_classifier(text, categories, multi_label=args.multi_label) + category = compare_category(compare_result, category) - result['labels'] = result['labels'] + ['SPLIT'] + compare_result['labels'] - result['scores'] = result['scores'] + [0] + compare_result['scores'] + result['labels'] = result['labels'] + ['SPLIT'] + compare_result['labels'] + result['scores'] = result['scores'] + [0] + compare_result['scores'] - output(text, category, result['labels'], result['scores'], path.basename(bug)) + output(text, category, result['labels'], result['scores'], path.basename(bug)) if __name__ == "__main__": start_time = monotonic() |