File size: 15,216 Bytes
0a34307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
# ----------------------------------------------------------------------------------------------
# FastMETRO Official Code
# Copyright (c) POSTECH Algorithmic Machine Intelligence Lab. (P-AMI Lab.) All Rights Reserved 
# Licensed under the MIT license.
# ----------------------------------------------------------------------------------------------

# ----------------------------------------------------------------------------------------------
# PostoMETRO Official Code
# Copyright (c) MIRACLE Lab. All Rights Reserved 
# Licensed under the MIT license.
# ----------------------------------------------------------------------------------------------

from __future__ import absolute_import, division, print_function
import torch
import numpy as np
import argparse
import os
import os.path as osp
from torch import nn
from postometro_utils.smpl import Mesh
from postometro_utils.transformer import build_transformer
from postometro_utils.positional_encoding import build_position_encoding
from postometro_utils.modules import FCBlock, MixerLayer
from pct_utils.pct import PCT
from pct_utils.pct_backbone import SwinV2TransformerRPE2FC
from postometro_utils.pose_resnet import get_pose_net as get_pose_resnet
from postometro_utils.pose_resnet_config import config as resnet_config
from postometro_utils.pose_hrnet import get_pose_hrnet
from postometro_utils.pose_hrnet_config import _C as hrnet_config
from postometro_utils.pose_hrnet_config import update_config as hrnet_update_config

CUR_DIR = osp.dirname(os.path.abspath(__file__))

class PostoMETRO(nn.Module):
    """PostoMETRO for 3D human pose and mesh reconstruction from a single RGB image"""
    def __init__(self, args, backbone, mesh_sampler, pct = None, num_joints=14, num_vertices=431):
        """
        Parameters:
            - args: Arguments
            - backbone: CNN Backbone used to extract image features from the given image
            - mesh_sampler: Mesh Sampler used in the coarse-to-fine mesh upsampling
            - num_joints: The number of joint tokens used in the transformer decoder
            - num_vertices: The number of vertex tokens used in the transformer decoder
        """
        super().__init__()
        self.args = args
        self.backbone = backbone
        self.mesh_sampler = mesh_sampler
        self.num_joints = num_joints
        self.num_vertices = num_vertices
        
        # the number of transformer layers, set to default
        num_enc_layers = 3
        num_dec_layers = 3
    
        # configurations for the first transformer
        self.transformer_config_1 = {"model_dim": args.model_dim_1, "dropout": args.transformer_dropout, "nhead": args.transformer_nhead, 
                                     "feedforward_dim": args.feedforward_dim_1, "num_enc_layers": num_enc_layers, "num_dec_layers": num_dec_layers, 
                                     "pos_type": args.pos_type}
        # configurations for the second transformer
        self.transformer_config_2 = {"model_dim": args.model_dim_2, "dropout": args.transformer_dropout, "nhead": args.transformer_nhead,
                                     "feedforward_dim": args.feedforward_dim_2, "num_enc_layers": num_enc_layers, "num_dec_layers": num_dec_layers, 
                                     "pos_type": args.pos_type}
        # build transformers
        self.transformer_1 = build_transformer(self.transformer_config_1)
        self.transformer_2 = build_transformer(self.transformer_config_2)

        # dimensionality reduction
        self.dim_reduce_enc_cam = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"])
        self.dim_reduce_enc_img = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"])
        self.dim_reduce_dec = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"])
        
        # token embeddings
        self.cam_token_embed = nn.Embedding(1, self.transformer_config_1["model_dim"])
        self.joint_token_embed = nn.Embedding(self.num_joints, self.transformer_config_1["model_dim"])
        self.vertex_token_embed = nn.Embedding(self.num_vertices, self.transformer_config_1["model_dim"])
        # positional encodings
        self.position_encoding_1 = build_position_encoding(pos_type=self.transformer_config_1['pos_type'], hidden_dim=self.transformer_config_1['model_dim'])
        self.position_encoding_2 = build_position_encoding(pos_type=self.transformer_config_2['pos_type'], hidden_dim=self.transformer_config_2['model_dim'])
        # estimators
        self.xyz_regressor = nn.Linear(self.transformer_config_2["model_dim"], 3)
        self.cam_predictor = nn.Linear(self.transformer_config_2["model_dim"], 3)
        
        # 1x1 Convolution
        self.conv_1x1 = nn.Conv2d(args.conv_1x1_dim, self.transformer_config_1["model_dim"], kernel_size=1)

        # attention mask
        zeros_1 = torch.tensor(np.zeros((num_vertices, num_joints)).astype(bool)) 
        zeros_2 = torch.tensor(np.zeros((num_joints, (num_joints + num_vertices))).astype(bool)) 
        adjacency_indices = torch.load(osp.join(CUR_DIR, 'data/smpl_431_adjmat_indices.pt'))
        adjacency_matrix_value = torch.load(osp.join(CUR_DIR, 'data/smpl_431_adjmat_values.pt'))
        adjacency_matrix_size = torch.load(osp.join(CUR_DIR, 'data/smpl_431_adjmat_size.pt'))
        adjacency_matrix = torch.sparse_coo_tensor(adjacency_indices, adjacency_matrix_value, size=adjacency_matrix_size).to_dense()
        temp_mask_1 = (adjacency_matrix == 0)
        temp_mask_2 = torch.cat([zeros_1, temp_mask_1], dim=1)
        self.attention_mask = torch.cat([zeros_2, temp_mask_2], dim=0)
        
        # learnable upsampling layer is used (from coarse mesh to intermediate mesh); for visually pleasing mesh result
        ### pre-computed upsampling matrix is used (from intermediate mesh to fine mesh); to reduce optimization difficulty
        self.coarse2intermediate_upsample = nn.Linear(431, 1723)

        # using extra token
        self.pct = None
        if pct is not None:
            self.pct = pct
            # +1 to align with uncertainty score
            self.token_mixer = FCBlock(args.tokenizer_codebook_token_dim + 1, self.transformer_config_1["model_dim"])
            self.start_embed = nn.Linear(512, args.enc_hidden_dim)
            self.encoder = nn.ModuleList(
            [MixerLayer(args.enc_hidden_dim, args.enc_hidden_inter_dim, 
                args.num_joints, args.token_inter_dim,
                args.enc_dropout) for _ in range(args.enc_num_blocks)])
            self.encoder_layer_norm = nn.LayerNorm(args.enc_hidden_dim)
            self.token_mlp = nn.Linear(args.num_joints, args.token_num)
            self.dim_reduce_enc_pct = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"])
    

    def forward(self, images):
        device = images.device
        batch_size = images.size(0)

        # preparation
        cam_token = self.cam_token_embed.weight.unsqueeze(1).repeat(1, batch_size, 1) # 1 X batch_size X 512 
        jv_tokens = torch.cat([self.joint_token_embed.weight, self.vertex_token_embed.weight], dim=0).unsqueeze(1).repeat(1, batch_size, 1) # (num_joints + num_vertices) X batch_size X 512
        attention_mask = self.attention_mask.to(device) # (num_joints + num_vertices) X (num_joints + num_vertices)

        pct_token = None
        if self.pct is not None:
            pct_out = self.pct(images, None, train=False)
            pct_pose = pct_out['part_token_feat'].clone()

            encode_feat = self.start_embed(pct_pose) # 2, 17, 512
            for num_layer in self.encoder:
                encode_feat = num_layer(encode_feat)
            encode_feat = self.encoder_layer_norm(encode_feat)
            encode_feat = encode_feat.transpose(2, 1)
            encode_feat = self.token_mlp(encode_feat).transpose(2, 1)
            pct_token_out = encode_feat.permute(1,0,2)

            pct_score = pct_out['encoding_scores']
            pct_score = pct_score.permute(1,0,2)
            pct_token = torch.cat([pct_token_out, pct_score], dim = -1)
            pct_token = self.token_mixer(pct_token) # [b, 34, 512]
        
        # extract image features through a CNN backbone
        _img_features = self.backbone(images) # batch_size X 2048 X 7 X 7
        _, _, h, w = _img_features.shape
        img_features = self.conv_1x1(_img_features).flatten(2).permute(2, 0, 1) # 49 X batch_size X 512 
        
        # positional encodings
        pos_enc_1 = self.position_encoding_1(batch_size, h, w, device).flatten(2).permute(2, 0, 1) # 49 X batch_size X 512 
        pos_enc_2 = self.position_encoding_2(batch_size, h, w, device).flatten(2).permute(2, 0, 1) # 49 X batch_size X 128 

        # first transformer encoder-decoder
        cam_features_1, enc_img_features_1, jv_features_1, pct_features_1 = self.transformer_1(img_features, cam_token, jv_tokens, pos_enc_1, pct_token = pct_token, attention_mask=attention_mask)
        
        # progressive dimensionality reduction
        reduced_cam_features_1 = self.dim_reduce_enc_cam(cam_features_1) # 1 X batch_size X 128 
        reduced_enc_img_features_1 = self.dim_reduce_enc_img(enc_img_features_1) # 49 X batch_size X 128 
        reduced_jv_features_1 = self.dim_reduce_dec(jv_features_1) # (num_joints + num_vertices) X batch_size X 128
        reduced_pct_features_1 = None
        if pct_features_1 is not None:
            reduced_pct_features_1 = self.dim_reduce_enc_pct(pct_features_1)

        # second transformer encoder-decoder
        cam_features_2, _, jv_features_2,_ = self.transformer_2(reduced_enc_img_features_1, reduced_cam_features_1, reduced_jv_features_1, pos_enc_2, pct_token = reduced_pct_features_1, attention_mask=attention_mask) 

        # estimators
        pred_cam = self.cam_predictor(cam_features_2).view(batch_size, 3) # batch_size X 3

        pred_3d_coordinates = self.xyz_regressor(jv_features_2.transpose(0, 1)) # batch_size X (num_joints + num_vertices) X 3
        pred_3d_joints = pred_3d_coordinates[:,:self.num_joints,:] # batch_size X num_joints X 3
        pred_3d_vertices_coarse = pred_3d_coordinates[:,self.num_joints:,:] # batch_size X num_vertices(coarse) X 3
        
        # coarse-to-intermediate mesh upsampling
        pred_3d_vertices_intermediate = self.coarse2intermediate_upsample(pred_3d_vertices_coarse.transpose(1,2)).transpose(1,2) # batch_size X num_vertices(intermediate) X 3
        # intermediate-to-fine mesh upsampling
        pred_3d_vertices_fine = self.mesh_sampler.upsample(pred_3d_vertices_intermediate, n1=1, n2=0) # batch_size X num_vertices(fine) X 3

        out = {}
        out['pred_cam'] = pred_cam
        out['pct_pose'] = pct_out['pred_pose'] if self.pct is not None else torch.zeros((batch_size, 34, 2)).cuda(device)
        out['pred_3d_joints'] = pred_3d_joints
        out['pred_3d_vertices_coarse'] = pred_3d_vertices_coarse
        out['pred_3d_vertices_intermediate'] = pred_3d_vertices_intermediate
        out['pred_3d_vertices_fine'] = pred_3d_vertices_fine

        return out


defaults_args = argparse.Namespace(
    pos_type = 'sine',
    transformer_dropout = 0.1,
    transformer_nhead = 8,
    conv_1x1_dim = 2048,
    tokenizer_codebook_token_dim = 512,
    model_dim_1 = 512,
    feedforward_dim_1 = 2048,
    model_dim_2 = 128,
    feedforward_dim_2 = 512,
    enc_hidden_dim = 512,
    enc_hidden_inter_dim = 512,
    token_inter_dim = 64,
    enc_dropout = 0.0,
    enc_num_blocks = 4,
    num_joints = 34,
    token_num = 34
)

default_pct_args = argparse.Namespace(
    pct_backbone_channel = 1536,
    tokenizer_guide_ratio=0.5,
    cls_head_conv_channels=256,
    cls_head_hidden_dim=64,
    cls_head_num_blocks=4, 
    cls_head_hidden_inter_dim=256,
    cls_head_token_inter_dim=64,
    cls_head_dropout=0.0,
    cls_head_conv_num_blocks=2, 
    cls_head_dilation=1, 
    # tokenzier
    tokenizer_encoder_drop_rate=0.2, 
    tokenizer_encoder_num_blocks=4, 
    tokenizer_encoder_hidden_dim=512, 
    tokenizer_encoder_token_inter_dim=64, 
    tokenizer_encoder_hidden_inter_dim=512, 
    tokenizer_encoder_dropout=0.0, 
    tokenizer_decoder_num_blocks=1, 
    tokenizer_decoder_hidden_dim=32, 
    tokenizer_decoder_token_inter_dim=64, 
    tokenizer_decoder_hidden_inter_dim=64, 
    tokenizer_decoder_dropout=0.0, 
    tokenizer_codebook_token_num=34,
    tokenizer_codebook_token_dim=512,  
    tokenizer_codebook_token_class_num=2048, 
    tokenizer_codebook_ema_decay=0.9, 
)

backbone_config=dict(
    embed_dim=192,
    depths=[2, 2, 18, 2],
    num_heads=[6, 12, 24, 48],
    window_size=[16, 16, 16, 8],
    pretrain_window_size=[12, 12, 12, 6],
    ape=False,
    drop_path_rate=0.5,
    patch_norm=True,
    use_checkpoint=True,
    rpe_interpolation='geo',
    use_shift=[True, True, False, False],
    relative_coords_table_type='norm8_log_bylayer',
    attn_type='cosine_mh',
    rpe_output_type='sigmoid',
    postnorm=True,
    mlp_type='normal',
    out_indices=(3,),
    patch_embed_type='normal',
    patch_merge_type='normal',
    strid16=False,
    frozen_stages=5,
)

def get_model(backbone_str = 'resnet50', device = torch.device('cpu'), checkpoint_file = None):
    if backbone_str == 'hrnet-w48':
        defaults_args.conv_1x1_dim = 384
        # update hrnet config by yaml
        hrnet_yaml = osp.join(CUR_DIR,'postometro_utils', 'pose_w48_256x192_adam_lr1e-3.yaml')
        hrnet_update_config(hrnet_config, hrnet_yaml)
        backbone = get_pose_hrnet(hrnet_config, None)
    else:
        backbone = get_pose_resnet(resnet_config, is_train=False)
    mesh_upsampler = Mesh(device=device)
    pct_swin_backbone = SwinV2TransformerRPE2FC(**backbone_config)
    # initialize pct head
    pct = PCT(default_pct_args, pct_swin_backbone, 'classifier', default_pct_args.pct_backbone_channel, (256, 256), 17, None, None).to(device)
    model = PostoMETRO(defaults_args, backbone, mesh_upsampler, pct=pct).to(device)
    print("[INFO] model loaded, params: {}, {}".format(backbone_str, device))
    if checkpoint_file:
        cpu_device = torch.device('cpu')
        state_dict = torch.load(checkpoint_file, map_location=cpu_device)
        model.load_state_dict(state_dict, strict=True)
        del state_dict
        print("[INFO] checkpoint loaded, params: {}, {}".format(backbone_str, device))
    return model

if __name__ == '__main__':
    test_model = get_model(device=torch.device('cuda'))
    images = torch.randn(1,3,256,256).to(torch.device('cuda'))
    test_out = test_model(images)
    print("[TEST] resnet50, cuda : pass")

    test_model = get_model()
    images = torch.randn(1,3,256,256).to()
    test_out = test_model(images)
    print("[TEST] resnet50, cpu : pass")

    test_model = get_model(backbone_str='hrnet-w48', device=torch.device('cuda'))
    images = torch.randn(1,3,256,256).to(torch.device('cuda'))
    test_out = test_model(images)
    print("[TEST] hrnet-w48, cuda : pass")

    test_model = get_model(backbone_str='hrnet-w48')
    images = torch.randn(1,3,256,256).to()
    test_out = test_model(images)
    print("[TEST] hrnet-w48, cpu : pass")