import re | |
import sympy | |
from typing import List | |
def extract_expression(response: str): | |
return response.strip().split('\n')[-1].lower().split('the answer is ')[-1].split('=')[0].strip() | |
def extract_expressions(responses: List[str]): | |
return [extract_expression(response) for response in responses] | |
# refer to https://github.com/princeton-nlp/tree-of-thought-llm/blob/master/src/tot/tasks/game24.py | |
def get_answer_label(expression: str, question: str): | |
numbers = re.findall(r'\d+', expression) | |
problem_numbers = re.findall(r'\d+', question) | |
if sorted(numbers) != sorted(problem_numbers): | |
return False | |
try: | |
# print(sympy.simplify(expression)) | |
return sympy.simplify(expression) == 24 | |
except Exception as e: | |
# print(e) | |
return False | |