diff --git a/app.py b/app.py
index 8ee0a9f7de2d664fcdbb768ad90da1b8efb768e5..8af72faa0092c69210f873b83ae34009278dced8 100644
--- a/app.py
+++ b/app.py
@@ -33,40 +33,7 @@ def infer(image_input, in_threshold=0.5, num_people="Single person", render_mesh
os.system(f'rm -rf {OUT_FOLDER}/*')
multi_person = False if (num_people == "Single person") else True
vis_img, num_bbox, mmdet_box = inferer.infer(image_input, in_threshold, 0, multi_person, not(render_mesh))
-
- # cap = cv2.VideoCapture(video_input)
- # fps = math.ceil(cap.get(5))
- # width = int(cap.get(3))
- # height = int(cap.get(4))
- # fourcc = cv2.VideoWriter_fourcc(*'mp4v')
- # video_path = osp.join(OUT_FOLDER, f'out.m4v')
- # final_video_path = osp.join(OUT_FOLDER, f'out.mp4')
- # video_output = cv2.VideoWriter(video_path, fourcc, fps, (width, height))
- # success = 1
- # frame = 0
- # while success:
- # success, original_img = cap.read()
- # if not success:
- # break
- # frame += 1
- # img, mesh_paths, smplx_paths = inferer.infer(original_img, in_threshold, frame, multi_person, not(render_mesh))
- # video_output.write(img)
- # yield img, None, None, None
- # cap.release()
- # video_output.release()
- # cv2.destroyAllWindows()
- # os.system(f'ffmpeg -i {video_path} -c copy {final_video_path}')
-
- # #Compress mesh and smplx files
- # save_path_mesh = os.path.join(OUT_FOLDER, 'mesh')
- # save_mesh_file = os.path.join(OUT_FOLDER, 'mesh.zip')
- # os.makedirs(save_path_mesh, exist_ok= True)
- # save_path_smplx = os.path.join(OUT_FOLDER, 'smplx')
- # save_smplx_file = os.path.join(OUT_FOLDER, 'smplx.zip')
- # os.makedirs(save_path_smplx, exist_ok= True)
- # os.system(f'zip -r {save_mesh_file} {save_path_mesh}')
- # os.system(f'zip -r {save_smplx_file} {save_path_smplx}')
- # yield img, video_path, save_mesh_file, save_smplx_file
+
return vis_img, "bbox num: {}, bbox meta: {}".format(num_bbox, mmdet_box)
TITLE = '''
PostoMETRO: Pose Token Enhanced Mesh Transformer for Robust 3D Human Mesh Recovery
'''
@@ -113,6 +80,9 @@ with gr.Blocks(title="PostoMETRO", css=".gradio-container") as demo:
['/home/user/app/assets/02.jpg'],
['/home/user/app/assets/03.jpg'],
['/home/user/app/assets/04.jpg'],
+ ['/home/user/app/assets/05.jpg'],
+ ['/home/user/app/assets/06.jpg'],
+ ['/home/user/app/assets/07.jpg'],
],
inputs=[image_input, 0.2])
diff --git a/assets/02.jpg b/assets/02.jpg
index e3176774fcf49344c3b644cb7ba4b03d581bbb8e..6c13965baa0fe4a8e27dee7178dc921ca931cc5d 100644
Binary files a/assets/02.jpg and b/assets/02.jpg differ
diff --git a/assets/04.jpg b/assets/04.jpg
index 846802cc2c91b0bd93b0aa3606e5a8a124eca151..e58fde18c3c9f558b7d894d470f0a5662721b5cb 100644
Binary files a/assets/04.jpg and b/assets/04.jpg differ
diff --git a/assets/05.jpg b/assets/05.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a4f520bc9c426c09bb834f8d0ed28992ba718482
Binary files /dev/null and b/assets/05.jpg differ
diff --git a/assets/06.jpg b/assets/06.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c58f911ffc33611b7133bc16bff82b81403b24aa
Binary files /dev/null and b/assets/06.jpg differ
diff --git a/assets/07.jpg b/assets/07.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b99df507302d7193d4d30b6e136636a78cfa4940
Binary files /dev/null and b/assets/07.jpg differ
diff --git a/common/utils/__pycache__/__init__.cpython-39.pyc b/common/utils/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f038c1d5b690eefd173a0276cb2eaea13d87b765
Binary files /dev/null and b/common/utils/__pycache__/__init__.cpython-39.pyc differ
diff --git a/common/utils/__pycache__/inference_utils.cpython-39.pyc b/common/utils/__pycache__/inference_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dfd27635a9573b1b638134c057d81378afa4e736
Binary files /dev/null and b/common/utils/__pycache__/inference_utils.cpython-39.pyc differ
diff --git a/common/utils/__pycache__/preprocessing.cpython-39.pyc b/common/utils/__pycache__/preprocessing.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cb5436efaaddef63f1fa8b7f2060c3d7c4fcf8aa
Binary files /dev/null and b/common/utils/__pycache__/preprocessing.cpython-39.pyc differ
diff --git a/common/utils/__pycache__/transforms.cpython-39.pyc b/common/utils/__pycache__/transforms.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d8f53f05922da528904b78eeb8193f83b4665253
Binary files /dev/null and b/common/utils/__pycache__/transforms.cpython-39.pyc differ
diff --git a/common/utils/__pycache__/vis.cpython-39.pyc b/common/utils/__pycache__/vis.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ed1d51ef5a2cd9c2a73e6ab0ad4242bf8735736
Binary files /dev/null and b/common/utils/__pycache__/vis.cpython-39.pyc differ
diff --git a/common/utils/vis.py b/common/utils/vis.py
index f5b7dd3b6775e16bff638c8383ed04ab916978c1..10d2dc9f98713d9faca95388d80ca27eaad7a569 100644
--- a/common/utils/vis.py
+++ b/common/utils/vis.py
@@ -5,7 +5,7 @@ from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib as mpl
import os
-os.environ["PYOPENGL_PLATFORM"] = "egl"
+os.environ["PYOPENGL_PLATFORM"] = "osmesa"
import pyrender
import trimesh
from config import cfg
@@ -138,6 +138,20 @@ def perspective_projection(vertices, cam_param):
vertices[:, 1] = vertices[:, 1] * fy / vertices[:, 2] + cy
return vertices
+class WeakPerspectiveCamera(pyrender.Camera):
+ def __init__(self, scale, translation, znear=pyrender.camera.DEFAULT_Z_NEAR, zfar=None, name=None):
+ super(WeakPerspectiveCamera, self).__init__(znear=znear, zfar=zfar, name=name)
+ self.scale = scale
+ self.translation = translation
+
+ def get_projection_matrix(self, width=None, height=None):
+ P = np.eye(4)
+ P[0, 0] = self.scale[0]
+ P[1, 1] = self.scale[1]
+ P[0, 3] = self.translation[0] * self.scale[0]
+ P[1, 3] = -self.translation[1] * self.scale[1]
+ P[2, 2] = -1
+ return P
def render_mesh(img, mesh, face, cam_param, mesh_as_vertices=False):
if mesh_as_vertices:
@@ -150,28 +164,32 @@ def render_mesh(img, mesh, face, cam_param, mesh_as_vertices=False):
rot = trimesh.transformations.rotation_matrix(
np.radians(180), [1, 0, 0])
mesh.apply_transform(rot)
- material = pyrender.MetallicRoughnessMaterial(metallicFactor=0.0, alphaMode='OPAQUE', baseColorFactor=(1.0, 1.0, 0.9, 1.0))
- mesh = pyrender.Mesh.from_trimesh(mesh, material=material, smooth=False)
- scene = pyrender.Scene(ambient_light=(0.3, 0.3, 0.3))
+ color=[0.7, 0.7, 0.6]
+ material = pyrender.MetallicRoughnessMaterial(
+ metallicFactor=0.2,
+ roughnessFactor=1.0,
+ alphaMode='OPAQUE',
+ baseColorFactor=(color[0], color[1], color[2], 1.0)
+ )
+ mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
+ scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], ambient_light=(0.05, 0.05, 0.05))
+ light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=3.0)
+ light_pose = trimesh.transformations.rotation_matrix(np.radians(-45), [1, 0, 0])
+ scene.add(light, pose=light_pose)
+ light_pose = trimesh.transformations.rotation_matrix(np.radians(45), [0, 1, 0])
+ scene.add(light, pose=light_pose)
scene.add(mesh, 'mesh')
- focal, princpt = cam_param['focal'], cam_param['princpt']
- camera = pyrender.IntrinsicsCamera(fx=focal[0], fy=focal[1], cx=princpt[0], cy=princpt[1])
- scene.add(camera)
+ # focal, princpt = cam_param['focal'], cam_param['princpt']
+ # camera = pyrender.IntrinsicsCamera(fx=focal[0], fy=focal[1], cx=princpt[0], cy=princpt[1])
+ sx, sy, tx, ty = cam_param
+ camera = WeakPerspectiveCamera(scale=[sx, sy], translation=[tx, ty], zfar=1000.0)
+ camera_pose = np.eye(4)
+ scene.add(camera, pose=camera_pose)
# renderer
renderer = pyrender.OffscreenRenderer(viewport_width=img.shape[1], viewport_height=img.shape[0], point_size=1.0)
- # light
- light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=0.8)
- light_pose = np.eye(4)
- light_pose[:3, 3] = np.array([0, -1, 1])
- scene.add(light, pose=light_pose)
- light_pose[:3, 3] = np.array([0, 1, 1])
- scene.add(light, pose=light_pose)
- light_pose[:3, 3] = np.array([1, 1, 2])
- scene.add(light, pose=light_pose)
-
# render
rgb, depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
rgb = rgb[:,:,:3].astype(np.float32)
diff --git a/main/__pycache__/config.cpython-39.pyc b/main/__pycache__/config.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6e0d91d6cbb13956029fded68bee60952a27873b
Binary files /dev/null and b/main/__pycache__/config.cpython-39.pyc differ
diff --git a/main/__pycache__/postometro.cpython-39.pyc b/main/__pycache__/postometro.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..03d340f47d4f0e13d4b2e18799e1caa874c92882
Binary files /dev/null and b/main/__pycache__/postometro.cpython-39.pyc differ
diff --git a/main/config/config_postometro.py b/main/config/config_postometro.py
index 00a09484bf30e29bdce054d284d65e019bf9799e..ac522eb12ef6a146ef7106ab3f9a32af4e6121df 100644
--- a/main/config/config_postometro.py
+++ b/main/config/config_postometro.py
@@ -3,109 +3,12 @@ import os.path as osp
# will be update in exp
num_gpus = -1
-exp_name = 'output/exp1/pre_analysis'
-
-# quick access
-save_epoch = 1
-lr = 1e-5
-end_epoch = 10
-train_batch_size = 16
-
-syncbn = True
-bbox_ratio = 1.2
-
-# continue
-continue_train = False
-start_over = True
-
-# dataset setting
-agora_fix_betas = True
-agora_fix_global_orient_transl = True
-agora_valid_root_pose = True
-
-# all
-dataset_list = ['Human36M', 'MSCOCO', 'MPII', 'AGORA', 'EHF', 'SynBody', 'GTA_Human2', \
- 'EgoBody_Egocentric', 'EgoBody_Kinect', 'UBody', 'PW3D', 'MuCo', 'PROX']
-trainset_3d = ['MSCOCO','AGORA', 'UBody']
-trainset_2d = ['PW3D', 'MPII', 'Human36M']
-trainset_humandata = ['BEDLAM', 'SPEC', 'GTA_Human2','SynBody', 'PoseTrack',
- 'EgoBody_Egocentric', 'PROX', 'CrowdPose',
- 'EgoBody_Kinect', 'MPI_INF_3DHP', 'RICH', 'MuCo', 'InstaVariety',
- 'Behave', 'UP3D', 'ARCTIC',
- 'OCHuman', 'CHI3D', 'RenBody_HiRes', 'MTP', 'HumanSC3D', 'RenBody',
- 'FIT3D', 'Talkshow' , 'SSP3D', 'LSPET']
-testset = 'EHF'
-
-use_cache = True
-# downsample
-BEDLAM_train_sample_interval = 5
-EgoBody_Kinect_train_sample_interval = 10
-train_sample_interval = 10 # UBody
-MPI_INF_3DHP_train_sample_interval = 5
-InstaVariety_train_sample_interval = 10
-RenBody_HiRes_train_sample_interval = 5
-ARCTIC_train_sample_interval = 10
-# RenBody_train_sample_interval = 10
-FIT3D_train_sample_interval = 10
-Talkshow_train_sample_interval = 10
-
-# strategy
-data_strategy = 'balance' # 'balance' need to define total_data_len
-total_data_len = 4500000
-
-# model
-smplx_loss_weight = 1.0 #2 for agora_model for smplx shape
-smplx_pose_weight = 10.0
-
-smplx_kps_3d_weight = 100.0
-smplx_kps_2d_weight = 1.0
-net_kps_2d_weight = 1.0
-
-agora_benchmark = 'agora_model' # 'agora_model', 'test_only'
-
-model_type = 'smpler_x_h'
-encoder_config_file = 'main/transformer_utils/configs/smpler_x/encoder/body_encoder_huge.py'
-encoder_pretrained_model_path = 'pretrained_models/vitpose_huge.pth'
-feat_dim = 1280
-
-## =====FIXED ARGS============================================================
-## model setting
-upscale = 4
-hand_pos_joint_num = 20
-face_pos_joint_num = 72
-num_task_token = 24
-num_noise_sample = 0
-
-## UBody setting
-train_sample_interval = 10
-test_sample_interval = 100
-make_same_len = False
## input, output size
input_img_shape = (256, 256)
input_body_shape = (256, 256)
-output_hm_shape = (16, 16, 12)
-input_hand_shape = (256, 256)
-output_hand_hm_shape = (16, 16, 16)
-output_face_hm_shape = (8, 8, 8)
-input_face_shape = (192, 192)
-focal = (5000, 5000) # virtual focal lengths
-princpt = (input_body_shape[1] / 2, input_body_shape[0] / 2) # virtual principal point position
-body_3d_size = 2
-hand_3d_size = 0.3
-face_3d_size = 0.3
-camera_3d_size = 2.5
-
-## training config
-print_iters = 100
-lr_mult = 1
-## testing config
-test_batch_size = 32
-
-## others
-num_thread = 2
-vis = False
+renderer_input_body_shape = (256, 256)
+focal = (5000, 5000) # virtual focal lengths
+princpt = (renderer_input_body_shape[1] / 2, renderer_input_body_shape[0] / 2) # virtual principal point position
-## directory
-output_dir, model_dir, vis_dir, log_dir, result_dir, code_dir = None, None, None, None, None, None
diff --git a/main/config/config_smpler_x_b32.py b/main/config/config_smpler_x_b32.py
deleted file mode 100644
index b737e5307a76fbeaf6eefa6e2bc775c52760fab4..0000000000000000000000000000000000000000
--- a/main/config/config_smpler_x_b32.py
+++ /dev/null
@@ -1,112 +0,0 @@
-import os
-import os.path as osp
-
-# will be update in exp
-num_gpus = -1
-exp_name = 'output/exp1/pre_analysis'
-
-# quick access
-save_epoch = 1
-lr = 1e-5
-end_epoch = 10
-train_batch_size = 32
-
-syncbn = True
-bbox_ratio = 1.2
-
-# continue
-continue_train = False
-start_over = True
-
-# dataset setting
-agora_fix_betas = True
-agora_fix_global_orient_transl = True
-agora_valid_root_pose = True
-
-# all
-dataset_list = ['Human36M', 'MSCOCO', 'MPII', 'AGORA', 'EHF', 'SynBody', 'GTA_Human2', \
- 'EgoBody_Egocentric', 'EgoBody_Kinect', 'UBody', 'PW3D', 'MuCo', 'PROX']
-trainset_3d = ['MSCOCO','AGORA', 'UBody']
-trainset_2d = ['PW3D', 'MPII', 'Human36M']
-trainset_humandata = ['BEDLAM', 'SPEC', 'GTA_Human2','SynBody', 'PoseTrack',
- 'EgoBody_Egocentric', 'PROX', 'CrowdPose',
- 'EgoBody_Kinect', 'MPI_INF_3DHP', 'RICH', 'MuCo', 'InstaVariety',
- 'Behave', 'UP3D', 'ARCTIC',
- 'OCHuman', 'CHI3D', 'RenBody_HiRes', 'MTP', 'HumanSC3D', 'RenBody',
- 'FIT3D', 'Talkshow' , 'SSP3D', 'LSPET']
-testset = 'EHF'
-
-use_cache = True
-# downsample
-BEDLAM_train_sample_interval = 5
-EgoBody_Kinect_train_sample_interval = 10
-train_sample_interval = 10 # UBody
-MPI_INF_3DHP_train_sample_interval = 5
-InstaVariety_train_sample_interval = 10
-RenBody_HiRes_train_sample_interval = 5
-ARCTIC_train_sample_interval = 10
-# RenBody_train_sample_interval = 10
-FIT3D_train_sample_interval = 10
-Talkshow_train_sample_interval = 10
-
-# strategy
-data_strategy = 'balance' # 'balance' need to define total_data_len
-total_data_len = 4500000
-
-# model
-smplx_loss_weight = 1.0 #2 for agora_model for smplx shape
-smplx_pose_weight = 10.0
-
-smplx_kps_3d_weight = 100.0
-smplx_kps_2d_weight = 1.0
-net_kps_2d_weight = 1.0
-
-agora_benchmark = 'agora_model' # 'agora_model', 'test_only'
-
-model_type = 'smpler_x_b'
-encoder_config_file = 'main/transformer_utils/configs/smpler_x/encoder/body_encoder_base.py'
-encoder_pretrained_model_path = 'pretrained_models/vitpose_base.pth'
-feat_dim = 768
-
-
-## =====FIXED ARGS============================================================
-## model setting
-upscale = 4
-hand_pos_joint_num = 20
-face_pos_joint_num = 72
-num_task_token = 24
-num_noise_sample = 0
-
-## UBody setting
-train_sample_interval = 10
-test_sample_interval = 100
-make_same_len = False
-
-## input, output size
-input_img_shape = (512, 384)
-input_body_shape = (256, 192)
-output_hm_shape = (16, 16, 12)
-input_hand_shape = (256, 256)
-output_hand_hm_shape = (16, 16, 16)
-output_face_hm_shape = (8, 8, 8)
-input_face_shape = (192, 192)
-focal = (5000, 5000) # virtual focal lengths
-princpt = (input_body_shape[1] / 2, input_body_shape[0] / 2) # virtual principal point position
-body_3d_size = 2
-hand_3d_size = 0.3
-face_3d_size = 0.3
-camera_3d_size = 2.5
-
-## training config
-print_iters = 100
-lr_mult = 1
-
-## testing config
-test_batch_size = 32
-
-## others
-num_thread = 2
-vis = False
-
-## directory
-output_dir, model_dir, vis_dir, log_dir, result_dir, code_dir = None, None, None, None, None, None
diff --git a/main/config/config_smpler_x_h32.py b/main/config/config_smpler_x_h32.py
deleted file mode 100644
index 2ffd86e9e965f9f2d3fd5efdb98ad2cb83fa81ed..0000000000000000000000000000000000000000
--- a/main/config/config_smpler_x_h32.py
+++ /dev/null
@@ -1,111 +0,0 @@
-import os
-import os.path as osp
-
-# will be update in exp
-num_gpus = -1
-exp_name = 'output/exp1/pre_analysis'
-
-# quick access
-save_epoch = 1
-lr = 1e-5
-end_epoch = 10
-train_batch_size = 16
-
-syncbn = True
-bbox_ratio = 1.2
-
-# continue
-continue_train = False
-start_over = True
-
-# dataset setting
-agora_fix_betas = True
-agora_fix_global_orient_transl = True
-agora_valid_root_pose = True
-
-# all
-dataset_list = ['Human36M', 'MSCOCO', 'MPII', 'AGORA', 'EHF', 'SynBody', 'GTA_Human2', \
- 'EgoBody_Egocentric', 'EgoBody_Kinect', 'UBody', 'PW3D', 'MuCo', 'PROX']
-trainset_3d = ['MSCOCO','AGORA', 'UBody']
-trainset_2d = ['PW3D', 'MPII', 'Human36M']
-trainset_humandata = ['BEDLAM', 'SPEC', 'GTA_Human2','SynBody', 'PoseTrack',
- 'EgoBody_Egocentric', 'PROX', 'CrowdPose',
- 'EgoBody_Kinect', 'MPI_INF_3DHP', 'RICH', 'MuCo', 'InstaVariety',
- 'Behave', 'UP3D', 'ARCTIC',
- 'OCHuman', 'CHI3D', 'RenBody_HiRes', 'MTP', 'HumanSC3D', 'RenBody',
- 'FIT3D', 'Talkshow' , 'SSP3D', 'LSPET']
-testset = 'EHF'
-
-use_cache = True
-# downsample
-BEDLAM_train_sample_interval = 5
-EgoBody_Kinect_train_sample_interval = 10
-train_sample_interval = 10 # UBody
-MPI_INF_3DHP_train_sample_interval = 5
-InstaVariety_train_sample_interval = 10
-RenBody_HiRes_train_sample_interval = 5
-ARCTIC_train_sample_interval = 10
-# RenBody_train_sample_interval = 10
-FIT3D_train_sample_interval = 10
-Talkshow_train_sample_interval = 10
-
-# strategy
-data_strategy = 'balance' # 'balance' need to define total_data_len
-total_data_len = 4500000
-
-# model
-smplx_loss_weight = 1.0 #2 for agora_model for smplx shape
-smplx_pose_weight = 10.0
-
-smplx_kps_3d_weight = 100.0
-smplx_kps_2d_weight = 1.0
-net_kps_2d_weight = 1.0
-
-agora_benchmark = 'agora_model' # 'agora_model', 'test_only'
-
-model_type = 'smpler_x_h'
-encoder_config_file = 'main/transformer_utils/configs/smpler_x/encoder/body_encoder_huge.py'
-encoder_pretrained_model_path = 'pretrained_models/vitpose_huge.pth'
-feat_dim = 1280
-
-## =====FIXED ARGS============================================================
-## model setting
-upscale = 4
-hand_pos_joint_num = 20
-face_pos_joint_num = 72
-num_task_token = 24
-num_noise_sample = 0
-
-## UBody setting
-train_sample_interval = 10
-test_sample_interval = 100
-make_same_len = False
-
-## input, output size
-input_img_shape = (512, 384)
-input_body_shape = (256, 192)
-output_hm_shape = (16, 16, 12)
-input_hand_shape = (256, 256)
-output_hand_hm_shape = (16, 16, 16)
-output_face_hm_shape = (8, 8, 8)
-input_face_shape = (192, 192)
-focal = (5000, 5000) # virtual focal lengths
-princpt = (input_body_shape[1] / 2, input_body_shape[0] / 2) # virtual principal point position
-body_3d_size = 2
-hand_3d_size = 0.3
-face_3d_size = 0.3
-camera_3d_size = 2.5
-
-## training config
-print_iters = 100
-lr_mult = 1
-
-## testing config
-test_batch_size = 32
-
-## others
-num_thread = 2
-vis = False
-
-## directory
-output_dir, model_dir, vis_dir, log_dir, result_dir, code_dir = None, None, None, None, None, None
diff --git a/main/config/config_smpler_x_l32.py b/main/config/config_smpler_x_l32.py
deleted file mode 100644
index 1cfedc0b6b59d17d2b666bfdfabff6c45069456b..0000000000000000000000000000000000000000
--- a/main/config/config_smpler_x_l32.py
+++ /dev/null
@@ -1,112 +0,0 @@
-import os
-import os.path as osp
-
-# will be update in exp
-num_gpus = -1
-exp_name = 'output/exp1/pre_analysis'
-
-# quick access
-save_epoch = 1
-lr = 1e-5
-end_epoch = 10
-train_batch_size = 32
-
-syncbn = True
-bbox_ratio = 1.2
-
-# continue
-continue_train = False
-start_over = True
-
-# dataset setting
-agora_fix_betas = True
-agora_fix_global_orient_transl = True
-agora_valid_root_pose = True
-
-# all
-dataset_list = ['Human36M', 'MSCOCO', 'MPII', 'AGORA', 'EHF', 'SynBody', 'GTA_Human2', \
- 'EgoBody_Egocentric', 'EgoBody_Kinect', 'UBody', 'PW3D', 'MuCo', 'PROX']
-trainset_3d = ['MSCOCO','AGORA', 'UBody']
-trainset_2d = ['PW3D', 'MPII', 'Human36M']
-trainset_humandata = ['BEDLAM', 'SPEC', 'GTA_Human2','SynBody', 'PoseTrack',
- 'EgoBody_Egocentric', 'PROX', 'CrowdPose',
- 'EgoBody_Kinect', 'MPI_INF_3DHP', 'RICH', 'MuCo', 'InstaVariety',
- 'Behave', 'UP3D', 'ARCTIC',
- 'OCHuman', 'CHI3D', 'RenBody_HiRes', 'MTP', 'HumanSC3D', 'RenBody',
- 'FIT3D', 'Talkshow' , 'SSP3D', 'LSPET']
-testset = 'EHF'
-
-use_cache = True
-# downsample
-BEDLAM_train_sample_interval = 5
-EgoBody_Kinect_train_sample_interval = 10
-train_sample_interval = 10 # UBody
-MPI_INF_3DHP_train_sample_interval = 5
-InstaVariety_train_sample_interval = 10
-RenBody_HiRes_train_sample_interval = 5
-ARCTIC_train_sample_interval = 10
-# RenBody_train_sample_interval = 10
-FIT3D_train_sample_interval = 10
-Talkshow_train_sample_interval = 10
-
-# strategy
-data_strategy = 'balance' # 'balance' need to define total_data_len
-total_data_len = 4500000
-
-# model
-smplx_loss_weight = 1.0 #2 for agora_model for smplx shape
-smplx_pose_weight = 10.0
-
-smplx_kps_3d_weight = 100.0
-smplx_kps_2d_weight = 1.0
-net_kps_2d_weight = 1.0
-
-agora_benchmark = 'agora_model' # 'agora_model', 'test_only'
-
-model_type = 'smpler_x_l'
-encoder_config_file = 'main/transformer_utils/configs/smpler_x/encoder/body_encoder_large.py'
-encoder_pretrained_model_path = 'pretrained_models/vitpose_large.pth'
-feat_dim = 1024
-
-
-## =====FIXED ARGS============================================================
-## model setting
-upscale = 4
-hand_pos_joint_num = 20
-face_pos_joint_num = 72
-num_task_token = 24
-num_noise_sample = 0
-
-## UBody setting
-train_sample_interval = 10
-test_sample_interval = 100
-make_same_len = False
-
-## input, output size
-input_img_shape = (512, 384)
-input_body_shape = (256, 192)
-output_hm_shape = (16, 16, 12)
-input_hand_shape = (256, 256)
-output_hand_hm_shape = (16, 16, 16)
-output_face_hm_shape = (8, 8, 8)
-input_face_shape = (192, 192)
-focal = (5000, 5000) # virtual focal lengths
-princpt = (input_body_shape[1] / 2, input_body_shape[0] / 2) # virtual principal point position
-body_3d_size = 2
-hand_3d_size = 0.3
-face_3d_size = 0.3
-camera_3d_size = 2.5
-
-## training config
-print_iters = 100
-lr_mult = 1
-
-## testing config
-test_batch_size = 32
-
-## others
-num_thread = 2
-vis = False
-
-## directory
-output_dir, model_dir, vis_dir, log_dir, result_dir, code_dir = None, None, None, None, None, None
diff --git a/main/config/config_smpler_x_s32.py b/main/config/config_smpler_x_s32.py
deleted file mode 100644
index 090501bef40b1130e733d9567c05dd11b22b9ed1..0000000000000000000000000000000000000000
--- a/main/config/config_smpler_x_s32.py
+++ /dev/null
@@ -1,111 +0,0 @@
-import os
-import os.path as osp
-
-# will be update in exp
-num_gpus = -1
-exp_name = 'output/exp1/pre_analysis'
-
-# quick access
-save_epoch = 1
-lr = 1e-5
-end_epoch = 10
-train_batch_size = 32
-
-syncbn = True
-bbox_ratio = 1.2
-
-# continue
-continue_train = False
-start_over = True
-
-# dataset setting
-agora_fix_betas = True
-agora_fix_global_orient_transl = True
-agora_valid_root_pose = True
-
-# all data
-dataset_list = ['Human36M', 'MSCOCO', 'MPII', 'AGORA', 'EHF', 'SynBody', 'GTA_Human2', \
- 'EgoBody_Egocentric', 'EgoBody_Kinect', 'UBody', 'PW3D', 'MuCo', 'PROX']
-trainset_3d = ['MSCOCO','AGORA', 'UBody']
-trainset_2d = ['PW3D', 'MPII', 'Human36M']
-trainset_humandata = ['BEDLAM', 'SPEC', 'GTA_Human2','SynBody', 'PoseTrack',
- 'EgoBody_Egocentric', 'PROX', 'CrowdPose',
- 'EgoBody_Kinect', 'MPI_INF_3DHP', 'RICH', 'MuCo', 'InstaVariety',
- 'Behave', 'UP3D', 'ARCTIC',
- 'OCHuman', 'CHI3D', 'RenBody_HiRes', 'MTP', 'HumanSC3D', 'RenBody',
- 'FIT3D', 'Talkshow' , 'SSP3D', 'LSPET']
-testset = 'EHF'
-
-use_cache = True
-# downsample
-BEDLAM_train_sample_interval = 5
-EgoBody_Kinect_train_sample_interval = 10
-train_sample_interval = 10 # UBody
-MPI_INF_3DHP_train_sample_interval = 5
-InstaVariety_train_sample_interval = 10
-RenBody_HiRes_train_sample_interval = 5
-ARCTIC_train_sample_interval = 10
-# RenBody_train_sample_interval = 10
-FIT3D_train_sample_interval = 10
-Talkshow_train_sample_interval = 10
-
-# strategy
-data_strategy = 'balance' # 'balance' need to define total_data_len
-total_data_len = 4500000
-
-# model
-smplx_loss_weight = 1.0 #2 for agora_model for smplx shape
-smplx_pose_weight = 10.0
-
-smplx_kps_3d_weight = 100.0
-smplx_kps_2d_weight = 1.0
-net_kps_2d_weight = 1.0
-
-agora_benchmark = 'agora_model' # 'agora_model', 'test_only'
-
-model_type = 'smpler_x_s'
-encoder_config_file = 'main/transformer_utils/configs/smpler_x/encoder/body_encoder_small.py'
-encoder_pretrained_model_path = 'pretrained_models/vitpose_small.pth'
-feat_dim = 384
-
-## =====FIXED ARGS============================================================
-## model setting
-upscale = 4
-hand_pos_joint_num = 20
-face_pos_joint_num = 72
-num_task_token = 24
-num_noise_sample = 0
-
-## UBody setting
-train_sample_interval = 10
-test_sample_interval = 100
-make_same_len = False
-
-## input, output size
-input_img_shape = (512, 384)
-input_body_shape = (256, 192)
-output_hm_shape = (16, 16, 12)
-input_hand_shape = (256, 256)
-output_hand_hm_shape = (16, 16, 16)
-output_face_hm_shape = (8, 8, 8)
-input_face_shape = (192, 192)
-focal = (5000, 5000) # virtual focal lengths
-princpt = (input_body_shape[1] / 2, input_body_shape[0] / 2) # virtual principal point position
-body_3d_size = 2
-hand_3d_size = 0.3
-face_3d_size = 0.3
-camera_3d_size = 2.5
-
-## training config
-print_iters = 100
-lr_mult = 1
-
-## testing config
-test_batch_size = 32
-
-## others
-num_thread = 2
-vis = False
-
-## directory
-output_dir, model_dir, vis_dir, log_dir, result_dir, code_dir = None, None, None, None, None, None
diff --git a/main/inference.py b/main/inference.py
index 6e9445582d9fb6fc200913efc519516d84b58988..af4d1c505367471e52c876a26a5b76f6412ce4e3 100644
--- a/main/inference.py
+++ b/main/inference.py
@@ -11,11 +11,12 @@ sys.path.insert(0, osp.join(CUR_DIR, '..', 'main'))
sys.path.insert(0, osp.join(CUR_DIR , '..', 'common'))
from config import cfg
import cv2
-from tqdm import tqdm
-import json
-from typing import Literal, Union
from mmdet.apis import init_detector, inference_detector
from utils.inference_utils import process_mmdet_results, non_max_suppression
+from postometro_utils.smpl import SMPL
+import data.config as smpl_cfg
+from postometro import get_model
+from postometro_utils.renderer_pyrender import PyRender_Renderer
class Inferer:
@@ -29,16 +30,18 @@ class Inferer:
# ckpt_path = osp.join(CUR_DIR, '../pretrained_models', f'{pretrained_model}.pth.tar')
ckpt_path = None # for config
cfg.get_config_fromfile(config_path)
+ # uodate config
cfg.update_config(num_gpus, ckpt_path, output_folder, self.device)
self.cfg = cfg
cudnn.benchmark = True
- # # load model
- # from base import Demoer
- # demoer = Demoer()
- # demoer._make_model()
- # demoer.model.eval()
- # self.demoer = demoer
+ # load SMPL
+ self.smpl = SMPL().to(self.device)
+ self.faces = self.smpl.faces.cpu().numpy()
+
+ # load model
+ hmr_model_checkpoint_file = osp.join(CUR_DIR, '../pretrained_models/postometro/resnet_state_dict.bin')
+ self.hmr_model = get_model(backbone_str='resnet50',device=self.device, checkpoint_file = hmr_model_checkpoint_file)
# load faster-rcnn as human detector
checkpoint_file = osp.join(CUR_DIR, '../pretrained_models/mmdet/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth')
@@ -46,17 +49,20 @@ class Inferer:
model = init_detector(config_file, checkpoint_file, device=self.device) # or device='cuda:0'
self.model = model
- def infer(self, original_img, iou_thr, frame, multi_person=False, mesh_as_vertices=False):
+ def infer(self, original_img, iou_thr, multi_person=False, mesh_as_vertices=False):
from utils.preprocessing import process_bbox, generate_patch_image
- # from utils.vis import render_mesh, save_obj
+ from utils.vis import render_mesh
# from utils.human_models import smpl_x
- mesh_paths = []
- smplx_paths = []
+
# prepare input image
- transform = transforms.ToTensor()
+ transform = transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
vis_img = original_img.copy()
original_img_height, original_img_width = original_img.shape[:2]
+ # load renderer
+ # self.renderer = PyRender_Renderer(resolution=(original_img_width, original_img_height), faces=self.faces)
+
## mmdet inference
mmdet_results = inference_detector(self.model, original_img)
mmdet_box = process_mmdet_results(mmdet_results, cat_id=0, multi_person=True)
@@ -99,51 +105,69 @@ class Inferer:
top_left = (int(bbox[0]), int(bbox[1]))
bottom_right = (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3]))
cv2.rectangle(vis_img, top_left, bottom_right, (0, 0, 255), 2)
-
-
+
# human model inference
- # img, img2bb_trans, bb2img_trans = generate_patch_image(original_img, bbox, 1.0, 0.0, False, self.cfg.input_img_shape)
- # img = transform(img.astype(np.float32))/255
- # img = img.to(cfg.device)[None,:,:,:]
- # inputs = {'img': img}
- # targets = {}
- # meta_info = {}
-
- # # mesh recovery
- # with torch.no_grad():
- # out = self.demoer.model(inputs, targets, meta_info, 'test')
- # mesh = out['smplx_mesh_cam'].detach().cpu().numpy()[0]
-
- # ## save mesh
- # save_path_mesh = os.path.join(self.output_folder, 'mesh')
- # os.makedirs(save_path_mesh, exist_ok= True)
- # obj_path = os.path.join(save_path_mesh, f'{frame:05}_{bbox_id}.obj')
- # save_obj(mesh, smpl_x.face, obj_path)
- # mesh_paths.append(obj_path)
- # ## save single person param
- # smplx_pred = {}
- # smplx_pred['global_orient'] = out['smplx_root_pose'].reshape(-1,3).cpu().numpy()
- # smplx_pred['body_pose'] = out['smplx_body_pose'].reshape(-1,3).cpu().numpy()
- # smplx_pred['left_hand_pose'] = out['smplx_lhand_pose'].reshape(-1,3).cpu().numpy()
- # smplx_pred['right_hand_pose'] = out['smplx_rhand_pose'].reshape(-1,3).cpu().numpy()
- # smplx_pred['jaw_pose'] = out['smplx_jaw_pose'].reshape(-1,3).cpu().numpy()
- # smplx_pred['leye_pose'] = np.zeros((1, 3))
- # smplx_pred['reye_pose'] = np.zeros((1, 3))
- # smplx_pred['betas'] = out['smplx_shape'].reshape(-1,10).cpu().numpy()
- # smplx_pred['expression'] = out['smplx_expr'].reshape(-1,10).cpu().numpy()
- # smplx_pred['transl'] = out['cam_trans'].reshape(-1,3).cpu().numpy()
- # save_path_smplx = os.path.join(self.output_folder, 'smplx')
- # os.makedirs(save_path_smplx, exist_ok= True)
-
- # npz_path = os.path.join(save_path_smplx, f'{frame:05}_{bbox_id}.npz')
- # np.savez(npz_path, **smplx_pred)
- # smplx_paths.append(npz_path)
-
- # ## render single person mesh
- # focal = [self.cfg.focal[0] / self.cfg.input_body_shape[1] * bbox[2], self.cfg.focal[1] / self.cfg.input_body_shape[0] * bbox[3]]
- # princpt = [self.cfg.princpt[0] / self.cfg.input_body_shape[1] * bbox[2] + bbox[0], self.cfg.princpt[1] / self.cfg.input_body_shape[0] * bbox[3] + bbox[1]]
- # vis_img = render_mesh(vis_img, mesh, smpl_x.face, {'focal': focal, 'princpt': princpt},
+ img, img2bb_trans, bb2img_trans = generate_patch_image(original_img, bbox, 1.0, 0.0, False, self.cfg.input_img_shape)
+ vis_patched_images = img.copy()
+ # here we pre-process images
+ img = img.transpose((2,0,1)) # h,w,c -> c,h,w
+ img = torch.from_numpy(img).float() / 255.0
+ # Store image before normalization to use it in visualization
+ img = transform(img)
+ img = img.to(cfg.device)[None,:,:,:]
+
+ self.renderer = PyRender_Renderer(resolution=(bbox[2], bbox[3]), faces=self.faces)
+
+ # mesh recovery
+ with torch.no_grad():
+ out = self.hmr_model(img)
+ pred_cam, pred_3d_vertices_fine = out['pred_cam'], out['pred_3d_vertices_fine']
+ pred_3d_joints_from_smpl = self.smpl.get_h36m_joints(pred_3d_vertices_fine) # batch_size X 17 X 3
+ pred_3d_joints_from_smpl_pelvis = pred_3d_joints_from_smpl[:,smpl_cfg.H36M_J17_NAME.index('Pelvis'),:]
+ pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,smpl_cfg.H36M_J17_TO_J14,:] # batch_size X 14 X 3
+ # normalize predicted vertices
+ pred_3d_vertices_fine = pred_3d_vertices_fine - pred_3d_joints_from_smpl_pelvis[:, None, :] # batch_size X 6890 X 3
+ pred_3d_vertices_fine = pred_3d_vertices_fine.detach().cpu().numpy()[0] # 6890 X 3
+ pred_cam = pred_cam.detach().cpu().numpy()[0]
+ bbox_cx, bbox_cy = bbox[0] + bbox[2] / 2, bbox[1] + bbox[3] / 2
+ img_cx, img_cy = original_img_width / 2, original_img_height / 2
+ cx_delta, cy_delta = bbox_cx / img_cx - 1, bbox_cy / img_cy - 1
+
+ # render single person mesh
+ # focal = [self.cfg.focal[0] / self.cfg.renderer_input_body_shape[1] * bbox[2], self.cfg.focal[1] / self.cfg.renderer_input_body_shape[0] * bbox[3]]
+ # princpt = [self.cfg.princpt[0] / self.cfg.renderer_input_body_shape[1] * bbox[2] + bbox[0], self.cfg.princpt[1] / self.cfg.renderer_input_body_shape[0] * bbox[3] + bbox[1]]
+ # vis_img = render_mesh(vis_img, pred_3d_vertices_fine, self.faces, {'focal': focal, 'princpt': princpt},
# mesh_as_vertices=mesh_as_vertices)
- # vis_img = vis_img.astype('uint8')
- return vis_img, num_bbox, ok_bboxes
+ # vis_img = render_mesh(vis_img, pred_3d_vertices_fine, self.faces, [pred_cam[0] / (original_img_width / bbox[2]), pred_cam[0] / (original_img_height / bbox[3]), pred_cam[1], pred_cam[2]], mesh_as_vertices=mesh_as_vertices)
+ # import ipdb
+ # ipdb.set_trace()
+ vis_img = render_mesh(vis_img, pred_3d_vertices_fine, self.faces, [pred_cam[0] / (original_img_width / bbox[2]), pred_cam[0] / (original_img_height / bbox[3]),
+ pred_cam[1] + cx_delta / (pred_cam[0] / (original_img_width / bbox[2])),
+ pred_cam[2] + cy_delta / (pred_cam[0] / (original_img_height / bbox[3]))],
+ mesh_as_vertices=mesh_as_vertices)
+ # vis_img = render_mesh(vis_img, pred_3d_vertices_fine, self.faces, [pred_cam[0] / (original_img_width / bbox[2]), pred_cam[0] / (original_img_height / bbox[3]), 0, 0], mesh_as_vertices=mesh_as_vertices)
+
+ # bbox_meta = {'bbox': bbox, 'img_hw': [original_img_height, original_img_width]}
+ # vis_img = self.renderer(pred_3d_vertices_fine, bbox_meta, vis_img, pred_cam)
+ vis_img = vis_img.astype('uint8')
+ return vis_img, len(ok_bboxes), ok_bboxes
+
+
+if __name__ == '__main__':
+ from PIL import Image
+ inferer = Inferer('postometro', 0, './out_folder') # gpu
+ image_path = f'../assets/07.jpg'
+ image = Image.open(image_path)
+ # Convert the PIL image to a NumPy array
+ image_np = np.array(image)
+ vis_img, _ , _ = inferer.infer(image_np, 0.2, multi_person=True, mesh_as_vertices=False)
+ save_path = f'./saved_vis_07.jpg'
+
+ # Ensure the image is in the correct format (PIL expects uint8)
+ if vis_img.dtype != np.uint8:
+ vis_img = vis_img.astype('uint8')
+ # Convert the Numpy array (if RGB) to a PIL image and save
+ image = Image.fromarray(vis_img)
+ image.save(save_path)
+
diff --git a/main/pct_utils/__pycache__/modules.cpython-39.pyc b/main/pct_utils/__pycache__/modules.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..421fd611fd1622031e1a82917072fb06efc4ed9f
Binary files /dev/null and b/main/pct_utils/__pycache__/modules.cpython-39.pyc differ
diff --git a/main/pct_utils/__pycache__/pct.cpython-39.pyc b/main/pct_utils/__pycache__/pct.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..25330b0d344355415e0196f7fdabc6d5f1063b7c
Binary files /dev/null and b/main/pct_utils/__pycache__/pct.cpython-39.pyc differ
diff --git a/main/pct_utils/__pycache__/pct_backbone.cpython-39.pyc b/main/pct_utils/__pycache__/pct_backbone.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c2cd494fc2451478426a1b2b6fa72e61fb0a12b7
Binary files /dev/null and b/main/pct_utils/__pycache__/pct_backbone.cpython-39.pyc differ
diff --git a/main/pct_utils/__pycache__/pct_head.cpython-39.pyc b/main/pct_utils/__pycache__/pct_head.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..132027917b0b4337373b740be122468af2477065
Binary files /dev/null and b/main/pct_utils/__pycache__/pct_head.cpython-39.pyc differ
diff --git a/main/pct_utils/__pycache__/pct_tokenizer.cpython-39.pyc b/main/pct_utils/__pycache__/pct_tokenizer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c31244caba23440d0c9250b64cc69f33b4816c94
Binary files /dev/null and b/main/pct_utils/__pycache__/pct_tokenizer.cpython-39.pyc differ
diff --git a/main/pct_utils/modules.py b/main/pct_utils/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c868a46adb650284343178e1ea8c9a5c51ff73a
--- /dev/null
+++ b/main/pct_utils/modules.py
@@ -0,0 +1,117 @@
+# --------------------------------------------------------
+# Borrow from unofficial MLPMixer (https://github.com/920232796/MlpMixer-pytorch)
+# Borrow from ResNet
+# Modified by Zigang Geng (zigang@mail.ustc.edu.cn)
+# --------------------------------------------------------
+
+import torch
+import torch.nn as nn
+
+
+class FCBlock(nn.Module):
+ def __init__(self, dim, out_dim):
+ super().__init__()
+
+ self.ff = nn.Sequential(
+ nn.Linear(dim, out_dim),
+ nn.LayerNorm(out_dim),
+ nn.ReLU(inplace=True),
+ )
+
+ def forward(self, x):
+ return self.ff(x)
+
+
+class MLPBlock(nn.Module):
+ def __init__(self, dim, inter_dim, dropout_ratio):
+ super().__init__()
+
+ self.ff = nn.Sequential(
+ nn.Linear(dim, inter_dim),
+ nn.GELU(),
+ nn.Dropout(dropout_ratio),
+ nn.Linear(inter_dim, dim),
+ nn.Dropout(dropout_ratio)
+ )
+
+ def forward(self, x):
+ return self.ff(x)
+
+
+class MixerLayer(nn.Module):
+ def __init__(self,
+ hidden_dim,
+ hidden_inter_dim,
+ token_dim,
+ token_inter_dim,
+ dropout_ratio):
+ super().__init__()
+
+ self.layernorm1 = nn.LayerNorm(hidden_dim)
+ self.MLP_token = MLPBlock(token_dim, token_inter_dim, dropout_ratio)
+ self.layernorm2 = nn.LayerNorm(hidden_dim)
+ self.MLP_channel = MLPBlock(hidden_dim, hidden_inter_dim, dropout_ratio)
+
+ def forward(self, x):
+ y = self.layernorm1(x)
+ y = y.transpose(2, 1)
+ y = self.MLP_token(y)
+ y = y.transpose(2, 1)
+ z = self.layernorm2(x + y)
+ z = self.MLP_channel(z)
+ out = x + y + z
+ return out
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1,
+ downsample=None, dilation=1):
+ super(BasicBlock, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
+ padding=dilation, bias=False, dilation=dilation)
+ self.bn1 = nn.BatchNorm2d(planes, momentum=0.1)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
+ padding=dilation, bias=False, dilation=dilation)
+ self.bn2 = nn.BatchNorm2d(planes, momentum=0.1)
+ self.downsample = downsample
+ self.stride = stride
+
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+def make_conv_layers(feat_dims, kernel=3, stride=1, padding=1, bnrelu_final=True):
+ layers = []
+ for i in range(len(feat_dims)-1):
+ layers.append(
+ nn.Conv2d(
+ in_channels=feat_dims[i],
+ out_channels=feat_dims[i+1],
+ kernel_size=kernel,
+ stride=stride,
+ padding=padding
+ ))
+ # Do not use BN and ReLU for final estimation
+ if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final):
+ layers.append(nn.BatchNorm2d(feat_dims[i+1]))
+ layers.append(nn.ReLU(inplace=True))
+
+ return nn.Sequential(*layers)
\ No newline at end of file
diff --git a/main/pct_utils/pct.py b/main/pct_utils/pct.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff60e0326a6f96bcf06f2ef3fa0c1c0d0bff8eb1
--- /dev/null
+++ b/main/pct_utils/pct.py
@@ -0,0 +1,69 @@
+import torch
+import torch.nn as nn
+from pct_utils.pct_head import PCT_Head
+
+class PCT(nn.Module):
+ def __init__(self,
+ args,
+ backbone,
+ stage_pct,
+ in_channels,
+ image_size,
+ num_joints,
+ pretrained=None,
+ tokenizer_pretrained=None):
+ super().__init__()
+ self.stage_pct = stage_pct
+ assert self.stage_pct in ["tokenizer", "classifier"]
+ self.guide_ratio = args.tokenizer_guide_ratio
+ self.image_guide = self.guide_ratio > 0.0
+ self.num_joints = num_joints
+
+ self.backbone = backbone
+ if self.image_guide:
+ self.extra_backbone = backbone
+
+ self.keypoint_head = PCT_Head(args,stage_pct,in_channels,image_size,num_joints)
+
+ if (pretrained is not None) or (tokenizer_pretrained is not None):
+ self.init_weights(pretrained, tokenizer_pretrained)
+
+ def init_weights(self, pretrained, tokenizer):
+ """Weight initialization for model."""
+ if self.stage_pct == "classifier":
+ self.backbone.init_weights(pretrained)
+ if self.image_guide:
+ self.extra_backbone.init_weights(pretrained)
+ self.keypoint_head.init_weights()
+ self.keypoint_head.tokenizer.init_weights(tokenizer)
+
+ def forward(self,img, joints, train = True):
+ if train:
+ output = None if self.stage_pct == "tokenizer" else self.backbone(img)
+ extra_output = self.extra_backbone(img) if self.image_guide else None
+
+ p_logits, p_joints, g_logits, e_latent_loss = \
+ self.keypoint_head(output, extra_output, joints, train=True)
+ return {
+ 'cls_logits': p_logits,
+ 'pred_pose': p_joints,
+ 'encoding_indices': g_logits,
+ 'e_latent_loss': e_latent_loss
+ }
+ else:
+ results = {}
+
+ batch_size, _, img_height, img_width = img.shape
+
+ output = None if self.stage_pct == "tokenizer" \
+ else self.backbone(img)
+ extra_output = self.extra_backbone(img) \
+ if self.image_guide and self.stage_pct == "tokenizer" else None
+
+ p_joints, encoding_scores, out_part_token_feat = \
+ self.keypoint_head(output, extra_output, joints, train=False)
+ return {
+ 'pred_pose': p_joints,
+ 'encoding_scores': encoding_scores,
+ 'part_token_feat': out_part_token_feat
+ }
diff --git a/main/pct_utils/pct_backbone.py b/main/pct_utils/pct_backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cacc11c1dfe29830a0a031b814c55e2dd1593f6
--- /dev/null
+++ b/main/pct_utils/pct_backbone.py
@@ -0,0 +1,1475 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu, Yutong Lin, Yixuan Wei
+# --------------------------------------------------------
+
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+import torch.utils.checkpoint as checkpoint
+from torch.nn.utils import weight_norm
+from torch import Tensor, Size
+from typing import Union, List
+import numpy as np
+import logging
+
+# Copyright (c) Open-MMLab. All rights reserved.
+# Copy from mmcv source code.
+import io
+import os
+import os.path as osp
+import pkgutil
+import time
+import warnings
+import numpy as np
+from scipy import interpolate
+
+import torch
+import torchvision
+import torch.distributed as dist
+from torch.utils import model_zoo
+from torch.nn import functional as F
+
+
+def _load_checkpoint(filename, map_location=None):
+ if not osp.isfile(filename):
+ raise IOError(f'{filename} is not a checkpoint file')
+ checkpoint = torch.load(filename, map_location=map_location)
+ return checkpoint
+
+
+def load_checkpoint_swin(model,
+ filename,
+ map_location='cpu',
+ strict=False,
+ rpe_interpolation='outer_mask',
+ logger=None):
+ """Load checkpoint from a file or URI.
+ Args:
+ model (Module): Module to load checkpoint.
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+ details.
+ map_location (str): Same as :func:`torch.load`.
+ strict (bool): Whether to allow different params for the model and
+ checkpoint.
+ logger (:mod:`logging.Logger` or None): The logger for error message.
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ checkpoint = _load_checkpoint(filename, map_location)
+ # OrderedDict is a subclass of dict
+ if not isinstance(checkpoint, dict):
+ raise RuntimeError(
+ f'No state_dict found in checkpoint file {filename}')
+ # get state_dict from checkpoint
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ elif 'model' in checkpoint:
+ state_dict = checkpoint['model']
+ elif 'module' in checkpoint:
+ state_dict = checkpoint['module']
+ else:
+ state_dict = checkpoint
+ # strip prefix of state_dict
+ if list(state_dict.keys())[0].startswith('module.'):
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
+
+ if list(state_dict.keys())[0].startswith('backbone.'):
+ state_dict = {k[9:]: v for k, v in state_dict.items()}
+
+ # for MoBY, load model of online branch
+ if sorted(list(state_dict.keys()))[2].startswith('encoder'):
+ state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
+
+ # directly load here
+
+ model.load_state_dict(state_dict, strict=True)
+
+ return checkpoint
+
+
+_shape_t = Union[int, List[int], Size]
+
+from itertools import repeat
+import collections.abc
+
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return tuple(x)
+ return tuple(repeat(x, n))
+ return parse
+
+to_2tuple = _ntuple(2)
+
+def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+ def extra_repr(self):
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
+
+def _trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2)
+
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+
+ NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
+ applied while sampling the normal with mean/std applied, therefore a, b args
+ should be adjusted to match the range of mean, std args.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.trunc_normal_(w)
+ """
+ with torch.no_grad():
+ return _trunc_normal_(tensor, mean, std, a, b)
+
+
+class LayerNorm2D(nn.Module):
+ def __init__(self, normalized_shape, norm_layer=None):
+ super().__init__()
+ self.ln = norm_layer(normalized_shape) if norm_layer is not None else nn.Identity()
+
+ def forward(self, x):
+ """
+ x: N C H W
+ """
+ x = x.permute(0, 2, 3, 1)
+ x = self.ln(x)
+ x = x.permute(0, 3, 1, 2)
+ return x
+
+
+class LayerNormFP32(nn.LayerNorm):
+ def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True) -> None:
+ super(LayerNormFP32, self).__init__(normalized_shape, eps, elementwise_affine)
+
+ def forward(self, input: Tensor) -> Tensor:
+ return F.layer_norm(
+ input.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).type_as(input)
+
+
+class LinearFP32(nn.Linear):
+ def __init__(self, in_features, out_features, bias=True):
+ super(LinearFP32, self).__init__(in_features, out_features, bias)
+
+ def forward(self, input: Tensor) -> Tensor:
+ return F.linear(input.float(), self.weight.float(),
+ self.bias.float() if self.bias is not None else None)
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,
+ norm_layer=None, mlpfp32=False):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.mlpfp32 = mlpfp32
+
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+ if norm_layer is not None:
+ self.norm = norm_layer(hidden_features)
+ else:
+ self.norm = None
+
+ def forward(self, x, H, W):
+ x = self.fc1(x)
+ if self.norm:
+ x = self.norm(x)
+ x = self.act(x)
+ x = self.drop(x)
+ if self.mlpfp32:
+ x = self.fc2.float()(x.type(torch.float32))
+ x = self.drop.float()(x)
+ # print(f"======>[MLP FP32]")
+ else:
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class ConvMlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,
+ norm_layer=None, mlpfp32=False, proj_ln=False):
+ super().__init__()
+ self.mlp = Mlp(in_features=in_features, hidden_features=hidden_features, out_features=out_features,
+ act_layer=act_layer, drop=drop, norm_layer=norm_layer, mlpfp32=mlpfp32)
+ self.conv_proj = nn.Conv2d(in_features,
+ in_features,
+ kernel_size=3,
+ padding=1,
+ stride=1,
+ bias=False,
+ groups=in_features)
+ self.proj_ln = LayerNorm2D(in_features, LayerNormFP32) if proj_ln else None
+
+ def forward(self, x, H, W):
+ B, L, C = x.shape
+ assert L == H * W
+ x = x.view(B, H, W, C).permute(0, 3, 1, 2) # B C H W
+ x = self.conv_proj(x)
+ if self.proj_ln:
+ x = self.proj_ln(x)
+ x = x.permute(0, 2, 3, 1) # B H W C
+ x = x.reshape(B, L, C)
+ x = self.mlp(x, H, W)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
+ relative_coords_table_type='norm8_log', rpe_hidden_dim=512,
+ rpe_output_type='normal', attn_type='normal', mlpfp32=False, pretrain_window_size=-1):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ self.mlpfp32 = mlpfp32
+ self.attn_type = attn_type
+ self.rpe_output_type = rpe_output_type
+ self.relative_coords_table_type = relative_coords_table_type
+
+ if self.attn_type == 'cosine_mh':
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
+ elif self.attn_type == 'normal':
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+ else:
+ raise NotImplementedError()
+ if self.relative_coords_table_type != "none":
+ # mlp to generate table of relative position bias
+ self.rpe_mlp = nn.Sequential(nn.Linear(2, rpe_hidden_dim, bias=True),
+ nn.ReLU(inplace=True),
+ LinearFP32(rpe_hidden_dim, num_heads, bias=False))
+
+ # get relative_coords_table
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
+ relative_coords_table = torch.stack(
+ torch.meshgrid([relative_coords_h,
+ relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
+ if relative_coords_table_type == 'linear':
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
+ elif relative_coords_table_type == 'linear_bylayer':
+ relative_coords_table[:, :, :, 0] /= (pretrain_window_size - 1)
+ relative_coords_table[:, :, :, 1] /= (pretrain_window_size - 1)
+ elif relative_coords_table_type == 'norm8_log':
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
+ relative_coords_table *= 8 # normalize to -8, 8
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8) # log8
+ elif relative_coords_table_type == 'norm8_log_192to640':
+ if self.window_size[0] == 40:
+ relative_coords_table[:, :, :, 0] /= (11)
+ relative_coords_table[:, :, :, 1] /= (11)
+ elif self.window_size[0] == 20:
+ relative_coords_table[:, :, :, 0] /= (5)
+ relative_coords_table[:, :, :, 1] /= (5)
+ else:
+ raise NotImplementedError
+ relative_coords_table *= 8 # normalize to -8, 8
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8) # log8
+ # check
+ elif relative_coords_table_type == 'norm8_log_256to640':
+ if self.window_size[0] == 40:
+ relative_coords_table[:, :, :, 0] /= (15)
+ relative_coords_table[:, :, :, 1] /= (15)
+ elif self.window_size[0] == 20:
+ relative_coords_table[:, :, :, 0] /= (7)
+ relative_coords_table[:, :, :, 1] /= (7)
+ else:
+ raise NotImplementedError
+ relative_coords_table *= 8 # normalize to -8, 8
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8) # log8
+ elif relative_coords_table_type == 'norm8_log_bylayer':
+ relative_coords_table[:, :, :, 0] /= (pretrain_window_size - 1)
+ relative_coords_table[:, :, :, 1] /= (pretrain_window_size - 1)
+ relative_coords_table *= 8 # normalize to -8, 8
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8) # log8
+ else:
+ raise NotImplementedError
+ self.register_buffer("relative_coords_table", relative_coords_table)
+ else:
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+ trunc_normal_(self.relative_position_bias_table, std=.02)
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(dim))
+ self.v_bias = nn.Parameter(torch.zeros(dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+
+ qkv_bias = None
+ if self.q_bias is not None:
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ # qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ if self.attn_type == 'cosine_mh':
+ q = F.normalize(q.float(), dim=-1)
+ k = F.normalize(k.float(), dim=-1)
+ logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01, device=self.logit_scale.device))).exp()
+ attn = (q @ k.transpose(-2, -1)) * logit_scale.float()
+ elif self.attn_type == 'normal':
+ q = q * self.scale
+ attn = (q.float() @ k.float().transpose(-2, -1))
+ else:
+ raise NotImplementedError()
+
+ if self.relative_coords_table_type != "none":
+ # relative_position_bias_table: 2*Wh-1 * 2*Ww-1, nH
+ relative_position_bias_table = self.rpe_mlp(self.relative_coords_table).view(-1, self.num_heads)
+ else:
+ relative_position_bias_table = self.relative_position_bias_table
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ if self.rpe_output_type == 'normal':
+ pass
+ elif self.rpe_output_type == 'sigmoid':
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
+ else:
+ raise NotImplementedError
+
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+
+ attn = self.softmax(attn)
+ attn = attn.type_as(x)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ if self.mlpfp32:
+ x = self.proj.float()(x.type(torch.float32))
+ x = self.proj_drop.float()(x)
+ # print(f"======>[ATTN FP32]")
+ else:
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
+
+ def flops(self, N):
+ # calculate flops for 1 window with token length of N
+ flops = 0
+ # qkv = self.qkv(x)
+ flops += N * self.dim * 3 * self.dim
+ # attn = (q @ k.transpose(-2, -1))
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
+ # x = (attn @ v)
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
+ # x = self.proj(x)
+ flops += N * self.dim * self.dim
+ return flops
+
+
+class SwinTransformerBlockPost(nn.Module):
+ """ Swin Transformer Block.
+
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+ use_mlp_norm=False, endnorm=False, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
+ relative_coords_table_type='norm8_log', rpe_hidden_dim=512,
+ rpe_output_type='normal', attn_type='normal', mlp_type='normal', mlpfp32=False,
+ pretrain_window_size=-1):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ self.use_mlp_norm = use_mlp_norm
+ self.endnorm = endnorm
+ self.mlpfp32 = mlpfp32
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
+ relative_coords_table_type=relative_coords_table_type, rpe_output_type=rpe_output_type,
+ rpe_hidden_dim=rpe_hidden_dim, attn_type=attn_type, mlpfp32=mlpfp32,
+ pretrain_window_size=pretrain_window_size)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+
+ if mlp_type == 'normal':
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
+ norm_layer=norm_layer if self.use_mlp_norm else None, mlpfp32=mlpfp32)
+ elif mlp_type == 'conv':
+ self.mlp = ConvMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
+ norm_layer=norm_layer if self.use_mlp_norm else None, mlpfp32=mlpfp32)
+ elif mlp_type == 'conv_ln':
+ self.mlp = ConvMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
+ norm_layer=norm_layer if self.use_mlp_norm else None, mlpfp32=mlpfp32, proj_ln=True)
+
+ if self.endnorm:
+ self.enorm = norm_layer(dim)
+ else:
+ self.enorm = None
+
+ self.H = None
+ self.W = None
+
+ def forward(self, x, mask_matrix):
+ H, W = self.H, self.W
+ B, L, C = x.shape
+ assert L == H * W, f"input feature has wrong size, with L = {L}, H = {H}, W = {W}"
+
+ shortcut = x
+
+ x = x.view(B, H, W, C)
+
+ # pad feature maps to multiples of window size
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ if pad_r > 0 or pad_b > 0:
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ attn_mask = mask_matrix
+ else:
+ shifted_x = x
+ attn_mask = None
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ orig_type = x.dtype # attn may force to fp32
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = x.view(B, H * W, C)
+
+ # FFN
+ if self.mlpfp32:
+ x = self.norm1.float()(x)
+ x = x.type(orig_type)
+ else:
+ x = self.norm1(x)
+ x = shortcut + self.drop_path(x)
+ shortcut = x
+
+ orig_type = x.dtype
+ x = self.mlp(x, H, W)
+ if self.mlpfp32:
+ x = self.norm2.float()(x)
+ x = x.type(orig_type)
+ else:
+ x = self.norm2(x)
+ x = shortcut + self.drop_path(x)
+
+ if self.endnorm:
+ x = self.enorm(x)
+
+ return x
+
+
+class SwinTransformerBlockPre(nn.Module):
+ """ Swin Transformer Block.
+
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+ use_mlp_norm=False, endnorm=False, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
+ init_values=None, relative_coords_table_type='norm8_log', rpe_hidden_dim=512,
+ rpe_output_type='normal', attn_type='normal', mlp_type='normal', mlpfp32=False,
+ pretrain_window_size=-1):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ self.use_mlp_norm = use_mlp_norm
+ self.endnorm = endnorm
+ self.mlpfp32 = mlpfp32
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
+ relative_coords_table_type=relative_coords_table_type, rpe_output_type=rpe_output_type,
+ rpe_hidden_dim=rpe_hidden_dim, attn_type=attn_type, mlpfp32=mlpfp32,
+ pretrain_window_size=pretrain_window_size)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+
+ if mlp_type == 'normal':
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
+ norm_layer=norm_layer if self.use_mlp_norm else None, mlpfp32=mlpfp32)
+ elif mlp_type == 'conv':
+ self.mlp = ConvMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
+ norm_layer=norm_layer if self.use_mlp_norm else None, mlpfp32=mlpfp32)
+ elif mlp_type == 'conv_ln':
+ self.mlp = ConvMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
+ norm_layer=norm_layer if self.use_mlp_norm else None, mlpfp32=mlpfp32, proj_ln=True)
+
+ if init_values is not None and init_values >= 0:
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+ else:
+ self.gamma_1, self.gamma_2 = 1.0, 1.0
+
+ if self.endnorm:
+ self.enorm = norm_layer(dim)
+ else:
+ self.enorm = None
+
+ self.H = None
+ self.W = None
+
+ def forward(self, x, mask_matrix):
+ H, W = self.H, self.W
+ B, L, C = x.shape
+ assert L == H * W, f"input feature has wrong size, with L = {L}, H = {H}, W = {W}"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # pad feature maps to multiples of window size
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ if pad_r > 0 or pad_b > 0:
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ attn_mask = mask_matrix
+ else:
+ shifted_x = x
+ attn_mask = None
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ orig_type = x.dtype # attn may force to fp32
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = x.view(B, H * W, C)
+
+ # FFN
+ if self.mlpfp32:
+ x = self.gamma_1 * x
+ x = x.type(orig_type)
+ else:
+ x = self.gamma_1 * x
+ x = shortcut + self.drop_path(x)
+ shortcut = x
+
+ orig_type = x.dtype
+ x = self.norm2(x)
+ if self.mlpfp32:
+ x = self.gamma_2 * self.mlp(x, H, W)
+ x = x.type(orig_type)
+ else:
+ x = self.gamma_2 * self.mlp(x, H, W)
+ x = shortcut + self.drop_path(x)
+
+ if self.endnorm:
+ x = self.enorm(x)
+
+ return x
+
+
+class PatchMerging(nn.Module):
+ """ Patch Merging Layer
+
+ Args:
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, norm_layer=nn.LayerNorm, postnorm=True):
+ super().__init__()
+ self.dim = dim
+ self.postnorm = postnorm
+
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(2 * dim) if postnorm else norm_layer(4 * dim)
+
+ def forward(self, x, H, W):
+ """ Forward function.
+
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ x = x.view(B, H, W, C)
+
+ # padding
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
+ if pad_input:
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ if self.postnorm:
+ x = self.reduction(x)
+ x = self.norm(x)
+ else:
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+
+class PatchReduction1C(nn.Module):
+ r""" Patch Reduction Layer.
+
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, norm_layer=nn.LayerNorm, postnorm=True):
+ super().__init__()
+ self.dim = dim
+ self.postnorm = postnorm
+
+ self.reduction = nn.Linear(dim, dim, bias=False)
+ self.norm = norm_layer(dim)
+
+ def forward(self, x, H, W):
+ """
+ x: B, H*W, C
+ """
+ if self.postnorm:
+ x = self.reduction(x)
+ x = self.norm(x)
+ else:
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+
+class ConvPatchMerging(nn.Module):
+ r""" Patch Merging Layer.
+
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, norm_layer=nn.LayerNorm, postnorm=True):
+ super().__init__()
+ self.dim = dim
+ self.postnorm = postnorm
+
+ self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=3, stride=2, padding=1)
+ self.norm = norm_layer(2 * dim) if postnorm else norm_layer(dim)
+
+ def forward(self, x, H, W):
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ x = x.view(B, H, W, C)
+
+ # padding
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
+ if pad_input:
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
+
+ if self.postnorm:
+ x = x.permute(0, 3, 1, 2) # B C H W
+ x = self.reduction(x).flatten(2).transpose(1, 2) # B H//2*W//2 2*C
+ x = self.norm(x)
+ else:
+ x = self.norm(x)
+ x = x.permute(0, 3, 1, 2) # B C H W
+ x = self.reduction(x).flatten(2).transpose(1, 2) # B H//2*W//2 2*C
+
+ return x
+
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+
+ Args:
+ dim (int): Number of feature channels
+ depth (int): Depths of this stage.
+ num_heads (int): Number of attention head.
+ window_size (int): Local window size. Default: 7.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ use_shift (bool): Whether to use shifted window. Default: True.
+ """
+
+ def __init__(self,
+ dim,
+ depth,
+ num_heads,
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False,
+ checkpoint_blocks=255,
+ init_values=None,
+ endnorm_interval=-1,
+ use_mlp_norm=False,
+ use_shift=True,
+ relative_coords_table_type='norm8_log',
+ rpe_hidden_dim=512,
+ rpe_output_type='normal',
+ attn_type='normal',
+ mlp_type='normal',
+ mlpfp32_blocks=[-1],
+ postnorm=True,
+ pretrain_window_size=-1):
+ super().__init__()
+ self.window_size = window_size
+ self.shift_size = window_size // 2
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+ self.checkpoint_blocks = checkpoint_blocks
+ self.init_values = init_values if init_values is not None else 0.0
+ self.endnorm_interval = endnorm_interval
+ self.mlpfp32_blocks = mlpfp32_blocks
+ self.postnorm = postnorm
+
+ # build blocks
+ if self.postnorm:
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlockPost(
+ dim=dim,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0 if (i % 2 == 0) or (not use_shift) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer,
+ use_mlp_norm=use_mlp_norm,
+ endnorm=True if ((i + 1) % endnorm_interval == 0) and (
+ endnorm_interval > 0) else False,
+ relative_coords_table_type=relative_coords_table_type,
+ rpe_hidden_dim=rpe_hidden_dim,
+ rpe_output_type=rpe_output_type,
+ attn_type=attn_type,
+ mlp_type=mlp_type,
+ mlpfp32=True if i in mlpfp32_blocks else False,
+ pretrain_window_size=pretrain_window_size)
+ for i in range(depth)])
+ else:
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlockPre(
+ dim=dim,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0 if (i % 2 == 0) or (not use_shift) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer,
+ init_values=init_values,
+ use_mlp_norm=use_mlp_norm,
+ endnorm=True if ((i + 1) % endnorm_interval == 0) and (
+ endnorm_interval > 0) else False,
+ relative_coords_table_type=relative_coords_table_type,
+ rpe_hidden_dim=rpe_hidden_dim,
+ rpe_output_type=rpe_output_type,
+ attn_type=attn_type,
+ mlp_type=mlp_type,
+ mlpfp32=True if i in mlpfp32_blocks else False,
+ pretrain_window_size=pretrain_window_size)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer, postnorm=postnorm)
+ else:
+ self.downsample = None
+
+ def forward(self, x, H, W):
+ """ Forward function.
+
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+
+ # calculate attention mask for SW-MSA
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ for idx, blk in enumerate(self.blocks):
+ blk.H, blk.W = H, W
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x, attn_mask)
+ else:
+ x = blk(x, attn_mask)
+
+ if self.downsample is not None:
+ x_down = self.downsample(x, H, W)
+ if isinstance(self.downsample, PatchReduction1C):
+ return x, H, W, x_down, H, W
+ else:
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
+ return x, H, W, x_down, Wh, Ww
+ else:
+ return x, H, W, x, H, W
+
+ def _init_block_norm_weights(self):
+ for blk in self.blocks:
+ nn.init.constant_(blk.norm1.bias, 0)
+ nn.init.constant_(blk.norm1.weight, self.init_values)
+ nn.init.constant_(blk.norm2.bias, 0)
+ nn.init.constant_(blk.norm2.weight, self.init_values)
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+
+ Args:
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ patch_size = to_2tuple(patch_size)
+ self.patch_size = patch_size
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ """Forward function."""
+ # padding
+ _, _, H, W = x.size()
+ if W % self.patch_size[1] != 0:
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+ if H % self.patch_size[0] != 0:
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+
+ x = self.proj(x) # B C Wh Ww
+ if self.norm is not None:
+ Wh, Ww = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
+
+ return x
+
+
+class ResNetDLNPatchEmbed(nn.Module):
+ def __init__(self, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ patch_size = to_2tuple(4)
+ self.patch_size = patch_size
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.conv1 = nn.Sequential(nn.Conv2d(in_chans, 64, 3, stride=2, padding=1, bias=False),
+ LayerNorm2D(64, norm_layer),
+ nn.GELU(),
+ nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False),
+ LayerNorm2D(64, norm_layer),
+ nn.GELU(),
+ nn.Conv2d(64, embed_dim, 3, stride=1, padding=1, bias=False))
+ self.norm = LayerNorm2D(embed_dim, norm_layer if norm_layer is not None else LayerNormFP32) # use ln always
+ self.act = nn.GELU()
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ def forward(self, x):
+ _, _, H, W = x.size()
+ if W % self.patch_size[1] != 0:
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+ if H % self.patch_size[0] != 0:
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+
+ x = self.conv1(x)
+ x = self.norm(x)
+ x = self.act(x)
+ x = self.maxpool(x)
+ # x = x.flatten(2).transpose(1, 2)
+ return x
+
+
+class SwinV2TransformerRPE2FC(nn.Module):
+ """ Swin Transformer backbone.
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
+ https://arxiv.org/pdf/2103.14030
+
+ Args:
+ pretrain_img_size (int): Input image size for training the pretrained model,
+ used in absolute postion embedding. Default 224.
+ patch_size (int | tuple(int)): Patch size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ depths (tuple[int]): Depths of each Swin Transformer stage.
+ num_heads (tuple[int]): Number of attention head of each stage.
+ window_size (int): Window size. Default: 7.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+ drop_rate (float): Dropout rate.
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
+ out_indices (Sequence[int]): Output from which stages.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters.
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ use_shift (bool): Whether to use shifted window. Default: True.
+ """
+
+ def __init__(self,
+ pretrain_img_size=224,
+ patch_size=4,
+ in_chans=3,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1,
+ norm_layer=partial(LayerNormFP32, eps=1e-6),
+ ape=False,
+ patch_norm=True,
+ use_checkpoint=False,
+ init_values=1e-5,
+ endnorm_interval=-1,
+ use_mlp_norm_layers=[],
+ relative_coords_table_type='norm8_log',
+ rpe_hidden_dim=512,
+ attn_type='cosine_mh',
+ rpe_output_type='sigmoid',
+ rpe_wd=False,
+ postnorm=True,
+ mlp_type='normal',
+ patch_embed_type='normal',
+ patch_merge_type='normal',
+ strid16=False,
+ checkpoint_blocks=[255, 255, 255, 255],
+ mlpfp32_layer_blocks=[[-1], [-1], [-1], [-1]],
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=-1,
+ use_shift=True,
+ rpe_interpolation='geo',
+ pretrain_window_size=[-1, -1, -1, -1],
+ **kwargs):
+ super().__init__()
+
+ self.pretrain_img_size = pretrain_img_size
+ self.depths = depths
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.rpe_interpolation = rpe_interpolation
+ self.mlp_ratio = mlp_ratio
+ self.endnorm_interval = endnorm_interval
+ self.use_mlp_norm_layers = use_mlp_norm_layers
+ self.relative_coords_table_type = relative_coords_table_type
+ self.rpe_hidden_dim = rpe_hidden_dim
+ self.rpe_output_type = rpe_output_type
+ self.rpe_wd = rpe_wd
+ self.attn_type = attn_type
+ self.postnorm = postnorm
+ self.mlp_type = mlp_type
+ self.strid16 = strid16
+
+ if isinstance(window_size, list):
+ pass
+ elif isinstance(window_size, int):
+ window_size = [window_size] * self.num_layers
+ else:
+ raise TypeError("We only support list or int for window size")
+
+ if isinstance(use_shift, list):
+ pass
+ elif isinstance(use_shift, bool):
+ use_shift = [use_shift] * self.num_layers
+ else:
+ raise TypeError("We only support list or bool for use_shift")
+
+ if isinstance(use_checkpoint, list):
+ pass
+ elif isinstance(use_checkpoint, bool):
+ use_checkpoint = [use_checkpoint] * self.num_layers
+ else:
+ raise TypeError("We only support list or bool for use_checkpoint")
+
+ # split image into non-overlapping patches
+ if patch_embed_type == 'normal':
+ self.patch_embed = PatchEmbed(
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+ elif patch_embed_type == 'resnetdln':
+ assert patch_size == 4, "check"
+ self.patch_embed = ResNetDLNPatchEmbed(
+ in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer)
+ elif patch_embed_type == 'resnetdnf':
+ assert patch_size == 4, "check"
+ self.patch_embed = ResNetDLNPatchEmbed(
+ in_chans=in_chans, embed_dim=embed_dim, norm_layer=None)
+ else:
+ raise NotImplementedError()
+ # absolute position embedding
+ if self.ape:
+ pretrain_img_size = to_2tuple(pretrain_img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
+
+ self.absolute_pos_embed = nn.Parameter(
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ if patch_merge_type == 'normal':
+ downsample_layer = PatchMerging
+ elif patch_merge_type == 'conv':
+ downsample_layer = ConvPatchMerging
+ else:
+ raise NotImplementedError()
+ # build layers
+ self.layers = nn.ModuleList()
+ num_features = []
+ for i_layer in range(self.num_layers):
+ cur_dim = int(embed_dim * 2 ** (i_layer - 1)) \
+ if (i_layer == self.num_layers - 1 and strid16) else \
+ int(embed_dim * 2 ** i_layer)
+ num_features.append(cur_dim)
+ if i_layer < self.num_layers - 2:
+ cur_downsample_layer = downsample_layer
+ elif i_layer == self.num_layers - 2:
+ if strid16:
+ cur_downsample_layer = PatchReduction1C
+ else:
+ cur_downsample_layer = downsample_layer
+ else:
+ cur_downsample_layer = None
+ layer = BasicLayer(
+ dim=cur_dim,
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size[i_layer],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=cur_downsample_layer,
+ use_checkpoint=use_checkpoint[i_layer],
+ checkpoint_blocks=checkpoint_blocks[i_layer],
+ init_values=init_values,
+ endnorm_interval=endnorm_interval,
+ use_mlp_norm=True if i_layer in use_mlp_norm_layers else False,
+ use_shift=use_shift[i_layer],
+ relative_coords_table_type=self.relative_coords_table_type,
+ rpe_hidden_dim=self.rpe_hidden_dim,
+ rpe_output_type=self.rpe_output_type,
+ attn_type=self.attn_type,
+ mlp_type=self.mlp_type,
+ mlpfp32_blocks=mlpfp32_layer_blocks[i_layer],
+ postnorm=self.postnorm,
+ pretrain_window_size=pretrain_window_size[i_layer]
+ )
+ self.layers.append(layer)
+
+ self.num_features = num_features
+
+ # add a norm layer for each output
+ for i_layer in out_indices:
+ layer = norm_layer(num_features[i_layer])
+ layer_name = f'norm{i_layer}'
+ self.add_module(layer_name, layer)
+
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+
+ if self.frozen_stages >= 1 and self.ape:
+ self.absolute_pos_embed.requires_grad = False
+
+ if self.frozen_stages >= 2:
+ self.pos_drop.eval()
+ for i in range(0, self.frozen_stages - 1):
+ m = self.layers[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ self.norm3.eval()
+ for param in self.norm3.parameters():
+ param.requires_grad = False
+
+ def _init_weights(m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ trunc_normal_(m.weight, std=.02)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ self.apply(_init_weights)
+ for bly in self.layers:
+ bly._init_block_norm_weights()
+
+ if isinstance(pretrained, str):
+ logger = None
+ load_checkpoint_swin(self, pretrained, strict=False, map_location='cpu',
+ logger=logger, rpe_interpolation=self.rpe_interpolation)
+ elif pretrained is None:
+ pass
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ """Forward function."""
+ x = self.patch_embed(x)
+
+ Wh, Ww = x.size(2), x.size(3)
+ if self.ape:
+ # interpolate the position embedding to the corresponding size
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
+ else:
+ x = x.flatten(2).transpose(1, 2)
+
+ x = self.pos_drop(x)
+
+ outs = []
+ for i in range(self.num_layers):
+ layer = self.layers[i]
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
+
+ if i in self.out_indices:
+ norm_layer = getattr(self, f'norm{i}')
+ x_out = norm_layer.float()(x_out.float())
+
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
+
+ outs.append(out)
+
+ return outs
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep layers freezed."""
+ super(SwinV2TransformerRPE2FC, self).train(mode)
+ self._freeze_stages()
diff --git a/main/pct_utils/pct_head.py b/main/pct_utils/pct_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..c53f58e08287a5c1e558212f89723fa0c2533b99
--- /dev/null
+++ b/main/pct_utils/pct_head.py
@@ -0,0 +1,208 @@
+import torch
+import torch.nn as nn
+
+from pct_utils.pct_tokenizer import PCT_Tokenizer
+from pct_utils.modules import MixerLayer, FCBlock, BasicBlock
+
+def constant_init(module, val, bias=0):
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.constant_(module.weight, val)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+def normal_init(module, mean=0, std=1, bias=0):
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.normal_(module.weight, mean, std)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+class PCT_Head(nn.Module):
+ """ Head of Pose Compositional Tokens.
+ paper ref: Zigang Geng et al. "Human Pose as
+ Compositional Tokens"
+
+ The pipelines of two stage during training and inference:
+
+ Tokenizer Stage & Train:
+ Joints -> (Img Guide) -> Encoder -> Codebook -> Decoder -> Recovered Joints
+ Loss: (Joints, Recovered Joints)
+ Tokenizer Stage & Test:
+ Joints -> (Img Guide) -> Encoder -> Codebook -> Decoder -> Recovered Joints
+
+ Classifer Stage & Train:
+ Img -> Classifier -> Predict Class -> Codebook -> Decoder -> Recovered Joints
+ Joints -> (Img Guide) -> Encoder -> Codebook -> Groundtruth Class
+ Loss: (Predict Class, Groundtruth Class), (Joints, Recovered Joints)
+ Classifer Stage & Test:
+ Img -> Classifier -> Predict Class -> Codebook -> Decoder -> Recovered Joints
+
+ Args:
+ stage_pct (str): Training stage (Tokenizer or Classifier).
+ in_channels (int): Feature Dim of the backbone feature.
+ image_size (tuple): Input image size.
+ num_joints (int): Number of annotated joints in the dataset.
+ cls_head (dict): Config for PCT classification head. Default: None.
+ tokenizer (dict): Config for PCT tokenizer. Default: None.
+ loss_keypoint (dict): Config for loss for training classifier. Default: None.
+ """
+
+ def __init__(self,
+ args,
+ stage_pct,
+ in_channels,
+ image_size,
+ num_joints,
+ cls_head=None,
+ tokenizer=None,
+ loss_keypoint=None,):
+ super().__init__()
+
+ self.image_size = image_size
+ self.stage_pct = stage_pct
+
+ self.guide_ratio = args.tokenizer_guide_ratio
+ self.img_guide = self.guide_ratio > 0.0
+
+ self.conv_channels = args.cls_head_conv_channels
+ self.hidden_dim = args.cls_head_hidden_dim
+
+ self.num_blocks = args.cls_head_num_blocks
+ self.hidden_inter_dim = args.cls_head_hidden_inter_dim
+ self.token_inter_dim = args.cls_head_token_inter_dim
+ self.dropout = args.cls_head_dropout
+
+ self.token_num = args.tokenizer_codebook_token_num
+ self.token_class_num = args.tokenizer_codebook_token_class_num
+
+ if stage_pct == "classifier":
+ self.conv_trans = self._make_transition_for_head(
+ in_channels, self.conv_channels)
+ self.conv_head = self._make_cls_head(args)
+
+ input_size = (image_size[0]//32)*(image_size[1]//32)
+ self.mixer_trans = FCBlock(
+ self.conv_channels * input_size,
+ self.token_num * self.hidden_dim)
+
+ self.mixer_head = nn.ModuleList(
+ [MixerLayer(self.hidden_dim, self.hidden_inter_dim,
+ self.token_num, self.token_inter_dim,
+ self.dropout) for _ in range(self.num_blocks)])
+ self.mixer_norm_layer = FCBlock(
+ self.hidden_dim, self.hidden_dim)
+
+ self.cls_pred_layer = nn.Linear(
+ self.hidden_dim, self.token_class_num)
+
+ self.tokenizer = PCT_Tokenizer(
+ args = args, stage_pct=stage_pct, num_joints=num_joints,
+ guide_ratio=self.guide_ratio, guide_channels=in_channels)
+
+ def forward(self, x, extra_x, joints=None, train=True):
+ """Forward function."""
+
+ if self.stage_pct == "classifier":
+ batch_size = x[-1].shape[0]
+ cls_feat = self.conv_head[0](self.conv_trans(x[-1]))
+
+ cls_feat = cls_feat.flatten(2).transpose(2,1).flatten(1)
+ cls_feat = self.mixer_trans(cls_feat)
+ cls_feat = cls_feat.reshape(batch_size, self.token_num, -1)
+
+ for mixer_layer in self.mixer_head:
+ cls_feat = mixer_layer(cls_feat)
+ cls_feat = self.mixer_norm_layer(cls_feat)
+
+ cls_logits = self.cls_pred_layer(cls_feat)
+
+ encoding_scores = cls_logits.topk(1, dim=2)[0]
+ cls_logits = cls_logits.flatten(0,1)
+ cls_logits_softmax = cls_logits.clone().softmax(1)
+ else:
+ encoding_scores = None
+ cls_logits = None
+ cls_logits_softmax = None
+
+ if not self.img_guide or \
+ (self.stage_pct == "classifier" and not train):
+ joints_feat = None
+ else:
+ joints_feat = self.extract_joints_feat(extra_x[-1], joints)
+
+ output_joints, cls_label, e_latent_loss, out_part_token_feat = \
+ self.tokenizer(joints, joints_feat, cls_logits_softmax, train=train)
+
+ if train:
+ return cls_logits, output_joints, cls_label, e_latent_loss
+ else:
+ return output_joints, encoding_scores, out_part_token_feat
+
+ def _make_transition_for_head(self, inplanes, outplanes):
+ transition_layer = [
+ nn.Conv2d(inplanes, outplanes, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(outplanes),
+ nn.ReLU(True)
+ ]
+ return nn.Sequential(*transition_layer)
+
+ def _make_cls_head(self, args):
+ feature_convs = []
+ feature_conv = self._make_layer(
+ BasicBlock,
+ args.cls_head_conv_channels,
+ args.cls_head_conv_channels,
+ args.cls_head_conv_num_blocks,
+ dilation=args.cls_head_dilation
+ )
+ feature_convs.append(feature_conv)
+
+ return nn.ModuleList(feature_convs)
+
+ def _make_layer(
+ self, block, inplanes, planes, blocks, stride=1, dilation=1):
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion, momentum=0.1),
+ )
+
+ layers = []
+ layers.append(block(inplanes, planes,
+ stride, downsample, dilation=dilation))
+ inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(inplanes, planes, dilation=dilation))
+
+ return nn.Sequential(*layers)
+
+ def extract_joints_feat(self, feature_map, joint_coords):
+ assert self.image_size[1] == self.image_size[0], \
+ 'If you want to use a rectangle input, ' \
+ 'please carefully check the length and width below.'
+ batch_size, _, _, height = feature_map.shape
+ stride = self.image_size[0] / feature_map.shape[-1]
+ joint_x = (joint_coords[:,:,0] / stride + 0.5).int()
+ joint_y = (joint_coords[:,:,1] / stride + 0.5).int()
+ joint_x = joint_x.clamp(0, feature_map.shape[-1] - 1)
+ joint_y = joint_y.clamp(0, feature_map.shape[-2] - 1)
+ joint_indices = (joint_y * height + joint_x).long()
+
+ flattened_feature_map = feature_map.clone().flatten(2)
+ joint_features = flattened_feature_map[
+ torch.arange(batch_size).unsqueeze(1), :, joint_indices]
+
+ return joint_features
+
+ def init_weights(self):
+ if self.stage_pct == "classifier":
+ self.tokenizer.eval()
+ for name, params in self.tokenizer.named_parameters():
+ params.requires_grad = False
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ normal_init(m, std=0.001, bias=0)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
\ No newline at end of file
diff --git a/main/pct_utils/pct_tokenizer.py b/main/pct_utils/pct_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b05318b9d8fc6ec6161eff69647608161db4ab1
--- /dev/null
+++ b/main/pct_utils/pct_tokenizer.py
@@ -0,0 +1,315 @@
+# --------------------------------------------------------
+# Pose Compositional Tokens
+# Written by Zigang Geng (zigang@mail.ustc.edu.cn)
+# --------------------------------------------------------
+
+import os
+import math
+import warnings
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.distributed as dist
+
+from pct_utils.modules import MixerLayer
+
+def _trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2)
+
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+
+ NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
+ applied while sampling the normal with mean/std applied, therefore a, b args
+ should be adjusted to match the range of mean, std args.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.trunc_normal_(w)
+ """
+ with torch.no_grad():
+ return _trunc_normal_(tensor, mean, std, a, b)
+
+class PCT_Tokenizer(nn.Module):
+ """ Tokenizer of Pose Compositional Tokens.
+ paper ref: Zigang Geng et al. "Human Pose as
+ Compositional Tokens"
+
+ Args:
+ stage_pct (str): Training stage (Tokenizer or Classifier).
+ tokenizer (list): Config about the tokenizer.
+ num_joints (int): Number of annotated joints in the dataset.
+ guide_ratio (float): The ratio of image guidance.
+ guide_channels (int): Feature Dim of the image guidance.
+ """
+
+ def __init__(self,
+ args,
+ stage_pct,
+ num_joints=14,
+ theta_dim=2,
+ guide_ratio=0,
+ guide_channels=0):
+ super().__init__()
+
+ self.stage_pct = stage_pct
+ self.guide_ratio = guide_ratio
+ self.num_joints = num_joints
+ self.theta_dim = theta_dim
+
+ self.drop_rate = args.tokenizer_encoder_drop_rate
+ self.enc_num_blocks = args.tokenizer_encoder_num_blocks
+ self.enc_hidden_dim = args.tokenizer_encoder_hidden_dim
+ self.enc_token_inter_dim = args.tokenizer_encoder_token_inter_dim
+ self.enc_hidden_inter_dim = args.tokenizer_encoder_hidden_inter_dim
+ self.enc_dropout = args.tokenizer_encoder_dropout
+
+ self.dec_num_blocks = args.tokenizer_decoder_num_blocks
+ self.dec_hidden_dim = args.tokenizer_decoder_hidden_dim
+ self.dec_token_inter_dim = args.tokenizer_decoder_token_inter_dim
+ self.dec_hidden_inter_dim = args.tokenizer_decoder_hidden_inter_dim
+ self.dec_dropout = args.tokenizer_decoder_dropout
+
+ self.token_num = args.tokenizer_codebook_token_num
+ self.token_class_num = args.tokenizer_codebook_token_class_num
+ self.token_dim = args.tokenizer_codebook_token_dim
+ self.decay = args.tokenizer_codebook_ema_decay
+
+ self.invisible_token = nn.Parameter(
+ torch.zeros(1, 1, self.enc_hidden_dim))
+ trunc_normal_(self.invisible_token, mean=0., std=0.02, a=-0.02, b=0.02)
+
+ if self.guide_ratio > 0:
+ self.start_img_embed = nn.Linear(
+ guide_channels, int(self.enc_hidden_dim*self.guide_ratio))
+ self.start_embed = nn.Linear(
+ 2, int(self.enc_hidden_dim*(1-self.guide_ratio)))
+
+ self.encoder = nn.ModuleList(
+ [MixerLayer(self.enc_hidden_dim, self.enc_hidden_inter_dim,
+ self.num_joints, self.enc_token_inter_dim,
+ self.enc_dropout) for _ in range(self.enc_num_blocks)])
+ self.encoder_layer_norm = nn.LayerNorm(self.enc_hidden_dim)
+
+ self.token_mlp = nn.Linear(
+ self.num_joints, self.token_num)
+ self.feature_embed = nn.Linear(
+ self.enc_hidden_dim, self.token_dim)
+
+ self.register_buffer('codebook',
+ torch.empty(self.token_class_num, self.token_dim))
+ self.codebook.data.normal_()
+ self.register_buffer('ema_cluster_size',
+ torch.zeros(self.token_class_num))
+ self.register_buffer('ema_w',
+ torch.empty(self.token_class_num, self.token_dim))
+ self.ema_w.data.normal_()
+
+ self.decoder_token_mlp = nn.Linear(
+ self.token_num, self.num_joints)
+ self.decoder_start = nn.Linear(
+ self.token_dim, self.dec_hidden_dim)
+
+ self.decoder = nn.ModuleList(
+ [MixerLayer(self.dec_hidden_dim, self.dec_hidden_inter_dim,
+ self.num_joints, self.dec_token_inter_dim,
+ self.dec_dropout) for _ in range(self.dec_num_blocks)])
+ self.decoder_layer_norm = nn.LayerNorm(self.dec_hidden_dim)
+
+ self.recover_embed = nn.Linear(self.dec_hidden_dim, 2)
+
+ def forward(self, joints, joints_feature, cls_logits, train=True):
+ """Forward function. """
+
+ if train or self.stage_pct == "tokenizer":
+ # Encoder of Tokenizer, Get the PCT groundtruth class labels.
+ bs, num_joints, _ = joints.shape
+ device = joints.device
+ joints_coord, joints_visible, bs \
+ = joints[:,:,:-1], joints[:,:,-1].bool(), joints.shape[0]
+
+ encode_feat = self.start_embed(joints_coord)
+ if self.guide_ratio > 0:
+ encode_img_feat = self.start_img_embed(joints_feature)
+ encode_feat = torch.cat((encode_feat, encode_img_feat), dim=2)
+
+ if train and self.stage_pct == "tokenizer":
+ rand_mask_ind = torch.rand(
+ joints_visible.shape, device=joints.device) > self.drop_rate
+ joints_visible = torch.logical_and(rand_mask_ind, joints_visible)
+
+ mask_tokens = self.invisible_token.expand(bs, joints.shape[1], -1)
+ w = joints_visible.unsqueeze(-1).type_as(mask_tokens)
+ encode_feat = encode_feat * w + mask_tokens * (1 - w)
+
+ for num_layer in self.encoder:
+ encode_feat = num_layer(encode_feat)
+ encode_feat = self.encoder_layer_norm(encode_feat)
+
+ encode_feat = encode_feat.transpose(2, 1)
+ encode_feat = self.token_mlp(encode_feat).transpose(2, 1)
+ encode_feat = self.feature_embed(encode_feat).flatten(0,1)
+
+ distances = torch.sum(encode_feat**2, dim=1, keepdim=True) \
+ + torch.sum(self.codebook**2, dim=1) \
+ - 2 * torch.matmul(encode_feat, self.codebook.t())
+
+ encoding_indices = torch.argmin(distances, dim=1)
+ encodings = torch.zeros(
+ encoding_indices.shape[0], self.token_class_num, device=joints.device)
+ encodings.scatter_(1, encoding_indices.unsqueeze(1), 1)
+ else:
+ # here it suppose cls_logits shape [bs * token_num * token_cls_num]
+ # predict prob of each token 0,1,2...M-1 belongs to entries 0,1,2...V-1
+ # see paper
+ bs = cls_logits.shape[0] // self.token_num
+ encoding_indices = None
+
+ if self.stage_pct == "classifier":
+ part_token_feat = torch.matmul(cls_logits, self.codebook)
+ else:
+ part_token_feat = torch.matmul(encodings, self.codebook)
+
+ if train and self.stage_pct == "tokenizer":
+ # Updating Codebook using EMA
+ dw = torch.matmul(encodings.t(), encode_feat.detach())
+ # sync
+ n_encodings, n_dw = encodings.numel(), dw.numel()
+ encodings_shape, dw_shape = encodings.shape, dw.shape
+ combined = torch.cat((encodings.flatten(), dw.flatten()))
+ dist.all_reduce(combined) # math sum
+ sync_encodings, sync_dw = torch.split(combined, [n_encodings, n_dw])
+ sync_encodings, sync_dw = \
+ sync_encodings.view(encodings_shape), sync_dw.view(dw_shape)
+
+ self.ema_cluster_size = self.ema_cluster_size * self.decay + \
+ (1 - self.decay) * torch.sum(sync_encodings, 0)
+
+ n = torch.sum(self.ema_cluster_size.data)
+ self.ema_cluster_size = (
+ (self.ema_cluster_size + 1e-5)
+ / (n + self.token_class_num * 1e-5) * n)
+
+ self.ema_w = self.ema_w * self.decay + (1 - self.decay) * sync_dw
+ self.codebook = self.ema_w / self.ema_cluster_size.unsqueeze(1)
+ e_latent_loss = F.mse_loss(part_token_feat.detach(), encode_feat)
+ part_token_feat = encode_feat + (part_token_feat - encode_feat).detach()
+ else:
+ e_latent_loss = None
+
+ # Decoder of Tokenizer, Recover the joints.
+ part_token_feat = part_token_feat.view(bs, -1, self.token_dim)
+
+ # Store part token
+ out_part_token_feat = part_token_feat.clone().detach()
+
+ part_token_feat = part_token_feat.transpose(2,1)
+ part_token_feat = self.decoder_token_mlp(part_token_feat).transpose(2,1)
+ decode_feat = self.decoder_start(part_token_feat)
+
+ for num_layer in self.decoder:
+ decode_feat = num_layer(decode_feat)
+ decode_feat = self.decoder_layer_norm(decode_feat)
+
+ recoverd_joints = self.recover_embed(decode_feat)
+
+ return recoverd_joints, encoding_indices, e_latent_loss, out_part_token_feat
+
+ def init_weights(self, pretrained=""):
+ """Initialize model weights."""
+
+ parameters_names = set()
+ for name, _ in self.named_parameters():
+ parameters_names.add(name)
+
+ buffers_names = set()
+ for name, _ in self.named_buffers():
+ buffers_names.add(name)
+
+ if os.path.isfile(pretrained):
+ assert (self.stage_pct == "classifier"), \
+ "Training tokenizer does not need to load model"
+ pretrained_state_dict = torch.load(pretrained,
+ map_location=lambda storage, loc: storage)
+
+ need_init_state_dict = {}
+
+ if 'state_dict' in pretrained_state_dict:
+ key = 'state_dict'
+ else:
+ key = 'model'
+ for name, m in pretrained_state_dict[key].items():
+ if 'keypoint_head.tokenizer.' in name:
+ name = name.replace('keypoint_head.tokenizer.', '')
+ if name in parameters_names or name in buffers_names:
+ need_init_state_dict[name] = m
+ self.load_state_dict(need_init_state_dict, strict=True)
+ else:
+ if self.stage_pct == "classifier":
+ print('If you are training a classifier, '\
+ 'must check that the well-trained tokenizer '\
+ 'is located in the correct path.')
+
+
+def save_checkpoint(model, optimizer, epoch, loss, filepath):
+ checkpoint = {
+ 'epoch': epoch,
+ 'model_state_dict': model.state_dict(),
+ 'optimizer_state_dict': optimizer.state_dict(),
+ 'loss': loss
+ }
+ torch.save(checkpoint, filepath)
+ print(f"Checkpoint saved at {filepath}")
+
+def load_checkpoint(model, optimizer, filepath):
+ checkpoint = torch.load(filepath)
+ model.load_state_dict(checkpoint['model_state_dict'])
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+ epoch = checkpoint['epoch']
+ loss = checkpoint['loss']
+
+ print(f"Checkpoint loaded from {filepath}. Resuming training from epoch {epoch} with loss {loss}")
+
+ return epoch, loss
diff --git a/main/postometro.py b/main/postometro.py
new file mode 100644
index 0000000000000000000000000000000000000000..aacd40b450df7299b9a5dc9dd100b97910a60cc5
--- /dev/null
+++ b/main/postometro.py
@@ -0,0 +1,305 @@
+# ----------------------------------------------------------------------------------------------
+# FastMETRO Official Code
+# Copyright (c) POSTECH Algorithmic Machine Intelligence Lab. (P-AMI Lab.) All Rights Reserved
+# Licensed under the MIT license.
+# ----------------------------------------------------------------------------------------------
+
+# ----------------------------------------------------------------------------------------------
+# PostoMETRO Official Code
+# Copyright (c) MIRACLE Lab. All Rights Reserved
+# Licensed under the MIT license.
+# ----------------------------------------------------------------------------------------------
+
+from __future__ import absolute_import, division, print_function
+import torch
+import numpy as np
+import argparse
+import os
+import os.path as osp
+from torch import nn
+from postometro_utils.smpl import Mesh
+from postometro_utils.transformer import build_transformer
+from postometro_utils.positional_encoding import build_position_encoding
+from postometro_utils.modules import FCBlock, MixerLayer
+from pct_utils.pct import PCT
+from pct_utils.pct_backbone import SwinV2TransformerRPE2FC
+from postometro_utils.pose_resnet import get_pose_net as get_pose_resnet
+from postometro_utils.pose_resnet_config import config as resnet_config
+from postometro_utils.pose_hrnet import get_pose_hrnet
+from postometro_utils.pose_hrnet_config import _C as hrnet_config
+from postometro_utils.pose_hrnet_config import update_config as hrnet_update_config
+
+CUR_DIR = osp.dirname(os.path.abspath(__file__))
+
+class PostoMETRO(nn.Module):
+ """PostoMETRO for 3D human pose and mesh reconstruction from a single RGB image"""
+ def __init__(self, args, backbone, mesh_sampler, pct = None, num_joints=14, num_vertices=431):
+ """
+ Parameters:
+ - args: Arguments
+ - backbone: CNN Backbone used to extract image features from the given image
+ - mesh_sampler: Mesh Sampler used in the coarse-to-fine mesh upsampling
+ - num_joints: The number of joint tokens used in the transformer decoder
+ - num_vertices: The number of vertex tokens used in the transformer decoder
+ """
+ super().__init__()
+ self.args = args
+ self.backbone = backbone
+ self.mesh_sampler = mesh_sampler
+ self.num_joints = num_joints
+ self.num_vertices = num_vertices
+
+ # the number of transformer layers, set to default
+ num_enc_layers = 3
+ num_dec_layers = 3
+
+ # configurations for the first transformer
+ self.transformer_config_1 = {"model_dim": args.model_dim_1, "dropout": args.transformer_dropout, "nhead": args.transformer_nhead,
+ "feedforward_dim": args.feedforward_dim_1, "num_enc_layers": num_enc_layers, "num_dec_layers": num_dec_layers,
+ "pos_type": args.pos_type}
+ # configurations for the second transformer
+ self.transformer_config_2 = {"model_dim": args.model_dim_2, "dropout": args.transformer_dropout, "nhead": args.transformer_nhead,
+ "feedforward_dim": args.feedforward_dim_2, "num_enc_layers": num_enc_layers, "num_dec_layers": num_dec_layers,
+ "pos_type": args.pos_type}
+ # build transformers
+ self.transformer_1 = build_transformer(self.transformer_config_1)
+ self.transformer_2 = build_transformer(self.transformer_config_2)
+
+ # dimensionality reduction
+ self.dim_reduce_enc_cam = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"])
+ self.dim_reduce_enc_img = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"])
+ self.dim_reduce_dec = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"])
+
+ # token embeddings
+ self.cam_token_embed = nn.Embedding(1, self.transformer_config_1["model_dim"])
+ self.joint_token_embed = nn.Embedding(self.num_joints, self.transformer_config_1["model_dim"])
+ self.vertex_token_embed = nn.Embedding(self.num_vertices, self.transformer_config_1["model_dim"])
+ # positional encodings
+ self.position_encoding_1 = build_position_encoding(pos_type=self.transformer_config_1['pos_type'], hidden_dim=self.transformer_config_1['model_dim'])
+ self.position_encoding_2 = build_position_encoding(pos_type=self.transformer_config_2['pos_type'], hidden_dim=self.transformer_config_2['model_dim'])
+ # estimators
+ self.xyz_regressor = nn.Linear(self.transformer_config_2["model_dim"], 3)
+ self.cam_predictor = nn.Linear(self.transformer_config_2["model_dim"], 3)
+
+ # 1x1 Convolution
+ self.conv_1x1 = nn.Conv2d(args.conv_1x1_dim, self.transformer_config_1["model_dim"], kernel_size=1)
+
+ # attention mask
+ zeros_1 = torch.tensor(np.zeros((num_vertices, num_joints)).astype(bool))
+ zeros_2 = torch.tensor(np.zeros((num_joints, (num_joints + num_vertices))).astype(bool))
+ adjacency_indices = torch.load(osp.join(CUR_DIR, 'data/smpl_431_adjmat_indices.pt'))
+ adjacency_matrix_value = torch.load(osp.join(CUR_DIR, 'data/smpl_431_adjmat_values.pt'))
+ adjacency_matrix_size = torch.load(osp.join(CUR_DIR, 'data/smpl_431_adjmat_size.pt'))
+ adjacency_matrix = torch.sparse_coo_tensor(adjacency_indices, adjacency_matrix_value, size=adjacency_matrix_size).to_dense()
+ temp_mask_1 = (adjacency_matrix == 0)
+ temp_mask_2 = torch.cat([zeros_1, temp_mask_1], dim=1)
+ self.attention_mask = torch.cat([zeros_2, temp_mask_2], dim=0)
+
+ # learnable upsampling layer is used (from coarse mesh to intermediate mesh); for visually pleasing mesh result
+ ### pre-computed upsampling matrix is used (from intermediate mesh to fine mesh); to reduce optimization difficulty
+ self.coarse2intermediate_upsample = nn.Linear(431, 1723)
+
+ # using extra token
+ self.pct = None
+ if pct is not None:
+ self.pct = pct
+ # +1 to align with uncertainty score
+ self.token_mixer = FCBlock(args.tokenizer_codebook_token_dim + 1, self.transformer_config_1["model_dim"])
+ self.start_embed = nn.Linear(512, args.enc_hidden_dim)
+ self.encoder = nn.ModuleList(
+ [MixerLayer(args.enc_hidden_dim, args.enc_hidden_inter_dim,
+ args.num_joints, args.token_inter_dim,
+ args.enc_dropout) for _ in range(args.enc_num_blocks)])
+ self.encoder_layer_norm = nn.LayerNorm(args.enc_hidden_dim)
+ self.token_mlp = nn.Linear(args.num_joints, args.token_num)
+ self.dim_reduce_enc_pct = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"])
+
+
+ def forward(self, images):
+ device = images.device
+ batch_size = images.size(0)
+
+ # preparation
+ cam_token = self.cam_token_embed.weight.unsqueeze(1).repeat(1, batch_size, 1) # 1 X batch_size X 512
+ jv_tokens = torch.cat([self.joint_token_embed.weight, self.vertex_token_embed.weight], dim=0).unsqueeze(1).repeat(1, batch_size, 1) # (num_joints + num_vertices) X batch_size X 512
+ attention_mask = self.attention_mask.to(device) # (num_joints + num_vertices) X (num_joints + num_vertices)
+
+ pct_token = None
+ if self.pct is not None:
+ pct_out = self.pct(images, None, train=False)
+ pct_pose = pct_out['part_token_feat'].clone()
+
+ encode_feat = self.start_embed(pct_pose) # 2, 17, 512
+ for num_layer in self.encoder:
+ encode_feat = num_layer(encode_feat)
+ encode_feat = self.encoder_layer_norm(encode_feat)
+ encode_feat = encode_feat.transpose(2, 1)
+ encode_feat = self.token_mlp(encode_feat).transpose(2, 1)
+ pct_token_out = encode_feat.permute(1,0,2)
+
+ pct_score = pct_out['encoding_scores']
+ pct_score = pct_score.permute(1,0,2)
+ pct_token = torch.cat([pct_token_out, pct_score], dim = -1)
+ pct_token = self.token_mixer(pct_token) # [b, 34, 512]
+
+ # extract image features through a CNN backbone
+ _img_features = self.backbone(images) # batch_size X 2048 X 7 X 7
+ _, _, h, w = _img_features.shape
+ img_features = self.conv_1x1(_img_features).flatten(2).permute(2, 0, 1) # 49 X batch_size X 512
+
+ # positional encodings
+ pos_enc_1 = self.position_encoding_1(batch_size, h, w, device).flatten(2).permute(2, 0, 1) # 49 X batch_size X 512
+ pos_enc_2 = self.position_encoding_2(batch_size, h, w, device).flatten(2).permute(2, 0, 1) # 49 X batch_size X 128
+
+ # first transformer encoder-decoder
+ cam_features_1, enc_img_features_1, jv_features_1, pct_features_1 = self.transformer_1(img_features, cam_token, jv_tokens, pos_enc_1, pct_token = pct_token, attention_mask=attention_mask)
+
+ # progressive dimensionality reduction
+ reduced_cam_features_1 = self.dim_reduce_enc_cam(cam_features_1) # 1 X batch_size X 128
+ reduced_enc_img_features_1 = self.dim_reduce_enc_img(enc_img_features_1) # 49 X batch_size X 128
+ reduced_jv_features_1 = self.dim_reduce_dec(jv_features_1) # (num_joints + num_vertices) X batch_size X 128
+ reduced_pct_features_1 = None
+ if pct_features_1 is not None:
+ reduced_pct_features_1 = self.dim_reduce_enc_pct(pct_features_1)
+
+ # second transformer encoder-decoder
+ cam_features_2, _, jv_features_2,_ = self.transformer_2(reduced_enc_img_features_1, reduced_cam_features_1, reduced_jv_features_1, pos_enc_2, pct_token = reduced_pct_features_1, attention_mask=attention_mask)
+
+ # estimators
+ pred_cam = self.cam_predictor(cam_features_2).view(batch_size, 3) # batch_size X 3
+
+ pred_3d_coordinates = self.xyz_regressor(jv_features_2.transpose(0, 1)) # batch_size X (num_joints + num_vertices) X 3
+ pred_3d_joints = pred_3d_coordinates[:,:self.num_joints,:] # batch_size X num_joints X 3
+ pred_3d_vertices_coarse = pred_3d_coordinates[:,self.num_joints:,:] # batch_size X num_vertices(coarse) X 3
+
+ # coarse-to-intermediate mesh upsampling
+ pred_3d_vertices_intermediate = self.coarse2intermediate_upsample(pred_3d_vertices_coarse.transpose(1,2)).transpose(1,2) # batch_size X num_vertices(intermediate) X 3
+ # intermediate-to-fine mesh upsampling
+ pred_3d_vertices_fine = self.mesh_sampler.upsample(pred_3d_vertices_intermediate, n1=1, n2=0) # batch_size X num_vertices(fine) X 3
+
+ out = {}
+ out['pred_cam'] = pred_cam
+ out['pct_pose'] = pct_out['pred_pose'] if self.pct is not None else torch.zeros((batch_size, 34, 2)).cuda(device)
+ out['pred_3d_joints'] = pred_3d_joints
+ out['pred_3d_vertices_coarse'] = pred_3d_vertices_coarse
+ out['pred_3d_vertices_intermediate'] = pred_3d_vertices_intermediate
+ out['pred_3d_vertices_fine'] = pred_3d_vertices_fine
+
+ return out
+
+
+defaults_args = argparse.Namespace(
+ pos_type = 'sine',
+ transformer_dropout = 0.1,
+ transformer_nhead = 8,
+ conv_1x1_dim = 2048,
+ tokenizer_codebook_token_dim = 512,
+ model_dim_1 = 512,
+ feedforward_dim_1 = 2048,
+ model_dim_2 = 128,
+ feedforward_dim_2 = 512,
+ enc_hidden_dim = 512,
+ enc_hidden_inter_dim = 512,
+ token_inter_dim = 64,
+ enc_dropout = 0.0,
+ enc_num_blocks = 4,
+ num_joints = 34,
+ token_num = 34
+)
+
+default_pct_args = argparse.Namespace(
+ pct_backbone_channel = 1536,
+ tokenizer_guide_ratio=0.5,
+ cls_head_conv_channels=256,
+ cls_head_hidden_dim=64,
+ cls_head_num_blocks=4,
+ cls_head_hidden_inter_dim=256,
+ cls_head_token_inter_dim=64,
+ cls_head_dropout=0.0,
+ cls_head_conv_num_blocks=2,
+ cls_head_dilation=1,
+ # tokenzier
+ tokenizer_encoder_drop_rate=0.2,
+ tokenizer_encoder_num_blocks=4,
+ tokenizer_encoder_hidden_dim=512,
+ tokenizer_encoder_token_inter_dim=64,
+ tokenizer_encoder_hidden_inter_dim=512,
+ tokenizer_encoder_dropout=0.0,
+ tokenizer_decoder_num_blocks=1,
+ tokenizer_decoder_hidden_dim=32,
+ tokenizer_decoder_token_inter_dim=64,
+ tokenizer_decoder_hidden_inter_dim=64,
+ tokenizer_decoder_dropout=0.0,
+ tokenizer_codebook_token_num=34,
+ tokenizer_codebook_token_dim=512,
+ tokenizer_codebook_token_class_num=2048,
+ tokenizer_codebook_ema_decay=0.9,
+)
+
+backbone_config=dict(
+ embed_dim=192,
+ depths=[2, 2, 18, 2],
+ num_heads=[6, 12, 24, 48],
+ window_size=[16, 16, 16, 8],
+ pretrain_window_size=[12, 12, 12, 6],
+ ape=False,
+ drop_path_rate=0.5,
+ patch_norm=True,
+ use_checkpoint=True,
+ rpe_interpolation='geo',
+ use_shift=[True, True, False, False],
+ relative_coords_table_type='norm8_log_bylayer',
+ attn_type='cosine_mh',
+ rpe_output_type='sigmoid',
+ postnorm=True,
+ mlp_type='normal',
+ out_indices=(3,),
+ patch_embed_type='normal',
+ patch_merge_type='normal',
+ strid16=False,
+ frozen_stages=5,
+)
+
+def get_model(backbone_str = 'resnet50', device = torch.device('cpu'), checkpoint_file = None):
+ if backbone_str == 'hrnet-w48':
+ defaults_args.conv_1x1_dim = 384
+ # update hrnet config by yaml
+ hrnet_yaml = osp.join(CUR_DIR,'postometro_utils', 'pose_w48_256x192_adam_lr1e-3.yaml')
+ hrnet_update_config(hrnet_config, hrnet_yaml)
+ backbone = get_pose_hrnet(hrnet_config, None)
+ else:
+ backbone = get_pose_resnet(resnet_config, is_train=False)
+ mesh_upsampler = Mesh(device=device)
+ pct_swin_backbone = SwinV2TransformerRPE2FC(**backbone_config)
+ # initialize pct head
+ pct = PCT(default_pct_args, pct_swin_backbone, 'classifier', default_pct_args.pct_backbone_channel, (256, 256), 17, None, None).to(device)
+ model = PostoMETRO(defaults_args, backbone, mesh_upsampler, pct=pct).to(device)
+ print("[INFO] model loaded, params: {}, {}".format(backbone_str, device))
+ if checkpoint_file:
+ cpu_device = torch.device('cpu')
+ state_dict = torch.load(checkpoint_file, map_location=cpu_device)
+ model.load_state_dict(state_dict, strict=True)
+ del state_dict
+ print("[INFO] checkpoint loaded, params: {}, {}".format(backbone_str, device))
+ return model
+
+if __name__ == '__main__':
+ test_model = get_model(device=torch.device('cuda'))
+ images = torch.randn(1,3,256,256).to(torch.device('cuda'))
+ test_out = test_model(images)
+ print("[TEST] resnet50, cuda : pass")
+
+ test_model = get_model()
+ images = torch.randn(1,3,256,256).to()
+ test_out = test_model(images)
+ print("[TEST] resnet50, cpu : pass")
+
+ test_model = get_model(backbone_str='hrnet-w48', device=torch.device('cuda'))
+ images = torch.randn(1,3,256,256).to(torch.device('cuda'))
+ test_out = test_model(images)
+ print("[TEST] hrnet-w48, cuda : pass")
+
+ test_model = get_model(backbone_str='hrnet-w48')
+ images = torch.randn(1,3,256,256).to()
+ test_out = test_model(images)
+ print("[TEST] hrnet-w48, cpu : pass")
diff --git a/main/postometro_utils/__pycache__/geometric_layers.cpython-39.pyc b/main/postometro_utils/__pycache__/geometric_layers.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a84550f3a4bdd7fef159896e31002a97dc8b171a
Binary files /dev/null and b/main/postometro_utils/__pycache__/geometric_layers.cpython-39.pyc differ
diff --git a/main/postometro_utils/__pycache__/modules.cpython-39.pyc b/main/postometro_utils/__pycache__/modules.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..98c14c460770f9de794e1635584e84a55d01480c
Binary files /dev/null and b/main/postometro_utils/__pycache__/modules.cpython-39.pyc differ
diff --git a/main/postometro_utils/__pycache__/pose_hrnet.cpython-39.pyc b/main/postometro_utils/__pycache__/pose_hrnet.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f0c47a615e65a6db352920d52d52a49107492e5
Binary files /dev/null and b/main/postometro_utils/__pycache__/pose_hrnet.cpython-39.pyc differ
diff --git a/main/postometro_utils/__pycache__/pose_hrnet_config.cpython-39.pyc b/main/postometro_utils/__pycache__/pose_hrnet_config.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5b0fcdf7c9c0294f5b0142084741debbcf4481cd
Binary files /dev/null and b/main/postometro_utils/__pycache__/pose_hrnet_config.cpython-39.pyc differ
diff --git a/main/postometro_utils/__pycache__/pose_resnet.cpython-39.pyc b/main/postometro_utils/__pycache__/pose_resnet.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b7928135ef3c7b5be652144ed03a3024dca85127
Binary files /dev/null and b/main/postometro_utils/__pycache__/pose_resnet.cpython-39.pyc differ
diff --git a/main/postometro_utils/__pycache__/pose_resnet_config.cpython-39.pyc b/main/postometro_utils/__pycache__/pose_resnet_config.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c121203e49e494758b826a7022484360bbad314d
Binary files /dev/null and b/main/postometro_utils/__pycache__/pose_resnet_config.cpython-39.pyc differ
diff --git a/main/postometro_utils/__pycache__/positional_encoding.cpython-39.pyc b/main/postometro_utils/__pycache__/positional_encoding.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..197a81e24a0f508e670f9413f7c51aa82f22907f
Binary files /dev/null and b/main/postometro_utils/__pycache__/positional_encoding.cpython-39.pyc differ
diff --git a/main/postometro_utils/__pycache__/renderer_pyrender.cpython-39.pyc b/main/postometro_utils/__pycache__/renderer_pyrender.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f56450a6e9b6213a210ed9b21b23eac5429f6839
Binary files /dev/null and b/main/postometro_utils/__pycache__/renderer_pyrender.cpython-39.pyc differ
diff --git a/main/postometro_utils/__pycache__/smpl.cpython-39.pyc b/main/postometro_utils/__pycache__/smpl.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..22203612db68c6bac6c09426e832ef46dad83786
Binary files /dev/null and b/main/postometro_utils/__pycache__/smpl.cpython-39.pyc differ
diff --git a/main/postometro_utils/__pycache__/transformer.cpython-39.pyc b/main/postometro_utils/__pycache__/transformer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e4b98396b2ae5d4d89c7461be3875058397574d
Binary files /dev/null and b/main/postometro_utils/__pycache__/transformer.cpython-39.pyc differ
diff --git a/main/postometro_utils/geometric_layers.py b/main/postometro_utils/geometric_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8b59ca8abcac164ff695dbfa700f273d0d156ef
--- /dev/null
+++ b/main/postometro_utils/geometric_layers.py
@@ -0,0 +1,679 @@
+# ----------------------------------------------------------------------------------------------
+# METRO (https://github.com/microsoft/MeshTransformer)
+# Copyright (c) Microsoft Corporation. All Rights Reserved [see https://github.com/microsoft/MeshTransformer/blob/main/LICENSE for details]
+# Licensed under the MIT license.
+# ----------------------------------------------------------------------------------------------
+"""
+Useful geometric operations, e.g. Orthographic projection and a differentiable Rodrigues formula
+
+Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
+"""
+
+import torch
+import torch.nn.functional as F
+
+def rodrigues(theta):
+ """Convert axis-angle representation to rotation matrix.
+ Args:
+ theta: size = [B, 3]
+ Returns:
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
+ """
+ l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
+ angle = torch.unsqueeze(l1norm, -1)
+ normalized = torch.div(theta, angle)
+ angle = angle * 0.5
+ v_cos = torch.cos(angle)
+ v_sin = torch.sin(angle)
+ quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
+ return quat2mat(quat)
+
+def quat2mat(quat):
+ """Convert quaternion coefficients to rotation matrix.
+ Args:
+ quat: size = [B, 4] 4 <===>(w, x, y, z)
+ Returns:
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
+ """
+ norm_quat = quat
+ norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
+ w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
+
+ B = quat.size(0)
+
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
+ wx, wy, wz = w*x, w*y, w*z
+ xy, xz, yz = x*y, x*z, y*z
+
+ rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
+ 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
+ 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
+ return rotMat
+
+def orthographic_projection(X, camera):
+ """Perform orthographic projection of 3D points X using the camera parameters
+ Args:
+ X: size = [B, N, 3]
+ camera: size = [B, 3]
+ Returns:
+ Projected 2D points -- size = [B, N, 2]
+ """
+ camera = camera.view(-1, 1, 3)
+ X_trans = X[:, :, :2] + camera[:, :, 1:]
+ shape = X_trans.shape
+ X_2d = (camera[:, :, 0] * X_trans.view(shape[0], -1)).view(shape)
+ return X_2d
+
+def orthographic_projection_reshape(X, camera):
+ """Perform orthographic projection of 3D points X using the camera parameters
+ Args:
+ X: size = [B, N, 3]
+ camera: size = [B, 3]
+ Returns:
+ Projected 2D points -- size = [B, N, 2]
+ """
+ camera = camera.reshape(-1, 1, 3)
+ X_trans = X[:, :, :2] + camera[:, :, 1:]
+ shape = X_trans.shape
+ X_2d = (camera[:, :, 0] * X_trans.reshape(shape[0], -1)).reshape(shape)
+ return X_2d
+
+def orthographic_projection_reshape(X, camera):
+ """Perform orthographic projection of 3D points X using the camera parameters
+ Args:
+ X: size = [B, N, 3]
+ camera: size = [B, 3]
+ Returns:
+ Projected 2D points -- size = [B, N, 2]
+ """
+ camera = camera.reshape(-1, 1, 3)
+ X_trans = X[:, :, :2] + camera[:, :, 1:]
+ shape = X_trans.shape
+ X_2d = (camera[:, :, 0] * X_trans.reshape(shape[0], -1)).reshape(shape)
+ return X_2d
+
+
+def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
+ """
+ Return a tensor where each element has the absolute value taken from the,
+ corresponding element of a, with sign taken from the corresponding
+ element of b. This is like the standard copysign floating-point operation,
+ but is not careful about negative 0 and NaN.
+
+ Args:
+ a: source tensor.
+ b: tensor whose signs will be used, of the same shape as a.
+
+ Returns:
+ Tensor of the same shape as a with the signs of b.
+ """
+ signs_differ = (a < 0) != (b < 0)
+ return torch.where(signs_differ, -a, a)
+
+
+def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
+ """
+ Returns torch.sqrt(torch.max(0, x))
+ but with a zero subgradient where x is 0.
+ """
+ ret = torch.zeros_like(x)
+ positive_mask = x > 0
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
+ return ret
+
+def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
+ """
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
+ using Gram--Schmidt orthogonalization per Section B of [1].
+ Args:
+ d6: 6D rotation representation, of size (*, 6)
+
+ Returns:
+ batch of rotation matrices of size (*, 3, 3)
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+
+ a1, a2 = d6[..., :3], d6[..., 3:]
+ b1 = F.normalize(a1, dim=-1)
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
+ b2 = F.normalize(b2, dim=-1)
+ b3 = torch.cross(b1, b2, dim=-1)
+ return torch.stack((b1, b2, b3), dim=-2)
+
+
+def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
+ by dropping the last row. Note that 6D representation is not unique.
+ Args:
+ matrix: batch of rotation matrices of size (*, 3, 3)
+
+ Returns:
+ 6D rotation representation, of size (*, 6)
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+ batch_dim = matrix.size()[:-2]
+ return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
+
+def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as axis/angle to quaternions.
+
+ Args:
+ axis_angle: Rotations given as a vector in axis angle form,
+ as a tensor of shape (..., 3), where the magnitude is
+ the angle turned anticlockwise in radians around the
+ vector's direction.
+
+ Returns:
+ quaternions with real part first, as tensor of shape (..., 4).
+ """
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
+ half_angles = angles * 0.5
+ eps = 1e-6
+ small_angles = angles.abs() < eps
+ sin_half_angles_over_angles = torch.empty_like(angles)
+ sin_half_angles_over_angles[~small_angles] = (
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
+ )
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
+ sin_half_angles_over_angles[small_angles] = (
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
+ )
+ quaternions = torch.cat(
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
+ )
+ return quaternions
+
+
+def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as quaternions to axis/angle.
+
+ Args:
+ quaternions: quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotations given as a vector in axis angle form, as a tensor
+ of shape (..., 3), where the magnitude is the angle
+ turned anticlockwise in radians around the vector's
+ direction.
+ """
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
+ half_angles = torch.atan2(norms, quaternions[..., :1])
+ angles = 2 * half_angles
+ eps = 1e-6
+ small_angles = angles.abs() < eps
+ sin_half_angles_over_angles = torch.empty_like(angles)
+ sin_half_angles_over_angles[~small_angles] = (
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
+ )
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
+ sin_half_angles_over_angles[small_angles] = (
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
+ )
+ return quaternions[..., 1:] / sin_half_angles_over_angles
+
+def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as quaternions to rotation matrices.
+
+ Args:
+ quaternions: quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ r, i, j, k = torch.unbind(quaternions, -1)
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ quaternions with real part first, as tensor of shape (..., 4).
+ """
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ batch_dim = matrix.shape[:-2]
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
+ matrix.reshape(batch_dim + (9,)), dim=-1
+ )
+
+ q_abs = _sqrt_positive_part(
+ torch.stack(
+ [
+ 1.0 + m00 + m11 + m22,
+ 1.0 + m00 - m11 - m22,
+ 1.0 - m00 + m11 - m22,
+ 1.0 - m00 - m11 + m22,
+ ],
+ dim=-1,
+ )
+ )
+
+ # we produce the desired quaternion multiplied by each of r, i, j, k
+ quat_by_rijk = torch.stack(
+ [
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
+ ],
+ dim=-2,
+ )
+
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
+ # the candidate won't be picked.
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
+
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
+ # forall i; we pick the best-conditioned one (with the largest denominator)
+
+ return quat_candidates[
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
+ ].reshape(batch_dim + (4,))
+
+def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as axis/angle to rotation matrices.
+
+ Args:
+ axis_angle: Rotations given as a vector in axis angle form,
+ as a tensor of shape (..., 3), where the magnitude is
+ the angle turned anticlockwise in radians around the
+ vector's direction.
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
+
+
+def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to axis/angle.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ Rotations given as a vector in axis angle form, as a tensor
+ of shape (..., 3), where the magnitude is the angle
+ turned anticlockwise in radians around the vector's
+ direction.
+ """
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
+
+def axis_angle_to_rotation_6d(axis_angle: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as axis/angle to rotation matrices.
+
+ Args:
+ axis_angle: Rotations given as a vector in axis angle form,
+ as a tensor of shape (..., 3), where the magnitude is
+ the angle turned anticlockwise in radians around the
+ vector's direction.
+
+ Returns:
+ 6D rotation representation, of size (*, 6)
+ """
+ return matrix_to_rotation_6d(axis_angle_to_matrix(axis_angle))
+
+def rotation_6d_to_axis_angle(d6):
+ """
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
+ using Gram--Schmidt orthogonalization per Section B of [1].
+ Args:
+ d6: 6D rotation representation, of size (*, 6)
+
+ Returns:
+ axis_angle: Rotations given as a vector in axis angle form,
+ as a tensor of shape (..., 3), where the magnitude is
+ the angle turned anticlockwise in radians around the
+ vector's direction.
+
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+
+ return matrix_to_axis_angle(rotation_6d_to_matrix(d6))
+
+
+def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
+ """
+ Return a tensor where each element has the absolute value taken from the,
+ corresponding element of a, with sign taken from the corresponding
+ element of b. This is like the standard copysign floating-point operation,
+ but is not careful about negative 0 and NaN.
+
+ Args:
+ a: source tensor.
+ b: tensor whose signs will be used, of the same shape as a.
+
+ Returns:
+ Tensor of the same shape as a with the signs of b.
+ """
+ signs_differ = (a < 0) != (b < 0)
+ return torch.where(signs_differ, -a, a)
+
+
+def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
+ """
+ Returns torch.sqrt(torch.max(0, x))
+ but with a zero subgradient where x is 0.
+ """
+ ret = torch.zeros_like(x)
+ positive_mask = x > 0
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
+ return ret
+
+def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
+ """
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
+ using Gram--Schmidt orthogonalization per Section B of [1].
+ Args:
+ d6: 6D rotation representation, of size (*, 6)
+
+ Returns:
+ batch of rotation matrices of size (*, 3, 3)
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+
+ a1, a2 = d6[..., :3], d6[..., 3:]
+ b1 = F.normalize(a1, dim=-1)
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
+ b2 = F.normalize(b2, dim=-1)
+ b3 = torch.cross(b1, b2, dim=-1)
+ return torch.stack((b1, b2, b3), dim=-2)
+
+
+def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
+ by dropping the last row. Note that 6D representation is not unique.
+ Args:
+ matrix: batch of rotation matrices of size (*, 3, 3)
+
+ Returns:
+ 6D rotation representation, of size (*, 6)
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+ batch_dim = matrix.size()[:-2]
+ return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
+
+def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as axis/angle to quaternions.
+
+ Args:
+ axis_angle: Rotations given as a vector in axis angle form,
+ as a tensor of shape (..., 3), where the magnitude is
+ the angle turned anticlockwise in radians around the
+ vector's direction.
+
+ Returns:
+ quaternions with real part first, as tensor of shape (..., 4).
+ """
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
+ half_angles = angles * 0.5
+ eps = 1e-6
+ small_angles = angles.abs() < eps
+ sin_half_angles_over_angles = torch.empty_like(angles)
+ sin_half_angles_over_angles[~small_angles] = (
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
+ )
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
+ sin_half_angles_over_angles[small_angles] = (
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
+ )
+ quaternions = torch.cat(
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
+ )
+ return quaternions
+
+
+def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as quaternions to axis/angle.
+
+ Args:
+ quaternions: quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotations given as a vector in axis angle form, as a tensor
+ of shape (..., 3), where the magnitude is the angle
+ turned anticlockwise in radians around the vector's
+ direction.
+ """
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
+ half_angles = torch.atan2(norms, quaternions[..., :1])
+ angles = 2 * half_angles
+ eps = 1e-6
+ small_angles = angles.abs() < eps
+ sin_half_angles_over_angles = torch.empty_like(angles)
+ sin_half_angles_over_angles[~small_angles] = (
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
+ )
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
+ sin_half_angles_over_angles[small_angles] = (
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
+ )
+ return quaternions[..., 1:] / sin_half_angles_over_angles
+
+def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as quaternions to rotation matrices.
+
+ Args:
+ quaternions: quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ r, i, j, k = torch.unbind(quaternions, -1)
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ quaternions with real part first, as tensor of shape (..., 4).
+ """
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ batch_dim = matrix.shape[:-2]
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
+ matrix.reshape(batch_dim + (9,)), dim=-1
+ )
+
+ q_abs = _sqrt_positive_part(
+ torch.stack(
+ [
+ 1.0 + m00 + m11 + m22,
+ 1.0 + m00 - m11 - m22,
+ 1.0 - m00 + m11 - m22,
+ 1.0 - m00 - m11 + m22,
+ ],
+ dim=-1,
+ )
+ )
+
+ # we produce the desired quaternion multiplied by each of r, i, j, k
+ quat_by_rijk = torch.stack(
+ [
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
+ ],
+ dim=-2,
+ )
+
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
+ # the candidate won't be picked.
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
+
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
+ # forall i; we pick the best-conditioned one (with the largest denominator)
+
+ return quat_candidates[
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
+ ].reshape(batch_dim + (4,))
+
+def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as axis/angle to rotation matrices.
+
+ Args:
+ axis_angle: Rotations given as a vector in axis angle form,
+ as a tensor of shape (..., 3), where the magnitude is
+ the angle turned anticlockwise in radians around the
+ vector's direction.
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
+
+
+def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to axis/angle.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ Rotations given as a vector in axis angle form, as a tensor
+ of shape (..., 3), where the magnitude is the angle
+ turned anticlockwise in radians around the vector's
+ direction.
+ """
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
+
+def axis_angle_to_rotation_6d(axis_angle: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as axis/angle to rotation matrices.
+
+ Args:
+ axis_angle: Rotations given as a vector in axis angle form,
+ as a tensor of shape (..., 3), where the magnitude is
+ the angle turned anticlockwise in radians around the
+ vector's direction.
+
+ Returns:
+ 6D rotation representation, of size (*, 6)
+ """
+ return matrix_to_rotation_6d(axis_angle_to_matrix(axis_angle))
+
+def rotation_6d_to_axis_angle(d6):
+ """
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
+ using Gram--Schmidt orthogonalization per Section B of [1].
+ Args:
+ d6: 6D rotation representation, of size (*, 6)
+
+ Returns:
+ axis_angle: Rotations given as a vector in axis angle form,
+ as a tensor of shape (..., 3), where the magnitude is
+ the angle turned anticlockwise in radians around the
+ vector's direction.
+
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+
+ return matrix_to_axis_angle(rotation_6d_to_matrix(d6))
\ No newline at end of file
diff --git a/main/postometro_utils/modules.py b/main/postometro_utils/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c868a46adb650284343178e1ea8c9a5c51ff73a
--- /dev/null
+++ b/main/postometro_utils/modules.py
@@ -0,0 +1,117 @@
+# --------------------------------------------------------
+# Borrow from unofficial MLPMixer (https://github.com/920232796/MlpMixer-pytorch)
+# Borrow from ResNet
+# Modified by Zigang Geng (zigang@mail.ustc.edu.cn)
+# --------------------------------------------------------
+
+import torch
+import torch.nn as nn
+
+
+class FCBlock(nn.Module):
+ def __init__(self, dim, out_dim):
+ super().__init__()
+
+ self.ff = nn.Sequential(
+ nn.Linear(dim, out_dim),
+ nn.LayerNorm(out_dim),
+ nn.ReLU(inplace=True),
+ )
+
+ def forward(self, x):
+ return self.ff(x)
+
+
+class MLPBlock(nn.Module):
+ def __init__(self, dim, inter_dim, dropout_ratio):
+ super().__init__()
+
+ self.ff = nn.Sequential(
+ nn.Linear(dim, inter_dim),
+ nn.GELU(),
+ nn.Dropout(dropout_ratio),
+ nn.Linear(inter_dim, dim),
+ nn.Dropout(dropout_ratio)
+ )
+
+ def forward(self, x):
+ return self.ff(x)
+
+
+class MixerLayer(nn.Module):
+ def __init__(self,
+ hidden_dim,
+ hidden_inter_dim,
+ token_dim,
+ token_inter_dim,
+ dropout_ratio):
+ super().__init__()
+
+ self.layernorm1 = nn.LayerNorm(hidden_dim)
+ self.MLP_token = MLPBlock(token_dim, token_inter_dim, dropout_ratio)
+ self.layernorm2 = nn.LayerNorm(hidden_dim)
+ self.MLP_channel = MLPBlock(hidden_dim, hidden_inter_dim, dropout_ratio)
+
+ def forward(self, x):
+ y = self.layernorm1(x)
+ y = y.transpose(2, 1)
+ y = self.MLP_token(y)
+ y = y.transpose(2, 1)
+ z = self.layernorm2(x + y)
+ z = self.MLP_channel(z)
+ out = x + y + z
+ return out
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1,
+ downsample=None, dilation=1):
+ super(BasicBlock, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
+ padding=dilation, bias=False, dilation=dilation)
+ self.bn1 = nn.BatchNorm2d(planes, momentum=0.1)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
+ padding=dilation, bias=False, dilation=dilation)
+ self.bn2 = nn.BatchNorm2d(planes, momentum=0.1)
+ self.downsample = downsample
+ self.stride = stride
+
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+def make_conv_layers(feat_dims, kernel=3, stride=1, padding=1, bnrelu_final=True):
+ layers = []
+ for i in range(len(feat_dims)-1):
+ layers.append(
+ nn.Conv2d(
+ in_channels=feat_dims[i],
+ out_channels=feat_dims[i+1],
+ kernel_size=kernel,
+ stride=stride,
+ padding=padding
+ ))
+ # Do not use BN and ReLU for final estimation
+ if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final):
+ layers.append(nn.BatchNorm2d(feat_dims[i+1]))
+ layers.append(nn.ReLU(inplace=True))
+
+ return nn.Sequential(*layers)
\ No newline at end of file
diff --git a/main/postometro_utils/pose_hrnet.py b/main/postometro_utils/pose_hrnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..007c29447896f05273b56c64451c7a8840d608bb
--- /dev/null
+++ b/main/postometro_utils/pose_hrnet.py
@@ -0,0 +1,502 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import logging
+
+import torch
+import torch.nn as nn
+
+
+BN_MOMENTUM = 0.1
+logger = logging.getLogger(__name__)
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
+ bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion,
+ momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class HighResolutionModule(nn.Module):
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
+ num_channels, fuse_method, multi_scale_output=True):
+ super(HighResolutionModule, self).__init__()
+ self._check_branches(
+ num_branches, blocks, num_blocks, num_inchannels, num_channels)
+
+ self.num_inchannels = num_inchannels
+ self.fuse_method = fuse_method
+ self.num_branches = num_branches
+
+ self.multi_scale_output = multi_scale_output
+
+ self.branches = self._make_branches(
+ num_branches, blocks, num_blocks, num_channels)
+ self.fuse_layers = self._make_fuse_layers()
+ self.relu = nn.ReLU(True)
+
+ def _check_branches(self, num_branches, blocks, num_blocks,
+ num_inchannels, num_channels):
+ if num_branches != len(num_blocks):
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
+ num_branches, len(num_blocks))
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ if num_branches != len(num_channels):
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
+ num_branches, len(num_channels))
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ if num_branches != len(num_inchannels):
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
+ num_branches, len(num_inchannels))
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
+ stride=1):
+ downsample = None
+ if stride != 1 or \
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ self.num_inchannels[branch_index],
+ num_channels[branch_index] * block.expansion,
+ kernel_size=1, stride=stride, bias=False
+ ),
+ nn.BatchNorm2d(
+ num_channels[branch_index] * block.expansion,
+ momentum=BN_MOMENTUM
+ ),
+ )
+
+ layers = []
+ layers.append(
+ block(
+ self.num_inchannels[branch_index],
+ num_channels[branch_index],
+ stride,
+ downsample
+ )
+ )
+ self.num_inchannels[branch_index] = \
+ num_channels[branch_index] * block.expansion
+ for i in range(1, num_blocks[branch_index]):
+ layers.append(
+ block(
+ self.num_inchannels[branch_index],
+ num_channels[branch_index]
+ )
+ )
+
+ return nn.Sequential(*layers)
+
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
+ branches = []
+
+ for i in range(num_branches):
+ branches.append(
+ self._make_one_branch(i, block, num_blocks, num_channels)
+ )
+
+ return nn.ModuleList(branches)
+
+ def _make_fuse_layers(self):
+ if self.num_branches == 1:
+ return None
+
+ num_branches = self.num_branches
+ num_inchannels = self.num_inchannels
+ fuse_layers = []
+ for i in range(num_branches if self.multi_scale_output else 1):
+ fuse_layer = []
+ for j in range(num_branches):
+ if j > i:
+ fuse_layer.append(
+ nn.Sequential(
+ nn.Conv2d(
+ num_inchannels[j],
+ num_inchannels[i],
+ 1, 1, 0, bias=False
+ ),
+ nn.BatchNorm2d(num_inchannels[i]),
+ nn.Upsample(scale_factor=2**(j-i), mode='nearest')
+ )
+ )
+ elif j == i:
+ fuse_layer.append(None)
+ else:
+ conv3x3s = []
+ for k in range(i-j):
+ if k == i - j - 1:
+ num_outchannels_conv3x3 = num_inchannels[i]
+ conv3x3s.append(
+ nn.Sequential(
+ nn.Conv2d(
+ num_inchannels[j],
+ num_outchannels_conv3x3,
+ 3, 2, 1, bias=False
+ ),
+ nn.BatchNorm2d(num_outchannels_conv3x3)
+ )
+ )
+ else:
+ num_outchannels_conv3x3 = num_inchannels[j]
+ conv3x3s.append(
+ nn.Sequential(
+ nn.Conv2d(
+ num_inchannels[j],
+ num_outchannels_conv3x3,
+ 3, 2, 1, bias=False
+ ),
+ nn.BatchNorm2d(num_outchannels_conv3x3),
+ nn.ReLU(True)
+ )
+ )
+ fuse_layer.append(nn.Sequential(*conv3x3s))
+ fuse_layers.append(nn.ModuleList(fuse_layer))
+
+ return nn.ModuleList(fuse_layers)
+
+ def get_num_inchannels(self):
+ return self.num_inchannels
+
+ def forward(self, x):
+ if self.num_branches == 1:
+ return [self.branches[0](x[0])]
+
+ for i in range(self.num_branches):
+ x[i] = self.branches[i](x[i])
+
+ x_fuse = []
+
+ for i in range(len(self.fuse_layers)):
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
+ for j in range(1, self.num_branches):
+ if i == j:
+ y = y + x[j]
+ else:
+ y = y + self.fuse_layers[i][j](x[j])
+ x_fuse.append(self.relu(y))
+
+ return x_fuse
+
+
+blocks_dict = {
+ 'BASIC': BasicBlock,
+ 'BOTTLENECK': Bottleneck
+}
+
+
+class PoseHighResolutionNet(nn.Module):
+
+ def __init__(self, cfg, **kwargs):
+ self.inplanes = 64
+ extra = cfg['MODEL']['EXTRA']
+ super(PoseHighResolutionNet, self).__init__()
+
+ # stem net
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
+ bias=False)
+ self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+ self.layer1 = self._make_layer(Bottleneck, 64, 4)
+
+ self.stage2_cfg = extra['STAGE2']
+ num_channels = self.stage2_cfg['NUM_CHANNELS']
+ block = blocks_dict[self.stage2_cfg['BLOCK']]
+ num_channels = [
+ num_channels[i] * block.expansion for i in range(len(num_channels))
+ ]
+ self.transition1 = self._make_transition_layer([256], num_channels)
+ self.stage2, pre_stage_channels = self._make_stage(
+ self.stage2_cfg, num_channels)
+
+ self.stage3_cfg = extra['STAGE3']
+ num_channels = self.stage3_cfg['NUM_CHANNELS']
+ block = blocks_dict[self.stage3_cfg['BLOCK']]
+ num_channels = [
+ num_channels[i] * block.expansion for i in range(len(num_channels))
+ ]
+ self.transition2 = self._make_transition_layer(
+ pre_stage_channels, num_channels)
+ self.stage3, pre_stage_channels = self._make_stage(
+ self.stage3_cfg, num_channels)
+
+ self.stage4_cfg = extra['STAGE4']
+ num_channels = self.stage4_cfg['NUM_CHANNELS']
+ block = blocks_dict[self.stage4_cfg['BLOCK']]
+ num_channels = [
+ num_channels[i] * block.expansion for i in range(len(num_channels))
+ ]
+ self.transition3 = self._make_transition_layer(
+ pre_stage_channels, num_channels)
+ self.stage4, pre_stage_channels = self._make_stage(
+ self.stage4_cfg, num_channels,
+ multi_scale_output=True)
+ # multi_scale_output=False)
+
+ self.final_layer = nn.Conv2d(
+ in_channels=pre_stage_channels[0],
+ out_channels=cfg['MODEL']['NUM_JOINTS'],
+ kernel_size=extra['FINAL_CONV_KERNEL'],
+ stride=1,
+ padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0
+ )
+
+ self.pretrained_layers = extra['PRETRAINED_LAYERS']
+
+ def _make_transition_layer(
+ self, num_channels_pre_layer, num_channels_cur_layer):
+ num_branches_cur = len(num_channels_cur_layer)
+ num_branches_pre = len(num_channels_pre_layer)
+
+ transition_layers = []
+ for i in range(num_branches_cur):
+ if i < num_branches_pre:
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
+ transition_layers.append(
+ nn.Sequential(
+ nn.Conv2d(
+ num_channels_pre_layer[i],
+ num_channels_cur_layer[i],
+ 3, 1, 1, bias=False
+ ),
+ nn.BatchNorm2d(num_channels_cur_layer[i]),
+ nn.ReLU(inplace=True)
+ )
+ )
+ else:
+ transition_layers.append(None)
+ else:
+ conv3x3s = []
+ for j in range(i+1-num_branches_pre):
+ inchannels = num_channels_pre_layer[-1]
+ outchannels = num_channels_cur_layer[i] \
+ if j == i-num_branches_pre else inchannels
+ conv3x3s.append(
+ nn.Sequential(
+ nn.Conv2d(
+ inchannels, outchannels, 3, 2, 1, bias=False
+ ),
+ nn.BatchNorm2d(outchannels),
+ nn.ReLU(inplace=True)
+ )
+ )
+ transition_layers.append(nn.Sequential(*conv3x3s))
+
+ return nn.ModuleList(transition_layers)
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False
+ ),
+ nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def _make_stage(self, layer_config, num_inchannels,
+ multi_scale_output=True):
+ num_modules = layer_config['NUM_MODULES']
+ num_branches = layer_config['NUM_BRANCHES']
+ num_blocks = layer_config['NUM_BLOCKS']
+ num_channels = layer_config['NUM_CHANNELS']
+ block = blocks_dict[layer_config['BLOCK']]
+ fuse_method = layer_config['FUSE_METHOD']
+
+ modules = []
+ for i in range(num_modules):
+ # multi_scale_output is only used last module
+ if not multi_scale_output and i == num_modules - 1:
+ reset_multi_scale_output = False
+ else:
+ reset_multi_scale_output = True
+
+ modules.append(
+ HighResolutionModule(
+ num_branches,
+ block,
+ num_blocks,
+ num_inchannels,
+ num_channels,
+ fuse_method,
+ reset_multi_scale_output
+ )
+ )
+ num_inchannels = modules[-1].get_num_inchannels()
+
+ return nn.Sequential(*modules), num_inchannels
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.relu(x)
+ x = self.layer1(x)
+
+ x_list = []
+ for i in range(self.stage2_cfg['NUM_BRANCHES']):
+ if self.transition1[i] is not None:
+ x_list.append(self.transition1[i](x))
+ else:
+ x_list.append(x)
+ y_list = self.stage2(x_list)
+
+ x_list = []
+ for i in range(self.stage3_cfg['NUM_BRANCHES']):
+ if self.transition2[i] is not None:
+ x_list.append(self.transition2[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage3(x_list)
+
+ x_list = []
+ for i in range(self.stage4_cfg['NUM_BRANCHES']):
+ if self.transition3[i] is not None:
+ x_list.append(self.transition3[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage4(x_list)
+
+ return y_list[-1]
+ # x = self.final_layer(y_list[0])
+ # return x
+
+ def init_weights(self, pretrained=''):
+ logger.info('=> init weights from normal distribution')
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ nn.init.normal_(m.weight, std=0.001)
+ for name, _ in m.named_parameters():
+ if name in ['bias']:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ nn.init.normal_(m.weight, std=0.001)
+ for name, _ in m.named_parameters():
+ if name in ['bias']:
+ nn.init.constant_(m.bias, 0)
+
+ if os.path.isfile(pretrained):
+ pretrained_state_dict = torch.load(pretrained)
+ logger.info('=> loading pretrained model {}'.format(pretrained))
+
+ need_init_state_dict = {}
+ for name, m in pretrained_state_dict.items():
+ if name.split('.')[0] in self.pretrained_layers \
+ or self.pretrained_layers[0] is '*':
+ need_init_state_dict[name] = m
+ out = self.load_state_dict(need_init_state_dict, strict=False)
+ elif pretrained:
+ logger.error('=> please download pre-trained models first!')
+ raise ValueError('{} is not exist!'.format(pretrained))
+
+
+def get_pose_hrnet(cfg, pretrained, **kwargs):
+ model = PoseHighResolutionNet(cfg, **kwargs)
+ if pretrained is not None:
+ model.init_weights(pretrained=pretrained)
+
+ return model
\ No newline at end of file
diff --git a/main/postometro_utils/pose_hrnet_config.py b/main/postometro_utils/pose_hrnet_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..00c542fa56cd875f6655b844d8a49d6ce061838f
--- /dev/null
+++ b/main/postometro_utils/pose_hrnet_config.py
@@ -0,0 +1,137 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# Modified by Ke Sun (sunk@mail.ustc.edu.cn)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from yacs.config import CfgNode as CN
+
+
+_C = CN()
+
+_C.OUTPUT_DIR = ''
+_C.LOG_DIR = ''
+_C.DATA_DIR = ''
+_C.GPUS = (0,)
+_C.WORKERS = 4
+_C.PRINT_FREQ = 20
+_C.AUTO_RESUME = False
+_C.PIN_MEMORY = True
+_C.RANK = 0
+
+# Cudnn related params
+_C.CUDNN = CN()
+_C.CUDNN.BENCHMARK = True
+_C.CUDNN.DETERMINISTIC = False
+_C.CUDNN.ENABLED = True
+
+# common params for NETWORK
+_C.MODEL = CN()
+_C.MODEL.NAME = 'cls_hrnet'
+_C.MODEL.INIT_WEIGHTS = True
+_C.MODEL.PRETRAINED = ''
+_C.MODEL.NUM_JOINTS = 17
+_C.MODEL.NUM_CLASSES = 1000
+_C.MODEL.TAG_PER_JOINT = True
+_C.MODEL.TARGET_TYPE = 'gaussian'
+_C.MODEL.IMAGE_SIZE = [256, 256] # width * height, ex: 192 * 256
+_C.MODEL.HEATMAP_SIZE = [64, 64] # width * height, ex: 24 * 32
+_C.MODEL.SIGMA = 2
+_C.MODEL.EXTRA = CN(new_allowed=True)
+
+_C.LOSS = CN()
+_C.LOSS.USE_OHKM = False
+_C.LOSS.TOPK = 8
+_C.LOSS.USE_TARGET_WEIGHT = True
+_C.LOSS.USE_DIFFERENT_JOINTS_WEIGHT = False
+
+# DATASET related params
+_C.DATASET = CN()
+_C.DATASET.ROOT = ''
+_C.DATASET.DATASET = 'mpii'
+_C.DATASET.TRAIN_SET = 'train'
+_C.DATASET.TEST_SET = 'valid'
+_C.DATASET.DATA_FORMAT = 'jpg'
+_C.DATASET.HYBRID_JOINTS_TYPE = ''
+_C.DATASET.SELECT_DATA = False
+
+# training data augmentation
+_C.DATASET.FLIP = True
+_C.DATASET.SCALE_FACTOR = 0.25
+_C.DATASET.ROT_FACTOR = 30
+_C.DATASET.PROB_HALF_BODY = 0.0
+_C.DATASET.NUM_JOINTS_HALF_BODY = 8
+_C.DATASET.COLOR_RGB = False
+
+# train
+_C.TRAIN = CN()
+
+_C.TRAIN.LR_FACTOR = 0.1
+_C.TRAIN.LR_STEP = [90, 110]
+_C.TRAIN.LR = 0.001
+
+_C.TRAIN.OPTIMIZER = 'adam'
+_C.TRAIN.MOMENTUM = 0.9
+_C.TRAIN.WD = 0.0001
+_C.TRAIN.NESTEROV = False
+_C.TRAIN.GAMMA1 = 0.99
+_C.TRAIN.GAMMA2 = 0.0
+
+_C.TRAIN.BEGIN_EPOCH = 0
+_C.TRAIN.END_EPOCH = 140
+
+_C.TRAIN.RESUME = False
+_C.TRAIN.CHECKPOINT = ''
+
+_C.TRAIN.BATCH_SIZE_PER_GPU = 32
+_C.TRAIN.SHUFFLE = True
+
+# testing
+_C.TEST = CN()
+
+# size of images for each device
+_C.TEST.BATCH_SIZE_PER_GPU = 32
+# Test Model Epoch
+_C.TEST.FLIP_TEST = False
+_C.TEST.POST_PROCESS = False
+_C.TEST.SHIFT_HEATMAP = False
+
+_C.TEST.USE_GT_BBOX = False
+
+# nms
+_C.TEST.IMAGE_THRE = 0.1
+_C.TEST.NMS_THRE = 0.6
+_C.TEST.SOFT_NMS = False
+_C.TEST.OKS_THRE = 0.5
+_C.TEST.IN_VIS_THRE = 0.0
+_C.TEST.COCO_BBOX_FILE = ''
+_C.TEST.BBOX_THRE = 1.0
+_C.TEST.MODEL_FILE = ''
+
+# debug
+_C.DEBUG = CN()
+_C.DEBUG.DEBUG = False
+_C.DEBUG.SAVE_BATCH_IMAGES_GT = False
+_C.DEBUG.SAVE_BATCH_IMAGES_PRED = False
+_C.DEBUG.SAVE_HEATMAPS_GT = False
+_C.DEBUG.SAVE_HEATMAPS_PRED = False
+
+
+def update_config(cfg, config_file):
+ cfg.defrost()
+ cfg.merge_from_file(config_file)
+ cfg.freeze()
+
+
+if __name__ == '__main__':
+ import sys
+ with open(sys.argv[1], 'w') as f:
+ print(_C, file=f)
+
diff --git a/main/postometro_utils/pose_resnet.py b/main/postometro_utils/pose_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c423f36418ef5ac0dcc85f082a029f0280b6b52f
--- /dev/null
+++ b/main/postometro_utils/pose_resnet.py
@@ -0,0 +1,318 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import logging
+
+import torch
+import torch.nn as nn
+from collections import OrderedDict
+
+
+BN_MOMENTUM = 0.1
+logger = logging.getLogger(__name__)
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
+ bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion,
+ momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck_CAFFE(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck_CAFFE, self).__init__()
+ # add stride to conv1x1
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
+ padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
+ bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion,
+ momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class PoseResNet(nn.Module):
+
+ def __init__(self, block, layers, cfg, **kwargs):
+ self.inplanes = 64
+ extra = cfg.MODEL.EXTRA
+ self.deconv_with_bias = extra.DECONV_WITH_BIAS
+
+ super(PoseResNet, self).__init__()
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+
+ # used for deconv layers
+ # self.deconv_layers = self._make_deconv_layer(
+ # extra.NUM_DECONV_LAYERS,
+ # extra.NUM_DECONV_FILTERS,
+ # extra.NUM_DECONV_KERNELS,
+ # )
+
+ # self.final_layer = nn.Conv2d(
+ # in_channels=extra.NUM_DECONV_FILTERS[-1],
+ # out_channels=cfg.MODEL.NUM_JOINTS,
+ # kernel_size=extra.FINAL_CONV_KERNEL,
+ # stride=1,
+ # padding=1 if extra.FINAL_CONV_KERNEL == 3 else 0
+ # )
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def _get_deconv_cfg(self, deconv_kernel, index):
+ if deconv_kernel == 4:
+ padding = 1
+ output_padding = 0
+ elif deconv_kernel == 3:
+ padding = 1
+ output_padding = 1
+ elif deconv_kernel == 2:
+ padding = 0
+ output_padding = 0
+
+ return deconv_kernel, padding, output_padding
+
+ def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
+ assert num_layers == len(num_filters), \
+ 'ERROR: num_deconv_layers is different len(num_deconv_filters)'
+ assert num_layers == len(num_kernels), \
+ 'ERROR: num_deconv_layers is different len(num_deconv_filters)'
+
+ layers = []
+ for i in range(num_layers):
+ kernel, padding, output_padding = \
+ self._get_deconv_cfg(num_kernels[i], i)
+
+ planes = num_filters[i]
+ layers.append(
+ nn.ConvTranspose2d(
+ in_channels=self.inplanes,
+ out_channels=planes,
+ kernel_size=kernel,
+ stride=2,
+ padding=padding,
+ output_padding=output_padding,
+ bias=self.deconv_with_bias))
+ layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
+ layers.append(nn.ReLU(inplace=True))
+ self.inplanes = planes
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x, skip_early = False, use_pct = False):
+ if not use_pct:
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ return x
+
+ if skip_early:
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ return x
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ return x
+
+ def init_weights(self, pretrained=''):
+ if os.path.isfile(pretrained):
+ # pretrained_state_dict = torch.load(pretrained)
+ logger.info('=> loading pretrained model {}'.format(pretrained))
+ # self.load_state_dict(pretrained_state_dict, strict=False)
+ checkpoint = torch.load(pretrained)
+ if isinstance(checkpoint, OrderedDict):
+ state_dict = checkpoint
+ elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
+ state_dict_old = checkpoint['state_dict']
+ state_dict = OrderedDict()
+ # delete 'module.' because it is saved from DataParallel module
+ for key in state_dict_old.keys():
+ if key.startswith('module.'):
+ # state_dict[key[7:]] = state_dict[key]
+ # state_dict.pop(key)
+ state_dict[key[7:]] = state_dict_old[key]
+ else:
+ state_dict[key] = state_dict_old[key]
+ else:
+ raise RuntimeError(
+ 'No state_dict found in checkpoint file {}'.format(pretrained))
+ state_dict_old = state_dict
+ state_dict = OrderedDict()
+ for k,v in state_dict_old.items():
+ if 'deconv_layers' in k or 'final_layer' in k:
+ continue
+ else:
+ state_dict[k] = state_dict_old[k]
+ self.load_state_dict(state_dict, strict=True)
+ else:
+ logger.error('=> imagenet pretrained model dose not exist')
+ logger.error('=> please download it first')
+ raise ValueError('imagenet pretrained model does not exist')
+
+
+resnet_spec = {18: (BasicBlock, [2, 2, 2, 2]),
+ 34: (BasicBlock, [3, 4, 6, 3]),
+ 50: (Bottleneck, [3, 4, 6, 3]),
+ 101: (Bottleneck, [3, 4, 23, 3]),
+ 152: (Bottleneck, [3, 8, 36, 3])}
+
+
+def get_pose_net(cfg, is_train, **kwargs):
+ num_layers = cfg.MODEL.EXTRA.NUM_LAYERS
+ style = cfg.MODEL.STYLE
+
+ block_class, layers = resnet_spec[num_layers]
+
+ if style == 'caffe':
+ block_class = Bottleneck_CAFFE
+
+ model = PoseResNet(block_class, layers, cfg, **kwargs)
+
+ if is_train and cfg.MODEL.INIT_WEIGHTS:
+ model.init_weights(cfg.MODEL.PRETRAINED)
+
+ return model
diff --git a/main/postometro_utils/pose_resnet_config.py b/main/postometro_utils/pose_resnet_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..06bc863d38bf6fad4f00390ebb11c984ed9dd89e
--- /dev/null
+++ b/main/postometro_utils/pose_resnet_config.py
@@ -0,0 +1,229 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import yaml
+
+import numpy as np
+from easydict import EasyDict as edict
+
+
+config = edict()
+
+config.OUTPUT_DIR = ''
+config.LOG_DIR = ''
+config.DATA_DIR = ''
+config.GPUS = '0'
+config.WORKERS = 4
+config.PRINT_FREQ = 20
+
+# Cudnn related params
+config.CUDNN = edict()
+config.CUDNN.BENCHMARK = True
+config.CUDNN.DETERMINISTIC = False
+config.CUDNN.ENABLED = True
+
+# pose_resnet related params
+POSE_RESNET = edict()
+POSE_RESNET.NUM_LAYERS = 50
+POSE_RESNET.DECONV_WITH_BIAS = False
+POSE_RESNET.NUM_DECONV_LAYERS = 3
+POSE_RESNET.NUM_DECONV_FILTERS = [256, 256, 256]
+POSE_RESNET.NUM_DECONV_KERNELS = [4, 4, 4]
+POSE_RESNET.FINAL_CONV_KERNEL = 1
+POSE_RESNET.TARGET_TYPE = 'gaussian'
+POSE_RESNET.HEATMAP_SIZE = [64, 64] # width * height, ex: 24 * 32
+POSE_RESNET.SIGMA = 2
+
+MODEL_EXTRAS = {
+ 'pose_resnet': POSE_RESNET,
+}
+
+# common params for NETWORK
+config.MODEL = edict()
+config.MODEL.NAME = 'pose_resnet'
+config.MODEL.INIT_WEIGHTS = True
+config.MODEL.PRETRAINED = ''
+config.MODEL.NUM_JOINTS = 16
+config.MODEL.IMAGE_SIZE = [256, 256] # width * height, ex: 192 * 256
+config.MODEL.EXTRA = MODEL_EXTRAS[config.MODEL.NAME]
+
+config.MODEL.STYLE = 'pytorch'
+
+config.LOSS = edict()
+config.LOSS.USE_TARGET_WEIGHT = True
+
+# DATASET related params
+config.DATASET = edict()
+config.DATASET.ROOT = ''
+config.DATASET.DATASET = 'mpii'
+config.DATASET.TRAIN_SET = 'train'
+config.DATASET.TEST_SET = 'valid'
+config.DATASET.DATA_FORMAT = 'jpg'
+config.DATASET.HYBRID_JOINTS_TYPE = ''
+config.DATASET.SELECT_DATA = False
+
+# training data augmentation
+config.DATASET.FLIP = True
+config.DATASET.SCALE_FACTOR = 0.25
+config.DATASET.ROT_FACTOR = 30
+
+# train
+config.TRAIN = edict()
+
+config.TRAIN.LR_FACTOR = 0.1
+config.TRAIN.LR_STEP = [90, 110]
+config.TRAIN.LR = 0.001
+
+config.TRAIN.OPTIMIZER = 'adam'
+config.TRAIN.MOMENTUM = 0.9
+config.TRAIN.WD = 0.0001
+config.TRAIN.NESTEROV = False
+config.TRAIN.GAMMA1 = 0.99
+config.TRAIN.GAMMA2 = 0.0
+
+config.TRAIN.BEGIN_EPOCH = 0
+config.TRAIN.END_EPOCH = 140
+
+config.TRAIN.RESUME = False
+config.TRAIN.CHECKPOINT = ''
+
+config.TRAIN.BATCH_SIZE = 32
+config.TRAIN.SHUFFLE = True
+
+# testing
+config.TEST = edict()
+
+# size of images for each device
+config.TEST.BATCH_SIZE = 32
+# Test Model Epoch
+config.TEST.FLIP_TEST = False
+config.TEST.POST_PROCESS = True
+config.TEST.SHIFT_HEATMAP = True
+
+config.TEST.USE_GT_BBOX = False
+# nms
+config.TEST.OKS_THRE = 0.5
+config.TEST.IN_VIS_THRE = 0.0
+config.TEST.COCO_BBOX_FILE = ''
+config.TEST.BBOX_THRE = 1.0
+config.TEST.MODEL_FILE = ''
+config.TEST.IMAGE_THRE = 0.0
+config.TEST.NMS_THRE = 1.0
+
+# debug
+config.DEBUG = edict()
+config.DEBUG.DEBUG = False
+config.DEBUG.SAVE_BATCH_IMAGES_GT = False
+config.DEBUG.SAVE_BATCH_IMAGES_PRED = False
+config.DEBUG.SAVE_HEATMAPS_GT = False
+config.DEBUG.SAVE_HEATMAPS_PRED = False
+
+
+def _update_dict(k, v):
+ if k == 'DATASET':
+ if 'MEAN' in v and v['MEAN']:
+ v['MEAN'] = np.array([eval(x) if isinstance(x, str) else x
+ for x in v['MEAN']])
+ if 'STD' in v and v['STD']:
+ v['STD'] = np.array([eval(x) if isinstance(x, str) else x
+ for x in v['STD']])
+ if k == 'MODEL':
+ if 'EXTRA' in v and 'HEATMAP_SIZE' in v['EXTRA']:
+ if isinstance(v['EXTRA']['HEATMAP_SIZE'], int):
+ v['EXTRA']['HEATMAP_SIZE'] = np.array(
+ [v['EXTRA']['HEATMAP_SIZE'], v['EXTRA']['HEATMAP_SIZE']])
+ else:
+ v['EXTRA']['HEATMAP_SIZE'] = np.array(
+ v['EXTRA']['HEATMAP_SIZE'])
+ if 'IMAGE_SIZE' in v:
+ if isinstance(v['IMAGE_SIZE'], int):
+ v['IMAGE_SIZE'] = np.array([v['IMAGE_SIZE'], v['IMAGE_SIZE']])
+ else:
+ v['IMAGE_SIZE'] = np.array(v['IMAGE_SIZE'])
+ for vk, vv in v.items():
+ if vk in config[k]:
+ config[k][vk] = vv
+ else:
+ raise ValueError("{}.{} not exist in config.py".format(k, vk))
+
+
+def update_config(config_file):
+ exp_config = None
+ with open(config_file) as f:
+ exp_config = edict(yaml.load(f))
+ for k, v in exp_config.items():
+ if k in config:
+ if isinstance(v, dict):
+ _update_dict(k, v)
+ else:
+ if k == 'SCALES':
+ config[k][0] = (tuple(v))
+ else:
+ config[k] = v
+ else:
+ raise ValueError("{} not exist in config.py".format(k))
+
+
+def gen_config(config_file):
+ cfg = dict(config)
+ for k, v in cfg.items():
+ if isinstance(v, edict):
+ cfg[k] = dict(v)
+
+ with open(config_file, 'w') as f:
+ yaml.dump(dict(cfg), f, default_flow_style=False)
+
+
+def update_dir(model_dir, log_dir, data_dir):
+ if model_dir:
+ config.OUTPUT_DIR = model_dir
+
+ if log_dir:
+ config.LOG_DIR = log_dir
+
+ if data_dir:
+ config.DATA_DIR = data_dir
+
+ config.DATASET.ROOT = os.path.join(
+ config.DATA_DIR, config.DATASET.ROOT)
+
+ config.TEST.COCO_BBOX_FILE = os.path.join(
+ config.DATA_DIR, config.TEST.COCO_BBOX_FILE)
+
+ config.MODEL.PRETRAINED = os.path.join(
+ config.DATA_DIR, config.MODEL.PRETRAINED)
+
+
+def get_model_name(cfg):
+ name = cfg.MODEL.NAME
+ full_name = cfg.MODEL.NAME
+ extra = cfg.MODEL.EXTRA
+ if name in ['pose_resnet']:
+ name = '{model}_{num_layers}'.format(
+ model=name,
+ num_layers=extra.NUM_LAYERS)
+ deconv_suffix = ''.join(
+ 'd{}'.format(num_filters)
+ for num_filters in extra.NUM_DECONV_FILTERS)
+ full_name = '{height}x{width}_{name}_{deconv_suffix}'.format(
+ height=cfg.MODEL.IMAGE_SIZE[1],
+ width=cfg.MODEL.IMAGE_SIZE[0],
+ name=name,
+ deconv_suffix=deconv_suffix)
+ else:
+ raise ValueError('Unkown model: {}'.format(cfg.MODEL))
+
+ return name, full_name
+
+
+if __name__ == '__main__':
+ import sys
+ gen_config(sys.argv[1])
diff --git a/main/postometro_utils/pose_w48_256x192_adam_lr1e-3.yaml b/main/postometro_utils/pose_w48_256x192_adam_lr1e-3.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ecdea1ee6b08031bea216fd72a102042b07b3880
--- /dev/null
+++ b/main/postometro_utils/pose_w48_256x192_adam_lr1e-3.yaml
@@ -0,0 +1,127 @@
+AUTO_RESUME: true
+CUDNN:
+ BENCHMARK: true
+ DETERMINISTIC: false
+ ENABLED: true
+DATA_DIR: ''
+GPUS: (0,1,2,3)
+OUTPUT_DIR: 'output'
+LOG_DIR: 'log'
+WORKERS: 24
+PRINT_FREQ: 100
+
+DATASET:
+ COLOR_RGB: true
+ DATASET: 'coco'
+ DATA_FORMAT: jpg
+ FLIP: true
+ NUM_JOINTS_HALF_BODY: 8
+ PROB_HALF_BODY: 0.3
+ ROOT: 'data/coco/'
+ ROT_FACTOR: 45
+ SCALE_FACTOR: 0.35
+ TEST_SET: 'val2017'
+ TRAIN_SET: 'train2017'
+MODEL:
+ INIT_WEIGHTS: true
+ NAME: pose_hrnet
+ NUM_JOINTS: 17
+ PRETRAINED: 'models/pytorch/imagenet/hrnet_w48-8ef0771d.pth'
+ TARGET_TYPE: gaussian
+ IMAGE_SIZE:
+ - 192
+ - 256
+ HEATMAP_SIZE:
+ - 48
+ - 64
+ SIGMA: 2
+ EXTRA:
+ PRETRAINED_LAYERS:
+ - 'conv1'
+ - 'bn1'
+ - 'conv2'
+ - 'bn2'
+ - 'layer1'
+ - 'transition1'
+ - 'stage2'
+ - 'transition2'
+ - 'stage3'
+ - 'transition3'
+ - 'stage4'
+ FINAL_CONV_KERNEL: 1
+ STAGE2:
+ NUM_MODULES: 1
+ NUM_BRANCHES: 2
+ BLOCK: BASIC
+ NUM_BLOCKS:
+ - 4
+ - 4
+ NUM_CHANNELS:
+ - 48
+ - 96
+ FUSE_METHOD: SUM
+ STAGE3:
+ NUM_MODULES: 4
+ NUM_BRANCHES: 3
+ BLOCK: BASIC
+ NUM_BLOCKS:
+ - 4
+ - 4
+ - 4
+ NUM_CHANNELS:
+ - 48
+ - 96
+ - 192
+ FUSE_METHOD: SUM
+ STAGE4:
+ NUM_MODULES: 3
+ NUM_BRANCHES: 4
+ BLOCK: BASIC
+ NUM_BLOCKS:
+ - 4
+ - 4
+ - 4
+ - 4
+ NUM_CHANNELS:
+ - 48
+ - 96
+ - 192
+ - 384
+ FUSE_METHOD: SUM
+LOSS:
+ USE_TARGET_WEIGHT: true
+TRAIN:
+ BATCH_SIZE_PER_GPU: 32
+ SHUFFLE: true
+ BEGIN_EPOCH: 0
+ END_EPOCH: 210
+ OPTIMIZER: adam
+ LR: 0.001
+ LR_FACTOR: 0.1
+ LR_STEP:
+ - 170
+ - 200
+ WD: 0.0001
+ GAMMA1: 0.99
+ GAMMA2: 0.0
+ MOMENTUM: 0.9
+ NESTEROV: false
+TEST:
+ BATCH_SIZE_PER_GPU: 32
+ COCO_BBOX_FILE: 'data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json'
+ BBOX_THRE: 1.0
+ IMAGE_THRE: 0.0
+ IN_VIS_THRE: 0.2
+ MODEL_FILE: ''
+ NMS_THRE: 1.0
+ OKS_THRE: 0.9
+ USE_GT_BBOX: true
+ FLIP_TEST: true
+ POST_PROCESS: true
+ SHIFT_HEATMAP: true
+DEBUG:
+ DEBUG: true
+ SAVE_BATCH_IMAGES_GT: true
+ SAVE_BATCH_IMAGES_PRED: true
+ SAVE_HEATMAPS_GT: true
+ SAVE_HEATMAPS_PRED: true
\ No newline at end of file
diff --git a/main/postometro_utils/positional_encoding.py b/main/postometro_utils/positional_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4e6b03f59d43f7ebc89669636319766e78ab4f7
--- /dev/null
+++ b/main/postometro_utils/positional_encoding.py
@@ -0,0 +1,57 @@
+# ----------------------------------------------------------------------------------------------
+# FastMETRO Official Code
+# Copyright (c) POSTECH Algorithmic Machine Intelligence Lab. (P-AMI Lab.) All Rights Reserved
+# Licensed under the MIT license.
+# ----------------------------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved [see https://github.com/facebookresearch/detr/blob/main/LICENSE for details]
+# ----------------------------------------------------------------------------------------------
+
+import math
+import torch
+from torch import nn
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, bs, h, w, device):
+ ones = torch.ones((bs, h, w), dtype=torch.bool, device=device)
+ y_embed = ones.cumsum(1, dtype=torch.float32)
+ x_embed = ones.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device)
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.num_pos_feats) # cancel warning
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+
+def build_position_encoding(pos_type, hidden_dim):
+ N_steps = hidden_dim // 2
+ if pos_type == 'sine':
+ position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
+ else:
+ raise ValueError("not supported {pos_type}")
+
+ return position_embedding
\ No newline at end of file
diff --git a/main/postometro_utils/renderer_pyrender.py b/main/postometro_utils/renderer_pyrender.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c4ec0597bba5561b98dfbe0a548a017ae7c42f8
--- /dev/null
+++ b/main/postometro_utils/renderer_pyrender.py
@@ -0,0 +1,225 @@
+# ----------------------------------------------------------------------------------------------
+# Modified from Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE)
+# Copyright (c) Hongsuk Choi. All Rights Reserved [see https://github.com/hongsukchoi/Pose2Mesh_RELEASE/blob/main/LICENSE for details]
+# ----------------------------------------------------------------------------------------------
+
+import os
+os.environ['PYOPENGL_PLATFORM'] = 'osmesa'
+import torch
+import numpy as np
+import torch.nn.functional as F
+import math
+import cv2
+import trimesh
+import pyrender
+from pyrender.constants import RenderFlags
+
+def crop_bbox(bbox_meta, resolution, rgb, valid_mask):
+ bbox, original_img_height, original_img_width = bbox_meta['bbox'], *bbox_meta['img_hw']
+ start_x = int(bbox[0])
+ start_y = int(bbox[1])
+ end_x = start_x + int(resolution[0]) # w + start_x
+ end_y = start_y + int(resolution[1]) # h + start_y
+ real_start_x, real_start_y, real_end_x, real_end_y = max(0, start_x), max(0, start_y), min(original_img_width, end_x), min(original_img_height, end_y)
+ max_height, max_width = rgb.shape[:2]
+ real_rgb = rgb[(real_start_y - start_y):((real_end_y - end_y) if real_end_y < end_y else max_height),
+ (real_start_x - start_x):((real_end_x - end_x) if real_end_x < end_x else max_width)].copy()
+ real_valid_mask = valid_mask[(real_start_y - start_y):((real_end_y - end_y) if real_end_y < end_y else max_height),
+ (real_start_x - start_x):((real_end_x - end_x) if real_end_x < end_x else max_width)].copy()
+ return {'bbox': [real_start_x, real_start_y, real_end_x, real_end_y], 'img_hw': [original_img_height, original_img_width]}, real_rgb, real_valid_mask
+
+
+class WeakPerspectiveCamera(pyrender.Camera):
+ def __init__(self, scale, translation, znear=pyrender.camera.DEFAULT_Z_NEAR, zfar=None, name=None):
+ super(WeakPerspectiveCamera, self).__init__(znear=znear, zfar=zfar, name=name)
+ self.scale = scale
+ self.translation = translation
+
+ def get_projection_matrix(self, width=None, height=None):
+ P = np.eye(4)
+ P[0, 0] = self.scale[0]
+ P[1, 1] = self.scale[1]
+ P[0, 3] = self.translation[0] * self.scale[0]
+ P[1, 3] = -self.translation[1] * self.scale[1]
+ P[2, 2] = -1
+ return P
+
+
+class PyRender_Renderer:
+ def __init__(self, resolution=(256, 256), faces=None, orig_img=False, wireframe=False):
+ self.resolution = resolution
+ self.faces = faces
+ self.orig_img = orig_img
+ self.wireframe = wireframe
+ self.renderer = pyrender.OffscreenRenderer(viewport_width=self.resolution[0],
+ viewport_height=self.resolution[1],
+ point_size=1.0)
+
+ # set the scene & create light source
+ self.scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], ambient_light=(0.05, 0.05, 0.05))
+ light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=3.0)
+ light_pose = trimesh.transformations.rotation_matrix(np.radians(-45), [1, 0, 0])
+ self.scene.add(light, pose=light_pose)
+ light_pose = trimesh.transformations.rotation_matrix(np.radians(45), [0, 1, 0])
+ self.scene.add(light, pose=light_pose)
+
+ # mesh colors
+ self.colors_dict = {'blue': np.array([0.35, 0.60, 0.92]),
+ 'neutral': np.array([0.7, 0.7, 0.6]),
+ 'pink': np.array([0.7, 0.5, 0.5]),
+ 'white': np.array([1.0, 0.98, 0.94]),
+ 'green': np.array([0.5, 0.55, 0.3]),
+ 'sky': np.array([0.3, 0.5, 0.55])}
+
+ def __call__(self, verts, bbox_meta, img=np.zeros((224, 224, 3)), cam=np.array([1, 0, 0]),
+ angle=None, axis=None, mesh_filename=None, color_type=None, color=[0.7, 0.7, 0.6]):
+ if color_type != None:
+ color = self.colors_dict[color_type]
+
+ mesh = trimesh.Trimesh(vertices=verts, faces=self.faces, process=False)
+ Rx = trimesh.transformations.rotation_matrix(math.radians(180), [1, 0, 0])
+ mesh.apply_transform(Rx)
+ if mesh_filename is not None:
+ mesh.export(mesh_filename)
+ if angle and axis:
+ R = trimesh.transformations.rotation_matrix(math.radians(angle), axis)
+ mesh.apply_transform(R)
+
+ sy, tx, ty = cam
+ sx = sy
+ camera = WeakPerspectiveCamera(scale=[sx, sy], translation=[tx, ty], zfar=1000.0)
+
+ material = pyrender.MetallicRoughnessMaterial(
+ metallicFactor=0.2,
+ roughnessFactor=1.0,
+ alphaMode='OPAQUE',
+ baseColorFactor=(color[0], color[1], color[2], 1.0)
+ )
+
+ mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
+ mesh_node = self.scene.add(mesh, 'mesh')
+
+ camera_pose = np.eye(4)
+ cam_node = self.scene.add(camera, pose=camera_pose)
+
+ if self.wireframe:
+ render_flags = RenderFlags.RGBA | RenderFlags.ALL_WIREFRAME
+ else:
+ render_flags = RenderFlags.RGBA
+
+ rgb, depth = self.renderer.render(self.scene, flags=render_flags)
+ valid_mask = (depth > 0)[:, :, np.newaxis] # bbox size
+ # adjust bbox (no out of boundary)
+ bbox_meta, rgb, valid_mask = crop_bbox(bbox_meta, [self.resolution[0], self.resolution[1]], rgb, valid_mask)
+ # parse bbox
+ start_x, start_y, end_x, end_y, original_img_height, original_img_width = *bbox_meta['bbox'], *bbox_meta['img_hw']
+ # start_x = int(bbox_meta['bbox'][0])
+ # start_y = int(bbox_meta['bbox'][1])
+ # end_x = start_x + int(self.resolution[0]) # w + start_x
+ # end_y = start_y + int(self.resolution[1]) # h + start_y
+ whole_img_mask = np.zeros((original_img_height, original_img_width,1))
+ whole_img_mask[start_y:end_y, start_x:end_x] = valid_mask
+ whole_rgb = np.zeros((original_img_height, original_img_width,4))
+ whole_rgb[start_y:end_y, start_x:end_x,:3] = rgb
+ output_img = whole_rgb[:, :, :3] * whole_img_mask + (1 - whole_img_mask) * img
+ image = output_img.astype(np.uint8)
+
+ self.scene.remove_node(mesh_node)
+ self.scene.remove_node(cam_node)
+
+ return image
+
+
+def visualize_reconstruction_pyrender(img, vertices, camera, renderer, color='blue', focal_length=1000):
+ img = (img * 255).astype(np.uint8)
+ save_mesh_path = None
+ rend_color = color
+
+ # Render front view
+ rend_img = renderer(vertices,
+ img=img,
+ cam=camera,
+ color_type=rend_color,
+ mesh_filename=save_mesh_path)
+
+ combined = np.hstack([img, rend_img])
+
+ return combined
+
+def visualize_reconstruction_multi_view_pyrender(img, vertices, camera, renderer, color='blue', focal_length=1000):
+ img = (img * 255).astype(np.uint8)
+ save_mesh_path = None
+ rend_color = color
+
+ # Render front view
+ rend_img = renderer(vertices,
+ img=img,
+ cam=camera,
+ color_type=rend_color,
+ mesh_filename=save_mesh_path)
+
+ # Render side views
+ aroundy0 = cv2.Rodrigues(np.array([0, np.radians(0.), 0]))[0]
+ aroundy1 = cv2.Rodrigues(np.array([0, np.radians(90.), 0]))[0]
+ aroundy2 = cv2.Rodrigues(np.array([0, np.radians(180.), 0]))[0]
+ aroundy3 = cv2.Rodrigues(np.array([0, np.radians(270.), 0]))[0]
+ aroundy4 = cv2.Rodrigues(np.array([0, np.radians(45.), 0]))[0]
+ center = vertices.mean(axis=0)
+ rot_vertices0 = np.dot((vertices - center), aroundy0) + center
+ rot_vertices1 = np.dot((vertices - center), aroundy1) + center
+ rot_vertices2 = np.dot((vertices - center), aroundy2) + center
+ rot_vertices3 = np.dot((vertices - center), aroundy3) + center
+ rot_vertices4 = np.dot((vertices - center), aroundy4) + center
+
+ # Render side-view shape
+ img_side0 = renderer(rot_vertices0,
+ img=np.ones_like(img)*255,
+ cam=camera,
+ color_type=rend_color,
+ mesh_filename=save_mesh_path)
+ img_side1 = renderer(rot_vertices1,
+ img=np.ones_like(img)*255,
+ cam=camera,
+ color_type=rend_color,
+ mesh_filename=save_mesh_path)
+ img_side2 = renderer(rot_vertices2,
+ img=np.ones_like(img)*255,
+ cam=camera,
+ color_type=rend_color,
+ mesh_filename=save_mesh_path)
+ img_side3 = renderer(rot_vertices3,
+ img=np.ones_like(img)*255,
+ cam=camera,
+ color_type=rend_color,
+ mesh_filename=save_mesh_path)
+ img_side4 = renderer(rot_vertices4,
+ img=np.ones_like(img)*255,
+ cam=camera,
+ color_type=rend_color,
+ mesh_filename=save_mesh_path)
+
+ combined = np.hstack([img, rend_img, img_side0, img_side1, img_side2, img_side3, img_side4])
+
+ return combined
+
+def visualize_reconstruction_smpl_pyrender(img, vertices, camera, renderer, smpl_vertices, color='blue', focal_length=1000):
+ img = (img * 255).astype(np.uint8)
+ save_mesh_path = None
+ rend_color = color
+
+ # Render front view
+ rend_img = renderer(vertices,
+ img=img,
+ cam=camera,
+ color_type=rend_color,
+ mesh_filename=save_mesh_path)
+
+ rend_img_smpl = renderer(smpl_vertices,
+ img=img,
+ cam=camera,
+ color_type=rend_color,
+ mesh_filename=save_mesh_path)
+
+ combined = np.hstack([img, rend_img, rend_img_smpl])
+
+ return combined
\ No newline at end of file
diff --git a/main/postometro_utils/smpl.py b/main/postometro_utils/smpl.py
new file mode 100644
index 0000000000000000000000000000000000000000..108ee9dfd390f789fd2ed815454d7108ff1b454b
--- /dev/null
+++ b/main/postometro_utils/smpl.py
@@ -0,0 +1,291 @@
+# ----------------------------------------------------------------------------------------------
+# METRO (https://github.com/microsoft/MeshTransformer)
+# Copyright (c) Microsoft Corporation. All Rights Reserved [see https://github.com/microsoft/MeshTransformer/blob/main/LICENSE for details]
+# Licensed under the MIT license.
+# ----------------------------------------------------------------------------------------------
+"""
+This file contains the definition of the SMPL model
+
+It is adapted from opensource project GraphCMR (https://github.com/nkolot/GraphCMR/)
+"""
+from __future__ import division
+
+import torch
+import torch.nn as nn
+import numpy as np
+import scipy.sparse
+try:
+ import cPickle as pickle
+except ImportError:
+ import pickle
+
+from postometro_utils.geometric_layers import rodrigues
+import data.config as cfg
+
+class SMPL(nn.Module):
+
+ def __init__(self, gender='neutral'):
+ super(SMPL, self).__init__()
+
+ if gender=='m':
+ model_file=cfg.SMPL_Male
+ elif gender=='f':
+ model_file=cfg.SMPL_Female
+ else:
+ model_file=cfg.SMPL_FILE
+
+ smpl_model = pickle.load(open(model_file, 'rb'), encoding='latin1')
+ J_regressor = smpl_model['J_regressor'].tocoo()
+ row = J_regressor.row
+ col = J_regressor.col
+ data = J_regressor.data
+ i = torch.LongTensor([row, col])
+ v = torch.FloatTensor(data)
+ J_regressor_shape = [24, 6890]
+ self.register_buffer('J_regressor', torch.sparse.FloatTensor(i, v, J_regressor_shape).to_dense())
+ self.register_buffer('weights', torch.FloatTensor(smpl_model['weights']))
+ self.register_buffer('posedirs', torch.FloatTensor(smpl_model['posedirs']))
+ self.register_buffer('v_template', torch.FloatTensor(smpl_model['v_template']))
+ self.register_buffer('shapedirs', torch.FloatTensor(np.array(smpl_model['shapedirs'])))
+ self.register_buffer('faces', torch.from_numpy(smpl_model['f'].astype(np.int64)))
+ self.register_buffer('kintree_table', torch.from_numpy(smpl_model['kintree_table'].astype(np.int64)))
+ id_to_col = {self.kintree_table[1, i].item(): i for i in range(self.kintree_table.shape[1])}
+ self.register_buffer('parent', torch.LongTensor([id_to_col[self.kintree_table[0, it].item()] for it in range(1, self.kintree_table.shape[1])]))
+
+ self.pose_shape = [24, 3]
+ self.beta_shape = [10]
+ self.translation_shape = [3]
+
+ self.pose = torch.zeros(self.pose_shape)
+ self.beta = torch.zeros(self.beta_shape)
+ self.translation = torch.zeros(self.translation_shape)
+
+ self.verts = None
+ self.J = None
+ self.R = None
+
+ J_regressor_extra = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_TRAIN_EXTRA)).float()
+ self.register_buffer('J_regressor_extra', J_regressor_extra)
+ self.joints_idx = cfg.JOINTS_IDX
+
+ J_regressor_h36m_correct = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_H36M_correct)).float()
+ self.register_buffer('J_regressor_h36m_correct', J_regressor_h36m_correct)
+
+ def forward(self, pose, beta):
+ device = pose.device
+ batch_size = pose.shape[0]
+ v_template = self.v_template[None, :]
+ shapedirs = self.shapedirs.view(-1,10)[None, :].expand(batch_size, -1, -1)
+ beta = beta[:, :, None]
+ v_shaped = torch.matmul(shapedirs, beta).view(-1, 6890, 3) + v_template
+ # batched sparse matmul not supported in pytorch
+ J = []
+ for i in range(batch_size):
+ J.append(torch.matmul(self.J_regressor, v_shaped[i]))
+ J = torch.stack(J, dim=0)
+ # input it rotmat: (bs,24,3,3)
+ if pose.ndimension() == 4:
+ R = pose
+ # input it rotmat: (bs,72)
+ elif pose.ndimension() == 2:
+ pose_cube = pose.view(-1, 3) # (batch_size * 24, 1, 3)
+ R = rodrigues(pose_cube).view(batch_size, 24, 3, 3)
+ R = R.view(batch_size, 24, 3, 3)
+ I_cube = torch.eye(3)[None, None, :].to(device)
+ # I_cube = torch.eye(3)[None, None, :].expand(theta.shape[0], R.shape[1]-1, -1, -1)
+ lrotmin = (R[:,1:,:] - I_cube).view(batch_size, -1)
+ posedirs = self.posedirs.view(-1,207)[None, :].expand(batch_size, -1, -1)
+ v_posed = v_shaped + torch.matmul(posedirs, lrotmin[:, :, None]).view(-1, 6890, 3)
+ J_ = J.clone()
+ J_[:, 1:, :] = J[:, 1:, :] - J[:, self.parent, :]
+ G_ = torch.cat([R, J_[:, :, :, None]], dim=-1)
+ pad_row = torch.FloatTensor([0,0,0,1]).to(device).view(1,1,1,4).expand(batch_size, 24, -1, -1)
+ G_ = torch.cat([G_, pad_row], dim=2)
+ G = [G_[:, 0].clone()]
+ for i in range(1, 24):
+ G.append(torch.matmul(G[self.parent[i-1]], G_[:, i, :, :]))
+ G = torch.stack(G, dim=1)
+
+ rest = torch.cat([J, torch.zeros(batch_size, 24, 1).to(device)], dim=2).view(batch_size, 24, 4, 1)
+ zeros = torch.zeros(batch_size, 24, 4, 3).to(device)
+ rest = torch.cat([zeros, rest], dim=-1)
+ rest = torch.matmul(G, rest)
+ G = G - rest
+ T = torch.matmul(self.weights, G.permute(1,0,2,3).contiguous().view(24,-1)).view(6890, batch_size, 4, 4).transpose(0,1)
+ rest_shape_h = torch.cat([v_posed, torch.ones_like(v_posed)[:, :, [0]]], dim=-1)
+ v = torch.matmul(T, rest_shape_h[:, :, :, None])[:, :, :3, 0]
+ return v
+
+ def get_joints(self, vertices):
+ """
+ This method is used to get the joint locations from the SMPL mesh
+ Input:
+ vertices: size = (B, 6890, 3)
+ Output:
+ 3D joints: size = (B, 38, 3)
+ """
+ joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor])
+ joints_extra = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_extra])
+ joints = torch.cat((joints, joints_extra), dim=1)
+ joints = joints[:, cfg.JOINTS_IDX]
+ return joints
+
+ def get_24_joints(self, vertices):
+ """
+ This method is used to get the joint locations from the SMPL mesh
+ Input:
+ vertices: size = (B, 6890, 3)
+ Output:
+ 3D joints: size = (B, 38, 3)
+ """
+ joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor])
+ return joints
+
+ def get_h36m_joints(self, vertices):
+ """
+ This method is used to get the joint locations from the SMPL mesh
+ Input:
+ vertices: size = (B, 6890, 3)
+ Output:
+ 3D joints: size = (B, 17, 3)
+ """
+ joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_h36m_correct])
+ return joints
+
+class SparseMM(torch.autograd.Function):
+ """Redefine sparse @ dense matrix multiplication to enable backpropagation.
+ The builtin matrix multiplication operation does not support backpropagation in some cases.
+ """
+ @staticmethod
+ def forward(ctx, sparse, dense):
+ ctx.req_grad = dense.requires_grad
+ ctx.save_for_backward(sparse)
+ return torch.matmul(sparse, dense)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_input = None
+ sparse, = ctx.saved_tensors
+ if ctx.req_grad:
+ grad_input = torch.matmul(sparse.t(), grad_output)
+ return None, grad_input
+
+def spmm(sparse, dense):
+ return SparseMM.apply(sparse, dense)
+
+
+def scipy_to_pytorch(A, U, D):
+ """Convert scipy sparse matrices to pytorch sparse matrix."""
+ ptU = []
+ ptD = []
+
+ for i in range(len(U)):
+ u = scipy.sparse.coo_matrix(U[i])
+ i = torch.LongTensor(np.array([u.row, u.col]))
+ v = torch.FloatTensor(u.data)
+ ptU.append(torch.sparse.FloatTensor(i, v, u.shape))
+
+ for i in range(len(D)):
+ d = scipy.sparse.coo_matrix(D[i])
+ i = torch.LongTensor(np.array([d.row, d.col]))
+ v = torch.FloatTensor(d.data)
+ ptD.append(torch.sparse.FloatTensor(i, v, d.shape))
+
+ return ptU, ptD
+
+
+def adjmat_sparse(adjmat, nsize=1):
+ """Create row-normalized sparse graph adjacency matrix."""
+ adjmat = scipy.sparse.csr_matrix(adjmat)
+ if nsize > 1:
+ orig_adjmat = adjmat.copy()
+ for _ in range(1, nsize):
+ adjmat = adjmat * orig_adjmat
+ adjmat.data = np.ones_like(adjmat.data)
+ for i in range(adjmat.shape[0]):
+ adjmat[i,i] = 1
+ num_neighbors = np.array(1 / adjmat.sum(axis=-1))
+ adjmat = adjmat.multiply(num_neighbors)
+ adjmat = scipy.sparse.coo_matrix(adjmat)
+ row = adjmat.row
+ col = adjmat.col
+ data = adjmat.data
+ i = torch.LongTensor(np.array([row, col]))
+ v = torch.from_numpy(data).float()
+ adjmat = torch.sparse.FloatTensor(i, v, adjmat.shape)
+ return adjmat
+
+def get_graph_params(filename, nsize=1):
+ """Load and process graph adjacency matrix and upsampling/downsampling matrices."""
+ data = np.load(filename, encoding='latin1', allow_pickle=True)
+ A = data['A']
+ U = data['U']
+ D = data['D']
+ U, D = scipy_to_pytorch(A, U, D)
+ A = [adjmat_sparse(a, nsize=nsize) for a in A]
+ return A, U, D
+
+
+class Mesh(object):
+ """Mesh object that is used for handling certain graph operations."""
+ def __init__(self, filename=cfg.SMPL_sampling_matrix,
+ num_downsampling=1, nsize=1, device=torch.device('cuda')):
+ self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
+ self._A = [a.to(device) for a in self._A]
+ self._U = [u.to(device) for u in self._U]
+ self._D = [d.to(device) for d in self._D]
+ self.num_downsampling = num_downsampling
+
+ # load template vertices from SMPL and normalize them
+ smpl = SMPL()
+ ref_vertices = smpl.v_template
+ center = 0.5*(ref_vertices.max(dim=0)[0] + ref_vertices.min(dim=0)[0])[None]
+ ref_vertices -= center
+ ref_vertices /= ref_vertices.abs().max().item()
+
+ self._ref_vertices = ref_vertices.to(device)
+ self.faces = smpl.faces.int().to(device)
+
+ @property
+ def ref_vertices(self):
+ """Return the template vertices at the specified subsampling level."""
+ ref_vertices = self._ref_vertices
+ for i in range(self.num_downsampling):
+ ref_vertices = torch.spmm(self._D[i], ref_vertices)
+ return ref_vertices
+
+ def adjmat(self, num_downsampling):
+ """Return the graph adjacency matrix at the specified subsampling level."""
+ return self._A[num_downsampling].float()
+
+ def downsample(self, x, n1=0, n2=None):
+ """Downsample mesh."""
+ if n2 is None:
+ n2 = self.num_downsampling
+ if x.ndimension() < 3:
+ for i in range(n1, n2):
+ x = spmm(self._D[i], x)
+ elif x.ndimension() == 3:
+ out = []
+ for i in range(x.shape[0]):
+ y = x[i]
+ for j in range(n1, n2):
+ y = spmm(self._D[j], y)
+ out.append(y)
+ x = torch.stack(out, dim=0)
+ return x
+
+ def upsample(self, x, n1=1, n2=0):
+ """Upsample mesh."""
+ if x.ndimension() < 3:
+ for i in reversed(range(n2, n1)):
+ x = spmm(self._U[i], x)
+ elif x.ndimension() == 3:
+ out = []
+ for i in range(x.shape[0]):
+ y = x[i]
+ for j in reversed(range(n2, n1)):
+ y = spmm(self._U[j], y)
+ out.append(y)
+ x = torch.stack(out, dim=0)
+ return x
diff --git a/main/postometro_utils/transformer.py b/main/postometro_utils/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..35132acf9a03128252539d28407620d52f25f361
--- /dev/null
+++ b/main/postometro_utils/transformer.py
@@ -0,0 +1,249 @@
+# ----------------------------------------------------------------------------------------------
+# FastMETRO Official Code
+# Copyright (c) POSTECH Algorithmic Machine Intelligence Lab. (P-AMI Lab.) All Rights Reserved
+# Licensed under the MIT license.
+# ----------------------------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved [see https://github.com/facebookresearch/detr/blob/main/LICENSE for details]
+# ----------------------------------------------------------------------------------------------
+
+"""
+Transformer encoder-decoder architecture in FastMETRO model.
+"""
+import copy
+import torch
+import torch.nn.functional as F
+from typing import Optional
+from torch import nn, Tensor
+
+class Transformer(nn.Module):
+ """Transformer encoder-decoder"""
+ def __init__(self, model_dim=512, nhead=8, num_enc_layers=3, num_dec_layers=3,
+ feedforward_dim=2048, dropout=0.1, activation="relu"):
+ """
+ Parameters:
+ - model_dim: The hidden dimension size in the transformer architecture
+ - nhead: The number of attention heads in the attention modules
+ - num_enc_layers: The number of encoder layers in the transformer encoder
+ - num_dec_layers: The number of decoder layers in the transformer decoder
+ - feedforward_dim: The hidden dimension size in MLP
+ - dropout: The dropout rate in the transformer architecture
+ - activation: The activation function used in MLP
+ """
+ super().__init__()
+ self.model_dim = model_dim
+ self.nhead = nhead
+
+ # transformer encoder
+ encoder_layer = TransformerEncoderLayer(model_dim, nhead, feedforward_dim, dropout, activation)
+ encoder_norm = nn.LayerNorm(model_dim)
+ self.encoder = TransformerEncoder(encoder_layer, num_enc_layers, encoder_norm)
+
+ # transformer decoder
+ decoder_layer = TransformerDecoderLayer(model_dim, nhead, feedforward_dim, dropout, activation)
+ decoder_norm = nn.LayerNorm(model_dim)
+ self.decoder = TransformerDecoder(decoder_layer, num_dec_layers, decoder_norm)
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, img_features, cam_token, jv_tokens, pos_embed, pct_token = None, attention_mask=None):
+ device = img_features.device
+ hw, bs, _ = img_features.shape # (height * width), batch_size, feature_dim
+
+ if pct_token is None:
+ mask = torch.zeros((bs, hw), dtype=torch.bool, device=device) # batch_size X (height * width)
+ else:
+ pct_len = pct_token.size(0)
+ mask = torch.zeros((bs, hw + pct_len), dtype=torch.bool, device=device)
+ # Transformer Encoder
+ zero_mask = torch.zeros((bs, 1), dtype=torch.bool, device=device) # batch_size X 1
+ mem_mask = torch.cat([zero_mask, mask], dim=1) # batch_size X (1 + height * width)
+ cam_with_img = torch.cat([cam_token, img_features], dim=0) # (1 + height * width) X batch_size X feature_dim
+ e_outputs = self.encoder(cam_with_img, pct_token, src_key_padding_mask=mem_mask, pos=pos_embed) # (1 + height * width) X batch_size X feature_dim
+ if pct_token is not None:
+ cam_features, enc_img_features, pct_features = e_outputs.split([1, hw, pct_len], dim=0)
+ enc_img_features = torch.cat([enc_img_features, pct_features], dim = 0) # concat pct to img features
+ else:
+ cam_features, enc_img_features = e_outputs.split([1, hw], dim=0)
+ pct_features = None
+
+ # Transformer Decoder
+ zero_tgt = torch.zeros_like(jv_tokens) # (num_joints + num_vertices) X batch_size X feature_dim
+ jv_features = self.decoder(jv_tokens, enc_img_features,
+ tgt_mask=attention_mask,
+ memory_key_padding_mask=mask, pos=pos_embed, query_pos=zero_tgt) # (num_joints + num_vertices) X batch_size X feature_dim
+
+ return cam_features, enc_img_features, jv_features, pct_features
+
+
+class TransformerEncoder(nn.Module):
+ """Transformer encoder"""
+ def __init__(self, encoder_layer, num_layers, norm=None):
+ super().__init__()
+ self.num_layers = num_layers
+ self.norm = norm
+ self.layers = _get_clones(encoder_layer, num_layers)
+
+ def forward(self, src, pct_token = None,
+ mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ if pct_token is not None:
+ output = torch.concat([src, pct_token], dim = 0)
+ else:
+ output = src
+
+ for layer in self.layers:
+ output = layer(output, src_mask=mask,
+ src_key_padding_mask=src_key_padding_mask, pos=pos)
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ return output
+
+
+class TransformerDecoder(nn.Module):
+ """Transformer decoder"""
+ def __init__(self, decoder_layer, num_layers, norm=None):
+ super().__init__()
+ self.num_layers = num_layers
+ self.norm = norm
+ self.layers = _get_clones(decoder_layer, num_layers)
+
+ def forward(self, tgt, memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+ output = tgt
+
+ for layer in self.layers:
+ output = layer(output, memory, tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask,
+ pos=pos, query_pos=query_pos)
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ return output
+
+
+class TransformerEncoderLayer(nn.Module):
+ """Transformer encoder layer"""
+ def __init__(self, model_dim, nhead, feedforward_dim=2048, dropout=0.1, activation="relu"):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(model_dim, nhead, dropout=dropout, batch_first=False)
+
+ # MLP
+ self.linear1 = nn.Linear(model_dim, feedforward_dim)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(feedforward_dim, model_dim)
+
+ # Layer Normalization & Dropout
+ self.norm1 = nn.LayerNorm(model_dim)
+ self.norm2 = nn.LayerNorm(model_dim)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ # tensor[0] is for a camera token (no positional encoding)
+ if pos is not None:
+ pos_len = pos.size(0)
+ return tensor if pos is None else torch.cat([tensor[:1], (tensor[1:1+pos_len] + pos), tensor[1+pos_len:]], dim=0)
+
+ def forward(self, src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ src2 = self.norm1(src)
+ q = k = self.with_pos_embed(src2, pos)
+ src2 = self.self_attn(q, k, value=src2, attn_mask=None,
+ key_padding_mask=None)[0]
+ src = src + self.dropout1(src2)
+ src2 = self.norm2(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
+ src = src + self.dropout2(src2)
+ return src
+
+
+class TransformerDecoderLayer(nn.Module):
+ """Transformer decoder layer"""
+ def __init__(self, model_dim, nhead, feedforward_dim=2048, dropout=0.1, activation="relu"):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(model_dim, nhead, dropout=dropout)
+ self.multihead_attn = nn.MultiheadAttention(model_dim, nhead, dropout=dropout)
+
+ # MLP
+ self.linear1 = nn.Linear(model_dim, feedforward_dim)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(feedforward_dim, model_dim)
+
+ # Layer Normalization & Dropout
+ self.norm1 = nn.LayerNorm(model_dim)
+ self.norm2 = nn.LayerNorm(model_dim)
+ self.norm3 = nn.LayerNorm(model_dim)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+ self.activation = _get_activation_fn(activation)
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ if pos is not None:
+ pos_len = pos.size(0)
+ return tensor if pos is None else torch.cat([tensor[:pos_len] + pos, tensor[pos_len:]], dim = 0)
+
+ def forward(self, tgt, memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+ tgt2 = self.norm1(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory, attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask)[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt2 = self.norm3(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout3(tgt2)
+ return tgt
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError("activation should be relu/gelu, not {activation}.")
+
+def build_transformer(transformer_config):
+ return Transformer(model_dim=transformer_config['model_dim'],
+ dropout=transformer_config['dropout'],
+ nhead=transformer_config['nhead'],
+ feedforward_dim=transformer_config['feedforward_dim'],
+ num_enc_layers=transformer_config['num_enc_layers'],
+ num_dec_layers=transformer_config['num_dec_layers'])
\ No newline at end of file