Spaces:
Running
Running
import torch | |
from torch.utils.data import Dataset | |
from torchvision.transforms import Normalize, ToTensor, Compose | |
import numpy as np | |
import cv2 | |
from lib.core import constants | |
from lib.utils.imutils import crop, boxes_2_cs | |
class TrackDatasetEval(Dataset): | |
""" | |
Track Dataset Class - Load images/crops of the tracked boxes. | |
""" | |
def __init__(self, imgfiles, boxes, | |
crop_size=256, dilate=1.0, | |
img_focal=None, img_center=None, normalization=True, | |
item_idx=0, do_flip=False): | |
super(TrackDatasetEval, self).__init__() | |
self.imgfiles = imgfiles | |
self.crop_size = crop_size | |
self.normalization = normalization | |
self.normalize_img = Compose([ | |
ToTensor(), | |
Normalize(mean=constants.IMG_NORM_MEAN, std=constants.IMG_NORM_STD) | |
]) | |
self.boxes = boxes | |
self.box_dilate = dilate | |
self.centers, self.scales = boxes_2_cs(boxes) | |
self.img_focal = img_focal | |
self.img_center = img_center | |
self.item_idx = item_idx | |
self.do_flip = do_flip | |
def __len__(self): | |
return len(self.imgfiles) | |
def __getitem__(self, index): | |
item = {} | |
imgfile = self.imgfiles[index] | |
scale = self.scales[index] * self.box_dilate | |
center = self.centers[index] | |
img_focal = self.img_focal | |
img_center = self.img_center | |
img = cv2.imread(imgfile)[:,:,::-1] | |
if self.do_flip: | |
img = img[:, ::-1, :] | |
img_width = img.shape[1] | |
center[0] = img_width - center[0] - 1 | |
img_crop = crop(img, center, scale, | |
[self.crop_size, self.crop_size], | |
rot=0).astype('uint8') | |
# cv2.imwrite('debug_crop.png', img_crop[:,:,::-1]) | |
if self.normalization: | |
img_crop = self.normalize_img(img_crop) | |
else: | |
img_crop = torch.from_numpy(img_crop) | |
item['img'] = img_crop | |
if self.do_flip: | |
# center[0] = img_width - center[0] - 1 | |
item['do_flip'] = torch.tensor(1).float() | |
item['img_idx'] = torch.tensor(index).long() | |
item['scale'] = torch.tensor(scale).float() | |
item['center'] = torch.tensor(center).float() | |
item['img_focal'] = torch.tensor(img_focal).float() | |
item['img_center'] = torch.tensor(img_center).float() | |
return item | |