|
from torch.utils.data import Dataset |
|
from utils.misc_utils import read_pfm |
|
import os |
|
import numpy as np |
|
import cv2 |
|
from PIL import Image |
|
import torch |
|
from torchvision import transforms as T |
|
from data.scene import get_boundingbox |
|
|
|
from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image |
|
|
|
from termcolor import colored |
|
import pdb |
|
import random |
|
|
|
|
|
def load_K_Rt_from_P(filename, P=None): |
|
if P is None: |
|
lines = open(filename).read().splitlines() |
|
if len(lines) == 4: |
|
lines = lines[1:] |
|
lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] |
|
P = np.asarray(lines).astype(np.float32).squeeze() |
|
|
|
out = cv2.decomposeProjectionMatrix(P) |
|
K = out[0] |
|
R = out[1] |
|
t = out[2] |
|
|
|
K = K / K[2, 2] |
|
intrinsics = np.eye(4) |
|
intrinsics[:3, :3] = K |
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
pose[:3, :3] = R.transpose() |
|
pose[:3, 3] = (t[:3] / t[3])[:, 0] |
|
|
|
return intrinsics, pose |
|
|
|
|
|
|
|
class MVSDatasetDtuPerView(Dataset): |
|
def __init__(self, root_dir, split, n_views=3, img_wh=(640, 512), downSample=1.0, |
|
split_filepath=None, pair_filepath=None, |
|
N_rays=512, |
|
vol_dims=[128, 128, 128], batch_size=1, |
|
clean_image=False, importance_sample=False, test_ref_views=[]): |
|
|
|
self.root_dir = root_dir |
|
self.split = split |
|
|
|
self.img_wh = img_wh |
|
self.downSample = downSample |
|
self.num_all_imgs = 49 |
|
self.n_views = n_views |
|
self.N_rays = N_rays |
|
self.batch_size = batch_size |
|
|
|
self.clean_image = clean_image |
|
self.importance_sample = importance_sample |
|
self.test_ref_views = test_ref_views |
|
self.scale_factor = 1.0 |
|
self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) |
|
|
|
if img_wh is not None: |
|
assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ |
|
'img_wh must both be multiples of 32!' |
|
|
|
self.split_filepath = f'data/dtu/lists/{self.split}.txt' if split_filepath is None else split_filepath |
|
self.pair_filepath = f'data/dtu/dtu_pairs.txt' if pair_filepath is None else pair_filepath |
|
|
|
print(colored("loading all scenes together", 'red')) |
|
with open(self.split_filepath) as f: |
|
self.scans = [line.rstrip() for line in f.readlines()] |
|
|
|
self.all_intrinsics = [] |
|
self.all_extrinsics = [] |
|
self.all_near_fars = [] |
|
|
|
self.metas, self.ref_src_pairs = self.build_metas() |
|
|
|
self.allview_ids = [i for i in range(self.num_all_imgs)] |
|
|
|
self.load_cam_info() |
|
|
|
self.build_remap() |
|
self.define_transforms() |
|
print(f'==> image down scale: {self.downSample}') |
|
|
|
|
|
self.bbox_min = np.array([-1.0, -1.0, -1.0]) |
|
self.bbox_max = np.array([1.0, 1.0, 1.0]) |
|
|
|
|
|
self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) |
|
self.partial_vol_origin = torch.Tensor([-1., -1., -1.]) |
|
|
|
def build_remap(self): |
|
self.remap = np.zeros(np.max(self.allview_ids) + 1).astype('int') |
|
for i, item in enumerate(self.allview_ids): |
|
self.remap[item] = i |
|
|
|
def define_transforms(self): |
|
self.transform = T.Compose([T.ToTensor()]) |
|
|
|
def build_metas(self): |
|
metas = [] |
|
ref_src_pairs = {} |
|
|
|
|
|
light_idxs = [3] if 'train' not in self.split else range(7) |
|
|
|
with open(self.pair_filepath) as f: |
|
num_viewpoint = int(f.readline()) |
|
|
|
for _ in range(num_viewpoint): |
|
ref_view = int(f.readline().rstrip()) |
|
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] |
|
|
|
ref_src_pairs[ref_view] = src_views |
|
|
|
for light_idx in light_idxs: |
|
for scan in self.scans: |
|
with open(self.pair_filepath) as f: |
|
num_viewpoint = int(f.readline()) |
|
|
|
for _ in range(num_viewpoint): |
|
ref_view = int(f.readline().rstrip()) |
|
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] |
|
|
|
|
|
if len(self.test_ref_views) > 0 and ref_view not in self.test_ref_views: |
|
continue |
|
|
|
metas += [(scan, light_idx, ref_view, src_views)] |
|
|
|
return metas, ref_src_pairs |
|
|
|
def read_cam_file(self, filename): |
|
with open(filename) as f: |
|
lines = [line.rstrip() for line in f.readlines()] |
|
|
|
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ') |
|
extrinsics = extrinsics.reshape((4, 4)) |
|
|
|
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ') |
|
intrinsics = intrinsics.reshape((3, 3)) |
|
|
|
depth_min = float(lines[11].split()[0]) |
|
depth_max = depth_min + float(lines[11].split()[1]) * 192 |
|
self.depth_interval = float(lines[11].split()[1]) |
|
intrinsics_ = np.float32(np.diag([1, 1, 1, 1])) |
|
intrinsics_[:3, :3] = intrinsics |
|
return intrinsics_, extrinsics, [depth_min, depth_max] |
|
|
|
def load_cam_info(self): |
|
for vid in range(self.num_all_imgs): |
|
proj_mat_filename = os.path.join(self.root_dir, |
|
f'Cameras/train/{vid:08d}_cam.txt') |
|
intrinsic, extrinsic, near_far = self.read_cam_file(proj_mat_filename) |
|
intrinsic[:2] *= 4 |
|
self.all_intrinsics.append(intrinsic) |
|
self.all_extrinsics.append(extrinsic) |
|
self.all_near_fars.append(near_far) |
|
|
|
def read_depth(self, filename): |
|
|
|
depth_h = np.array(read_pfm(filename)[0], dtype=np.float32) |
|
depth_h = np.ones((1200, 1600)) |
|
|
|
depth_h = cv2.resize(depth_h, None, fx=0.5, fy=0.5, |
|
interpolation=cv2.INTER_NEAREST) |
|
depth_h = depth_h[44:556, 80:720] |
|
|
|
|
|
depth_h = cv2.resize(depth_h, None, fx=self.downSample, fy=self.downSample, |
|
interpolation=cv2.INTER_NEAREST) |
|
depth = cv2.resize(depth_h, None, fx=1.0 / 4, fy=1.0 / 4, |
|
interpolation=cv2.INTER_NEAREST) |
|
|
|
return depth, depth_h |
|
|
|
def read_mask(self, filename): |
|
mask_h = cv2.imread(filename, 0) |
|
mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, |
|
interpolation=cv2.INTER_NEAREST) |
|
mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, |
|
interpolation=cv2.INTER_NEAREST) |
|
|
|
mask[mask > 0] = 1 |
|
mask_h[mask_h > 0] = 1 |
|
|
|
return mask, mask_h |
|
|
|
def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): |
|
center, radius, _ = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) |
|
radius = radius * factor |
|
scale_mat = np.diag([radius, radius, radius, 1.0]) |
|
scale_mat[:3, 3] = center.cpu().numpy() |
|
scale_mat = scale_mat.astype(np.float32) |
|
|
|
return scale_mat, 1. / radius.cpu().numpy() |
|
|
|
def __len__(self): |
|
return len(self.metas) |
|
|
|
def __getitem__(self, idx): |
|
sample = {} |
|
scan, light_idx, ref_view, src_views = self.metas[idx % len(self.metas)] |
|
|
|
|
|
view_ids = [ref_view] + src_views[:self.n_views] |
|
|
|
w2c_ref = self.all_extrinsics[self.remap[ref_view]] |
|
w2c_ref_inv = np.linalg.inv(w2c_ref) |
|
|
|
image_perm = 0 |
|
|
|
imgs, depths_h, masks_h = [], [], [] |
|
intrinsics, w2cs, near_fars = [], [], [] |
|
mask_dilated = None |
|
for i, vid in enumerate(view_ids): |
|
|
|
img_filename = os.path.join(self.root_dir, |
|
f'Rectified/{scan}_train/rect_{vid + 1:03d}_{light_idx}_r5000.png') |
|
depth_filename = os.path.join(self.root_dir, |
|
f'Depths/{scan}_train/depth_map_{vid:04d}.pfm') |
|
|
|
mask_filename = os.path.join(self.root_dir, |
|
f'Masks_clean_dilated/{scan}_train/mask_{vid:04d}.png') |
|
|
|
img = Image.open(img_filename) |
|
img_wh = np.round(np.array(img.size) * self.downSample).astype('int') |
|
img = img.resize(img_wh, Image.BILINEAR) |
|
|
|
if os.path.exists(mask_filename) and self.clean_image: |
|
mask_l, mask_h = self.read_mask(mask_filename) |
|
else: |
|
|
|
mask_h = np.ones([img_wh[1], img_wh[0]]) |
|
masks_h.append(mask_h) |
|
|
|
if i == 0: |
|
kernel_size = 101 |
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) |
|
mask_dilated = np.float32(cv2.dilate(np.uint8(mask_h * 255), kernel, iterations=1) > 128) |
|
|
|
if self.clean_image: |
|
img = np.array(img) |
|
img[mask_h < 0.5] = 0.0 |
|
|
|
img = self.transform(img) |
|
|
|
imgs += [img] |
|
|
|
index_mat = self.remap[vid] |
|
near_fars.append(self.all_near_fars[index_mat]) |
|
intrinsics.append(self.all_intrinsics[index_mat]) |
|
|
|
w2cs.append(self.all_extrinsics[index_mat] @ w2c_ref_inv) |
|
|
|
|
|
if os.path.exists(depth_filename): |
|
|
|
depth_l, depth_h = self.read_depth(depth_filename) |
|
depths_h.append(depth_h) |
|
|
|
scale_mat, scale_factor = self.cal_scale_mat(img_hw=[img_wh[1], img_wh[0]], |
|
intrinsics=intrinsics, extrinsics=w2cs, |
|
near_fars=near_fars, factor=1.1) |
|
|
|
|
|
new_near_fars = [] |
|
new_w2cs = [] |
|
new_c2ws = [] |
|
new_affine_mats = [] |
|
new_depths_h = [] |
|
for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): |
|
P = intrinsic @ extrinsic @ scale_mat |
|
P = P[:3, :4] |
|
|
|
c2w = load_K_Rt_from_P(None, P)[1] |
|
w2c = np.linalg.inv(c2w) |
|
new_w2cs.append(w2c) |
|
new_c2ws.append(c2w) |
|
affine_mat = np.eye(4) |
|
affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] |
|
new_affine_mats.append(affine_mat) |
|
|
|
camera_o = c2w[:3, 3] |
|
dist = np.sqrt(np.sum(camera_o ** 2)) |
|
near = dist - 1 |
|
far = dist + 1 |
|
|
|
new_near_fars.append([0.95 * near, 1.05 * far]) |
|
new_depths_h.append(depth * scale_factor) |
|
|
|
imgs = torch.stack(imgs).float() |
|
print(new_near_fars) |
|
depths_h = np.stack(new_depths_h) |
|
masks_h = np.stack(masks_h) |
|
|
|
affine_mats = np.stack(new_affine_mats) |
|
intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( |
|
new_near_fars) |
|
|
|
if 'train' in self.split: |
|
start_idx = 0 |
|
else: |
|
start_idx = 1 |
|
|
|
sample['images'] = imgs |
|
sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) |
|
sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) |
|
sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) |
|
sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) |
|
sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) |
|
sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] |
|
sample['view_ids'] = torch.from_numpy(np.array(view_ids)) |
|
sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) |
|
|
|
sample['light_idx'] = torch.tensor(light_idx) |
|
sample['scan'] = scan |
|
|
|
sample['scale_factor'] = torch.tensor(scale_factor) |
|
sample['img_wh'] = torch.from_numpy(img_wh) |
|
sample['render_img_idx'] = torch.tensor(image_perm) |
|
sample['partial_vol_origin'] = torch.tensor(self.partial_vol_origin, dtype=torch.float32) |
|
sample['meta'] = str(scan) + "_light" + str(light_idx) + "_refview" + str(ref_view) |
|
|
|
|
|
sample['query_image'] = sample['images'][0] |
|
sample['query_c2w'] = sample['c2ws'][0] |
|
sample['query_w2c'] = sample['w2cs'][0] |
|
sample['query_intrinsic'] = sample['intrinsics'][0] |
|
sample['query_depth'] = sample['depths_h'][0] |
|
sample['query_mask'] = sample['masks_h'][0] |
|
sample['query_near_far'] = sample['near_fars'][0] |
|
|
|
sample['images'] = sample['images'][start_idx:] |
|
sample['depths_h'] = sample['depths_h'][start_idx:] |
|
sample['masks_h'] = sample['masks_h'][start_idx:] |
|
sample['w2cs'] = sample['w2cs'][start_idx:] |
|
sample['c2ws'] = sample['c2ws'][start_idx:] |
|
sample['intrinsics'] = sample['intrinsics'][start_idx:] |
|
sample['view_ids'] = sample['view_ids'][start_idx:] |
|
sample['affine_mats'] = sample['affine_mats'][start_idx:] |
|
|
|
sample['scale_mat'] = torch.from_numpy(scale_mat) |
|
sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) |
|
|
|
|
|
if ('val' in self.split) or ('test' in self.split): |
|
sample_rays = gen_rays_from_single_image( |
|
img_wh[1], img_wh[0], |
|
sample['query_image'], |
|
sample['query_intrinsic'], |
|
sample['query_c2w'], |
|
depth=sample['query_depth'], |
|
mask=sample['query_mask'] if self.clean_image else None) |
|
else: |
|
sample_rays = gen_random_rays_from_single_image( |
|
img_wh[1], img_wh[0], |
|
self.N_rays, |
|
sample['query_image'], |
|
sample['query_intrinsic'], |
|
sample['query_c2w'], |
|
depth=sample['query_depth'], |
|
mask=sample['query_mask'] if self.clean_image else None, |
|
dilated_mask=mask_dilated, |
|
importance_sample=self.importance_sample) |
|
|
|
sample['rays'] = sample_rays |
|
|
|
return sample |
|
|