summary refs log tree commit diff stats
path: root/classification/main.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rwxr-xr-xclassification/main.py43
1 files changed, 28 insertions, 15 deletions
diff --git a/classification/main.py b/classification/main.py
index db54a3f3..ad336899 100755
--- a/classification/main.py
+++ b/classification/main.py
@@ -33,8 +33,27 @@ def output(text : str, category : str, labels : list, scores : list, identifier
         file.write("\n")
         file.write(text)
 
+def get_category(classification : dict):
+    result = classification['labels'][0]
+
+    for label, score in zip(classification["labels"], classification["scores"]):
+        if label in negative_categories and score >= 0.92:
+            result = label
+            break
+
+    if classification['labels'][0] == "semantic" and classification['scores'][0] <= 0.91:
+        result = "other"
+
+    if all(i > 0.9 for i in classification["scores"]):
+        result = "all"
+    elif all(i < 0.6 for i in classification["scores"]):
+        result = "none"
+
+    return result
+
 def main():
-    classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
+    classifier_one = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
+    classifier_two = pipeline("zero-shot-classification", model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0")
 
     bugs = list_files_recursive("../results/scraper/mailinglist")
     if args.minimal:
@@ -49,23 +68,17 @@ def main():
         with open(bug, "r") as file:
             text = file.read()
 
-        result = classifier(text, categories, multi_label=True)
-        category = result['labels'][0]
-
-        for label, score in zip(result["labels"], result["scores"]):
-            if label in negative_categories and score >= 0.92:
-                category = label
-                break
+        result_one = classifier_one(text, categories, multi_label=True)
+        result_two = classifier_two(text, categories, multi_label=True)
 
-        if result['labels'][0] == "semantic" and result['scores'][0] <= 0.91:
-            category = "other"
+        category_one = get_category(result_one)
+        category_two = get_category(result_two)
 
-        if all(i > 0.9 for i in result["scores"]):
-            category = "all"
-        elif all(i < 0.6 for i in result["scores"]):
-            category = "none"
+        category = category_one
+        if category_one != category_two:
+            category = "review"
 
-        output(text, category, result['labels'], result['scores'], path.basename(bug))
+        output(text, category, result_one['labels']+result_two['labels'], result_one['scores']+result_two['scores'], path.basename(bug))
 
 if __name__ == "__main__":
     start_time = monotonic()