summary refs log tree commit diff stats
path: root/classification/classifier.py
diff options
context:
space:
mode:
Diffstat (limited to 'classification/classifier.py')
-rwxr-xr-xclassification/classifier.py24
1 files changed, 18 insertions, 6 deletions
diff --git a/classification/classifier.py b/classification/classifier.py
index 43f5c13c4..7b08439ab 100755
--- a/classification/classifier.py
+++ b/classification/classifier.py
@@ -4,6 +4,7 @@ from datetime import timedelta
 from time import monotonic
 from argparse import ArgumentParser
 from ollama import chat, ChatResponse
+from re import sub
 
 parser = ArgumentParser(prog='classifier.py')
 parser.add_argument('-f', '--full', action='store_true', help="use whole dataset")
@@ -15,17 +16,20 @@ args = parser.parse_args()
 
 positive_categories = ['semantic', 'TCG', 'assembly', 'architecture', 'mistranslation', 'register', 'user-level']
 architectures = ['x86', 'arm', 'risc-v', 'i386', 'ppc']
-negative_categories = ['boot', 'network', 'KVM', 'vnc', 'graphic', 'device', 'socket', 'debug', 'files', 'PID', 'permissions', 'performance', 'kernel', 'peripherals', 'VMM', 'hypervisor', 'virtual' ]
+negative_categories = ['boot', 'network', 'KVM', 'vnc', 'graphic', 'device', 'socket', 'debug', 'files', 'PID', 'permissions', 'performance', 'kernel', 'peripherals', 'VMM', 'hypervisor', 'virtual', 'other']
 categories = positive_categories + negative_categories + architectures
 
-def list_files_recursive(directory):
+def list_files_recursive(directory, basename = False):
     result = []
     for entry in listdir(directory):
         full_path = path.join(directory, entry)
         if path.isdir(full_path):
-            result = result + list_files_recursive(full_path)
+            result = result + list_files_recursive(full_path, basename)
         else:
-            result.append(full_path)
+            if basename:
+                result.append(path.basename(full_path))
+            else:
+                result.append(full_path)
     return result
 
 def output(text : str, category : str, labels : list, scores : list, identifier : str, reasoning : str = None):
@@ -106,24 +110,32 @@ def main():
         with open("preambel", "r") as file:
             preambel = file.read()
 
+    processed_bugs = list_files_recursive("output", True)
     bugs = list_files_recursive("../results/scraper/mailinglist")
     bugs = []
     if not args.full:
         bugs = bugs + list_files_recursive("../results/scraper/gitlab/semantic_issues")
         bugs = bugs + [ "../results/scraper/launchpad/1809546", "../results/scraper/launchpad/1156313" ]
     else:
-        bugs = bugs + list_files_recursive("../results/scraper/launchpad")
+        bugs = bugs + list_files_recursive("../results/scraper/launchpad-without-comments")
         bugs = bugs + list_files_recursive("../results/scraper/gitlab/issues_text")
 
     print(f"{len(bugs)} number of bugs will be processed")
     for i, bug in enumerate(bugs):
         print(f"Bug: {bug}, Number: {i+1},", end=" ")
+
+        if path.basename(bug) in processed_bugs:
+            print("skipped")
+            continue
+
         with open(bug, "r") as file:
             text = file.read()
 
         if args.deepseek:
             response = chat(args.deepseek, [{'role': 'user', 'content': preambel + "\n" + text,}])
-            category = response['message']['content'].split()[-1].strip("* ")
+            category = sub(r'[^a-zA-Z]', '', response['message']['content'].split()[-1])
+            if not category in categories:
+                category = "manual-review"
             output(text, category, [], [], path.basename(bug), response['message']['content'])
         else:
             result = classifier(text, categories, multi_label=args.multi_label)