jatingocodeo commited on
Commit
6442065
·
verified ·
1 Parent(s): 75af272

Create train_optimized.py

Browse files
Files changed (1) hide show
  1. train_optimized.py +298 -0
train_optimized.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import time
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import functional as F
7
+ from dataclasses import dataclass
8
+ from torch.nn.parallel import DistributedDataParallel as DDP
9
+ import numpy as np
10
+ from datetime import datetime
11
+
12
+ # Hyperparameters
13
+ learning_rate = 3e-4 # Peak learning rate
14
+ min_lr = 3e-5 # Minimum learning rate at the end of training
15
+ warmup_iters = 2000 # Linear warmup over warmup_iters
16
+ lr_decay_iters = 800000 # Cosine decay over lr_decay_iters
17
+ weight_decay = 0.1 # AdamW weight decay
18
+ beta1 = 0.9 # AdamW beta1
19
+ beta2 = 0.95 # AdamW beta2
20
+ grad_clip = 1.0 # Clip gradients at this value
21
+ decay_lr = True # Whether to decay learning rate
22
+ batch_size = 64 # Training batch size
23
+ block_size = 256 # Maximum sequence length
24
+ eval_interval = 500 # How often to evaluate
25
+ eval_iters = 200 # Number of iterations to use for evaluation
26
+ log_interval = 10 # How often to print training progress
27
+
28
+ # Model architecture
29
+ @dataclass
30
+ class GPTConfig:
31
+ block_size: int = block_size
32
+ vocab_size: int = 50304
33
+ n_layer: int = 12
34
+ n_head: int = 16
35
+ n_embd: int = 1024
36
+ dropout: float = 0.1
37
+ bias: bool = False
38
+
39
+ class CausalSelfAttention(nn.Module):
40
+ def __init__(self, config):
41
+ super().__init__()
42
+ assert config.n_embd % config.n_head == 0
43
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
44
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
45
+ self.attn_dropout = nn.Dropout(config.dropout)
46
+ self.resid_dropout = nn.Dropout(config.dropout)
47
+ self.n_head = config.n_head
48
+ self.n_embd = config.n_embd
49
+ self.dropout = config.dropout
50
+
51
+ def forward(self, x):
52
+ B, T, C = x.size()
53
+ qkv = self.c_attn(x)
54
+ q, k, v = qkv.split(self.n_embd, dim=2)
55
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
56
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
57
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
58
+
59
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
60
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
61
+ y = self.resid_dropout(self.c_proj(y))
62
+ return y
63
+
64
+ class MLP(nn.Module):
65
+ def __init__(self, config):
66
+ super().__init__()
67
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
68
+ self.gelu = nn.GELU()
69
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
70
+ self.dropout = nn.Dropout(config.dropout)
71
+
72
+ def forward(self, x):
73
+ x = self.c_fc(x)
74
+ x = self.gelu(x)
75
+ x = self.c_proj(x)
76
+ x = self.dropout(x)
77
+ return x
78
+
79
+ class Block(nn.Module):
80
+ def __init__(self, config):
81
+ super().__init__()
82
+ self.ln_1 = nn.LayerNorm(config.n_embd)
83
+ self.attn = CausalSelfAttention(config)
84
+ self.ln_2 = nn.LayerNorm(config.n_embd)
85
+ self.mlp = MLP(config)
86
+
87
+ def forward(self, x):
88
+ x = x + self.attn(self.ln_1(x))
89
+ x = x + self.mlp(self.ln_2(x))
90
+ return x
91
+
92
+ class GPT(nn.Module):
93
+ def __init__(self, config):
94
+ super().__init__()
95
+ self.config = config
96
+ self.transformer = nn.ModuleDict(dict(
97
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
98
+ wpe = nn.Embedding(config.block_size, config.n_embd),
99
+ drop = nn.Dropout(config.dropout),
100
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
101
+ ln_f = nn.LayerNorm(config.n_embd)
102
+ ))
103
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
104
+ self.transformer.wte.weight = self.lm_head.weight
105
+
106
+ # Initialize weights
107
+ self.apply(self._init_weights)
108
+ for pn, p in self.named_parameters():
109
+ if pn.endswith('c_proj.weight'):
110
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
111
+
112
+ def _init_weights(self, module):
113
+ if isinstance(module, nn.Linear):
114
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
115
+ if module.bias is not None:
116
+ torch.nn.init.zeros_(module.bias)
117
+ elif isinstance(module, nn.Embedding):
118
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
119
+
120
+ def forward(self, idx, targets=None):
121
+ device = idx.device
122
+ b, t = idx.size()
123
+ pos = torch.arange(0, t, dtype=torch.long, device=device)
124
+
125
+ tok_emb = self.transformer.wte(idx)
126
+ pos_emb = self.transformer.wpe(pos)
127
+ x = self.transformer.drop(tok_emb + pos_emb)
128
+
129
+ for block in self.transformer.h:
130
+ x = block(x)
131
+ x = self.transformer.ln_f(x)
132
+
133
+ if targets is not None:
134
+ logits = self.lm_head(x)
135
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
136
+ else:
137
+ logits = self.lm_head(x[:, [-1], :])
138
+ loss = None
139
+
140
+ return logits, loss
141
+
142
+ @torch.no_grad()
143
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
144
+ for _ in range(max_new_tokens):
145
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
146
+ logits, _ = self(idx_cond)
147
+ logits = logits[:, -1, :] / temperature
148
+ if top_k is not None:
149
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
150
+ logits[logits < v[:, [-1]]] = -float('Inf')
151
+ probs = F.softmax(logits, dim=-1)
152
+ idx_next = torch.multinomial(probs, num_samples=1)
153
+ idx = torch.cat((idx, idx_next), dim=1)
154
+ return idx
155
+
156
+ def get_batch(data, block_size, batch_size):
157
+ ix = torch.randint(len(data) - block_size, (batch_size,))
158
+ x = torch.stack([data[i:i+block_size] for i in ix])
159
+ y = torch.stack([data[i+1:i+1+block_size] for i in ix])
160
+ return x, y
161
+
162
+ def get_lr(it):
163
+ # 1) Linear warmup for warmup_iters steps
164
+ if it < warmup_iters:
165
+ return learning_rate * it / warmup_iters
166
+ # 2) If it > lr_decay_iters, return min learning rate
167
+ if it > lr_decay_iters:
168
+ return min_lr
169
+ # 3) In between, use cosine decay down to min learning rate
170
+ decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
171
+ assert 0 <= decay_ratio <= 1
172
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
173
+ return min_lr + coeff * (learning_rate - min_lr)
174
+
175
+ def save_training_log(log_entry, filename='training_logs.md'):
176
+ """Save training logs in markdown format"""
177
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
178
+ with open(filename, 'a') as f:
179
+ if not f.tell(): # If file is empty, write header
180
+ f.write('# Training Logs\n\n')
181
+ f.write('| Timestamp | Iteration | Training Loss | Learning Rate |\n')
182
+ f.write('|-----------|------------|---------------|---------------|\n')
183
+ f.write(f'| {timestamp} | {log_entry["iter"]:10d} | {log_entry["train_loss"]:.6f} | {log_entry["lr"]:.2e} |\n')
184
+
185
+ def save_model(model, optimizer, iter_num, loss, filename):
186
+ """Save model checkpoint with error handling"""
187
+ try:
188
+ # First save to a temporary file
189
+ tmp_filename = filename + '.tmp'
190
+ checkpoint = {
191
+ 'model_state_dict': model.state_dict(),
192
+ 'optimizer_state_dict': optimizer.state_dict(),
193
+ 'iter_num': iter_num,
194
+ 'loss': loss,
195
+ }
196
+
197
+ # Use torch.save with zip compression
198
+ torch.save(checkpoint, tmp_filename, _use_new_zipfile_serialization=True)
199
+
200
+ # If save was successful, rename tmp file to final filename
201
+ if os.path.exists(filename):
202
+ os.remove(filename) # Remove old file if it exists
203
+ os.rename(tmp_filename, filename)
204
+ return True
205
+ except Exception as e:
206
+ print(f"Error saving model to {filename}: {str(e)}")
207
+ # Clean up temp file if it exists
208
+ if os.path.exists(tmp_filename):
209
+ try:
210
+ os.remove(tmp_filename)
211
+ except:
212
+ pass
213
+ return False
214
+
215
+ def main():
216
+ torch.manual_seed(1337)
217
+ torch.backends.cuda.matmul.allow_tf32 = True
218
+ torch.backends.cudnn.allow_tf32 = True
219
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
220
+ print(f"Using device: {device}")
221
+
222
+ # Create checkpoint directory
223
+ os.makedirs('checkpoints', exist_ok=True)
224
+
225
+ # Load the data
226
+ with open('input.txt', 'r') as f:
227
+ text = f.read()
228
+ chars = sorted(list(set(text)))
229
+ vocab_size = len(chars)
230
+ stoi = {ch:i for i,ch in enumerate(chars)}
231
+ itos = {i:ch for i,ch in enumerate(chars)}
232
+ encode = lambda s: [stoi[c] for c in s]
233
+ data = torch.tensor(encode(text), dtype=torch.long)
234
+ n = int(0.9 * len(data))
235
+ train_data = data[:n]
236
+ val_data = data[n:]
237
+
238
+ # Initialize the model
239
+ model = GPT(GPTConfig(vocab_size=vocab_size))
240
+ model = model.to(device)
241
+ print(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
242
+
243
+ # Initialize optimizer
244
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(beta1, beta2), weight_decay=weight_decay)
245
+
246
+ # Training loop
247
+ best_train_loss = float('inf')
248
+ iter_num = 0
249
+
250
+ while True:
251
+ # Get batch and learning rate
252
+ xb, yb = get_batch(train_data, block_size, batch_size)
253
+ xb, yb = xb.to(device), yb.to(device)
254
+ lr = get_lr(iter_num) if decay_lr else learning_rate
255
+ for param_group in optimizer.param_groups:
256
+ param_group['lr'] = lr
257
+
258
+ # Forward pass
259
+ logits, loss = model(xb, yb)
260
+ optimizer.zero_grad(set_to_none=True)
261
+ loss.backward()
262
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
263
+ optimizer.step()
264
+
265
+ # Logging and model saving
266
+ if iter_num % log_interval == 0:
267
+ train_loss = loss.item()
268
+ print(f"iter {iter_num}: loss {train_loss:.4f}, lr {lr:e}")
269
+ save_training_log({
270
+ "iter": iter_num,
271
+ "train_loss": train_loss,
272
+ "lr": lr
273
+ })
274
+
275
+ # Save model if loss improved
276
+ if train_loss < best_train_loss:
277
+ best_train_loss = train_loss
278
+ print(f"Saving model with training loss: {best_train_loss:.6f}")
279
+
280
+ # Save the latest model
281
+ save_model(
282
+ model,
283
+ optimizer,
284
+ iter_num,
285
+ best_train_loss,
286
+ os.path.join('checkpoints', 'latest_model.pt')
287
+ )
288
+
289
+ if best_train_loss < 0.099999:
290
+ print(f"Achieved target loss of {best_train_loss:.6f}")
291
+ break
292
+
293
+ iter_num += 1
294
+ if iter_num > lr_decay_iters:
295
+ break
296
+
297
+ if __name__ == '__main__':
298
+ main()