SpiralSense / train.py
cycool29's picture
Update
e2b5593
raw
history blame
6.19 kB
import os
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from models import *
from torch.utils.tensorboard import SummaryWriter
from configs import *
import data_loader
def setup_tensorboard():
return SummaryWriter(log_dir="output/tensorboard/training")
def load_and_preprocess_data():
return data_loader.load_data(
RAW_DATA_DIR + str(TASK), AUG_DATA_DIR + str(TASK), EXTERNAL_DATA_DIR + str(TASK), preprocess
)
def initialize_model_optimizer_scheduler():
model = MODEL.to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)
return model, criterion, optimizer, scheduler
def plot_and_log_metrics(metrics_dict, step, writer, prefix="Train"):
for metric_name, metric_value in metrics_dict.items():
writer.add_scalar(f"{prefix}/{metric_name}", metric_value, step)
def train_one_epoch(model, criterion, optimizer, train_loader, epoch):
model.train()
running_loss = 0.0
total_train = 0
correct_train = 0
for i, (inputs, labels) in enumerate(train_loader, 0):
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad()
if model.__class__.__name__ == "GoogLeNet":
outputs = model(inputs).logits
else:
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if (i + 1) % NUM_PRINT == 0:
print(
"[Epoch %d, Batch %d] Loss: %.6f"
% (epoch + 1, i + 1, running_loss / NUM_PRINT)
)
running_loss = 0.0
_, predicted = torch.max(outputs, 1)
total_train += labels.size(0)
correct_train += (predicted == labels).sum().item()
avg_train_loss = running_loss / len(train_loader)
return avg_train_loss, correct_train / total_train
def validate_model(model, criterion, valid_loader):
model.eval()
val_loss = 0.0
correct_val = 0
total_val = 0
with torch.no_grad():
for inputs, labels in valid_loader:
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total_val += labels.size(0)
correct_val += (predicted == labels).sum().item()
avg_val_loss = val_loss / len(valid_loader)
return avg_val_loss, correct_val / total_val
def main_training_loop():
writer = setup_tensorboard()
train_loader, valid_loader = load_and_preprocess_data()
model, criterion, optimizer, scheduler = initialize_model_optimizer_scheduler()
best_val_loss = float("inf")
best_val_accuracy = 0.0
no_improvement_count = 0
AVG_TRAIN_LOSS_HIST = []
AVG_VAL_LOSS_HIST = []
TRAIN_ACC_HIST = []
VAL_ACC_HIST = []
for epoch in range(NUM_EPOCHS):
print(f"[Epoch: {epoch + 1}]")
print("Learning rate:", scheduler.get_last_lr()[0])
avg_train_loss, train_accuracy = train_one_epoch(
model, criterion, optimizer, train_loader, epoch
)
AVG_TRAIN_LOSS_HIST.append(avg_train_loss)
TRAIN_ACC_HIST.append(train_accuracy)
# Log training metrics
train_metrics = {
"Loss": avg_train_loss,
"Accuracy": train_accuracy,
}
plot_and_log_metrics(train_metrics, epoch, writer=writer, prefix="Train")
# Learning rate scheduling
scheduler.step()
avg_val_loss, val_accuracy = validate_model(model, criterion, valid_loader)
AVG_VAL_LOSS_HIST.append(avg_val_loss)
VAL_ACC_HIST.append(val_accuracy)
# Log validation metrics
val_metrics = {
"Loss": avg_val_loss,
"Accuracy": val_accuracy,
}
plot_and_log_metrics(train_metrics, epoch, writer=writer, prefix="Train")
# Print average training and validation metrics
print(f"Average Training Loss: {avg_train_loss:.6f}")
print(f"Average Validation Loss: {avg_val_loss:.6f}")
print(f"Training Accuracy: {train_accuracy:.6f}")
print(f"Validation Accuracy: {val_accuracy:.6f}")
# Check for early stopping based on validation accuracy
if val_accuracy > best_val_accuracy:
best_val_accuracy = val_accuracy
no_improvement_count = 0
else:
no_improvement_count += 1
# Early stopping condition
if no_improvement_count >= EARLY_STOPPING_PATIENCE:
print(
"Early stopping: Validation accuracy did not improve for {} consecutive epochs.".format(
EARLY_STOPPING_PATIENCE
)
)
break
# Save the model
MODEL_SAVE_PATH = "output/checkpoints/model.pth"
# Ensure the parent directory exists
os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print("Model saved at", MODEL_SAVE_PATH)
# Plot loss and accuracy curves
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(
range(1, len(AVG_TRAIN_LOSS_HIST) + 1),
AVG_TRAIN_LOSS_HIST,
label="Average Train Loss",
)
plt.plot(
range(1, len(AVG_VAL_LOSS_HIST) + 1),
AVG_VAL_LOSS_HIST,
label="Average Validation Loss",
)
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.title("Loss Curves")
plt.subplot(1, 2, 2)
plt.plot(range(1, len(TRAIN_ACC_HIST) + 1), TRAIN_ACC_HIST, label="Train Accuracy")
plt.plot(range(1, len(VAL_ACC_HIST) + 1), VAL_ACC_HIST, label="Validation Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.title("Accuracy Curves")
plt.tight_layout()
plt.savefig("training_curves.png")
# Close TensorBoard writer
writer.close()
if __name__ == "__main__":
main_training_loop()