summary refs log tree commit diff stats
path: root/classification/classifier.py
diff options
context:
space:
mode:
authorChristian Krinitsin <mail@krinitsin.com>2025-06-08 16:55:48 +0200
committerChristian Krinitsin <mail@krinitsin.com>2025-06-08 16:55:48 +0200
commitd804cb5b8f55b5e32c217e728fe02f6e53ecdf78 (patch)
treec0521bf192213ee8db765fcf807c7837609a4591 /classification/classifier.py
parent9ebc3c7b58e0820054942a2e22b7c48889c3ee26 (diff)
downloadqemu-analysis-d804cb5b8f55b5e32c217e728fe02f6e53ecdf78.tar.gz
qemu-analysis-d804cb5b8f55b5e32c217e728fe02f6e53ecdf78.zip
classifier: combine architecture classification with category
Diffstat (limited to '')
-rwxr-xr-xclassification/classifier.py39
1 files changed, 29 insertions, 10 deletions
diff --git a/classification/classifier.py b/classification/classifier.py
index 5c81f1818..c3fca28dd 100755
--- a/classification/classifier.py
+++ b/classification/classifier.py
@@ -11,9 +11,10 @@ parser.add_argument('--model', default="facebook/bart-large-mnli", type=str, hel
 parser.add_argument('--compare', nargs='?', const="MoritzLaurer/deberta-v3-large-zeroshot-v2.0", type=str, help="second model for comparison")
 args = parser.parse_args()
 
-positive_categories = ['semantic', 'TCG', 'assembly', 'architecture', 'mistranslation', 'register', 'x86', 'arm', 'risc-v']
+positive_categories = ['semantic', 'TCG', 'assembly', 'architecture', 'mistranslation', 'register' ]
+architectures = ['x86', 'arm', 'risc-v', 'i386', 'alpha', 'ppc']
 negative_categories = ['other', 'boot', 'network', 'kernel virtual machine', 'vnc', 'graphic', 'device', 'socket', 'debug', 'files', 'PID', 'permissions', 'performance']
-categories = positive_categories + negative_categories
+categories = positive_categories + negative_categories + architectures
 
 def list_files_recursive(directory):
     result = []
@@ -40,23 +41,41 @@ def output(text : str, category : str, labels : list, scores : list, identifier
         file.write(text)
 
 def get_category(classification : dict):
-    result = classification['labels'][0]
+    highest_category = classification['labels'][0]
+
+    if sum(1 for i in classification["scores"] if i > 0.9) >= 6:
+        return "all"
+    elif sum(1 for i in classification["scores"] if i < 0.6) >= 6:
+        return "none"
 
     if not args.multi_label:
-        return result
+        return highest_category 
 
+    result = highest_category
     for label, score in zip(classification["labels"], classification["scores"]):
         if label in negative_categories and score >= 0.92:
             result = label
             break
 
-    if classification['labels'][0] == "semantic" and classification['scores'][0] <= 0.91:
-        result = "other"
+        arch = None
+        pos = None
+        if label in positive_categories and score >= 0.92:
+            if not arch:
+                pos = label
+            else:
+                result = arch
+                break
+
+        if label in architectures and score >= 0.92:
+            if not pos:
+                arch = label
+            else:
+                result = pos
+                break
+
 
-    if all(i > 0.9 for i in classification["scores"]):
-        result = "all"
-    elif all(i < 0.6 for i in classification["scores"]):
-        result = "none"
+    # if highest_category == "semantic" and classification['scores'][0] <= 0.92:
+    #     result = "other"
 
     return result