prem-1B-grpo / README.md
ucalyptus's picture
Update README.md
6b76fae verified
metadata
language:
  - en
tags:
  - math
  - reasoning
  - grpo
  - gsm8k
  - reinforcement-learning
license: apache-2.0
datasets:
  - openai/gsm8k
metrics:
  - accuracy
  - format_adherence
library_name: transformers
pipeline_tag: text-generation
base_model: premai-io/prem-1B-chat

GRPO-Trained Math Reasoning Model

Model Description

This model was fine-tuned using GRPO (Generative Reward-Powered Optimization), a reinforcement learning technique that optimizes language models using multiple reward functions to improve both output format consistency and mathematical reasoning abilities.

Disclaimer

Not an official Prem Labs product. This is WIP.

Base Model

  • Started with the premai-io/prem-1B-chat base model
  • Uses Flash Attention 2 for efficient training
  • Model architecture: Causal Language Model (CLM)

GRPO Training Details

GRPO training involves optimizing the model using two specific reward functions:

  1. Format Reward Function

    • Ensures responses follow a strict XML-style format:
    <reasoning>
    [step-by-step solution]
    </reasoning>
    <answer>
    [numerical answer]
    </answer>
    
    • Rewards are given for:
      • Strict format adherence (0.5 points)
      • Soft format matching (0.3 points)
      • Integer answer format (0.5 points)
      • Proper XML structure (up to 0.5 points)
  2. Correctness Reward Function

    • Evaluates mathematical accuracy
    • Awards 2.0 points for correct numerical answers
    • Awards 0.0 points for incorrect answers

Training Configuration

  • Learning rate: 5e-6
  • Batch size: 2 per device
  • Gradient accumulation steps: 2
  • Training epochs: 1
  • Uses cosine learning rate scheduler
  • Warmup ratio: 0.1
  • Uses bfloat16 precision
  • Maximum prompt length: 256 tokens
  • Maximum completion length: 200 tokens
  • Number of generations per prompt: 16

Dataset

  • Trained on the GSM8K (Grade School Math 8K) dataset
  • Dataset contains grade-school level math word problems
  • Each problem includes a question and step-by-step solution

Intended Use

  • Solving mathematical word problems
  • Providing step-by-step reasoning for solutions
  • Educational assistance and math tutoring

Limitations

  • Limited to the complexity level of GSM8K problems
  • May struggle with problems requiring knowledge beyond the training data
  • Performance depends on proper formatting of input queries

Training Infrastructure

  • Supports distributed training across multiple GPUs
  • Uses NCCL backend for distributed processing
  • Implements gradient clipping (max norm: 0.1)

Evaluation

The model's performance is continuously evaluated during training based on:

  1. Format adherence to the XML structure
  2. Mathematical accuracy of answers
  3. Quality of step-by-step reasoning

Citation

If you use this model, please cite both the original GRPO paper and the GSM8K dataset.

Training Details

Training Data

The model was trained on the GSM8K dataset, which contains 8.5K grade school math word problems.

Training Procedure

  • Hardware: Multi-GPU setup with NVIDIA GPUs
  • Framework: Hugging Face Transformers, TRL (Transformer Reinforcement Learning)
  • Optimization: GRPO with dual reward functions