File size: 2,517 Bytes
5f028d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
from torch.utils.data import Dataset
from torchvision.transforms import Normalize, ToTensor, Compose
import numpy as np
import cv2

from lib.core import constants
from lib.utils.imutils import crop, boxes_2_cs


class TrackDatasetEval(Dataset):
    """
    Track Dataset Class - Load images/crops of the tracked boxes.
    """
    def __init__(self, imgfiles, boxes, 
                 crop_size=256, dilate=1.0,
                img_focal=None, img_center=None, normalization=True,
                item_idx=0, do_flip=False):
        super(TrackDatasetEval, self).__init__()

        self.imgfiles = imgfiles
        self.crop_size = crop_size
        self.normalization = normalization
        self.normalize_img = Compose([
                            ToTensor(),
                            Normalize(mean=constants.IMG_NORM_MEAN, std=constants.IMG_NORM_STD)
                        ])

        self.boxes = boxes
        self.box_dilate = dilate
        self.centers, self.scales = boxes_2_cs(boxes)

        self.img_focal = img_focal
        self.img_center = img_center
        self.item_idx = item_idx
        self.do_flip = do_flip

    def __len__(self):
        return len(self.imgfiles)
    
    
    def __getitem__(self, index):
        item = {}
        imgfile = self.imgfiles[index]
        scale = self.scales[index] * self.box_dilate
        center = self.centers[index]

        img_focal = self.img_focal
        img_center = self.img_center

        img = cv2.imread(imgfile)[:,:,::-1]
        if self.do_flip:
            img = img[:, ::-1, :]
            img_width = img.shape[1]
            center[0] = img_width - center[0] - 1
        img_crop = crop(img, center, scale, 
                        [self.crop_size, self.crop_size], 
                        rot=0).astype('uint8')
        # cv2.imwrite('debug_crop.png', img_crop[:,:,::-1])
        
        if self.normalization:
            img_crop = self.normalize_img(img_crop)
        else:
            img_crop = torch.from_numpy(img_crop)
        item['img'] = img_crop
        
        if self.do_flip:
            # center[0] = img_width - center[0] - 1 
            item['do_flip'] = torch.tensor(1).float()
        item['img_idx'] = torch.tensor(index).long()
        item['scale'] = torch.tensor(scale).float()
        item['center'] = torch.tensor(center).float()
        item['img_focal'] = torch.tensor(img_focal).float()
        item['img_center'] = torch.tensor(img_center).float()
        

        return item