summary refs log tree commit diff stats
path: root/classification/classifier.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rwxr-xr-xclassification/classifier.py (renamed from classification/main.py)41
1 files changed, 27 insertions, 14 deletions
diff --git a/classification/main.py b/classification/classifier.py
index ad3368992..d29d385a2 100755
--- a/classification/main.py
+++ b/classification/classifier.py
@@ -4,8 +4,11 @@ from datetime import timedelta
 from time import monotonic
 from argparse import ArgumentParser
 
-parser = ArgumentParser(prog='main.py')
-parser.add_argument('-m', '--minimal', action='store_true')
+parser = ArgumentParser(prog='classifier.py')
+parser.add_argument('-f', '--full', action='store_true', help="use whole dataset")
+parser.add_argument('-m', '--multi_label', action='store_true', help="enable multi_label for the classifier")
+parser.add_argument('--model', default="facebook/bart-large-mnli", type=str, help="main model to use")
+parser.add_argument('--compare', nargs='?', const="MoritzLaurer/deberta-v3-large-zeroshot-v2.0", type=str, help="second model for comparison")
 args = parser.parse_args()
 
 positive_categories = ['semantic']
@@ -28,7 +31,10 @@ def output(text : str, category : str, labels : list, scores : list, identifier
 
     with open(file_path, "w") as file:
         for label, score in zip(labels, scores):
-            file.write(f"{label}: {score:.3f}\n")
+            if label == "SPLIT":
+                file.write(f"--------------------\n")
+            else:
+                file.write(f"{label}: {score:.3f}\n")
 
         file.write("\n")
         file.write(text)
@@ -36,6 +42,9 @@ def output(text : str, category : str, labels : list, scores : list, identifier
 def get_category(classification : dict):
     result = classification['labels'][0]
 
+    if not args.multi_label:
+        return result
+
     for label, score in zip(classification["labels"], classification["scores"]):
         if label in negative_categories and score >= 0.92:
             result = label
@@ -52,11 +61,12 @@ def get_category(classification : dict):
     return result
 
 def main():
-    classifier_one = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
-    classifier_two = pipeline("zero-shot-classification", model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0")
+    classifier = pipeline("zero-shot-classification", model=args.model)
+    if args.compare:
+        compare_classifier = pipeline("zero-shot-classification", model=args.compare)
 
     bugs = list_files_recursive("../results/scraper/mailinglist")
-    if args.minimal:
+    if not args.full:
         bugs = bugs + list_files_recursive("../results/scraper/gitlab/semantic_issues")
     else:
         bugs = bugs + list_files_recursive("../results/scraper/launchpad")
@@ -68,17 +78,20 @@ def main():
         with open(bug, "r") as file:
             text = file.read()
 
-        result_one = classifier_one(text, categories, multi_label=True)
-        result_two = classifier_two(text, categories, multi_label=True)
+        result = classifier(text, categories, multi_label=args.multi_label)
+        category = get_category(result)
+
+        if args.compare:
+            compare_result = compare_classifier(text, categories, multi_label=args.multi_label)
+            compare_category = get_category(compare_result)
 
-        category_one = get_category(result_one)
-        category_two = get_category(result_two)
+            if category != compare_category:
+                category = "review"
 
-        category = category_one
-        if category_one != category_two:
-            category = "review"
+            result['labels'] = result['labels'] + ['SPLIT'] + compare_result['labels']
+            result['scores'] = result['scores'] + [0] + compare_result['scores']
 
-        output(text, category, result_one['labels']+result_two['labels'], result_one['scores']+result_two['scores'], path.basename(bug))
+        output(text, category, result['labels'], result['scores'], path.basename(bug))
 
 if __name__ == "__main__":
     start_time = monotonic()