summaryrefslogtreecommitdiffstats
path: root/classification
diff options
context:
space:
mode:
authorChristian Krinitsin <mail@krinitsin.com>2025-06-10 17:04:21 +0000
committerChristian Krinitsin <mail@krinitsin.com>2025-06-10 17:04:21 +0000
commit7b681b9f9eedaad2f081ae11a32f459f5a1312ff (patch)
tree447529eab427f2cb024d33933794a27f30369c4d /classification
parentd804cb5b8f55b5e32c217e728fe02f6e53ecdf78 (diff)
downloademulator-bug-study-7b681b9f9eedaad2f081ae11a32f459f5a1312ff.tar.gz
emulator-bug-study-7b681b9f9eedaad2f081ae11a32f459f5a1312ff.zip
add 17th version of the classifier, including results
Diffstat (limited to '')
-rwxr-xr-xclassification/classifier.py63
1 files changed, 34 insertions, 29 deletions
diff --git a/classification/classifier.py b/classification/classifier.py
index c3fca28d..213f64aa 100755
--- a/classification/classifier.py
+++ b/classification/classifier.py
@@ -11,9 +11,9 @@ parser.add_argument('--model', default="facebook/bart-large-mnli", type=str, hel
parser.add_argument('--compare', nargs='?', const="MoritzLaurer/deberta-v3-large-zeroshot-v2.0", type=str, help="second model for comparison")
args = parser.parse_args()
-positive_categories = ['semantic', 'TCG', 'assembly', 'architecture', 'mistranslation', 'register' ]
+positive_categories = ['semantic', 'TCG', 'assembly', 'architecture', 'mistranslation', 'register', 'user-level']
architectures = ['x86', 'arm', 'risc-v', 'i386', 'alpha', 'ppc']
-negative_categories = ['other', 'boot', 'network', 'kernel virtual machine', 'vnc', 'graphic', 'device', 'socket', 'debug', 'files', 'PID', 'permissions', 'performance']
+negative_categories = ['boot', 'network', 'KVM', 'vnc', 'graphic', 'device', 'socket', 'debug', 'files', 'PID', 'permissions', 'performance', 'kernel', 'peripherals', 'VMM', 'hypervisor', 'virtual', 'operating system']
categories = positive_categories + negative_categories + architectures
def list_files_recursive(directory):
@@ -43,41 +43,48 @@ def output(text : str, category : str, labels : list, scores : list, identifier
def get_category(classification : dict):
highest_category = classification['labels'][0]
- if sum(1 for i in classification["scores"] if i > 0.9) >= 6:
- return "all"
- elif sum(1 for i in classification["scores"] if i < 0.6) >= 6:
- return "none"
-
if not args.multi_label:
return highest_category
+ if all(i < 0.8 for i in classification["scores"]):
+ return "none"
+ elif sum(1 for i in classification["scores"] if i > 0.85) >= 20:
+ return "all"
+ elif classification["scores"][0] - classification["scores"][-1] <= 0.2:
+ return "unknown"
+
result = highest_category
+ arch = None
+ pos = None
for label, score in zip(classification["labels"], classification["scores"]):
- if label in negative_categories and score >= 0.92:
- result = label
- break
+ if label in negative_categories and (not arch and not pos or score >= 0.92):
+ return label
- arch = None
- pos = None
- if label in positive_categories and score >= 0.92:
+ if label in positive_categories and not pos and score > 0.8:
+ pos = label
if not arch:
- pos = label
+ result = label
else:
- result = arch
- break
+ result = label + "-" + arch
- if label in architectures and score >= 0.92:
- if not pos:
- arch = label
- else:
- result = pos
- break
+ if label in architectures and not arch and score > 0.8:
+ arch = label
+ if pos:
+ result = pos + "-" + label
+ return result
- # if highest_category == "semantic" and classification['scores'][0] <= 0.92:
- # result = "other"
+def compare_category(classification : dict, category : str):
+ if sum(1 for i in positive_categories if i in category) < 1:
+ return category
- return result
+ for label, score in zip(classification["labels"], classification["scores"]):
+ if label in positive_categories and score >= 0.85:
+ return category
+ if label in category and score >= 0.85:
+ return category
+
+ return "review"
def main():
classifier = pipeline("zero-shot-classification", model=args.model)
@@ -87,6 +94,7 @@ def main():
bugs = list_files_recursive("../results/scraper/mailinglist")
if not args.full:
bugs = bugs + list_files_recursive("../results/scraper/gitlab/semantic_issues")
+ bugs = bugs + [ "../results/scraper/launchpad/1809546", "../results/scraper/launchpad/1156313" ]
else:
bugs = bugs + list_files_recursive("../results/scraper/launchpad")
bugs = bugs + list_files_recursive("../results/scraper/gitlab/issues_text")
@@ -102,10 +110,7 @@ def main():
if args.compare:
compare_result = compare_classifier(text, categories, multi_label=args.multi_label)
- compare_category = get_category(compare_result)
-
- if category != compare_category:
- category = "review"
+ category = compare_category(compare_result, category)
result['labels'] = result['labels'] + ['SPLIT'] + compare_result['labels']
result['scores'] = result['scores'] + [0] + compare_result['scores']