|
from torch.utils.data import Dataset |
|
from utils.misc_utils import read_pfm |
|
import os |
|
import numpy as np |
|
import cv2 |
|
from PIL import Image |
|
import torch |
|
from torchvision import transforms as T |
|
from data.scene import get_boundingbox |
|
|
|
from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image |
|
import json |
|
from termcolor import colored |
|
import imageio |
|
from kornia import create_meshgrid |
|
import open3d as o3d |
|
def get_ray_directions(H, W, focal, center=None): |
|
""" |
|
Get ray directions for all pixels in camera coordinate. |
|
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ |
|
ray-tracing-generating-camera-rays/standard-coordinate-systems |
|
Inputs: |
|
H, W, focal: image height, width and focal length |
|
Outputs: |
|
directions: (H, W, 3), the direction of the rays in camera coordinate |
|
""" |
|
grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 |
|
|
|
i, j = grid.unbind(-1) |
|
|
|
|
|
cent = center if center is not None else [W / 2, H / 2] |
|
directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) |
|
|
|
return directions |
|
|
|
def load_K_Rt_from_P(filename, P=None): |
|
if P is None: |
|
lines = open(filename).read().splitlines() |
|
if len(lines) == 4: |
|
lines = lines[1:] |
|
lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] |
|
P = np.asarray(lines).astype(np.float32).squeeze() |
|
|
|
out = cv2.decomposeProjectionMatrix(P) |
|
K = out[0] |
|
R = out[1] |
|
t = out[2] |
|
|
|
K = K / K[2, 2] |
|
intrinsics = np.eye(4) |
|
intrinsics[:3, :3] = K |
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
pose[:3, :3] = R.transpose() |
|
pose[:3, 3] = (t[:3] / t[3])[:, 0] |
|
|
|
return intrinsics, pose |
|
|
|
|
|
|
|
class BlenderPerView(Dataset): |
|
def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, |
|
split_filepath=None, pair_filepath=None, |
|
N_rays=512, |
|
vol_dims=[128, 128, 128], batch_size=1, |
|
clean_image=False, importance_sample=False, test_ref_views=[]): |
|
|
|
|
|
self.root_dir = root_dir |
|
self.split = split |
|
|
|
self.n_views = n_views |
|
self.N_rays = N_rays |
|
self.batch_size = batch_size |
|
|
|
self.clean_image = clean_image |
|
self.importance_sample = importance_sample |
|
self.test_ref_views = test_ref_views |
|
self.scale_factor = 1.0 |
|
self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) |
|
|
|
lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' |
|
with open(lvis_json_path, 'r') as f: |
|
lvis_paths = json.load(f) |
|
if self.split == 'train': |
|
self.lvis_paths = lvis_paths['train'] |
|
else: |
|
self.lvis_paths = lvis_paths['val'] |
|
if img_wh is not None: |
|
assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ |
|
'img_wh must both be multiples of 32!' |
|
|
|
|
|
pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" |
|
with open(pose_json_path, 'r') as f: |
|
meta = json.load(f) |
|
|
|
self.img_ids = list(meta["c2ws"].keys()) |
|
self.img_wh = (256, 256) |
|
self.input_poses = np.array(list(meta["c2ws"].values())) |
|
intrinsic = np.eye(4) |
|
intrinsic[:3, :3] = np.array(meta["intrinsics"]) |
|
self.intrinsic = intrinsic |
|
self.near_far = np.array(meta["near_far"]) |
|
self.near_far[1] = 1.8 |
|
self.define_transforms() |
|
self.blender2opencv = np.array( |
|
[[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] |
|
) |
|
|
|
|
|
self.c2ws = [] |
|
self.w2cs = [] |
|
self.near_fars = [] |
|
|
|
for idx, img_id in enumerate(self.img_ids): |
|
pose = self.input_poses[idx] |
|
c2w = pose @ self.blender2opencv |
|
self.c2ws.append(c2w) |
|
self.w2cs.append(np.linalg.inv(c2w)) |
|
self.near_fars.append(self.near_far) |
|
self.c2ws = np.stack(self.c2ws, axis=0) |
|
self.w2cs = np.stack(self.w2cs, axis=0) |
|
|
|
|
|
self.all_intrinsics = [] |
|
self.all_extrinsics = [] |
|
self.all_near_fars = [] |
|
self.load_cam_info() |
|
|
|
|
|
self.bbox_min = np.array([-1.0, -1.0, -1.0]) |
|
self.bbox_max = np.array([1.0, 1.0, 1.0]) |
|
|
|
|
|
self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) |
|
self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) |
|
|
|
|
|
def define_transforms(self): |
|
self.transform = T.Compose([T.ToTensor()]) |
|
|
|
|
|
|
|
def load_cam_info(self): |
|
for vid, img_id in enumerate(self.img_ids): |
|
intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far |
|
self.all_intrinsics.append(intrinsic) |
|
self.all_extrinsics.append(extrinsic) |
|
self.all_near_fars.append(near_far) |
|
|
|
def read_depth(self, filename): |
|
pass |
|
|
|
def read_mask(self, filename): |
|
mask_h = cv2.imread(filename, 0) |
|
mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, |
|
interpolation=cv2.INTER_NEAREST) |
|
mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, |
|
interpolation=cv2.INTER_NEAREST) |
|
|
|
mask[mask > 0] = 1 |
|
mask_h[mask_h > 0] = 1 |
|
|
|
return mask, mask_h |
|
|
|
def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): |
|
|
|
center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) |
|
|
|
|
|
|
|
|
|
radius = radius * factor |
|
scale_mat = np.diag([radius, radius, radius, 1.0]) |
|
scale_mat[:3, 3] = center.cpu().numpy() |
|
scale_mat = scale_mat.astype(np.float32) |
|
|
|
return scale_mat, 1. / radius.cpu().numpy() |
|
|
|
def __len__(self): |
|
if self.split == 'train': |
|
return 6*len(self.lvis_paths) |
|
else: |
|
return 8*len(self.lvis_paths) |
|
|
|
|
|
def read_depth(self, filename, near_bound, noisy_factor=1.0): |
|
pass |
|
|
|
|
|
def __getitem__(self, idx): |
|
sample = {} |
|
origin_idx = idx |
|
imgs, depths_h, masks_h = [], [], [] |
|
intrinsics, w2cs, c2ws, near_fars = [], [], [], [] |
|
|
|
if self.split == 'train': |
|
folder_uid_dict = self.lvis_paths[idx//6] |
|
idx = idx % 6 |
|
if idx == 4: |
|
idx = 5 |
|
elif idx == 5: |
|
idx = 7 |
|
else: |
|
folder_uid_dict = self.lvis_paths[idx//8] |
|
idx = idx % 8 |
|
|
|
folder_id = folder_uid_dict['folder_id'] |
|
uid = folder_uid_dict['uid'] |
|
|
|
|
|
|
|
c2w = self.c2ws[idx] |
|
w2c = np.linalg.inv(c2w) |
|
w2c_ref = w2c |
|
w2c_ref_inv = np.linalg.inv(w2c_ref) |
|
|
|
w2cs.append(w2c @ w2c_ref_inv) |
|
c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) |
|
|
|
img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{idx}.png') |
|
|
|
depth_filename = os.path.join(os.path.join(self.root_dir, folder_id, uid, f'view_{idx}_depth_mm.png')) |
|
|
|
|
|
img = Image.open(img_filename) |
|
|
|
img = self.transform(img) |
|
|
|
|
|
if img.shape[0] == 4: |
|
img = img[:3] * img[-1:] + (1 - img[-1:]) |
|
imgs += [img] |
|
|
|
depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 |
|
mask_h = depth_h > 0 |
|
|
|
directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) |
|
surface_points = directions * depth_h[..., None] |
|
distance = np.linalg.norm(surface_points, axis=-1) |
|
depth_h = distance |
|
|
|
|
|
depths_h.append(depth_h) |
|
masks_h.append(mask_h) |
|
|
|
intrinsic = self.intrinsic |
|
intrinsics.append(intrinsic) |
|
|
|
|
|
near_fars.append(self.near_fars[idx]) |
|
image_perm = 0 |
|
|
|
mask_dilated = None |
|
|
|
|
|
src_views = range(8, 8 + 8 * 4) |
|
|
|
for vid in src_views: |
|
if ((vid - 8) // 4 == 4) or ((vid - 8) // 4 == 6): |
|
continue |
|
img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10.png') |
|
|
|
img = Image.open(img_filename) |
|
img_wh = self.img_wh |
|
|
|
img = self.transform(img) |
|
if img.shape[0] == 4: |
|
img = img[:3] * img[-1:] + (1 - img[-1:]) |
|
|
|
imgs += [img] |
|
depth_h = np.ones(img.shape[1:], dtype=np.float32) |
|
depths_h.append(depth_h) |
|
masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) |
|
|
|
near_fars.append(self.all_near_fars[vid]) |
|
intrinsics.append(self.all_intrinsics[vid]) |
|
|
|
w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) |
|
|
|
|
|
|
|
scale_mat, scale_factor = self.cal_scale_mat( |
|
img_hw=[img_wh[1], img_wh[0]], |
|
intrinsics=intrinsics, extrinsics=w2cs, |
|
near_fars=near_fars, factor=1.1 |
|
) |
|
|
|
|
|
new_near_fars = [] |
|
new_w2cs = [] |
|
new_c2ws = [] |
|
new_affine_mats = [] |
|
new_depths_h = [] |
|
for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): |
|
|
|
P = intrinsic @ extrinsic @ scale_mat |
|
P = P[:3, :4] |
|
|
|
c2w = load_K_Rt_from_P(None, P)[1] |
|
w2c = np.linalg.inv(c2w) |
|
new_w2cs.append(w2c) |
|
new_c2ws.append(c2w) |
|
affine_mat = np.eye(4) |
|
affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] |
|
new_affine_mats.append(affine_mat) |
|
|
|
camera_o = c2w[:3, 3] |
|
dist = np.sqrt(np.sum(camera_o ** 2)) |
|
near = dist - 1 |
|
far = dist + 1 |
|
|
|
new_near_fars.append([0.95 * near, 1.05 * far]) |
|
new_depths_h.append(depth * scale_factor) |
|
|
|
|
|
imgs = torch.stack(imgs).float() |
|
depths_h = np.stack(new_depths_h) |
|
masks_h = np.stack(masks_h) |
|
|
|
affine_mats = np.stack(new_affine_mats) |
|
intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( |
|
new_near_fars) |
|
|
|
if self.split == 'train': |
|
start_idx = 0 |
|
else: |
|
start_idx = 1 |
|
|
|
view_ids = [idx] + list(src_views) |
|
sample['origin_idx'] = origin_idx |
|
sample['images'] = imgs |
|
sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) |
|
sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) |
|
sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) |
|
sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) |
|
sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) |
|
sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] |
|
sample['view_ids'] = torch.from_numpy(np.array(view_ids)) |
|
sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) |
|
|
|
|
|
sample['scan'] = folder_id |
|
|
|
sample['scale_factor'] = torch.tensor(scale_factor) |
|
sample['img_wh'] = torch.from_numpy(np.array(img_wh)) |
|
sample['render_img_idx'] = torch.tensor(image_perm) |
|
sample['partial_vol_origin'] = self.partial_vol_origin |
|
sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0]) |
|
|
|
|
|
|
|
sample['query_image'] = sample['images'][0] |
|
sample['query_c2w'] = sample['c2ws'][0] |
|
sample['query_w2c'] = sample['w2cs'][0] |
|
sample['query_intrinsic'] = sample['intrinsics'][0] |
|
sample['query_depth'] = sample['depths_h'][0] |
|
sample['query_mask'] = sample['masks_h'][0] |
|
sample['query_near_far'] = sample['near_fars'][0] |
|
|
|
sample['images'] = sample['images'][start_idx:] |
|
sample['depths_h'] = sample['depths_h'][start_idx:] |
|
sample['masks_h'] = sample['masks_h'][start_idx:] |
|
sample['w2cs'] = sample['w2cs'][start_idx:] |
|
sample['c2ws'] = sample['c2ws'][start_idx:] |
|
sample['intrinsics'] = sample['intrinsics'][start_idx:] |
|
sample['view_ids'] = sample['view_ids'][start_idx:] |
|
sample['affine_mats'] = sample['affine_mats'][start_idx:] |
|
|
|
sample['scale_mat'] = torch.from_numpy(scale_mat) |
|
sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) |
|
|
|
|
|
if ('val' in self.split) or ('test' in self.split): |
|
sample_rays = gen_rays_from_single_image( |
|
img_wh[1], img_wh[0], |
|
sample['query_image'], |
|
sample['query_intrinsic'], |
|
sample['query_c2w'], |
|
depth=sample['query_depth'], |
|
mask=sample['query_mask'] if self.clean_image else None) |
|
else: |
|
sample_rays = gen_random_rays_from_single_image( |
|
img_wh[1], img_wh[0], |
|
self.N_rays, |
|
sample['query_image'], |
|
sample['query_intrinsic'], |
|
sample['query_c2w'], |
|
depth=sample['query_depth'], |
|
mask=sample['query_mask'] if self.clean_image else None, |
|
dilated_mask=mask_dilated, |
|
importance_sample=self.importance_sample) |
|
|
|
|
|
sample['rays'] = sample_rays |
|
|
|
return sample |
|
|