# ---------------------------------------------------------------------------------------------- # FastMETRO Official Code # Copyright (c) POSTECH Algorithmic Machine Intelligence Lab. (P-AMI Lab.) All Rights Reserved # Licensed under the MIT license. # ---------------------------------------------------------------------------------------------- # ---------------------------------------------------------------------------------------------- # PostoMETRO Official Code # Copyright (c) MIRACLE Lab. All Rights Reserved # Licensed under the MIT license. # ---------------------------------------------------------------------------------------------- from __future__ import absolute_import, division, print_function import torch import numpy as np import argparse import os import os.path as osp from torch import nn from postometro_utils.smpl import Mesh from postometro_utils.transformer import build_transformer from postometro_utils.positional_encoding import build_position_encoding from postometro_utils.modules import FCBlock, MixerLayer from pct_utils.pct import PCT from pct_utils.pct_backbone import SwinV2TransformerRPE2FC from postometro_utils.pose_resnet import get_pose_net as get_pose_resnet from postometro_utils.pose_resnet_config import config as resnet_config from postometro_utils.pose_hrnet import get_pose_hrnet from postometro_utils.pose_hrnet_config import _C as hrnet_config from postometro_utils.pose_hrnet_config import update_config as hrnet_update_config CUR_DIR = osp.dirname(os.path.abspath(__file__)) class PostoMETRO(nn.Module): """PostoMETRO for 3D human pose and mesh reconstruction from a single RGB image""" def __init__(self, args, backbone, mesh_sampler, pct = None, num_joints=14, num_vertices=431): """ Parameters: - args: Arguments - backbone: CNN Backbone used to extract image features from the given image - mesh_sampler: Mesh Sampler used in the coarse-to-fine mesh upsampling - num_joints: The number of joint tokens used in the transformer decoder - num_vertices: The number of vertex tokens used in the transformer decoder """ super().__init__() self.args = args self.backbone = backbone self.mesh_sampler = mesh_sampler self.num_joints = num_joints self.num_vertices = num_vertices # the number of transformer layers, set to default num_enc_layers = 3 num_dec_layers = 3 # configurations for the first transformer self.transformer_config_1 = {"model_dim": args.model_dim_1, "dropout": args.transformer_dropout, "nhead": args.transformer_nhead, "feedforward_dim": args.feedforward_dim_1, "num_enc_layers": num_enc_layers, "num_dec_layers": num_dec_layers, "pos_type": args.pos_type} # configurations for the second transformer self.transformer_config_2 = {"model_dim": args.model_dim_2, "dropout": args.transformer_dropout, "nhead": args.transformer_nhead, "feedforward_dim": args.feedforward_dim_2, "num_enc_layers": num_enc_layers, "num_dec_layers": num_dec_layers, "pos_type": args.pos_type} # build transformers self.transformer_1 = build_transformer(self.transformer_config_1) self.transformer_2 = build_transformer(self.transformer_config_2) # dimensionality reduction self.dim_reduce_enc_cam = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"]) self.dim_reduce_enc_img = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"]) self.dim_reduce_dec = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"]) # token embeddings self.cam_token_embed = nn.Embedding(1, self.transformer_config_1["model_dim"]) self.joint_token_embed = nn.Embedding(self.num_joints, self.transformer_config_1["model_dim"]) self.vertex_token_embed = nn.Embedding(self.num_vertices, self.transformer_config_1["model_dim"]) # positional encodings self.position_encoding_1 = build_position_encoding(pos_type=self.transformer_config_1['pos_type'], hidden_dim=self.transformer_config_1['model_dim']) self.position_encoding_2 = build_position_encoding(pos_type=self.transformer_config_2['pos_type'], hidden_dim=self.transformer_config_2['model_dim']) # estimators self.xyz_regressor = nn.Linear(self.transformer_config_2["model_dim"], 3) self.cam_predictor = nn.Linear(self.transformer_config_2["model_dim"], 3) # 1x1 Convolution self.conv_1x1 = nn.Conv2d(args.conv_1x1_dim, self.transformer_config_1["model_dim"], kernel_size=1) # attention mask zeros_1 = torch.tensor(np.zeros((num_vertices, num_joints)).astype(bool)) zeros_2 = torch.tensor(np.zeros((num_joints, (num_joints + num_vertices))).astype(bool)) adjacency_indices = torch.load(osp.join(CUR_DIR, 'data/smpl_431_adjmat_indices.pt')) adjacency_matrix_value = torch.load(osp.join(CUR_DIR, 'data/smpl_431_adjmat_values.pt')) adjacency_matrix_size = torch.load(osp.join(CUR_DIR, 'data/smpl_431_adjmat_size.pt')) adjacency_matrix = torch.sparse_coo_tensor(adjacency_indices, adjacency_matrix_value, size=adjacency_matrix_size).to_dense() temp_mask_1 = (adjacency_matrix == 0) temp_mask_2 = torch.cat([zeros_1, temp_mask_1], dim=1) self.attention_mask = torch.cat([zeros_2, temp_mask_2], dim=0) # learnable upsampling layer is used (from coarse mesh to intermediate mesh); for visually pleasing mesh result ### pre-computed upsampling matrix is used (from intermediate mesh to fine mesh); to reduce optimization difficulty self.coarse2intermediate_upsample = nn.Linear(431, 1723) # using extra token self.pct = None if pct is not None: self.pct = pct # +1 to align with uncertainty score self.token_mixer = FCBlock(args.tokenizer_codebook_token_dim + 1, self.transformer_config_1["model_dim"]) self.start_embed = nn.Linear(512, args.enc_hidden_dim) self.encoder = nn.ModuleList( [MixerLayer(args.enc_hidden_dim, args.enc_hidden_inter_dim, args.num_joints, args.token_inter_dim, args.enc_dropout) for _ in range(args.enc_num_blocks)]) self.encoder_layer_norm = nn.LayerNorm(args.enc_hidden_dim) self.token_mlp = nn.Linear(args.num_joints, args.token_num) self.dim_reduce_enc_pct = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"]) def forward(self, images): device = images.device batch_size = images.size(0) # preparation cam_token = self.cam_token_embed.weight.unsqueeze(1).repeat(1, batch_size, 1) # 1 X batch_size X 512 jv_tokens = torch.cat([self.joint_token_embed.weight, self.vertex_token_embed.weight], dim=0).unsqueeze(1).repeat(1, batch_size, 1) # (num_joints + num_vertices) X batch_size X 512 attention_mask = self.attention_mask.to(device) # (num_joints + num_vertices) X (num_joints + num_vertices) pct_token = None if self.pct is not None: pct_out = self.pct(images, None, train=False) pct_pose = pct_out['part_token_feat'].clone() encode_feat = self.start_embed(pct_pose) # 2, 17, 512 for num_layer in self.encoder: encode_feat = num_layer(encode_feat) encode_feat = self.encoder_layer_norm(encode_feat) encode_feat = encode_feat.transpose(2, 1) encode_feat = self.token_mlp(encode_feat).transpose(2, 1) pct_token_out = encode_feat.permute(1,0,2) pct_score = pct_out['encoding_scores'] pct_score = pct_score.permute(1,0,2) pct_token = torch.cat([pct_token_out, pct_score], dim = -1) pct_token = self.token_mixer(pct_token) # [b, 34, 512] # extract image features through a CNN backbone _img_features = self.backbone(images) # batch_size X 2048 X 7 X 7 _, _, h, w = _img_features.shape img_features = self.conv_1x1(_img_features).flatten(2).permute(2, 0, 1) # 49 X batch_size X 512 # positional encodings pos_enc_1 = self.position_encoding_1(batch_size, h, w, device).flatten(2).permute(2, 0, 1) # 49 X batch_size X 512 pos_enc_2 = self.position_encoding_2(batch_size, h, w, device).flatten(2).permute(2, 0, 1) # 49 X batch_size X 128 # first transformer encoder-decoder cam_features_1, enc_img_features_1, jv_features_1, pct_features_1 = self.transformer_1(img_features, cam_token, jv_tokens, pos_enc_1, pct_token = pct_token, attention_mask=attention_mask) # progressive dimensionality reduction reduced_cam_features_1 = self.dim_reduce_enc_cam(cam_features_1) # 1 X batch_size X 128 reduced_enc_img_features_1 = self.dim_reduce_enc_img(enc_img_features_1) # 49 X batch_size X 128 reduced_jv_features_1 = self.dim_reduce_dec(jv_features_1) # (num_joints + num_vertices) X batch_size X 128 reduced_pct_features_1 = None if pct_features_1 is not None: reduced_pct_features_1 = self.dim_reduce_enc_pct(pct_features_1) # second transformer encoder-decoder cam_features_2, _, jv_features_2,_ = self.transformer_2(reduced_enc_img_features_1, reduced_cam_features_1, reduced_jv_features_1, pos_enc_2, pct_token = reduced_pct_features_1, attention_mask=attention_mask) # estimators pred_cam = self.cam_predictor(cam_features_2).view(batch_size, 3) # batch_size X 3 pred_3d_coordinates = self.xyz_regressor(jv_features_2.transpose(0, 1)) # batch_size X (num_joints + num_vertices) X 3 pred_3d_joints = pred_3d_coordinates[:,:self.num_joints,:] # batch_size X num_joints X 3 pred_3d_vertices_coarse = pred_3d_coordinates[:,self.num_joints:,:] # batch_size X num_vertices(coarse) X 3 # coarse-to-intermediate mesh upsampling pred_3d_vertices_intermediate = self.coarse2intermediate_upsample(pred_3d_vertices_coarse.transpose(1,2)).transpose(1,2) # batch_size X num_vertices(intermediate) X 3 # intermediate-to-fine mesh upsampling pred_3d_vertices_fine = self.mesh_sampler.upsample(pred_3d_vertices_intermediate, n1=1, n2=0) # batch_size X num_vertices(fine) X 3 out = {} out['pred_cam'] = pred_cam out['pct_pose'] = pct_out['pred_pose'] if self.pct is not None else torch.zeros((batch_size, 34, 2)).cuda(device) out['pred_3d_joints'] = pred_3d_joints out['pred_3d_vertices_coarse'] = pred_3d_vertices_coarse out['pred_3d_vertices_intermediate'] = pred_3d_vertices_intermediate out['pred_3d_vertices_fine'] = pred_3d_vertices_fine return out defaults_args = argparse.Namespace( pos_type = 'sine', transformer_dropout = 0.1, transformer_nhead = 8, conv_1x1_dim = 2048, tokenizer_codebook_token_dim = 512, model_dim_1 = 512, feedforward_dim_1 = 2048, model_dim_2 = 128, feedforward_dim_2 = 512, enc_hidden_dim = 512, enc_hidden_inter_dim = 512, token_inter_dim = 64, enc_dropout = 0.0, enc_num_blocks = 4, num_joints = 34, token_num = 34 ) default_pct_args = argparse.Namespace( pct_backbone_channel = 1536, tokenizer_guide_ratio=0.5, cls_head_conv_channels=256, cls_head_hidden_dim=64, cls_head_num_blocks=4, cls_head_hidden_inter_dim=256, cls_head_token_inter_dim=64, cls_head_dropout=0.0, cls_head_conv_num_blocks=2, cls_head_dilation=1, # tokenzier tokenizer_encoder_drop_rate=0.2, tokenizer_encoder_num_blocks=4, tokenizer_encoder_hidden_dim=512, tokenizer_encoder_token_inter_dim=64, tokenizer_encoder_hidden_inter_dim=512, tokenizer_encoder_dropout=0.0, tokenizer_decoder_num_blocks=1, tokenizer_decoder_hidden_dim=32, tokenizer_decoder_token_inter_dim=64, tokenizer_decoder_hidden_inter_dim=64, tokenizer_decoder_dropout=0.0, tokenizer_codebook_token_num=34, tokenizer_codebook_token_dim=512, tokenizer_codebook_token_class_num=2048, tokenizer_codebook_ema_decay=0.9, ) backbone_config=dict( embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=[16, 16, 16, 8], pretrain_window_size=[12, 12, 12, 6], ape=False, drop_path_rate=0.5, patch_norm=True, use_checkpoint=True, rpe_interpolation='geo', use_shift=[True, True, False, False], relative_coords_table_type='norm8_log_bylayer', attn_type='cosine_mh', rpe_output_type='sigmoid', postnorm=True, mlp_type='normal', out_indices=(3,), patch_embed_type='normal', patch_merge_type='normal', strid16=False, frozen_stages=5, ) def get_model(backbone_str = 'resnet50', device = torch.device('cpu'), checkpoint_file = None): if backbone_str == 'hrnet-w48': defaults_args.conv_1x1_dim = 384 # update hrnet config by yaml hrnet_yaml = osp.join(CUR_DIR,'postometro_utils', 'pose_w48_256x192_adam_lr1e-3.yaml') hrnet_update_config(hrnet_config, hrnet_yaml) backbone = get_pose_hrnet(hrnet_config, None) else: backbone = get_pose_resnet(resnet_config, is_train=False) mesh_upsampler = Mesh(device=device) pct_swin_backbone = SwinV2TransformerRPE2FC(**backbone_config) # initialize pct head pct = PCT(default_pct_args, pct_swin_backbone, 'classifier', default_pct_args.pct_backbone_channel, (256, 256), 17, None, None).to(device) model = PostoMETRO(defaults_args, backbone, mesh_upsampler, pct=pct).to(device) print("[INFO] model loaded, params: {}, {}".format(backbone_str, device)) if checkpoint_file: cpu_device = torch.device('cpu') state_dict = torch.load(checkpoint_file, map_location=cpu_device) model.load_state_dict(state_dict, strict=True) del state_dict print("[INFO] checkpoint loaded, params: {}, {}".format(backbone_str, device)) return model if __name__ == '__main__': test_model = get_model(device=torch.device('cuda')) images = torch.randn(1,3,256,256).to(torch.device('cuda')) test_out = test_model(images) print("[TEST] resnet50, cuda : pass") test_model = get_model() images = torch.randn(1,3,256,256).to() test_out = test_model(images) print("[TEST] resnet50, cpu : pass") test_model = get_model(backbone_str='hrnet-w48', device=torch.device('cuda')) images = torch.randn(1,3,256,256).to(torch.device('cuda')) test_out = test_model(images) print("[TEST] hrnet-w48, cuda : pass") test_model = get_model(backbone_str='hrnet-w48') images = torch.randn(1,3,256,256).to() test_out = test_model(images) print("[TEST] hrnet-w48, cpu : pass")