summary refs log tree commit diff stats
path: root/classification/main.py
diff options
context:
space:
mode:
authorChristian Krinitsin <mail@krinitsin.com>2025-06-01 21:19:28 +0200
committerChristian Krinitsin <mail@krinitsin.com>2025-06-01 21:19:28 +0200
commit05ed9b00104b3da2e9cb0d541b5c5bbceb027fde (patch)
treed47bd9721b78db6131d56780e7230cf69412eaa6 /classification/main.py
parentcddf7dfa5a6e92a9057e7f6afae5d9b970585e6f (diff)
downloadqemu-analysis-05ed9b00104b3da2e9cb0d541b5c5bbceb027fde.tar.gz
qemu-analysis-05ed9b00104b3da2e9cb0d541b5c5bbceb027fde.zip
adjust classifier
Diffstat (limited to 'classification/main.py')
-rwxr-xr-xclassification/main.py15
1 files changed, 13 insertions, 2 deletions
diff --git a/classification/main.py b/classification/main.py
index 3394a3fba..e38460acb 100755
--- a/classification/main.py
+++ b/classification/main.py
@@ -11,7 +11,9 @@ parser.add_argument('-d', '--deepseek', action='store_true')
 parser.add_argument('-t', '--test', action='store_true')
 args = parser.parse_args()
 
-categories = ['semantic', 'other', 'mistranslation', 'instruction']
+positive_categories = ['semantic', 'mistranslation', 'instruction', 'assembly'] # to add: register
+negative_categories = ['other', 'boot', 'network', 'KVM', 'vnc', 'graphic', 'device', 'socket'] # to add: performance
+categories = positive_categories + negative_categories
 
 def main():
     if args.deepseek:
@@ -24,12 +26,21 @@ def main():
         exit()
 
     bugs = list_files_recursive("../mailinglist/output_mailinglist")
+    bugs = bugs + list_files_recursive("./semantic_issues")
     for bug in bugs:
+        print(f"Processing {bug}")
         with open(bug, "r") as file:
             text = file.read()
 
         result = classifier(text, categories, multi_label=True)
-        output(text, result['labels'], result['scores'], path.basename(bug))
+        category = result['labels'][0]
+
+        for label, score in zip(result["labels"], result["scores"]):
+            if label in negative_categories and score >= 0.92:
+                category = label
+                break
+
+        output(text, category, result['labels'], result['scores'], path.basename(bug))
 
 if __name__ == "__main__":
     main()