summary refs log tree commit diff stats
path: root/classification/classifier.py
blob: 43f5c13c4ebb06f5b61d18f6a3d1c2e24ef2ed96 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from transformers import pipeline
from os import path, listdir, makedirs
from datetime import timedelta
from time import monotonic
from argparse import ArgumentParser
from ollama import chat, ChatResponse

parser = ArgumentParser(prog='classifier.py')
parser.add_argument('-f', '--full', action='store_true', help="use whole dataset")
parser.add_argument('-m', '--multi_label', action='store_true', help="enable multi_label for the classifier")
parser.add_argument('-d', '--deepseek', nargs='?', const="deepseek-r1:7b", type=str, help="use deepseek")
parser.add_argument('--model', default="facebook/bart-large-mnli", type=str, help="main model to use")
parser.add_argument('--compare', nargs='?', const="MoritzLaurer/deberta-v3-large-zeroshot-v2.0", type=str, help="second model for comparison")
args = parser.parse_args()

positive_categories = ['semantic', 'TCG', 'assembly', 'architecture', 'mistranslation', 'register', 'user-level']
architectures = ['x86', 'arm', 'risc-v', 'i386', 'ppc']
negative_categories = ['boot', 'network', 'KVM', 'vnc', 'graphic', 'device', 'socket', 'debug', 'files', 'PID', 'permissions', 'performance', 'kernel', 'peripherals', 'VMM', 'hypervisor', 'virtual' ]
categories = positive_categories + negative_categories + architectures

def list_files_recursive(directory):
    result = []
    for entry in listdir(directory):
        full_path = path.join(directory, entry)
        if path.isdir(full_path):
            result = result + list_files_recursive(full_path)
        else:
            result.append(full_path)
    return result

def output(text : str, category : str, labels : list, scores : list, identifier : str, reasoning : str = None):
    print(f"Category: {category}, Time: {timedelta(seconds=monotonic() - start_time)}")
    file_path = f"output/{category}/{identifier}"
    makedirs(path.dirname(file_path), exist_ok = True)

    with open(file_path, "w") as file:
        for label, score in zip(labels, scores):
            if label == "SPLIT":
                file.write(f"--------------------\n")
            else:
                file.write(f"{label}: {score:.3f}\n")

        file.write("\n")
        file.write(text)

    if reasoning:
        file_path = f"reasoning/{category}/{identifier}"
        makedirs(path.dirname(file_path), exist_ok = True)

        with open(file_path, "w") as file:
            file.write(reasoning)

def get_category(classification : dict):
    highest_category = classification['labels'][0]

    if not args.multi_label:
        return highest_category 

    if all(i < 0.8 for i in classification["scores"]):
        return "none"
    elif sum(1 for i in classification["scores"] if i > 0.85) >= 20:
        return "all"
    elif classification["scores"][0] - classification["scores"][-1] <= 0.2:
        return "unknown"

    result = highest_category
    arch = None
    pos = None
    for label, score in zip(classification["labels"], classification["scores"]):
        if label in negative_categories and (not arch and not pos or score >= 0.92):
            return label

        if label in positive_categories and not pos and score > 0.8:
            pos = label
            if not arch:
                result = label
            else:
                result = label + "-" + arch

        if label in architectures and not arch and score > 0.8:
            arch = label
            if pos:
                result = pos + "-" + label

    return result

def compare_category(classification : dict, category : str):

    for label, score in zip(classification["labels"], classification["scores"]):
        if label in positive_categories and score >= 0.85:
            return category
        if label in category and score >= 0.85:
            return category

    return "review"

def main():
    if not args.deepseek:
        classifier = pipeline("zero-shot-classification", model=args.model)
        print(f"The model {args.model} will be used")
        if args.compare:
            compare_classifier = pipeline("zero-shot-classification", model=args.compare)
            print(f"The comparison model {args.compare} will be used")
    else:
        print(f"The model {args.deepseek} will be used")
        with open("preambel", "r") as file:
            preambel = file.read()

    bugs = list_files_recursive("../results/scraper/mailinglist")
    bugs = []
    if not args.full:
        bugs = bugs + list_files_recursive("../results/scraper/gitlab/semantic_issues")
        bugs = bugs + [ "../results/scraper/launchpad/1809546", "../results/scraper/launchpad/1156313" ]
    else:
        bugs = bugs + list_files_recursive("../results/scraper/launchpad")
        bugs = bugs + list_files_recursive("../results/scraper/gitlab/issues_text")

    print(f"{len(bugs)} number of bugs will be processed")
    for i, bug in enumerate(bugs):
        print(f"Bug: {bug}, Number: {i+1},", end=" ")
        with open(bug, "r") as file:
            text = file.read()

        if args.deepseek:
            response = chat(args.deepseek, [{'role': 'user', 'content': preambel + "\n" + text,}])
            category = response['message']['content'].split()[-1].strip("* ")
            output(text, category, [], [], path.basename(bug), response['message']['content'])
        else:
            result = classifier(text, categories, multi_label=args.multi_label)
            category = get_category(result)

            if args.compare and sum(1 for i in positive_categories if i in category) >= 1:
                compare_result = compare_classifier(text, categories, multi_label=args.multi_label)
                category = compare_category(compare_result, category)

                result['labels'] = result['labels'] + ['SPLIT'] + compare_result['labels']
                result['scores'] = result['scores'] + [0] + compare_result['scores']

            output(text, category, result['labels'], result['scores'], path.basename(bug))

if __name__ == "__main__":
    start_time = monotonic()
    main()
    end_time = monotonic()
    print(timedelta(seconds=end_time - start_time))