GRPO-Trained Math Reasoning Model
Model Description
This model was fine-tuned using GRPO (Generative Reward-Powered Optimization), a reinforcement learning technique that optimizes language models using multiple reward functions to improve both output format consistency and mathematical reasoning abilities.
Disclaimer
Not an official Prem Labs product. This is WIP.
Base Model
- Started with the
premai-io/prem-1B-chat
base model
- Uses Flash Attention 2 for efficient training
- Model architecture: Causal Language Model (CLM)
GRPO Training Details
GRPO training involves optimizing the model using two specific reward functions:
Format Reward Function
- Ensures responses follow a strict XML-style format:
<reasoning>
[step-by-step solution]
</reasoning>
<answer>
[numerical answer]
</answer>
- Rewards are given for:
- Strict format adherence (0.5 points)
- Soft format matching (0.3 points)
- Integer answer format (0.5 points)
- Proper XML structure (up to 0.5 points)
Correctness Reward Function
- Evaluates mathematical accuracy
- Awards 2.0 points for correct numerical answers
- Awards 0.0 points for incorrect answers
Training Configuration
- Learning rate: 5e-6
- Batch size: 2 per device
- Gradient accumulation steps: 2
- Training epochs: 1
- Uses cosine learning rate scheduler
- Warmup ratio: 0.1
- Uses bfloat16 precision
- Maximum prompt length: 256 tokens
- Maximum completion length: 200 tokens
- Number of generations per prompt: 16
Dataset
- Trained on the GSM8K (Grade School Math 8K) dataset
- Dataset contains grade-school level math word problems
- Each problem includes a question and step-by-step solution
Intended Use
- Solving mathematical word problems
- Providing step-by-step reasoning for solutions
- Educational assistance and math tutoring
Limitations
- Limited to the complexity level of GSM8K problems
- May struggle with problems requiring knowledge beyond the training data
- Performance depends on proper formatting of input queries
Training Infrastructure
- Supports distributed training across multiple GPUs
- Uses NCCL backend for distributed processing
- Implements gradient clipping (max norm: 0.1)
Evaluation
The model's performance is continuously evaluated during training based on:
- Format adherence to the XML structure
- Mathematical accuracy of answers
- Quality of step-by-step reasoning
Citation
If you use this model, please cite both the original GRPO paper and the GSM8K dataset.
Training Details
Training Data
The model was trained on the GSM8K dataset, which contains 8.5K grade school math word problems.
Training Procedure
- Hardware: Multi-GPU setup with NVIDIA GPUs
- Framework: Hugging Face Transformers, TRL (Transformer Reinforcement Learning)
- Optimization: GRPO with dual reward functions