File size: 4,655 Bytes
da66274 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import os
import torch
import numpy as np
from typing import Optional, List, Dict, Set, Any, Union
import torch.distributed as dist
import re
from utils.game24.decoding import extract_expressions, get_answer_label
class GeneratorAnswerAcc:
def __init__(self, n_data: int):
self.n_data = n_data
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
self.corrs = []
self.gather = False
@torch.inference_mode(mode=True)
def __call__(self, completions: List[str], questions: List[str]):
expressions = extract_expressions(completions)
corrs = [float(get_answer_label(expression, question)) == True for expression, question in zip(expressions, questions)]
self.corrs.append(corrs)
def get_metric(self, reset=True):
if not self.gather:
if self.world_size != 1:
gathered_corrs = [None] * self.world_size
for obj, container in [
(self.corrs, gathered_corrs),
]:
dist.all_gather_object(container, obj)
flatten_corrs = []
for corrs_gpus in zip(*gathered_corrs):
for corrs in corrs_gpus:
flatten_corrs.extend(corrs)
else:
flatten_corrs = [item for sublist in self.corrs for item in sublist]
self.corrs = flatten_corrs[:self.n_data]
self.gather = True
acc = (sum(self.corrs) / len(self.corrs))
if reset:
self.corrs = []
self.gather = False
return acc
class MultiSamplingAnswerAcc:
def __init__(self, n_data: int = None):
self.n_data = n_data
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
self.answers = []
self.questions = []
def start_new_sol_epoch(self):
self.cur_answers = []
self.cur_questions = []
def end_the_sol_epoch(self):
if self.world_size != 1:
gathered_answers, gathered_questions = tuple([None] * self.world_size for _ in range(2))
for obj, container in [
(self.cur_answers, gathered_answers),
(self.cur_questions, gathered_questions),
]:
dist.all_gather_object(container, obj)
flatten_answers, flatten_questions = [], []
for answers_gpus, questions_gpus in zip(zip(*gathered_answers), zip(*gathered_questions)):
for answers, questions in zip(answers_gpus, questions_gpus):
flatten_answers.extend(answers)
flatten_questions.extend(questions)
else:
flatten_answers, flatten_questions = tuple([item for sublist in container for item in sublist]
for container in [self.cur_answers, self.cur_questions])
self.answers.append(flatten_answers[:self.n_data])
self.questions.append(flatten_questions[:self.n_data])
@torch.inference_mode(mode=True)
def __call__(self, completions: List[str], questions: List[str]):
expressions = extract_expressions(completions)
self.cur_answers.append(expressions)
self.cur_questions.append(questions)
def get_metric(self, n_solution: int=3, reset=True):
# [n_question, n_solution]
answers = self.answers[:n_solution]
# [n_question]
questions = self.questions[:n_solution][0]
pass_k = np.mean([is_passk(expressions, question) for expressions, question in zip(answers, questions)])
acc_majority = np.mean([is_majority(expressions, question) for expressions, question in zip(answers, questions)])
if reset:
self.answers = []
self.questions = []
return pass_k, acc_majority
def is_passk(expressions, question):
return any(get_answer_label(expression, question) for expression in expressions)
def is_majority(expressions, question):
repres = [get_semantics(expr) for expr in expressions]
final_repre = max(repres, key=repres.count)
index = repres.index(final_repre)
return get_answer_label(expressions[index], question)
def get_semantics(expression):
numbers = re.findall(r'\d+', expression)
symbols = re.findall(r'[+\-\*\/]', expression)
try:
value = eval(expression)
except:
value = None
value = str(value)
if value[-2:] == '.0':
value = value[:-2]
return tuple(sorted(numbers) + sorted(symbols) + [f'value={value}'])
|