|
import pytest |
|
|
|
from llmdataparser.mmlu_parser import ( |
|
BaseMMLUDatasetParser, |
|
MMLUParseEntry, |
|
MMLUProDatasetParser, |
|
MMLUProParseEntry, |
|
MMLUReduxDatasetParser, |
|
TMMLUPlusDatasetParser, |
|
) |
|
|
|
|
|
@pytest.fixture |
|
def base_parser(): |
|
"""Create a base MMLU parser instance.""" |
|
return BaseMMLUDatasetParser() |
|
|
|
|
|
@pytest.fixture |
|
def redux_parser(): |
|
"""Create a MMLU Redux parser instance.""" |
|
return MMLUReduxDatasetParser() |
|
|
|
|
|
@pytest.fixture |
|
def tmmlu_parser(): |
|
"""Create a TMMLU+ parser instance.""" |
|
return TMMLUPlusDatasetParser() |
|
|
|
|
|
@pytest.fixture |
|
def mmlu_pro_parser(): |
|
"""Create a MMLU Pro parser instance.""" |
|
return MMLUProDatasetParser() |
|
|
|
|
|
@pytest.fixture |
|
def sample_mmlu_entries(): |
|
"""Create sample MMLU dataset entries for testing.""" |
|
return [ |
|
{ |
|
"question": "What is the capital of France?", |
|
"choices": ["London", "Paris", "Berlin", "Madrid"], |
|
"answer": 1, |
|
"subject": "geography", |
|
}, |
|
{ |
|
"question": "Which of these is a primary color?", |
|
"choices": ["Green", "Purple", "Blue", "Orange"], |
|
"answer": 2, |
|
"subject": "art", |
|
}, |
|
] |
|
|
|
|
|
@pytest.fixture |
|
def sample_mmlu_pro_entries(): |
|
"""Create sample MMLU Pro dataset entries for testing.""" |
|
return [ |
|
{ |
|
"question": "What is the time complexity of quicksort?", |
|
"options": ["O(n)", "O(n log n)", "O(n²)", "O(2ⁿ)", "O(n!)", "O(1)"], |
|
"answer": "The average time complexity of quicksort is O(n log n)", |
|
"answer_index": 1, |
|
"category": "computer_science", |
|
} |
|
] |
|
|
|
|
|
def test_mmlu_parse_entry_creation_valid(): |
|
"""Test valid creation of MMLUParseEntry.""" |
|
entry = MMLUParseEntry.create( |
|
question="Test question", |
|
answer="A", |
|
raw_question="Test question", |
|
raw_choices=["choice1", "choice2", "choice3", "choice4"], |
|
raw_answer="0", |
|
task_name="test_task", |
|
) |
|
assert isinstance(entry, MMLUParseEntry) |
|
assert entry.question == "Test question" |
|
assert entry.answer == "A" |
|
assert entry.raw_choices == ["choice1", "choice2", "choice3", "choice4"] |
|
assert entry.task_name == "test_task" |
|
|
|
|
|
@pytest.mark.parametrize("invalid_answer", ["E", "F", "1", "", None]) |
|
def test_mmlu_parse_entry_creation_invalid(invalid_answer): |
|
"""Test invalid answer handling in MMLUParseEntry creation.""" |
|
with pytest.raises( |
|
ValueError, match="Invalid answer_letter.*must be one of A, B, C, D" |
|
): |
|
MMLUParseEntry.create( |
|
question="Test question", |
|
answer=invalid_answer, |
|
raw_question="Test question", |
|
raw_choices=["choice1", "choice2", "choice3", "choice4"], |
|
raw_answer="4", |
|
task_name="test_task", |
|
) |
|
|
|
|
|
def test_process_entry_base(base_parser, sample_mmlu_entries): |
|
"""Test processing entries in base MMLU parser.""" |
|
entry = base_parser.process_entry(sample_mmlu_entries[0], task_name="geography") |
|
|
|
assert isinstance(entry, MMLUParseEntry) |
|
assert entry.answer == "B" |
|
assert "A. London" in entry.question |
|
assert "B. Paris" in entry.question |
|
assert "C. Berlin" in entry.question |
|
assert "D. Madrid" in entry.question |
|
assert entry.raw_question == "What is the capital of France?" |
|
assert entry.raw_choices == ["London", "Paris", "Berlin", "Madrid"] |
|
assert entry.raw_answer == "1" |
|
assert entry.task_name == "geography" |
|
|
|
|
|
def test_mmlu_pro_parse_entry_creation_valid(): |
|
"""Test valid creation of MMLUProParseEntry.""" |
|
entry = MMLUProParseEntry.create( |
|
question="Test question", |
|
answer="E", |
|
raw_question="Test question", |
|
raw_choices=["choice1", "choice2", "choice3", "choice4", "choice5"], |
|
raw_answer="4", |
|
task_name="test_task", |
|
) |
|
assert isinstance(entry, MMLUProParseEntry) |
|
assert entry.answer == "E" |
|
assert len(entry.raw_choices) == 5 |
|
|
|
|
|
def test_process_entry_mmlu_pro(mmlu_pro_parser, sample_mmlu_pro_entries): |
|
"""Test processing entries in MMLU Pro parser.""" |
|
entry = mmlu_pro_parser.process_entry( |
|
sample_mmlu_pro_entries[0], task_name="computer_science" |
|
) |
|
|
|
assert isinstance(entry, MMLUProParseEntry) |
|
assert entry.answer == "B" |
|
assert "O(n log n)" in entry.question |
|
assert entry.task_name == "computer_science" |
|
assert len(entry.raw_choices) == 6 |
|
|
|
|
|
def test_tmmlu_process_entry(tmmlu_parser): |
|
"""Test processing entries in TMMLU+ parser.""" |
|
test_row = { |
|
"question": "什麼是台灣最高的山峰?", |
|
"A": "玉山", |
|
"B": "阿里山", |
|
"C": "合歡山", |
|
"D": "雪山", |
|
"answer": "A", |
|
"subject": "geography_of_taiwan", |
|
} |
|
|
|
entry = tmmlu_parser.process_entry(test_row, task_name="geography_of_taiwan") |
|
assert isinstance(entry, MMLUParseEntry) |
|
assert entry.answer == "A" |
|
assert entry.raw_choices == ["玉山", "阿里山", "合歡山", "雪山"] |
|
assert entry.task_name == "geography_of_taiwan" |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"parser_fixture,expected_tasks,expected_source", |
|
[ |
|
("base_parser", 57, "cais/mmlu"), |
|
("redux_parser", 30, "edinburgh-dawg/mmlu-redux"), |
|
("tmmlu_parser", 66, "ikala/tmmluplus"), |
|
("mmlu_pro_parser", 1, "TIGER-Lab/MMLU-Pro"), |
|
], |
|
) |
|
def test_parser_initialization( |
|
request, parser_fixture, expected_tasks, expected_source |
|
): |
|
"""Test initialization of different MMLU parser variants.""" |
|
parser = request.getfixturevalue(parser_fixture) |
|
assert len(parser.task_names) == expected_tasks |
|
assert parser._data_source == expected_source |
|
assert ( |
|
parser.get_huggingface_link |
|
== f"https://huggingface.co/datasets/{expected_source}" |
|
) |
|
|
|
|
|
@pytest.mark.integration |
|
def test_load_dataset(base_parser): |
|
"""Test loading the MMLU dataset.""" |
|
base_parser.load(task_name="anatomy", split="test") |
|
assert base_parser.raw_data is not None |
|
assert base_parser.split_names == ["test"] |
|
assert base_parser._current_task == "anatomy" |
|
|
|
|
|
def test_parser_string_representation(base_parser): |
|
"""Test string representation of MMLU parser.""" |
|
repr_str = str(base_parser) |
|
assert "MMLUDatasetParser" in repr_str |
|
assert "cais/mmlu" in repr_str |
|
assert "not loaded" in repr_str |
|
|
|
|
|
@pytest.mark.integration |
|
def test_different_splits_parsing(base_parser): |
|
"""Test parsing different splits of the dataset.""" |
|
|
|
base_parser.load(task_name="anatomy", split="test") |
|
base_parser.parse(split_names="test", force=True) |
|
test_count = len(base_parser.get_parsed_data) |
|
|
|
|
|
base_parser.load(task_name="anatomy", split="validation") |
|
base_parser.parse(split_names="validation", force=True) |
|
val_count = len(base_parser.get_parsed_data) |
|
|
|
assert test_count > 0 |
|
assert val_count > 0 |
|
assert test_count != val_count |
|
|
|
|
|
def test_base_mmlu_dataset_description(base_parser): |
|
"""Test dataset description for base MMLU.""" |
|
description = base_parser.get_dataset_description() |
|
|
|
assert description.name == "Massive Multitask Language Understanding (MMLU)" |
|
assert "cais/mmlu" in description.source |
|
assert description.language == "English" |
|
|
|
|
|
assert "57 subjects" in description.characteristics.lower() |
|
|
|
|
|
assert "hendryckstest2021" in description.citation |
|
|
|
|
|
def test_mmlu_redux_dataset_description(redux_parser): |
|
"""Test dataset description for MMLU Redux.""" |
|
description = redux_parser.get_dataset_description() |
|
|
|
assert description.name == "MMLU Redux" |
|
assert "manually re-annotated" in description.purpose.lower() |
|
assert "edinburgh-dawg/mmlu-redux" in description.source |
|
assert description.language == "English" |
|
|
|
|
|
assert "3,000" in description.characteristics |
|
|
|
|
|
def test_tmmlu_plus_dataset_description(tmmlu_parser): |
|
"""Test dataset description for TMMLU+.""" |
|
description = tmmlu_parser.get_dataset_description() |
|
|
|
assert "ikala/tmmluplus" in description.source |
|
assert description.language == "Traditional Chinese" |
|
|
|
|
|
assert "66 subjects" in description.characteristics.lower() |
|
|
|
|
|
assert "ikala2024improved" in description.citation |
|
|
|
|
|
def test_mmlu_pro_dataset_description(mmlu_pro_parser): |
|
"""Test dataset description for MMLU Pro.""" |
|
description = mmlu_pro_parser.get_dataset_description() |
|
|
|
assert description.name == "MMLU Pro" |
|
assert "challenging" in description.purpose.lower() |
|
assert "TIGER-Lab/MMLU-Pro" in description.source |
|
assert description.language == "English" |
|
|
|
|
|
def test_base_mmlu_evaluation_metrics(base_parser): |
|
"""Test evaluation metrics for base MMLU.""" |
|
metrics = base_parser.get_evaluation_metrics() |
|
|
|
assert len(metrics) >= 3 |
|
metric_names = {m.name for m in metrics} |
|
|
|
assert "accuracy" in metric_names |
|
assert "subject_accuracy" in metric_names |
|
assert "category_accuracy" in metric_names |
|
|
|
accuracy_metric = next(m for m in metrics if m.name == "accuracy") |
|
assert accuracy_metric.type == "classification" |
|
assert accuracy_metric.primary is True |
|
assert "multiple-choice" in accuracy_metric.description.lower() |
|
|
|
|
|
def test_mmlu_redux_evaluation_metrics(redux_parser): |
|
"""Test evaluation metrics for MMLU Redux.""" |
|
metrics = redux_parser.get_evaluation_metrics() |
|
|
|
metric_names = {m.name for m in metrics} |
|
assert "question_clarity" in metric_names |
|
|
|
|
|
def test_tmmlu_plus_evaluation_metrics(tmmlu_parser): |
|
"""Test evaluation metrics for TMMLU+.""" |
|
metrics = tmmlu_parser.get_evaluation_metrics() |
|
|
|
metric_names = {m.name for m in metrics} |
|
assert "difficulty_analysis" in metric_names |
|
|
|
|
|
def test_mmlu_pro_evaluation_metrics(mmlu_pro_parser): |
|
"""Test evaluation metrics for MMLU Pro.""" |
|
metrics = mmlu_pro_parser.get_evaluation_metrics() |
|
|
|
metric_names = {m.name for m in metrics} |
|
assert "reasoning_analysis" in metric_names |
|
assert "prompt_robustness" in metric_names |
|
|