---
language:
- en
tags:
- math
- reasoning
- grpo
- gsm8k
- reinforcement-learning
license: apache-2.0
datasets:
- openai/gsm8k
metrics:
- accuracy
- format_adherence
library_name: transformers
pipeline_tag: text-generation
base_model: premai-io/prem-1B-chat
---
# GRPO-Trained Math Reasoning Model
## Model Description
This model was fine-tuned using GRPO (Generative Reward-Powered Optimization), a reinforcement learning technique that optimizes language models using multiple reward functions to improve both output format consistency and mathematical reasoning abilities.
## Disclaimer
Not an official Prem Labs product. This is WIP.
### Base Model
- Started with the `premai-io/prem-1B-chat` base model
- Uses Flash Attention 2 for efficient training
- Model architecture: Causal Language Model (CLM)
### GRPO Training Details
GRPO training involves optimizing the model using two specific reward functions:
1. **Format Reward Function**
- Ensures responses follow a strict XML-style format:
```xml
[step-by-step solution]
[numerical answer]
```
- Rewards are given for:
- Strict format adherence (0.5 points)
- Soft format matching (0.3 points)
- Integer answer format (0.5 points)
- Proper XML structure (up to 0.5 points)
2. **Correctness Reward Function**
- Evaluates mathematical accuracy
- Awards 2.0 points for correct numerical answers
- Awards 0.0 points for incorrect answers
### Training Configuration
- Learning rate: 5e-6
- Batch size: 2 per device
- Gradient accumulation steps: 2
- Training epochs: 1
- Uses cosine learning rate scheduler
- Warmup ratio: 0.1
- Uses bfloat16 precision
- Maximum prompt length: 256 tokens
- Maximum completion length: 200 tokens
- Number of generations per prompt: 16
### Dataset
- Trained on the GSM8K (Grade School Math 8K) dataset
- Dataset contains grade-school level math word problems
- Each problem includes a question and step-by-step solution
## Intended Use
- Solving mathematical word problems
- Providing step-by-step reasoning for solutions
- Educational assistance and math tutoring
## Limitations
- Limited to the complexity level of GSM8K problems
- May struggle with problems requiring knowledge beyond the training data
- Performance depends on proper formatting of input queries
## Training Infrastructure
- Supports distributed training across multiple GPUs
- Uses NCCL backend for distributed processing
- Implements gradient clipping (max norm: 0.1)
## Evaluation
The model's performance is continuously evaluated during training based on:
1. Format adherence to the XML structure
2. Mathematical accuracy of answers
3. Quality of step-by-step reasoning
## Citation
If you use this model, please cite both the original GRPO paper and the GSM8K dataset.
## Training Details
### Training Data
The model was trained on the [GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k), which contains 8.5K grade school math word problems.
### Training Procedure
- **Hardware:** Multi-GPU setup with NVIDIA GPUs
- **Framework:** Hugging Face Transformers, TRL (Transformer Reinforcement Learning)
- **Optimization:** GRPO with dual reward functions