HaWoR / lib /eval_utils /filling_utils.py
ThunderVVV's picture
update
5f028d6
raw
history blame
14.1 kB
import copy
import os
import joblib
import numpy as np
from scipy.spatial.transform import Slerp, Rotation
import torch
from hawor.utils.process import run_mano, run_mano_left
from hawor.utils.rotation import angle_axis_to_quaternion, angle_axis_to_rotation_matrix, quaternion_to_rotation_matrix, rotation_matrix_to_angle_axis
from lib.utils.geometry import rotmat_to_rot6d
from lib.utils.geometry import rot6d_to_rotmat
def slerp_interpolation_aa(pos, valid):
B, T, N, _ = pos.shape # B: 批次大小, T: 时间步长, N: 关节数, 4: 四元数维度
pos_interp = pos.copy() # 创建副本以存储插值结果
for b in range(B):
for n in range(N):
quat_b_n = pos[b, :, n, :]
valid_b_n = valid[b, :]
invalid_idxs = np.where(~valid_b_n)[0]
valid_idxs = np.where(valid_b_n)[0]
if len(invalid_idxs) == 0:
continue
if len(valid_idxs) > 1:
valid_times = valid_idxs # 有效时间步
valid_rots = Rotation.from_rotvec(quat_b_n[valid_idxs]) # 有效四元数
slerp = Slerp(valid_times, valid_rots)
for idx in invalid_idxs:
if idx < valid_idxs[0]: # 时间步小于第一个有效时间步,进行外推
pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[0]] # 复制第一个有效四元数
elif idx > valid_idxs[-1]: # 时间步大于最后一个有效时间步,进行外推
pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[-1]] # 复制最后一个有效四元数
else:
interp_rot = slerp([idx])
pos_interp[b, idx, n, :] = interp_rot.as_rotvec()[0]
# print("#######")
# if N > 1:
# print(pos[1,0,11])
# print(pos_interp[1,0,11])
return pos_interp
def slerp_interpolation_quat(pos, valid):
# wxyz to xyzw
pos = pos[:, :, :, [1, 2, 3, 0]]
B, T, N, _ = pos.shape # B: 批次大小, T: 时间步长, N: 关节数, 4: 四元数维度
pos_interp = pos.copy() # 创建副本以存储插值结果
for b in range(B):
for n in range(N):
quat_b_n = pos[b, :, n, :]
valid_b_n = valid[b, :]
invalid_idxs = np.where(~valid_b_n)[0]
valid_idxs = np.where(valid_b_n)[0]
if len(invalid_idxs) == 0:
continue
if len(valid_idxs) > 1:
valid_times = valid_idxs # 有效时间步
valid_rots = Rotation.from_quat(quat_b_n[valid_idxs]) # 有效四元数
slerp = Slerp(valid_times, valid_rots)
for idx in invalid_idxs:
if idx < valid_idxs[0]: # 时间步小于第一个有效时间步,进行外推
pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[0]] # 复制第一个有效四元数
elif idx > valid_idxs[-1]: # 时间步大于最后一个有效时间步,进行外推
pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[-1]] # 复制最后一个有效四元数
else:
interp_rot = slerp([idx])
pos_interp[b, idx, n, :] = interp_rot.as_quat()[0]
# xyzw to wxyz
pos_interp = pos_interp[:, :, :, [3, 0, 1, 2]]
return pos_interp
def linear_interpolation_nd(pos, valid):
B, T = pos.shape[:2] # 取出批次大小B和时间步长T
feature_dim = pos.shape[2] # ** 代表的任意维度
pos_interp = pos.copy() # 创建一个副本,用来保存插值结果
for b in range(B):
for idx in range(feature_dim): # 针对任意维度
pos_b_idx = pos[b, :, idx] # 取出第b批次对应的**维度下的一个时间序列
valid_b = valid[b, :] # 当前批次的有效标志
# 找到无效的索引(False)
invalid_idxs = np.where(~valid_b)[0]
valid_idxs = np.where(valid_b)[0]
if len(invalid_idxs) == 0:
continue
# 对无效部分进行线性插值
if len(valid_idxs) > 1: # 确保有足够的有效点用于插值
pos_b_idx[invalid_idxs] = np.interp(invalid_idxs, valid_idxs, pos_b_idx[valid_idxs])
pos_interp[b, :, idx] = pos_b_idx # 保存插值结果
return pos_interp
def world2canonical_convert(R_c2w_sla, t_c2w_sla, data_out, handedness):
init_rot_mat = copy.deepcopy(data_out["init_root_orient"])
init_rot_mat = torch.einsum("tij,btjk->btik", R_c2w_sla, init_rot_mat)
init_rot = rotation_matrix_to_angle_axis(init_rot_mat)
init_rot_quat = angle_axis_to_quaternion(init_rot)
# data_out["init_root_orient"] = rotation_matrix_to_angle_axis(data_out["init_root_orient"])
# data_out["init_hand_pose"] = rotation_matrix_to_angle_axis(data_out["init_hand_pose"])
data_out_init_root_orient = rotation_matrix_to_angle_axis(data_out["init_root_orient"])
data_out_init_hand_pose = rotation_matrix_to_angle_axis(data_out["init_hand_pose"])
init_trans = data_out["init_trans"] # (B, T, 3)
if handedness == "left":
outputs = run_mano_left(data_out["init_trans"], data_out_init_root_orient, data_out_init_hand_pose, betas=data_out["init_betas"])
elif handedness == "right":
outputs = run_mano(data_out["init_trans"], data_out_init_root_orient, data_out_init_hand_pose, betas=data_out["init_betas"])
root_loc = outputs["joints"][..., 0, :].cpu() # (B, T, 3)
offset = init_trans - root_loc # It is a constant, no matter what the rotation is.
init_trans = (
torch.einsum("tij,btj->bti", R_c2w_sla, root_loc)
+ t_c2w_sla[None, :]
+ offset
)
data_world = {
"init_root_orient": init_rot, # (B, T, 3)
"init_hand_pose": data_out_init_hand_pose, # (B, T, 15, 3)
"init_trans": init_trans, # (B, T, 3)
"init_betas": data_out["init_betas"] # (B, T, 10)
}
return data_world
def filling_preprocess(item):
num_joints = 15
global_trans = item['trans'] # (2, seq_len, 3)
global_rot = item['rot'] #(2, seq_len, 3)
hand_pose = item['hand_pose'] # (2, seq_len, 45)
betas = item['betas'] # (2, seq_len, 10)
valid = item['valid'] # (2, seq_len)
N, T, _ = global_trans.shape
R_canonical2world_left_aa = torch.from_numpy(global_rot[0, 0])
R_canonical2world_right_aa = torch.from_numpy(global_rot[1, 0])
R_world2canonical_left = angle_axis_to_rotation_matrix(R_canonical2world_left_aa).t()
R_world2canonical_right = angle_axis_to_rotation_matrix(R_canonical2world_right_aa).t()
# transform left hand to canonical
hand_pose = hand_pose.reshape(N, T, num_joints, 3)
data_world_left = {
"init_trans": torch.from_numpy(global_trans[0:1]),
"init_root_orient": angle_axis_to_rotation_matrix(torch.from_numpy(global_rot[0:1])),
"init_hand_pose": angle_axis_to_rotation_matrix(torch.from_numpy(hand_pose[0:1])),
"init_betas": torch.from_numpy(betas[0:1]),
}
data_left_init_root_orient = rotation_matrix_to_angle_axis(data_world_left["init_root_orient"])
data_left_init_hand_pose = rotation_matrix_to_angle_axis(data_world_left["init_hand_pose"])
outputs = run_mano_left(data_world_left["init_trans"], data_left_init_root_orient, data_left_init_hand_pose, betas=data_world_left["init_betas"])
init_trans = data_world_left["init_trans"][0, 0] # (3,)
root_loc = outputs["joints"][0, 0, 0, :].cpu() # (3,)
offset = init_trans - root_loc # It is a constant, no matter what the rotation is.
t_world2canonical_left = -torch.einsum("ij,j->i", R_world2canonical_left, root_loc) - offset
R_world2canonical_left = R_world2canonical_left.repeat(T, 1, 1)
t_world2canonical_left = t_world2canonical_left.repeat(T, 1)
data_canonical_left = world2canonical_convert(R_world2canonical_left, t_world2canonical_left, data_world_left, "left")
# transform right hand to canonical
data_world_right = {
"init_trans": torch.from_numpy(global_trans[1:2]),
"init_root_orient": angle_axis_to_rotation_matrix(torch.from_numpy(global_rot[1:2])),
"init_hand_pose": angle_axis_to_rotation_matrix(torch.from_numpy(hand_pose[1:2])),
"init_betas": torch.from_numpy(betas[1:2]),
}
data_right_init_root_orient = rotation_matrix_to_angle_axis(data_world_right["init_root_orient"])
data_right_init_hand_pose = rotation_matrix_to_angle_axis(data_world_right["init_hand_pose"])
outputs = run_mano(data_world_right["init_trans"], data_right_init_root_orient, data_right_init_hand_pose, betas=data_world_right["init_betas"])
init_trans = data_world_right["init_trans"][0, 0] # (3,)
root_loc = outputs["joints"][0, 0, 0, :].cpu() # (3,)
offset = init_trans - root_loc # It is a constant, no matter what the rotation is.
t_world2canonical_right = -torch.einsum("ij,j->i", R_world2canonical_right, root_loc) - offset
R_world2canonical_right = R_world2canonical_right.repeat(T, 1, 1)
t_world2canonical_right = t_world2canonical_right.repeat(T, 1)
data_canonical_right = world2canonical_convert(R_world2canonical_right, t_world2canonical_right, data_world_right, "right")
# merge left and right canonical data
global_rot = torch.cat((data_canonical_left['init_root_orient'], data_canonical_right['init_root_orient']))
global_trans = torch.cat((data_canonical_left['init_trans'], data_canonical_right['init_trans'])).numpy()
# global_rot = angle_axis_to_quaternion(global_rot).numpy().reshape(N, T, 1, 4)
global_rot = global_rot.reshape(N, T, 1, 3).numpy()
hand_pose = hand_pose.reshape(N, T, 15, 3)
# hand_pose = angle_axis_to_quaternion(torch.from_numpy(hand_pose)).numpy()
# lerp and slerp
global_trans_lerped = linear_interpolation_nd(global_trans, valid)
betas_lerped = linear_interpolation_nd(betas, valid)
global_rot_slerped = slerp_interpolation_aa(global_rot, valid)
hand_pose_slerped = slerp_interpolation_aa(hand_pose, valid)
# convert to rot6d
global_rot_slerped_mat = angle_axis_to_rotation_matrix(torch.from_numpy(global_rot_slerped.reshape(N*T, -1)))
# global_rot_slerped_mat = quaternion_to_rotation_matrix(torch.from_numpy(global_rot_slerped.reshape(N*T, -1)))
global_rot_slerped_rot6d = rotmat_to_rot6d(global_rot_slerped_mat).reshape(N, T, -1).numpy()
hand_pose_slerped_mat = angle_axis_to_rotation_matrix(torch.from_numpy(hand_pose_slerped.reshape(N*T*num_joints, -1)))
# hand_pose_slerped_mat = quaternion_to_rotation_matrix(torch.from_numpy(hand_pose_slerped.reshape(N*T*num_joints, -1)))
hand_pose_slerped_rot6d = rotmat_to_rot6d(hand_pose_slerped_mat).reshape(N, T, -1).numpy()
# concat to (T, concat_dim)
global_pose_vec_input = np.concatenate((global_trans_lerped, betas_lerped, global_rot_slerped_rot6d, hand_pose_slerped_rot6d), axis=-1).transpose(1, 0, 2).reshape(T, -1)
R_canon2w_left = R_world2canonical_left.transpose(-1, -2)
t_canon2w_left = -torch.einsum("tij,tj->ti", R_canon2w_left, t_world2canonical_left)
R_canon2w_right = R_world2canonical_right.transpose(-1, -2)
t_canon2w_right = -torch.einsum("tij,tj->ti", R_canon2w_right, t_world2canonical_right)
transform_w_canon = {
"R_w2canon_left": R_world2canonical_left,
"t_w2canon_left": t_world2canonical_left,
"R_canon2w_left": R_canon2w_left,
"t_canon2w_left": t_canon2w_left,
"R_w2canon_right": R_world2canonical_right,
"t_w2canon_right": t_world2canonical_right,
"R_canon2w_right": R_canon2w_right,
"t_canon2w_right": t_canon2w_right,
}
return global_pose_vec_input, transform_w_canon
def custom_rot6d_to_rotmat(rot6d):
original_shape = rot6d.shape[:-1]
rot6d = rot6d.reshape(-1, 6)
mat = rot6d_to_rotmat(rot6d)
mat = mat.reshape(*original_shape, 3, 3)
return mat
def filling_postprocess(output, transform_w_canon):
# output = output.numpy()
output = output.permute(1, 0, 2) # (2, T, -1)
N, T, _ = output.shape
canon_trans = output[:, :, :3]
betas = output[:, :, 3:13]
canon_rot_rot6d = output[:, :, 13:19]
hand_pose_rot6d = output[:, :, 19:109].reshape(N, T, 15, 6)
canon_rot_mat = custom_rot6d_to_rotmat(canon_rot_rot6d)
hand_pose_mat = custom_rot6d_to_rotmat(hand_pose_rot6d)
data_canonical_left = {
"init_trans": canon_trans[[0], :, :],
"init_root_orient": canon_rot_mat[[0], :, :, :],
"init_hand_pose": hand_pose_mat[[0], :, :, :, :],
"init_betas": betas[[0], :, :]
}
data_canonical_right = {
"init_trans": canon_trans[[1], :, :],
"init_root_orient": canon_rot_mat[[1], :, :, :],
"init_hand_pose": hand_pose_mat[[1], :, :, :, :],
"init_betas": betas[[1], :, :]
}
R_canon2w_left = transform_w_canon['R_canon2w_left']
t_canon2w_left = transform_w_canon['t_canon2w_left']
R_canon2w_right = transform_w_canon['R_canon2w_right']
t_canon2w_right = transform_w_canon['t_canon2w_right']
world_left = world2canonical_convert(R_canon2w_left, t_canon2w_left, data_canonical_left, "left")
world_right = world2canonical_convert(R_canon2w_right, t_canon2w_right, data_canonical_right, "right")
global_rot = torch.cat((world_left['init_root_orient'], world_right['init_root_orient'])).numpy()
global_trans = torch.cat((world_left['init_trans'], world_right['init_trans'])).numpy()
pred_data = {
"trans": global_trans, # (2, T, 3)
"rot": global_rot, # (2, T, 3)
"hand_pose": rotation_matrix_to_angle_axis(hand_pose_mat).flatten(-2).numpy(), # (2, T, 45)
"betas": betas.numpy(), # (2, T, 10)
}
return pred_data