LLMEval-Dataset-Parser / tests /test_mmlu_parser.py
JeffYang52415's picture
refactor: remove system prompt
0450c4e unverified
import pytest
from llmdataparser.mmlu_parser import (
BaseMMLUDatasetParser,
MMLUParseEntry,
MMLUProDatasetParser,
MMLUProParseEntry,
MMLUReduxDatasetParser,
TMMLUPlusDatasetParser,
)
@pytest.fixture
def base_parser():
"""Create a base MMLU parser instance."""
return BaseMMLUDatasetParser()
@pytest.fixture
def redux_parser():
"""Create a MMLU Redux parser instance."""
return MMLUReduxDatasetParser()
@pytest.fixture
def tmmlu_parser():
"""Create a TMMLU+ parser instance."""
return TMMLUPlusDatasetParser()
@pytest.fixture
def mmlu_pro_parser():
"""Create a MMLU Pro parser instance."""
return MMLUProDatasetParser()
@pytest.fixture
def sample_mmlu_entries():
"""Create sample MMLU dataset entries for testing."""
return [
{
"question": "What is the capital of France?",
"choices": ["London", "Paris", "Berlin", "Madrid"],
"answer": 1, # Paris
"subject": "geography",
},
{
"question": "Which of these is a primary color?",
"choices": ["Green", "Purple", "Blue", "Orange"],
"answer": 2, # Blue
"subject": "art",
},
]
@pytest.fixture
def sample_mmlu_pro_entries():
"""Create sample MMLU Pro dataset entries for testing."""
return [
{
"question": "What is the time complexity of quicksort?",
"options": ["O(n)", "O(n log n)", "O(n²)", "O(2ⁿ)", "O(n!)", "O(1)"],
"answer": "The average time complexity of quicksort is O(n log n)",
"answer_index": 1,
"category": "computer_science",
}
]
def test_mmlu_parse_entry_creation_valid():
"""Test valid creation of MMLUParseEntry."""
entry = MMLUParseEntry.create(
question="Test question",
answer="A",
raw_question="Test question",
raw_choices=["choice1", "choice2", "choice3", "choice4"],
raw_answer="0",
task_name="test_task",
)
assert isinstance(entry, MMLUParseEntry)
assert entry.question == "Test question"
assert entry.answer == "A"
assert entry.raw_choices == ["choice1", "choice2", "choice3", "choice4"]
assert entry.task_name == "test_task"
@pytest.mark.parametrize("invalid_answer", ["E", "F", "1", "", None])
def test_mmlu_parse_entry_creation_invalid(invalid_answer):
"""Test invalid answer handling in MMLUParseEntry creation."""
with pytest.raises(
ValueError, match="Invalid answer_letter.*must be one of A, B, C, D"
):
MMLUParseEntry.create(
question="Test question",
answer=invalid_answer,
raw_question="Test question",
raw_choices=["choice1", "choice2", "choice3", "choice4"],
raw_answer="4",
task_name="test_task",
)
def test_process_entry_base(base_parser, sample_mmlu_entries):
"""Test processing entries in base MMLU parser."""
entry = base_parser.process_entry(sample_mmlu_entries[0], task_name="geography")
assert isinstance(entry, MMLUParseEntry)
assert entry.answer == "B" # Index 1 maps to B
assert "A. London" in entry.question
assert "B. Paris" in entry.question
assert "C. Berlin" in entry.question
assert "D. Madrid" in entry.question
assert entry.raw_question == "What is the capital of France?"
assert entry.raw_choices == ["London", "Paris", "Berlin", "Madrid"]
assert entry.raw_answer == "1"
assert entry.task_name == "geography"
def test_mmlu_pro_parse_entry_creation_valid():
"""Test valid creation of MMLUProParseEntry."""
entry = MMLUProParseEntry.create(
question="Test question",
answer="E", # MMLU Pro supports up to J
raw_question="Test question",
raw_choices=["choice1", "choice2", "choice3", "choice4", "choice5"],
raw_answer="4",
task_name="test_task",
)
assert isinstance(entry, MMLUProParseEntry)
assert entry.answer == "E"
assert len(entry.raw_choices) == 5
def test_process_entry_mmlu_pro(mmlu_pro_parser, sample_mmlu_pro_entries):
"""Test processing entries in MMLU Pro parser."""
entry = mmlu_pro_parser.process_entry(
sample_mmlu_pro_entries[0], task_name="computer_science"
)
assert isinstance(entry, MMLUProParseEntry)
assert entry.answer == "B" # Index 1 maps to B
assert "O(n log n)" in entry.question
assert entry.task_name == "computer_science"
assert len(entry.raw_choices) == 6
def test_tmmlu_process_entry(tmmlu_parser):
"""Test processing entries in TMMLU+ parser."""
test_row = {
"question": "什麼是台灣最高的山峰?",
"A": "玉山",
"B": "阿里山",
"C": "合歡山",
"D": "雪山",
"answer": "A",
"subject": "geography_of_taiwan",
}
entry = tmmlu_parser.process_entry(test_row, task_name="geography_of_taiwan")
assert isinstance(entry, MMLUParseEntry)
assert entry.answer == "A"
assert entry.raw_choices == ["玉山", "阿里山", "合歡山", "雪山"]
assert entry.task_name == "geography_of_taiwan"
@pytest.mark.parametrize(
"parser_fixture,expected_tasks,expected_source",
[
("base_parser", 57, "cais/mmlu"),
("redux_parser", 30, "edinburgh-dawg/mmlu-redux"),
("tmmlu_parser", 66, "ikala/tmmluplus"),
("mmlu_pro_parser", 1, "TIGER-Lab/MMLU-Pro"),
],
)
def test_parser_initialization(
request, parser_fixture, expected_tasks, expected_source
):
"""Test initialization of different MMLU parser variants."""
parser = request.getfixturevalue(parser_fixture)
assert len(parser.task_names) == expected_tasks
assert parser._data_source == expected_source
assert (
parser.get_huggingface_link
== f"https://huggingface.co/datasets/{expected_source}"
)
@pytest.mark.integration
def test_load_dataset(base_parser):
"""Test loading the MMLU dataset."""
base_parser.load(task_name="anatomy", split="test")
assert base_parser.raw_data is not None
assert base_parser.split_names == ["test"]
assert base_parser._current_task == "anatomy"
def test_parser_string_representation(base_parser):
"""Test string representation of MMLU parser."""
repr_str = str(base_parser)
assert "MMLUDatasetParser" in repr_str
assert "cais/mmlu" in repr_str
assert "not loaded" in repr_str
@pytest.mark.integration
def test_different_splits_parsing(base_parser):
"""Test parsing different splits of the dataset."""
# Load and parse test split
base_parser.load(task_name="anatomy", split="test")
base_parser.parse(split_names="test", force=True)
test_count = len(base_parser.get_parsed_data)
# Load and parse validation split
base_parser.load(task_name="anatomy", split="validation")
base_parser.parse(split_names="validation", force=True)
val_count = len(base_parser.get_parsed_data)
assert test_count > 0
assert val_count > 0
assert test_count != val_count
def test_base_mmlu_dataset_description(base_parser):
"""Test dataset description for base MMLU."""
description = base_parser.get_dataset_description()
assert description.name == "Massive Multitask Language Understanding (MMLU)"
assert "cais/mmlu" in description.source
assert description.language == "English"
# Check characteristics
assert "57 subjects" in description.characteristics.lower()
# Check citation
assert "hendryckstest2021" in description.citation
def test_mmlu_redux_dataset_description(redux_parser):
"""Test dataset description for MMLU Redux."""
description = redux_parser.get_dataset_description()
assert description.name == "MMLU Redux"
assert "manually re-annotated" in description.purpose.lower()
assert "edinburgh-dawg/mmlu-redux" in description.source
assert description.language == "English"
# Check characteristics
assert "3,000" in description.characteristics
def test_tmmlu_plus_dataset_description(tmmlu_parser):
"""Test dataset description for TMMLU+."""
description = tmmlu_parser.get_dataset_description()
assert "ikala/tmmluplus" in description.source
assert description.language == "Traditional Chinese"
# Check characteristics
assert "66 subjects" in description.characteristics.lower()
# Check citation
assert "ikala2024improved" in description.citation
def test_mmlu_pro_dataset_description(mmlu_pro_parser):
"""Test dataset description for MMLU Pro."""
description = mmlu_pro_parser.get_dataset_description()
assert description.name == "MMLU Pro"
assert "challenging" in description.purpose.lower()
assert "TIGER-Lab/MMLU-Pro" in description.source
assert description.language == "English"
def test_base_mmlu_evaluation_metrics(base_parser):
"""Test evaluation metrics for base MMLU."""
metrics = base_parser.get_evaluation_metrics()
assert len(metrics) >= 3
metric_names = {m.name for m in metrics}
assert "accuracy" in metric_names
assert "subject_accuracy" in metric_names
assert "category_accuracy" in metric_names
accuracy_metric = next(m for m in metrics if m.name == "accuracy")
assert accuracy_metric.type == "classification"
assert accuracy_metric.primary is True
assert "multiple-choice" in accuracy_metric.description.lower()
def test_mmlu_redux_evaluation_metrics(redux_parser):
"""Test evaluation metrics for MMLU Redux."""
metrics = redux_parser.get_evaluation_metrics()
metric_names = {m.name for m in metrics}
assert "question_clarity" in metric_names
def test_tmmlu_plus_evaluation_metrics(tmmlu_parser):
"""Test evaluation metrics for TMMLU+."""
metrics = tmmlu_parser.get_evaluation_metrics()
metric_names = {m.name for m in metrics}
assert "difficulty_analysis" in metric_names
def test_mmlu_pro_evaluation_metrics(mmlu_pro_parser):
"""Test evaluation metrics for MMLU Pro."""
metrics = mmlu_pro_parser.get_evaluation_metrics()
metric_names = {m.name for m in metrics}
assert "reasoning_analysis" in metric_names
assert "prompt_robustness" in metric_names