JeffYang52415's picture
refactor: remove system prompt
0450c4e unverified
raw
history blame
6.66 kB
from dataclasses import dataclass
from typing import Any, Final
from llmdataparser.base_parser import (
DatasetDescription,
EvaluationMetric,
HuggingFaceDatasetParser,
HuggingFaceParseEntry,
)
TMLU_VALID_ANSWERS: Final[set[str]] = {"A", "B", "C", "D"}
TMLU_VALID_ANSWER_STR: Final[str] = ", ".join(sorted(TMLU_VALID_ANSWERS))
@dataclass(frozen=True, kw_only=True, slots=True)
class TMLUParseEntry(HuggingFaceParseEntry):
"""Custom entry class for TMLU, with fields specific to this dataset parser."""
raw_choices: list[str]
explanation: str
metadata: dict[str, Any]
@classmethod
def create(
cls,
question: str,
answer: str,
raw_question: str,
raw_choices: list[str],
raw_answer: str,
task_name: str,
explanation: str = "",
metadata: dict[str, Any] = {},
) -> "TMLUParseEntry":
if answer not in TMLU_VALID_ANSWERS:
raise ValueError(
f"Invalid answer_letter '{answer}'; must be one of {TMLU_VALID_ANSWER_STR}"
)
return cls(
question=question,
answer=answer,
raw_question=raw_question,
raw_answer=raw_answer,
raw_choices=raw_choices,
task_name=task_name,
explanation=explanation,
metadata=metadata,
)
class TMLUDatasetParser(HuggingFaceDatasetParser[TMLUParseEntry]):
"""Parser for the TMLU dataset."""
_data_source = "miulab/tmlu"
_default_task = "AST_chinese"
_task_names = [
"AST_chinese",
"AST_mathematics",
"AST_biology",
"AST_chemistry",
"AST_physics",
"AST_civics",
"AST_geography",
"AST_history",
"GSAT_chinese",
"GSAT_chemistry",
"GSAT_biology",
"GSAT_physics",
"GSAT_earth_science",
"GSAT_mathematics",
"GSAT_geography",
"GSAT_history",
"GSAT_civics",
"CAP_mathematics",
"CAP_biology",
"CAP_physics",
"CAP_chemistry",
"CAP_earth_science",
"CAP_civics",
"CAP_history",
"CAP_geography",
"CAP_chinese",
"driving_rule",
"basic_traditional_chinese_medicine",
"clinical_traditional_chinese_medicine",
"lawyer_qualification",
"nutritionist",
"tour_leader",
"tour_guide",
"taiwan_tourist_resources",
"clinical_psychologist",
"teacher_qualification",
"accountant",
]
def process_entry(
self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
) -> TMLUParseEntry:
"""Process a single TMLU entry."""
task = task_name or self._get_current_task(row)
# Extract choices in order
raw_choices = [row["A"], row["B"], row["C"], row["D"]]
choices = "\n".join(
f"{chr(65 + i)}. {choice}" for i, choice in enumerate(raw_choices)
)
raw_question = row["question"]
raw_answer = row["answer"]
explanation = row.get("explanation", "")
metadata = row.get("metadata", {})
question = f"Question: {raw_question}\n{choices}\nAnswer:"
return TMLUParseEntry.create(
question=question,
answer=raw_answer,
raw_question=raw_question,
raw_choices=raw_choices,
raw_answer=raw_answer,
task_name=task,
explanation=explanation,
metadata=metadata,
)
def get_dataset_description(self) -> DatasetDescription:
"""Returns description of the TMLU dataset."""
return DatasetDescription.create(
name="Taiwan Multiple-choice Language Understanding (TMLU)",
language="Traditional Chinese",
purpose="Evaluate models on Taiwan-specific educational and professional knowledge",
source="Various Taiwan standardized tests and professional certifications",
category=["Taiwan", "General Knowledge and Reasoning"],
format="Multiple choice questions (A/B/C/D)",
characteristics=(
"Covers various subjects including Advanced Subjects Test (AST), "
"General Scholastic Ability Test (GSAT), College Admission Practice (CAP), "
"and professional certifications"
),
citation="""@article{DBLP:journals/corr/abs-2403-20180,
author = {Po-Heng Chen and Sijia Cheng and Wei-Lin Chen and Yen-Ting Lin and Yun-Nung Chen},
title = {Measuring Taiwanese Mandarin Language Understanding},
journal = {CoRR},
volume = {abs/2403.20180},
year = {2024},
url = {https://doi.org/10.48550/arXiv.2403.20180},
doi = {10.48550/ARXIV.2403.20180},
eprinttype = {arXiv},
eprint = {2403.20180},
timestamp = {Wed, 10 Apr 2024 17:37:45 +0200},
biburl = {https://dblp.org/rec/journals/corr/abs-2403-20180.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}""",
)
def get_evaluation_metrics(self) -> list[EvaluationMetric]:
"""Returns recommended evaluation metrics for TMLU."""
return [
EvaluationMetric.create(
name="accuracy",
type="classification",
description="Overall percentage of correctly answered questions",
implementation="datasets.load_metric('accuracy')",
primary=True,
),
EvaluationMetric.create(
name="per_subject_accuracy",
type="classification",
description="Accuracy broken down by subject areas (AST, GSAT, CAP, etc.)",
implementation="custom_subject_accuracy",
primary=True,
),
]
if __name__ == "__main__":
# Example usage
parser = TMLUDatasetParser()
parser.load()
parser.parse()
# Get parsed data with correct type
parsed_data = parser.get_parsed_data
# Print example entry
if parsed_data:
example = parsed_data[0]
print("\nExample parsed entry:")
print(f"Task: {example.task_name}")
print(f"Question: {example.question}")
print("Choices:")
for i, choice in enumerate(example.raw_choices):
print(f"{chr(65 + i)}. {choice}")
print(f"Correct Answer: {example.answer}")
if example.explanation:
print(f"Explanation: {example.explanation}")
print(f"Metadata: {example.metadata}")