rookiemango's picture
Upload folder using huggingface_hub
da66274 verified
import os
import torch
import numpy as np
from typing import Optional, List, Dict, Set, Any, Union
import torch.distributed as dist
import re
from utils.game24.decoding import extract_expressions, get_answer_label
class GeneratorAnswerAcc:
def __init__(self, n_data: int):
self.n_data = n_data
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
self.corrs = []
self.gather = False
@torch.inference_mode(mode=True)
def __call__(self, completions: List[str], questions: List[str]):
expressions = extract_expressions(completions)
corrs = [float(get_answer_label(expression, question)) == True for expression, question in zip(expressions, questions)]
self.corrs.append(corrs)
def get_metric(self, reset=True):
if not self.gather:
if self.world_size != 1:
gathered_corrs = [None] * self.world_size
for obj, container in [
(self.corrs, gathered_corrs),
]:
dist.all_gather_object(container, obj)
flatten_corrs = []
for corrs_gpus in zip(*gathered_corrs):
for corrs in corrs_gpus:
flatten_corrs.extend(corrs)
else:
flatten_corrs = [item for sublist in self.corrs for item in sublist]
self.corrs = flatten_corrs[:self.n_data]
self.gather = True
acc = (sum(self.corrs) / len(self.corrs))
if reset:
self.corrs = []
self.gather = False
return acc
class MultiSamplingAnswerAcc:
def __init__(self, n_data: int = None):
self.n_data = n_data
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
self.answers = []
self.questions = []
def start_new_sol_epoch(self):
self.cur_answers = []
self.cur_questions = []
def end_the_sol_epoch(self):
if self.world_size != 1:
gathered_answers, gathered_questions = tuple([None] * self.world_size for _ in range(2))
for obj, container in [
(self.cur_answers, gathered_answers),
(self.cur_questions, gathered_questions),
]:
dist.all_gather_object(container, obj)
flatten_answers, flatten_questions = [], []
for answers_gpus, questions_gpus in zip(zip(*gathered_answers), zip(*gathered_questions)):
for answers, questions in zip(answers_gpus, questions_gpus):
flatten_answers.extend(answers)
flatten_questions.extend(questions)
else:
flatten_answers, flatten_questions = tuple([item for sublist in container for item in sublist]
for container in [self.cur_answers, self.cur_questions])
self.answers.append(flatten_answers[:self.n_data])
self.questions.append(flatten_questions[:self.n_data])
@torch.inference_mode(mode=True)
def __call__(self, completions: List[str], questions: List[str]):
expressions = extract_expressions(completions)
self.cur_answers.append(expressions)
self.cur_questions.append(questions)
def get_metric(self, n_solution: int=3, reset=True):
# [n_question, n_solution]
answers = self.answers[:n_solution]
# [n_question]
questions = self.questions[:n_solution][0]
pass_k = np.mean([is_passk(expressions, question) for expressions, question in zip(answers, questions)])
acc_majority = np.mean([is_majority(expressions, question) for expressions, question in zip(answers, questions)])
if reset:
self.answers = []
self.questions = []
return pass_k, acc_majority
def is_passk(expressions, question):
return any(get_answer_label(expression, question) for expression in expressions)
def is_majority(expressions, question):
repres = [get_semantics(expr) for expr in expressions]
final_repre = max(repres, key=repres.count)
index = repres.index(final_repre)
return get_answer_label(expressions[index], question)
def get_semantics(expression):
numbers = re.findall(r'\d+', expression)
symbols = re.findall(r'[+\-\*\/]', expression)
try:
value = eval(expression)
except:
value = None
value = str(value)
if value[-2:] == '.0':
value = value[:-2]
return tuple(sorted(numbers) + sorted(symbols) + [f'value={value}'])