summary refs log tree commit diff stats
path: root/classification
diff options
context:
space:
mode:
authorChristian Krinitsin <mail@krinitsin.com>2025-06-26 07:33:21 +0000
committerChristian Krinitsin <mail@krinitsin.com>2025-06-26 07:33:21 +0000
commitba2789bd7d81618a42dc7f69706a7acfa591630a (patch)
tree4aae3f6659946b6da5b436f38a7786a632c930be /classification
parent9aba81d8eb048db908c94a3c40c25a5fde0caee6 (diff)
downloademulator-bug-study-ba2789bd7d81618a42dc7f69706a7acfa591630a.tar.gz
emulator-bug-study-ba2789bd7d81618a42dc7f69706a7acfa591630a.zip
add prompt-based classifier and first results with DeepSeek-R1:14b
Diffstat (limited to 'classification')
-rwxr-xr-xclassification/classifier.py51
-rw-r--r--classification/preambel27
-rw-r--r--classification/requirements.txt1
-rw-r--r--classification/shell.nix1
4 files changed, 66 insertions, 14 deletions
diff --git a/classification/classifier.py b/classification/classifier.py
index ba01a864..43f5c13c 100755
--- a/classification/classifier.py
+++ b/classification/classifier.py
@@ -3,10 +3,12 @@ from os import path, listdir, makedirs
 from datetime import timedelta
 from time import monotonic
 from argparse import ArgumentParser
+from ollama import chat, ChatResponse
 
 parser = ArgumentParser(prog='classifier.py')
 parser.add_argument('-f', '--full', action='store_true', help="use whole dataset")
 parser.add_argument('-m', '--multi_label', action='store_true', help="enable multi_label for the classifier")
+parser.add_argument('-d', '--deepseek', nargs='?', const="deepseek-r1:7b", type=str, help="use deepseek")
 parser.add_argument('--model', default="facebook/bart-large-mnli", type=str, help="main model to use")
 parser.add_argument('--compare', nargs='?', const="MoritzLaurer/deberta-v3-large-zeroshot-v2.0", type=str, help="second model for comparison")
 args = parser.parse_args()
@@ -26,7 +28,8 @@ def list_files_recursive(directory):
             result.append(full_path)
     return result
 
-def output(text : str, category : str, labels : list, scores : list, identifier : str):
+def output(text : str, category : str, labels : list, scores : list, identifier : str, reasoning : str = None):
+    print(f"Category: {category}, Time: {timedelta(seconds=monotonic() - start_time)}")
     file_path = f"output/{category}/{identifier}"
     makedirs(path.dirname(file_path), exist_ok = True)
 
@@ -40,6 +43,13 @@ def output(text : str, category : str, labels : list, scores : list, identifier
         file.write("\n")
         file.write(text)
 
+    if reasoning:
+        file_path = f"reasoning/{category}/{identifier}"
+        makedirs(path.dirname(file_path), exist_ok = True)
+
+        with open(file_path, "w") as file:
+            file.write(reasoning)
+
 def get_category(classification : dict):
     highest_category = classification['labels'][0]
 
@@ -85,11 +95,19 @@ def compare_category(classification : dict, category : str):
     return "review"
 
 def main():
-    classifier = pipeline("zero-shot-classification", model=args.model)
-    if args.compare:
-        compare_classifier = pipeline("zero-shot-classification", model=args.compare)
+    if not args.deepseek:
+        classifier = pipeline("zero-shot-classification", model=args.model)
+        print(f"The model {args.model} will be used")
+        if args.compare:
+            compare_classifier = pipeline("zero-shot-classification", model=args.compare)
+            print(f"The comparison model {args.compare} will be used")
+    else:
+        print(f"The model {args.deepseek} will be used")
+        with open("preambel", "r") as file:
+            preambel = file.read()
 
     bugs = list_files_recursive("../results/scraper/mailinglist")
+    bugs = []
     if not args.full:
         bugs = bugs + list_files_recursive("../results/scraper/gitlab/semantic_issues")
         bugs = bugs + [ "../results/scraper/launchpad/1809546", "../results/scraper/launchpad/1156313" ]
@@ -98,22 +116,27 @@ def main():
         bugs = bugs + list_files_recursive("../results/scraper/gitlab/issues_text")
 
     print(f"{len(bugs)} number of bugs will be processed")
-    for bug in bugs:
-        print(f"Processing {bug}")
+    for i, bug in enumerate(bugs):
+        print(f"Bug: {bug}, Number: {i+1},", end=" ")
         with open(bug, "r") as file:
             text = file.read()
 
-        result = classifier(text, categories, multi_label=args.multi_label)
-        category = get_category(result)
+        if args.deepseek:
+            response = chat(args.deepseek, [{'role': 'user', 'content': preambel + "\n" + text,}])
+            category = response['message']['content'].split()[-1].strip("* ")
+            output(text, category, [], [], path.basename(bug), response['message']['content'])
+        else:
+            result = classifier(text, categories, multi_label=args.multi_label)
+            category = get_category(result)
 
-        if args.compare and sum(1 for i in positive_categories if i in category) >= 1:
-            compare_result = compare_classifier(text, categories, multi_label=args.multi_label)
-            category = compare_category(compare_result, category)
+            if args.compare and sum(1 for i in positive_categories if i in category) >= 1:
+                compare_result = compare_classifier(text, categories, multi_label=args.multi_label)
+                category = compare_category(compare_result, category)
 
-            result['labels'] = result['labels'] + ['SPLIT'] + compare_result['labels']
-            result['scores'] = result['scores'] + [0] + compare_result['scores']
+                result['labels'] = result['labels'] + ['SPLIT'] + compare_result['labels']
+                result['scores'] = result['scores'] + [0] + compare_result['scores']
 
-        output(text, category, result['labels'], result['scores'], path.basename(bug))
+            output(text, category, result['labels'], result['scores'], path.basename(bug))
 
 if __name__ == "__main__":
     start_time = monotonic()
diff --git a/classification/preambel b/classification/preambel
new file mode 100644
index 00000000..c85f8d25
--- /dev/null
+++ b/classification/preambel
@@ -0,0 +1,27 @@
+Classify the following bug report. it is part of qemu.
+
+These are the possible categories:
+
+mistranslation: incorrect semantic mapping from source architecture to IR/target, which happen in user-mode
+assembly: assembly lowering
+other: other
+device
+graphic
+socket
+network
+KVM
+boot
+vnc
+debug
+files
+permissions
+performance
+kernel
+peripherals
+hypervisor
+
+Respond only with a single word, the name of the category. 
+
+Bug report follows:
+
+
diff --git a/classification/requirements.txt b/classification/requirements.txt
index 38502406..32ba382a 100644
--- a/classification/requirements.txt
+++ b/classification/requirements.txt
@@ -1 +1,2 @@
 transformers[torch]
+ollama
diff --git a/classification/shell.nix b/classification/shell.nix
index 5a20b8e4..5a992f0f 100644
--- a/classification/shell.nix
+++ b/classification/shell.nix
@@ -4,6 +4,7 @@ mkShell {
   buildInputs = [
     python312Packages.transformers
     python312Packages.torch
+    python312Packages.ollama
   ];
   shellHook = ''
     # fixes libstdc++ issues and libgl.so issues