diff options
Diffstat (limited to 'classification/classifier.py')
| -rwxr-xr-x | classification/classifier.py | 63 |
1 files changed, 34 insertions, 29 deletions
diff --git a/classification/classifier.py b/classification/classifier.py index c3fca28dd..213f64aa9 100755 --- a/classification/classifier.py +++ b/classification/classifier.py @@ -11,9 +11,9 @@ parser.add_argument('--model', default="facebook/bart-large-mnli", type=str, hel parser.add_argument('--compare', nargs='?', const="MoritzLaurer/deberta-v3-large-zeroshot-v2.0", type=str, help="second model for comparison") args = parser.parse_args() -positive_categories = ['semantic', 'TCG', 'assembly', 'architecture', 'mistranslation', 'register' ] +positive_categories = ['semantic', 'TCG', 'assembly', 'architecture', 'mistranslation', 'register', 'user-level'] architectures = ['x86', 'arm', 'risc-v', 'i386', 'alpha', 'ppc'] -negative_categories = ['other', 'boot', 'network', 'kernel virtual machine', 'vnc', 'graphic', 'device', 'socket', 'debug', 'files', 'PID', 'permissions', 'performance'] +negative_categories = ['boot', 'network', 'KVM', 'vnc', 'graphic', 'device', 'socket', 'debug', 'files', 'PID', 'permissions', 'performance', 'kernel', 'peripherals', 'VMM', 'hypervisor', 'virtual', 'operating system'] categories = positive_categories + negative_categories + architectures def list_files_recursive(directory): @@ -43,41 +43,48 @@ def output(text : str, category : str, labels : list, scores : list, identifier def get_category(classification : dict): highest_category = classification['labels'][0] - if sum(1 for i in classification["scores"] if i > 0.9) >= 6: - return "all" - elif sum(1 for i in classification["scores"] if i < 0.6) >= 6: - return "none" - if not args.multi_label: return highest_category + if all(i < 0.8 for i in classification["scores"]): + return "none" + elif sum(1 for i in classification["scores"] if i > 0.85) >= 20: + return "all" + elif classification["scores"][0] - classification["scores"][-1] <= 0.2: + return "unknown" + result = highest_category + arch = None + pos = None for label, score in zip(classification["labels"], classification["scores"]): - if label in negative_categories and score >= 0.92: - result = label - break + if label in negative_categories and (not arch and not pos or score >= 0.92): + return label - arch = None - pos = None - if label in positive_categories and score >= 0.92: + if label in positive_categories and not pos and score > 0.8: + pos = label if not arch: - pos = label + result = label else: - result = arch - break + result = label + "-" + arch - if label in architectures and score >= 0.92: - if not pos: - arch = label - else: - result = pos - break + if label in architectures and not arch and score > 0.8: + arch = label + if pos: + result = pos + "-" + label + return result - # if highest_category == "semantic" and classification['scores'][0] <= 0.92: - # result = "other" +def compare_category(classification : dict, category : str): + if sum(1 for i in positive_categories if i in category) < 1: + return category - return result + for label, score in zip(classification["labels"], classification["scores"]): + if label in positive_categories and score >= 0.85: + return category + if label in category and score >= 0.85: + return category + + return "review" def main(): classifier = pipeline("zero-shot-classification", model=args.model) @@ -87,6 +94,7 @@ def main(): bugs = list_files_recursive("../results/scraper/mailinglist") if not args.full: bugs = bugs + list_files_recursive("../results/scraper/gitlab/semantic_issues") + bugs = bugs + [ "../results/scraper/launchpad/1809546", "../results/scraper/launchpad/1156313" ] else: bugs = bugs + list_files_recursive("../results/scraper/launchpad") bugs = bugs + list_files_recursive("../results/scraper/gitlab/issues_text") @@ -102,10 +110,7 @@ def main(): if args.compare: compare_result = compare_classifier(text, categories, multi_label=args.multi_label) - compare_category = get_category(compare_result) - - if category != compare_category: - category = "review" + category = compare_category(compare_result, category) result['labels'] = result['labels'] + ['SPLIT'] + compare_result['labels'] result['scores'] = result['scores'] + [0] + compare_result['scores'] |