diff options
Diffstat (limited to 'classification/main.py')
| -rwxr-xr-x | classification/main.py | 15 |
1 files changed, 13 insertions, 2 deletions
diff --git a/classification/main.py b/classification/main.py index 3394a3fba..e38460acb 100755 --- a/classification/main.py +++ b/classification/main.py @@ -11,7 +11,9 @@ parser.add_argument('-d', '--deepseek', action='store_true') parser.add_argument('-t', '--test', action='store_true') args = parser.parse_args() -categories = ['semantic', 'other', 'mistranslation', 'instruction'] +positive_categories = ['semantic', 'mistranslation', 'instruction', 'assembly'] # to add: register +negative_categories = ['other', 'boot', 'network', 'KVM', 'vnc', 'graphic', 'device', 'socket'] # to add: performance +categories = positive_categories + negative_categories def main(): if args.deepseek: @@ -24,12 +26,21 @@ def main(): exit() bugs = list_files_recursive("../mailinglist/output_mailinglist") + bugs = bugs + list_files_recursive("./semantic_issues") for bug in bugs: + print(f"Processing {bug}") with open(bug, "r") as file: text = file.read() result = classifier(text, categories, multi_label=True) - output(text, result['labels'], result['scores'], path.basename(bug)) + category = result['labels'][0] + + for label, score in zip(result["labels"], result["scores"]): + if label in negative_categories and score >= 0.92: + category = label + break + + output(text, category, result['labels'], result['scores'], path.basename(bug)) if __name__ == "__main__": main() |