HaWoR / infiller /lib /model /skeleton.py
ThunderVVV's picture
update
5f028d6
raw
history blame
10.7 kB
import torch
import numpy as np
from cmib.data.quaternion import qmul, qrot
import torch.nn as nn
amass_offsets = [
[0.0, 0.0, 0.0],
[0.058581, -0.082280, -0.017664],
[0.043451, -0.386469, 0.008037],
[-0.014790, -0.426874, -0.037428],
[0.041054, -0.060286, 0.122042],
[0.0, 0.0, 0.0],
[-0.060310, -0.090513, -0.013543],
[-0.043257, -0.383688, -0.004843],
[0.019056, -0.420046, -0.034562],
[-0.034840, -0.062106, 0.130323],
[0.0, 0.0, 0.0],
[0.004439, 0.124404, -0.038385],
[0.004488, 0.137956, 0.026820],
[-0.002265, 0.056032, 0.002855],
[-0.013390, 0.211636, -0.033468],
[0.010113, 0.088937, 0.050410],
[0.0, 0.0, 0.0],
[0.071702, 0.114000, -0.018898],
[0.122921, 0.045205, -0.019046],
[0.255332, -0.015649, -0.022946],
[0.265709, 0.012698, -0.007375],
[0.0, 0.0, 0.0],
[-0.082954, 0.112472, -0.023707],
[-0.113228, 0.046853, -0.008472],
[-0.260127, -0.014369, -0.031269],
[-0.269108, 0.006794, -0.006027],
[0.0, 0.0, 0.0]
]
sk_offsets = [
[-42.198200, 91.614723, -40.067841],
[0.103456, 1.857829, 10.548506],
[43.499992, -0.000038, -0.000002],
[42.372192, 0.000015, -0.000007],
[17.299999, -0.000002, 0.000003],
[0.000000, 0.000000, 0.000000],
[0.103457, 1.857829, -10.548503],
[43.500042, -0.000027, 0.000008],
[42.372257, -0.000008, 0.000014],
[17.299992, -0.000005, 0.000004],
[0.000000, 0.000000, 0.000000],
[6.901968, -2.603733, -0.000001],
[12.588099, 0.000002, 0.000000],
[12.343206, 0.000000, -0.000001],
[25.832886, -0.000004, 0.000003],
[11.766620, 0.000005, -0.000001],
[0.000000, 0.000000, 0.000000],
[19.745899, -1.480370, 6.000108],
[11.284125, -0.000009, -0.000018],
[33.000050, 0.000004, 0.000032],
[25.200008, 0.000015, 0.000008],
[0.000000, 0.000000, 0.000000],
[19.746099, -1.480375, -6.000073],
[11.284138, -0.000015, -0.000012],
[33.000092, 0.000017, 0.000013],
[25.199780, 0.000135, 0.000422],
[0.000000, 0.000000, 0.000000],
]
sk_parents = [
-1,
0,
1,
2,
3,
4,
0,
6,
7,
8,
9,
0,
11,
12,
13,
14,
15,
13,
17,
18,
19,
20,
13,
22,
23,
24,
25,
]
sk_joints_to_remove = [5, 10, 16, 21, 26]
joint_names = [
"Hips",
"LeftUpLeg",
"LeftLeg",
"LeftFoot",
"LeftToe",
"RightUpLeg",
"RightLeg",
"RightFoot",
"RightToe",
"Spine",
"Spine1",
"Spine2",
"Neck",
"Head",
"LeftShoulder",
"LeftArm",
"LeftForeArm",
"LeftHand",
"RightShoulder",
"RightArm",
"RightForeArm",
"RightHand",
]
class Skeleton:
def __init__(
self,
offsets,
parents,
joints_left=None,
joints_right=None,
bone_length=None,
device=None,
):
assert len(offsets) == len(parents)
self._offsets = torch.Tensor(offsets).to(device)
self._parents = np.array(parents)
self._joints_left = joints_left
self._joints_right = joints_right
self._compute_metadata()
def num_joints(self):
return self._offsets.shape[0]
def offsets(self):
return self._offsets
def parents(self):
return self._parents
def has_children(self):
return self._has_children
def children(self):
return self._children
def convert_to_global_pos(self, unit_vec_rerp):
"""
Convert the unit offset matrix to global position.
First row(root) will have absolute position value in global coordinates.
"""
bone_length = self.get_bone_length_weight()
batch_size = unit_vec_rerp.size(0)
seq_len = unit_vec_rerp.size(1)
unit_vec_table = unit_vec_rerp.reshape(batch_size, seq_len, 22, 3)
global_position = torch.zeros_like(unit_vec_table, device=unit_vec_table.device)
for i, parent in enumerate(self._parents):
if parent == -1: # if root
global_position[:, :, i] = unit_vec_table[:, :, i]
else:
global_position[:, :, i] = global_position[:, :, parent] + (
nn.functional.normalize(unit_vec_table[:, :, i], p=2.0, dim=-1)
* bone_length[i]
)
return global_position
def convert_to_unit_offset_mat(self, global_position):
"""
Convert the global position of the skeleton to a unit offset matrix.
First row(root) will have absolute position value in global coordinates.
"""
bone_length = self.get_bone_length_weight()
unit_offset_mat = torch.zeros_like(
global_position, device=global_position.device
)
for i, parent in enumerate(self._parents):
if parent == -1: # if root
unit_offset_mat[:, :, i] = global_position[:, :, i]
else:
unit_offset_mat[:, :, i] = (
global_position[:, :, i] - global_position[:, :, parent]
) / bone_length[i]
return unit_offset_mat
def remove_joints(self, joints_to_remove):
"""
Remove the joints specified in 'joints_to_remove', both from the
skeleton definition and from the dataset (which is modified in place).
The rotations of removed joints are propagated along the kinematic chain.
"""
valid_joints = []
for joint in range(len(self._parents)):
if joint not in joints_to_remove:
valid_joints.append(joint)
index_offsets = np.zeros(len(self._parents), dtype=int)
new_parents = []
for i, parent in enumerate(self._parents):
if i not in joints_to_remove:
new_parents.append(parent - index_offsets[parent])
else:
index_offsets[i:] += 1
self._parents = np.array(new_parents)
self._offsets = self._offsets[valid_joints]
self._compute_metadata()
def forward_kinematics(self, rotations, root_positions):
"""
Perform forward kinematics using the given trajectory and local rotations.
Arguments (where N = batch size, L = sequence length, J = number of joints):
-- rotations: (N, L, J, 4) tensor of unit quaternions describing the local rotations of each joint.
-- root_positions: (N, L, 3) tensor describing the root joint positions.
"""
assert len(rotations.shape) == 4
assert rotations.shape[-1] == 4
positions_world = []
rotations_world = []
expanded_offsets = self._offsets.expand(
rotations.shape[0],
rotations.shape[1],
self._offsets.shape[0],
self._offsets.shape[1],
)
# Parallelize along the batch and time dimensions
for i in range(self._offsets.shape[0]):
if self._parents[i] == -1:
positions_world.append(root_positions)
rotations_world.append(rotations[:, :, 0])
else:
positions_world.append(
qrot(rotations_world[self._parents[i]], expanded_offsets[:, :, i])
+ positions_world[self._parents[i]]
)
if self._has_children[i]:
rotations_world.append(
qmul(rotations_world[self._parents[i]], rotations[:, :, i])
)
else:
# This joint is a terminal node -> it would be useless to compute the transformation
rotations_world.append(None)
return torch.stack(positions_world, dim=3).permute(0, 1, 3, 2)
def forward_kinematics_with_rotation(self, rotations, root_positions):
"""
Perform forward kinematics using the given trajectory and local rotations.
Arguments (where N = batch size, L = sequence length, J = number of joints):
-- rotations: (N, L, J, 4) tensor of unit quaternions describing the local rotations of each joint.
-- root_positions: (N, L, 3) tensor describing the root joint positions.
"""
assert len(rotations.shape) == 4
assert rotations.shape[-1] == 4
positions_world = []
rotations_world = []
expanded_offsets = self._offsets.expand(
rotations.shape[0],
rotations.shape[1],
self._offsets.shape[0],
self._offsets.shape[1],
)
# Parallelize along the batch and time dimensions
for i in range(self._offsets.shape[0]):
if self._parents[i] == -1:
positions_world.append(root_positions)
rotations_world.append(rotations[:, :, 0])
else:
positions_world.append(
qrot(rotations_world[self._parents[i]], expanded_offsets[:, :, i])
+ positions_world[self._parents[i]]
)
if self._has_children[i]:
rotations_world.append(
qmul(rotations_world[self._parents[i]], rotations[:, :, i])
)
else:
# This joint is a terminal node -> it would be useless to compute the transformation
rotations_world.append(
torch.Tensor([1, 0, 0, 0])
.expand(rotations.shape[0], rotations.shape[1], 4)
.to(rotations.device)
)
return torch.stack(positions_world, dim=3).permute(0, 1, 3, 2), torch.stack(
rotations_world, dim=3
).permute(0, 1, 3, 2)
def get_bone_length_weight(self):
bone_length = []
for i, parent in enumerate(self._parents):
if parent == -1:
bone_length.append(1)
else:
bone_length.append(
torch.linalg.norm(self._offsets[i : i + 1], ord="fro").item()
)
return torch.Tensor(bone_length)
def joints_left(self):
return self._joints_left
def joints_right(self):
return self._joints_right
def _compute_metadata(self):
self._has_children = np.zeros(len(self._parents)).astype(bool)
for i, parent in enumerate(self._parents):
if parent != -1:
self._has_children[parent] = True
self._children = []
for i, parent in enumerate(self._parents):
self._children.append([])
for i, parent in enumerate(self._parents):
if parent != -1:
self._children[parent].append(i)