postometro-free-demo / main /postometro.py
imabackstabber
test postometro pipeline
0a34307
raw
history blame
15.2 kB
# ----------------------------------------------------------------------------------------------
# FastMETRO Official Code
# Copyright (c) POSTECH Algorithmic Machine Intelligence Lab. (P-AMI Lab.) All Rights Reserved
# Licensed under the MIT license.
# ----------------------------------------------------------------------------------------------
# ----------------------------------------------------------------------------------------------
# PostoMETRO Official Code
# Copyright (c) MIRACLE Lab. All Rights Reserved
# Licensed under the MIT license.
# ----------------------------------------------------------------------------------------------
from __future__ import absolute_import, division, print_function
import torch
import numpy as np
import argparse
import os
import os.path as osp
from torch import nn
from postometro_utils.smpl import Mesh
from postometro_utils.transformer import build_transformer
from postometro_utils.positional_encoding import build_position_encoding
from postometro_utils.modules import FCBlock, MixerLayer
from pct_utils.pct import PCT
from pct_utils.pct_backbone import SwinV2TransformerRPE2FC
from postometro_utils.pose_resnet import get_pose_net as get_pose_resnet
from postometro_utils.pose_resnet_config import config as resnet_config
from postometro_utils.pose_hrnet import get_pose_hrnet
from postometro_utils.pose_hrnet_config import _C as hrnet_config
from postometro_utils.pose_hrnet_config import update_config as hrnet_update_config
CUR_DIR = osp.dirname(os.path.abspath(__file__))
class PostoMETRO(nn.Module):
"""PostoMETRO for 3D human pose and mesh reconstruction from a single RGB image"""
def __init__(self, args, backbone, mesh_sampler, pct = None, num_joints=14, num_vertices=431):
"""
Parameters:
- args: Arguments
- backbone: CNN Backbone used to extract image features from the given image
- mesh_sampler: Mesh Sampler used in the coarse-to-fine mesh upsampling
- num_joints: The number of joint tokens used in the transformer decoder
- num_vertices: The number of vertex tokens used in the transformer decoder
"""
super().__init__()
self.args = args
self.backbone = backbone
self.mesh_sampler = mesh_sampler
self.num_joints = num_joints
self.num_vertices = num_vertices
# the number of transformer layers, set to default
num_enc_layers = 3
num_dec_layers = 3
# configurations for the first transformer
self.transformer_config_1 = {"model_dim": args.model_dim_1, "dropout": args.transformer_dropout, "nhead": args.transformer_nhead,
"feedforward_dim": args.feedforward_dim_1, "num_enc_layers": num_enc_layers, "num_dec_layers": num_dec_layers,
"pos_type": args.pos_type}
# configurations for the second transformer
self.transformer_config_2 = {"model_dim": args.model_dim_2, "dropout": args.transformer_dropout, "nhead": args.transformer_nhead,
"feedforward_dim": args.feedforward_dim_2, "num_enc_layers": num_enc_layers, "num_dec_layers": num_dec_layers,
"pos_type": args.pos_type}
# build transformers
self.transformer_1 = build_transformer(self.transformer_config_1)
self.transformer_2 = build_transformer(self.transformer_config_2)
# dimensionality reduction
self.dim_reduce_enc_cam = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"])
self.dim_reduce_enc_img = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"])
self.dim_reduce_dec = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"])
# token embeddings
self.cam_token_embed = nn.Embedding(1, self.transformer_config_1["model_dim"])
self.joint_token_embed = nn.Embedding(self.num_joints, self.transformer_config_1["model_dim"])
self.vertex_token_embed = nn.Embedding(self.num_vertices, self.transformer_config_1["model_dim"])
# positional encodings
self.position_encoding_1 = build_position_encoding(pos_type=self.transformer_config_1['pos_type'], hidden_dim=self.transformer_config_1['model_dim'])
self.position_encoding_2 = build_position_encoding(pos_type=self.transformer_config_2['pos_type'], hidden_dim=self.transformer_config_2['model_dim'])
# estimators
self.xyz_regressor = nn.Linear(self.transformer_config_2["model_dim"], 3)
self.cam_predictor = nn.Linear(self.transformer_config_2["model_dim"], 3)
# 1x1 Convolution
self.conv_1x1 = nn.Conv2d(args.conv_1x1_dim, self.transformer_config_1["model_dim"], kernel_size=1)
# attention mask
zeros_1 = torch.tensor(np.zeros((num_vertices, num_joints)).astype(bool))
zeros_2 = torch.tensor(np.zeros((num_joints, (num_joints + num_vertices))).astype(bool))
adjacency_indices = torch.load(osp.join(CUR_DIR, 'data/smpl_431_adjmat_indices.pt'))
adjacency_matrix_value = torch.load(osp.join(CUR_DIR, 'data/smpl_431_adjmat_values.pt'))
adjacency_matrix_size = torch.load(osp.join(CUR_DIR, 'data/smpl_431_adjmat_size.pt'))
adjacency_matrix = torch.sparse_coo_tensor(adjacency_indices, adjacency_matrix_value, size=adjacency_matrix_size).to_dense()
temp_mask_1 = (adjacency_matrix == 0)
temp_mask_2 = torch.cat([zeros_1, temp_mask_1], dim=1)
self.attention_mask = torch.cat([zeros_2, temp_mask_2], dim=0)
# learnable upsampling layer is used (from coarse mesh to intermediate mesh); for visually pleasing mesh result
### pre-computed upsampling matrix is used (from intermediate mesh to fine mesh); to reduce optimization difficulty
self.coarse2intermediate_upsample = nn.Linear(431, 1723)
# using extra token
self.pct = None
if pct is not None:
self.pct = pct
# +1 to align with uncertainty score
self.token_mixer = FCBlock(args.tokenizer_codebook_token_dim + 1, self.transformer_config_1["model_dim"])
self.start_embed = nn.Linear(512, args.enc_hidden_dim)
self.encoder = nn.ModuleList(
[MixerLayer(args.enc_hidden_dim, args.enc_hidden_inter_dim,
args.num_joints, args.token_inter_dim,
args.enc_dropout) for _ in range(args.enc_num_blocks)])
self.encoder_layer_norm = nn.LayerNorm(args.enc_hidden_dim)
self.token_mlp = nn.Linear(args.num_joints, args.token_num)
self.dim_reduce_enc_pct = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"])
def forward(self, images):
device = images.device
batch_size = images.size(0)
# preparation
cam_token = self.cam_token_embed.weight.unsqueeze(1).repeat(1, batch_size, 1) # 1 X batch_size X 512
jv_tokens = torch.cat([self.joint_token_embed.weight, self.vertex_token_embed.weight], dim=0).unsqueeze(1).repeat(1, batch_size, 1) # (num_joints + num_vertices) X batch_size X 512
attention_mask = self.attention_mask.to(device) # (num_joints + num_vertices) X (num_joints + num_vertices)
pct_token = None
if self.pct is not None:
pct_out = self.pct(images, None, train=False)
pct_pose = pct_out['part_token_feat'].clone()
encode_feat = self.start_embed(pct_pose) # 2, 17, 512
for num_layer in self.encoder:
encode_feat = num_layer(encode_feat)
encode_feat = self.encoder_layer_norm(encode_feat)
encode_feat = encode_feat.transpose(2, 1)
encode_feat = self.token_mlp(encode_feat).transpose(2, 1)
pct_token_out = encode_feat.permute(1,0,2)
pct_score = pct_out['encoding_scores']
pct_score = pct_score.permute(1,0,2)
pct_token = torch.cat([pct_token_out, pct_score], dim = -1)
pct_token = self.token_mixer(pct_token) # [b, 34, 512]
# extract image features through a CNN backbone
_img_features = self.backbone(images) # batch_size X 2048 X 7 X 7
_, _, h, w = _img_features.shape
img_features = self.conv_1x1(_img_features).flatten(2).permute(2, 0, 1) # 49 X batch_size X 512
# positional encodings
pos_enc_1 = self.position_encoding_1(batch_size, h, w, device).flatten(2).permute(2, 0, 1) # 49 X batch_size X 512
pos_enc_2 = self.position_encoding_2(batch_size, h, w, device).flatten(2).permute(2, 0, 1) # 49 X batch_size X 128
# first transformer encoder-decoder
cam_features_1, enc_img_features_1, jv_features_1, pct_features_1 = self.transformer_1(img_features, cam_token, jv_tokens, pos_enc_1, pct_token = pct_token, attention_mask=attention_mask)
# progressive dimensionality reduction
reduced_cam_features_1 = self.dim_reduce_enc_cam(cam_features_1) # 1 X batch_size X 128
reduced_enc_img_features_1 = self.dim_reduce_enc_img(enc_img_features_1) # 49 X batch_size X 128
reduced_jv_features_1 = self.dim_reduce_dec(jv_features_1) # (num_joints + num_vertices) X batch_size X 128
reduced_pct_features_1 = None
if pct_features_1 is not None:
reduced_pct_features_1 = self.dim_reduce_enc_pct(pct_features_1)
# second transformer encoder-decoder
cam_features_2, _, jv_features_2,_ = self.transformer_2(reduced_enc_img_features_1, reduced_cam_features_1, reduced_jv_features_1, pos_enc_2, pct_token = reduced_pct_features_1, attention_mask=attention_mask)
# estimators
pred_cam = self.cam_predictor(cam_features_2).view(batch_size, 3) # batch_size X 3
pred_3d_coordinates = self.xyz_regressor(jv_features_2.transpose(0, 1)) # batch_size X (num_joints + num_vertices) X 3
pred_3d_joints = pred_3d_coordinates[:,:self.num_joints,:] # batch_size X num_joints X 3
pred_3d_vertices_coarse = pred_3d_coordinates[:,self.num_joints:,:] # batch_size X num_vertices(coarse) X 3
# coarse-to-intermediate mesh upsampling
pred_3d_vertices_intermediate = self.coarse2intermediate_upsample(pred_3d_vertices_coarse.transpose(1,2)).transpose(1,2) # batch_size X num_vertices(intermediate) X 3
# intermediate-to-fine mesh upsampling
pred_3d_vertices_fine = self.mesh_sampler.upsample(pred_3d_vertices_intermediate, n1=1, n2=0) # batch_size X num_vertices(fine) X 3
out = {}
out['pred_cam'] = pred_cam
out['pct_pose'] = pct_out['pred_pose'] if self.pct is not None else torch.zeros((batch_size, 34, 2)).cuda(device)
out['pred_3d_joints'] = pred_3d_joints
out['pred_3d_vertices_coarse'] = pred_3d_vertices_coarse
out['pred_3d_vertices_intermediate'] = pred_3d_vertices_intermediate
out['pred_3d_vertices_fine'] = pred_3d_vertices_fine
return out
defaults_args = argparse.Namespace(
pos_type = 'sine',
transformer_dropout = 0.1,
transformer_nhead = 8,
conv_1x1_dim = 2048,
tokenizer_codebook_token_dim = 512,
model_dim_1 = 512,
feedforward_dim_1 = 2048,
model_dim_2 = 128,
feedforward_dim_2 = 512,
enc_hidden_dim = 512,
enc_hidden_inter_dim = 512,
token_inter_dim = 64,
enc_dropout = 0.0,
enc_num_blocks = 4,
num_joints = 34,
token_num = 34
)
default_pct_args = argparse.Namespace(
pct_backbone_channel = 1536,
tokenizer_guide_ratio=0.5,
cls_head_conv_channels=256,
cls_head_hidden_dim=64,
cls_head_num_blocks=4,
cls_head_hidden_inter_dim=256,
cls_head_token_inter_dim=64,
cls_head_dropout=0.0,
cls_head_conv_num_blocks=2,
cls_head_dilation=1,
# tokenzier
tokenizer_encoder_drop_rate=0.2,
tokenizer_encoder_num_blocks=4,
tokenizer_encoder_hidden_dim=512,
tokenizer_encoder_token_inter_dim=64,
tokenizer_encoder_hidden_inter_dim=512,
tokenizer_encoder_dropout=0.0,
tokenizer_decoder_num_blocks=1,
tokenizer_decoder_hidden_dim=32,
tokenizer_decoder_token_inter_dim=64,
tokenizer_decoder_hidden_inter_dim=64,
tokenizer_decoder_dropout=0.0,
tokenizer_codebook_token_num=34,
tokenizer_codebook_token_dim=512,
tokenizer_codebook_token_class_num=2048,
tokenizer_codebook_ema_decay=0.9,
)
backbone_config=dict(
embed_dim=192,
depths=[2, 2, 18, 2],
num_heads=[6, 12, 24, 48],
window_size=[16, 16, 16, 8],
pretrain_window_size=[12, 12, 12, 6],
ape=False,
drop_path_rate=0.5,
patch_norm=True,
use_checkpoint=True,
rpe_interpolation='geo',
use_shift=[True, True, False, False],
relative_coords_table_type='norm8_log_bylayer',
attn_type='cosine_mh',
rpe_output_type='sigmoid',
postnorm=True,
mlp_type='normal',
out_indices=(3,),
patch_embed_type='normal',
patch_merge_type='normal',
strid16=False,
frozen_stages=5,
)
def get_model(backbone_str = 'resnet50', device = torch.device('cpu'), checkpoint_file = None):
if backbone_str == 'hrnet-w48':
defaults_args.conv_1x1_dim = 384
# update hrnet config by yaml
hrnet_yaml = osp.join(CUR_DIR,'postometro_utils', 'pose_w48_256x192_adam_lr1e-3.yaml')
hrnet_update_config(hrnet_config, hrnet_yaml)
backbone = get_pose_hrnet(hrnet_config, None)
else:
backbone = get_pose_resnet(resnet_config, is_train=False)
mesh_upsampler = Mesh(device=device)
pct_swin_backbone = SwinV2TransformerRPE2FC(**backbone_config)
# initialize pct head
pct = PCT(default_pct_args, pct_swin_backbone, 'classifier', default_pct_args.pct_backbone_channel, (256, 256), 17, None, None).to(device)
model = PostoMETRO(defaults_args, backbone, mesh_upsampler, pct=pct).to(device)
print("[INFO] model loaded, params: {}, {}".format(backbone_str, device))
if checkpoint_file:
cpu_device = torch.device('cpu')
state_dict = torch.load(checkpoint_file, map_location=cpu_device)
model.load_state_dict(state_dict, strict=True)
del state_dict
print("[INFO] checkpoint loaded, params: {}, {}".format(backbone_str, device))
return model
if __name__ == '__main__':
test_model = get_model(device=torch.device('cuda'))
images = torch.randn(1,3,256,256).to(torch.device('cuda'))
test_out = test_model(images)
print("[TEST] resnet50, cuda : pass")
test_model = get_model()
images = torch.randn(1,3,256,256).to()
test_out = test_model(images)
print("[TEST] resnet50, cpu : pass")
test_model = get_model(backbone_str='hrnet-w48', device=torch.device('cuda'))
images = torch.randn(1,3,256,256).to(torch.device('cuda'))
test_out = test_model(images)
print("[TEST] hrnet-w48, cuda : pass")
test_model = get_model(backbone_str='hrnet-w48')
images = torch.randn(1,3,256,256).to()
test_out = test_model(images)
print("[TEST] hrnet-w48, cpu : pass")