HaWoR / lib /pipeline /masked_droid_slam.py
ThunderVVV's picture
update
5f028d6
raw
history blame
9.38 kB
import sys
sys.path.insert(0, 'thirdparty/DROID-SLAM/droid_slam')
sys.path.insert(0, 'thirdparty/DROID-SLAM')
from tqdm import tqdm
import numpy as np
import torch
import os
import argparse
from PIL import Image
import cv2
from glob import glob
from droid import Droid
from torch.multiprocessing import Process
import evo
from evo.core.trajectory import PoseTrajectory3D
from evo.tools import file_interface
from evo.core import sync
import evo.main_ape as main_ape
from evo.core.metrics import PoseRelation
from pycocotools import mask as masktool
from torchvision.transforms import Resize
# Some default settings for DROID-SLAM
parser = argparse.ArgumentParser()
parser.add_argument("--imagedir", type=str, help="path to image directory")
parser.add_argument("--calib", type=str, help="path to calibration file")
parser.add_argument("--t0", default=0, type=int, help="starting frame")
parser.add_argument("--stride", default=1, type=int, help="frame stride")
parser.add_argument("--weights", default="weights/external/droid.pth")
parser.add_argument("--buffer", type=int, default=512)
parser.add_argument("--image_size", default=[240, 320])
parser.add_argument("--disable_vis", action="store_true")
parser.add_argument("--beta", type=float, default=0.3, help="weight for translation / rotation components of flow")
parser.add_argument("--filter_thresh", type=float, default=2.4, help="how much motion before considering new keyframe")
parser.add_argument("--warmup", type=int, default=8, help="number of warmup frames")
parser.add_argument("--keyframe_thresh", type=float, default=4.0, help="threshold to create a new keyframe")
parser.add_argument("--frontend_thresh", type=float, default=16.0, help="add edges between frames whithin this distance")
parser.add_argument("--frontend_window", type=int, default=25, help="frontend optimization window")
parser.add_argument("--frontend_radius", type=int, default=2, help="force edges between frames within radius")
parser.add_argument("--frontend_nms", type=int, default=1, help="non-maximal supression of edges")
parser.add_argument("--backend_thresh", type=float, default=22.0)
parser.add_argument("--backend_radius", type=int, default=2)
parser.add_argument("--backend_nms", type=int, default=3)
parser.add_argument("--upsample", action="store_true")
parser.add_argument("--reconstruction_path", help="path to saved reconstruction")
args = parser.parse_args([])
args.stereo = False
args.upsample = True
args.disable_vis = True
torch.multiprocessing.set_start_method('spawn')
def est_calib(imagedir):
""" Roughly estimate intrinsics by image dimensions """
if isinstance(imagedir, list):
imgfiles = imagedir
else:
imgfiles = sorted(glob(f'{imagedir}/*.jpg'))
image = cv2.imread(imgfiles[0])
h0, w0, _ = image.shape
focal = np.max([h0, w0])
cx, cy = w0/2., h0/2.
calib = [focal, focal, cx, cy]
return calib
def get_dimention(imagedir):
""" Get proper image dimension for DROID """
if isinstance(imagedir, list):
imgfiles = imagedir
else:
imgfiles = sorted(glob(f'{imagedir}/*.jpg'))
image = cv2.imread(imgfiles[0])
h0, w0, _ = image.shape
h1 = int(h0 * np.sqrt((384 * 512) / (h0 * w0)))
w1 = int(w0 * np.sqrt((384 * 512) / (h0 * w0)))
image = cv2.resize(image, (w1, h1))
image = image[:h1-h1%8, :w1-w1%8]
H, W, _ = image.shape
return H, W
def image_stream(imagedir, calib, stride, max_frame=None):
""" Image generator for DROID """
fx, fy, cx, cy = calib[:4]
K = np.eye(3)
K[0,0] = fx
K[0,2] = cx
K[1,1] = fy
K[1,2] = cy
if isinstance(imagedir, list):
image_list = imagedir
else:
image_list = sorted(glob(f'{imagedir}/*.jpg'))
image_list = image_list[::stride]
if max_frame is not None:
image_list = image_list[:max_frame]
for t, imfile in enumerate(image_list):
image = cv2.imread(imfile)
if len(calib) > 4:
image = cv2.undistort(image, K, calib[4:])
h0, w0, _ = image.shape
h1 = int(h0 * np.sqrt((384 * 512) / (h0 * w0)))
w1 = int(w0 * np.sqrt((384 * 512) / (h0 * w0)))
image = cv2.resize(image, (w1, h1))
image = image[:h1-h1%8, :w1-w1%8]
image = torch.as_tensor(image).permute(2, 0, 1)
intrinsics = torch.as_tensor([fx, fy, cx, cy])
intrinsics[0::2] *= (w1 / w0)
intrinsics[1::2] *= (h1 / h0)
yield t, image[None], intrinsics
def run_slam(imagedir, masks, calib=None, depth=None, stride=1,
filter_thresh=2.4, disable_vis=True):
""" Maksed DROID-SLAM """
droid = None
depth = None
args.filter_thresh = filter_thresh
args.disable_vis = disable_vis
masks = masks[::stride]
img_msks, conf_msks = preprocess_masks(imagedir, masks)
if calib is None:
calib = est_calib(imagedir)
for (t, image, intrinsics) in tqdm(image_stream(imagedir, calib, stride)):
if droid is None:
args.image_size = [image.shape[2], image.shape[3]]
droid = Droid(args)
img_msk = img_msks[t]
conf_msk = conf_msks[t]
image = image * (img_msk < 0.5)
# cv2.imwrite('debug.png', image[0].permute(1, 2, 0).numpy())
droid.track(t, image, intrinsics=intrinsics, depth=depth, mask=conf_msk)
traj = droid.terminate(image_stream(imagedir, calib, stride))
return droid, traj
def run_droid_slam(imagedir, calib=None, depth=None, stride=1,
filter_thresh=2.4, disable_vis=True):
""" Maksed DROID-SLAM """
droid = None
depth = None
args.filter_thresh = filter_thresh
args.disable_vis = disable_vis
if calib is None:
calib = est_calib(imagedir)
for (t, image, intrinsics) in tqdm(image_stream(imagedir, calib, stride)):
if droid is None:
args.image_size = [image.shape[2], image.shape[3]]
droid = Droid(args)
droid.track(t, image, intrinsics=intrinsics, depth=depth)
traj = droid.terminate(image_stream(imagedir, calib, stride))
return droid, traj
def eval_slam(traj_est, cam_t, cam_q, return_traj=True, correct_scale=False, align=True, align_origin=False):
""" Evaluation for SLAM """
tstamps = np.array([i for i in range(len(traj_est))], dtype=np.float32)
traj_est = PoseTrajectory3D(
positions_xyz=traj_est[:,:3],
orientations_quat_wxyz=traj_est[:,3:],
timestamps=tstamps)
traj_ref = PoseTrajectory3D(
positions_xyz=cam_t.copy(),
orientations_quat_wxyz=cam_q.copy(),
timestamps=tstamps)
traj_ref, traj_est = sync.associate_trajectories(traj_ref, traj_est)
result = main_ape.ape(traj_ref, traj_est, est_name='traj',
pose_relation=PoseRelation.translation_part, align=align, align_origin=align_origin,
correct_scale=correct_scale)
stats = result.stats
if return_traj:
return stats, traj_ref, traj_est
return stats
def test_slam(imagedir, img_msks, conf_msks, calib, stride=10, max_frame=50):
""" Shorter SLAM step to test reprojection error """
args = parser.parse_args([])
args.stereo = False
args.upsample = False
args.disable_vis = True
args.frontend_window = 10
args.frontend_thresh = 10
droid = None
for (t, image, intrinsics) in image_stream(imagedir, calib, stride, max_frame):
if droid is None:
args.image_size = [image.shape[2], image.shape[3]]
droid = Droid(args)
img_msk = img_msks[t]
conf_msk = conf_msks[t]
image = image * (img_msk < 0.5)
droid.track(t, image, intrinsics=intrinsics, mask=conf_msk)
reprojection_error = droid.compute_error()
del droid
return reprojection_error
def search_focal_length(img_folder, masks, stride=10, max_frame=50,
low=500, high=1500, step=100):
""" Search for a good focal length by SLAM reprojection error """
masks = masks[::stride]
masks = masks[:max_frame]
img_msks, conf_msks = preprocess_masks(img_folder, masks)
# default estimate
calib = np.array(est_calib(img_folder))
best_focal = calib[0]
best_err = test_slam(img_folder, img_msks, conf_msks,
stride=stride, calib=calib, max_frame=max_frame)
# search based on slam reprojection error
for focal in range(low, high, step):
calib[:2] = focal
err = test_slam(img_folder, img_msks, conf_msks,
stride=stride, calib=calib, max_frame=max_frame)
if err < best_err:
best_err = err
best_focal = focal
print('Best focal length:', best_focal)
return best_focal
def preprocess_masks(img_folder, masks):
""" Resize masks for masked droid """
H, W = get_dimention(img_folder)
resize_1 = Resize((H, W), antialias=True)
resize_2 = Resize((H//8, W//8), antialias=True)
img_msks = []
for i in range(0, len(masks), 500):
m = resize_1(masks[i:i+500])
img_msks.append(m)
img_msks = torch.cat(img_msks)
conf_msks = []
for i in range(0, len(masks), 500):
m = resize_2(masks[i:i+500])
conf_msks.append(m)
conf_msks = torch.cat(conf_msks)
return img_msks, conf_msks