Spaces:
Sleeping
Sleeping
# ---------------------------------------------------------------------------------------------- | |
# FastMETRO Official Code | |
# Copyright (c) POSTECH Algorithmic Machine Intelligence Lab. (P-AMI Lab.) All Rights Reserved | |
# Licensed under the MIT license. | |
# ---------------------------------------------------------------------------------------------- | |
# ---------------------------------------------------------------------------------------------- | |
# PostoMETRO Official Code | |
# Copyright (c) MIRACLE Lab. All Rights Reserved | |
# Licensed under the MIT license. | |
# ---------------------------------------------------------------------------------------------- | |
from __future__ import absolute_import, division, print_function | |
import torch | |
import numpy as np | |
import argparse | |
import os | |
import os.path as osp | |
from torch import nn | |
from postometro_utils.smpl import Mesh | |
from postometro_utils.transformer import build_transformer | |
from postometro_utils.positional_encoding import build_position_encoding | |
from postometro_utils.modules import FCBlock, MixerLayer | |
from pct_utils.pct import PCT | |
from pct_utils.pct_backbone import SwinV2TransformerRPE2FC | |
from postometro_utils.pose_resnet import get_pose_net as get_pose_resnet | |
from postometro_utils.pose_resnet_config import config as resnet_config | |
from postometro_utils.pose_hrnet import get_pose_hrnet | |
from postometro_utils.pose_hrnet_config import _C as hrnet_config | |
from postometro_utils.pose_hrnet_config import update_config as hrnet_update_config | |
CUR_DIR = osp.dirname(os.path.abspath(__file__)) | |
class PostoMETRO(nn.Module): | |
"""PostoMETRO for 3D human pose and mesh reconstruction from a single RGB image""" | |
def __init__(self, args, backbone, mesh_sampler, pct = None, num_joints=14, num_vertices=431): | |
""" | |
Parameters: | |
- args: Arguments | |
- backbone: CNN Backbone used to extract image features from the given image | |
- mesh_sampler: Mesh Sampler used in the coarse-to-fine mesh upsampling | |
- num_joints: The number of joint tokens used in the transformer decoder | |
- num_vertices: The number of vertex tokens used in the transformer decoder | |
""" | |
super().__init__() | |
self.args = args | |
self.backbone = backbone | |
self.mesh_sampler = mesh_sampler | |
self.num_joints = num_joints | |
self.num_vertices = num_vertices | |
# the number of transformer layers, set to default | |
num_enc_layers = 3 | |
num_dec_layers = 3 | |
# configurations for the first transformer | |
self.transformer_config_1 = {"model_dim": args.model_dim_1, "dropout": args.transformer_dropout, "nhead": args.transformer_nhead, | |
"feedforward_dim": args.feedforward_dim_1, "num_enc_layers": num_enc_layers, "num_dec_layers": num_dec_layers, | |
"pos_type": args.pos_type} | |
# configurations for the second transformer | |
self.transformer_config_2 = {"model_dim": args.model_dim_2, "dropout": args.transformer_dropout, "nhead": args.transformer_nhead, | |
"feedforward_dim": args.feedforward_dim_2, "num_enc_layers": num_enc_layers, "num_dec_layers": num_dec_layers, | |
"pos_type": args.pos_type} | |
# build transformers | |
self.transformer_1 = build_transformer(self.transformer_config_1) | |
self.transformer_2 = build_transformer(self.transformer_config_2) | |
# dimensionality reduction | |
self.dim_reduce_enc_cam = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"]) | |
self.dim_reduce_enc_img = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"]) | |
self.dim_reduce_dec = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"]) | |
# token embeddings | |
self.cam_token_embed = nn.Embedding(1, self.transformer_config_1["model_dim"]) | |
self.joint_token_embed = nn.Embedding(self.num_joints, self.transformer_config_1["model_dim"]) | |
self.vertex_token_embed = nn.Embedding(self.num_vertices, self.transformer_config_1["model_dim"]) | |
# positional encodings | |
self.position_encoding_1 = build_position_encoding(pos_type=self.transformer_config_1['pos_type'], hidden_dim=self.transformer_config_1['model_dim']) | |
self.position_encoding_2 = build_position_encoding(pos_type=self.transformer_config_2['pos_type'], hidden_dim=self.transformer_config_2['model_dim']) | |
# estimators | |
self.xyz_regressor = nn.Linear(self.transformer_config_2["model_dim"], 3) | |
self.cam_predictor = nn.Linear(self.transformer_config_2["model_dim"], 3) | |
# 1x1 Convolution | |
self.conv_1x1 = nn.Conv2d(args.conv_1x1_dim, self.transformer_config_1["model_dim"], kernel_size=1) | |
# attention mask | |
zeros_1 = torch.tensor(np.zeros((num_vertices, num_joints)).astype(bool)) | |
zeros_2 = torch.tensor(np.zeros((num_joints, (num_joints + num_vertices))).astype(bool)) | |
adjacency_indices = torch.load(osp.join(CUR_DIR, 'data/smpl_431_adjmat_indices.pt')) | |
adjacency_matrix_value = torch.load(osp.join(CUR_DIR, 'data/smpl_431_adjmat_values.pt')) | |
adjacency_matrix_size = torch.load(osp.join(CUR_DIR, 'data/smpl_431_adjmat_size.pt')) | |
adjacency_matrix = torch.sparse_coo_tensor(adjacency_indices, adjacency_matrix_value, size=adjacency_matrix_size).to_dense() | |
temp_mask_1 = (adjacency_matrix == 0) | |
temp_mask_2 = torch.cat([zeros_1, temp_mask_1], dim=1) | |
self.attention_mask = torch.cat([zeros_2, temp_mask_2], dim=0) | |
# learnable upsampling layer is used (from coarse mesh to intermediate mesh); for visually pleasing mesh result | |
### pre-computed upsampling matrix is used (from intermediate mesh to fine mesh); to reduce optimization difficulty | |
self.coarse2intermediate_upsample = nn.Linear(431, 1723) | |
# using extra token | |
self.pct = None | |
if pct is not None: | |
self.pct = pct | |
# +1 to align with uncertainty score | |
self.token_mixer = FCBlock(args.tokenizer_codebook_token_dim + 1, self.transformer_config_1["model_dim"]) | |
self.start_embed = nn.Linear(512, args.enc_hidden_dim) | |
self.encoder = nn.ModuleList( | |
[MixerLayer(args.enc_hidden_dim, args.enc_hidden_inter_dim, | |
args.num_joints, args.token_inter_dim, | |
args.enc_dropout) for _ in range(args.enc_num_blocks)]) | |
self.encoder_layer_norm = nn.LayerNorm(args.enc_hidden_dim) | |
self.token_mlp = nn.Linear(args.num_joints, args.token_num) | |
self.dim_reduce_enc_pct = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"]) | |
def forward(self, images): | |
device = images.device | |
batch_size = images.size(0) | |
# preparation | |
cam_token = self.cam_token_embed.weight.unsqueeze(1).repeat(1, batch_size, 1) # 1 X batch_size X 512 | |
jv_tokens = torch.cat([self.joint_token_embed.weight, self.vertex_token_embed.weight], dim=0).unsqueeze(1).repeat(1, batch_size, 1) # (num_joints + num_vertices) X batch_size X 512 | |
attention_mask = self.attention_mask.to(device) # (num_joints + num_vertices) X (num_joints + num_vertices) | |
pct_token = None | |
if self.pct is not None: | |
pct_out = self.pct(images, None, train=False) | |
pct_pose = pct_out['part_token_feat'].clone() | |
encode_feat = self.start_embed(pct_pose) # 2, 17, 512 | |
for num_layer in self.encoder: | |
encode_feat = num_layer(encode_feat) | |
encode_feat = self.encoder_layer_norm(encode_feat) | |
encode_feat = encode_feat.transpose(2, 1) | |
encode_feat = self.token_mlp(encode_feat).transpose(2, 1) | |
pct_token_out = encode_feat.permute(1,0,2) | |
pct_score = pct_out['encoding_scores'] | |
pct_score = pct_score.permute(1,0,2) | |
pct_token = torch.cat([pct_token_out, pct_score], dim = -1) | |
pct_token = self.token_mixer(pct_token) # [b, 34, 512] | |
# extract image features through a CNN backbone | |
_img_features = self.backbone(images) # batch_size X 2048 X 7 X 7 | |
_, _, h, w = _img_features.shape | |
img_features = self.conv_1x1(_img_features).flatten(2).permute(2, 0, 1) # 49 X batch_size X 512 | |
# positional encodings | |
pos_enc_1 = self.position_encoding_1(batch_size, h, w, device).flatten(2).permute(2, 0, 1) # 49 X batch_size X 512 | |
pos_enc_2 = self.position_encoding_2(batch_size, h, w, device).flatten(2).permute(2, 0, 1) # 49 X batch_size X 128 | |
# first transformer encoder-decoder | |
cam_features_1, enc_img_features_1, jv_features_1, pct_features_1 = self.transformer_1(img_features, cam_token, jv_tokens, pos_enc_1, pct_token = pct_token, attention_mask=attention_mask) | |
# progressive dimensionality reduction | |
reduced_cam_features_1 = self.dim_reduce_enc_cam(cam_features_1) # 1 X batch_size X 128 | |
reduced_enc_img_features_1 = self.dim_reduce_enc_img(enc_img_features_1) # 49 X batch_size X 128 | |
reduced_jv_features_1 = self.dim_reduce_dec(jv_features_1) # (num_joints + num_vertices) X batch_size X 128 | |
reduced_pct_features_1 = None | |
if pct_features_1 is not None: | |
reduced_pct_features_1 = self.dim_reduce_enc_pct(pct_features_1) | |
# second transformer encoder-decoder | |
cam_features_2, _, jv_features_2,_ = self.transformer_2(reduced_enc_img_features_1, reduced_cam_features_1, reduced_jv_features_1, pos_enc_2, pct_token = reduced_pct_features_1, attention_mask=attention_mask) | |
# estimators | |
pred_cam = self.cam_predictor(cam_features_2).view(batch_size, 3) # batch_size X 3 | |
pred_3d_coordinates = self.xyz_regressor(jv_features_2.transpose(0, 1)) # batch_size X (num_joints + num_vertices) X 3 | |
pred_3d_joints = pred_3d_coordinates[:,:self.num_joints,:] # batch_size X num_joints X 3 | |
pred_3d_vertices_coarse = pred_3d_coordinates[:,self.num_joints:,:] # batch_size X num_vertices(coarse) X 3 | |
# coarse-to-intermediate mesh upsampling | |
pred_3d_vertices_intermediate = self.coarse2intermediate_upsample(pred_3d_vertices_coarse.transpose(1,2)).transpose(1,2) # batch_size X num_vertices(intermediate) X 3 | |
# intermediate-to-fine mesh upsampling | |
pred_3d_vertices_fine = self.mesh_sampler.upsample(pred_3d_vertices_intermediate, n1=1, n2=0) # batch_size X num_vertices(fine) X 3 | |
out = {} | |
out['pred_cam'] = pred_cam | |
out['pct_pose'] = pct_out['pred_pose'] if self.pct is not None else torch.zeros((batch_size, 34, 2)).cuda(device) | |
out['pred_3d_joints'] = pred_3d_joints | |
out['pred_3d_vertices_coarse'] = pred_3d_vertices_coarse | |
out['pred_3d_vertices_intermediate'] = pred_3d_vertices_intermediate | |
out['pred_3d_vertices_fine'] = pred_3d_vertices_fine | |
return out | |
defaults_args = argparse.Namespace( | |
pos_type = 'sine', | |
transformer_dropout = 0.1, | |
transformer_nhead = 8, | |
conv_1x1_dim = 2048, | |
tokenizer_codebook_token_dim = 512, | |
model_dim_1 = 512, | |
feedforward_dim_1 = 2048, | |
model_dim_2 = 128, | |
feedforward_dim_2 = 512, | |
enc_hidden_dim = 512, | |
enc_hidden_inter_dim = 512, | |
token_inter_dim = 64, | |
enc_dropout = 0.0, | |
enc_num_blocks = 4, | |
num_joints = 34, | |
token_num = 34 | |
) | |
default_pct_args = argparse.Namespace( | |
pct_backbone_channel = 1536, | |
tokenizer_guide_ratio=0.5, | |
cls_head_conv_channels=256, | |
cls_head_hidden_dim=64, | |
cls_head_num_blocks=4, | |
cls_head_hidden_inter_dim=256, | |
cls_head_token_inter_dim=64, | |
cls_head_dropout=0.0, | |
cls_head_conv_num_blocks=2, | |
cls_head_dilation=1, | |
# tokenzier | |
tokenizer_encoder_drop_rate=0.2, | |
tokenizer_encoder_num_blocks=4, | |
tokenizer_encoder_hidden_dim=512, | |
tokenizer_encoder_token_inter_dim=64, | |
tokenizer_encoder_hidden_inter_dim=512, | |
tokenizer_encoder_dropout=0.0, | |
tokenizer_decoder_num_blocks=1, | |
tokenizer_decoder_hidden_dim=32, | |
tokenizer_decoder_token_inter_dim=64, | |
tokenizer_decoder_hidden_inter_dim=64, | |
tokenizer_decoder_dropout=0.0, | |
tokenizer_codebook_token_num=34, | |
tokenizer_codebook_token_dim=512, | |
tokenizer_codebook_token_class_num=2048, | |
tokenizer_codebook_ema_decay=0.9, | |
) | |
backbone_config=dict( | |
embed_dim=192, | |
depths=[2, 2, 18, 2], | |
num_heads=[6, 12, 24, 48], | |
window_size=[16, 16, 16, 8], | |
pretrain_window_size=[12, 12, 12, 6], | |
ape=False, | |
drop_path_rate=0.5, | |
patch_norm=True, | |
use_checkpoint=True, | |
rpe_interpolation='geo', | |
use_shift=[True, True, False, False], | |
relative_coords_table_type='norm8_log_bylayer', | |
attn_type='cosine_mh', | |
rpe_output_type='sigmoid', | |
postnorm=True, | |
mlp_type='normal', | |
out_indices=(3,), | |
patch_embed_type='normal', | |
patch_merge_type='normal', | |
strid16=False, | |
frozen_stages=5, | |
) | |
def get_model(backbone_str = 'resnet50', device = torch.device('cpu'), checkpoint_file = None): | |
if backbone_str == 'hrnet-w48': | |
defaults_args.conv_1x1_dim = 384 | |
# update hrnet config by yaml | |
hrnet_yaml = osp.join(CUR_DIR,'postometro_utils', 'pose_w48_256x192_adam_lr1e-3.yaml') | |
hrnet_update_config(hrnet_config, hrnet_yaml) | |
backbone = get_pose_hrnet(hrnet_config, None) | |
else: | |
backbone = get_pose_resnet(resnet_config, is_train=False) | |
mesh_upsampler = Mesh(device=device) | |
pct_swin_backbone = SwinV2TransformerRPE2FC(**backbone_config) | |
# initialize pct head | |
pct = PCT(default_pct_args, pct_swin_backbone, 'classifier', default_pct_args.pct_backbone_channel, (256, 256), 17, None, None).to(device) | |
model = PostoMETRO(defaults_args, backbone, mesh_upsampler, pct=pct).to(device) | |
print("[INFO] model loaded, params: {}, {}".format(backbone_str, device)) | |
if checkpoint_file: | |
cpu_device = torch.device('cpu') | |
state_dict = torch.load(checkpoint_file, map_location=cpu_device) | |
model.load_state_dict(state_dict, strict=True) | |
del state_dict | |
print("[INFO] checkpoint loaded, params: {}, {}".format(backbone_str, device)) | |
return model | |
if __name__ == '__main__': | |
test_model = get_model(device=torch.device('cuda')) | |
images = torch.randn(1,3,256,256).to(torch.device('cuda')) | |
test_out = test_model(images) | |
print("[TEST] resnet50, cuda : pass") | |
test_model = get_model() | |
images = torch.randn(1,3,256,256).to() | |
test_out = test_model(images) | |
print("[TEST] resnet50, cpu : pass") | |
test_model = get_model(backbone_str='hrnet-w48', device=torch.device('cuda')) | |
images = torch.randn(1,3,256,256).to(torch.device('cuda')) | |
test_out = test_model(images) | |
print("[TEST] hrnet-w48, cuda : pass") | |
test_model = get_model(backbone_str='hrnet-w48') | |
images = torch.randn(1,3,256,256).to() | |
test_out = test_model(images) | |
print("[TEST] hrnet-w48, cpu : pass") | |