Spaces:
Running
Running
File size: 2,517 Bytes
5f028d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import torch
from torch.utils.data import Dataset
from torchvision.transforms import Normalize, ToTensor, Compose
import numpy as np
import cv2
from lib.core import constants
from lib.utils.imutils import crop, boxes_2_cs
class TrackDatasetEval(Dataset):
"""
Track Dataset Class - Load images/crops of the tracked boxes.
"""
def __init__(self, imgfiles, boxes,
crop_size=256, dilate=1.0,
img_focal=None, img_center=None, normalization=True,
item_idx=0, do_flip=False):
super(TrackDatasetEval, self).__init__()
self.imgfiles = imgfiles
self.crop_size = crop_size
self.normalization = normalization
self.normalize_img = Compose([
ToTensor(),
Normalize(mean=constants.IMG_NORM_MEAN, std=constants.IMG_NORM_STD)
])
self.boxes = boxes
self.box_dilate = dilate
self.centers, self.scales = boxes_2_cs(boxes)
self.img_focal = img_focal
self.img_center = img_center
self.item_idx = item_idx
self.do_flip = do_flip
def __len__(self):
return len(self.imgfiles)
def __getitem__(self, index):
item = {}
imgfile = self.imgfiles[index]
scale = self.scales[index] * self.box_dilate
center = self.centers[index]
img_focal = self.img_focal
img_center = self.img_center
img = cv2.imread(imgfile)[:,:,::-1]
if self.do_flip:
img = img[:, ::-1, :]
img_width = img.shape[1]
center[0] = img_width - center[0] - 1
img_crop = crop(img, center, scale,
[self.crop_size, self.crop_size],
rot=0).astype('uint8')
# cv2.imwrite('debug_crop.png', img_crop[:,:,::-1])
if self.normalization:
img_crop = self.normalize_img(img_crop)
else:
img_crop = torch.from_numpy(img_crop)
item['img'] = img_crop
if self.do_flip:
# center[0] = img_width - center[0] - 1
item['do_flip'] = torch.tensor(1).float()
item['img_idx'] = torch.tensor(index).long()
item['scale'] = torch.tensor(scale).float()
item['center'] = torch.tensor(center).float()
item['img_focal'] = torch.tensor(img_focal).float()
item['img_center'] = torch.tensor(img_center).float()
return item
|