File size: 6,984 Bytes
424ff6a 0450c4e 424ff6a 0450c4e 424ff6a e5427e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import pytest
from llmdataparser.gsm8k_parser import GSM8KDatasetParser, GSM8KParseEntry
@pytest.fixture
def gsm8k_parser():
"""Create a GSM8K parser instance for testing."""
return GSM8KDatasetParser()
@pytest.fixture
def loaded_gsm8k_parser(gsm8k_parser):
"""Create and load a GSM8K parser instance for testing."""
gsm8k_parser.load(
task_name="main", split="test"
) # Using test split as it's smaller
return gsm8k_parser
@pytest.fixture
def sample_row():
"""Create a sample GSM8K data row for testing."""
return {
"question": "Janet has 3 apples. She buys 2 more. How many apples does she have now?",
"answer": "Let's solve this step by step:\n1) Initially, Janet has 3 apples\n2) She buys 2 more apples\n3) Total apples = 3 + 2\n#### 5",
}
def test_gsm8k_parse_entry_creation_valid():
"""Test valid creation of GSM8KParseEntry."""
entry = GSM8KParseEntry.create(
question="Test question",
answer="5",
raw_question="Test question",
raw_answer="Solution steps #### 5",
solution="Solution steps",
task_name="main",
numerical_answer=5,
)
assert isinstance(entry, GSM8KParseEntry)
assert entry.question == "Test question"
assert entry.answer == "5"
assert entry.solution == "Solution steps"
assert entry.numerical_answer == 5
assert entry.task_name == "main"
def test_gsm8k_parser_initialization(gsm8k_parser):
"""Test GSM8K parser initialization."""
assert gsm8k_parser._data_source == "openai/gsm8k"
assert gsm8k_parser._default_task == "main"
assert gsm8k_parser._task_names == ["main", "socratic"]
assert (
gsm8k_parser.get_huggingface_link
== "https://huggingface.co/datasets/openai/gsm8k"
)
def test_load_dataset(loaded_gsm8k_parser):
"""Test loading the dataset."""
assert loaded_gsm8k_parser.raw_data is not None
assert loaded_gsm8k_parser.split_names == [
"test"
] # Since we specifically loaded the test split
assert loaded_gsm8k_parser._current_task == "main"
@pytest.mark.integration
def test_full_parse_workflow(loaded_gsm8k_parser):
"""Test the complete workflow of loading and parsing data."""
# Parse the test split
loaded_gsm8k_parser.parse(split_names="test", force=True)
parsed_data = loaded_gsm8k_parser.get_parsed_data
# Basic checks
assert len(parsed_data) > 0
# Check first entry structure
first_entry = parsed_data[0]
assert isinstance(first_entry, GSM8KParseEntry)
assert first_entry.task_name == "main"
assert isinstance(first_entry.numerical_answer, (str, int, float))
assert "####" in first_entry.raw_answer
assert first_entry.solution
def test_process_entry(gsm8k_parser, sample_row):
"""Test processing of a single GSM8K entry."""
entry = gsm8k_parser.process_entry(sample_row, task_name="main")
assert isinstance(entry, GSM8KParseEntry)
assert entry.numerical_answer == 5
assert "Janet has 3 apples" in entry.raw_question
assert "#### 5" in entry.raw_answer
assert "Let's solve this step by step:" in entry.solution
assert entry.task_name == "main"
@pytest.mark.parametrize("split_name", ["invalid_split", "wrong_split"])
def test_parse_with_invalid_split(gsm8k_parser, split_name):
"""Test parsing with invalid split names."""
gsm8k_parser.raw_data = {"train": [], "test": []} # Mock data
with pytest.raises(
ValueError, match=f"Split '{split_name}' not found in the dataset"
):
gsm8k_parser.parse(split_name)
def test_parse_without_loaded_data(gsm8k_parser):
"""Test parsing without loading data first."""
with pytest.raises(
ValueError, match="No data loaded. Please load the dataset first"
):
gsm8k_parser.parse()
@pytest.mark.parametrize(
"test_case",
[
{"question": "Test question", "answer": "Some solution steps #### 42"},
{
"question": "Test question",
"answer": "Complex solution\nWith multiple lines\n#### 123.45",
},
{"question": "Test question", "answer": "No steps #### 0"},
],
)
def test_numerical_answer_extraction(gsm8k_parser, test_case):
"""Test extraction of numerical answers from different formats."""
entry = gsm8k_parser.process_entry(test_case, task_name="main")
assert str(entry.numerical_answer) == test_case["answer"].split("####")[
-1
].strip().replace(",", "")
def test_solution_extraction(gsm8k_parser):
"""Test extraction of solution steps."""
row = {
"question": "Test question",
"answer": "Step 1: Do this\nStep 2: Do that\n#### 42",
}
entry = gsm8k_parser.process_entry(row, task_name="main")
assert entry.solution == "Step 1: Do this\nStep 2: Do that"
assert entry.task_name == "main"
assert "####" not in entry.solution
def test_parser_properties(gsm8k_parser):
"""Test parser property getters."""
assert gsm8k_parser.task_names == ["main", "socratic"]
assert gsm8k_parser.total_tasks == 2
def test_parser_string_representation(loaded_gsm8k_parser):
"""Test string representation of parser."""
repr_str = str(loaded_gsm8k_parser)
assert "GSM8KDatasetParser" in repr_str
assert "openai/gsm8k" in repr_str
assert "main" in repr_str
assert "loaded" in repr_str
@pytest.mark.integration
def test_different_splits_parsing(gsm8k_parser):
"""Test parsing different splits of the dataset."""
# Load and parse test split
gsm8k_parser.load(task_name="main", split="test")
gsm8k_parser.parse(split_names="test", force=True)
test_count = len(gsm8k_parser.get_parsed_data)
# Load and parse train split
gsm8k_parser.load(task_name="main", split="train")
gsm8k_parser.parse(split_names="train", force=True)
train_count = len(gsm8k_parser.get_parsed_data)
assert test_count > 0
assert train_count > 0
assert train_count != test_count
def test_get_dataset_description(gsm8k_parser):
"""Test dataset description generation."""
description = gsm8k_parser.get_dataset_description()
assert description.name == "Grade School Math 8K (GSM8K)"
assert description.source == "OpenAI"
assert description.language == "English"
assert "Cobbe" in description.citation
def test_get_evaluation_metrics(gsm8k_parser):
"""Test evaluation metrics specification."""
metrics = gsm8k_parser.get_evaluation_metrics()
# Check we have all expected metrics
metric_names = {metric.name for metric in metrics}
expected_names = {"exact_match", "solution_validity", "step_accuracy", "step_count"}
assert metric_names == expected_names
# Check exact_match metric details
exact_match = next(m for m in metrics if m.name == "exact_match")
assert exact_match.type == "string"
assert exact_match.primary is True
assert "exact match" in exact_match.description.lower()
|