diff options
Diffstat (limited to 'classification/classifier.py')
| -rwxr-xr-x | classification/classifier.py | 8 |
1 files changed, 3 insertions, 5 deletions
diff --git a/classification/classifier.py b/classification/classifier.py index 213f64aa9..ba01a8644 100755 --- a/classification/classifier.py +++ b/classification/classifier.py @@ -12,8 +12,8 @@ parser.add_argument('--compare', nargs='?', const="MoritzLaurer/deberta-v3-large args = parser.parse_args() positive_categories = ['semantic', 'TCG', 'assembly', 'architecture', 'mistranslation', 'register', 'user-level'] -architectures = ['x86', 'arm', 'risc-v', 'i386', 'alpha', 'ppc'] -negative_categories = ['boot', 'network', 'KVM', 'vnc', 'graphic', 'device', 'socket', 'debug', 'files', 'PID', 'permissions', 'performance', 'kernel', 'peripherals', 'VMM', 'hypervisor', 'virtual', 'operating system'] +architectures = ['x86', 'arm', 'risc-v', 'i386', 'ppc'] +negative_categories = ['boot', 'network', 'KVM', 'vnc', 'graphic', 'device', 'socket', 'debug', 'files', 'PID', 'permissions', 'performance', 'kernel', 'peripherals', 'VMM', 'hypervisor', 'virtual' ] categories = positive_categories + negative_categories + architectures def list_files_recursive(directory): @@ -75,8 +75,6 @@ def get_category(classification : dict): return result def compare_category(classification : dict, category : str): - if sum(1 for i in positive_categories if i in category) < 1: - return category for label, score in zip(classification["labels"], classification["scores"]): if label in positive_categories and score >= 0.85: @@ -108,7 +106,7 @@ def main(): result = classifier(text, categories, multi_label=args.multi_label) category = get_category(result) - if args.compare: + if args.compare and sum(1 for i in positive_categories if i in category) >= 1: compare_result = compare_classifier(text, categories, multi_label=args.multi_label) category = compare_category(compare_result, category) |