--- language: - en tags: - math - reasoning - grpo - gsm8k - reinforcement-learning license: apache-2.0 datasets: - openai/gsm8k metrics: - accuracy - format_adherence library_name: transformers pipeline_tag: text-generation base_model: premai-io/prem-1B-chat --- # GRPO-Trained Math Reasoning Model ## Model Description This model was fine-tuned using GRPO (Generative Reward-Powered Optimization), a reinforcement learning technique that optimizes language models using multiple reward functions to improve both output format consistency and mathematical reasoning abilities. ## Disclaimer Not an official Prem Labs product. This is WIP. ### Base Model - Started with the `premai-io/prem-1B-chat` base model - Uses Flash Attention 2 for efficient training - Model architecture: Causal Language Model (CLM) ### GRPO Training Details GRPO training involves optimizing the model using two specific reward functions: 1. **Format Reward Function** - Ensures responses follow a strict XML-style format: ```xml [step-by-step solution] [numerical answer] ``` - Rewards are given for: - Strict format adherence (0.5 points) - Soft format matching (0.3 points) - Integer answer format (0.5 points) - Proper XML structure (up to 0.5 points) 2. **Correctness Reward Function** - Evaluates mathematical accuracy - Awards 2.0 points for correct numerical answers - Awards 0.0 points for incorrect answers ### Training Configuration - Learning rate: 5e-6 - Batch size: 2 per device - Gradient accumulation steps: 2 - Training epochs: 1 - Uses cosine learning rate scheduler - Warmup ratio: 0.1 - Uses bfloat16 precision - Maximum prompt length: 256 tokens - Maximum completion length: 200 tokens - Number of generations per prompt: 16 ### Dataset - Trained on the GSM8K (Grade School Math 8K) dataset - Dataset contains grade-school level math word problems - Each problem includes a question and step-by-step solution ## Intended Use - Solving mathematical word problems - Providing step-by-step reasoning for solutions - Educational assistance and math tutoring ## Limitations - Limited to the complexity level of GSM8K problems - May struggle with problems requiring knowledge beyond the training data - Performance depends on proper formatting of input queries ## Training Infrastructure - Supports distributed training across multiple GPUs - Uses NCCL backend for distributed processing - Implements gradient clipping (max norm: 0.1) ## Evaluation The model's performance is continuously evaluated during training based on: 1. Format adherence to the XML structure 2. Mathematical accuracy of answers 3. Quality of step-by-step reasoning ## Citation If you use this model, please cite both the original GRPO paper and the GSM8K dataset. ## Training Details ### Training Data The model was trained on the [GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k), which contains 8.5K grade school math word problems. ### Training Procedure - **Hardware:** Multi-GPU setup with NVIDIA GPUs - **Framework:** Hugging Face Transformers, TRL (Transformer Reinforcement Learning) - **Optimization:** GRPO with dual reward functions