From d804cb5b8f55b5e32c217e728fe02f6e53ecdf78 Mon Sep 17 00:00:00 2001 From: Christian Krinitsin Date: Sun, 8 Jun 2025 16:55:48 +0200 Subject: classifier: combine architecture classification with category --- classification/classifier.py | 39 +++++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 10 deletions(-) (limited to 'classification/classifier.py') diff --git a/classification/classifier.py b/classification/classifier.py index 5c81f1818..c3fca28dd 100755 --- a/classification/classifier.py +++ b/classification/classifier.py @@ -11,9 +11,10 @@ parser.add_argument('--model', default="facebook/bart-large-mnli", type=str, hel parser.add_argument('--compare', nargs='?', const="MoritzLaurer/deberta-v3-large-zeroshot-v2.0", type=str, help="second model for comparison") args = parser.parse_args() -positive_categories = ['semantic', 'TCG', 'assembly', 'architecture', 'mistranslation', 'register', 'x86', 'arm', 'risc-v'] +positive_categories = ['semantic', 'TCG', 'assembly', 'architecture', 'mistranslation', 'register' ] +architectures = ['x86', 'arm', 'risc-v', 'i386', 'alpha', 'ppc'] negative_categories = ['other', 'boot', 'network', 'kernel virtual machine', 'vnc', 'graphic', 'device', 'socket', 'debug', 'files', 'PID', 'permissions', 'performance'] -categories = positive_categories + negative_categories +categories = positive_categories + negative_categories + architectures def list_files_recursive(directory): result = [] @@ -40,23 +41,41 @@ def output(text : str, category : str, labels : list, scores : list, identifier file.write(text) def get_category(classification : dict): - result = classification['labels'][0] + highest_category = classification['labels'][0] + + if sum(1 for i in classification["scores"] if i > 0.9) >= 6: + return "all" + elif sum(1 for i in classification["scores"] if i < 0.6) >= 6: + return "none" if not args.multi_label: - return result + return highest_category + result = highest_category for label, score in zip(classification["labels"], classification["scores"]): if label in negative_categories and score >= 0.92: result = label break - if classification['labels'][0] == "semantic" and classification['scores'][0] <= 0.91: - result = "other" + arch = None + pos = None + if label in positive_categories and score >= 0.92: + if not arch: + pos = label + else: + result = arch + break + + if label in architectures and score >= 0.92: + if not pos: + arch = label + else: + result = pos + break + - if all(i > 0.9 for i in classification["scores"]): - result = "all" - elif all(i < 0.6 for i in classification["scores"]): - result = "none" + # if highest_category == "semantic" and classification['scores'][0] <= 0.92: + # result = "other" return result -- cgit 1.4.1