import numpy as np def softmax(logits: np.ndarray) -> np.ndarray: exp_logits = np.exp(logits - np.max(logits)) return exp_logits / exp_logits.sum(axis=0) def one_hot(probs: np.array) -> np.array: one_hot = np.zeros_like(probs) one_hot[np.argmax(probs)] = 1 return one_hot def opt_to_index(s): if s.startswith("(") and s.endswith(")"): letter = s[1] # Extract the letter inside the parentheses return ord(letter) - ord("A") # Convert to zero-based index elif is_single_letter(s): return ord(s.upper()) - ord("A") else: raise ValueError("Invalid format") def is_single_letter(s): return len(s) == 1 and s.isalpha() def get_test_target(doc): if "target" in doc: return doc["target"], "target" elif "answer" in doc: return doc["answer"], "answer" else: return "", ""