HaWoR / lib /datasets /track_dataset.py
ThunderVVV's picture
update
5f028d6
raw
history blame
2.52 kB
import torch
from torch.utils.data import Dataset
from torchvision.transforms import Normalize, ToTensor, Compose
import numpy as np
import cv2
from lib.core import constants
from lib.utils.imutils import crop, boxes_2_cs
class TrackDatasetEval(Dataset):
"""
Track Dataset Class - Load images/crops of the tracked boxes.
"""
def __init__(self, imgfiles, boxes,
crop_size=256, dilate=1.0,
img_focal=None, img_center=None, normalization=True,
item_idx=0, do_flip=False):
super(TrackDatasetEval, self).__init__()
self.imgfiles = imgfiles
self.crop_size = crop_size
self.normalization = normalization
self.normalize_img = Compose([
ToTensor(),
Normalize(mean=constants.IMG_NORM_MEAN, std=constants.IMG_NORM_STD)
])
self.boxes = boxes
self.box_dilate = dilate
self.centers, self.scales = boxes_2_cs(boxes)
self.img_focal = img_focal
self.img_center = img_center
self.item_idx = item_idx
self.do_flip = do_flip
def __len__(self):
return len(self.imgfiles)
def __getitem__(self, index):
item = {}
imgfile = self.imgfiles[index]
scale = self.scales[index] * self.box_dilate
center = self.centers[index]
img_focal = self.img_focal
img_center = self.img_center
img = cv2.imread(imgfile)[:,:,::-1]
if self.do_flip:
img = img[:, ::-1, :]
img_width = img.shape[1]
center[0] = img_width - center[0] - 1
img_crop = crop(img, center, scale,
[self.crop_size, self.crop_size],
rot=0).astype('uint8')
# cv2.imwrite('debug_crop.png', img_crop[:,:,::-1])
if self.normalization:
img_crop = self.normalize_img(img_crop)
else:
img_crop = torch.from_numpy(img_crop)
item['img'] = img_crop
if self.do_flip:
# center[0] = img_width - center[0] - 1
item['do_flip'] = torch.tensor(1).float()
item['img_idx'] = torch.tensor(index).long()
item['scale'] = torch.tensor(scale).float()
item['center'] = torch.tensor(center).float()
item['img_focal'] = torch.tensor(img_focal).float()
item['img_center'] = torch.tensor(img_center).float()
return item