LLMEval-Dataset-Parser / tests /test_tmlu_parser.py
JeffYang52415's picture
refactor: remove system prompt
0450c4e unverified
raw
history blame
6.4 kB
import pytest
from llmdataparser.tmlu_parser import TMLUDatasetParser, TMLUParseEntry
@pytest.fixture
def tmlu_parser():
"""Create a TMLU parser instance for testing."""
return TMLUDatasetParser()
@pytest.fixture
def sample_tmlu_entries():
"""Create sample TMLU dataset entries for testing."""
return [
{
"question": "閱讀下文,選出依序最適合填入□內的選項:",
"A": "張揚/綢繆未雨/奏疏",
"B": "抽搐/煮繭抽絲/奏疏",
"C": "張揚/煮繭抽絲/進貢",
"D": "抽搐/綢繆未雨/進貢",
"answer": "B",
"explanation": "根據文意,選項B最為恰當。",
"metadata": {
"timestamp": "2023-10-09T18:27:20.304623",
"source": "AST chinese - 108",
"explanation_source": "",
},
},
{
"question": "下列何者是質數?",
"A": "21",
"B": "27",
"C": "31",
"D": "33",
"answer": "C",
"explanation": "31是質數,其他選項都是合數。",
"metadata": {
"timestamp": "2023-10-09T18:27:20.304623",
"source": "AST mathematics - 108",
"explanation_source": "",
},
},
]
def test_tmlu_parse_entry_creation_valid():
"""Test valid creation of TMLUParseEntry."""
entry = TMLUParseEntry.create(
question="Test question",
answer="A",
raw_question="Test question",
raw_choices=["choice1", "choice2", "choice3", "choice4"],
raw_answer="A",
task_name="AST_chinese",
explanation="Test explanation",
metadata={"source": "test"},
)
assert isinstance(entry, TMLUParseEntry)
assert entry.question == "Test question"
assert entry.answer == "A"
assert entry.raw_choices == ["choice1", "choice2", "choice3", "choice4"]
assert entry.explanation == "Test explanation"
assert entry.metadata == {"source": "test"}
@pytest.mark.parametrize("invalid_answer", ["E", "F", "1", "", None])
def test_tmlu_parse_entry_creation_invalid(invalid_answer):
"""Test invalid answer handling in TMLUParseEntry creation."""
with pytest.raises(
ValueError, match="Invalid answer_letter.*must be one of A, B, C, D"
):
TMLUParseEntry.create(
question="Test question",
answer=invalid_answer,
raw_question="Test question",
raw_choices=["choice1", "choice2", "choice3", "choice4"],
raw_answer=invalid_answer,
task_name="AST_chinese",
)
def test_process_entry(tmlu_parser, sample_tmlu_entries):
"""Test processing entries in TMLU parser."""
entry = tmlu_parser.process_entry(sample_tmlu_entries[0], task_name="AST_chinese")
assert isinstance(entry, TMLUParseEntry)
assert entry.answer == "B"
assert entry.task_name == "AST_chinese"
assert len(entry.raw_choices) == 4
assert entry.explanation == "根據文意,選項B最為恰當。"
assert "AST chinese - 108" in entry.metadata["source"]
def test_tmlu_parser_initialization(tmlu_parser):
"""Test TMLU parser initialization and properties."""
assert isinstance(tmlu_parser.task_names, list)
assert len(tmlu_parser.task_names) == 37 # Total number of tasks
assert tmlu_parser._data_source == "miulab/tmlu"
assert tmlu_parser._default_task == "AST_chinese"
assert "AST_chinese" in tmlu_parser.task_names
assert "GSAT_mathematics" in tmlu_parser.task_names
assert (
tmlu_parser.get_huggingface_link
== "https://huggingface.co/datasets/miulab/tmlu"
)
@pytest.mark.integration
def test_load_dataset(tmlu_parser):
"""Test loading the TMLU dataset."""
tmlu_parser.load(task_name="AST_chinese", split="test")
assert tmlu_parser.raw_data is not None
assert tmlu_parser.split_names == ["test"]
assert tmlu_parser._current_task == "AST_chinese"
def test_parser_string_representation(tmlu_parser):
"""Test string representation of TMLU parser."""
repr_str = str(tmlu_parser)
assert "TMLUDatasetParser" in repr_str
assert "miulab/tmlu" in repr_str
assert "not loaded" in repr_str
@pytest.mark.integration
def test_different_tasks_parsing(tmlu_parser):
"""Test parsing different tasks of the dataset."""
# Load and parse AST_chinese
tmlu_parser.load(task_name="AST_chinese", split="test")
tmlu_parser.parse(split_names="test", force=True)
chinese_count = len(tmlu_parser.get_parsed_data)
# Load and parse AST_mathematics
tmlu_parser.load(task_name="AST_mathematics", split="test")
tmlu_parser.parse(split_names="test", force=True)
math_count = len(tmlu_parser.get_parsed_data)
assert chinese_count > 0
assert math_count > 0
def test_metadata_handling(tmlu_parser, sample_tmlu_entries):
"""Test proper handling of metadata in entries."""
entry = tmlu_parser.process_entry(sample_tmlu_entries[0])
assert "timestamp" in entry.metadata
assert "source" in entry.metadata
assert "explanation_source" in entry.metadata
assert entry.metadata["source"] == "AST chinese - 108"
def test_get_dataset_description(tmlu_parser):
"""Test dataset description generation."""
description = tmlu_parser.get_dataset_description()
assert description.name == "Taiwan Multiple-choice Language Understanding (TMLU)"
assert description.language == "Traditional Chinese"
assert "Taiwan-specific educational" in description.purpose
assert "Various Taiwan standardized tests" in description.source
assert description.format == "Multiple choice questions (A/B/C/D)"
assert "Advanced Subjects Test (AST)" in description.characteristics
assert "DBLP:journals/corr/abs-2403-20180" in description.citation
def test_get_evaluation_metrics(tmlu_parser):
"""Test evaluation metrics generation."""
metrics = tmlu_parser.get_evaluation_metrics()
assert len(metrics) == 2 # Check total number of metrics
# Check primary metrics
primary_metrics = [m for m in metrics if m.primary]
assert len(primary_metrics) == 2
assert any(m.name == "accuracy" for m in primary_metrics)
assert any(m.name == "per_subject_accuracy" for m in primary_metrics)