File size: 1,427 Bytes
e6f2a04
 
 
 
 
9d7b040
e6f2a04
 
 
 
 
 
9d7b040
 
 
 
e6f2a04
 
 
 
 
 
 
 
 
 
 
 
9d7b040
e6f2a04
 
9d7b040
e6f2a04
9d7b040
e6f2a04
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from configs import *
from torchvision.datasets import ImageFolder
from torch.utils.data import random_split, DataLoader, Dataset


def load_data(raw_dir, augmented_dir, external_dir, preprocess, batch_size=BATCH_SIZE):
    # Load the dataset using ImageFolder
    raw_dataset = ImageFolder(root=raw_dir, transform=preprocess)
    external_dataset = ImageFolder(root=external_dir, transform=preprocess)
    augmented_dataset = ImageFolder(root=augmented_dir, transform=preprocess)
    dataset = raw_dataset + external_dataset + augmented_dataset

    # Classes
    classes = augmented_dataset.classes

    print("Classes: ", *classes, sep=", ")
    print("Length of raw dataset: ", len(raw_dataset))
    print("Length of external dataset: ", len(external_dataset))
    print("Length of augmented dataset: ", len(augmented_dataset))
    print("Length of total dataset: ", len(dataset))

    # Split the dataset into train and validation sets
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    # Create data loaders for the custom dataset
    train_loader = DataLoader(
        CustomDataset(train_dataset), batch_size=batch_size, shuffle=True, num_workers=0
    )
    valid_loader = DataLoader(
        CustomDataset(val_dataset), batch_size=batch_size, num_workers=0
    )

    return train_loader, valid_loader