import copy import os import joblib import numpy as np from scipy.spatial.transform import Slerp, Rotation import torch from hawor.utils.process import run_mano, run_mano_left from hawor.utils.rotation import angle_axis_to_quaternion, angle_axis_to_rotation_matrix, quaternion_to_rotation_matrix, rotation_matrix_to_angle_axis from lib.utils.geometry import rotmat_to_rot6d from lib.utils.geometry import rot6d_to_rotmat def slerp_interpolation_aa(pos, valid): B, T, N, _ = pos.shape # B: 批次大小, T: 时间步长, N: 关节数, 4: 四元数维度 pos_interp = pos.copy() # 创建副本以存储插值结果 for b in range(B): for n in range(N): quat_b_n = pos[b, :, n, :] valid_b_n = valid[b, :] invalid_idxs = np.where(~valid_b_n)[0] valid_idxs = np.where(valid_b_n)[0] if len(invalid_idxs) == 0: continue if len(valid_idxs) > 1: valid_times = valid_idxs # 有效时间步 valid_rots = Rotation.from_rotvec(quat_b_n[valid_idxs]) # 有效四元数 slerp = Slerp(valid_times, valid_rots) for idx in invalid_idxs: if idx < valid_idxs[0]: # 时间步小于第一个有效时间步,进行外推 pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[0]] # 复制第一个有效四元数 elif idx > valid_idxs[-1]: # 时间步大于最后一个有效时间步,进行外推 pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[-1]] # 复制最后一个有效四元数 else: interp_rot = slerp([idx]) pos_interp[b, idx, n, :] = interp_rot.as_rotvec()[0] # print("#######") # if N > 1: # print(pos[1,0,11]) # print(pos_interp[1,0,11]) return pos_interp def slerp_interpolation_quat(pos, valid): # wxyz to xyzw pos = pos[:, :, :, [1, 2, 3, 0]] B, T, N, _ = pos.shape # B: 批次大小, T: 时间步长, N: 关节数, 4: 四元数维度 pos_interp = pos.copy() # 创建副本以存储插值结果 for b in range(B): for n in range(N): quat_b_n = pos[b, :, n, :] valid_b_n = valid[b, :] invalid_idxs = np.where(~valid_b_n)[0] valid_idxs = np.where(valid_b_n)[0] if len(invalid_idxs) == 0: continue if len(valid_idxs) > 1: valid_times = valid_idxs # 有效时间步 valid_rots = Rotation.from_quat(quat_b_n[valid_idxs]) # 有效四元数 slerp = Slerp(valid_times, valid_rots) for idx in invalid_idxs: if idx < valid_idxs[0]: # 时间步小于第一个有效时间步,进行外推 pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[0]] # 复制第一个有效四元数 elif idx > valid_idxs[-1]: # 时间步大于最后一个有效时间步,进行外推 pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[-1]] # 复制最后一个有效四元数 else: interp_rot = slerp([idx]) pos_interp[b, idx, n, :] = interp_rot.as_quat()[0] # xyzw to wxyz pos_interp = pos_interp[:, :, :, [3, 0, 1, 2]] return pos_interp def linear_interpolation_nd(pos, valid): B, T = pos.shape[:2] # 取出批次大小B和时间步长T feature_dim = pos.shape[2] # ** 代表的任意维度 pos_interp = pos.copy() # 创建一个副本,用来保存插值结果 for b in range(B): for idx in range(feature_dim): # 针对任意维度 pos_b_idx = pos[b, :, idx] # 取出第b批次对应的**维度下的一个时间序列 valid_b = valid[b, :] # 当前批次的有效标志 # 找到无效的索引(False) invalid_idxs = np.where(~valid_b)[0] valid_idxs = np.where(valid_b)[0] if len(invalid_idxs) == 0: continue # 对无效部分进行线性插值 if len(valid_idxs) > 1: # 确保有足够的有效点用于插值 pos_b_idx[invalid_idxs] = np.interp(invalid_idxs, valid_idxs, pos_b_idx[valid_idxs]) pos_interp[b, :, idx] = pos_b_idx # 保存插值结果 return pos_interp def world2canonical_convert(R_c2w_sla, t_c2w_sla, data_out, handedness): init_rot_mat = copy.deepcopy(data_out["init_root_orient"]) init_rot_mat = torch.einsum("tij,btjk->btik", R_c2w_sla, init_rot_mat) init_rot = rotation_matrix_to_angle_axis(init_rot_mat) init_rot_quat = angle_axis_to_quaternion(init_rot) # data_out["init_root_orient"] = rotation_matrix_to_angle_axis(data_out["init_root_orient"]) # data_out["init_hand_pose"] = rotation_matrix_to_angle_axis(data_out["init_hand_pose"]) data_out_init_root_orient = rotation_matrix_to_angle_axis(data_out["init_root_orient"]) data_out_init_hand_pose = rotation_matrix_to_angle_axis(data_out["init_hand_pose"]) init_trans = data_out["init_trans"] # (B, T, 3) if handedness == "left": outputs = run_mano_left(data_out["init_trans"], data_out_init_root_orient, data_out_init_hand_pose, betas=data_out["init_betas"]) elif handedness == "right": outputs = run_mano(data_out["init_trans"], data_out_init_root_orient, data_out_init_hand_pose, betas=data_out["init_betas"]) root_loc = outputs["joints"][..., 0, :].cpu() # (B, T, 3) offset = init_trans - root_loc # It is a constant, no matter what the rotation is. init_trans = ( torch.einsum("tij,btj->bti", R_c2w_sla, root_loc) + t_c2w_sla[None, :] + offset ) data_world = { "init_root_orient": init_rot, # (B, T, 3) "init_hand_pose": data_out_init_hand_pose, # (B, T, 15, 3) "init_trans": init_trans, # (B, T, 3) "init_betas": data_out["init_betas"] # (B, T, 10) } return data_world def filling_preprocess(item): num_joints = 15 global_trans = item['trans'] # (2, seq_len, 3) global_rot = item['rot'] #(2, seq_len, 3) hand_pose = item['hand_pose'] # (2, seq_len, 45) betas = item['betas'] # (2, seq_len, 10) valid = item['valid'] # (2, seq_len) N, T, _ = global_trans.shape R_canonical2world_left_aa = torch.from_numpy(global_rot[0, 0]) R_canonical2world_right_aa = torch.from_numpy(global_rot[1, 0]) R_world2canonical_left = angle_axis_to_rotation_matrix(R_canonical2world_left_aa).t() R_world2canonical_right = angle_axis_to_rotation_matrix(R_canonical2world_right_aa).t() # transform left hand to canonical hand_pose = hand_pose.reshape(N, T, num_joints, 3) data_world_left = { "init_trans": torch.from_numpy(global_trans[0:1]), "init_root_orient": angle_axis_to_rotation_matrix(torch.from_numpy(global_rot[0:1])), "init_hand_pose": angle_axis_to_rotation_matrix(torch.from_numpy(hand_pose[0:1])), "init_betas": torch.from_numpy(betas[0:1]), } data_left_init_root_orient = rotation_matrix_to_angle_axis(data_world_left["init_root_orient"]) data_left_init_hand_pose = rotation_matrix_to_angle_axis(data_world_left["init_hand_pose"]) outputs = run_mano_left(data_world_left["init_trans"], data_left_init_root_orient, data_left_init_hand_pose, betas=data_world_left["init_betas"]) init_trans = data_world_left["init_trans"][0, 0] # (3,) root_loc = outputs["joints"][0, 0, 0, :].cpu() # (3,) offset = init_trans - root_loc # It is a constant, no matter what the rotation is. t_world2canonical_left = -torch.einsum("ij,j->i", R_world2canonical_left, root_loc) - offset R_world2canonical_left = R_world2canonical_left.repeat(T, 1, 1) t_world2canonical_left = t_world2canonical_left.repeat(T, 1) data_canonical_left = world2canonical_convert(R_world2canonical_left, t_world2canonical_left, data_world_left, "left") # transform right hand to canonical data_world_right = { "init_trans": torch.from_numpy(global_trans[1:2]), "init_root_orient": angle_axis_to_rotation_matrix(torch.from_numpy(global_rot[1:2])), "init_hand_pose": angle_axis_to_rotation_matrix(torch.from_numpy(hand_pose[1:2])), "init_betas": torch.from_numpy(betas[1:2]), } data_right_init_root_orient = rotation_matrix_to_angle_axis(data_world_right["init_root_orient"]) data_right_init_hand_pose = rotation_matrix_to_angle_axis(data_world_right["init_hand_pose"]) outputs = run_mano(data_world_right["init_trans"], data_right_init_root_orient, data_right_init_hand_pose, betas=data_world_right["init_betas"]) init_trans = data_world_right["init_trans"][0, 0] # (3,) root_loc = outputs["joints"][0, 0, 0, :].cpu() # (3,) offset = init_trans - root_loc # It is a constant, no matter what the rotation is. t_world2canonical_right = -torch.einsum("ij,j->i", R_world2canonical_right, root_loc) - offset R_world2canonical_right = R_world2canonical_right.repeat(T, 1, 1) t_world2canonical_right = t_world2canonical_right.repeat(T, 1) data_canonical_right = world2canonical_convert(R_world2canonical_right, t_world2canonical_right, data_world_right, "right") # merge left and right canonical data global_rot = torch.cat((data_canonical_left['init_root_orient'], data_canonical_right['init_root_orient'])) global_trans = torch.cat((data_canonical_left['init_trans'], data_canonical_right['init_trans'])).numpy() # global_rot = angle_axis_to_quaternion(global_rot).numpy().reshape(N, T, 1, 4) global_rot = global_rot.reshape(N, T, 1, 3).numpy() hand_pose = hand_pose.reshape(N, T, 15, 3) # hand_pose = angle_axis_to_quaternion(torch.from_numpy(hand_pose)).numpy() # lerp and slerp global_trans_lerped = linear_interpolation_nd(global_trans, valid) betas_lerped = linear_interpolation_nd(betas, valid) global_rot_slerped = slerp_interpolation_aa(global_rot, valid) hand_pose_slerped = slerp_interpolation_aa(hand_pose, valid) # convert to rot6d global_rot_slerped_mat = angle_axis_to_rotation_matrix(torch.from_numpy(global_rot_slerped.reshape(N*T, -1))) # global_rot_slerped_mat = quaternion_to_rotation_matrix(torch.from_numpy(global_rot_slerped.reshape(N*T, -1))) global_rot_slerped_rot6d = rotmat_to_rot6d(global_rot_slerped_mat).reshape(N, T, -1).numpy() hand_pose_slerped_mat = angle_axis_to_rotation_matrix(torch.from_numpy(hand_pose_slerped.reshape(N*T*num_joints, -1))) # hand_pose_slerped_mat = quaternion_to_rotation_matrix(torch.from_numpy(hand_pose_slerped.reshape(N*T*num_joints, -1))) hand_pose_slerped_rot6d = rotmat_to_rot6d(hand_pose_slerped_mat).reshape(N, T, -1).numpy() # concat to (T, concat_dim) global_pose_vec_input = np.concatenate((global_trans_lerped, betas_lerped, global_rot_slerped_rot6d, hand_pose_slerped_rot6d), axis=-1).transpose(1, 0, 2).reshape(T, -1) R_canon2w_left = R_world2canonical_left.transpose(-1, -2) t_canon2w_left = -torch.einsum("tij,tj->ti", R_canon2w_left, t_world2canonical_left) R_canon2w_right = R_world2canonical_right.transpose(-1, -2) t_canon2w_right = -torch.einsum("tij,tj->ti", R_canon2w_right, t_world2canonical_right) transform_w_canon = { "R_w2canon_left": R_world2canonical_left, "t_w2canon_left": t_world2canonical_left, "R_canon2w_left": R_canon2w_left, "t_canon2w_left": t_canon2w_left, "R_w2canon_right": R_world2canonical_right, "t_w2canon_right": t_world2canonical_right, "R_canon2w_right": R_canon2w_right, "t_canon2w_right": t_canon2w_right, } return global_pose_vec_input, transform_w_canon def custom_rot6d_to_rotmat(rot6d): original_shape = rot6d.shape[:-1] rot6d = rot6d.reshape(-1, 6) mat = rot6d_to_rotmat(rot6d) mat = mat.reshape(*original_shape, 3, 3) return mat def filling_postprocess(output, transform_w_canon): # output = output.numpy() output = output.permute(1, 0, 2) # (2, T, -1) N, T, _ = output.shape canon_trans = output[:, :, :3] betas = output[:, :, 3:13] canon_rot_rot6d = output[:, :, 13:19] hand_pose_rot6d = output[:, :, 19:109].reshape(N, T, 15, 6) canon_rot_mat = custom_rot6d_to_rotmat(canon_rot_rot6d) hand_pose_mat = custom_rot6d_to_rotmat(hand_pose_rot6d) data_canonical_left = { "init_trans": canon_trans[[0], :, :], "init_root_orient": canon_rot_mat[[0], :, :, :], "init_hand_pose": hand_pose_mat[[0], :, :, :, :], "init_betas": betas[[0], :, :] } data_canonical_right = { "init_trans": canon_trans[[1], :, :], "init_root_orient": canon_rot_mat[[1], :, :, :], "init_hand_pose": hand_pose_mat[[1], :, :, :, :], "init_betas": betas[[1], :, :] } R_canon2w_left = transform_w_canon['R_canon2w_left'] t_canon2w_left = transform_w_canon['t_canon2w_left'] R_canon2w_right = transform_w_canon['R_canon2w_right'] t_canon2w_right = transform_w_canon['t_canon2w_right'] world_left = world2canonical_convert(R_canon2w_left, t_canon2w_left, data_canonical_left, "left") world_right = world2canonical_convert(R_canon2w_right, t_canon2w_right, data_canonical_right, "right") global_rot = torch.cat((world_left['init_root_orient'], world_right['init_root_orient'])).numpy() global_trans = torch.cat((world_left['init_trans'], world_right['init_trans'])).numpy() pred_data = { "trans": global_trans, # (2, T, 3) "rot": global_rot, # (2, T, 3) "hand_pose": rotation_matrix_to_angle_axis(hand_pose_mat).flatten(-2).numpy(), # (2, T, 45) "betas": betas.numpy(), # (2, T, 10) } return pred_data