summaryrefslogtreecommitdiffstats
path: root/classification
diff options
context:
space:
mode:
Diffstat (limited to '')
-rwxr-xr-xclassification/classifier.py39
1 files changed, 29 insertions, 10 deletions
diff --git a/classification/classifier.py b/classification/classifier.py
index 5c81f181..c3fca28d 100755
--- a/classification/classifier.py
+++ b/classification/classifier.py
@@ -11,9 +11,10 @@ parser.add_argument('--model', default="facebook/bart-large-mnli", type=str, hel
parser.add_argument('--compare', nargs='?', const="MoritzLaurer/deberta-v3-large-zeroshot-v2.0", type=str, help="second model for comparison")
args = parser.parse_args()
-positive_categories = ['semantic', 'TCG', 'assembly', 'architecture', 'mistranslation', 'register', 'x86', 'arm', 'risc-v']
+positive_categories = ['semantic', 'TCG', 'assembly', 'architecture', 'mistranslation', 'register' ]
+architectures = ['x86', 'arm', 'risc-v', 'i386', 'alpha', 'ppc']
negative_categories = ['other', 'boot', 'network', 'kernel virtual machine', 'vnc', 'graphic', 'device', 'socket', 'debug', 'files', 'PID', 'permissions', 'performance']
-categories = positive_categories + negative_categories
+categories = positive_categories + negative_categories + architectures
def list_files_recursive(directory):
result = []
@@ -40,23 +41,41 @@ def output(text : str, category : str, labels : list, scores : list, identifier
file.write(text)
def get_category(classification : dict):
- result = classification['labels'][0]
+ highest_category = classification['labels'][0]
+
+ if sum(1 for i in classification["scores"] if i > 0.9) >= 6:
+ return "all"
+ elif sum(1 for i in classification["scores"] if i < 0.6) >= 6:
+ return "none"
if not args.multi_label:
- return result
+ return highest_category
+ result = highest_category
for label, score in zip(classification["labels"], classification["scores"]):
if label in negative_categories and score >= 0.92:
result = label
break
- if classification['labels'][0] == "semantic" and classification['scores'][0] <= 0.91:
- result = "other"
+ arch = None
+ pos = None
+ if label in positive_categories and score >= 0.92:
+ if not arch:
+ pos = label
+ else:
+ result = arch
+ break
+
+ if label in architectures and score >= 0.92:
+ if not pos:
+ arch = label
+ else:
+ result = pos
+ break
+
- if all(i > 0.9 for i in classification["scores"]):
- result = "all"
- elif all(i < 0.6 for i in classification["scores"]):
- result = "none"
+ # if highest_category == "semantic" and classification['scores'][0] <= 0.92:
+ # result = "other"
return result