File size: 1,375 Bytes
e6f2a04
 
 
97dcf92
e6f2a04
97dcf92
e6f2a04
97dcf92
 
 
 
 
e6f2a04
9d7b040
97dcf92
9d7b040
 
e6f2a04
 
 
 
 
 
 
 
 
9d7b040
e6f2a04
 
97dcf92
e6f2a04
9d7b040
e6f2a04
eb01e5b
97dcf92
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from configs import *
from torchvision.datasets import ImageFolder
from torch.utils.data import random_split, DataLoader, Dataset
import torch

torch.manual_seed(RANDOM_SEED)

# Set seed
torch.manual_seed(RANDOM_SEED)

def load_data(combined_dir, preprocess, batch_size=BATCH_SIZE):
    dataset = ImageFolder(combined_dir, transform=preprocess)

    # Classes
    classes = dataset.classes

    print("Classes: ", *classes, sep=", ")
    print("Length of total dataset: ", len(dataset))

    # Split the dataset into train and validation sets
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    # Create data loaders for the custom dataset
    train_loader = DataLoader(
        CustomDataset(train_dataset), batch_size=batch_size, shuffle=True, num_workers=0
    )
    valid_loader = DataLoader(
        CustomDataset(val_dataset), batch_size=batch_size, num_workers=0, shuffle=False
    )

    return train_loader, valid_loader


def load_test_data(test_dir, preprocess, batch_size=BATCH_SIZE):
    test_dataset = ImageFolder(test_dir, transform=preprocess)

    # Create a DataLoader for the test data
    test_dataloader = DataLoader(
        CustomDataset(test_dataset), batch_size=batch_size, shuffle=False, num_workers=0
    )

    return test_dataloader