Text-to-3D
image-to-3d
code / SparseNeuS_demo_v1 /data /dtu_general.py
Chao Xu
sparseneus and elev est
854f0d0
raw
history blame
15.8 kB
from torch.utils.data import Dataset
from utils.misc_utils import read_pfm
import os
import numpy as np
import cv2
from PIL import Image
import torch
from torchvision import transforms as T
from data.scene import get_boundingbox
from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
from termcolor import colored
import pdb
import random
def load_K_Rt_from_P(filename, P=None):
if P is None:
lines = open(filename).read().splitlines()
if len(lines) == 4:
lines = lines[1:]
lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
P = np.asarray(lines).astype(np.float32).squeeze()
out = cv2.decomposeProjectionMatrix(P)
K = out[0]
R = out[1]
t = out[2]
K = K / K[2, 2]
intrinsics = np.eye(4)
intrinsics[:3, :3] = K
pose = np.eye(4, dtype=np.float32)
pose[:3, :3] = R.transpose() # ? why need transpose here
pose[:3, 3] = (t[:3] / t[3])[:, 0]
return intrinsics, pose # ! return cam2world matrix here
# ! load one ref-image with multiple src-images in camera coordinate system
class MVSDatasetDtuPerView(Dataset):
def __init__(self, root_dir, split, n_views=3, img_wh=(640, 512), downSample=1.0,
split_filepath=None, pair_filepath=None,
N_rays=512,
vol_dims=[128, 128, 128], batch_size=1,
clean_image=False, importance_sample=False, test_ref_views=[]):
self.root_dir = root_dir
self.split = split
self.img_wh = img_wh
self.downSample = downSample
self.num_all_imgs = 49 # this preprocessed DTU dataset has 49 images
self.n_views = n_views
self.N_rays = N_rays
self.batch_size = batch_size # - used for construct new metas for gru fusion training
self.clean_image = clean_image
self.importance_sample = importance_sample
self.test_ref_views = test_ref_views # used for testing
self.scale_factor = 1.0
self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
if img_wh is not None:
assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
'img_wh must both be multiples of 32!'
self.split_filepath = f'data/dtu/lists/{self.split}.txt' if split_filepath is None else split_filepath
self.pair_filepath = f'data/dtu/dtu_pairs.txt' if pair_filepath is None else pair_filepath
print(colored("loading all scenes together", 'red'))
with open(self.split_filepath) as f:
self.scans = [line.rstrip() for line in f.readlines()]
self.all_intrinsics = [] # the cam info of the whole scene
self.all_extrinsics = []
self.all_near_fars = []
self.metas, self.ref_src_pairs = self.build_metas() # load ref-srcs view pairs info of the scene
self.allview_ids = [i for i in range(self.num_all_imgs)]
self.load_cam_info() # load camera info of DTU, and estimate scale_mat
self.build_remap()
self.define_transforms()
print(f'==> image down scale: {self.downSample}')
# * bounding box for rendering
self.bbox_min = np.array([-1.0, -1.0, -1.0])
self.bbox_max = np.array([1.0, 1.0, 1.0])
# - used for cost volume regularization
self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
self.partial_vol_origin = torch.Tensor([-1., -1., -1.])
def build_remap(self):
self.remap = np.zeros(np.max(self.allview_ids) + 1).astype('int')
for i, item in enumerate(self.allview_ids):
self.remap[item] = i
def define_transforms(self):
self.transform = T.Compose([T.ToTensor()])
def build_metas(self):
metas = []
ref_src_pairs = {}
# light conditions 0-6 for training
# light condition 3 for testing (the brightest?)
light_idxs = [3] if 'train' not in self.split else range(7)
with open(self.pair_filepath) as f:
num_viewpoint = int(f.readline())
# viewpoints (49)
for _ in range(num_viewpoint):
ref_view = int(f.readline().rstrip())
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
ref_src_pairs[ref_view] = src_views
for light_idx in light_idxs:
for scan in self.scans:
with open(self.pair_filepath) as f:
num_viewpoint = int(f.readline())
# viewpoints (49)
for _ in range(num_viewpoint):
ref_view = int(f.readline().rstrip())
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
# ! only for validation
if len(self.test_ref_views) > 0 and ref_view not in self.test_ref_views:
continue
metas += [(scan, light_idx, ref_view, src_views)]
return metas, ref_src_pairs
def read_cam_file(self, filename):
with open(filename) as f:
lines = [line.rstrip() for line in f.readlines()]
# extrinsics: line [1,5), 4x4 matrix
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ')
extrinsics = extrinsics.reshape((4, 4))
# intrinsics: line [7-10), 3x3 matrix
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ')
intrinsics = intrinsics.reshape((3, 3))
# depth_min & depth_interval: line 11
depth_min = float(lines[11].split()[0])
depth_max = depth_min + float(lines[11].split()[1]) * 192
self.depth_interval = float(lines[11].split()[1])
intrinsics_ = np.float32(np.diag([1, 1, 1, 1]))
intrinsics_[:3, :3] = intrinsics
return intrinsics_, extrinsics, [depth_min, depth_max]
def load_cam_info(self):
for vid in range(self.num_all_imgs):
proj_mat_filename = os.path.join(self.root_dir,
f'Cameras/train/{vid:08d}_cam.txt')
intrinsic, extrinsic, near_far = self.read_cam_file(proj_mat_filename)
intrinsic[:2] *= 4 # * the provided intrinsics is 4x downsampled, now keep the same scale with image
self.all_intrinsics.append(intrinsic)
self.all_extrinsics.append(extrinsic)
self.all_near_fars.append(near_far)
def read_depth(self, filename):
# import ipdb; ipdb.set_trace()
depth_h = np.array(read_pfm(filename)[0], dtype=np.float32) # (1200, 1600)
depth_h = np.ones((1200, 1600))
# print(depth_h.shape)
depth_h = cv2.resize(depth_h, None, fx=0.5, fy=0.5,
interpolation=cv2.INTER_NEAREST) # (600, 800)
depth_h = depth_h[44:556, 80:720] # (512, 640)
# print(depth_h.shape)
# import ipdb; ipdb.set_trace()
depth_h = cv2.resize(depth_h, None, fx=self.downSample, fy=self.downSample,
interpolation=cv2.INTER_NEAREST)
depth = cv2.resize(depth_h, None, fx=1.0 / 4, fy=1.0 / 4,
interpolation=cv2.INTER_NEAREST)
return depth, depth_h
def read_mask(self, filename):
mask_h = cv2.imread(filename, 0)
mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
interpolation=cv2.INTER_NEAREST)
mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
interpolation=cv2.INTER_NEAREST)
mask[mask > 0] = 1 # the masks stored in png are not binary
mask_h[mask_h > 0] = 1
return mask, mask_h
def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
center, radius, _ = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
radius = radius * factor
scale_mat = np.diag([radius, radius, radius, 1.0])
scale_mat[:3, 3] = center.cpu().numpy()
scale_mat = scale_mat.astype(np.float32)
return scale_mat, 1. / radius.cpu().numpy()
def __len__(self):
return len(self.metas)
def __getitem__(self, idx):
sample = {}
scan, light_idx, ref_view, src_views = self.metas[idx % len(self.metas)]
# generalized, load some images at once
view_ids = [ref_view] + src_views[:self.n_views]
# * transform from world system to camera system
w2c_ref = self.all_extrinsics[self.remap[ref_view]]
w2c_ref_inv = np.linalg.inv(w2c_ref)
image_perm = 0 # only supervised on reference view
imgs, depths_h, masks_h = [], [], [] # full size (640, 512)
intrinsics, w2cs, near_fars = [], [], [] # record proj mats between views
mask_dilated = None
for i, vid in enumerate(view_ids):
# NOTE that the id in image file names is from 1 to 49 (not 0~48)
img_filename = os.path.join(self.root_dir,
f'Rectified/{scan}_train/rect_{vid + 1:03d}_{light_idx}_r5000.png')
depth_filename = os.path.join(self.root_dir,
f'Depths/{scan}_train/depth_map_{vid:04d}.pfm')
# print(depth_filename)
mask_filename = os.path.join(self.root_dir,
f'Masks_clean_dilated/{scan}_train/mask_{vid:04d}.png')
img = Image.open(img_filename)
img_wh = np.round(np.array(img.size) * self.downSample).astype('int')
img = img.resize(img_wh, Image.BILINEAR)
if os.path.exists(mask_filename) and self.clean_image:
mask_l, mask_h = self.read_mask(mask_filename)
else:
# print(self.split, "don't find mask file", mask_filename)
mask_h = np.ones([img_wh[1], img_wh[0]])
masks_h.append(mask_h)
if i == 0:
kernel_size = 101 # default 101
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
mask_dilated = np.float32(cv2.dilate(np.uint8(mask_h * 255), kernel, iterations=1) > 128)
if self.clean_image:
img = np.array(img)
img[mask_h < 0.5] = 0.0
img = self.transform(img)
imgs += [img]
index_mat = self.remap[vid]
near_fars.append(self.all_near_fars[index_mat])
intrinsics.append(self.all_intrinsics[index_mat])
w2cs.append(self.all_extrinsics[index_mat] @ w2c_ref_inv)
# print(depth_filename)
if os.path.exists(depth_filename): # and i == 0
# print("file exists")
depth_l, depth_h = self.read_depth(depth_filename)
depths_h.append(depth_h)
# ! estimate scale_mat
scale_mat, scale_factor = self.cal_scale_mat(img_hw=[img_wh[1], img_wh[0]],
intrinsics=intrinsics, extrinsics=w2cs,
near_fars=near_fars, factor=1.1)
# ! calculate the new w2cs after scaling
new_near_fars = []
new_w2cs = []
new_c2ws = []
new_affine_mats = []
new_depths_h = []
for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
P = intrinsic @ extrinsic @ scale_mat
P = P[:3, :4]
# - should use load_K_Rt_from_P() to obtain c2w
c2w = load_K_Rt_from_P(None, P)[1]
w2c = np.linalg.inv(c2w)
new_w2cs.append(w2c)
new_c2ws.append(c2w)
affine_mat = np.eye(4)
affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
new_affine_mats.append(affine_mat)
camera_o = c2w[:3, 3]
dist = np.sqrt(np.sum(camera_o ** 2))
near = dist - 1
far = dist + 1
new_near_fars.append([0.95 * near, 1.05 * far])
new_depths_h.append(depth * scale_factor)
imgs = torch.stack(imgs).float()
print(new_near_fars)
depths_h = np.stack(new_depths_h)
masks_h = np.stack(masks_h)
affine_mats = np.stack(new_affine_mats)
intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
new_near_fars)
if 'train' in self.split:
start_idx = 0
else:
start_idx = 1
sample['images'] = imgs # (V, 3, H, W)
sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
sample['view_ids'] = torch.from_numpy(np.array(view_ids))
sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
sample['light_idx'] = torch.tensor(light_idx)
sample['scan'] = scan
sample['scale_factor'] = torch.tensor(scale_factor)
sample['img_wh'] = torch.from_numpy(img_wh)
sample['render_img_idx'] = torch.tensor(image_perm)
sample['partial_vol_origin'] = torch.tensor(self.partial_vol_origin, dtype=torch.float32)
sample['meta'] = str(scan) + "_light" + str(light_idx) + "_refview" + str(ref_view)
# - image to render
sample['query_image'] = sample['images'][0]
sample['query_c2w'] = sample['c2ws'][0]
sample['query_w2c'] = sample['w2cs'][0]
sample['query_intrinsic'] = sample['intrinsics'][0]
sample['query_depth'] = sample['depths_h'][0]
sample['query_mask'] = sample['masks_h'][0]
sample['query_near_far'] = sample['near_fars'][0]
sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
sample['view_ids'] = sample['view_ids'][start_idx:]
sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
sample['scale_mat'] = torch.from_numpy(scale_mat)
sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
# - generate rays
if ('val' in self.split) or ('test' in self.split):
sample_rays = gen_rays_from_single_image(
img_wh[1], img_wh[0],
sample['query_image'],
sample['query_intrinsic'],
sample['query_c2w'],
depth=sample['query_depth'],
mask=sample['query_mask'] if self.clean_image else None)
else:
sample_rays = gen_random_rays_from_single_image(
img_wh[1], img_wh[0],
self.N_rays,
sample['query_image'],
sample['query_intrinsic'],
sample['query_c2w'],
depth=sample['query_depth'],
mask=sample['query_mask'] if self.clean_image else None,
dilated_mask=mask_dilated,
importance_sample=self.importance_sample)
sample['rays'] = sample_rays
return sample