File size: 3,257 Bytes
4d4870d
58467f5
 
 
 
 
 
 
 
 
 
 
 
 
 
4d4870d
58467f5
 
4d4870d
 
58467f5
 
 
 
 
6b76fae
 
 
58467f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d4870d
58467f5
 
 
 
 
4d4870d
58467f5
 
4d4870d
 
 
58467f5
4d4870d
 
58467f5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
---
language:
  - en
tags:
  - math
  - reasoning
  - grpo
  - gsm8k
  - reinforcement-learning
license: apache-2.0
datasets:
  - openai/gsm8k
metrics:
  - accuracy
  - format_adherence
library_name: transformers
pipeline_tag: text-generation
base_model: premai-io/prem-1B-chat
---

# GRPO-Trained Math Reasoning Model

## Model Description
This model was fine-tuned using GRPO (Generative Reward-Powered Optimization), a reinforcement learning technique that optimizes language models using multiple reward functions to improve both output format consistency and mathematical reasoning abilities.

## Disclaimer
Not an official Prem Labs product. This is WIP.

### Base Model
- Started with the `premai-io/prem-1B-chat` base model
- Uses Flash Attention 2 for efficient training
- Model architecture: Causal Language Model (CLM)

### GRPO Training Details
GRPO training involves optimizing the model using two specific reward functions:

1. **Format Reward Function**
   - Ensures responses follow a strict XML-style format:
   ```xml
   <reasoning>
   [step-by-step solution]
   </reasoning>
   <answer>
   [numerical answer]
   </answer>
   ```
   - Rewards are given for:
     - Strict format adherence (0.5 points)
     - Soft format matching (0.3 points)
     - Integer answer format (0.5 points)
     - Proper XML structure (up to 0.5 points)

2. **Correctness Reward Function**
   - Evaluates mathematical accuracy
   - Awards 2.0 points for correct numerical answers
   - Awards 0.0 points for incorrect answers

### Training Configuration
- Learning rate: 5e-6
- Batch size: 2 per device
- Gradient accumulation steps: 2
- Training epochs: 1
- Uses cosine learning rate scheduler
- Warmup ratio: 0.1
- Uses bfloat16 precision
- Maximum prompt length: 256 tokens
- Maximum completion length: 200 tokens
- Number of generations per prompt: 16

### Dataset
- Trained on the GSM8K (Grade School Math 8K) dataset
- Dataset contains grade-school level math word problems
- Each problem includes a question and step-by-step solution

## Intended Use
- Solving mathematical word problems
- Providing step-by-step reasoning for solutions
- Educational assistance and math tutoring

## Limitations
- Limited to the complexity level of GSM8K problems
- May struggle with problems requiring knowledge beyond the training data
- Performance depends on proper formatting of input queries

## Training Infrastructure
- Supports distributed training across multiple GPUs
- Uses NCCL backend for distributed processing
- Implements gradient clipping (max norm: 0.1)

## Evaluation
The model's performance is continuously evaluated during training based on:
1. Format adherence to the XML structure
2. Mathematical accuracy of answers
3. Quality of step-by-step reasoning

## Citation
If you use this model, please cite both the original GRPO paper and the GSM8K dataset.

## Training Details
### Training Data
The model was trained on the [GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k), which contains 8.5K grade school math word problems.

### Training Procedure
- **Hardware:** Multi-GPU setup with NVIDIA GPUs
- **Framework:** Hugging Face Transformers, TRL (Transformer Reinforcement Learning)
- **Optimization:** GRPO with dual reward functions