summary refs log tree commit diff stats
path: root/classification/classifier.py
diff options
context:
space:
mode:
authorChristian Krinitsin <mail@krinitsin.com>2025-06-10 17:04:21 +0000
committerChristian Krinitsin <mail@krinitsin.com>2025-06-10 17:04:21 +0000
commit7b681b9f9eedaad2f081ae11a32f459f5a1312ff (patch)
tree447529eab427f2cb024d33933794a27f30369c4d /classification/classifier.py
parentd804cb5b8f55b5e32c217e728fe02f6e53ecdf78 (diff)
downloadqemu-analysis-7b681b9f9eedaad2f081ae11a32f459f5a1312ff.tar.gz
qemu-analysis-7b681b9f9eedaad2f081ae11a32f459f5a1312ff.zip
add 17th version of the classifier, including results
Diffstat (limited to 'classification/classifier.py')
-rwxr-xr-xclassification/classifier.py63
1 files changed, 34 insertions, 29 deletions
diff --git a/classification/classifier.py b/classification/classifier.py
index c3fca28dd..213f64aa9 100755
--- a/classification/classifier.py
+++ b/classification/classifier.py
@@ -11,9 +11,9 @@ parser.add_argument('--model', default="facebook/bart-large-mnli", type=str, hel
 parser.add_argument('--compare', nargs='?', const="MoritzLaurer/deberta-v3-large-zeroshot-v2.0", type=str, help="second model for comparison")
 args = parser.parse_args()
 
-positive_categories = ['semantic', 'TCG', 'assembly', 'architecture', 'mistranslation', 'register' ]
+positive_categories = ['semantic', 'TCG', 'assembly', 'architecture', 'mistranslation', 'register', 'user-level']
 architectures = ['x86', 'arm', 'risc-v', 'i386', 'alpha', 'ppc']
-negative_categories = ['other', 'boot', 'network', 'kernel virtual machine', 'vnc', 'graphic', 'device', 'socket', 'debug', 'files', 'PID', 'permissions', 'performance']
+negative_categories = ['boot', 'network', 'KVM', 'vnc', 'graphic', 'device', 'socket', 'debug', 'files', 'PID', 'permissions', 'performance', 'kernel', 'peripherals', 'VMM', 'hypervisor', 'virtual', 'operating system']
 categories = positive_categories + negative_categories + architectures
 
 def list_files_recursive(directory):
@@ -43,41 +43,48 @@ def output(text : str, category : str, labels : list, scores : list, identifier
 def get_category(classification : dict):
     highest_category = classification['labels'][0]
 
-    if sum(1 for i in classification["scores"] if i > 0.9) >= 6:
-        return "all"
-    elif sum(1 for i in classification["scores"] if i < 0.6) >= 6:
-        return "none"
-
     if not args.multi_label:
         return highest_category 
 
+    if all(i < 0.8 for i in classification["scores"]):
+        return "none"
+    elif sum(1 for i in classification["scores"] if i > 0.85) >= 20:
+        return "all"
+    elif classification["scores"][0] - classification["scores"][-1] <= 0.2:
+        return "unknown"
+
     result = highest_category
+    arch = None
+    pos = None
     for label, score in zip(classification["labels"], classification["scores"]):
-        if label in negative_categories and score >= 0.92:
-            result = label
-            break
+        if label in negative_categories and (not arch and not pos or score >= 0.92):
+            return label
 
-        arch = None
-        pos = None
-        if label in positive_categories and score >= 0.92:
+        if label in positive_categories and not pos and score > 0.8:
+            pos = label
             if not arch:
-                pos = label
+                result = label
             else:
-                result = arch
-                break
+                result = label + "-" + arch
 
-        if label in architectures and score >= 0.92:
-            if not pos:
-                arch = label
-            else:
-                result = pos
-                break
+        if label in architectures and not arch and score > 0.8:
+            arch = label
+            if pos:
+                result = pos + "-" + label
 
+    return result
 
-    # if highest_category == "semantic" and classification['scores'][0] <= 0.92:
-    #     result = "other"
+def compare_category(classification : dict, category : str):
+    if sum(1 for i in positive_categories if i in category) < 1:
+        return category
 
-    return result
+    for label, score in zip(classification["labels"], classification["scores"]):
+        if label in positive_categories and score >= 0.85:
+            return category
+        if label in category and score >= 0.85:
+            return category
+
+    return "review"
 
 def main():
     classifier = pipeline("zero-shot-classification", model=args.model)
@@ -87,6 +94,7 @@ def main():
     bugs = list_files_recursive("../results/scraper/mailinglist")
     if not args.full:
         bugs = bugs + list_files_recursive("../results/scraper/gitlab/semantic_issues")
+        bugs = bugs + [ "../results/scraper/launchpad/1809546", "../results/scraper/launchpad/1156313" ]
     else:
         bugs = bugs + list_files_recursive("../results/scraper/launchpad")
         bugs = bugs + list_files_recursive("../results/scraper/gitlab/issues_text")
@@ -102,10 +110,7 @@ def main():
 
         if args.compare:
             compare_result = compare_classifier(text, categories, multi_label=args.multi_label)
-            compare_category = get_category(compare_result)
-
-            if category != compare_category:
-                category = "review"
+            category = compare_category(compare_result, category)
 
             result['labels'] = result['labels'] + ['SPLIT'] + compare_result['labels']
             result['scores'] = result['scores'] + [0] + compare_result['scores']