Spaces:
Running
Running
import torch | |
import numpy as np | |
from cmib.data.quaternion import qmul, qrot | |
import torch.nn as nn | |
amass_offsets = [ | |
[0.0, 0.0, 0.0], | |
[0.058581, -0.082280, -0.017664], | |
[0.043451, -0.386469, 0.008037], | |
[-0.014790, -0.426874, -0.037428], | |
[0.041054, -0.060286, 0.122042], | |
[0.0, 0.0, 0.0], | |
[-0.060310, -0.090513, -0.013543], | |
[-0.043257, -0.383688, -0.004843], | |
[0.019056, -0.420046, -0.034562], | |
[-0.034840, -0.062106, 0.130323], | |
[0.0, 0.0, 0.0], | |
[0.004439, 0.124404, -0.038385], | |
[0.004488, 0.137956, 0.026820], | |
[-0.002265, 0.056032, 0.002855], | |
[-0.013390, 0.211636, -0.033468], | |
[0.010113, 0.088937, 0.050410], | |
[0.0, 0.0, 0.0], | |
[0.071702, 0.114000, -0.018898], | |
[0.122921, 0.045205, -0.019046], | |
[0.255332, -0.015649, -0.022946], | |
[0.265709, 0.012698, -0.007375], | |
[0.0, 0.0, 0.0], | |
[-0.082954, 0.112472, -0.023707], | |
[-0.113228, 0.046853, -0.008472], | |
[-0.260127, -0.014369, -0.031269], | |
[-0.269108, 0.006794, -0.006027], | |
[0.0, 0.0, 0.0] | |
] | |
sk_offsets = [ | |
[-42.198200, 91.614723, -40.067841], | |
[0.103456, 1.857829, 10.548506], | |
[43.499992, -0.000038, -0.000002], | |
[42.372192, 0.000015, -0.000007], | |
[17.299999, -0.000002, 0.000003], | |
[0.000000, 0.000000, 0.000000], | |
[0.103457, 1.857829, -10.548503], | |
[43.500042, -0.000027, 0.000008], | |
[42.372257, -0.000008, 0.000014], | |
[17.299992, -0.000005, 0.000004], | |
[0.000000, 0.000000, 0.000000], | |
[6.901968, -2.603733, -0.000001], | |
[12.588099, 0.000002, 0.000000], | |
[12.343206, 0.000000, -0.000001], | |
[25.832886, -0.000004, 0.000003], | |
[11.766620, 0.000005, -0.000001], | |
[0.000000, 0.000000, 0.000000], | |
[19.745899, -1.480370, 6.000108], | |
[11.284125, -0.000009, -0.000018], | |
[33.000050, 0.000004, 0.000032], | |
[25.200008, 0.000015, 0.000008], | |
[0.000000, 0.000000, 0.000000], | |
[19.746099, -1.480375, -6.000073], | |
[11.284138, -0.000015, -0.000012], | |
[33.000092, 0.000017, 0.000013], | |
[25.199780, 0.000135, 0.000422], | |
[0.000000, 0.000000, 0.000000], | |
] | |
sk_parents = [ | |
-1, | |
0, | |
1, | |
2, | |
3, | |
4, | |
0, | |
6, | |
7, | |
8, | |
9, | |
0, | |
11, | |
12, | |
13, | |
14, | |
15, | |
13, | |
17, | |
18, | |
19, | |
20, | |
13, | |
22, | |
23, | |
24, | |
25, | |
] | |
sk_joints_to_remove = [5, 10, 16, 21, 26] | |
joint_names = [ | |
"Hips", | |
"LeftUpLeg", | |
"LeftLeg", | |
"LeftFoot", | |
"LeftToe", | |
"RightUpLeg", | |
"RightLeg", | |
"RightFoot", | |
"RightToe", | |
"Spine", | |
"Spine1", | |
"Spine2", | |
"Neck", | |
"Head", | |
"LeftShoulder", | |
"LeftArm", | |
"LeftForeArm", | |
"LeftHand", | |
"RightShoulder", | |
"RightArm", | |
"RightForeArm", | |
"RightHand", | |
] | |
class Skeleton: | |
def __init__( | |
self, | |
offsets, | |
parents, | |
joints_left=None, | |
joints_right=None, | |
bone_length=None, | |
device=None, | |
): | |
assert len(offsets) == len(parents) | |
self._offsets = torch.Tensor(offsets).to(device) | |
self._parents = np.array(parents) | |
self._joints_left = joints_left | |
self._joints_right = joints_right | |
self._compute_metadata() | |
def num_joints(self): | |
return self._offsets.shape[0] | |
def offsets(self): | |
return self._offsets | |
def parents(self): | |
return self._parents | |
def has_children(self): | |
return self._has_children | |
def children(self): | |
return self._children | |
def convert_to_global_pos(self, unit_vec_rerp): | |
""" | |
Convert the unit offset matrix to global position. | |
First row(root) will have absolute position value in global coordinates. | |
""" | |
bone_length = self.get_bone_length_weight() | |
batch_size = unit_vec_rerp.size(0) | |
seq_len = unit_vec_rerp.size(1) | |
unit_vec_table = unit_vec_rerp.reshape(batch_size, seq_len, 22, 3) | |
global_position = torch.zeros_like(unit_vec_table, device=unit_vec_table.device) | |
for i, parent in enumerate(self._parents): | |
if parent == -1: # if root | |
global_position[:, :, i] = unit_vec_table[:, :, i] | |
else: | |
global_position[:, :, i] = global_position[:, :, parent] + ( | |
nn.functional.normalize(unit_vec_table[:, :, i], p=2.0, dim=-1) | |
* bone_length[i] | |
) | |
return global_position | |
def convert_to_unit_offset_mat(self, global_position): | |
""" | |
Convert the global position of the skeleton to a unit offset matrix. | |
First row(root) will have absolute position value in global coordinates. | |
""" | |
bone_length = self.get_bone_length_weight() | |
unit_offset_mat = torch.zeros_like( | |
global_position, device=global_position.device | |
) | |
for i, parent in enumerate(self._parents): | |
if parent == -1: # if root | |
unit_offset_mat[:, :, i] = global_position[:, :, i] | |
else: | |
unit_offset_mat[:, :, i] = ( | |
global_position[:, :, i] - global_position[:, :, parent] | |
) / bone_length[i] | |
return unit_offset_mat | |
def remove_joints(self, joints_to_remove): | |
""" | |
Remove the joints specified in 'joints_to_remove', both from the | |
skeleton definition and from the dataset (which is modified in place). | |
The rotations of removed joints are propagated along the kinematic chain. | |
""" | |
valid_joints = [] | |
for joint in range(len(self._parents)): | |
if joint not in joints_to_remove: | |
valid_joints.append(joint) | |
index_offsets = np.zeros(len(self._parents), dtype=int) | |
new_parents = [] | |
for i, parent in enumerate(self._parents): | |
if i not in joints_to_remove: | |
new_parents.append(parent - index_offsets[parent]) | |
else: | |
index_offsets[i:] += 1 | |
self._parents = np.array(new_parents) | |
self._offsets = self._offsets[valid_joints] | |
self._compute_metadata() | |
def forward_kinematics(self, rotations, root_positions): | |
""" | |
Perform forward kinematics using the given trajectory and local rotations. | |
Arguments (where N = batch size, L = sequence length, J = number of joints): | |
-- rotations: (N, L, J, 4) tensor of unit quaternions describing the local rotations of each joint. | |
-- root_positions: (N, L, 3) tensor describing the root joint positions. | |
""" | |
assert len(rotations.shape) == 4 | |
assert rotations.shape[-1] == 4 | |
positions_world = [] | |
rotations_world = [] | |
expanded_offsets = self._offsets.expand( | |
rotations.shape[0], | |
rotations.shape[1], | |
self._offsets.shape[0], | |
self._offsets.shape[1], | |
) | |
# Parallelize along the batch and time dimensions | |
for i in range(self._offsets.shape[0]): | |
if self._parents[i] == -1: | |
positions_world.append(root_positions) | |
rotations_world.append(rotations[:, :, 0]) | |
else: | |
positions_world.append( | |
qrot(rotations_world[self._parents[i]], expanded_offsets[:, :, i]) | |
+ positions_world[self._parents[i]] | |
) | |
if self._has_children[i]: | |
rotations_world.append( | |
qmul(rotations_world[self._parents[i]], rotations[:, :, i]) | |
) | |
else: | |
# This joint is a terminal node -> it would be useless to compute the transformation | |
rotations_world.append(None) | |
return torch.stack(positions_world, dim=3).permute(0, 1, 3, 2) | |
def forward_kinematics_with_rotation(self, rotations, root_positions): | |
""" | |
Perform forward kinematics using the given trajectory and local rotations. | |
Arguments (where N = batch size, L = sequence length, J = number of joints): | |
-- rotations: (N, L, J, 4) tensor of unit quaternions describing the local rotations of each joint. | |
-- root_positions: (N, L, 3) tensor describing the root joint positions. | |
""" | |
assert len(rotations.shape) == 4 | |
assert rotations.shape[-1] == 4 | |
positions_world = [] | |
rotations_world = [] | |
expanded_offsets = self._offsets.expand( | |
rotations.shape[0], | |
rotations.shape[1], | |
self._offsets.shape[0], | |
self._offsets.shape[1], | |
) | |
# Parallelize along the batch and time dimensions | |
for i in range(self._offsets.shape[0]): | |
if self._parents[i] == -1: | |
positions_world.append(root_positions) | |
rotations_world.append(rotations[:, :, 0]) | |
else: | |
positions_world.append( | |
qrot(rotations_world[self._parents[i]], expanded_offsets[:, :, i]) | |
+ positions_world[self._parents[i]] | |
) | |
if self._has_children[i]: | |
rotations_world.append( | |
qmul(rotations_world[self._parents[i]], rotations[:, :, i]) | |
) | |
else: | |
# This joint is a terminal node -> it would be useless to compute the transformation | |
rotations_world.append( | |
torch.Tensor([1, 0, 0, 0]) | |
.expand(rotations.shape[0], rotations.shape[1], 4) | |
.to(rotations.device) | |
) | |
return torch.stack(positions_world, dim=3).permute(0, 1, 3, 2), torch.stack( | |
rotations_world, dim=3 | |
).permute(0, 1, 3, 2) | |
def get_bone_length_weight(self): | |
bone_length = [] | |
for i, parent in enumerate(self._parents): | |
if parent == -1: | |
bone_length.append(1) | |
else: | |
bone_length.append( | |
torch.linalg.norm(self._offsets[i : i + 1], ord="fro").item() | |
) | |
return torch.Tensor(bone_length) | |
def joints_left(self): | |
return self._joints_left | |
def joints_right(self): | |
return self._joints_right | |
def _compute_metadata(self): | |
self._has_children = np.zeros(len(self._parents)).astype(bool) | |
for i, parent in enumerate(self._parents): | |
if parent != -1: | |
self._has_children[parent] = True | |
self._children = [] | |
for i, parent in enumerate(self._parents): | |
self._children.append([]) | |
for i, parent in enumerate(self._parents): | |
if parent != -1: | |
self._children[parent].append(i) | |