Spaces:
Runtime error
Runtime error
Update
Browse files- augment.py +13 -5
- configs.py +10 -3
- data_loader.py +8 -5
- eval.py +9 -17
- predict.py +5 -4
- train.py +44 -18
- tuning.py +84 -169
augment.py
CHANGED
@@ -14,10 +14,20 @@ for task in tasks:
|
|
14 |
if not os.path.exists(f"data/temp/Task {task}/{disease}/"):
|
15 |
os.makedirs(f"data/temp/Task {task}/{disease}/")
|
16 |
for file in os.listdir(f"data/train/raw/Task {task}/{disease}"):
|
17 |
-
shutil.copy(
|
|
|
|
|
|
|
18 |
for file in os.listdir(f"data/train/external/Task {task}/{disease}"):
|
19 |
-
shutil.copy(
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
p.rotate(probability=0.8, max_left_rotation=5, max_right_rotation=5)
|
22 |
p.flip_left_right(probability=0.8)
|
23 |
p.zoom_random(probability=0.8, percentage_area=0.8)
|
@@ -46,5 +56,3 @@ for task in tasks:
|
|
46 |
f"data/train/augmented/Task {task}/{disease}/{file}",
|
47 |
f"data/train/augmented/Task {task}/{disease}/{number}.png",
|
48 |
)
|
49 |
-
|
50 |
-
|
|
|
14 |
if not os.path.exists(f"data/temp/Task {task}/{disease}/"):
|
15 |
os.makedirs(f"data/temp/Task {task}/{disease}/")
|
16 |
for file in os.listdir(f"data/train/raw/Task {task}/{disease}"):
|
17 |
+
shutil.copy(
|
18 |
+
f"data/train/raw/Task {task}/{disease}/{file}",
|
19 |
+
f"data/temp/Task {task}/{disease}/{file}",
|
20 |
+
)
|
21 |
for file in os.listdir(f"data/train/external/Task {task}/{disease}"):
|
22 |
+
shutil.copy(
|
23 |
+
f"data/train/external/Task {task}/{disease}/{file}",
|
24 |
+
f"data/temp/Task {task}/{disease}/{file}",
|
25 |
+
)
|
26 |
+
p = Augmentor.Pipeline(
|
27 |
+
f"data/temp/Task {task}/{disease}",
|
28 |
+
output_directory=f"{disease}/",
|
29 |
+
save_format="png",
|
30 |
+
)
|
31 |
p.rotate(probability=0.8, max_left_rotation=5, max_right_rotation=5)
|
32 |
p.flip_left_right(probability=0.8)
|
33 |
p.zoom_random(probability=0.8, percentage_area=0.8)
|
|
|
56 |
f"data/train/augmented/Task {task}/{disease}/{file}",
|
57 |
f"data/train/augmented/Task {task}/{disease}/{number}.png",
|
58 |
)
|
|
|
|
configs.py
CHANGED
@@ -6,9 +6,10 @@ from models import *
|
|
6 |
|
7 |
# Constants
|
8 |
RANDOM_SEED = 123
|
9 |
-
BATCH_SIZE =
|
10 |
NUM_EPOCHS = 100
|
11 |
-
LEARNING_RATE =
|
|
|
12 |
STEP_SIZE = 10
|
13 |
GAMMA = 0.5
|
14 |
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
@@ -18,8 +19,13 @@ RAW_DATA_DIR = r"data/train/raw/Task " + str(TASK)
|
|
18 |
AUG_DATA_DIR = r"data/train/augmented/Task " + str(TASK)
|
19 |
EXTERNAL_DATA_DIR = r"data/train/external/Task " + str(TASK)
|
20 |
NUM_CLASSES = 7
|
|
|
|
|
21 |
MODEL_SAVE_PATH = "output/checkpoints/model.pth"
|
22 |
-
MODEL =
|
|
|
|
|
|
|
23 |
|
24 |
preprocess = transforms.Compose(
|
25 |
[
|
@@ -30,6 +36,7 @@ preprocess = transforms.Compose(
|
|
30 |
]
|
31 |
)
|
32 |
|
|
|
33 |
# Custom dataset class
|
34 |
class CustomDataset(Dataset):
|
35 |
def __init__(self, dataset):
|
|
|
6 |
|
7 |
# Constants
|
8 |
RANDOM_SEED = 123
|
9 |
+
BATCH_SIZE = 128
|
10 |
NUM_EPOCHS = 100
|
11 |
+
LEARNING_RATE = 0.04279442975996121
|
12 |
+
OPTIMIZER_NAME = "LBFGS"
|
13 |
STEP_SIZE = 10
|
14 |
GAMMA = 0.5
|
15 |
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
19 |
AUG_DATA_DIR = r"data/train/augmented/Task " + str(TASK)
|
20 |
EXTERNAL_DATA_DIR = r"data/train/external/Task " + str(TASK)
|
21 |
NUM_CLASSES = 7
|
22 |
+
# Define classes as listdir of augmented data
|
23 |
+
CLASSES = os.listdir("data/train/augmented/Task 1/")
|
24 |
MODEL_SAVE_PATH = "output/checkpoints/model.pth"
|
25 |
+
MODEL = mobilenet_v2(num_classes=NUM_CLASSES)
|
26 |
+
|
27 |
+
print(CLASSES)
|
28 |
+
|
29 |
|
30 |
preprocess = transforms.Compose(
|
31 |
[
|
|
|
36 |
]
|
37 |
)
|
38 |
|
39 |
+
|
40 |
# Custom dataset class
|
41 |
class CustomDataset(Dataset):
|
42 |
def __init__(self, dataset):
|
data_loader.py
CHANGED
@@ -3,14 +3,17 @@ from torchvision.datasets import ImageFolder
|
|
3 |
from torch.utils.data import random_split, DataLoader, Dataset
|
4 |
|
5 |
|
6 |
-
def load_data(raw_dir, augmented_dir, external_dir, preprocess):
|
7 |
# Load the dataset using ImageFolder
|
8 |
raw_dataset = ImageFolder(root=raw_dir, transform=preprocess)
|
9 |
external_dataset = ImageFolder(root=external_dir, transform=preprocess)
|
10 |
augmented_dataset = ImageFolder(root=augmented_dir, transform=preprocess)
|
11 |
dataset = raw_dataset + external_dataset + augmented_dataset
|
12 |
|
13 |
-
|
|
|
|
|
|
|
14 |
print("Length of raw dataset: ", len(raw_dataset))
|
15 |
print("Length of external dataset: ", len(external_dataset))
|
16 |
print("Length of augmented dataset: ", len(augmented_dataset))
|
@@ -23,10 +26,10 @@ def load_data(raw_dir, augmented_dir, external_dir, preprocess):
|
|
23 |
|
24 |
# Create data loaders for the custom dataset
|
25 |
train_loader = DataLoader(
|
26 |
-
CustomDataset(train_dataset), batch_size=
|
27 |
)
|
28 |
valid_loader = DataLoader(
|
29 |
-
CustomDataset(val_dataset), batch_size=
|
30 |
)
|
31 |
-
|
32 |
return train_loader, valid_loader
|
|
|
3 |
from torch.utils.data import random_split, DataLoader, Dataset
|
4 |
|
5 |
|
6 |
+
def load_data(raw_dir, augmented_dir, external_dir, preprocess, batch_size=BATCH_SIZE):
|
7 |
# Load the dataset using ImageFolder
|
8 |
raw_dataset = ImageFolder(root=raw_dir, transform=preprocess)
|
9 |
external_dataset = ImageFolder(root=external_dir, transform=preprocess)
|
10 |
augmented_dataset = ImageFolder(root=augmented_dir, transform=preprocess)
|
11 |
dataset = raw_dataset + external_dataset + augmented_dataset
|
12 |
|
13 |
+
# Classes
|
14 |
+
classes = augmented_dataset.classes
|
15 |
+
|
16 |
+
print("Classes: ", *classes, sep=", ")
|
17 |
print("Length of raw dataset: ", len(raw_dataset))
|
18 |
print("Length of external dataset: ", len(external_dataset))
|
19 |
print("Length of augmented dataset: ", len(augmented_dataset))
|
|
|
26 |
|
27 |
# Create data loaders for the custom dataset
|
28 |
train_loader = DataLoader(
|
29 |
+
CustomDataset(train_dataset), batch_size=batch_size, shuffle=True, num_workers=0
|
30 |
)
|
31 |
valid_loader = DataLoader(
|
32 |
+
CustomDataset(val_dataset), batch_size=batch_size, num_workers=0
|
33 |
)
|
34 |
+
|
35 |
return train_loader, valid_loader
|
eval.py
CHANGED
@@ -19,37 +19,27 @@ MODEL = MODEL.to(DEVICE)
|
|
19 |
MODEL.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
|
20 |
MODEL.eval()
|
21 |
|
22 |
-
# Get class labels from the dataset
|
23 |
-
class_labels = os.listdir(image_path)
|
24 |
-
|
25 |
-
# Define transformation for preprocessing
|
26 |
-
preprocess = transforms.Compose(
|
27 |
-
[
|
28 |
-
transforms.Resize((64, 64)), # Resize images to 64x64
|
29 |
-
transforms.ToTensor(), # Convert to tensor
|
30 |
-
transforms.Normalize((0.5,), (0.5,)), # Normalize (for grayscale)
|
31 |
-
]
|
32 |
-
)
|
33 |
|
34 |
def predict_image(image_path, model, transform):
|
35 |
model.eval()
|
36 |
correct_predictions = 0
|
37 |
-
total_predictions = len(images)
|
38 |
|
39 |
# Get a list of image files
|
40 |
images = list(pathlib.Path(image_path).rglob("*.png"))
|
41 |
|
|
|
|
|
42 |
true_classes = []
|
43 |
predicted_labels = []
|
44 |
|
45 |
with torch.no_grad():
|
46 |
for image_file in images:
|
47 |
-
print(
|
48 |
# Check the true label of the image by checking the sequence of the folder in Task 1
|
49 |
-
true_class =
|
50 |
print("Image path:", image_file)
|
51 |
print("True class:", true_class)
|
52 |
-
image = Image.open(image_file).convert(
|
53 |
image = transform(image).unsqueeze(0)
|
54 |
image = image.to(DEVICE)
|
55 |
output = model(image)
|
@@ -67,7 +57,7 @@ def predict_image(image_path, model, transform):
|
|
67 |
# Calculate accuracy and f1 score
|
68 |
accuracy = correct_predictions / total_predictions
|
69 |
print("Accuracy:", accuracy)
|
70 |
-
f1 = f1_score(true_classes, predicted_labels, average=
|
71 |
print("Weighted F1 Score:", f1)
|
72 |
|
73 |
# Convert the lists to tensors
|
@@ -75,12 +65,14 @@ def predict_image(image_path, model, transform):
|
|
75 |
true_classes_tensor = torch.tensor(true_classes)
|
76 |
|
77 |
# Create a confusion matrix
|
78 |
-
conf_matrix = ConfusionMatrix(num_classes=NUM_CLASSES, task=
|
79 |
conf_matrix.update(predicted_labels_tensor, true_classes_tensor)
|
80 |
|
81 |
# Plot the confusion matrix
|
|
|
82 |
conf_matrix.plot()
|
83 |
plt.show()
|
84 |
|
|
|
85 |
# Call predict_image function
|
86 |
predict_image(image_path, MODEL, preprocess)
|
|
|
19 |
MODEL.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
|
20 |
MODEL.eval()
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
def predict_image(image_path, model, transform):
|
24 |
model.eval()
|
25 |
correct_predictions = 0
|
|
|
26 |
|
27 |
# Get a list of image files
|
28 |
images = list(pathlib.Path(image_path).rglob("*.png"))
|
29 |
|
30 |
+
total_predictions = len(images)
|
31 |
+
|
32 |
true_classes = []
|
33 |
predicted_labels = []
|
34 |
|
35 |
with torch.no_grad():
|
36 |
for image_file in images:
|
37 |
+
print("---------------------------")
|
38 |
# Check the true label of the image by checking the sequence of the folder in Task 1
|
39 |
+
true_class = CLASSES.index(image_file.parts[-2])
|
40 |
print("Image path:", image_file)
|
41 |
print("True class:", true_class)
|
42 |
+
image = Image.open(image_file).convert("RGB")
|
43 |
image = transform(image).unsqueeze(0)
|
44 |
image = image.to(DEVICE)
|
45 |
output = model(image)
|
|
|
57 |
# Calculate accuracy and f1 score
|
58 |
accuracy = correct_predictions / total_predictions
|
59 |
print("Accuracy:", accuracy)
|
60 |
+
f1 = f1_score(true_classes, predicted_labels, average="weighted")
|
61 |
print("Weighted F1 Score:", f1)
|
62 |
|
63 |
# Convert the lists to tensors
|
|
|
65 |
true_classes_tensor = torch.tensor(true_classes)
|
66 |
|
67 |
# Create a confusion matrix
|
68 |
+
conf_matrix = ConfusionMatrix(num_classes=NUM_CLASSES, task="multiclass")
|
69 |
conf_matrix.update(predicted_labels_tensor, true_classes_tensor)
|
70 |
|
71 |
# Plot the confusion matrix
|
72 |
+
conf_matrix.compute()
|
73 |
conf_matrix.plot()
|
74 |
plt.show()
|
75 |
|
76 |
+
|
77 |
# Call predict_image function
|
78 |
predict_image(image_path, MODEL, preprocess)
|
predict.py
CHANGED
@@ -20,13 +20,11 @@ torch.set_grad_enabled(False)
|
|
20 |
|
21 |
|
22 |
def predict_image(image_path, model=MODEL, transform=preprocess):
|
23 |
-
classes =
|
24 |
-
'Cerebral Palsy', 'Dystonia', 'Essential Tremor', 'Healthy', 'Huntington Disease', 'Parkinson Disease'
|
25 |
-
]
|
26 |
|
27 |
print("---------------------------")
|
28 |
print("Image path:", image_path)
|
29 |
-
image = Image.open(image_path)
|
30 |
image = transform(image).unsqueeze(0)
|
31 |
image = image.to(DEVICE)
|
32 |
output = model(image)
|
@@ -55,3 +53,6 @@ def predict_image(image_path, model=MODEL, transform=preprocess):
|
|
55 |
print("---------------------------")
|
56 |
|
57 |
return predicted_label, sorted_classes
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
def predict_image(image_path, model=MODEL, transform=preprocess):
|
23 |
+
classes = CLASSES
|
|
|
|
|
24 |
|
25 |
print("---------------------------")
|
26 |
print("Image path:", image_path)
|
27 |
+
image = Image.open(image_path).convert("RGB")
|
28 |
image = transform(image).unsqueeze(0)
|
29 |
image = image.to(DEVICE)
|
30 |
output = model(image)
|
|
|
53 |
print("---------------------------")
|
54 |
|
55 |
return predicted_label, sorted_classes
|
56 |
+
|
57 |
+
|
58 |
+
predict_image("data/test/Task 1/Healthy/01.png")
|
train.py
CHANGED
@@ -5,12 +5,12 @@ import torch.optim as optim
|
|
5 |
from torchvision.transforms import transforms
|
6 |
from torch.utils.data import DataLoader
|
7 |
from torchvision.utils import make_grid
|
8 |
-
from scipy.ndimage import gaussian_filter1d
|
9 |
import matplotlib.pyplot as plt
|
10 |
from models import *
|
11 |
from torch.utils.tensorboard import SummaryWriter
|
12 |
from configs import *
|
13 |
import data_loader
|
|
|
14 |
|
15 |
# Set up TensorBoard writer
|
16 |
writer = SummaryWriter(log_dir="output/tensorboard/training")
|
@@ -28,9 +28,21 @@ train_loader, valid_loader = data_loader.load_data(
|
|
28 |
# Initialize model, criterion, optimizer, and scheduler
|
29 |
MODEL = MODEL.to(DEVICE)
|
30 |
criterion = nn.CrossEntropyLoss()
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
# Lists to store training and validation loss history
|
35 |
TRAIN_LOSS_HIST = []
|
36 |
VAL_LOSS_HIST = []
|
@@ -41,6 +53,7 @@ VAL_ACC_HIST = []
|
|
41 |
|
42 |
# Training loop
|
43 |
for epoch in range(NUM_EPOCHS):
|
|
|
44 |
MODEL.train() # Set model to training mode
|
45 |
running_loss = 0.0
|
46 |
total_train = 0
|
@@ -52,7 +65,10 @@ for epoch in range(NUM_EPOCHS):
|
|
52 |
outputs = MODEL(inputs)
|
53 |
loss = criterion(outputs, labels)
|
54 |
loss.backward()
|
55 |
-
|
|
|
|
|
|
|
56 |
running_loss += loss.item()
|
57 |
|
58 |
if (i + 1) % NUM_PRINT == 0:
|
@@ -67,9 +83,9 @@ for epoch in range(NUM_EPOCHS):
|
|
67 |
correct_train += (predicted == labels).sum().item()
|
68 |
|
69 |
avg_train_loss = running_loss / len(train_loader)
|
70 |
-
|
71 |
TRAIN_ACC_HIST.append(correct_train / total_train)
|
72 |
-
|
73 |
# Log training metrics
|
74 |
train_metrics = {
|
75 |
"Loss": avg_train_loss,
|
@@ -98,7 +114,7 @@ for epoch in range(NUM_EPOCHS):
|
|
98 |
correct_val += (predicted == labels).sum().item()
|
99 |
|
100 |
avg_val_loss = val_loss / len(valid_loader)
|
101 |
-
|
102 |
VAL_ACC_HIST.append(correct_val / total_val)
|
103 |
|
104 |
# Log validation metrics
|
@@ -107,14 +123,24 @@ for epoch in range(NUM_EPOCHS):
|
|
107 |
"Accuracy": correct_val / total_val,
|
108 |
}
|
109 |
plot_and_log_metrics(val_metrics, epoch, prefix="Validation")
|
110 |
-
|
111 |
-
#
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
# Save the model
|
120 |
torch.save(MODEL.state_dict(), MODEL_SAVE_PATH)
|
@@ -123,16 +149,16 @@ print("Model saved at", MODEL_SAVE_PATH)
|
|
123 |
# Plot loss and accuracy curves
|
124 |
plt.figure(figsize=(12, 4))
|
125 |
plt.subplot(1, 2, 1)
|
126 |
-
plt.plot(range(1,
|
127 |
-
plt.plot(range(1,
|
128 |
plt.xlabel("Epochs")
|
129 |
plt.ylabel("Loss")
|
130 |
plt.legend()
|
131 |
plt.title("Loss Curves")
|
132 |
|
133 |
plt.subplot(1, 2, 2)
|
134 |
-
plt.plot(range(1,
|
135 |
-
plt.plot(range(1,
|
136 |
plt.xlabel("Epochs")
|
137 |
plt.ylabel("Accuracy")
|
138 |
plt.legend()
|
|
|
5 |
from torchvision.transforms import transforms
|
6 |
from torch.utils.data import DataLoader
|
7 |
from torchvision.utils import make_grid
|
|
|
8 |
import matplotlib.pyplot as plt
|
9 |
from models import *
|
10 |
from torch.utils.tensorboard import SummaryWriter
|
11 |
from configs import *
|
12 |
import data_loader
|
13 |
+
import numpy as np
|
14 |
|
15 |
# Set up TensorBoard writer
|
16 |
writer = SummaryWriter(log_dir="output/tensorboard/training")
|
|
|
28 |
# Initialize model, criterion, optimizer, and scheduler
|
29 |
MODEL = MODEL.to(DEVICE)
|
30 |
criterion = nn.CrossEntropyLoss()
|
31 |
+
if OPTIMIZER_NAME == "LBFGS":
|
32 |
+
optimizer = optim.LBFGS(MODEL.parameters(), lr=LEARNING_RATE)
|
33 |
+
elif OPTIMIZER_NAME == "Adam":
|
34 |
+
optimizer = optim.Adam(MODEL.parameters(), lr=LEARNING_RATE)
|
35 |
+
elif OPTIMIZER_NAME == "SGD":
|
36 |
+
optimizer = optim.SGD(MODEL.parameters(), lr=LEARNING_RATE)
|
37 |
+
|
38 |
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)
|
39 |
|
40 |
+
# Define early stopping parameters
|
41 |
+
early_stopping_patience = 20 # Number of epochs with no improvement to wait before stopping
|
42 |
+
best_val_loss = float("inf")
|
43 |
+
best_val_accuracy = 0.0
|
44 |
+
no_improvement_count = 0
|
45 |
+
|
46 |
# Lists to store training and validation loss history
|
47 |
TRAIN_LOSS_HIST = []
|
48 |
VAL_LOSS_HIST = []
|
|
|
53 |
|
54 |
# Training loop
|
55 |
for epoch in range(NUM_EPOCHS):
|
56 |
+
print(f"[Epoch: {epoch + 1}]")
|
57 |
MODEL.train() # Set model to training mode
|
58 |
running_loss = 0.0
|
59 |
total_train = 0
|
|
|
65 |
outputs = MODEL(inputs)
|
66 |
loss = criterion(outputs, labels)
|
67 |
loss.backward()
|
68 |
+
if OPTIMIZER_NAME == "LBFGS":
|
69 |
+
optimizer.step(closure=lambda: loss)
|
70 |
+
else:
|
71 |
+
optimizer.step()
|
72 |
running_loss += loss.item()
|
73 |
|
74 |
if (i + 1) % NUM_PRINT == 0:
|
|
|
83 |
correct_train += (predicted == labels).sum().item()
|
84 |
|
85 |
avg_train_loss = running_loss / len(train_loader)
|
86 |
+
AVG_TRAIN_LOSS_HIST.append(avg_train_loss)
|
87 |
TRAIN_ACC_HIST.append(correct_train / total_train)
|
88 |
+
|
89 |
# Log training metrics
|
90 |
train_metrics = {
|
91 |
"Loss": avg_train_loss,
|
|
|
114 |
correct_val += (predicted == labels).sum().item()
|
115 |
|
116 |
avg_val_loss = val_loss / len(valid_loader)
|
117 |
+
AVG_VAL_LOSS_HIST.append(avg_val_loss)
|
118 |
VAL_ACC_HIST.append(correct_val / total_val)
|
119 |
|
120 |
# Log validation metrics
|
|
|
123 |
"Accuracy": correct_val / total_val,
|
124 |
}
|
125 |
plot_and_log_metrics(val_metrics, epoch, prefix="Validation")
|
126 |
+
|
127 |
+
# Print average training and validation metrics
|
128 |
+
print(f"Average Training Loss: {avg_train_loss:.6f}")
|
129 |
+
print(f"Average Validation Loss: {avg_val_loss:.6f}")
|
130 |
+
print(f"Training Accuracy: {correct_train / total_train:.6f}")
|
131 |
+
print(f"Validation Accuracy: {correct_val / total_val:.6f}")
|
132 |
+
|
133 |
+
# Check for early stopping based on validation accuracy
|
134 |
+
if correct_val / total_val > best_val_accuracy:
|
135 |
+
best_val_accuracy = correct_val / total_val
|
136 |
+
no_improvement_count = 0
|
137 |
+
else:
|
138 |
+
no_improvement_count += 1
|
139 |
+
|
140 |
+
# Early stopping condition
|
141 |
+
if no_improvement_count >= early_stopping_patience:
|
142 |
+
print("Early stopping: Validation accuracy did not improve for {} consecutive epochs.".format(early_stopping_patience))
|
143 |
+
break # Stop training
|
144 |
|
145 |
# Save the model
|
146 |
torch.save(MODEL.state_dict(), MODEL_SAVE_PATH)
|
|
|
149 |
# Plot loss and accuracy curves
|
150 |
plt.figure(figsize=(12, 4))
|
151 |
plt.subplot(1, 2, 1)
|
152 |
+
plt.plot(range(1, len(AVG_TRAIN_LOSS_HIST) + 1), AVG_TRAIN_LOSS_HIST, label="Average Train Loss")
|
153 |
+
plt.plot(range(1, len(AVG_VAL_LOSS_HIST) + 1), AVG_VAL_LOSS_HIST, label="Average Validation Loss")
|
154 |
plt.xlabel("Epochs")
|
155 |
plt.ylabel("Loss")
|
156 |
plt.legend()
|
157 |
plt.title("Loss Curves")
|
158 |
|
159 |
plt.subplot(1, 2, 2)
|
160 |
+
plt.plot(range(1, len(TRAIN_ACC_HIST) + 1), TRAIN_ACC_HIST, label="Train Accuracy")
|
161 |
+
plt.plot(range(1, len(VAL_ACC_HIST) + 1), VAL_ACC_HIST, label="Validation Accuracy")
|
162 |
plt.xlabel("Epochs")
|
163 |
plt.ylabel("Accuracy")
|
164 |
plt.legend()
|
tuning.py
CHANGED
@@ -1,186 +1,101 @@
|
|
1 |
import os
|
|
|
|
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
import torch.optim as optim
|
5 |
-
|
6 |
-
from torch.utils.tensorboard import SummaryWriter
|
7 |
-
from torchvision.utils import make_grid
|
8 |
-
import optuna
|
9 |
from configs import *
|
10 |
import data_loader
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
# Initialize model, criterion, optimizer, and scheduler
|
18 |
-
MODEL = MODEL.to(DEVICE)
|
19 |
-
criterion = nn.CrossEntropyLoss()
|
20 |
-
optimizer = optim.Adam(MODEL.parameters(), lr=LEARNING_RATE)
|
21 |
-
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
22 |
-
optimizer, mode="min", factor=0.1, patience=10, verbose=True
|
23 |
-
)
|
24 |
-
|
25 |
-
# Lists to store training and validation loss history
|
26 |
-
TRAIN_LOSS_HIST = []
|
27 |
-
VAL_LOSS_HIST = []
|
28 |
-
TRAIN_ACC_HIST = []
|
29 |
-
VAL_ACC_HIST = []
|
30 |
-
AVG_TRAIN_LOSS_HIST = []
|
31 |
-
AVG_VAL_LOSS_HIST = []
|
32 |
-
|
33 |
-
# Create a TensorBoard writer for logging
|
34 |
-
writer = SummaryWriter(
|
35 |
-
log_dir="output/tensorboard/tuning",
|
36 |
-
)
|
37 |
-
|
38 |
-
# Define early stopping parameters
|
39 |
-
early_stopping_patience = 10 # Number of epochs to wait for improvement
|
40 |
-
best_val_loss = float('inf')
|
41 |
-
no_improvement_count = 0
|
42 |
-
|
43 |
-
def train_epoch(epoch):
|
44 |
-
MODEL.train(True)
|
45 |
-
running_loss = 0.0
|
46 |
-
total_train = 0
|
47 |
-
correct_train = 0
|
48 |
-
|
49 |
-
for i, (inputs, labels) in enumerate(train_loader, 0):
|
50 |
-
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
51 |
-
optimizer.zero_grad()
|
52 |
-
outputs = MODEL(inputs)
|
53 |
-
loss = criterion(outputs, labels)
|
54 |
-
loss.backward()
|
55 |
-
optimizer.step()
|
56 |
-
running_loss += loss.item()
|
57 |
-
|
58 |
-
if (i + 1) % NUM_PRINT == 0:
|
59 |
-
print(
|
60 |
-
"[Epoch %d, Batch %d] Loss: %.6f"
|
61 |
-
% (epoch + 1, i + 1, running_loss / NUM_PRINT)
|
62 |
-
)
|
63 |
-
running_loss = 0.0
|
64 |
-
|
65 |
-
_, predicted = torch.max(outputs, 1)
|
66 |
-
total_train += labels.size(0)
|
67 |
-
correct_train += (predicted == labels).sum().item()
|
68 |
-
|
69 |
-
TRAIN_LOSS_HIST.append(loss.item())
|
70 |
-
train_accuracy = correct_train / total_train
|
71 |
-
TRAIN_ACC_HIST.append(train_accuracy)
|
72 |
-
# Calculate the average training loss for the epoch
|
73 |
-
avg_train_loss = running_loss / len(train_loader)
|
74 |
-
|
75 |
-
writer.add_scalar("Loss/Train", avg_train_loss, epoch)
|
76 |
-
writer.add_scalar("Accuracy/Train", train_accuracy, epoch)
|
77 |
-
AVG_TRAIN_LOSS_HIST.append(avg_train_loss)
|
78 |
-
|
79 |
-
# Print average training loss for the epoch
|
80 |
-
print("[Epoch %d] Average Training Loss: %.6f" % (epoch + 1, avg_train_loss))
|
81 |
-
|
82 |
-
# Learning rate scheduling
|
83 |
-
lr_1 = optimizer.param_groups[0]["lr"]
|
84 |
-
print("Learning Rate: {:.15f}".format(lr_1))
|
85 |
-
scheduler.step(avg_train_loss)
|
86 |
-
|
87 |
-
def validate_epoch(epoch):
|
88 |
-
global best_val_loss, no_improvement_count
|
89 |
-
|
90 |
-
MODEL.eval()
|
91 |
-
val_loss = 0.0
|
92 |
-
correct_val = 0
|
93 |
-
total_val = 0
|
94 |
-
|
95 |
-
with torch.no_grad():
|
96 |
-
for inputs, labels in valid_loader:
|
97 |
-
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
98 |
-
outputs = MODEL(inputs)
|
99 |
-
loss = criterion(outputs, labels)
|
100 |
-
val_loss += loss.item()
|
101 |
-
# Calculate accuracy
|
102 |
-
_, predicted = torch.max(outputs, 1)
|
103 |
-
total_val += labels.size(0)
|
104 |
-
correct_val += (predicted == labels).sum().item()
|
105 |
-
|
106 |
-
VAL_LOSS_HIST.append(loss.item())
|
107 |
-
|
108 |
-
# Calculate the average validation loss for the epoch
|
109 |
-
avg_val_loss = val_loss / len(valid_loader)
|
110 |
-
AVG_VAL_LOSS_HIST.append(loss.item())
|
111 |
-
print("Average Validation Loss: %.6f" % (avg_val_loss))
|
112 |
-
|
113 |
-
# Calculate the accuracy of the validation set
|
114 |
-
val_accuracy = correct_val / total_val
|
115 |
-
VAL_ACC_HIST.append(val_accuracy)
|
116 |
-
print("Validation Accuracy: %.6f" % (val_accuracy))
|
117 |
-
writer.add_scalar("Loss/Validation", avg_val_loss, epoch)
|
118 |
-
writer.add_scalar("Accuracy/Validation", val_accuracy, epoch)
|
119 |
-
|
120 |
-
# Add sample images to TensorBoard
|
121 |
-
sample_images, _ = next(iter(valid_loader)) # Get a batch of sample images
|
122 |
-
sample_images = sample_images.to(DEVICE)
|
123 |
-
grid_image = make_grid(
|
124 |
-
sample_images, nrow=8, normalize=True
|
125 |
-
) # Create a grid of images
|
126 |
-
writer.add_image("Sample Images", grid_image, global_step=epoch)
|
127 |
-
|
128 |
-
# Check for early stopping
|
129 |
-
if avg_val_loss < best_val_loss:
|
130 |
-
best_val_loss = avg_val_loss
|
131 |
-
no_improvement_count = 0
|
132 |
-
else:
|
133 |
-
no_improvement_count += 1
|
134 |
-
|
135 |
-
if no_improvement_count >= early_stopping_patience:
|
136 |
-
print(f"Early stopping after {epoch + 1} epochs without improvement.")
|
137 |
-
return True # Return True to stop training
|
138 |
-
|
139 |
-
def objective(trial):
|
140 |
-
global best_val_loss, no_improvement_count
|
141 |
-
|
142 |
-
learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-1)
|
143 |
-
batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
|
144 |
|
145 |
-
# Modify the model and optimizer using suggested hyperparameters
|
146 |
-
optimizer = optim.Adam(MODEL.parameters(), lr=learning_rate)
|
147 |
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
151 |
|
152 |
-
# Check for early stopping
|
153 |
-
if early_stopping:
|
154 |
-
break
|
155 |
|
156 |
-
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
-
# Return the negative score as Optuna maximizes by default
|
160 |
-
return -validation_score
|
161 |
|
162 |
if __name__ == "__main__":
|
163 |
-
|
164 |
-
study.
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
print("Number of
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
print("Number of complete trials: ", len(complete_trials))
|
176 |
-
|
177 |
-
# Print best trial
|
178 |
-
trial = study.best_trial
|
179 |
print("Best trial:")
|
180 |
-
|
|
|
|
|
|
|
181 |
print(" Params: ")
|
182 |
for key, value in trial.params.items():
|
183 |
-
print(
|
184 |
-
|
185 |
-
# Close TensorBoard writer
|
186 |
-
writer.close()
|
|
|
1 |
import os
|
2 |
+
import optuna
|
3 |
+
from optuna.trial import TrialState
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
import torch.optim as optim
|
7 |
+
import torch.utils.data
|
|
|
|
|
|
|
8 |
from configs import *
|
9 |
import data_loader
|
10 |
|
11 |
+
optuna.logging.set_verbosity(optuna.logging.DEBUG)
|
12 |
+
|
13 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
+
EPOCHS = 10
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
|
|
|
|
16 |
|
17 |
+
def create_data_loaders(batch_size):
|
18 |
+
# Create or modify data loaders with the specified batch size
|
19 |
+
train_loader, valid_loader = data_loader.load_data(
|
20 |
+
RAW_DATA_DIR, AUG_DATA_DIR, EXTERNAL_DATA_DIR, preprocess, batch_size=batch_size
|
21 |
+
)
|
22 |
+
return train_loader, valid_loader
|
23 |
|
|
|
|
|
|
|
24 |
|
25 |
+
def objective(trial, model=MODEL):
|
26 |
+
# Generate the model.
|
27 |
+
model = model.to(DEVICE)
|
28 |
+
|
29 |
+
# Suggest batch size for tuning.
|
30 |
+
batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
|
31 |
+
|
32 |
+
# Create data loaders with the suggested batch size.
|
33 |
+
train_loader, valid_loader = create_data_loaders(batch_size)
|
34 |
+
|
35 |
+
# Generate the optimizer.
|
36 |
+
optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "SGD"])
|
37 |
+
lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
|
38 |
+
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
|
39 |
+
criterion = nn.CrossEntropyLoss()
|
40 |
+
|
41 |
+
# Training of the model.
|
42 |
+
for epoch in range(EPOCHS):
|
43 |
+
print(f"[Epoch: {epoch} | Trial: {trial.number}]")
|
44 |
+
model.train()
|
45 |
+
for batch_idx, (data, target) in enumerate(train_loader, 0):
|
46 |
+
data, target = data.to(DEVICE), target.to(DEVICE)
|
47 |
+
optimizer.zero_grad()
|
48 |
+
output = model(data)
|
49 |
+
loss = criterion(output, target)
|
50 |
+
loss.backward()
|
51 |
+
if optimizer_name == "LBFGS":
|
52 |
+
optimizer.step(closure=lambda: loss)
|
53 |
+
else:
|
54 |
+
optimizer.step()
|
55 |
+
|
56 |
+
# Validation of the model.
|
57 |
+
model.eval()
|
58 |
+
correct = 0
|
59 |
+
with torch.no_grad():
|
60 |
+
for batch_idx, (data, target) in enumerate(valid_loader, 0):
|
61 |
+
data, target = data.to(DEVICE), target.to(DEVICE)
|
62 |
+
output = model(data)
|
63 |
+
# Get the index of the max log-probability.
|
64 |
+
pred = output.argmax(dim=1, keepdim=True)
|
65 |
+
correct += pred.eq(target.view_as(pred)).sum().item()
|
66 |
+
|
67 |
+
accuracy = correct / len(valid_loader.dataset)
|
68 |
+
|
69 |
+
# Print hyperparameters and accuracy
|
70 |
+
print("Hyperparameters: ", trial.params)
|
71 |
+
print("Accuracy: ", accuracy)
|
72 |
+
trial.report(accuracy, epoch)
|
73 |
+
|
74 |
+
# Handle pruning based on the intermediate value.
|
75 |
+
if trial.should_prune():
|
76 |
+
raise optuna.exceptions.TrialPruned()
|
77 |
+
|
78 |
+
return accuracy
|
79 |
|
|
|
|
|
80 |
|
81 |
if __name__ == "__main__":
|
82 |
+
pruner = optuna.pruners.HyperbandPruner()
|
83 |
+
study = optuna.create_study(direction="maximize", pruner=pruner, study_name="handetect")
|
84 |
+
study.optimize(objective, n_trials=100, timeout=1000)
|
85 |
+
|
86 |
+
pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
|
87 |
+
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
|
88 |
+
|
89 |
+
print("Study statistics: ")
|
90 |
+
print(" Number of finished trials: ", len(study.trials))
|
91 |
+
print(" Number of pruned trials: ", len(pruned_trials))
|
92 |
+
print(" Number of complete trials: ", len(complete_trials))
|
93 |
+
|
|
|
|
|
|
|
|
|
94 |
print("Best trial:")
|
95 |
+
trial = study.best_trial
|
96 |
+
|
97 |
+
print(" Value: ", trial.value)
|
98 |
+
|
99 |
print(" Params: ")
|
100 |
for key, value in trial.params.items():
|
101 |
+
print(" {}: {}".format(key, value))
|
|
|
|
|
|