cycool29 commited on
Commit
9d7b040
·
1 Parent(s): e8ebf3d
Files changed (7) hide show
  1. augment.py +13 -5
  2. configs.py +10 -3
  3. data_loader.py +8 -5
  4. eval.py +9 -17
  5. predict.py +5 -4
  6. train.py +44 -18
  7. tuning.py +84 -169
augment.py CHANGED
@@ -14,10 +14,20 @@ for task in tasks:
14
  if not os.path.exists(f"data/temp/Task {task}/{disease}/"):
15
  os.makedirs(f"data/temp/Task {task}/{disease}/")
16
  for file in os.listdir(f"data/train/raw/Task {task}/{disease}"):
17
- shutil.copy(f"data/train/raw/Task {task}/{disease}/{file}", f"data/temp/Task {task}/{disease}/{file}")
 
 
 
18
  for file in os.listdir(f"data/train/external/Task {task}/{disease}"):
19
- shutil.copy(f"data/train/external/Task {task}/{disease}/{file}", f"data/temp/Task {task}/{disease}/{file}")
20
- p = Augmentor.Pipeline(f"data/temp/Task {task}/{disease}", output_directory=f"{disease}/", save_format="png")
 
 
 
 
 
 
 
21
  p.rotate(probability=0.8, max_left_rotation=5, max_right_rotation=5)
22
  p.flip_left_right(probability=0.8)
23
  p.zoom_random(probability=0.8, percentage_area=0.8)
@@ -46,5 +56,3 @@ for task in tasks:
46
  f"data/train/augmented/Task {task}/{disease}/{file}",
47
  f"data/train/augmented/Task {task}/{disease}/{number}.png",
48
  )
49
-
50
-
 
14
  if not os.path.exists(f"data/temp/Task {task}/{disease}/"):
15
  os.makedirs(f"data/temp/Task {task}/{disease}/")
16
  for file in os.listdir(f"data/train/raw/Task {task}/{disease}"):
17
+ shutil.copy(
18
+ f"data/train/raw/Task {task}/{disease}/{file}",
19
+ f"data/temp/Task {task}/{disease}/{file}",
20
+ )
21
  for file in os.listdir(f"data/train/external/Task {task}/{disease}"):
22
+ shutil.copy(
23
+ f"data/train/external/Task {task}/{disease}/{file}",
24
+ f"data/temp/Task {task}/{disease}/{file}",
25
+ )
26
+ p = Augmentor.Pipeline(
27
+ f"data/temp/Task {task}/{disease}",
28
+ output_directory=f"{disease}/",
29
+ save_format="png",
30
+ )
31
  p.rotate(probability=0.8, max_left_rotation=5, max_right_rotation=5)
32
  p.flip_left_right(probability=0.8)
33
  p.zoom_random(probability=0.8, percentage_area=0.8)
 
56
  f"data/train/augmented/Task {task}/{disease}/{file}",
57
  f"data/train/augmented/Task {task}/{disease}/{number}.png",
58
  )
 
 
configs.py CHANGED
@@ -6,9 +6,10 @@ from models import *
6
 
7
  # Constants
8
  RANDOM_SEED = 123
9
- BATCH_SIZE = 16
10
  NUM_EPOCHS = 100
11
- LEARNING_RATE = 0.05585974668605116
 
12
  STEP_SIZE = 10
13
  GAMMA = 0.5
14
  DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -18,8 +19,13 @@ RAW_DATA_DIR = r"data/train/raw/Task " + str(TASK)
18
  AUG_DATA_DIR = r"data/train/augmented/Task " + str(TASK)
19
  EXTERNAL_DATA_DIR = r"data/train/external/Task " + str(TASK)
20
  NUM_CLASSES = 7
 
 
21
  MODEL_SAVE_PATH = "output/checkpoints/model.pth"
22
- MODEL = mobilenet_v3_small(num_classes=NUM_CLASSES)
 
 
 
23
 
24
  preprocess = transforms.Compose(
25
  [
@@ -30,6 +36,7 @@ preprocess = transforms.Compose(
30
  ]
31
  )
32
 
 
33
  # Custom dataset class
34
  class CustomDataset(Dataset):
35
  def __init__(self, dataset):
 
6
 
7
  # Constants
8
  RANDOM_SEED = 123
9
+ BATCH_SIZE = 128
10
  NUM_EPOCHS = 100
11
+ LEARNING_RATE = 0.04279442975996121
12
+ OPTIMIZER_NAME = "LBFGS"
13
  STEP_SIZE = 10
14
  GAMMA = 0.5
15
  DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
19
  AUG_DATA_DIR = r"data/train/augmented/Task " + str(TASK)
20
  EXTERNAL_DATA_DIR = r"data/train/external/Task " + str(TASK)
21
  NUM_CLASSES = 7
22
+ # Define classes as listdir of augmented data
23
+ CLASSES = os.listdir("data/train/augmented/Task 1/")
24
  MODEL_SAVE_PATH = "output/checkpoints/model.pth"
25
+ MODEL = mobilenet_v2(num_classes=NUM_CLASSES)
26
+
27
+ print(CLASSES)
28
+
29
 
30
  preprocess = transforms.Compose(
31
  [
 
36
  ]
37
  )
38
 
39
+
40
  # Custom dataset class
41
  class CustomDataset(Dataset):
42
  def __init__(self, dataset):
data_loader.py CHANGED
@@ -3,14 +3,17 @@ from torchvision.datasets import ImageFolder
3
  from torch.utils.data import random_split, DataLoader, Dataset
4
 
5
 
6
- def load_data(raw_dir, augmented_dir, external_dir, preprocess):
7
  # Load the dataset using ImageFolder
8
  raw_dataset = ImageFolder(root=raw_dir, transform=preprocess)
9
  external_dataset = ImageFolder(root=external_dir, transform=preprocess)
10
  augmented_dataset = ImageFolder(root=augmented_dir, transform=preprocess)
11
  dataset = raw_dataset + external_dataset + augmented_dataset
12
 
13
- print("Classes: ", *raw_dataset.classes, sep = ', ')
 
 
 
14
  print("Length of raw dataset: ", len(raw_dataset))
15
  print("Length of external dataset: ", len(external_dataset))
16
  print("Length of augmented dataset: ", len(augmented_dataset))
@@ -23,10 +26,10 @@ def load_data(raw_dir, augmented_dir, external_dir, preprocess):
23
 
24
  # Create data loaders for the custom dataset
25
  train_loader = DataLoader(
26
- CustomDataset(train_dataset), batch_size=BATCH_SIZE, shuffle=True, num_workers=0
27
  )
28
  valid_loader = DataLoader(
29
- CustomDataset(val_dataset), batch_size=BATCH_SIZE, num_workers=0
30
  )
31
-
32
  return train_loader, valid_loader
 
3
  from torch.utils.data import random_split, DataLoader, Dataset
4
 
5
 
6
+ def load_data(raw_dir, augmented_dir, external_dir, preprocess, batch_size=BATCH_SIZE):
7
  # Load the dataset using ImageFolder
8
  raw_dataset = ImageFolder(root=raw_dir, transform=preprocess)
9
  external_dataset = ImageFolder(root=external_dir, transform=preprocess)
10
  augmented_dataset = ImageFolder(root=augmented_dir, transform=preprocess)
11
  dataset = raw_dataset + external_dataset + augmented_dataset
12
 
13
+ # Classes
14
+ classes = augmented_dataset.classes
15
+
16
+ print("Classes: ", *classes, sep=", ")
17
  print("Length of raw dataset: ", len(raw_dataset))
18
  print("Length of external dataset: ", len(external_dataset))
19
  print("Length of augmented dataset: ", len(augmented_dataset))
 
26
 
27
  # Create data loaders for the custom dataset
28
  train_loader = DataLoader(
29
+ CustomDataset(train_dataset), batch_size=batch_size, shuffle=True, num_workers=0
30
  )
31
  valid_loader = DataLoader(
32
+ CustomDataset(val_dataset), batch_size=batch_size, num_workers=0
33
  )
34
+
35
  return train_loader, valid_loader
eval.py CHANGED
@@ -19,37 +19,27 @@ MODEL = MODEL.to(DEVICE)
19
  MODEL.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
20
  MODEL.eval()
21
 
22
- # Get class labels from the dataset
23
- class_labels = os.listdir(image_path)
24
-
25
- # Define transformation for preprocessing
26
- preprocess = transforms.Compose(
27
- [
28
- transforms.Resize((64, 64)), # Resize images to 64x64
29
- transforms.ToTensor(), # Convert to tensor
30
- transforms.Normalize((0.5,), (0.5,)), # Normalize (for grayscale)
31
- ]
32
- )
33
 
34
  def predict_image(image_path, model, transform):
35
  model.eval()
36
  correct_predictions = 0
37
- total_predictions = len(images)
38
 
39
  # Get a list of image files
40
  images = list(pathlib.Path(image_path).rglob("*.png"))
41
 
 
 
42
  true_classes = []
43
  predicted_labels = []
44
 
45
  with torch.no_grad():
46
  for image_file in images:
47
- print('---------------------------')
48
  # Check the true label of the image by checking the sequence of the folder in Task 1
49
- true_class = class_labels.index(image_file.parts[-2])
50
  print("Image path:", image_file)
51
  print("True class:", true_class)
52
- image = Image.open(image_file).convert('RGB')
53
  image = transform(image).unsqueeze(0)
54
  image = image.to(DEVICE)
55
  output = model(image)
@@ -67,7 +57,7 @@ def predict_image(image_path, model, transform):
67
  # Calculate accuracy and f1 score
68
  accuracy = correct_predictions / total_predictions
69
  print("Accuracy:", accuracy)
70
- f1 = f1_score(true_classes, predicted_labels, average='weighted')
71
  print("Weighted F1 Score:", f1)
72
 
73
  # Convert the lists to tensors
@@ -75,12 +65,14 @@ def predict_image(image_path, model, transform):
75
  true_classes_tensor = torch.tensor(true_classes)
76
 
77
  # Create a confusion matrix
78
- conf_matrix = ConfusionMatrix(num_classes=NUM_CLASSES, task='multiclass')
79
  conf_matrix.update(predicted_labels_tensor, true_classes_tensor)
80
 
81
  # Plot the confusion matrix
 
82
  conf_matrix.plot()
83
  plt.show()
84
 
 
85
  # Call predict_image function
86
  predict_image(image_path, MODEL, preprocess)
 
19
  MODEL.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
20
  MODEL.eval()
21
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def predict_image(image_path, model, transform):
24
  model.eval()
25
  correct_predictions = 0
 
26
 
27
  # Get a list of image files
28
  images = list(pathlib.Path(image_path).rglob("*.png"))
29
 
30
+ total_predictions = len(images)
31
+
32
  true_classes = []
33
  predicted_labels = []
34
 
35
  with torch.no_grad():
36
  for image_file in images:
37
+ print("---------------------------")
38
  # Check the true label of the image by checking the sequence of the folder in Task 1
39
+ true_class = CLASSES.index(image_file.parts[-2])
40
  print("Image path:", image_file)
41
  print("True class:", true_class)
42
+ image = Image.open(image_file).convert("RGB")
43
  image = transform(image).unsqueeze(0)
44
  image = image.to(DEVICE)
45
  output = model(image)
 
57
  # Calculate accuracy and f1 score
58
  accuracy = correct_predictions / total_predictions
59
  print("Accuracy:", accuracy)
60
+ f1 = f1_score(true_classes, predicted_labels, average="weighted")
61
  print("Weighted F1 Score:", f1)
62
 
63
  # Convert the lists to tensors
 
65
  true_classes_tensor = torch.tensor(true_classes)
66
 
67
  # Create a confusion matrix
68
+ conf_matrix = ConfusionMatrix(num_classes=NUM_CLASSES, task="multiclass")
69
  conf_matrix.update(predicted_labels_tensor, true_classes_tensor)
70
 
71
  # Plot the confusion matrix
72
+ conf_matrix.compute()
73
  conf_matrix.plot()
74
  plt.show()
75
 
76
+
77
  # Call predict_image function
78
  predict_image(image_path, MODEL, preprocess)
predict.py CHANGED
@@ -20,13 +20,11 @@ torch.set_grad_enabled(False)
20
 
21
 
22
  def predict_image(image_path, model=MODEL, transform=preprocess):
23
- classes = [
24
- 'Cerebral Palsy', 'Dystonia', 'Essential Tremor', 'Healthy', 'Huntington Disease', 'Parkinson Disease'
25
- ]
26
 
27
  print("---------------------------")
28
  print("Image path:", image_path)
29
- image = Image.open(image_path)
30
  image = transform(image).unsqueeze(0)
31
  image = image.to(DEVICE)
32
  output = model(image)
@@ -55,3 +53,6 @@ def predict_image(image_path, model=MODEL, transform=preprocess):
55
  print("---------------------------")
56
 
57
  return predicted_label, sorted_classes
 
 
 
 
20
 
21
 
22
  def predict_image(image_path, model=MODEL, transform=preprocess):
23
+ classes = CLASSES
 
 
24
 
25
  print("---------------------------")
26
  print("Image path:", image_path)
27
+ image = Image.open(image_path).convert("RGB")
28
  image = transform(image).unsqueeze(0)
29
  image = image.to(DEVICE)
30
  output = model(image)
 
53
  print("---------------------------")
54
 
55
  return predicted_label, sorted_classes
56
+
57
+
58
+ predict_image("data/test/Task 1/Healthy/01.png")
train.py CHANGED
@@ -5,12 +5,12 @@ import torch.optim as optim
5
  from torchvision.transforms import transforms
6
  from torch.utils.data import DataLoader
7
  from torchvision.utils import make_grid
8
- from scipy.ndimage import gaussian_filter1d
9
  import matplotlib.pyplot as plt
10
  from models import *
11
  from torch.utils.tensorboard import SummaryWriter
12
  from configs import *
13
  import data_loader
 
14
 
15
  # Set up TensorBoard writer
16
  writer = SummaryWriter(log_dir="output/tensorboard/training")
@@ -28,9 +28,21 @@ train_loader, valid_loader = data_loader.load_data(
28
  # Initialize model, criterion, optimizer, and scheduler
29
  MODEL = MODEL.to(DEVICE)
30
  criterion = nn.CrossEntropyLoss()
31
- optimizer = optim.SGD(MODEL.parameters(), lr=LEARNING_RATE)
 
 
 
 
 
 
32
  scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)
33
 
 
 
 
 
 
 
34
  # Lists to store training and validation loss history
35
  TRAIN_LOSS_HIST = []
36
  VAL_LOSS_HIST = []
@@ -41,6 +53,7 @@ VAL_ACC_HIST = []
41
 
42
  # Training loop
43
  for epoch in range(NUM_EPOCHS):
 
44
  MODEL.train() # Set model to training mode
45
  running_loss = 0.0
46
  total_train = 0
@@ -52,7 +65,10 @@ for epoch in range(NUM_EPOCHS):
52
  outputs = MODEL(inputs)
53
  loss = criterion(outputs, labels)
54
  loss.backward()
55
- optimizer.step()
 
 
 
56
  running_loss += loss.item()
57
 
58
  if (i + 1) % NUM_PRINT == 0:
@@ -67,9 +83,9 @@ for epoch in range(NUM_EPOCHS):
67
  correct_train += (predicted == labels).sum().item()
68
 
69
  avg_train_loss = running_loss / len(train_loader)
70
- TRAIN_LOSS_HIST.append(avg_train_loss)
71
  TRAIN_ACC_HIST.append(correct_train / total_train)
72
-
73
  # Log training metrics
74
  train_metrics = {
75
  "Loss": avg_train_loss,
@@ -98,7 +114,7 @@ for epoch in range(NUM_EPOCHS):
98
  correct_val += (predicted == labels).sum().item()
99
 
100
  avg_val_loss = val_loss / len(valid_loader)
101
- VAL_LOSS_HIST.append(avg_val_loss)
102
  VAL_ACC_HIST.append(correct_val / total_val)
103
 
104
  # Log validation metrics
@@ -107,14 +123,24 @@ for epoch in range(NUM_EPOCHS):
107
  "Accuracy": correct_val / total_val,
108
  }
109
  plot_and_log_metrics(val_metrics, epoch, prefix="Validation")
110
-
111
- # Add sample images to TensorBoard
112
- sample_images, _ = next(iter(valid_loader))
113
- sample_images = sample_images.to(DEVICE)
114
- grid_image = make_grid(
115
- sample_images, nrow=8, normalize=True
116
- )
117
- writer.add_image("Sample Images", grid_image, global_step=epoch)
 
 
 
 
 
 
 
 
 
 
118
 
119
  # Save the model
120
  torch.save(MODEL.state_dict(), MODEL_SAVE_PATH)
@@ -123,16 +149,16 @@ print("Model saved at", MODEL_SAVE_PATH)
123
  # Plot loss and accuracy curves
124
  plt.figure(figsize=(12, 4))
125
  plt.subplot(1, 2, 1)
126
- plt.plot(range(1, NUM_EPOCHS + 1), TRAIN_LOSS_HIST, label="Train Loss")
127
- plt.plot(range(1, NUM_EPOCHS + 1), VAL_LOSS_HIST, label="Validation Loss")
128
  plt.xlabel("Epochs")
129
  plt.ylabel("Loss")
130
  plt.legend()
131
  plt.title("Loss Curves")
132
 
133
  plt.subplot(1, 2, 2)
134
- plt.plot(range(1, NUM_EPOCHS + 1), TRAIN_ACC_HIST, label="Train Accuracy")
135
- plt.plot(range(1, NUM_EPOCHS + 1), VAL_ACC_HIST, label="Validation Accuracy")
136
  plt.xlabel("Epochs")
137
  plt.ylabel("Accuracy")
138
  plt.legend()
 
5
  from torchvision.transforms import transforms
6
  from torch.utils.data import DataLoader
7
  from torchvision.utils import make_grid
 
8
  import matplotlib.pyplot as plt
9
  from models import *
10
  from torch.utils.tensorboard import SummaryWriter
11
  from configs import *
12
  import data_loader
13
+ import numpy as np
14
 
15
  # Set up TensorBoard writer
16
  writer = SummaryWriter(log_dir="output/tensorboard/training")
 
28
  # Initialize model, criterion, optimizer, and scheduler
29
  MODEL = MODEL.to(DEVICE)
30
  criterion = nn.CrossEntropyLoss()
31
+ if OPTIMIZER_NAME == "LBFGS":
32
+ optimizer = optim.LBFGS(MODEL.parameters(), lr=LEARNING_RATE)
33
+ elif OPTIMIZER_NAME == "Adam":
34
+ optimizer = optim.Adam(MODEL.parameters(), lr=LEARNING_RATE)
35
+ elif OPTIMIZER_NAME == "SGD":
36
+ optimizer = optim.SGD(MODEL.parameters(), lr=LEARNING_RATE)
37
+
38
  scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)
39
 
40
+ # Define early stopping parameters
41
+ early_stopping_patience = 20 # Number of epochs with no improvement to wait before stopping
42
+ best_val_loss = float("inf")
43
+ best_val_accuracy = 0.0
44
+ no_improvement_count = 0
45
+
46
  # Lists to store training and validation loss history
47
  TRAIN_LOSS_HIST = []
48
  VAL_LOSS_HIST = []
 
53
 
54
  # Training loop
55
  for epoch in range(NUM_EPOCHS):
56
+ print(f"[Epoch: {epoch + 1}]")
57
  MODEL.train() # Set model to training mode
58
  running_loss = 0.0
59
  total_train = 0
 
65
  outputs = MODEL(inputs)
66
  loss = criterion(outputs, labels)
67
  loss.backward()
68
+ if OPTIMIZER_NAME == "LBFGS":
69
+ optimizer.step(closure=lambda: loss)
70
+ else:
71
+ optimizer.step()
72
  running_loss += loss.item()
73
 
74
  if (i + 1) % NUM_PRINT == 0:
 
83
  correct_train += (predicted == labels).sum().item()
84
 
85
  avg_train_loss = running_loss / len(train_loader)
86
+ AVG_TRAIN_LOSS_HIST.append(avg_train_loss)
87
  TRAIN_ACC_HIST.append(correct_train / total_train)
88
+
89
  # Log training metrics
90
  train_metrics = {
91
  "Loss": avg_train_loss,
 
114
  correct_val += (predicted == labels).sum().item()
115
 
116
  avg_val_loss = val_loss / len(valid_loader)
117
+ AVG_VAL_LOSS_HIST.append(avg_val_loss)
118
  VAL_ACC_HIST.append(correct_val / total_val)
119
 
120
  # Log validation metrics
 
123
  "Accuracy": correct_val / total_val,
124
  }
125
  plot_and_log_metrics(val_metrics, epoch, prefix="Validation")
126
+
127
+ # Print average training and validation metrics
128
+ print(f"Average Training Loss: {avg_train_loss:.6f}")
129
+ print(f"Average Validation Loss: {avg_val_loss:.6f}")
130
+ print(f"Training Accuracy: {correct_train / total_train:.6f}")
131
+ print(f"Validation Accuracy: {correct_val / total_val:.6f}")
132
+
133
+ # Check for early stopping based on validation accuracy
134
+ if correct_val / total_val > best_val_accuracy:
135
+ best_val_accuracy = correct_val / total_val
136
+ no_improvement_count = 0
137
+ else:
138
+ no_improvement_count += 1
139
+
140
+ # Early stopping condition
141
+ if no_improvement_count >= early_stopping_patience:
142
+ print("Early stopping: Validation accuracy did not improve for {} consecutive epochs.".format(early_stopping_patience))
143
+ break # Stop training
144
 
145
  # Save the model
146
  torch.save(MODEL.state_dict(), MODEL_SAVE_PATH)
 
149
  # Plot loss and accuracy curves
150
  plt.figure(figsize=(12, 4))
151
  plt.subplot(1, 2, 1)
152
+ plt.plot(range(1, len(AVG_TRAIN_LOSS_HIST) + 1), AVG_TRAIN_LOSS_HIST, label="Average Train Loss")
153
+ plt.plot(range(1, len(AVG_VAL_LOSS_HIST) + 1), AVG_VAL_LOSS_HIST, label="Average Validation Loss")
154
  plt.xlabel("Epochs")
155
  plt.ylabel("Loss")
156
  plt.legend()
157
  plt.title("Loss Curves")
158
 
159
  plt.subplot(1, 2, 2)
160
+ plt.plot(range(1, len(TRAIN_ACC_HIST) + 1), TRAIN_ACC_HIST, label="Train Accuracy")
161
+ plt.plot(range(1, len(VAL_ACC_HIST) + 1), VAL_ACC_HIST, label="Validation Accuracy")
162
  plt.xlabel("Epochs")
163
  plt.ylabel("Accuracy")
164
  plt.legend()
tuning.py CHANGED
@@ -1,186 +1,101 @@
1
  import os
 
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.optim as optim
5
- from models import * # Import your model here
6
- from torch.utils.tensorboard import SummaryWriter
7
- from torchvision.utils import make_grid
8
- import optuna
9
  from configs import *
10
  import data_loader
11
 
12
- # Data loader
13
- train_loader, valid_loader = data_loader.load_data(
14
- RAW_DATA_DIR, AUG_DATA_DIR, EXTERNAL_DATA_DIR, preprocess
15
- )
16
-
17
- # Initialize model, criterion, optimizer, and scheduler
18
- MODEL = MODEL.to(DEVICE)
19
- criterion = nn.CrossEntropyLoss()
20
- optimizer = optim.Adam(MODEL.parameters(), lr=LEARNING_RATE)
21
- scheduler = optim.lr_scheduler.ReduceLROnPlateau(
22
- optimizer, mode="min", factor=0.1, patience=10, verbose=True
23
- )
24
-
25
- # Lists to store training and validation loss history
26
- TRAIN_LOSS_HIST = []
27
- VAL_LOSS_HIST = []
28
- TRAIN_ACC_HIST = []
29
- VAL_ACC_HIST = []
30
- AVG_TRAIN_LOSS_HIST = []
31
- AVG_VAL_LOSS_HIST = []
32
-
33
- # Create a TensorBoard writer for logging
34
- writer = SummaryWriter(
35
- log_dir="output/tensorboard/tuning",
36
- )
37
-
38
- # Define early stopping parameters
39
- early_stopping_patience = 10 # Number of epochs to wait for improvement
40
- best_val_loss = float('inf')
41
- no_improvement_count = 0
42
-
43
- def train_epoch(epoch):
44
- MODEL.train(True)
45
- running_loss = 0.0
46
- total_train = 0
47
- correct_train = 0
48
-
49
- for i, (inputs, labels) in enumerate(train_loader, 0):
50
- inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
51
- optimizer.zero_grad()
52
- outputs = MODEL(inputs)
53
- loss = criterion(outputs, labels)
54
- loss.backward()
55
- optimizer.step()
56
- running_loss += loss.item()
57
-
58
- if (i + 1) % NUM_PRINT == 0:
59
- print(
60
- "[Epoch %d, Batch %d] Loss: %.6f"
61
- % (epoch + 1, i + 1, running_loss / NUM_PRINT)
62
- )
63
- running_loss = 0.0
64
-
65
- _, predicted = torch.max(outputs, 1)
66
- total_train += labels.size(0)
67
- correct_train += (predicted == labels).sum().item()
68
-
69
- TRAIN_LOSS_HIST.append(loss.item())
70
- train_accuracy = correct_train / total_train
71
- TRAIN_ACC_HIST.append(train_accuracy)
72
- # Calculate the average training loss for the epoch
73
- avg_train_loss = running_loss / len(train_loader)
74
-
75
- writer.add_scalar("Loss/Train", avg_train_loss, epoch)
76
- writer.add_scalar("Accuracy/Train", train_accuracy, epoch)
77
- AVG_TRAIN_LOSS_HIST.append(avg_train_loss)
78
-
79
- # Print average training loss for the epoch
80
- print("[Epoch %d] Average Training Loss: %.6f" % (epoch + 1, avg_train_loss))
81
-
82
- # Learning rate scheduling
83
- lr_1 = optimizer.param_groups[0]["lr"]
84
- print("Learning Rate: {:.15f}".format(lr_1))
85
- scheduler.step(avg_train_loss)
86
-
87
- def validate_epoch(epoch):
88
- global best_val_loss, no_improvement_count
89
-
90
- MODEL.eval()
91
- val_loss = 0.0
92
- correct_val = 0
93
- total_val = 0
94
-
95
- with torch.no_grad():
96
- for inputs, labels in valid_loader:
97
- inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
98
- outputs = MODEL(inputs)
99
- loss = criterion(outputs, labels)
100
- val_loss += loss.item()
101
- # Calculate accuracy
102
- _, predicted = torch.max(outputs, 1)
103
- total_val += labels.size(0)
104
- correct_val += (predicted == labels).sum().item()
105
-
106
- VAL_LOSS_HIST.append(loss.item())
107
-
108
- # Calculate the average validation loss for the epoch
109
- avg_val_loss = val_loss / len(valid_loader)
110
- AVG_VAL_LOSS_HIST.append(loss.item())
111
- print("Average Validation Loss: %.6f" % (avg_val_loss))
112
-
113
- # Calculate the accuracy of the validation set
114
- val_accuracy = correct_val / total_val
115
- VAL_ACC_HIST.append(val_accuracy)
116
- print("Validation Accuracy: %.6f" % (val_accuracy))
117
- writer.add_scalar("Loss/Validation", avg_val_loss, epoch)
118
- writer.add_scalar("Accuracy/Validation", val_accuracy, epoch)
119
-
120
- # Add sample images to TensorBoard
121
- sample_images, _ = next(iter(valid_loader)) # Get a batch of sample images
122
- sample_images = sample_images.to(DEVICE)
123
- grid_image = make_grid(
124
- sample_images, nrow=8, normalize=True
125
- ) # Create a grid of images
126
- writer.add_image("Sample Images", grid_image, global_step=epoch)
127
-
128
- # Check for early stopping
129
- if avg_val_loss < best_val_loss:
130
- best_val_loss = avg_val_loss
131
- no_improvement_count = 0
132
- else:
133
- no_improvement_count += 1
134
-
135
- if no_improvement_count >= early_stopping_patience:
136
- print(f"Early stopping after {epoch + 1} epochs without improvement.")
137
- return True # Return True to stop training
138
-
139
- def objective(trial):
140
- global best_val_loss, no_improvement_count
141
-
142
- learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-1)
143
- batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
144
 
145
- # Modify the model and optimizer using suggested hyperparameters
146
- optimizer = optim.Adam(MODEL.parameters(), lr=learning_rate)
147
 
148
- for epoch in range(20):
149
- train_epoch(epoch)
150
- early_stopping = validate_epoch(epoch)
 
 
 
151
 
152
- # Check for early stopping
153
- if early_stopping:
154
- break
155
 
156
- # Calculate a weighted score based on validation accuracy and loss
157
- validation_score = VAL_ACC_HIST[-1] - AVG_VAL_LOSS_HIST[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- # Return the negative score as Optuna maximizes by default
160
- return -validation_score
161
 
162
  if __name__ == "__main__":
163
- study = optuna.create_study(direction="maximize")
164
- study.optimize(objective, timeout=3600)
165
-
166
- # Print statistics
167
- print("Number of finished trials: ", len(study.trials))
168
- pruned_trials = [
169
- t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED
170
- ]
171
- print("Number of pruned trials: ", len(pruned_trials))
172
- complete_trials = [
173
- t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE
174
- ]
175
- print("Number of complete trials: ", len(complete_trials))
176
-
177
- # Print best trial
178
- trial = study.best_trial
179
  print("Best trial:")
180
- print(" Value: ", -trial.value) # Negate the value as it was maximized
 
 
 
181
  print(" Params: ")
182
  for key, value in trial.params.items():
183
- print(f" {key}: {value}")
184
-
185
- # Close TensorBoard writer
186
- writer.close()
 
1
  import os
2
+ import optuna
3
+ from optuna.trial import TrialState
4
  import torch
5
  import torch.nn as nn
6
  import torch.optim as optim
7
+ import torch.utils.data
 
 
 
8
  from configs import *
9
  import data_loader
10
 
11
+ optuna.logging.set_verbosity(optuna.logging.DEBUG)
12
+
13
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ EPOCHS = 10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
 
 
16
 
17
+ def create_data_loaders(batch_size):
18
+ # Create or modify data loaders with the specified batch size
19
+ train_loader, valid_loader = data_loader.load_data(
20
+ RAW_DATA_DIR, AUG_DATA_DIR, EXTERNAL_DATA_DIR, preprocess, batch_size=batch_size
21
+ )
22
+ return train_loader, valid_loader
23
 
 
 
 
24
 
25
+ def objective(trial, model=MODEL):
26
+ # Generate the model.
27
+ model = model.to(DEVICE)
28
+
29
+ # Suggest batch size for tuning.
30
+ batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
31
+
32
+ # Create data loaders with the suggested batch size.
33
+ train_loader, valid_loader = create_data_loaders(batch_size)
34
+
35
+ # Generate the optimizer.
36
+ optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "SGD"])
37
+ lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
38
+ optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
39
+ criterion = nn.CrossEntropyLoss()
40
+
41
+ # Training of the model.
42
+ for epoch in range(EPOCHS):
43
+ print(f"[Epoch: {epoch} | Trial: {trial.number}]")
44
+ model.train()
45
+ for batch_idx, (data, target) in enumerate(train_loader, 0):
46
+ data, target = data.to(DEVICE), target.to(DEVICE)
47
+ optimizer.zero_grad()
48
+ output = model(data)
49
+ loss = criterion(output, target)
50
+ loss.backward()
51
+ if optimizer_name == "LBFGS":
52
+ optimizer.step(closure=lambda: loss)
53
+ else:
54
+ optimizer.step()
55
+
56
+ # Validation of the model.
57
+ model.eval()
58
+ correct = 0
59
+ with torch.no_grad():
60
+ for batch_idx, (data, target) in enumerate(valid_loader, 0):
61
+ data, target = data.to(DEVICE), target.to(DEVICE)
62
+ output = model(data)
63
+ # Get the index of the max log-probability.
64
+ pred = output.argmax(dim=1, keepdim=True)
65
+ correct += pred.eq(target.view_as(pred)).sum().item()
66
+
67
+ accuracy = correct / len(valid_loader.dataset)
68
+
69
+ # Print hyperparameters and accuracy
70
+ print("Hyperparameters: ", trial.params)
71
+ print("Accuracy: ", accuracy)
72
+ trial.report(accuracy, epoch)
73
+
74
+ # Handle pruning based on the intermediate value.
75
+ if trial.should_prune():
76
+ raise optuna.exceptions.TrialPruned()
77
+
78
+ return accuracy
79
 
 
 
80
 
81
  if __name__ == "__main__":
82
+ pruner = optuna.pruners.HyperbandPruner()
83
+ study = optuna.create_study(direction="maximize", pruner=pruner, study_name="handetect")
84
+ study.optimize(objective, n_trials=100, timeout=1000)
85
+
86
+ pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
87
+ complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
88
+
89
+ print("Study statistics: ")
90
+ print(" Number of finished trials: ", len(study.trials))
91
+ print(" Number of pruned trials: ", len(pruned_trials))
92
+ print(" Number of complete trials: ", len(complete_trials))
93
+
 
 
 
 
94
  print("Best trial:")
95
+ trial = study.best_trial
96
+
97
+ print(" Value: ", trial.value)
98
+
99
  print(" Params: ")
100
  for key, value in trial.params.items():
101
+ print(" {}: {}".format(key, value))