summary refs log tree commit diff stats
path: root/classification/classifier.py
blob: c3fca28dd49df788f68eba394ee6cc1ce7fc928f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from transformers import pipeline
from os import path, listdir, makedirs
from datetime import timedelta
from time import monotonic
from argparse import ArgumentParser

parser = ArgumentParser(prog='classifier.py')
parser.add_argument('-f', '--full', action='store_true', help="use whole dataset")
parser.add_argument('-m', '--multi_label', action='store_true', help="enable multi_label for the classifier")
parser.add_argument('--model', default="facebook/bart-large-mnli", type=str, help="main model to use")
parser.add_argument('--compare', nargs='?', const="MoritzLaurer/deberta-v3-large-zeroshot-v2.0", type=str, help="second model for comparison")
args = parser.parse_args()

positive_categories = ['semantic', 'TCG', 'assembly', 'architecture', 'mistranslation', 'register' ]
architectures = ['x86', 'arm', 'risc-v', 'i386', 'alpha', 'ppc']
negative_categories = ['other', 'boot', 'network', 'kernel virtual machine', 'vnc', 'graphic', 'device', 'socket', 'debug', 'files', 'PID', 'permissions', 'performance']
categories = positive_categories + negative_categories + architectures

def list_files_recursive(directory):
    result = []
    for entry in listdir(directory):
        full_path = path.join(directory, entry)
        if path.isdir(full_path):
            result = result + list_files_recursive(full_path)
        else:
            result.append(full_path)
    return result

def output(text : str, category : str, labels : list, scores : list, identifier : str):
    file_path = f"output/{category}/{identifier}"
    makedirs(path.dirname(file_path), exist_ok = True)

    with open(file_path, "w") as file:
        for label, score in zip(labels, scores):
            if label == "SPLIT":
                file.write(f"--------------------\n")
            else:
                file.write(f"{label}: {score:.3f}\n")

        file.write("\n")
        file.write(text)

def get_category(classification : dict):
    highest_category = classification['labels'][0]

    if sum(1 for i in classification["scores"] if i > 0.9) >= 6:
        return "all"
    elif sum(1 for i in classification["scores"] if i < 0.6) >= 6:
        return "none"

    if not args.multi_label:
        return highest_category 

    result = highest_category
    for label, score in zip(classification["labels"], classification["scores"]):
        if label in negative_categories and score >= 0.92:
            result = label
            break

        arch = None
        pos = None
        if label in positive_categories and score >= 0.92:
            if not arch:
                pos = label
            else:
                result = arch
                break

        if label in architectures and score >= 0.92:
            if not pos:
                arch = label
            else:
                result = pos
                break


    # if highest_category == "semantic" and classification['scores'][0] <= 0.92:
    #     result = "other"

    return result

def main():
    classifier = pipeline("zero-shot-classification", model=args.model)
    if args.compare:
        compare_classifier = pipeline("zero-shot-classification", model=args.compare)

    bugs = list_files_recursive("../results/scraper/mailinglist")
    if not args.full:
        bugs = bugs + list_files_recursive("../results/scraper/gitlab/semantic_issues")
    else:
        bugs = bugs + list_files_recursive("../results/scraper/launchpad")
        bugs = bugs + list_files_recursive("../results/scraper/gitlab/issues_text")

    print(f"{len(bugs)} number of bugs will be processed")
    for bug in bugs:
        print(f"Processing {bug}")
        with open(bug, "r") as file:
            text = file.read()

        result = classifier(text, categories, multi_label=args.multi_label)
        category = get_category(result)

        if args.compare:
            compare_result = compare_classifier(text, categories, multi_label=args.multi_label)
            compare_category = get_category(compare_result)

            if category != compare_category:
                category = "review"

            result['labels'] = result['labels'] + ['SPLIT'] + compare_result['labels']
            result['scores'] = result['scores'] + [0] + compare_result['scores']

        output(text, category, result['labels'], result['scores'], path.basename(bug))

if __name__ == "__main__":
    start_time = monotonic()
    main()
    end_time = monotonic()
    print(timedelta(seconds=end_time - start_time))