|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from tensorboardX import SummaryWriter |
|
from safetensors.torch import save_file, load_file |
|
from pathlib import Path |
|
import time |
|
|
|
def count_parameters_layerwise(model): |
|
|
|
total_params = 0 |
|
layer_params = {} |
|
|
|
for name, parameter in model.named_parameters(): |
|
if not parameter.requires_grad: |
|
continue |
|
|
|
param_count = parameter.numel() |
|
layer_params[name] = param_count |
|
total_params += param_count |
|
|
|
print(f"\nModel Parameter Summary:") |
|
print("-" * 60) |
|
for name, count in layer_params.items(): |
|
print(f"{name}: {count:,} parameters") |
|
print("-" * 60) |
|
print(f"Total Trainable Parameters: {total_params:,}\n") |
|
|
|
return total_params |
|
|
|
def save_checkpoint(model, filename="checkpoint.safetensors"): |
|
if hasattr(model, '_orig_mod'): |
|
model = model._orig_mod |
|
|
|
torch.save(model.state_dict(), filename.replace('.safetensors', '.pt')) |
|
|
|
def load_checkpoint(model, filename="checkpoint.safetensors"): |
|
if hasattr(model, '_orig_mod'): |
|
model = model._orig_mod |
|
|
|
try: |
|
model_state = load_file(filename) |
|
model.load_state_dict(model_state) |
|
except Exception as e: |
|
model_state = torch.load(filename.replace('.safetensors', '.pt'), weights_only=True) |
|
model.load_state_dict(model_state) |
|
|
|
class TBLogger: |
|
def __init__(self, log_dir='logs/current_run', flush_secs=10, enable_grad_logging=True): |
|
Path(log_dir).mkdir(parents=True, exist_ok=True) |
|
self.writer = SummaryWriter(log_dir, flush_secs=flush_secs) |
|
self.enable_grad_logging = enable_grad_logging |
|
self.start_time = time.time() |
|
|
|
def log(self, metrics, step=None, model=None, prefix='', grad_checking=False): |
|
for name, value in metrics.items(): |
|
full_name = f"{prefix}{name}" if prefix else name |
|
|
|
if isinstance(value, (int, float)): |
|
self.writer.add_scalar(full_name, value, step) |
|
elif isinstance(value, torch.Tensor): |
|
self.writer.add_scalar(full_name, value.item(), step) |
|
elif isinstance(value, (list, tuple)) and len(value) > 0: |
|
if all(isinstance(x, (int, float)) for x in value): |
|
self.writer.add_histogram(full_name, torch.tensor(value), step) |
|
|
|
if self.enable_grad_logging and model is not None: |
|
self._log_gradients(model, step, grad_checking) |
|
|
|
def _log_gradients(self, model, step, grad_checking): |
|
total_norm = 0.0 |
|
for name, param in model.named_parameters(): |
|
if grad_checking and param.grad is not None: |
|
|
|
if torch.isnan(param.grad).any(): |
|
print(f"Warning: Found nan in gradients for layer: {name}") |
|
continue |
|
if torch.isinf(param.grad).any(): |
|
print(f"Warning: Found inf in gradients for layer: {name}") |
|
continue |
|
|
|
param_norm = param.grad.detach().data.norm(2) |
|
self.writer.add_scalar(f"gradients/{name}_norm", param_norm, step) |
|
total_norm += param_norm.item() ** 2 |
|
|
|
|
|
if total_norm > 0: |
|
total_norm = total_norm ** 0.5 |
|
self.writer.add_scalar("gradients/total_norm", total_norm, step) |
|
|
|
def close(self): |
|
self.writer.close() |