File size: 3,836 Bytes
8ebda9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
    CenterCrop
from transformers import BertTokenizer
import pytorch_lightning as pl
from PIL import Image
import os


class flickr30k_CNA(Dataset):
    def __init__(self, img_root_path,
                 annot_path,
                 transform=None):
        self.images = []
        self.captions = []
        self.labels = []
        self.root = img_root_path
        with open(annot_path, 'r') as f:
            for line in f:
                line = line.strip().split('\t')
                key, caption = line[0].split('#')[0], line[1]
                img_path = key + '.jpg'
                self.images.append(img_path)
                self.captions.append(caption)
                self.labels.append(key)
        self.transforms = transform
        self.tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")

        # NOTE large 模型
        self.context_length = 77

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = str(self.images[idx])
        image = self.transforms(Image.open(os.path.join(self.root, img_path)))
        text = self.tokenizer(str(self.captions[idx]), max_length=self.context_length,
                              padding='max_length', truncation=True, return_tensors='pt')['input_ids'][0]
        label = self.labels[idx]
        return image, text, label


def _convert_to_rgb(image):
    return image.convert('RGB')


def image_transform(
        image_size: int,
        is_train: bool,
        mean=(0.48145466, 0.4578275, 0.40821073),
        std=(0.26862954, 0.26130258, 0.27577711)
):
    normalize = Normalize(mean=mean, std=std)
    if is_train:
        return Compose([
            RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
            _convert_to_rgb,
            ToTensor(),
            normalize,
        ])
    else:
        return Compose([
            Resize(image_size, interpolation=InterpolationMode.BICUBIC),
            CenterCrop(image_size),
            _convert_to_rgb,
            ToTensor(),
            normalize,
        ])


class FlickrDataModule(pl.LightningDataModule):
    def __init__(self, args):
        self.batch_size = args.batch_size
        self.train_filename = args.train_filename  # NOTE 标注的文件夹
        self.train_root = args.train_root  # NOTE 图片地址
        self.val_filename = args.val_filename
        self.val_root = args.val_root
        self.test_filename = args.test_filename
        self.test_root = args.test_root

        self.pretrain_model = args.pretrain_model
        self.image_size = 224
        self.prepare_data_per_node = True
        self._log_hyperparams = False
        self.num_workers = args.num_workers

    def setup(self, stage=None):
        # dataset
        train_transform = image_transform(224, True)
        val_transform = image_transform(224, False)
        test_transform = image_transform(224, False)

        self.train_dataset = flickr30k_CNA(self.train_root, self.train_filename, transform=train_transform)
        self.val_dataset = flickr30k_CNA(self.val_root, self.val_filename, transform=val_transform)
        self.test_dataset = flickr30k_CNA(self.test_root, self.test_filename, transform=test_transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)