|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ |
|
CenterCrop |
|
from transformers import BertTokenizer |
|
import pytorch_lightning as pl |
|
from PIL import Image |
|
import os |
|
|
|
|
|
class flickr30k_CNA(Dataset): |
|
def __init__(self, img_root_path, |
|
annot_path, |
|
transform=None): |
|
self.images = [] |
|
self.captions = [] |
|
self.labels = [] |
|
self.root = img_root_path |
|
with open(annot_path, 'r') as f: |
|
for line in f: |
|
line = line.strip().split('\t') |
|
key, caption = line[0].split('#')[0], line[1] |
|
img_path = key + '.jpg' |
|
self.images.append(img_path) |
|
self.captions.append(caption) |
|
self.labels.append(key) |
|
self.transforms = transform |
|
self.tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext") |
|
|
|
|
|
self.context_length = 77 |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
def __getitem__(self, idx): |
|
img_path = str(self.images[idx]) |
|
image = self.transforms(Image.open(os.path.join(self.root, img_path))) |
|
text = self.tokenizer(str(self.captions[idx]), max_length=self.context_length, |
|
padding='max_length', truncation=True, return_tensors='pt')['input_ids'][0] |
|
label = self.labels[idx] |
|
return image, text, label |
|
|
|
|
|
def _convert_to_rgb(image): |
|
return image.convert('RGB') |
|
|
|
|
|
def image_transform( |
|
image_size: int, |
|
is_train: bool, |
|
mean=(0.48145466, 0.4578275, 0.40821073), |
|
std=(0.26862954, 0.26130258, 0.27577711) |
|
): |
|
normalize = Normalize(mean=mean, std=std) |
|
if is_train: |
|
return Compose([ |
|
RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), |
|
_convert_to_rgb, |
|
ToTensor(), |
|
normalize, |
|
]) |
|
else: |
|
return Compose([ |
|
Resize(image_size, interpolation=InterpolationMode.BICUBIC), |
|
CenterCrop(image_size), |
|
_convert_to_rgb, |
|
ToTensor(), |
|
normalize, |
|
]) |
|
|
|
|
|
class FlickrDataModule(pl.LightningDataModule): |
|
def __init__(self, args): |
|
self.batch_size = args.batch_size |
|
self.train_filename = args.train_filename |
|
self.train_root = args.train_root |
|
self.val_filename = args.val_filename |
|
self.val_root = args.val_root |
|
self.test_filename = args.test_filename |
|
self.test_root = args.test_root |
|
|
|
self.pretrain_model = args.pretrain_model |
|
self.image_size = 224 |
|
self.prepare_data_per_node = True |
|
self._log_hyperparams = False |
|
self.num_workers = args.num_workers |
|
|
|
def setup(self, stage=None): |
|
|
|
train_transform = image_transform(224, True) |
|
val_transform = image_transform(224, False) |
|
test_transform = image_transform(224, False) |
|
|
|
self.train_dataset = flickr30k_CNA(self.train_root, self.train_filename, transform=train_transform) |
|
self.val_dataset = flickr30k_CNA(self.val_root, self.val_filename, transform=val_transform) |
|
self.test_dataset = flickr30k_CNA(self.test_root, self.test_filename, transform=test_transform) |
|
|
|
def train_dataloader(self): |
|
return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers) |
|
|
|
def val_dataloader(self): |
|
return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers) |
|
|
|
def test_dataloader(self): |
|
return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers) |
|
|