Spaces:
Sleeping
Sleeping
import cv2 | |
import numpy as np | |
# np.set_printoptions(threshold=np.inf) | |
import random | |
import torch | |
import torchvision.transforms as transforms | |
# from visualization import plot_img_and_mask,plot_one_box,show_seg_result | |
from pathlib import Path | |
from PIL import Image | |
from torch.utils.data import Dataset | |
from ..utils import letterbox, augment_hsv, random_perspective, xyxy2xywh, cutout | |
class AutoDriveDataset(Dataset): | |
""" | |
A general Dataset for some common function | |
""" | |
def __init__(self, cfg, is_train, inputsize=640, transform=None): | |
""" | |
initial all the characteristic | |
Inputs: | |
-cfg: configurations | |
-is_train(bool): whether train set or not | |
-transform: ToTensor and Normalize | |
Returns: | |
None | |
""" | |
self.is_train = is_train | |
self.cfg = cfg | |
self.transform = transform | |
self.inputsize = inputsize | |
self.Tensor = transforms.ToTensor() | |
img_root = Path(cfg.DATASET.DATAROOT) | |
label_root = Path(cfg.DATASET.LABELROOT) | |
mask_root = Path(cfg.DATASET.MASKROOT) | |
lane_root = Path(cfg.DATASET.LANEROOT) | |
if is_train: | |
indicator = cfg.DATASET.TRAIN_SET | |
else: | |
indicator = cfg.DATASET.TEST_SET | |
self.img_root = img_root / indicator | |
self.label_root = label_root / indicator | |
self.mask_root = mask_root / indicator | |
self.lane_root = lane_root / indicator | |
# self.label_list = self.label_root.iterdir() | |
self.mask_list = self.mask_root.iterdir() | |
self.db = [] | |
self.data_format = cfg.DATASET.DATA_FORMAT | |
self.scale_factor = cfg.DATASET.SCALE_FACTOR | |
self.rotation_factor = cfg.DATASET.ROT_FACTOR | |
self.flip = cfg.DATASET.FLIP | |
self.color_rgb = cfg.DATASET.COLOR_RGB | |
# self.target_type = cfg.MODEL.TARGET_TYPE | |
self.shapes = np.array(cfg.DATASET.ORG_IMG_SIZE) | |
def _get_db(self): | |
""" | |
finished on children Dataset(for dataset which is not in Bdd100k format, rewrite children Dataset) | |
""" | |
raise NotImplementedError | |
def evaluate(self, cfg, preds, output_dir): | |
""" | |
finished on children dataset | |
""" | |
raise NotImplementedError | |
def __len__(self,): | |
""" | |
number of objects in the dataset | |
""" | |
return len(self.db) | |
def __getitem__(self, idx): | |
""" | |
Get input and groud-truth from database & add data augmentation on input | |
Inputs: | |
-idx: the index of image in self.db(database)(list) | |
self.db(list) [a,b,c,...] | |
a: (dictionary){'image':, 'information':} | |
Returns: | |
-image: transformed image, first passed the data augmentation in __getitem__ function(type:numpy), then apply self.transform | |
-target: ground truth(det_gt,seg_gt) | |
function maybe useful | |
cv2.imread | |
cv2.cvtColor(data, cv2.COLOR_BGR2RGB) | |
cv2.warpAffine | |
""" | |
data = self.db[idx] | |
img = cv2.imread(data["image"], cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
# seg_label = cv2.imread(data["mask"], 0) | |
if self.cfg.num_seg_class == 3: | |
seg_label = cv2.imread(data["mask"]) | |
else: | |
seg_label = cv2.imread(data["mask"], 0) | |
lane_label = cv2.imread(data["lane"], 0) | |
#print(lane_label.shape) | |
# print(seg_label.shape) | |
# print(lane_label.shape) | |
# print(seg_label.shape) | |
resized_shape = self.inputsize | |
if isinstance(resized_shape, list): | |
resized_shape = max(resized_shape) | |
h0, w0 = img.shape[:2] # orig hw | |
r = resized_shape / max(h0, w0) # resize image to img_size | |
if r != 1: # always resize down, only resize up if training with augmentation | |
interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR | |
img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp) | |
seg_label = cv2.resize(seg_label, (int(w0 * r), int(h0 * r)), interpolation=interp) | |
lane_label = cv2.resize(lane_label, (int(w0 * r), int(h0 * r)), interpolation=interp) | |
h, w = img.shape[:2] | |
(img, seg_label, lane_label), ratio, pad = letterbox((img, seg_label, lane_label), resized_shape, auto=True, scaleup=self.is_train) | |
shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling | |
# ratio = (w / w0, h / h0) | |
# print(resized_shape) | |
det_label = data["label"] | |
labels=[] | |
if det_label.size > 0: | |
# Normalized xywh to pixel xyxy format | |
labels = det_label.copy() | |
labels[:, 1] = ratio[0] * w * (det_label[:, 1] - det_label[:, 3] / 2) + pad[0] # pad width | |
labels[:, 2] = ratio[1] * h * (det_label[:, 2] - det_label[:, 4] / 2) + pad[1] # pad height | |
labels[:, 3] = ratio[0] * w * (det_label[:, 1] + det_label[:, 3] / 2) + pad[0] | |
labels[:, 4] = ratio[1] * h * (det_label[:, 2] + det_label[:, 4] / 2) + pad[1] | |
if self.is_train: | |
combination = (img, seg_label, lane_label) | |
(img, seg_label, lane_label), labels = random_perspective( | |
combination=combination, | |
targets=labels, | |
degrees=self.cfg.DATASET.ROT_FACTOR, | |
translate=self.cfg.DATASET.TRANSLATE, | |
scale=self.cfg.DATASET.SCALE_FACTOR, | |
shear=self.cfg.DATASET.SHEAR | |
) | |
#print(labels.shape) | |
augment_hsv(img, hgain=self.cfg.DATASET.HSV_H, sgain=self.cfg.DATASET.HSV_S, vgain=self.cfg.DATASET.HSV_V) | |
# img, seg_label, labels = cutout(combination=combination, labels=labels) | |
if len(labels): | |
# convert xyxy to xywh | |
labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) | |
# Normalize coordinates 0 - 1 | |
labels[:, [2, 4]] /= img.shape[0] # height | |
labels[:, [1, 3]] /= img.shape[1] # width | |
# if self.is_train: | |
# random left-right flip | |
lr_flip = True | |
if lr_flip and random.random() < 0.5: | |
img = np.fliplr(img) | |
seg_label = np.fliplr(seg_label) | |
lane_label = np.fliplr(lane_label) | |
if len(labels): | |
labels[:, 1] = 1 - labels[:, 1] | |
# random up-down flip | |
ud_flip = False | |
if ud_flip and random.random() < 0.5: | |
img = np.flipud(img) | |
seg_label = np.filpud(seg_label) | |
lane_label = np.filpud(lane_label) | |
if len(labels): | |
labels[:, 2] = 1 - labels[:, 2] | |
else: | |
if len(labels): | |
# convert xyxy to xywh | |
labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) | |
# Normalize coordinates 0 - 1 | |
labels[:, [2, 4]] /= img.shape[0] # height | |
labels[:, [1, 3]] /= img.shape[1] # width | |
labels_out = torch.zeros((len(labels), 6)) | |
if len(labels): | |
labels_out[:, 1:] = torch.from_numpy(labels) | |
# Convert | |
# img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 | |
# img = img.transpose(2, 0, 1) | |
img = np.ascontiguousarray(img) | |
# seg_label = np.ascontiguousarray(seg_label) | |
# if idx == 0: | |
# print(seg_label[:,:,0]) | |
if self.cfg.num_seg_class == 3: | |
_,seg0 = cv2.threshold(seg_label[:,:,0],128,255,cv2.THRESH_BINARY) | |
_,seg1 = cv2.threshold(seg_label[:,:,1],1,255,cv2.THRESH_BINARY) | |
_,seg2 = cv2.threshold(seg_label[:,:,2],1,255,cv2.THRESH_BINARY) | |
else: | |
_,seg1 = cv2.threshold(seg_label,1,255,cv2.THRESH_BINARY) | |
_,seg2 = cv2.threshold(seg_label,1,255,cv2.THRESH_BINARY_INV) | |
_,lane1 = cv2.threshold(lane_label,1,255,cv2.THRESH_BINARY) | |
_,lane2 = cv2.threshold(lane_label,1,255,cv2.THRESH_BINARY_INV) | |
# _,seg2 = cv2.threshold(seg_label[:,:,2],1,255,cv2.THRESH_BINARY) | |
# # seg1[cutout_mask] = 0 | |
# # seg2[cutout_mask] = 0 | |
# seg_label /= 255 | |
# seg0 = self.Tensor(seg0) | |
if self.cfg.num_seg_class == 3: | |
seg0 = self.Tensor(seg0) | |
seg1 = self.Tensor(seg1) | |
seg2 = self.Tensor(seg2) | |
# seg1 = self.Tensor(seg1) | |
# seg2 = self.Tensor(seg2) | |
lane1 = self.Tensor(lane1) | |
lane2 = self.Tensor(lane2) | |
# seg_label = torch.stack((seg2[0], seg1[0]),0) | |
if self.cfg.num_seg_class == 3: | |
seg_label = torch.stack((seg0[0],seg1[0],seg2[0]),0) | |
else: | |
seg_label = torch.stack((seg2[0], seg1[0]),0) | |
lane_label = torch.stack((lane2[0], lane1[0]),0) | |
# _, gt_mask = torch.max(seg_label, 0) | |
# _ = show_seg_result(img, gt_mask, idx, 0, save_dir='debug', is_gt=True) | |
target = [labels_out, seg_label, lane_label] | |
img = self.transform(img) | |
return img, target, data["image"], shapes | |
def select_data(self, db): | |
""" | |
You can use this function to filter useless images in the dataset | |
Inputs: | |
-db: (list)database | |
Returns: | |
-db_selected: (list)filtered dataset | |
""" | |
db_selected = ... | |
return db_selected | |
def collate_fn(batch): | |
img, label, paths, shapes= zip(*batch) | |
label_det, label_seg, label_lane = [], [], [] | |
for i, l in enumerate(label): | |
l_det, l_seg, l_lane = l | |
l_det[:, 0] = i # add target image index for build_targets() | |
label_det.append(l_det) | |
label_seg.append(l_seg) | |
label_lane.append(l_lane) | |
return torch.stack(img, 0), [torch.cat(label_det, 0), torch.stack(label_seg, 0), torch.stack(label_lane, 0)], paths, shapes | |