summary refs log tree commit diff stats
path: root/classification/main.py
diff options
context:
space:
mode:
authorChristian Krinitsin <mail@krinitsin.com>2025-06-01 21:35:14 +0200
committerChristian Krinitsin <mail@krinitsin.com>2025-06-01 21:35:14 +0200
commit3e4c5a6261770bced301b5e74233e7866166ea5b (patch)
tree9379fddaba693ef8a045da06efee8529baa5f6f4 /classification/main.py
parente5634e2806195bee44407853c4bf8776f7abfa4f (diff)
downloademulator-bug-study-3e4c5a6261770bced301b5e74233e7866166ea5b.tar.gz
emulator-bug-study-3e4c5a6261770bced301b5e74233e7866166ea5b.zip
clean up repository
Diffstat (limited to 'classification/main.py')
-rwxr-xr-xclassification/main.py38
1 files changed, 21 insertions, 17 deletions
diff --git a/classification/main.py b/classification/main.py
index e38460ac..6786cc2f 100755
--- a/classification/main.py
+++ b/classification/main.py
@@ -1,29 +1,33 @@
 from transformers import pipeline
-from argparse import ArgumentParser
 from os import path
 
-from test import test
-from files import list_files_recursive
-from output import output
-
-parser = ArgumentParser(prog='classifier')
-parser.add_argument('-d', '--deepseek', action='store_true')
-parser.add_argument('-t', '--test', action='store_true')
-args = parser.parse_args()
-
 positive_categories = ['semantic', 'mistranslation', 'instruction', 'assembly'] # to add: register
 negative_categories = ['other', 'boot', 'network', 'KVM', 'vnc', 'graphic', 'device', 'socket'] # to add: performance
 categories = positive_categories + negative_categories
 
-def main():
-    if args.deepseek:
-        print("deepseek currently not supported")
-        exit()
+def list_files_recursive(path):
+    result = []
+    for entry in os.listdir(path):
+        full_path = os.path.join(path, entry)
+        if os.path.isdir(full_path):
+            result = result + list_files_recursive(full_path)
+        else:
+            result.append(full_path)
+    return result
+
+def output(text : str, category : str, labels : list, scores : list, identifier : str):
+    file_path = f"output/{category}/{identifier}"
+    makedirs(path.dirname(file_path), exist_ok = True)
 
+    with open(file_path, "w") as file:
+        for label, score in zip(labels, scores):
+            file.write(f"{label}: {score:.3f}\n")
+
+        file.write("\n")
+        file.write(text)
+
+def main():
     classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
-    if args.test:
-        test(classifier, categories)
-        exit()
 
     bugs = list_files_recursive("../mailinglist/output_mailinglist")
     bugs = bugs + list_files_recursive("./semantic_issues")