File size: 4,180 Bytes
e6f2a04
9d7b040
 
e6f2a04
 
 
9d7b040
e6f2a04
 
672baaa
e6f2a04
9d7b040
 
59908f1
1bdf45d
e6f2a04
672baaa
59908f1
 
e6f2a04
9d7b040
 
 
59908f1
 
 
 
 
9d7b040
 
e6f2a04
59908f1
9d7b040
 
 
 
 
 
 
 
 
 
 
59908f1
 
9d7b040
 
59908f1
 
 
 
 
 
9d7b040
 
 
 
 
 
 
59908f1
 
 
672baaa
 
 
9d7b040
 
59908f1
 
 
 
9d7b040
 
 
 
 
 
 
 
 
 
 
 
 
 
672baaa
 
 
59908f1
 
672baaa
 
9d7b040
 
 
 
 
 
 
 
 
59908f1
 
 
9d7b040
e6f2a04
59908f1
e6f2a04
9d7b040
59908f1
 
 
 
 
9d7b040
59908f1
 
1bdf45d
 
9d7b040
59908f1
 
e6f2a04
59908f1
e6f2a04
59908f1
9d7b040
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import optuna
from optuna.trial import TrialState
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
from configs import *
import data_loader
from torch.utils.tensorboard import SummaryWriter

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 10
N_TRIALS = 50
TIMEOUT = None

# Create a TensorBoard writer
writer = SummaryWriter(log_dir="output/tensorboard/tuning")


def create_data_loaders(batch_size):
    # Create or modify data loaders with the specified batch size
    train_loader, valid_loader = data_loader.load_data(
        RAW_DATA_DIR + str(TASK),
        AUG_DATA_DIR + str(TASK),
        EXTERNAL_DATA_DIR + str(TASK),
        preprocess,
        batch_size=batch_size,
    )
    return train_loader, valid_loader


def objective(trial, model=MODEL):
    # Generate the model.
    model = model.to(DEVICE)

    # Suggest batch size for tuning.
    batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])

    # Create data loaders with the suggested batch size.
    train_loader, valid_loader = create_data_loaders(batch_size)

    # Generate the optimizer.
    lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    # Suggest the gamma parameter for the learning rate scheduler.
    gamma = trial.suggest_float("gamma", 0.1, 1.0, step=0.1)

    # Create a learning rate scheduler with the suggested gamma.
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)

    # Training of the model.
    for epoch in range(EPOCHS):
        print(f"[Epoch: {epoch} | Trial: {trial.number}]")
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader, 0):
            data, target = data.to(DEVICE), target.to(DEVICE)
            optimizer.zero_grad()
            if (
                model.__class__.__name__ == "GoogLeNet"
            ):  # the shit GoogLeNet has a different output
                output = model(data).logits
            else:
                output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

        # Update the learning rate using the scheduler.
        scheduler.step()

        # Validation of the model.
        model.eval()
        correct = 0
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(valid_loader, 0):
                data, target = data.to(DEVICE), target.to(DEVICE)
                output = model(data)
                # Get the index of the max log-probability.
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

        accuracy = correct / len(valid_loader.dataset)

        # Log hyperparameters and accuracy to TensorBoard
        writer.add_scalar("Accuracy", accuracy, trial.number)
        writer.add_hparams(
            {"batch_size": batch_size, "lr": lr, "gamma": gamma},
            {"accuracy": accuracy},
        )

        # Print hyperparameters and accuracy
        print("Hyperparameters: ", trial.params)
        print("Accuracy: ", accuracy)
        trial.report(accuracy, epoch)

        # Handle pruning based on the intermediate value.
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    if trial.number > 10 and trial.params["lr"] < 1e-3 and accuracy < 0.7:
        return float("inf")  # Prune the trial

    return accuracy


if __name__ == "__main__":
    pruner = optuna.pruners.HyperbandPruner()
    study = optuna.create_study(
        direction="maximize",  # Adjust the direction as per your optimization goal
        pruner=pruner,
        study_name="hyperparameter_tuning",
    )

    # Optimize the hyperparameters
    study.optimize(
        objective, n_trials=N_TRIALS, timeout=TIMEOUT
    )

    # Print the best trial
    best_trial = study.best_trial
    print("Best trial:")
    print("  Value: ", best_trial.value)
    print("  Params: ")
    for key, value in best_trial.params.items():
        print("    {}: {}".format(key, value))