diff --git a/app.py b/app.py index 8ee0a9f7de2d664fcdbb768ad90da1b8efb768e5..8af72faa0092c69210f873b83ae34009278dced8 100644 --- a/app.py +++ b/app.py @@ -33,40 +33,7 @@ def infer(image_input, in_threshold=0.5, num_people="Single person", render_mesh os.system(f'rm -rf {OUT_FOLDER}/*') multi_person = False if (num_people == "Single person") else True vis_img, num_bbox, mmdet_box = inferer.infer(image_input, in_threshold, 0, multi_person, not(render_mesh)) - - # cap = cv2.VideoCapture(video_input) - # fps = math.ceil(cap.get(5)) - # width = int(cap.get(3)) - # height = int(cap.get(4)) - # fourcc = cv2.VideoWriter_fourcc(*'mp4v') - # video_path = osp.join(OUT_FOLDER, f'out.m4v') - # final_video_path = osp.join(OUT_FOLDER, f'out.mp4') - # video_output = cv2.VideoWriter(video_path, fourcc, fps, (width, height)) - # success = 1 - # frame = 0 - # while success: - # success, original_img = cap.read() - # if not success: - # break - # frame += 1 - # img, mesh_paths, smplx_paths = inferer.infer(original_img, in_threshold, frame, multi_person, not(render_mesh)) - # video_output.write(img) - # yield img, None, None, None - # cap.release() - # video_output.release() - # cv2.destroyAllWindows() - # os.system(f'ffmpeg -i {video_path} -c copy {final_video_path}') - - # #Compress mesh and smplx files - # save_path_mesh = os.path.join(OUT_FOLDER, 'mesh') - # save_mesh_file = os.path.join(OUT_FOLDER, 'mesh.zip') - # os.makedirs(save_path_mesh, exist_ok= True) - # save_path_smplx = os.path.join(OUT_FOLDER, 'smplx') - # save_smplx_file = os.path.join(OUT_FOLDER, 'smplx.zip') - # os.makedirs(save_path_smplx, exist_ok= True) - # os.system(f'zip -r {save_mesh_file} {save_path_mesh}') - # os.system(f'zip -r {save_smplx_file} {save_path_smplx}') - # yield img, video_path, save_mesh_file, save_smplx_file + return vis_img, "bbox num: {}, bbox meta: {}".format(num_bbox, mmdet_box) TITLE = '''

PostoMETRO: Pose Token Enhanced Mesh Transformer for Robust 3D Human Mesh Recovery

''' @@ -113,6 +80,9 @@ with gr.Blocks(title="PostoMETRO", css=".gradio-container") as demo: ['/home/user/app/assets/02.jpg'], ['/home/user/app/assets/03.jpg'], ['/home/user/app/assets/04.jpg'], + ['/home/user/app/assets/05.jpg'], + ['/home/user/app/assets/06.jpg'], + ['/home/user/app/assets/07.jpg'], ], inputs=[image_input, 0.2]) diff --git a/assets/02.jpg b/assets/02.jpg index e3176774fcf49344c3b644cb7ba4b03d581bbb8e..6c13965baa0fe4a8e27dee7178dc921ca931cc5d 100644 Binary files a/assets/02.jpg and b/assets/02.jpg differ diff --git a/assets/04.jpg b/assets/04.jpg index 846802cc2c91b0bd93b0aa3606e5a8a124eca151..e58fde18c3c9f558b7d894d470f0a5662721b5cb 100644 Binary files a/assets/04.jpg and b/assets/04.jpg differ diff --git a/assets/05.jpg b/assets/05.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a4f520bc9c426c09bb834f8d0ed28992ba718482 Binary files /dev/null and b/assets/05.jpg differ diff --git a/assets/06.jpg b/assets/06.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c58f911ffc33611b7133bc16bff82b81403b24aa Binary files /dev/null and b/assets/06.jpg differ diff --git a/assets/07.jpg b/assets/07.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b99df507302d7193d4d30b6e136636a78cfa4940 Binary files /dev/null and b/assets/07.jpg differ diff --git a/common/utils/__pycache__/__init__.cpython-39.pyc b/common/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f038c1d5b690eefd173a0276cb2eaea13d87b765 Binary files /dev/null and b/common/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/common/utils/__pycache__/inference_utils.cpython-39.pyc b/common/utils/__pycache__/inference_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfd27635a9573b1b638134c057d81378afa4e736 Binary files /dev/null and b/common/utils/__pycache__/inference_utils.cpython-39.pyc differ diff --git a/common/utils/__pycache__/preprocessing.cpython-39.pyc b/common/utils/__pycache__/preprocessing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb5436efaaddef63f1fa8b7f2060c3d7c4fcf8aa Binary files /dev/null and b/common/utils/__pycache__/preprocessing.cpython-39.pyc differ diff --git a/common/utils/__pycache__/transforms.cpython-39.pyc b/common/utils/__pycache__/transforms.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8f53f05922da528904b78eeb8193f83b4665253 Binary files /dev/null and b/common/utils/__pycache__/transforms.cpython-39.pyc differ diff --git a/common/utils/__pycache__/vis.cpython-39.pyc b/common/utils/__pycache__/vis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ed1d51ef5a2cd9c2a73e6ab0ad4242bf8735736 Binary files /dev/null and b/common/utils/__pycache__/vis.cpython-39.pyc differ diff --git a/common/utils/vis.py b/common/utils/vis.py index f5b7dd3b6775e16bff638c8383ed04ab916978c1..10d2dc9f98713d9faca95388d80ca27eaad7a569 100644 --- a/common/utils/vis.py +++ b/common/utils/vis.py @@ -5,7 +5,7 @@ from mpl_toolkits.mplot3d import Axes3D import matplotlib.pyplot as plt import matplotlib as mpl import os -os.environ["PYOPENGL_PLATFORM"] = "egl" +os.environ["PYOPENGL_PLATFORM"] = "osmesa" import pyrender import trimesh from config import cfg @@ -138,6 +138,20 @@ def perspective_projection(vertices, cam_param): vertices[:, 1] = vertices[:, 1] * fy / vertices[:, 2] + cy return vertices +class WeakPerspectiveCamera(pyrender.Camera): + def __init__(self, scale, translation, znear=pyrender.camera.DEFAULT_Z_NEAR, zfar=None, name=None): + super(WeakPerspectiveCamera, self).__init__(znear=znear, zfar=zfar, name=name) + self.scale = scale + self.translation = translation + + def get_projection_matrix(self, width=None, height=None): + P = np.eye(4) + P[0, 0] = self.scale[0] + P[1, 1] = self.scale[1] + P[0, 3] = self.translation[0] * self.scale[0] + P[1, 3] = -self.translation[1] * self.scale[1] + P[2, 2] = -1 + return P def render_mesh(img, mesh, face, cam_param, mesh_as_vertices=False): if mesh_as_vertices: @@ -150,28 +164,32 @@ def render_mesh(img, mesh, face, cam_param, mesh_as_vertices=False): rot = trimesh.transformations.rotation_matrix( np.radians(180), [1, 0, 0]) mesh.apply_transform(rot) - material = pyrender.MetallicRoughnessMaterial(metallicFactor=0.0, alphaMode='OPAQUE', baseColorFactor=(1.0, 1.0, 0.9, 1.0)) - mesh = pyrender.Mesh.from_trimesh(mesh, material=material, smooth=False) - scene = pyrender.Scene(ambient_light=(0.3, 0.3, 0.3)) + color=[0.7, 0.7, 0.6] + material = pyrender.MetallicRoughnessMaterial( + metallicFactor=0.2, + roughnessFactor=1.0, + alphaMode='OPAQUE', + baseColorFactor=(color[0], color[1], color[2], 1.0) + ) + mesh = pyrender.Mesh.from_trimesh(mesh, material=material) + scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], ambient_light=(0.05, 0.05, 0.05)) + light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=3.0) + light_pose = trimesh.transformations.rotation_matrix(np.radians(-45), [1, 0, 0]) + scene.add(light, pose=light_pose) + light_pose = trimesh.transformations.rotation_matrix(np.radians(45), [0, 1, 0]) + scene.add(light, pose=light_pose) scene.add(mesh, 'mesh') - focal, princpt = cam_param['focal'], cam_param['princpt'] - camera = pyrender.IntrinsicsCamera(fx=focal[0], fy=focal[1], cx=princpt[0], cy=princpt[1]) - scene.add(camera) + # focal, princpt = cam_param['focal'], cam_param['princpt'] + # camera = pyrender.IntrinsicsCamera(fx=focal[0], fy=focal[1], cx=princpt[0], cy=princpt[1]) + sx, sy, tx, ty = cam_param + camera = WeakPerspectiveCamera(scale=[sx, sy], translation=[tx, ty], zfar=1000.0) + camera_pose = np.eye(4) + scene.add(camera, pose=camera_pose) # renderer renderer = pyrender.OffscreenRenderer(viewport_width=img.shape[1], viewport_height=img.shape[0], point_size=1.0) - # light - light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=0.8) - light_pose = np.eye(4) - light_pose[:3, 3] = np.array([0, -1, 1]) - scene.add(light, pose=light_pose) - light_pose[:3, 3] = np.array([0, 1, 1]) - scene.add(light, pose=light_pose) - light_pose[:3, 3] = np.array([1, 1, 2]) - scene.add(light, pose=light_pose) - # render rgb, depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) rgb = rgb[:,:,:3].astype(np.float32) diff --git a/main/__pycache__/config.cpython-39.pyc b/main/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e0d91d6cbb13956029fded68bee60952a27873b Binary files /dev/null and b/main/__pycache__/config.cpython-39.pyc differ diff --git a/main/__pycache__/postometro.cpython-39.pyc b/main/__pycache__/postometro.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03d340f47d4f0e13d4b2e18799e1caa874c92882 Binary files /dev/null and b/main/__pycache__/postometro.cpython-39.pyc differ diff --git a/main/config/config_postometro.py b/main/config/config_postometro.py index 00a09484bf30e29bdce054d284d65e019bf9799e..ac522eb12ef6a146ef7106ab3f9a32af4e6121df 100644 --- a/main/config/config_postometro.py +++ b/main/config/config_postometro.py @@ -3,109 +3,12 @@ import os.path as osp # will be update in exp num_gpus = -1 -exp_name = 'output/exp1/pre_analysis' - -# quick access -save_epoch = 1 -lr = 1e-5 -end_epoch = 10 -train_batch_size = 16 - -syncbn = True -bbox_ratio = 1.2 - -# continue -continue_train = False -start_over = True - -# dataset setting -agora_fix_betas = True -agora_fix_global_orient_transl = True -agora_valid_root_pose = True - -# all -dataset_list = ['Human36M', 'MSCOCO', 'MPII', 'AGORA', 'EHF', 'SynBody', 'GTA_Human2', \ - 'EgoBody_Egocentric', 'EgoBody_Kinect', 'UBody', 'PW3D', 'MuCo', 'PROX'] -trainset_3d = ['MSCOCO','AGORA', 'UBody'] -trainset_2d = ['PW3D', 'MPII', 'Human36M'] -trainset_humandata = ['BEDLAM', 'SPEC', 'GTA_Human2','SynBody', 'PoseTrack', - 'EgoBody_Egocentric', 'PROX', 'CrowdPose', - 'EgoBody_Kinect', 'MPI_INF_3DHP', 'RICH', 'MuCo', 'InstaVariety', - 'Behave', 'UP3D', 'ARCTIC', - 'OCHuman', 'CHI3D', 'RenBody_HiRes', 'MTP', 'HumanSC3D', 'RenBody', - 'FIT3D', 'Talkshow' , 'SSP3D', 'LSPET'] -testset = 'EHF' - -use_cache = True -# downsample -BEDLAM_train_sample_interval = 5 -EgoBody_Kinect_train_sample_interval = 10 -train_sample_interval = 10 # UBody -MPI_INF_3DHP_train_sample_interval = 5 -InstaVariety_train_sample_interval = 10 -RenBody_HiRes_train_sample_interval = 5 -ARCTIC_train_sample_interval = 10 -# RenBody_train_sample_interval = 10 -FIT3D_train_sample_interval = 10 -Talkshow_train_sample_interval = 10 - -# strategy -data_strategy = 'balance' # 'balance' need to define total_data_len -total_data_len = 4500000 - -# model -smplx_loss_weight = 1.0 #2 for agora_model for smplx shape -smplx_pose_weight = 10.0 - -smplx_kps_3d_weight = 100.0 -smplx_kps_2d_weight = 1.0 -net_kps_2d_weight = 1.0 - -agora_benchmark = 'agora_model' # 'agora_model', 'test_only' - -model_type = 'smpler_x_h' -encoder_config_file = 'main/transformer_utils/configs/smpler_x/encoder/body_encoder_huge.py' -encoder_pretrained_model_path = 'pretrained_models/vitpose_huge.pth' -feat_dim = 1280 - -## =====FIXED ARGS============================================================ -## model setting -upscale = 4 -hand_pos_joint_num = 20 -face_pos_joint_num = 72 -num_task_token = 24 -num_noise_sample = 0 - -## UBody setting -train_sample_interval = 10 -test_sample_interval = 100 -make_same_len = False ## input, output size input_img_shape = (256, 256) input_body_shape = (256, 256) -output_hm_shape = (16, 16, 12) -input_hand_shape = (256, 256) -output_hand_hm_shape = (16, 16, 16) -output_face_hm_shape = (8, 8, 8) -input_face_shape = (192, 192) -focal = (5000, 5000) # virtual focal lengths -princpt = (input_body_shape[1] / 2, input_body_shape[0] / 2) # virtual principal point position -body_3d_size = 2 -hand_3d_size = 0.3 -face_3d_size = 0.3 -camera_3d_size = 2.5 - -## training config -print_iters = 100 -lr_mult = 1 -## testing config -test_batch_size = 32 - -## others -num_thread = 2 -vis = False +renderer_input_body_shape = (256, 256) +focal = (5000, 5000) # virtual focal lengths +princpt = (renderer_input_body_shape[1] / 2, renderer_input_body_shape[0] / 2) # virtual principal point position -## directory -output_dir, model_dir, vis_dir, log_dir, result_dir, code_dir = None, None, None, None, None, None diff --git a/main/config/config_smpler_x_b32.py b/main/config/config_smpler_x_b32.py deleted file mode 100644 index b737e5307a76fbeaf6eefa6e2bc775c52760fab4..0000000000000000000000000000000000000000 --- a/main/config/config_smpler_x_b32.py +++ /dev/null @@ -1,112 +0,0 @@ -import os -import os.path as osp - -# will be update in exp -num_gpus = -1 -exp_name = 'output/exp1/pre_analysis' - -# quick access -save_epoch = 1 -lr = 1e-5 -end_epoch = 10 -train_batch_size = 32 - -syncbn = True -bbox_ratio = 1.2 - -# continue -continue_train = False -start_over = True - -# dataset setting -agora_fix_betas = True -agora_fix_global_orient_transl = True -agora_valid_root_pose = True - -# all -dataset_list = ['Human36M', 'MSCOCO', 'MPII', 'AGORA', 'EHF', 'SynBody', 'GTA_Human2', \ - 'EgoBody_Egocentric', 'EgoBody_Kinect', 'UBody', 'PW3D', 'MuCo', 'PROX'] -trainset_3d = ['MSCOCO','AGORA', 'UBody'] -trainset_2d = ['PW3D', 'MPII', 'Human36M'] -trainset_humandata = ['BEDLAM', 'SPEC', 'GTA_Human2','SynBody', 'PoseTrack', - 'EgoBody_Egocentric', 'PROX', 'CrowdPose', - 'EgoBody_Kinect', 'MPI_INF_3DHP', 'RICH', 'MuCo', 'InstaVariety', - 'Behave', 'UP3D', 'ARCTIC', - 'OCHuman', 'CHI3D', 'RenBody_HiRes', 'MTP', 'HumanSC3D', 'RenBody', - 'FIT3D', 'Talkshow' , 'SSP3D', 'LSPET'] -testset = 'EHF' - -use_cache = True -# downsample -BEDLAM_train_sample_interval = 5 -EgoBody_Kinect_train_sample_interval = 10 -train_sample_interval = 10 # UBody -MPI_INF_3DHP_train_sample_interval = 5 -InstaVariety_train_sample_interval = 10 -RenBody_HiRes_train_sample_interval = 5 -ARCTIC_train_sample_interval = 10 -# RenBody_train_sample_interval = 10 -FIT3D_train_sample_interval = 10 -Talkshow_train_sample_interval = 10 - -# strategy -data_strategy = 'balance' # 'balance' need to define total_data_len -total_data_len = 4500000 - -# model -smplx_loss_weight = 1.0 #2 for agora_model for smplx shape -smplx_pose_weight = 10.0 - -smplx_kps_3d_weight = 100.0 -smplx_kps_2d_weight = 1.0 -net_kps_2d_weight = 1.0 - -agora_benchmark = 'agora_model' # 'agora_model', 'test_only' - -model_type = 'smpler_x_b' -encoder_config_file = 'main/transformer_utils/configs/smpler_x/encoder/body_encoder_base.py' -encoder_pretrained_model_path = 'pretrained_models/vitpose_base.pth' -feat_dim = 768 - - -## =====FIXED ARGS============================================================ -## model setting -upscale = 4 -hand_pos_joint_num = 20 -face_pos_joint_num = 72 -num_task_token = 24 -num_noise_sample = 0 - -## UBody setting -train_sample_interval = 10 -test_sample_interval = 100 -make_same_len = False - -## input, output size -input_img_shape = (512, 384) -input_body_shape = (256, 192) -output_hm_shape = (16, 16, 12) -input_hand_shape = (256, 256) -output_hand_hm_shape = (16, 16, 16) -output_face_hm_shape = (8, 8, 8) -input_face_shape = (192, 192) -focal = (5000, 5000) # virtual focal lengths -princpt = (input_body_shape[1] / 2, input_body_shape[0] / 2) # virtual principal point position -body_3d_size = 2 -hand_3d_size = 0.3 -face_3d_size = 0.3 -camera_3d_size = 2.5 - -## training config -print_iters = 100 -lr_mult = 1 - -## testing config -test_batch_size = 32 - -## others -num_thread = 2 -vis = False - -## directory -output_dir, model_dir, vis_dir, log_dir, result_dir, code_dir = None, None, None, None, None, None diff --git a/main/config/config_smpler_x_h32.py b/main/config/config_smpler_x_h32.py deleted file mode 100644 index 2ffd86e9e965f9f2d3fd5efdb98ad2cb83fa81ed..0000000000000000000000000000000000000000 --- a/main/config/config_smpler_x_h32.py +++ /dev/null @@ -1,111 +0,0 @@ -import os -import os.path as osp - -# will be update in exp -num_gpus = -1 -exp_name = 'output/exp1/pre_analysis' - -# quick access -save_epoch = 1 -lr = 1e-5 -end_epoch = 10 -train_batch_size = 16 - -syncbn = True -bbox_ratio = 1.2 - -# continue -continue_train = False -start_over = True - -# dataset setting -agora_fix_betas = True -agora_fix_global_orient_transl = True -agora_valid_root_pose = True - -# all -dataset_list = ['Human36M', 'MSCOCO', 'MPII', 'AGORA', 'EHF', 'SynBody', 'GTA_Human2', \ - 'EgoBody_Egocentric', 'EgoBody_Kinect', 'UBody', 'PW3D', 'MuCo', 'PROX'] -trainset_3d = ['MSCOCO','AGORA', 'UBody'] -trainset_2d = ['PW3D', 'MPII', 'Human36M'] -trainset_humandata = ['BEDLAM', 'SPEC', 'GTA_Human2','SynBody', 'PoseTrack', - 'EgoBody_Egocentric', 'PROX', 'CrowdPose', - 'EgoBody_Kinect', 'MPI_INF_3DHP', 'RICH', 'MuCo', 'InstaVariety', - 'Behave', 'UP3D', 'ARCTIC', - 'OCHuman', 'CHI3D', 'RenBody_HiRes', 'MTP', 'HumanSC3D', 'RenBody', - 'FIT3D', 'Talkshow' , 'SSP3D', 'LSPET'] -testset = 'EHF' - -use_cache = True -# downsample -BEDLAM_train_sample_interval = 5 -EgoBody_Kinect_train_sample_interval = 10 -train_sample_interval = 10 # UBody -MPI_INF_3DHP_train_sample_interval = 5 -InstaVariety_train_sample_interval = 10 -RenBody_HiRes_train_sample_interval = 5 -ARCTIC_train_sample_interval = 10 -# RenBody_train_sample_interval = 10 -FIT3D_train_sample_interval = 10 -Talkshow_train_sample_interval = 10 - -# strategy -data_strategy = 'balance' # 'balance' need to define total_data_len -total_data_len = 4500000 - -# model -smplx_loss_weight = 1.0 #2 for agora_model for smplx shape -smplx_pose_weight = 10.0 - -smplx_kps_3d_weight = 100.0 -smplx_kps_2d_weight = 1.0 -net_kps_2d_weight = 1.0 - -agora_benchmark = 'agora_model' # 'agora_model', 'test_only' - -model_type = 'smpler_x_h' -encoder_config_file = 'main/transformer_utils/configs/smpler_x/encoder/body_encoder_huge.py' -encoder_pretrained_model_path = 'pretrained_models/vitpose_huge.pth' -feat_dim = 1280 - -## =====FIXED ARGS============================================================ -## model setting -upscale = 4 -hand_pos_joint_num = 20 -face_pos_joint_num = 72 -num_task_token = 24 -num_noise_sample = 0 - -## UBody setting -train_sample_interval = 10 -test_sample_interval = 100 -make_same_len = False - -## input, output size -input_img_shape = (512, 384) -input_body_shape = (256, 192) -output_hm_shape = (16, 16, 12) -input_hand_shape = (256, 256) -output_hand_hm_shape = (16, 16, 16) -output_face_hm_shape = (8, 8, 8) -input_face_shape = (192, 192) -focal = (5000, 5000) # virtual focal lengths -princpt = (input_body_shape[1] / 2, input_body_shape[0] / 2) # virtual principal point position -body_3d_size = 2 -hand_3d_size = 0.3 -face_3d_size = 0.3 -camera_3d_size = 2.5 - -## training config -print_iters = 100 -lr_mult = 1 - -## testing config -test_batch_size = 32 - -## others -num_thread = 2 -vis = False - -## directory -output_dir, model_dir, vis_dir, log_dir, result_dir, code_dir = None, None, None, None, None, None diff --git a/main/config/config_smpler_x_l32.py b/main/config/config_smpler_x_l32.py deleted file mode 100644 index 1cfedc0b6b59d17d2b666bfdfabff6c45069456b..0000000000000000000000000000000000000000 --- a/main/config/config_smpler_x_l32.py +++ /dev/null @@ -1,112 +0,0 @@ -import os -import os.path as osp - -# will be update in exp -num_gpus = -1 -exp_name = 'output/exp1/pre_analysis' - -# quick access -save_epoch = 1 -lr = 1e-5 -end_epoch = 10 -train_batch_size = 32 - -syncbn = True -bbox_ratio = 1.2 - -# continue -continue_train = False -start_over = True - -# dataset setting -agora_fix_betas = True -agora_fix_global_orient_transl = True -agora_valid_root_pose = True - -# all -dataset_list = ['Human36M', 'MSCOCO', 'MPII', 'AGORA', 'EHF', 'SynBody', 'GTA_Human2', \ - 'EgoBody_Egocentric', 'EgoBody_Kinect', 'UBody', 'PW3D', 'MuCo', 'PROX'] -trainset_3d = ['MSCOCO','AGORA', 'UBody'] -trainset_2d = ['PW3D', 'MPII', 'Human36M'] -trainset_humandata = ['BEDLAM', 'SPEC', 'GTA_Human2','SynBody', 'PoseTrack', - 'EgoBody_Egocentric', 'PROX', 'CrowdPose', - 'EgoBody_Kinect', 'MPI_INF_3DHP', 'RICH', 'MuCo', 'InstaVariety', - 'Behave', 'UP3D', 'ARCTIC', - 'OCHuman', 'CHI3D', 'RenBody_HiRes', 'MTP', 'HumanSC3D', 'RenBody', - 'FIT3D', 'Talkshow' , 'SSP3D', 'LSPET'] -testset = 'EHF' - -use_cache = True -# downsample -BEDLAM_train_sample_interval = 5 -EgoBody_Kinect_train_sample_interval = 10 -train_sample_interval = 10 # UBody -MPI_INF_3DHP_train_sample_interval = 5 -InstaVariety_train_sample_interval = 10 -RenBody_HiRes_train_sample_interval = 5 -ARCTIC_train_sample_interval = 10 -# RenBody_train_sample_interval = 10 -FIT3D_train_sample_interval = 10 -Talkshow_train_sample_interval = 10 - -# strategy -data_strategy = 'balance' # 'balance' need to define total_data_len -total_data_len = 4500000 - -# model -smplx_loss_weight = 1.0 #2 for agora_model for smplx shape -smplx_pose_weight = 10.0 - -smplx_kps_3d_weight = 100.0 -smplx_kps_2d_weight = 1.0 -net_kps_2d_weight = 1.0 - -agora_benchmark = 'agora_model' # 'agora_model', 'test_only' - -model_type = 'smpler_x_l' -encoder_config_file = 'main/transformer_utils/configs/smpler_x/encoder/body_encoder_large.py' -encoder_pretrained_model_path = 'pretrained_models/vitpose_large.pth' -feat_dim = 1024 - - -## =====FIXED ARGS============================================================ -## model setting -upscale = 4 -hand_pos_joint_num = 20 -face_pos_joint_num = 72 -num_task_token = 24 -num_noise_sample = 0 - -## UBody setting -train_sample_interval = 10 -test_sample_interval = 100 -make_same_len = False - -## input, output size -input_img_shape = (512, 384) -input_body_shape = (256, 192) -output_hm_shape = (16, 16, 12) -input_hand_shape = (256, 256) -output_hand_hm_shape = (16, 16, 16) -output_face_hm_shape = (8, 8, 8) -input_face_shape = (192, 192) -focal = (5000, 5000) # virtual focal lengths -princpt = (input_body_shape[1] / 2, input_body_shape[0] / 2) # virtual principal point position -body_3d_size = 2 -hand_3d_size = 0.3 -face_3d_size = 0.3 -camera_3d_size = 2.5 - -## training config -print_iters = 100 -lr_mult = 1 - -## testing config -test_batch_size = 32 - -## others -num_thread = 2 -vis = False - -## directory -output_dir, model_dir, vis_dir, log_dir, result_dir, code_dir = None, None, None, None, None, None diff --git a/main/config/config_smpler_x_s32.py b/main/config/config_smpler_x_s32.py deleted file mode 100644 index 090501bef40b1130e733d9567c05dd11b22b9ed1..0000000000000000000000000000000000000000 --- a/main/config/config_smpler_x_s32.py +++ /dev/null @@ -1,111 +0,0 @@ -import os -import os.path as osp - -# will be update in exp -num_gpus = -1 -exp_name = 'output/exp1/pre_analysis' - -# quick access -save_epoch = 1 -lr = 1e-5 -end_epoch = 10 -train_batch_size = 32 - -syncbn = True -bbox_ratio = 1.2 - -# continue -continue_train = False -start_over = True - -# dataset setting -agora_fix_betas = True -agora_fix_global_orient_transl = True -agora_valid_root_pose = True - -# all data -dataset_list = ['Human36M', 'MSCOCO', 'MPII', 'AGORA', 'EHF', 'SynBody', 'GTA_Human2', \ - 'EgoBody_Egocentric', 'EgoBody_Kinect', 'UBody', 'PW3D', 'MuCo', 'PROX'] -trainset_3d = ['MSCOCO','AGORA', 'UBody'] -trainset_2d = ['PW3D', 'MPII', 'Human36M'] -trainset_humandata = ['BEDLAM', 'SPEC', 'GTA_Human2','SynBody', 'PoseTrack', - 'EgoBody_Egocentric', 'PROX', 'CrowdPose', - 'EgoBody_Kinect', 'MPI_INF_3DHP', 'RICH', 'MuCo', 'InstaVariety', - 'Behave', 'UP3D', 'ARCTIC', - 'OCHuman', 'CHI3D', 'RenBody_HiRes', 'MTP', 'HumanSC3D', 'RenBody', - 'FIT3D', 'Talkshow' , 'SSP3D', 'LSPET'] -testset = 'EHF' - -use_cache = True -# downsample -BEDLAM_train_sample_interval = 5 -EgoBody_Kinect_train_sample_interval = 10 -train_sample_interval = 10 # UBody -MPI_INF_3DHP_train_sample_interval = 5 -InstaVariety_train_sample_interval = 10 -RenBody_HiRes_train_sample_interval = 5 -ARCTIC_train_sample_interval = 10 -# RenBody_train_sample_interval = 10 -FIT3D_train_sample_interval = 10 -Talkshow_train_sample_interval = 10 - -# strategy -data_strategy = 'balance' # 'balance' need to define total_data_len -total_data_len = 4500000 - -# model -smplx_loss_weight = 1.0 #2 for agora_model for smplx shape -smplx_pose_weight = 10.0 - -smplx_kps_3d_weight = 100.0 -smplx_kps_2d_weight = 1.0 -net_kps_2d_weight = 1.0 - -agora_benchmark = 'agora_model' # 'agora_model', 'test_only' - -model_type = 'smpler_x_s' -encoder_config_file = 'main/transformer_utils/configs/smpler_x/encoder/body_encoder_small.py' -encoder_pretrained_model_path = 'pretrained_models/vitpose_small.pth' -feat_dim = 384 - -## =====FIXED ARGS============================================================ -## model setting -upscale = 4 -hand_pos_joint_num = 20 -face_pos_joint_num = 72 -num_task_token = 24 -num_noise_sample = 0 - -## UBody setting -train_sample_interval = 10 -test_sample_interval = 100 -make_same_len = False - -## input, output size -input_img_shape = (512, 384) -input_body_shape = (256, 192) -output_hm_shape = (16, 16, 12) -input_hand_shape = (256, 256) -output_hand_hm_shape = (16, 16, 16) -output_face_hm_shape = (8, 8, 8) -input_face_shape = (192, 192) -focal = (5000, 5000) # virtual focal lengths -princpt = (input_body_shape[1] / 2, input_body_shape[0] / 2) # virtual principal point position -body_3d_size = 2 -hand_3d_size = 0.3 -face_3d_size = 0.3 -camera_3d_size = 2.5 - -## training config -print_iters = 100 -lr_mult = 1 - -## testing config -test_batch_size = 32 - -## others -num_thread = 2 -vis = False - -## directory -output_dir, model_dir, vis_dir, log_dir, result_dir, code_dir = None, None, None, None, None, None diff --git a/main/inference.py b/main/inference.py index 6e9445582d9fb6fc200913efc519516d84b58988..af4d1c505367471e52c876a26a5b76f6412ce4e3 100644 --- a/main/inference.py +++ b/main/inference.py @@ -11,11 +11,12 @@ sys.path.insert(0, osp.join(CUR_DIR, '..', 'main')) sys.path.insert(0, osp.join(CUR_DIR , '..', 'common')) from config import cfg import cv2 -from tqdm import tqdm -import json -from typing import Literal, Union from mmdet.apis import init_detector, inference_detector from utils.inference_utils import process_mmdet_results, non_max_suppression +from postometro_utils.smpl import SMPL +import data.config as smpl_cfg +from postometro import get_model +from postometro_utils.renderer_pyrender import PyRender_Renderer class Inferer: @@ -29,16 +30,18 @@ class Inferer: # ckpt_path = osp.join(CUR_DIR, '../pretrained_models', f'{pretrained_model}.pth.tar') ckpt_path = None # for config cfg.get_config_fromfile(config_path) + # uodate config cfg.update_config(num_gpus, ckpt_path, output_folder, self.device) self.cfg = cfg cudnn.benchmark = True - # # load model - # from base import Demoer - # demoer = Demoer() - # demoer._make_model() - # demoer.model.eval() - # self.demoer = demoer + # load SMPL + self.smpl = SMPL().to(self.device) + self.faces = self.smpl.faces.cpu().numpy() + + # load model + hmr_model_checkpoint_file = osp.join(CUR_DIR, '../pretrained_models/postometro/resnet_state_dict.bin') + self.hmr_model = get_model(backbone_str='resnet50',device=self.device, checkpoint_file = hmr_model_checkpoint_file) # load faster-rcnn as human detector checkpoint_file = osp.join(CUR_DIR, '../pretrained_models/mmdet/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth') @@ -46,17 +49,20 @@ class Inferer: model = init_detector(config_file, checkpoint_file, device=self.device) # or device='cuda:0' self.model = model - def infer(self, original_img, iou_thr, frame, multi_person=False, mesh_as_vertices=False): + def infer(self, original_img, iou_thr, multi_person=False, mesh_as_vertices=False): from utils.preprocessing import process_bbox, generate_patch_image - # from utils.vis import render_mesh, save_obj + from utils.vis import render_mesh # from utils.human_models import smpl_x - mesh_paths = [] - smplx_paths = [] + # prepare input image - transform = transforms.ToTensor() + transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) vis_img = original_img.copy() original_img_height, original_img_width = original_img.shape[:2] + # load renderer + # self.renderer = PyRender_Renderer(resolution=(original_img_width, original_img_height), faces=self.faces) + ## mmdet inference mmdet_results = inference_detector(self.model, original_img) mmdet_box = process_mmdet_results(mmdet_results, cat_id=0, multi_person=True) @@ -99,51 +105,69 @@ class Inferer: top_left = (int(bbox[0]), int(bbox[1])) bottom_right = (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3])) cv2.rectangle(vis_img, top_left, bottom_right, (0, 0, 255), 2) - - + # human model inference - # img, img2bb_trans, bb2img_trans = generate_patch_image(original_img, bbox, 1.0, 0.0, False, self.cfg.input_img_shape) - # img = transform(img.astype(np.float32))/255 - # img = img.to(cfg.device)[None,:,:,:] - # inputs = {'img': img} - # targets = {} - # meta_info = {} - - # # mesh recovery - # with torch.no_grad(): - # out = self.demoer.model(inputs, targets, meta_info, 'test') - # mesh = out['smplx_mesh_cam'].detach().cpu().numpy()[0] - - # ## save mesh - # save_path_mesh = os.path.join(self.output_folder, 'mesh') - # os.makedirs(save_path_mesh, exist_ok= True) - # obj_path = os.path.join(save_path_mesh, f'{frame:05}_{bbox_id}.obj') - # save_obj(mesh, smpl_x.face, obj_path) - # mesh_paths.append(obj_path) - # ## save single person param - # smplx_pred = {} - # smplx_pred['global_orient'] = out['smplx_root_pose'].reshape(-1,3).cpu().numpy() - # smplx_pred['body_pose'] = out['smplx_body_pose'].reshape(-1,3).cpu().numpy() - # smplx_pred['left_hand_pose'] = out['smplx_lhand_pose'].reshape(-1,3).cpu().numpy() - # smplx_pred['right_hand_pose'] = out['smplx_rhand_pose'].reshape(-1,3).cpu().numpy() - # smplx_pred['jaw_pose'] = out['smplx_jaw_pose'].reshape(-1,3).cpu().numpy() - # smplx_pred['leye_pose'] = np.zeros((1, 3)) - # smplx_pred['reye_pose'] = np.zeros((1, 3)) - # smplx_pred['betas'] = out['smplx_shape'].reshape(-1,10).cpu().numpy() - # smplx_pred['expression'] = out['smplx_expr'].reshape(-1,10).cpu().numpy() - # smplx_pred['transl'] = out['cam_trans'].reshape(-1,3).cpu().numpy() - # save_path_smplx = os.path.join(self.output_folder, 'smplx') - # os.makedirs(save_path_smplx, exist_ok= True) - - # npz_path = os.path.join(save_path_smplx, f'{frame:05}_{bbox_id}.npz') - # np.savez(npz_path, **smplx_pred) - # smplx_paths.append(npz_path) - - # ## render single person mesh - # focal = [self.cfg.focal[0] / self.cfg.input_body_shape[1] * bbox[2], self.cfg.focal[1] / self.cfg.input_body_shape[0] * bbox[3]] - # princpt = [self.cfg.princpt[0] / self.cfg.input_body_shape[1] * bbox[2] + bbox[0], self.cfg.princpt[1] / self.cfg.input_body_shape[0] * bbox[3] + bbox[1]] - # vis_img = render_mesh(vis_img, mesh, smpl_x.face, {'focal': focal, 'princpt': princpt}, + img, img2bb_trans, bb2img_trans = generate_patch_image(original_img, bbox, 1.0, 0.0, False, self.cfg.input_img_shape) + vis_patched_images = img.copy() + # here we pre-process images + img = img.transpose((2,0,1)) # h,w,c -> c,h,w + img = torch.from_numpy(img).float() / 255.0 + # Store image before normalization to use it in visualization + img = transform(img) + img = img.to(cfg.device)[None,:,:,:] + + self.renderer = PyRender_Renderer(resolution=(bbox[2], bbox[3]), faces=self.faces) + + # mesh recovery + with torch.no_grad(): + out = self.hmr_model(img) + pred_cam, pred_3d_vertices_fine = out['pred_cam'], out['pred_3d_vertices_fine'] + pred_3d_joints_from_smpl = self.smpl.get_h36m_joints(pred_3d_vertices_fine) # batch_size X 17 X 3 + pred_3d_joints_from_smpl_pelvis = pred_3d_joints_from_smpl[:,smpl_cfg.H36M_J17_NAME.index('Pelvis'),:] + pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,smpl_cfg.H36M_J17_TO_J14,:] # batch_size X 14 X 3 + # normalize predicted vertices + pred_3d_vertices_fine = pred_3d_vertices_fine - pred_3d_joints_from_smpl_pelvis[:, None, :] # batch_size X 6890 X 3 + pred_3d_vertices_fine = pred_3d_vertices_fine.detach().cpu().numpy()[0] # 6890 X 3 + pred_cam = pred_cam.detach().cpu().numpy()[0] + bbox_cx, bbox_cy = bbox[0] + bbox[2] / 2, bbox[1] + bbox[3] / 2 + img_cx, img_cy = original_img_width / 2, original_img_height / 2 + cx_delta, cy_delta = bbox_cx / img_cx - 1, bbox_cy / img_cy - 1 + + # render single person mesh + # focal = [self.cfg.focal[0] / self.cfg.renderer_input_body_shape[1] * bbox[2], self.cfg.focal[1] / self.cfg.renderer_input_body_shape[0] * bbox[3]] + # princpt = [self.cfg.princpt[0] / self.cfg.renderer_input_body_shape[1] * bbox[2] + bbox[0], self.cfg.princpt[1] / self.cfg.renderer_input_body_shape[0] * bbox[3] + bbox[1]] + # vis_img = render_mesh(vis_img, pred_3d_vertices_fine, self.faces, {'focal': focal, 'princpt': princpt}, # mesh_as_vertices=mesh_as_vertices) - # vis_img = vis_img.astype('uint8') - return vis_img, num_bbox, ok_bboxes + # vis_img = render_mesh(vis_img, pred_3d_vertices_fine, self.faces, [pred_cam[0] / (original_img_width / bbox[2]), pred_cam[0] / (original_img_height / bbox[3]), pred_cam[1], pred_cam[2]], mesh_as_vertices=mesh_as_vertices) + # import ipdb + # ipdb.set_trace() + vis_img = render_mesh(vis_img, pred_3d_vertices_fine, self.faces, [pred_cam[0] / (original_img_width / bbox[2]), pred_cam[0] / (original_img_height / bbox[3]), + pred_cam[1] + cx_delta / (pred_cam[0] / (original_img_width / bbox[2])), + pred_cam[2] + cy_delta / (pred_cam[0] / (original_img_height / bbox[3]))], + mesh_as_vertices=mesh_as_vertices) + # vis_img = render_mesh(vis_img, pred_3d_vertices_fine, self.faces, [pred_cam[0] / (original_img_width / bbox[2]), pred_cam[0] / (original_img_height / bbox[3]), 0, 0], mesh_as_vertices=mesh_as_vertices) + + # bbox_meta = {'bbox': bbox, 'img_hw': [original_img_height, original_img_width]} + # vis_img = self.renderer(pred_3d_vertices_fine, bbox_meta, vis_img, pred_cam) + vis_img = vis_img.astype('uint8') + return vis_img, len(ok_bboxes), ok_bboxes + + +if __name__ == '__main__': + from PIL import Image + inferer = Inferer('postometro', 0, './out_folder') # gpu + image_path = f'../assets/07.jpg' + image = Image.open(image_path) + # Convert the PIL image to a NumPy array + image_np = np.array(image) + vis_img, _ , _ = inferer.infer(image_np, 0.2, multi_person=True, mesh_as_vertices=False) + save_path = f'./saved_vis_07.jpg' + + # Ensure the image is in the correct format (PIL expects uint8) + if vis_img.dtype != np.uint8: + vis_img = vis_img.astype('uint8') + # Convert the Numpy array (if RGB) to a PIL image and save + image = Image.fromarray(vis_img) + image.save(save_path) + diff --git a/main/pct_utils/__pycache__/modules.cpython-39.pyc b/main/pct_utils/__pycache__/modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..421fd611fd1622031e1a82917072fb06efc4ed9f Binary files /dev/null and b/main/pct_utils/__pycache__/modules.cpython-39.pyc differ diff --git a/main/pct_utils/__pycache__/pct.cpython-39.pyc b/main/pct_utils/__pycache__/pct.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25330b0d344355415e0196f7fdabc6d5f1063b7c Binary files /dev/null and b/main/pct_utils/__pycache__/pct.cpython-39.pyc differ diff --git a/main/pct_utils/__pycache__/pct_backbone.cpython-39.pyc b/main/pct_utils/__pycache__/pct_backbone.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2cd494fc2451478426a1b2b6fa72e61fb0a12b7 Binary files /dev/null and b/main/pct_utils/__pycache__/pct_backbone.cpython-39.pyc differ diff --git a/main/pct_utils/__pycache__/pct_head.cpython-39.pyc b/main/pct_utils/__pycache__/pct_head.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..132027917b0b4337373b740be122468af2477065 Binary files /dev/null and b/main/pct_utils/__pycache__/pct_head.cpython-39.pyc differ diff --git a/main/pct_utils/__pycache__/pct_tokenizer.cpython-39.pyc b/main/pct_utils/__pycache__/pct_tokenizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c31244caba23440d0c9250b64cc69f33b4816c94 Binary files /dev/null and b/main/pct_utils/__pycache__/pct_tokenizer.cpython-39.pyc differ diff --git a/main/pct_utils/modules.py b/main/pct_utils/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..2c868a46adb650284343178e1ea8c9a5c51ff73a --- /dev/null +++ b/main/pct_utils/modules.py @@ -0,0 +1,117 @@ +# -------------------------------------------------------- +# Borrow from unofficial MLPMixer (https://github.com/920232796/MlpMixer-pytorch) +# Borrow from ResNet +# Modified by Zigang Geng (zigang@mail.ustc.edu.cn) +# -------------------------------------------------------- + +import torch +import torch.nn as nn + + +class FCBlock(nn.Module): + def __init__(self, dim, out_dim): + super().__init__() + + self.ff = nn.Sequential( + nn.Linear(dim, out_dim), + nn.LayerNorm(out_dim), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + return self.ff(x) + + +class MLPBlock(nn.Module): + def __init__(self, dim, inter_dim, dropout_ratio): + super().__init__() + + self.ff = nn.Sequential( + nn.Linear(dim, inter_dim), + nn.GELU(), + nn.Dropout(dropout_ratio), + nn.Linear(inter_dim, dim), + nn.Dropout(dropout_ratio) + ) + + def forward(self, x): + return self.ff(x) + + +class MixerLayer(nn.Module): + def __init__(self, + hidden_dim, + hidden_inter_dim, + token_dim, + token_inter_dim, + dropout_ratio): + super().__init__() + + self.layernorm1 = nn.LayerNorm(hidden_dim) + self.MLP_token = MLPBlock(token_dim, token_inter_dim, dropout_ratio) + self.layernorm2 = nn.LayerNorm(hidden_dim) + self.MLP_channel = MLPBlock(hidden_dim, hidden_inter_dim, dropout_ratio) + + def forward(self, x): + y = self.layernorm1(x) + y = y.transpose(2, 1) + y = self.MLP_token(y) + y = y.transpose(2, 1) + z = self.layernorm2(x + y) + z = self.MLP_channel(z) + out = x + y + z + return out + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, + downsample=None, dilation=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, + padding=dilation, bias=False, dilation=dilation) + self.bn1 = nn.BatchNorm2d(planes, momentum=0.1) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, + padding=dilation, bias=False, dilation=dilation) + self.bn2 = nn.BatchNorm2d(planes, momentum=0.1) + self.downsample = downsample + self.stride = stride + + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + +def make_conv_layers(feat_dims, kernel=3, stride=1, padding=1, bnrelu_final=True): + layers = [] + for i in range(len(feat_dims)-1): + layers.append( + nn.Conv2d( + in_channels=feat_dims[i], + out_channels=feat_dims[i+1], + kernel_size=kernel, + stride=stride, + padding=padding + )) + # Do not use BN and ReLU for final estimation + if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final): + layers.append(nn.BatchNorm2d(feat_dims[i+1])) + layers.append(nn.ReLU(inplace=True)) + + return nn.Sequential(*layers) \ No newline at end of file diff --git a/main/pct_utils/pct.py b/main/pct_utils/pct.py new file mode 100644 index 0000000000000000000000000000000000000000..ff60e0326a6f96bcf06f2ef3fa0c1c0d0bff8eb1 --- /dev/null +++ b/main/pct_utils/pct.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +from pct_utils.pct_head import PCT_Head + +class PCT(nn.Module): + def __init__(self, + args, + backbone, + stage_pct, + in_channels, + image_size, + num_joints, + pretrained=None, + tokenizer_pretrained=None): + super().__init__() + self.stage_pct = stage_pct + assert self.stage_pct in ["tokenizer", "classifier"] + self.guide_ratio = args.tokenizer_guide_ratio + self.image_guide = self.guide_ratio > 0.0 + self.num_joints = num_joints + + self.backbone = backbone + if self.image_guide: + self.extra_backbone = backbone + + self.keypoint_head = PCT_Head(args,stage_pct,in_channels,image_size,num_joints) + + if (pretrained is not None) or (tokenizer_pretrained is not None): + self.init_weights(pretrained, tokenizer_pretrained) + + def init_weights(self, pretrained, tokenizer): + """Weight initialization for model.""" + if self.stage_pct == "classifier": + self.backbone.init_weights(pretrained) + if self.image_guide: + self.extra_backbone.init_weights(pretrained) + self.keypoint_head.init_weights() + self.keypoint_head.tokenizer.init_weights(tokenizer) + + def forward(self,img, joints, train = True): + if train: + output = None if self.stage_pct == "tokenizer" else self.backbone(img) + extra_output = self.extra_backbone(img) if self.image_guide else None + + p_logits, p_joints, g_logits, e_latent_loss = \ + self.keypoint_head(output, extra_output, joints, train=True) + return { + 'cls_logits': p_logits, + 'pred_pose': p_joints, + 'encoding_indices': g_logits, + 'e_latent_loss': e_latent_loss + } + else: + results = {} + + batch_size, _, img_height, img_width = img.shape + + output = None if self.stage_pct == "tokenizer" \ + else self.backbone(img) + extra_output = self.extra_backbone(img) \ + if self.image_guide and self.stage_pct == "tokenizer" else None + + p_joints, encoding_scores, out_part_token_feat = \ + self.keypoint_head(output, extra_output, joints, train=False) + return { + 'pred_pose': p_joints, + 'encoding_scores': encoding_scores, + 'part_token_feat': out_part_token_feat + } diff --git a/main/pct_utils/pct_backbone.py b/main/pct_utils/pct_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..6cacc11c1dfe29830a0a031b814c55e2dd1593f6 --- /dev/null +++ b/main/pct_utils/pct_backbone.py @@ -0,0 +1,1475 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu, Yutong Lin, Yixuan Wei +# -------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +import torch.utils.checkpoint as checkpoint +from torch.nn.utils import weight_norm +from torch import Tensor, Size +from typing import Union, List +import numpy as np +import logging + +# Copyright (c) Open-MMLab. All rights reserved. +# Copy from mmcv source code. +import io +import os +import os.path as osp +import pkgutil +import time +import warnings +import numpy as np +from scipy import interpolate + +import torch +import torchvision +import torch.distributed as dist +from torch.utils import model_zoo +from torch.nn import functional as F + + +def _load_checkpoint(filename, map_location=None): + if not osp.isfile(filename): + raise IOError(f'{filename} is not a checkpoint file') + checkpoint = torch.load(filename, map_location=map_location) + return checkpoint + + +def load_checkpoint_swin(model, + filename, + map_location='cpu', + strict=False, + rpe_interpolation='outer_mask', + logger=None): + """Load checkpoint from a file or URI. + Args: + model (Module): Module to load checkpoint. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str): Same as :func:`torch.load`. + strict (bool): Whether to allow different params for the model and + checkpoint. + logger (:mod:`logging.Logger` or None): The logger for error message. + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + checkpoint = _load_checkpoint(filename, map_location) + # OrderedDict is a subclass of dict + if not isinstance(checkpoint, dict): + raise RuntimeError( + f'No state_dict found in checkpoint file {filename}') + # get state_dict from checkpoint + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + elif 'module' in checkpoint: + state_dict = checkpoint['module'] + else: + state_dict = checkpoint + # strip prefix of state_dict + if list(state_dict.keys())[0].startswith('module.'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + if list(state_dict.keys())[0].startswith('backbone.'): + state_dict = {k[9:]: v for k, v in state_dict.items()} + + # for MoBY, load model of online branch + if sorted(list(state_dict.keys()))[2].startswith('encoder'): + state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')} + + # directly load here + + model.load_state_dict(state_dict, strict=True) + + return checkpoint + + +_shape_t = Union[int, List[int], Size] + +from itertools import repeat +import collections.abc + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + return parse + +to_2tuple = _ntuple(2) + +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f'drop_prob={round(self.drop_prob,3):0.3f}' + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are + applied while sampling the normal with mean/std applied, therefore a, b args + should be adjusted to match the range of mean, std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + with torch.no_grad(): + return _trunc_normal_(tensor, mean, std, a, b) + + +class LayerNorm2D(nn.Module): + def __init__(self, normalized_shape, norm_layer=None): + super().__init__() + self.ln = norm_layer(normalized_shape) if norm_layer is not None else nn.Identity() + + def forward(self, x): + """ + x: N C H W + """ + x = x.permute(0, 2, 3, 1) + x = self.ln(x) + x = x.permute(0, 3, 1, 2) + return x + + +class LayerNormFP32(nn.LayerNorm): + def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True) -> None: + super(LayerNormFP32, self).__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, input: Tensor) -> Tensor: + return F.layer_norm( + input.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).type_as(input) + + +class LinearFP32(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super(LinearFP32, self).__init__(in_features, out_features, bias) + + def forward(self, input: Tensor) -> Tensor: + return F.linear(input.float(), self.weight.float(), + self.bias.float() if self.bias is not None else None) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., + norm_layer=None, mlpfp32=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.mlpfp32 = mlpfp32 + + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + if norm_layer is not None: + self.norm = norm_layer(hidden_features) + else: + self.norm = None + + def forward(self, x, H, W): + x = self.fc1(x) + if self.norm: + x = self.norm(x) + x = self.act(x) + x = self.drop(x) + if self.mlpfp32: + x = self.fc2.float()(x.type(torch.float32)) + x = self.drop.float()(x) + # print(f"======>[MLP FP32]") + else: + x = self.fc2(x) + x = self.drop(x) + return x + + +class ConvMlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., + norm_layer=None, mlpfp32=False, proj_ln=False): + super().__init__() + self.mlp = Mlp(in_features=in_features, hidden_features=hidden_features, out_features=out_features, + act_layer=act_layer, drop=drop, norm_layer=norm_layer, mlpfp32=mlpfp32) + self.conv_proj = nn.Conv2d(in_features, + in_features, + kernel_size=3, + padding=1, + stride=1, + bias=False, + groups=in_features) + self.proj_ln = LayerNorm2D(in_features, LayerNormFP32) if proj_ln else None + + def forward(self, x, H, W): + B, L, C = x.shape + assert L == H * W + x = x.view(B, H, W, C).permute(0, 3, 1, 2) # B C H W + x = self.conv_proj(x) + if self.proj_ln: + x = self.proj_ln(x) + x = x.permute(0, 2, 3, 1) # B H W C + x = x.reshape(B, L, C) + x = self.mlp(x, H, W) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + relative_coords_table_type='norm8_log', rpe_hidden_dim=512, + rpe_output_type='normal', attn_type='normal', mlpfp32=False, pretrain_window_size=-1): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + self.mlpfp32 = mlpfp32 + self.attn_type = attn_type + self.rpe_output_type = rpe_output_type + self.relative_coords_table_type = relative_coords_table_type + + if self.attn_type == 'cosine_mh': + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) + elif self.attn_type == 'normal': + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + else: + raise NotImplementedError() + if self.relative_coords_table_type != "none": + # mlp to generate table of relative position bias + self.rpe_mlp = nn.Sequential(nn.Linear(2, rpe_hidden_dim, bias=True), + nn.ReLU(inplace=True), + LinearFP32(rpe_hidden_dim, num_heads, bias=False)) + + # get relative_coords_table + relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) + relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) + relative_coords_table = torch.stack( + torch.meshgrid([relative_coords_h, + relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + if relative_coords_table_type == 'linear': + relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) + elif relative_coords_table_type == 'linear_bylayer': + relative_coords_table[:, :, :, 0] /= (pretrain_window_size - 1) + relative_coords_table[:, :, :, 1] /= (pretrain_window_size - 1) + elif relative_coords_table_type == 'norm8_log': + relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1.0) / np.log2(8) # log8 + elif relative_coords_table_type == 'norm8_log_192to640': + if self.window_size[0] == 40: + relative_coords_table[:, :, :, 0] /= (11) + relative_coords_table[:, :, :, 1] /= (11) + elif self.window_size[0] == 20: + relative_coords_table[:, :, :, 0] /= (5) + relative_coords_table[:, :, :, 1] /= (5) + else: + raise NotImplementedError + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1.0) / np.log2(8) # log8 + # check + elif relative_coords_table_type == 'norm8_log_256to640': + if self.window_size[0] == 40: + relative_coords_table[:, :, :, 0] /= (15) + relative_coords_table[:, :, :, 1] /= (15) + elif self.window_size[0] == 20: + relative_coords_table[:, :, :, 0] /= (7) + relative_coords_table[:, :, :, 1] /= (7) + else: + raise NotImplementedError + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1.0) / np.log2(8) # log8 + elif relative_coords_table_type == 'norm8_log_bylayer': + relative_coords_table[:, :, :, 0] /= (pretrain_window_size - 1) + relative_coords_table[:, :, :, 1] /= (pretrain_window_size - 1) + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1.0) / np.log2(8) # log8 + else: + raise NotImplementedError + self.register_buffer("relative_coords_table", relative_coords_table) + else: + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + trunc_normal_(self.relative_position_bias_table, std=.02) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(dim)) + self.v_bias = nn.Parameter(torch.zeros(dim)) + else: + self.q_bias = None + self.v_bias = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + if self.attn_type == 'cosine_mh': + q = F.normalize(q.float(), dim=-1) + k = F.normalize(k.float(), dim=-1) + logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01, device=self.logit_scale.device))).exp() + attn = (q @ k.transpose(-2, -1)) * logit_scale.float() + elif self.attn_type == 'normal': + q = q * self.scale + attn = (q.float() @ k.float().transpose(-2, -1)) + else: + raise NotImplementedError() + + if self.relative_coords_table_type != "none": + # relative_position_bias_table: 2*Wh-1 * 2*Ww-1, nH + relative_position_bias_table = self.rpe_mlp(self.relative_coords_table).view(-1, self.num_heads) + else: + relative_position_bias_table = self.relative_position_bias_table + relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + if self.rpe_output_type == 'normal': + pass + elif self.rpe_output_type == 'sigmoid': + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + else: + raise NotImplementedError + + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + + attn = self.softmax(attn) + attn = attn.type_as(x) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + if self.mlpfp32: + x = self.proj.float()(x.type(torch.float32)) + x = self.proj_drop.float()(x) + # print(f"======>[ATTN FP32]") + else: + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlockPost(nn.Module): + """ Swin Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + use_mlp_norm=False, endnorm=False, act_layer=nn.GELU, norm_layer=nn.LayerNorm, + relative_coords_table_type='norm8_log', rpe_hidden_dim=512, + rpe_output_type='normal', attn_type='normal', mlp_type='normal', mlpfp32=False, + pretrain_window_size=-1): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.use_mlp_norm = use_mlp_norm + self.endnorm = endnorm + self.mlpfp32 = mlpfp32 + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + relative_coords_table_type=relative_coords_table_type, rpe_output_type=rpe_output_type, + rpe_hidden_dim=rpe_hidden_dim, attn_type=attn_type, mlpfp32=mlpfp32, + pretrain_window_size=pretrain_window_size) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + if mlp_type == 'normal': + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, + norm_layer=norm_layer if self.use_mlp_norm else None, mlpfp32=mlpfp32) + elif mlp_type == 'conv': + self.mlp = ConvMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, + norm_layer=norm_layer if self.use_mlp_norm else None, mlpfp32=mlpfp32) + elif mlp_type == 'conv_ln': + self.mlp = ConvMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, + norm_layer=norm_layer if self.use_mlp_norm else None, mlpfp32=mlpfp32, proj_ln=True) + + if self.endnorm: + self.enorm = norm_layer(dim) + else: + self.enorm = None + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + H, W = self.H, self.W + B, L, C = x.shape + assert L == H * W, f"input feature has wrong size, with L = {L}, H = {H}, W = {W}" + + shortcut = x + + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + if pad_r > 0 or pad_b > 0: + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + orig_type = x.dtype # attn may force to fp32 + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + if self.mlpfp32: + x = self.norm1.float()(x) + x = x.type(orig_type) + else: + x = self.norm1(x) + x = shortcut + self.drop_path(x) + shortcut = x + + orig_type = x.dtype + x = self.mlp(x, H, W) + if self.mlpfp32: + x = self.norm2.float()(x) + x = x.type(orig_type) + else: + x = self.norm2(x) + x = shortcut + self.drop_path(x) + + if self.endnorm: + x = self.enorm(x) + + return x + + +class SwinTransformerBlockPre(nn.Module): + """ Swin Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + use_mlp_norm=False, endnorm=False, act_layer=nn.GELU, norm_layer=nn.LayerNorm, + init_values=None, relative_coords_table_type='norm8_log', rpe_hidden_dim=512, + rpe_output_type='normal', attn_type='normal', mlp_type='normal', mlpfp32=False, + pretrain_window_size=-1): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.use_mlp_norm = use_mlp_norm + self.endnorm = endnorm + self.mlpfp32 = mlpfp32 + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + relative_coords_table_type=relative_coords_table_type, rpe_output_type=rpe_output_type, + rpe_hidden_dim=rpe_hidden_dim, attn_type=attn_type, mlpfp32=mlpfp32, + pretrain_window_size=pretrain_window_size) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + if mlp_type == 'normal': + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, + norm_layer=norm_layer if self.use_mlp_norm else None, mlpfp32=mlpfp32) + elif mlp_type == 'conv': + self.mlp = ConvMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, + norm_layer=norm_layer if self.use_mlp_norm else None, mlpfp32=mlpfp32) + elif mlp_type == 'conv_ln': + self.mlp = ConvMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, + norm_layer=norm_layer if self.use_mlp_norm else None, mlpfp32=mlpfp32, proj_ln=True) + + if init_values is not None and init_values >= 0: + self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + else: + self.gamma_1, self.gamma_2 = 1.0, 1.0 + + if self.endnorm: + self.enorm = norm_layer(dim) + else: + self.enorm = None + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + H, W = self.H, self.W + B, L, C = x.shape + assert L == H * W, f"input feature has wrong size, with L = {L}, H = {H}, W = {W}" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + if pad_r > 0 or pad_b > 0: + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + orig_type = x.dtype # attn may force to fp32 + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + if self.mlpfp32: + x = self.gamma_1 * x + x = x.type(orig_type) + else: + x = self.gamma_1 * x + x = shortcut + self.drop_path(x) + shortcut = x + + orig_type = x.dtype + x = self.norm2(x) + if self.mlpfp32: + x = self.gamma_2 * self.mlp(x, H, W) + x = x.type(orig_type) + else: + x = self.gamma_2 * self.mlp(x, H, W) + x = shortcut + self.drop_path(x) + + if self.endnorm: + x = self.enorm(x) + + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer + + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm, postnorm=True): + super().__init__() + self.dim = dim + self.postnorm = postnorm + + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(2 * dim) if postnorm else norm_layer(4 * dim) + + def forward(self, x, H, W): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + if self.postnorm: + x = self.reduction(x) + x = self.norm(x) + else: + x = self.norm(x) + x = self.reduction(x) + + return x + + +class PatchReduction1C(nn.Module): + r""" Patch Reduction Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm, postnorm=True): + super().__init__() + self.dim = dim + self.postnorm = postnorm + + self.reduction = nn.Linear(dim, dim, bias=False) + self.norm = norm_layer(dim) + + def forward(self, x, H, W): + """ + x: B, H*W, C + """ + if self.postnorm: + x = self.reduction(x) + x = self.norm(x) + else: + x = self.norm(x) + x = self.reduction(x) + + return x + + +class ConvPatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm, postnorm=True): + super().__init__() + self.dim = dim + self.postnorm = postnorm + + self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=3, stride=2, padding=1) + self.norm = norm_layer(2 * dim) if postnorm else norm_layer(dim) + + def forward(self, x, H, W): + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + if self.postnorm: + x = x.permute(0, 3, 1, 2) # B C H W + x = self.reduction(x).flatten(2).transpose(1, 2) # B H//2*W//2 2*C + x = self.norm(x) + else: + x = self.norm(x) + x = x.permute(0, 3, 1, 2) # B C H W + x = self.reduction(x).flatten(2).transpose(1, 2) # B H//2*W//2 2*C + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + use_shift (bool): Whether to use shifted window. Default: True. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + checkpoint_blocks=255, + init_values=None, + endnorm_interval=-1, + use_mlp_norm=False, + use_shift=True, + relative_coords_table_type='norm8_log', + rpe_hidden_dim=512, + rpe_output_type='normal', + attn_type='normal', + mlp_type='normal', + mlpfp32_blocks=[-1], + postnorm=True, + pretrain_window_size=-1): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + self.checkpoint_blocks = checkpoint_blocks + self.init_values = init_values if init_values is not None else 0.0 + self.endnorm_interval = endnorm_interval + self.mlpfp32_blocks = mlpfp32_blocks + self.postnorm = postnorm + + # build blocks + if self.postnorm: + self.blocks = nn.ModuleList([ + SwinTransformerBlockPost( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) or (not use_shift) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + use_mlp_norm=use_mlp_norm, + endnorm=True if ((i + 1) % endnorm_interval == 0) and ( + endnorm_interval > 0) else False, + relative_coords_table_type=relative_coords_table_type, + rpe_hidden_dim=rpe_hidden_dim, + rpe_output_type=rpe_output_type, + attn_type=attn_type, + mlp_type=mlp_type, + mlpfp32=True if i in mlpfp32_blocks else False, + pretrain_window_size=pretrain_window_size) + for i in range(depth)]) + else: + self.blocks = nn.ModuleList([ + SwinTransformerBlockPre( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) or (not use_shift) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + init_values=init_values, + use_mlp_norm=use_mlp_norm, + endnorm=True if ((i + 1) % endnorm_interval == 0) and ( + endnorm_interval > 0) else False, + relative_coords_table_type=relative_coords_table_type, + rpe_hidden_dim=rpe_hidden_dim, + rpe_output_type=rpe_output_type, + attn_type=attn_type, + mlp_type=mlp_type, + mlpfp32=True if i in mlpfp32_blocks else False, + pretrain_window_size=pretrain_window_size) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer, postnorm=postnorm) + else: + self.downsample = None + + def forward(self, x, H, W): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + for idx, blk in enumerate(self.blocks): + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + + if self.downsample is not None: + x_down = self.downsample(x, H, W) + if isinstance(self.downsample, PatchReduction1C): + return x, H, W, x_down, H, W + else: + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + def _init_block_norm_weights(self): + for blk in self.blocks: + nn.init.constant_(blk.norm1.bias, 0) + nn.init.constant_(blk.norm1.weight, self.init_values) + nn.init.constant_(blk.norm2.bias, 0) + nn.init.constant_(blk.norm2.weight, self.init_values) + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class ResNetDLNPatchEmbed(nn.Module): + def __init__(self, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_2tuple(4) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.conv1 = nn.Sequential(nn.Conv2d(in_chans, 64, 3, stride=2, padding=1, bias=False), + LayerNorm2D(64, norm_layer), + nn.GELU(), + nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False), + LayerNorm2D(64, norm_layer), + nn.GELU(), + nn.Conv2d(64, embed_dim, 3, stride=1, padding=1, bias=False)) + self.norm = LayerNorm2D(embed_dim, norm_layer if norm_layer is not None else LayerNormFP32) # use ln always + self.act = nn.GELU() + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.conv1(x) + x = self.norm(x) + x = self.act(x) + x = self.maxpool(x) + # x = x.flatten(2).transpose(1, 2) + return x + + +class SwinV2TransformerRPE2FC(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + use_shift (bool): Whether to use shifted window. Default: True. + """ + + def __init__(self, + pretrain_img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=partial(LayerNormFP32, eps=1e-6), + ape=False, + patch_norm=True, + use_checkpoint=False, + init_values=1e-5, + endnorm_interval=-1, + use_mlp_norm_layers=[], + relative_coords_table_type='norm8_log', + rpe_hidden_dim=512, + attn_type='cosine_mh', + rpe_output_type='sigmoid', + rpe_wd=False, + postnorm=True, + mlp_type='normal', + patch_embed_type='normal', + patch_merge_type='normal', + strid16=False, + checkpoint_blocks=[255, 255, 255, 255], + mlpfp32_layer_blocks=[[-1], [-1], [-1], [-1]], + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_shift=True, + rpe_interpolation='geo', + pretrain_window_size=[-1, -1, -1, -1], + **kwargs): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.depths = depths + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.rpe_interpolation = rpe_interpolation + self.mlp_ratio = mlp_ratio + self.endnorm_interval = endnorm_interval + self.use_mlp_norm_layers = use_mlp_norm_layers + self.relative_coords_table_type = relative_coords_table_type + self.rpe_hidden_dim = rpe_hidden_dim + self.rpe_output_type = rpe_output_type + self.rpe_wd = rpe_wd + self.attn_type = attn_type + self.postnorm = postnorm + self.mlp_type = mlp_type + self.strid16 = strid16 + + if isinstance(window_size, list): + pass + elif isinstance(window_size, int): + window_size = [window_size] * self.num_layers + else: + raise TypeError("We only support list or int for window size") + + if isinstance(use_shift, list): + pass + elif isinstance(use_shift, bool): + use_shift = [use_shift] * self.num_layers + else: + raise TypeError("We only support list or bool for use_shift") + + if isinstance(use_checkpoint, list): + pass + elif isinstance(use_checkpoint, bool): + use_checkpoint = [use_checkpoint] * self.num_layers + else: + raise TypeError("We only support list or bool for use_checkpoint") + + # split image into non-overlapping patches + if patch_embed_type == 'normal': + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + elif patch_embed_type == 'resnetdln': + assert patch_size == 4, "check" + self.patch_embed = ResNetDLNPatchEmbed( + in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer) + elif patch_embed_type == 'resnetdnf': + assert patch_size == 4, "check" + self.patch_embed = ResNetDLNPatchEmbed( + in_chans=in_chans, embed_dim=embed_dim, norm_layer=None) + else: + raise NotImplementedError() + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + if patch_merge_type == 'normal': + downsample_layer = PatchMerging + elif patch_merge_type == 'conv': + downsample_layer = ConvPatchMerging + else: + raise NotImplementedError() + # build layers + self.layers = nn.ModuleList() + num_features = [] + for i_layer in range(self.num_layers): + cur_dim = int(embed_dim * 2 ** (i_layer - 1)) \ + if (i_layer == self.num_layers - 1 and strid16) else \ + int(embed_dim * 2 ** i_layer) + num_features.append(cur_dim) + if i_layer < self.num_layers - 2: + cur_downsample_layer = downsample_layer + elif i_layer == self.num_layers - 2: + if strid16: + cur_downsample_layer = PatchReduction1C + else: + cur_downsample_layer = downsample_layer + else: + cur_downsample_layer = None + layer = BasicLayer( + dim=cur_dim, + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size[i_layer], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=cur_downsample_layer, + use_checkpoint=use_checkpoint[i_layer], + checkpoint_blocks=checkpoint_blocks[i_layer], + init_values=init_values, + endnorm_interval=endnorm_interval, + use_mlp_norm=True if i_layer in use_mlp_norm_layers else False, + use_shift=use_shift[i_layer], + relative_coords_table_type=self.relative_coords_table_type, + rpe_hidden_dim=self.rpe_hidden_dim, + rpe_output_type=self.rpe_output_type, + attn_type=self.attn_type, + mlp_type=self.mlp_type, + mlpfp32_blocks=mlpfp32_layer_blocks[i_layer], + postnorm=self.postnorm, + pretrain_window_size=pretrain_window_size[i_layer] + ) + self.layers.append(layer) + + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + self.norm3.eval() + for param in self.norm3.parameters(): + param.requires_grad = False + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + self.apply(_init_weights) + for bly in self.layers: + bly._init_block_norm_weights() + + if isinstance(pretrained, str): + logger = None + load_checkpoint_swin(self, pretrained, strict=False, map_location='cpu', + logger=logger, rpe_interpolation=self.rpe_interpolation) + elif pretrained is None: + pass + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + + x = self.pos_drop(x) + + outs = [] + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer.float()(x_out.float()) + + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + + outs.append(out) + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinV2TransformerRPE2FC, self).train(mode) + self._freeze_stages() diff --git a/main/pct_utils/pct_head.py b/main/pct_utils/pct_head.py new file mode 100644 index 0000000000000000000000000000000000000000..c53f58e08287a5c1e558212f89723fa0c2533b99 --- /dev/null +++ b/main/pct_utils/pct_head.py @@ -0,0 +1,208 @@ +import torch +import torch.nn as nn + +from pct_utils.pct_tokenizer import PCT_Tokenizer +from pct_utils.modules import MixerLayer, FCBlock, BasicBlock + +def constant_init(module, val, bias=0): + if hasattr(module, 'weight') and module.weight is not None: + nn.init.constant_(module.weight, val) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + +def normal_init(module, mean=0, std=1, bias=0): + if hasattr(module, 'weight') and module.weight is not None: + nn.init.normal_(module.weight, mean, std) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + +class PCT_Head(nn.Module): + """ Head of Pose Compositional Tokens. + paper ref: Zigang Geng et al. "Human Pose as + Compositional Tokens" + + The pipelines of two stage during training and inference: + + Tokenizer Stage & Train: + Joints -> (Img Guide) -> Encoder -> Codebook -> Decoder -> Recovered Joints + Loss: (Joints, Recovered Joints) + Tokenizer Stage & Test: + Joints -> (Img Guide) -> Encoder -> Codebook -> Decoder -> Recovered Joints + + Classifer Stage & Train: + Img -> Classifier -> Predict Class -> Codebook -> Decoder -> Recovered Joints + Joints -> (Img Guide) -> Encoder -> Codebook -> Groundtruth Class + Loss: (Predict Class, Groundtruth Class), (Joints, Recovered Joints) + Classifer Stage & Test: + Img -> Classifier -> Predict Class -> Codebook -> Decoder -> Recovered Joints + + Args: + stage_pct (str): Training stage (Tokenizer or Classifier). + in_channels (int): Feature Dim of the backbone feature. + image_size (tuple): Input image size. + num_joints (int): Number of annotated joints in the dataset. + cls_head (dict): Config for PCT classification head. Default: None. + tokenizer (dict): Config for PCT tokenizer. Default: None. + loss_keypoint (dict): Config for loss for training classifier. Default: None. + """ + + def __init__(self, + args, + stage_pct, + in_channels, + image_size, + num_joints, + cls_head=None, + tokenizer=None, + loss_keypoint=None,): + super().__init__() + + self.image_size = image_size + self.stage_pct = stage_pct + + self.guide_ratio = args.tokenizer_guide_ratio + self.img_guide = self.guide_ratio > 0.0 + + self.conv_channels = args.cls_head_conv_channels + self.hidden_dim = args.cls_head_hidden_dim + + self.num_blocks = args.cls_head_num_blocks + self.hidden_inter_dim = args.cls_head_hidden_inter_dim + self.token_inter_dim = args.cls_head_token_inter_dim + self.dropout = args.cls_head_dropout + + self.token_num = args.tokenizer_codebook_token_num + self.token_class_num = args.tokenizer_codebook_token_class_num + + if stage_pct == "classifier": + self.conv_trans = self._make_transition_for_head( + in_channels, self.conv_channels) + self.conv_head = self._make_cls_head(args) + + input_size = (image_size[0]//32)*(image_size[1]//32) + self.mixer_trans = FCBlock( + self.conv_channels * input_size, + self.token_num * self.hidden_dim) + + self.mixer_head = nn.ModuleList( + [MixerLayer(self.hidden_dim, self.hidden_inter_dim, + self.token_num, self.token_inter_dim, + self.dropout) for _ in range(self.num_blocks)]) + self.mixer_norm_layer = FCBlock( + self.hidden_dim, self.hidden_dim) + + self.cls_pred_layer = nn.Linear( + self.hidden_dim, self.token_class_num) + + self.tokenizer = PCT_Tokenizer( + args = args, stage_pct=stage_pct, num_joints=num_joints, + guide_ratio=self.guide_ratio, guide_channels=in_channels) + + def forward(self, x, extra_x, joints=None, train=True): + """Forward function.""" + + if self.stage_pct == "classifier": + batch_size = x[-1].shape[0] + cls_feat = self.conv_head[0](self.conv_trans(x[-1])) + + cls_feat = cls_feat.flatten(2).transpose(2,1).flatten(1) + cls_feat = self.mixer_trans(cls_feat) + cls_feat = cls_feat.reshape(batch_size, self.token_num, -1) + + for mixer_layer in self.mixer_head: + cls_feat = mixer_layer(cls_feat) + cls_feat = self.mixer_norm_layer(cls_feat) + + cls_logits = self.cls_pred_layer(cls_feat) + + encoding_scores = cls_logits.topk(1, dim=2)[0] + cls_logits = cls_logits.flatten(0,1) + cls_logits_softmax = cls_logits.clone().softmax(1) + else: + encoding_scores = None + cls_logits = None + cls_logits_softmax = None + + if not self.img_guide or \ + (self.stage_pct == "classifier" and not train): + joints_feat = None + else: + joints_feat = self.extract_joints_feat(extra_x[-1], joints) + + output_joints, cls_label, e_latent_loss, out_part_token_feat = \ + self.tokenizer(joints, joints_feat, cls_logits_softmax, train=train) + + if train: + return cls_logits, output_joints, cls_label, e_latent_loss + else: + return output_joints, encoding_scores, out_part_token_feat + + def _make_transition_for_head(self, inplanes, outplanes): + transition_layer = [ + nn.Conv2d(inplanes, outplanes, 1, 1, 0, bias=False), + nn.BatchNorm2d(outplanes), + nn.ReLU(True) + ] + return nn.Sequential(*transition_layer) + + def _make_cls_head(self, args): + feature_convs = [] + feature_conv = self._make_layer( + BasicBlock, + args.cls_head_conv_channels, + args.cls_head_conv_channels, + args.cls_head_conv_num_blocks, + dilation=args.cls_head_dilation + ) + feature_convs.append(feature_conv) + + return nn.ModuleList(feature_convs) + + def _make_layer( + self, block, inplanes, planes, blocks, stride=1, dilation=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion, momentum=0.1), + ) + + layers = [] + layers.append(block(inplanes, planes, + stride, downsample, dilation=dilation)) + inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(inplanes, planes, dilation=dilation)) + + return nn.Sequential(*layers) + + def extract_joints_feat(self, feature_map, joint_coords): + assert self.image_size[1] == self.image_size[0], \ + 'If you want to use a rectangle input, ' \ + 'please carefully check the length and width below.' + batch_size, _, _, height = feature_map.shape + stride = self.image_size[0] / feature_map.shape[-1] + joint_x = (joint_coords[:,:,0] / stride + 0.5).int() + joint_y = (joint_coords[:,:,1] / stride + 0.5).int() + joint_x = joint_x.clamp(0, feature_map.shape[-1] - 1) + joint_y = joint_y.clamp(0, feature_map.shape[-2] - 1) + joint_indices = (joint_y * height + joint_x).long() + + flattened_feature_map = feature_map.clone().flatten(2) + joint_features = flattened_feature_map[ + torch.arange(batch_size).unsqueeze(1), :, joint_indices] + + return joint_features + + def init_weights(self): + if self.stage_pct == "classifier": + self.tokenizer.eval() + for name, params in self.tokenizer.named_parameters(): + params.requires_grad = False + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.001, bias=0) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) \ No newline at end of file diff --git a/main/pct_utils/pct_tokenizer.py b/main/pct_utils/pct_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..3b05318b9d8fc6ec6161eff69647608161db4ab1 --- /dev/null +++ b/main/pct_utils/pct_tokenizer.py @@ -0,0 +1,315 @@ +# -------------------------------------------------------- +# Pose Compositional Tokens +# Written by Zigang Geng (zigang@mail.ustc.edu.cn) +# -------------------------------------------------------- + +import os +import math +import warnings +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from pct_utils.modules import MixerLayer + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are + applied while sampling the normal with mean/std applied, therefore a, b args + should be adjusted to match the range of mean, std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + with torch.no_grad(): + return _trunc_normal_(tensor, mean, std, a, b) + +class PCT_Tokenizer(nn.Module): + """ Tokenizer of Pose Compositional Tokens. + paper ref: Zigang Geng et al. "Human Pose as + Compositional Tokens" + + Args: + stage_pct (str): Training stage (Tokenizer or Classifier). + tokenizer (list): Config about the tokenizer. + num_joints (int): Number of annotated joints in the dataset. + guide_ratio (float): The ratio of image guidance. + guide_channels (int): Feature Dim of the image guidance. + """ + + def __init__(self, + args, + stage_pct, + num_joints=14, + theta_dim=2, + guide_ratio=0, + guide_channels=0): + super().__init__() + + self.stage_pct = stage_pct + self.guide_ratio = guide_ratio + self.num_joints = num_joints + self.theta_dim = theta_dim + + self.drop_rate = args.tokenizer_encoder_drop_rate + self.enc_num_blocks = args.tokenizer_encoder_num_blocks + self.enc_hidden_dim = args.tokenizer_encoder_hidden_dim + self.enc_token_inter_dim = args.tokenizer_encoder_token_inter_dim + self.enc_hidden_inter_dim = args.tokenizer_encoder_hidden_inter_dim + self.enc_dropout = args.tokenizer_encoder_dropout + + self.dec_num_blocks = args.tokenizer_decoder_num_blocks + self.dec_hidden_dim = args.tokenizer_decoder_hidden_dim + self.dec_token_inter_dim = args.tokenizer_decoder_token_inter_dim + self.dec_hidden_inter_dim = args.tokenizer_decoder_hidden_inter_dim + self.dec_dropout = args.tokenizer_decoder_dropout + + self.token_num = args.tokenizer_codebook_token_num + self.token_class_num = args.tokenizer_codebook_token_class_num + self.token_dim = args.tokenizer_codebook_token_dim + self.decay = args.tokenizer_codebook_ema_decay + + self.invisible_token = nn.Parameter( + torch.zeros(1, 1, self.enc_hidden_dim)) + trunc_normal_(self.invisible_token, mean=0., std=0.02, a=-0.02, b=0.02) + + if self.guide_ratio > 0: + self.start_img_embed = nn.Linear( + guide_channels, int(self.enc_hidden_dim*self.guide_ratio)) + self.start_embed = nn.Linear( + 2, int(self.enc_hidden_dim*(1-self.guide_ratio))) + + self.encoder = nn.ModuleList( + [MixerLayer(self.enc_hidden_dim, self.enc_hidden_inter_dim, + self.num_joints, self.enc_token_inter_dim, + self.enc_dropout) for _ in range(self.enc_num_blocks)]) + self.encoder_layer_norm = nn.LayerNorm(self.enc_hidden_dim) + + self.token_mlp = nn.Linear( + self.num_joints, self.token_num) + self.feature_embed = nn.Linear( + self.enc_hidden_dim, self.token_dim) + + self.register_buffer('codebook', + torch.empty(self.token_class_num, self.token_dim)) + self.codebook.data.normal_() + self.register_buffer('ema_cluster_size', + torch.zeros(self.token_class_num)) + self.register_buffer('ema_w', + torch.empty(self.token_class_num, self.token_dim)) + self.ema_w.data.normal_() + + self.decoder_token_mlp = nn.Linear( + self.token_num, self.num_joints) + self.decoder_start = nn.Linear( + self.token_dim, self.dec_hidden_dim) + + self.decoder = nn.ModuleList( + [MixerLayer(self.dec_hidden_dim, self.dec_hidden_inter_dim, + self.num_joints, self.dec_token_inter_dim, + self.dec_dropout) for _ in range(self.dec_num_blocks)]) + self.decoder_layer_norm = nn.LayerNorm(self.dec_hidden_dim) + + self.recover_embed = nn.Linear(self.dec_hidden_dim, 2) + + def forward(self, joints, joints_feature, cls_logits, train=True): + """Forward function. """ + + if train or self.stage_pct == "tokenizer": + # Encoder of Tokenizer, Get the PCT groundtruth class labels. + bs, num_joints, _ = joints.shape + device = joints.device + joints_coord, joints_visible, bs \ + = joints[:,:,:-1], joints[:,:,-1].bool(), joints.shape[0] + + encode_feat = self.start_embed(joints_coord) + if self.guide_ratio > 0: + encode_img_feat = self.start_img_embed(joints_feature) + encode_feat = torch.cat((encode_feat, encode_img_feat), dim=2) + + if train and self.stage_pct == "tokenizer": + rand_mask_ind = torch.rand( + joints_visible.shape, device=joints.device) > self.drop_rate + joints_visible = torch.logical_and(rand_mask_ind, joints_visible) + + mask_tokens = self.invisible_token.expand(bs, joints.shape[1], -1) + w = joints_visible.unsqueeze(-1).type_as(mask_tokens) + encode_feat = encode_feat * w + mask_tokens * (1 - w) + + for num_layer in self.encoder: + encode_feat = num_layer(encode_feat) + encode_feat = self.encoder_layer_norm(encode_feat) + + encode_feat = encode_feat.transpose(2, 1) + encode_feat = self.token_mlp(encode_feat).transpose(2, 1) + encode_feat = self.feature_embed(encode_feat).flatten(0,1) + + distances = torch.sum(encode_feat**2, dim=1, keepdim=True) \ + + torch.sum(self.codebook**2, dim=1) \ + - 2 * torch.matmul(encode_feat, self.codebook.t()) + + encoding_indices = torch.argmin(distances, dim=1) + encodings = torch.zeros( + encoding_indices.shape[0], self.token_class_num, device=joints.device) + encodings.scatter_(1, encoding_indices.unsqueeze(1), 1) + else: + # here it suppose cls_logits shape [bs * token_num * token_cls_num] + # predict prob of each token 0,1,2...M-1 belongs to entries 0,1,2...V-1 + # see paper + bs = cls_logits.shape[0] // self.token_num + encoding_indices = None + + if self.stage_pct == "classifier": + part_token_feat = torch.matmul(cls_logits, self.codebook) + else: + part_token_feat = torch.matmul(encodings, self.codebook) + + if train and self.stage_pct == "tokenizer": + # Updating Codebook using EMA + dw = torch.matmul(encodings.t(), encode_feat.detach()) + # sync + n_encodings, n_dw = encodings.numel(), dw.numel() + encodings_shape, dw_shape = encodings.shape, dw.shape + combined = torch.cat((encodings.flatten(), dw.flatten())) + dist.all_reduce(combined) # math sum + sync_encodings, sync_dw = torch.split(combined, [n_encodings, n_dw]) + sync_encodings, sync_dw = \ + sync_encodings.view(encodings_shape), sync_dw.view(dw_shape) + + self.ema_cluster_size = self.ema_cluster_size * self.decay + \ + (1 - self.decay) * torch.sum(sync_encodings, 0) + + n = torch.sum(self.ema_cluster_size.data) + self.ema_cluster_size = ( + (self.ema_cluster_size + 1e-5) + / (n + self.token_class_num * 1e-5) * n) + + self.ema_w = self.ema_w * self.decay + (1 - self.decay) * sync_dw + self.codebook = self.ema_w / self.ema_cluster_size.unsqueeze(1) + e_latent_loss = F.mse_loss(part_token_feat.detach(), encode_feat) + part_token_feat = encode_feat + (part_token_feat - encode_feat).detach() + else: + e_latent_loss = None + + # Decoder of Tokenizer, Recover the joints. + part_token_feat = part_token_feat.view(bs, -1, self.token_dim) + + # Store part token + out_part_token_feat = part_token_feat.clone().detach() + + part_token_feat = part_token_feat.transpose(2,1) + part_token_feat = self.decoder_token_mlp(part_token_feat).transpose(2,1) + decode_feat = self.decoder_start(part_token_feat) + + for num_layer in self.decoder: + decode_feat = num_layer(decode_feat) + decode_feat = self.decoder_layer_norm(decode_feat) + + recoverd_joints = self.recover_embed(decode_feat) + + return recoverd_joints, encoding_indices, e_latent_loss, out_part_token_feat + + def init_weights(self, pretrained=""): + """Initialize model weights.""" + + parameters_names = set() + for name, _ in self.named_parameters(): + parameters_names.add(name) + + buffers_names = set() + for name, _ in self.named_buffers(): + buffers_names.add(name) + + if os.path.isfile(pretrained): + assert (self.stage_pct == "classifier"), \ + "Training tokenizer does not need to load model" + pretrained_state_dict = torch.load(pretrained, + map_location=lambda storage, loc: storage) + + need_init_state_dict = {} + + if 'state_dict' in pretrained_state_dict: + key = 'state_dict' + else: + key = 'model' + for name, m in pretrained_state_dict[key].items(): + if 'keypoint_head.tokenizer.' in name: + name = name.replace('keypoint_head.tokenizer.', '') + if name in parameters_names or name in buffers_names: + need_init_state_dict[name] = m + self.load_state_dict(need_init_state_dict, strict=True) + else: + if self.stage_pct == "classifier": + print('If you are training a classifier, '\ + 'must check that the well-trained tokenizer '\ + 'is located in the correct path.') + + +def save_checkpoint(model, optimizer, epoch, loss, filepath): + checkpoint = { + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss + } + torch.save(checkpoint, filepath) + print(f"Checkpoint saved at {filepath}") + +def load_checkpoint(model, optimizer, filepath): + checkpoint = torch.load(filepath) + model.load_state_dict(checkpoint['model_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + epoch = checkpoint['epoch'] + loss = checkpoint['loss'] + + print(f"Checkpoint loaded from {filepath}. Resuming training from epoch {epoch} with loss {loss}") + + return epoch, loss diff --git a/main/postometro.py b/main/postometro.py new file mode 100644 index 0000000000000000000000000000000000000000..aacd40b450df7299b9a5dc9dd100b97910a60cc5 --- /dev/null +++ b/main/postometro.py @@ -0,0 +1,305 @@ +# ---------------------------------------------------------------------------------------------- +# FastMETRO Official Code +# Copyright (c) POSTECH Algorithmic Machine Intelligence Lab. (P-AMI Lab.) All Rights Reserved +# Licensed under the MIT license. +# ---------------------------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------------------------- +# PostoMETRO Official Code +# Copyright (c) MIRACLE Lab. All Rights Reserved +# Licensed under the MIT license. +# ---------------------------------------------------------------------------------------------- + +from __future__ import absolute_import, division, print_function +import torch +import numpy as np +import argparse +import os +import os.path as osp +from torch import nn +from postometro_utils.smpl import Mesh +from postometro_utils.transformer import build_transformer +from postometro_utils.positional_encoding import build_position_encoding +from postometro_utils.modules import FCBlock, MixerLayer +from pct_utils.pct import PCT +from pct_utils.pct_backbone import SwinV2TransformerRPE2FC +from postometro_utils.pose_resnet import get_pose_net as get_pose_resnet +from postometro_utils.pose_resnet_config import config as resnet_config +from postometro_utils.pose_hrnet import get_pose_hrnet +from postometro_utils.pose_hrnet_config import _C as hrnet_config +from postometro_utils.pose_hrnet_config import update_config as hrnet_update_config + +CUR_DIR = osp.dirname(os.path.abspath(__file__)) + +class PostoMETRO(nn.Module): + """PostoMETRO for 3D human pose and mesh reconstruction from a single RGB image""" + def __init__(self, args, backbone, mesh_sampler, pct = None, num_joints=14, num_vertices=431): + """ + Parameters: + - args: Arguments + - backbone: CNN Backbone used to extract image features from the given image + - mesh_sampler: Mesh Sampler used in the coarse-to-fine mesh upsampling + - num_joints: The number of joint tokens used in the transformer decoder + - num_vertices: The number of vertex tokens used in the transformer decoder + """ + super().__init__() + self.args = args + self.backbone = backbone + self.mesh_sampler = mesh_sampler + self.num_joints = num_joints + self.num_vertices = num_vertices + + # the number of transformer layers, set to default + num_enc_layers = 3 + num_dec_layers = 3 + + # configurations for the first transformer + self.transformer_config_1 = {"model_dim": args.model_dim_1, "dropout": args.transformer_dropout, "nhead": args.transformer_nhead, + "feedforward_dim": args.feedforward_dim_1, "num_enc_layers": num_enc_layers, "num_dec_layers": num_dec_layers, + "pos_type": args.pos_type} + # configurations for the second transformer + self.transformer_config_2 = {"model_dim": args.model_dim_2, "dropout": args.transformer_dropout, "nhead": args.transformer_nhead, + "feedforward_dim": args.feedforward_dim_2, "num_enc_layers": num_enc_layers, "num_dec_layers": num_dec_layers, + "pos_type": args.pos_type} + # build transformers + self.transformer_1 = build_transformer(self.transformer_config_1) + self.transformer_2 = build_transformer(self.transformer_config_2) + + # dimensionality reduction + self.dim_reduce_enc_cam = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"]) + self.dim_reduce_enc_img = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"]) + self.dim_reduce_dec = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"]) + + # token embeddings + self.cam_token_embed = nn.Embedding(1, self.transformer_config_1["model_dim"]) + self.joint_token_embed = nn.Embedding(self.num_joints, self.transformer_config_1["model_dim"]) + self.vertex_token_embed = nn.Embedding(self.num_vertices, self.transformer_config_1["model_dim"]) + # positional encodings + self.position_encoding_1 = build_position_encoding(pos_type=self.transformer_config_1['pos_type'], hidden_dim=self.transformer_config_1['model_dim']) + self.position_encoding_2 = build_position_encoding(pos_type=self.transformer_config_2['pos_type'], hidden_dim=self.transformer_config_2['model_dim']) + # estimators + self.xyz_regressor = nn.Linear(self.transformer_config_2["model_dim"], 3) + self.cam_predictor = nn.Linear(self.transformer_config_2["model_dim"], 3) + + # 1x1 Convolution + self.conv_1x1 = nn.Conv2d(args.conv_1x1_dim, self.transformer_config_1["model_dim"], kernel_size=1) + + # attention mask + zeros_1 = torch.tensor(np.zeros((num_vertices, num_joints)).astype(bool)) + zeros_2 = torch.tensor(np.zeros((num_joints, (num_joints + num_vertices))).astype(bool)) + adjacency_indices = torch.load(osp.join(CUR_DIR, 'data/smpl_431_adjmat_indices.pt')) + adjacency_matrix_value = torch.load(osp.join(CUR_DIR, 'data/smpl_431_adjmat_values.pt')) + adjacency_matrix_size = torch.load(osp.join(CUR_DIR, 'data/smpl_431_adjmat_size.pt')) + adjacency_matrix = torch.sparse_coo_tensor(adjacency_indices, adjacency_matrix_value, size=adjacency_matrix_size).to_dense() + temp_mask_1 = (adjacency_matrix == 0) + temp_mask_2 = torch.cat([zeros_1, temp_mask_1], dim=1) + self.attention_mask = torch.cat([zeros_2, temp_mask_2], dim=0) + + # learnable upsampling layer is used (from coarse mesh to intermediate mesh); for visually pleasing mesh result + ### pre-computed upsampling matrix is used (from intermediate mesh to fine mesh); to reduce optimization difficulty + self.coarse2intermediate_upsample = nn.Linear(431, 1723) + + # using extra token + self.pct = None + if pct is not None: + self.pct = pct + # +1 to align with uncertainty score + self.token_mixer = FCBlock(args.tokenizer_codebook_token_dim + 1, self.transformer_config_1["model_dim"]) + self.start_embed = nn.Linear(512, args.enc_hidden_dim) + self.encoder = nn.ModuleList( + [MixerLayer(args.enc_hidden_dim, args.enc_hidden_inter_dim, + args.num_joints, args.token_inter_dim, + args.enc_dropout) for _ in range(args.enc_num_blocks)]) + self.encoder_layer_norm = nn.LayerNorm(args.enc_hidden_dim) + self.token_mlp = nn.Linear(args.num_joints, args.token_num) + self.dim_reduce_enc_pct = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"]) + + + def forward(self, images): + device = images.device + batch_size = images.size(0) + + # preparation + cam_token = self.cam_token_embed.weight.unsqueeze(1).repeat(1, batch_size, 1) # 1 X batch_size X 512 + jv_tokens = torch.cat([self.joint_token_embed.weight, self.vertex_token_embed.weight], dim=0).unsqueeze(1).repeat(1, batch_size, 1) # (num_joints + num_vertices) X batch_size X 512 + attention_mask = self.attention_mask.to(device) # (num_joints + num_vertices) X (num_joints + num_vertices) + + pct_token = None + if self.pct is not None: + pct_out = self.pct(images, None, train=False) + pct_pose = pct_out['part_token_feat'].clone() + + encode_feat = self.start_embed(pct_pose) # 2, 17, 512 + for num_layer in self.encoder: + encode_feat = num_layer(encode_feat) + encode_feat = self.encoder_layer_norm(encode_feat) + encode_feat = encode_feat.transpose(2, 1) + encode_feat = self.token_mlp(encode_feat).transpose(2, 1) + pct_token_out = encode_feat.permute(1,0,2) + + pct_score = pct_out['encoding_scores'] + pct_score = pct_score.permute(1,0,2) + pct_token = torch.cat([pct_token_out, pct_score], dim = -1) + pct_token = self.token_mixer(pct_token) # [b, 34, 512] + + # extract image features through a CNN backbone + _img_features = self.backbone(images) # batch_size X 2048 X 7 X 7 + _, _, h, w = _img_features.shape + img_features = self.conv_1x1(_img_features).flatten(2).permute(2, 0, 1) # 49 X batch_size X 512 + + # positional encodings + pos_enc_1 = self.position_encoding_1(batch_size, h, w, device).flatten(2).permute(2, 0, 1) # 49 X batch_size X 512 + pos_enc_2 = self.position_encoding_2(batch_size, h, w, device).flatten(2).permute(2, 0, 1) # 49 X batch_size X 128 + + # first transformer encoder-decoder + cam_features_1, enc_img_features_1, jv_features_1, pct_features_1 = self.transformer_1(img_features, cam_token, jv_tokens, pos_enc_1, pct_token = pct_token, attention_mask=attention_mask) + + # progressive dimensionality reduction + reduced_cam_features_1 = self.dim_reduce_enc_cam(cam_features_1) # 1 X batch_size X 128 + reduced_enc_img_features_1 = self.dim_reduce_enc_img(enc_img_features_1) # 49 X batch_size X 128 + reduced_jv_features_1 = self.dim_reduce_dec(jv_features_1) # (num_joints + num_vertices) X batch_size X 128 + reduced_pct_features_1 = None + if pct_features_1 is not None: + reduced_pct_features_1 = self.dim_reduce_enc_pct(pct_features_1) + + # second transformer encoder-decoder + cam_features_2, _, jv_features_2,_ = self.transformer_2(reduced_enc_img_features_1, reduced_cam_features_1, reduced_jv_features_1, pos_enc_2, pct_token = reduced_pct_features_1, attention_mask=attention_mask) + + # estimators + pred_cam = self.cam_predictor(cam_features_2).view(batch_size, 3) # batch_size X 3 + + pred_3d_coordinates = self.xyz_regressor(jv_features_2.transpose(0, 1)) # batch_size X (num_joints + num_vertices) X 3 + pred_3d_joints = pred_3d_coordinates[:,:self.num_joints,:] # batch_size X num_joints X 3 + pred_3d_vertices_coarse = pred_3d_coordinates[:,self.num_joints:,:] # batch_size X num_vertices(coarse) X 3 + + # coarse-to-intermediate mesh upsampling + pred_3d_vertices_intermediate = self.coarse2intermediate_upsample(pred_3d_vertices_coarse.transpose(1,2)).transpose(1,2) # batch_size X num_vertices(intermediate) X 3 + # intermediate-to-fine mesh upsampling + pred_3d_vertices_fine = self.mesh_sampler.upsample(pred_3d_vertices_intermediate, n1=1, n2=0) # batch_size X num_vertices(fine) X 3 + + out = {} + out['pred_cam'] = pred_cam + out['pct_pose'] = pct_out['pred_pose'] if self.pct is not None else torch.zeros((batch_size, 34, 2)).cuda(device) + out['pred_3d_joints'] = pred_3d_joints + out['pred_3d_vertices_coarse'] = pred_3d_vertices_coarse + out['pred_3d_vertices_intermediate'] = pred_3d_vertices_intermediate + out['pred_3d_vertices_fine'] = pred_3d_vertices_fine + + return out + + +defaults_args = argparse.Namespace( + pos_type = 'sine', + transformer_dropout = 0.1, + transformer_nhead = 8, + conv_1x1_dim = 2048, + tokenizer_codebook_token_dim = 512, + model_dim_1 = 512, + feedforward_dim_1 = 2048, + model_dim_2 = 128, + feedforward_dim_2 = 512, + enc_hidden_dim = 512, + enc_hidden_inter_dim = 512, + token_inter_dim = 64, + enc_dropout = 0.0, + enc_num_blocks = 4, + num_joints = 34, + token_num = 34 +) + +default_pct_args = argparse.Namespace( + pct_backbone_channel = 1536, + tokenizer_guide_ratio=0.5, + cls_head_conv_channels=256, + cls_head_hidden_dim=64, + cls_head_num_blocks=4, + cls_head_hidden_inter_dim=256, + cls_head_token_inter_dim=64, + cls_head_dropout=0.0, + cls_head_conv_num_blocks=2, + cls_head_dilation=1, + # tokenzier + tokenizer_encoder_drop_rate=0.2, + tokenizer_encoder_num_blocks=4, + tokenizer_encoder_hidden_dim=512, + tokenizer_encoder_token_inter_dim=64, + tokenizer_encoder_hidden_inter_dim=512, + tokenizer_encoder_dropout=0.0, + tokenizer_decoder_num_blocks=1, + tokenizer_decoder_hidden_dim=32, + tokenizer_decoder_token_inter_dim=64, + tokenizer_decoder_hidden_inter_dim=64, + tokenizer_decoder_dropout=0.0, + tokenizer_codebook_token_num=34, + tokenizer_codebook_token_dim=512, + tokenizer_codebook_token_class_num=2048, + tokenizer_codebook_ema_decay=0.9, +) + +backbone_config=dict( + embed_dim=192, + depths=[2, 2, 18, 2], + num_heads=[6, 12, 24, 48], + window_size=[16, 16, 16, 8], + pretrain_window_size=[12, 12, 12, 6], + ape=False, + drop_path_rate=0.5, + patch_norm=True, + use_checkpoint=True, + rpe_interpolation='geo', + use_shift=[True, True, False, False], + relative_coords_table_type='norm8_log_bylayer', + attn_type='cosine_mh', + rpe_output_type='sigmoid', + postnorm=True, + mlp_type='normal', + out_indices=(3,), + patch_embed_type='normal', + patch_merge_type='normal', + strid16=False, + frozen_stages=5, +) + +def get_model(backbone_str = 'resnet50', device = torch.device('cpu'), checkpoint_file = None): + if backbone_str == 'hrnet-w48': + defaults_args.conv_1x1_dim = 384 + # update hrnet config by yaml + hrnet_yaml = osp.join(CUR_DIR,'postometro_utils', 'pose_w48_256x192_adam_lr1e-3.yaml') + hrnet_update_config(hrnet_config, hrnet_yaml) + backbone = get_pose_hrnet(hrnet_config, None) + else: + backbone = get_pose_resnet(resnet_config, is_train=False) + mesh_upsampler = Mesh(device=device) + pct_swin_backbone = SwinV2TransformerRPE2FC(**backbone_config) + # initialize pct head + pct = PCT(default_pct_args, pct_swin_backbone, 'classifier', default_pct_args.pct_backbone_channel, (256, 256), 17, None, None).to(device) + model = PostoMETRO(defaults_args, backbone, mesh_upsampler, pct=pct).to(device) + print("[INFO] model loaded, params: {}, {}".format(backbone_str, device)) + if checkpoint_file: + cpu_device = torch.device('cpu') + state_dict = torch.load(checkpoint_file, map_location=cpu_device) + model.load_state_dict(state_dict, strict=True) + del state_dict + print("[INFO] checkpoint loaded, params: {}, {}".format(backbone_str, device)) + return model + +if __name__ == '__main__': + test_model = get_model(device=torch.device('cuda')) + images = torch.randn(1,3,256,256).to(torch.device('cuda')) + test_out = test_model(images) + print("[TEST] resnet50, cuda : pass") + + test_model = get_model() + images = torch.randn(1,3,256,256).to() + test_out = test_model(images) + print("[TEST] resnet50, cpu : pass") + + test_model = get_model(backbone_str='hrnet-w48', device=torch.device('cuda')) + images = torch.randn(1,3,256,256).to(torch.device('cuda')) + test_out = test_model(images) + print("[TEST] hrnet-w48, cuda : pass") + + test_model = get_model(backbone_str='hrnet-w48') + images = torch.randn(1,3,256,256).to() + test_out = test_model(images) + print("[TEST] hrnet-w48, cpu : pass") diff --git a/main/postometro_utils/__pycache__/geometric_layers.cpython-39.pyc b/main/postometro_utils/__pycache__/geometric_layers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a84550f3a4bdd7fef159896e31002a97dc8b171a Binary files /dev/null and b/main/postometro_utils/__pycache__/geometric_layers.cpython-39.pyc differ diff --git a/main/postometro_utils/__pycache__/modules.cpython-39.pyc b/main/postometro_utils/__pycache__/modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98c14c460770f9de794e1635584e84a55d01480c Binary files /dev/null and b/main/postometro_utils/__pycache__/modules.cpython-39.pyc differ diff --git a/main/postometro_utils/__pycache__/pose_hrnet.cpython-39.pyc b/main/postometro_utils/__pycache__/pose_hrnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f0c47a615e65a6db352920d52d52a49107492e5 Binary files /dev/null and b/main/postometro_utils/__pycache__/pose_hrnet.cpython-39.pyc differ diff --git a/main/postometro_utils/__pycache__/pose_hrnet_config.cpython-39.pyc b/main/postometro_utils/__pycache__/pose_hrnet_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b0fcdf7c9c0294f5b0142084741debbcf4481cd Binary files /dev/null and b/main/postometro_utils/__pycache__/pose_hrnet_config.cpython-39.pyc differ diff --git a/main/postometro_utils/__pycache__/pose_resnet.cpython-39.pyc b/main/postometro_utils/__pycache__/pose_resnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7928135ef3c7b5be652144ed03a3024dca85127 Binary files /dev/null and b/main/postometro_utils/__pycache__/pose_resnet.cpython-39.pyc differ diff --git a/main/postometro_utils/__pycache__/pose_resnet_config.cpython-39.pyc b/main/postometro_utils/__pycache__/pose_resnet_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c121203e49e494758b826a7022484360bbad314d Binary files /dev/null and b/main/postometro_utils/__pycache__/pose_resnet_config.cpython-39.pyc differ diff --git a/main/postometro_utils/__pycache__/positional_encoding.cpython-39.pyc b/main/postometro_utils/__pycache__/positional_encoding.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..197a81e24a0f508e670f9413f7c51aa82f22907f Binary files /dev/null and b/main/postometro_utils/__pycache__/positional_encoding.cpython-39.pyc differ diff --git a/main/postometro_utils/__pycache__/renderer_pyrender.cpython-39.pyc b/main/postometro_utils/__pycache__/renderer_pyrender.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f56450a6e9b6213a210ed9b21b23eac5429f6839 Binary files /dev/null and b/main/postometro_utils/__pycache__/renderer_pyrender.cpython-39.pyc differ diff --git a/main/postometro_utils/__pycache__/smpl.cpython-39.pyc b/main/postometro_utils/__pycache__/smpl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22203612db68c6bac6c09426e832ef46dad83786 Binary files /dev/null and b/main/postometro_utils/__pycache__/smpl.cpython-39.pyc differ diff --git a/main/postometro_utils/__pycache__/transformer.cpython-39.pyc b/main/postometro_utils/__pycache__/transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e4b98396b2ae5d4d89c7461be3875058397574d Binary files /dev/null and b/main/postometro_utils/__pycache__/transformer.cpython-39.pyc differ diff --git a/main/postometro_utils/geometric_layers.py b/main/postometro_utils/geometric_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..f8b59ca8abcac164ff695dbfa700f273d0d156ef --- /dev/null +++ b/main/postometro_utils/geometric_layers.py @@ -0,0 +1,679 @@ +# ---------------------------------------------------------------------------------------------- +# METRO (https://github.com/microsoft/MeshTransformer) +# Copyright (c) Microsoft Corporation. All Rights Reserved [see https://github.com/microsoft/MeshTransformer/blob/main/LICENSE for details] +# Licensed under the MIT license. +# ---------------------------------------------------------------------------------------------- +""" +Useful geometric operations, e.g. Orthographic projection and a differentiable Rodrigues formula + +Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR +""" + +import torch +import torch.nn.functional as F + +def rodrigues(theta): + """Convert axis-angle representation to rotation matrix. + Args: + theta: size = [B, 3] + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1) + angle = torch.unsqueeze(l1norm, -1) + normalized = torch.div(theta, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * normalized], dim = 1) + return quat2mat(quat) + +def quat2mat(quat): + """Convert quaternion coefficients to rotation matrix. + Args: + quat: size = [B, 4] 4 <===>(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + norm_quat = quat + norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3] + + B = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w*x, w*y, w*z + xy, xz, yz = x*y, x*z, y*z + + rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, + 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, + 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) + return rotMat + +def orthographic_projection(X, camera): + """Perform orthographic projection of 3D points X using the camera parameters + Args: + X: size = [B, N, 3] + camera: size = [B, 3] + Returns: + Projected 2D points -- size = [B, N, 2] + """ + camera = camera.view(-1, 1, 3) + X_trans = X[:, :, :2] + camera[:, :, 1:] + shape = X_trans.shape + X_2d = (camera[:, :, 0] * X_trans.view(shape[0], -1)).view(shape) + return X_2d + +def orthographic_projection_reshape(X, camera): + """Perform orthographic projection of 3D points X using the camera parameters + Args: + X: size = [B, N, 3] + camera: size = [B, 3] + Returns: + Projected 2D points -- size = [B, N, 2] + """ + camera = camera.reshape(-1, 1, 3) + X_trans = X[:, :, :2] + camera[:, :, 1:] + shape = X_trans.shape + X_2d = (camera[:, :, 0] * X_trans.reshape(shape[0], -1)).reshape(shape) + return X_2d + +def orthographic_projection_reshape(X, camera): + """Perform orthographic projection of 3D points X using the camera parameters + Args: + X: size = [B, N, 3] + camera: size = [B, 3] + Returns: + Projected 2D points -- size = [B, N, 2] + """ + camera = camera.reshape(-1, 1, 3) + X_trans = X[:, :, :2] + camera[:, :, 1:] + shape = X_trans.shape + X_2d = (camera[:, :, 0] * X_trans.reshape(shape[0], -1)).reshape(shape) + return X_2d + + +def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalization per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + batch_dim = matrix.size()[:-2] + return matrix[..., :2, :].clone().reshape(batch_dim + (6,)) + +def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = angles * 0.5 + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + +def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + +def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + + return quat_candidates[ + F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : + ].reshape(batch_dim + (4,)) + +def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + +def axis_angle_to_rotation_6d(axis_angle: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + 6D rotation representation, of size (*, 6) + """ + return matrix_to_rotation_6d(axis_angle_to_matrix(axis_angle)) + +def rotation_6d_to_axis_angle(d6): + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalization per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + return matrix_to_axis_angle(rotation_6d_to_matrix(d6)) + + +def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalization per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + batch_dim = matrix.size()[:-2] + return matrix[..., :2, :].clone().reshape(batch_dim + (6,)) + +def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = angles * 0.5 + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + +def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + +def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + + return quat_candidates[ + F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : + ].reshape(batch_dim + (4,)) + +def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + +def axis_angle_to_rotation_6d(axis_angle: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + 6D rotation representation, of size (*, 6) + """ + return matrix_to_rotation_6d(axis_angle_to_matrix(axis_angle)) + +def rotation_6d_to_axis_angle(d6): + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalization per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + return matrix_to_axis_angle(rotation_6d_to_matrix(d6)) \ No newline at end of file diff --git a/main/postometro_utils/modules.py b/main/postometro_utils/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..2c868a46adb650284343178e1ea8c9a5c51ff73a --- /dev/null +++ b/main/postometro_utils/modules.py @@ -0,0 +1,117 @@ +# -------------------------------------------------------- +# Borrow from unofficial MLPMixer (https://github.com/920232796/MlpMixer-pytorch) +# Borrow from ResNet +# Modified by Zigang Geng (zigang@mail.ustc.edu.cn) +# -------------------------------------------------------- + +import torch +import torch.nn as nn + + +class FCBlock(nn.Module): + def __init__(self, dim, out_dim): + super().__init__() + + self.ff = nn.Sequential( + nn.Linear(dim, out_dim), + nn.LayerNorm(out_dim), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + return self.ff(x) + + +class MLPBlock(nn.Module): + def __init__(self, dim, inter_dim, dropout_ratio): + super().__init__() + + self.ff = nn.Sequential( + nn.Linear(dim, inter_dim), + nn.GELU(), + nn.Dropout(dropout_ratio), + nn.Linear(inter_dim, dim), + nn.Dropout(dropout_ratio) + ) + + def forward(self, x): + return self.ff(x) + + +class MixerLayer(nn.Module): + def __init__(self, + hidden_dim, + hidden_inter_dim, + token_dim, + token_inter_dim, + dropout_ratio): + super().__init__() + + self.layernorm1 = nn.LayerNorm(hidden_dim) + self.MLP_token = MLPBlock(token_dim, token_inter_dim, dropout_ratio) + self.layernorm2 = nn.LayerNorm(hidden_dim) + self.MLP_channel = MLPBlock(hidden_dim, hidden_inter_dim, dropout_ratio) + + def forward(self, x): + y = self.layernorm1(x) + y = y.transpose(2, 1) + y = self.MLP_token(y) + y = y.transpose(2, 1) + z = self.layernorm2(x + y) + z = self.MLP_channel(z) + out = x + y + z + return out + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, + downsample=None, dilation=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, + padding=dilation, bias=False, dilation=dilation) + self.bn1 = nn.BatchNorm2d(planes, momentum=0.1) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, + padding=dilation, bias=False, dilation=dilation) + self.bn2 = nn.BatchNorm2d(planes, momentum=0.1) + self.downsample = downsample + self.stride = stride + + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + +def make_conv_layers(feat_dims, kernel=3, stride=1, padding=1, bnrelu_final=True): + layers = [] + for i in range(len(feat_dims)-1): + layers.append( + nn.Conv2d( + in_channels=feat_dims[i], + out_channels=feat_dims[i+1], + kernel_size=kernel, + stride=stride, + padding=padding + )) + # Do not use BN and ReLU for final estimation + if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final): + layers.append(nn.BatchNorm2d(feat_dims[i+1])) + layers.append(nn.ReLU(inplace=True)) + + return nn.Sequential(*layers) \ No newline at end of file diff --git a/main/postometro_utils/pose_hrnet.py b/main/postometro_utils/pose_hrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..007c29447896f05273b56c64451c7a8840d608bb --- /dev/null +++ b/main/postometro_utils/pose_hrnet.py @@ -0,0 +1,502 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (Bin.Xiao@microsoft.com) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import logging + +import torch +import torch.nn as nn + + +BN_MOMENTUM = 0.1 +logger = logging.getLogger(__name__) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, + bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion, + momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method, multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(True) + + def _check_branches(self, num_branches, blocks, num_blocks, + num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False + ), + nn.BatchNorm2d( + num_channels[branch_index] * block.expansion, + momentum=BN_MOMENTUM + ), + ) + + layers = [] + layers.append( + block( + self.num_inchannels[branch_index], + num_channels[branch_index], + stride, + downsample + ) + ) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append( + block( + self.num_inchannels[branch_index], + num_channels[branch_index] + ) + ) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels) + ) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_inchannels[i], + 1, 1, 0, bias=False + ), + nn.BatchNorm2d(num_inchannels[i]), + nn.Upsample(scale_factor=2**(j-i), mode='nearest') + ) + ) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i-j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False + ), + nn.BatchNorm2d(num_outchannels_conv3x3) + ) + ) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False + ), + nn.BatchNorm2d(num_outchannels_conv3x3), + nn.ReLU(True) + ) + ) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +blocks_dict = { + 'BASIC': BasicBlock, + 'BOTTLENECK': Bottleneck +} + + +class PoseHighResolutionNet(nn.Module): + + def __init__(self, cfg, **kwargs): + self.inplanes = 64 + extra = cfg['MODEL']['EXTRA'] + super(PoseHighResolutionNet, self).__init__() + + # stem net + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.layer1 = self._make_layer(Bottleneck, 64, 4) + + self.stage2_cfg = extra['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels)) + ] + self.transition1 = self._make_transition_layer([256], num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + self.stage3_cfg = extra['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels)) + ] + self.transition2 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + self.stage4_cfg = extra['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels)) + ] + self.transition3 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, + multi_scale_output=True) + # multi_scale_output=False) + + self.final_layer = nn.Conv2d( + in_channels=pre_stage_channels[0], + out_channels=cfg['MODEL']['NUM_JOINTS'], + kernel_size=extra['FINAL_CONV_KERNEL'], + stride=1, + padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0 + ) + + self.pretrained_layers = extra['PRETRAINED_LAYERS'] + + def _make_transition_layer( + self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + nn.Conv2d( + num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, 1, 1, bias=False + ), + nn.BatchNorm2d(num_channels_cur_layer[i]), + nn.ReLU(inplace=True) + ) + ) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i+1-num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i-num_branches_pre else inchannels + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + inchannels, outchannels, 3, 2, 1, bias=False + ), + nn.BatchNorm2d(outchannels), + nn.ReLU(inplace=True) + ) + ) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False + ), + nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, + multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + + modules.append( + HighResolutionModule( + num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output + ) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['NUM_BRANCHES']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage4(x_list) + + return y_list[-1] + # x = self.final_layer(y_list[0]) + # return x + + def init_weights(self, pretrained=''): + logger.info('=> init weights from normal distribution') + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.normal_(m.weight, std=0.001) + for name, _ in m.named_parameters(): + if name in ['bias']: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + nn.init.normal_(m.weight, std=0.001) + for name, _ in m.named_parameters(): + if name in ['bias']: + nn.init.constant_(m.bias, 0) + + if os.path.isfile(pretrained): + pretrained_state_dict = torch.load(pretrained) + logger.info('=> loading pretrained model {}'.format(pretrained)) + + need_init_state_dict = {} + for name, m in pretrained_state_dict.items(): + if name.split('.')[0] in self.pretrained_layers \ + or self.pretrained_layers[0] is '*': + need_init_state_dict[name] = m + out = self.load_state_dict(need_init_state_dict, strict=False) + elif pretrained: + logger.error('=> please download pre-trained models first!') + raise ValueError('{} is not exist!'.format(pretrained)) + + +def get_pose_hrnet(cfg, pretrained, **kwargs): + model = PoseHighResolutionNet(cfg, **kwargs) + if pretrained is not None: + model.init_weights(pretrained=pretrained) + + return model \ No newline at end of file diff --git a/main/postometro_utils/pose_hrnet_config.py b/main/postometro_utils/pose_hrnet_config.py new file mode 100644 index 0000000000000000000000000000000000000000..00c542fa56cd875f6655b844d8a49d6ce061838f --- /dev/null +++ b/main/postometro_utils/pose_hrnet_config.py @@ -0,0 +1,137 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (Bin.Xiao@microsoft.com) +# Modified by Ke Sun (sunk@mail.ustc.edu.cn) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from yacs.config import CfgNode as CN + + +_C = CN() + +_C.OUTPUT_DIR = '' +_C.LOG_DIR = '' +_C.DATA_DIR = '' +_C.GPUS = (0,) +_C.WORKERS = 4 +_C.PRINT_FREQ = 20 +_C.AUTO_RESUME = False +_C.PIN_MEMORY = True +_C.RANK = 0 + +# Cudnn related params +_C.CUDNN = CN() +_C.CUDNN.BENCHMARK = True +_C.CUDNN.DETERMINISTIC = False +_C.CUDNN.ENABLED = True + +# common params for NETWORK +_C.MODEL = CN() +_C.MODEL.NAME = 'cls_hrnet' +_C.MODEL.INIT_WEIGHTS = True +_C.MODEL.PRETRAINED = '' +_C.MODEL.NUM_JOINTS = 17 +_C.MODEL.NUM_CLASSES = 1000 +_C.MODEL.TAG_PER_JOINT = True +_C.MODEL.TARGET_TYPE = 'gaussian' +_C.MODEL.IMAGE_SIZE = [256, 256] # width * height, ex: 192 * 256 +_C.MODEL.HEATMAP_SIZE = [64, 64] # width * height, ex: 24 * 32 +_C.MODEL.SIGMA = 2 +_C.MODEL.EXTRA = CN(new_allowed=True) + +_C.LOSS = CN() +_C.LOSS.USE_OHKM = False +_C.LOSS.TOPK = 8 +_C.LOSS.USE_TARGET_WEIGHT = True +_C.LOSS.USE_DIFFERENT_JOINTS_WEIGHT = False + +# DATASET related params +_C.DATASET = CN() +_C.DATASET.ROOT = '' +_C.DATASET.DATASET = 'mpii' +_C.DATASET.TRAIN_SET = 'train' +_C.DATASET.TEST_SET = 'valid' +_C.DATASET.DATA_FORMAT = 'jpg' +_C.DATASET.HYBRID_JOINTS_TYPE = '' +_C.DATASET.SELECT_DATA = False + +# training data augmentation +_C.DATASET.FLIP = True +_C.DATASET.SCALE_FACTOR = 0.25 +_C.DATASET.ROT_FACTOR = 30 +_C.DATASET.PROB_HALF_BODY = 0.0 +_C.DATASET.NUM_JOINTS_HALF_BODY = 8 +_C.DATASET.COLOR_RGB = False + +# train +_C.TRAIN = CN() + +_C.TRAIN.LR_FACTOR = 0.1 +_C.TRAIN.LR_STEP = [90, 110] +_C.TRAIN.LR = 0.001 + +_C.TRAIN.OPTIMIZER = 'adam' +_C.TRAIN.MOMENTUM = 0.9 +_C.TRAIN.WD = 0.0001 +_C.TRAIN.NESTEROV = False +_C.TRAIN.GAMMA1 = 0.99 +_C.TRAIN.GAMMA2 = 0.0 + +_C.TRAIN.BEGIN_EPOCH = 0 +_C.TRAIN.END_EPOCH = 140 + +_C.TRAIN.RESUME = False +_C.TRAIN.CHECKPOINT = '' + +_C.TRAIN.BATCH_SIZE_PER_GPU = 32 +_C.TRAIN.SHUFFLE = True + +# testing +_C.TEST = CN() + +# size of images for each device +_C.TEST.BATCH_SIZE_PER_GPU = 32 +# Test Model Epoch +_C.TEST.FLIP_TEST = False +_C.TEST.POST_PROCESS = False +_C.TEST.SHIFT_HEATMAP = False + +_C.TEST.USE_GT_BBOX = False + +# nms +_C.TEST.IMAGE_THRE = 0.1 +_C.TEST.NMS_THRE = 0.6 +_C.TEST.SOFT_NMS = False +_C.TEST.OKS_THRE = 0.5 +_C.TEST.IN_VIS_THRE = 0.0 +_C.TEST.COCO_BBOX_FILE = '' +_C.TEST.BBOX_THRE = 1.0 +_C.TEST.MODEL_FILE = '' + +# debug +_C.DEBUG = CN() +_C.DEBUG.DEBUG = False +_C.DEBUG.SAVE_BATCH_IMAGES_GT = False +_C.DEBUG.SAVE_BATCH_IMAGES_PRED = False +_C.DEBUG.SAVE_HEATMAPS_GT = False +_C.DEBUG.SAVE_HEATMAPS_PRED = False + + +def update_config(cfg, config_file): + cfg.defrost() + cfg.merge_from_file(config_file) + cfg.freeze() + + +if __name__ == '__main__': + import sys + with open(sys.argv[1], 'w') as f: + print(_C, file=f) + diff --git a/main/postometro_utils/pose_resnet.py b/main/postometro_utils/pose_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c423f36418ef5ac0dcc85f082a029f0280b6b52f --- /dev/null +++ b/main/postometro_utils/pose_resnet.py @@ -0,0 +1,318 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (Bin.Xiao@microsoft.com) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import logging + +import torch +import torch.nn as nn +from collections import OrderedDict + + +BN_MOMENTUM = 0.1 +logger = logging.getLogger(__name__) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, + bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion, + momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck_CAFFE(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck_CAFFE, self).__init__() + # add stride to conv1x1 + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, + bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion, + momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class PoseResNet(nn.Module): + + def __init__(self, block, layers, cfg, **kwargs): + self.inplanes = 64 + extra = cfg.MODEL.EXTRA + self.deconv_with_bias = extra.DECONV_WITH_BIAS + + super(PoseResNet, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + # used for deconv layers + # self.deconv_layers = self._make_deconv_layer( + # extra.NUM_DECONV_LAYERS, + # extra.NUM_DECONV_FILTERS, + # extra.NUM_DECONV_KERNELS, + # ) + + # self.final_layer = nn.Conv2d( + # in_channels=extra.NUM_DECONV_FILTERS[-1], + # out_channels=cfg.MODEL.NUM_JOINTS, + # kernel_size=extra.FINAL_CONV_KERNEL, + # stride=1, + # padding=1 if extra.FINAL_CONV_KERNEL == 3 else 0 + # ) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def _get_deconv_cfg(self, deconv_kernel, index): + if deconv_kernel == 4: + padding = 1 + output_padding = 0 + elif deconv_kernel == 3: + padding = 1 + output_padding = 1 + elif deconv_kernel == 2: + padding = 0 + output_padding = 0 + + return deconv_kernel, padding, output_padding + + def _make_deconv_layer(self, num_layers, num_filters, num_kernels): + assert num_layers == len(num_filters), \ + 'ERROR: num_deconv_layers is different len(num_deconv_filters)' + assert num_layers == len(num_kernels), \ + 'ERROR: num_deconv_layers is different len(num_deconv_filters)' + + layers = [] + for i in range(num_layers): + kernel, padding, output_padding = \ + self._get_deconv_cfg(num_kernels[i], i) + + planes = num_filters[i] + layers.append( + nn.ConvTranspose2d( + in_channels=self.inplanes, + out_channels=planes, + kernel_size=kernel, + stride=2, + padding=padding, + output_padding=output_padding, + bias=self.deconv_with_bias)) + layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)) + layers.append(nn.ReLU(inplace=True)) + self.inplanes = planes + + return nn.Sequential(*layers) + + def forward(self, x, skip_early = False, use_pct = False): + if not use_pct: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + return x + + if skip_early: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + return x + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + return x + + def init_weights(self, pretrained=''): + if os.path.isfile(pretrained): + # pretrained_state_dict = torch.load(pretrained) + logger.info('=> loading pretrained model {}'.format(pretrained)) + # self.load_state_dict(pretrained_state_dict, strict=False) + checkpoint = torch.load(pretrained) + if isinstance(checkpoint, OrderedDict): + state_dict = checkpoint + elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + state_dict_old = checkpoint['state_dict'] + state_dict = OrderedDict() + # delete 'module.' because it is saved from DataParallel module + for key in state_dict_old.keys(): + if key.startswith('module.'): + # state_dict[key[7:]] = state_dict[key] + # state_dict.pop(key) + state_dict[key[7:]] = state_dict_old[key] + else: + state_dict[key] = state_dict_old[key] + else: + raise RuntimeError( + 'No state_dict found in checkpoint file {}'.format(pretrained)) + state_dict_old = state_dict + state_dict = OrderedDict() + for k,v in state_dict_old.items(): + if 'deconv_layers' in k or 'final_layer' in k: + continue + else: + state_dict[k] = state_dict_old[k] + self.load_state_dict(state_dict, strict=True) + else: + logger.error('=> imagenet pretrained model dose not exist') + logger.error('=> please download it first') + raise ValueError('imagenet pretrained model does not exist') + + +resnet_spec = {18: (BasicBlock, [2, 2, 2, 2]), + 34: (BasicBlock, [3, 4, 6, 3]), + 50: (Bottleneck, [3, 4, 6, 3]), + 101: (Bottleneck, [3, 4, 23, 3]), + 152: (Bottleneck, [3, 8, 36, 3])} + + +def get_pose_net(cfg, is_train, **kwargs): + num_layers = cfg.MODEL.EXTRA.NUM_LAYERS + style = cfg.MODEL.STYLE + + block_class, layers = resnet_spec[num_layers] + + if style == 'caffe': + block_class = Bottleneck_CAFFE + + model = PoseResNet(block_class, layers, cfg, **kwargs) + + if is_train and cfg.MODEL.INIT_WEIGHTS: + model.init_weights(cfg.MODEL.PRETRAINED) + + return model diff --git a/main/postometro_utils/pose_resnet_config.py b/main/postometro_utils/pose_resnet_config.py new file mode 100644 index 0000000000000000000000000000000000000000..06bc863d38bf6fad4f00390ebb11c984ed9dd89e --- /dev/null +++ b/main/postometro_utils/pose_resnet_config.py @@ -0,0 +1,229 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (Bin.Xiao@microsoft.com) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import yaml + +import numpy as np +from easydict import EasyDict as edict + + +config = edict() + +config.OUTPUT_DIR = '' +config.LOG_DIR = '' +config.DATA_DIR = '' +config.GPUS = '0' +config.WORKERS = 4 +config.PRINT_FREQ = 20 + +# Cudnn related params +config.CUDNN = edict() +config.CUDNN.BENCHMARK = True +config.CUDNN.DETERMINISTIC = False +config.CUDNN.ENABLED = True + +# pose_resnet related params +POSE_RESNET = edict() +POSE_RESNET.NUM_LAYERS = 50 +POSE_RESNET.DECONV_WITH_BIAS = False +POSE_RESNET.NUM_DECONV_LAYERS = 3 +POSE_RESNET.NUM_DECONV_FILTERS = [256, 256, 256] +POSE_RESNET.NUM_DECONV_KERNELS = [4, 4, 4] +POSE_RESNET.FINAL_CONV_KERNEL = 1 +POSE_RESNET.TARGET_TYPE = 'gaussian' +POSE_RESNET.HEATMAP_SIZE = [64, 64] # width * height, ex: 24 * 32 +POSE_RESNET.SIGMA = 2 + +MODEL_EXTRAS = { + 'pose_resnet': POSE_RESNET, +} + +# common params for NETWORK +config.MODEL = edict() +config.MODEL.NAME = 'pose_resnet' +config.MODEL.INIT_WEIGHTS = True +config.MODEL.PRETRAINED = '' +config.MODEL.NUM_JOINTS = 16 +config.MODEL.IMAGE_SIZE = [256, 256] # width * height, ex: 192 * 256 +config.MODEL.EXTRA = MODEL_EXTRAS[config.MODEL.NAME] + +config.MODEL.STYLE = 'pytorch' + +config.LOSS = edict() +config.LOSS.USE_TARGET_WEIGHT = True + +# DATASET related params +config.DATASET = edict() +config.DATASET.ROOT = '' +config.DATASET.DATASET = 'mpii' +config.DATASET.TRAIN_SET = 'train' +config.DATASET.TEST_SET = 'valid' +config.DATASET.DATA_FORMAT = 'jpg' +config.DATASET.HYBRID_JOINTS_TYPE = '' +config.DATASET.SELECT_DATA = False + +# training data augmentation +config.DATASET.FLIP = True +config.DATASET.SCALE_FACTOR = 0.25 +config.DATASET.ROT_FACTOR = 30 + +# train +config.TRAIN = edict() + +config.TRAIN.LR_FACTOR = 0.1 +config.TRAIN.LR_STEP = [90, 110] +config.TRAIN.LR = 0.001 + +config.TRAIN.OPTIMIZER = 'adam' +config.TRAIN.MOMENTUM = 0.9 +config.TRAIN.WD = 0.0001 +config.TRAIN.NESTEROV = False +config.TRAIN.GAMMA1 = 0.99 +config.TRAIN.GAMMA2 = 0.0 + +config.TRAIN.BEGIN_EPOCH = 0 +config.TRAIN.END_EPOCH = 140 + +config.TRAIN.RESUME = False +config.TRAIN.CHECKPOINT = '' + +config.TRAIN.BATCH_SIZE = 32 +config.TRAIN.SHUFFLE = True + +# testing +config.TEST = edict() + +# size of images for each device +config.TEST.BATCH_SIZE = 32 +# Test Model Epoch +config.TEST.FLIP_TEST = False +config.TEST.POST_PROCESS = True +config.TEST.SHIFT_HEATMAP = True + +config.TEST.USE_GT_BBOX = False +# nms +config.TEST.OKS_THRE = 0.5 +config.TEST.IN_VIS_THRE = 0.0 +config.TEST.COCO_BBOX_FILE = '' +config.TEST.BBOX_THRE = 1.0 +config.TEST.MODEL_FILE = '' +config.TEST.IMAGE_THRE = 0.0 +config.TEST.NMS_THRE = 1.0 + +# debug +config.DEBUG = edict() +config.DEBUG.DEBUG = False +config.DEBUG.SAVE_BATCH_IMAGES_GT = False +config.DEBUG.SAVE_BATCH_IMAGES_PRED = False +config.DEBUG.SAVE_HEATMAPS_GT = False +config.DEBUG.SAVE_HEATMAPS_PRED = False + + +def _update_dict(k, v): + if k == 'DATASET': + if 'MEAN' in v and v['MEAN']: + v['MEAN'] = np.array([eval(x) if isinstance(x, str) else x + for x in v['MEAN']]) + if 'STD' in v and v['STD']: + v['STD'] = np.array([eval(x) if isinstance(x, str) else x + for x in v['STD']]) + if k == 'MODEL': + if 'EXTRA' in v and 'HEATMAP_SIZE' in v['EXTRA']: + if isinstance(v['EXTRA']['HEATMAP_SIZE'], int): + v['EXTRA']['HEATMAP_SIZE'] = np.array( + [v['EXTRA']['HEATMAP_SIZE'], v['EXTRA']['HEATMAP_SIZE']]) + else: + v['EXTRA']['HEATMAP_SIZE'] = np.array( + v['EXTRA']['HEATMAP_SIZE']) + if 'IMAGE_SIZE' in v: + if isinstance(v['IMAGE_SIZE'], int): + v['IMAGE_SIZE'] = np.array([v['IMAGE_SIZE'], v['IMAGE_SIZE']]) + else: + v['IMAGE_SIZE'] = np.array(v['IMAGE_SIZE']) + for vk, vv in v.items(): + if vk in config[k]: + config[k][vk] = vv + else: + raise ValueError("{}.{} not exist in config.py".format(k, vk)) + + +def update_config(config_file): + exp_config = None + with open(config_file) as f: + exp_config = edict(yaml.load(f)) + for k, v in exp_config.items(): + if k in config: + if isinstance(v, dict): + _update_dict(k, v) + else: + if k == 'SCALES': + config[k][0] = (tuple(v)) + else: + config[k] = v + else: + raise ValueError("{} not exist in config.py".format(k)) + + +def gen_config(config_file): + cfg = dict(config) + for k, v in cfg.items(): + if isinstance(v, edict): + cfg[k] = dict(v) + + with open(config_file, 'w') as f: + yaml.dump(dict(cfg), f, default_flow_style=False) + + +def update_dir(model_dir, log_dir, data_dir): + if model_dir: + config.OUTPUT_DIR = model_dir + + if log_dir: + config.LOG_DIR = log_dir + + if data_dir: + config.DATA_DIR = data_dir + + config.DATASET.ROOT = os.path.join( + config.DATA_DIR, config.DATASET.ROOT) + + config.TEST.COCO_BBOX_FILE = os.path.join( + config.DATA_DIR, config.TEST.COCO_BBOX_FILE) + + config.MODEL.PRETRAINED = os.path.join( + config.DATA_DIR, config.MODEL.PRETRAINED) + + +def get_model_name(cfg): + name = cfg.MODEL.NAME + full_name = cfg.MODEL.NAME + extra = cfg.MODEL.EXTRA + if name in ['pose_resnet']: + name = '{model}_{num_layers}'.format( + model=name, + num_layers=extra.NUM_LAYERS) + deconv_suffix = ''.join( + 'd{}'.format(num_filters) + for num_filters in extra.NUM_DECONV_FILTERS) + full_name = '{height}x{width}_{name}_{deconv_suffix}'.format( + height=cfg.MODEL.IMAGE_SIZE[1], + width=cfg.MODEL.IMAGE_SIZE[0], + name=name, + deconv_suffix=deconv_suffix) + else: + raise ValueError('Unkown model: {}'.format(cfg.MODEL)) + + return name, full_name + + +if __name__ == '__main__': + import sys + gen_config(sys.argv[1]) diff --git a/main/postometro_utils/pose_w48_256x192_adam_lr1e-3.yaml b/main/postometro_utils/pose_w48_256x192_adam_lr1e-3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ecdea1ee6b08031bea216fd72a102042b07b3880 --- /dev/null +++ b/main/postometro_utils/pose_w48_256x192_adam_lr1e-3.yaml @@ -0,0 +1,127 @@ +AUTO_RESUME: true +CUDNN: + BENCHMARK: true + DETERMINISTIC: false + ENABLED: true +DATA_DIR: '' +GPUS: (0,1,2,3) +OUTPUT_DIR: 'output' +LOG_DIR: 'log' +WORKERS: 24 +PRINT_FREQ: 100 + +DATASET: + COLOR_RGB: true + DATASET: 'coco' + DATA_FORMAT: jpg + FLIP: true + NUM_JOINTS_HALF_BODY: 8 + PROB_HALF_BODY: 0.3 + ROOT: 'data/coco/' + ROT_FACTOR: 45 + SCALE_FACTOR: 0.35 + TEST_SET: 'val2017' + TRAIN_SET: 'train2017' +MODEL: + INIT_WEIGHTS: true + NAME: pose_hrnet + NUM_JOINTS: 17 + PRETRAINED: 'models/pytorch/imagenet/hrnet_w48-8ef0771d.pth' + TARGET_TYPE: gaussian + IMAGE_SIZE: + - 192 + - 256 + HEATMAP_SIZE: + - 48 + - 64 + SIGMA: 2 + EXTRA: + PRETRAINED_LAYERS: + - 'conv1' + - 'bn1' + - 'conv2' + - 'bn2' + - 'layer1' + - 'transition1' + - 'stage2' + - 'transition2' + - 'stage3' + - 'transition3' + - 'stage4' + FINAL_CONV_KERNEL: 1 + STAGE2: + NUM_MODULES: 1 + NUM_BRANCHES: 2 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + NUM_CHANNELS: + - 48 + - 96 + FUSE_METHOD: SUM + STAGE3: + NUM_MODULES: 4 + NUM_BRANCHES: 3 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 48 + - 96 + - 192 + FUSE_METHOD: SUM + STAGE4: + NUM_MODULES: 3 + NUM_BRANCHES: 4 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 48 + - 96 + - 192 + - 384 + FUSE_METHOD: SUM +LOSS: + USE_TARGET_WEIGHT: true +TRAIN: + BATCH_SIZE_PER_GPU: 32 + SHUFFLE: true + BEGIN_EPOCH: 0 + END_EPOCH: 210 + OPTIMIZER: adam + LR: 0.001 + LR_FACTOR: 0.1 + LR_STEP: + - 170 + - 200 + WD: 0.0001 + GAMMA1: 0.99 + GAMMA2: 0.0 + MOMENTUM: 0.9 + NESTEROV: false +TEST: + BATCH_SIZE_PER_GPU: 32 + COCO_BBOX_FILE: 'data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json' + BBOX_THRE: 1.0 + IMAGE_THRE: 0.0 + IN_VIS_THRE: 0.2 + MODEL_FILE: '' + NMS_THRE: 1.0 + OKS_THRE: 0.9 + USE_GT_BBOX: true + FLIP_TEST: true + POST_PROCESS: true + SHIFT_HEATMAP: true +DEBUG: + DEBUG: true + SAVE_BATCH_IMAGES_GT: true + SAVE_BATCH_IMAGES_PRED: true + SAVE_HEATMAPS_GT: true + SAVE_HEATMAPS_PRED: true \ No newline at end of file diff --git a/main/postometro_utils/positional_encoding.py b/main/postometro_utils/positional_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..f4e6b03f59d43f7ebc89669636319766e78ab4f7 --- /dev/null +++ b/main/postometro_utils/positional_encoding.py @@ -0,0 +1,57 @@ +# ---------------------------------------------------------------------------------------------- +# FastMETRO Official Code +# Copyright (c) POSTECH Algorithmic Machine Intelligence Lab. (P-AMI Lab.) All Rights Reserved +# Licensed under the MIT license. +# ---------------------------------------------------------------------------------------------- +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved [see https://github.com/facebookresearch/detr/blob/main/LICENSE for details] +# ---------------------------------------------------------------------------------------------- + +import math +import torch +from torch import nn + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, bs, h, w, device): + ones = torch.ones((bs, h, w), dtype=torch.bool, device=device) + y_embed = ones.cumsum(1, dtype=torch.float32) + x_embed = ones.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.num_pos_feats) # cancel warning + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +def build_position_encoding(pos_type, hidden_dim): + N_steps = hidden_dim // 2 + if pos_type == 'sine': + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + else: + raise ValueError("not supported {pos_type}") + + return position_embedding \ No newline at end of file diff --git a/main/postometro_utils/renderer_pyrender.py b/main/postometro_utils/renderer_pyrender.py new file mode 100644 index 0000000000000000000000000000000000000000..4c4ec0597bba5561b98dfbe0a548a017ae7c42f8 --- /dev/null +++ b/main/postometro_utils/renderer_pyrender.py @@ -0,0 +1,225 @@ +# ---------------------------------------------------------------------------------------------- +# Modified from Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE) +# Copyright (c) Hongsuk Choi. All Rights Reserved [see https://github.com/hongsukchoi/Pose2Mesh_RELEASE/blob/main/LICENSE for details] +# ---------------------------------------------------------------------------------------------- + +import os +os.environ['PYOPENGL_PLATFORM'] = 'osmesa' +import torch +import numpy as np +import torch.nn.functional as F +import math +import cv2 +import trimesh +import pyrender +from pyrender.constants import RenderFlags + +def crop_bbox(bbox_meta, resolution, rgb, valid_mask): + bbox, original_img_height, original_img_width = bbox_meta['bbox'], *bbox_meta['img_hw'] + start_x = int(bbox[0]) + start_y = int(bbox[1]) + end_x = start_x + int(resolution[0]) # w + start_x + end_y = start_y + int(resolution[1]) # h + start_y + real_start_x, real_start_y, real_end_x, real_end_y = max(0, start_x), max(0, start_y), min(original_img_width, end_x), min(original_img_height, end_y) + max_height, max_width = rgb.shape[:2] + real_rgb = rgb[(real_start_y - start_y):((real_end_y - end_y) if real_end_y < end_y else max_height), + (real_start_x - start_x):((real_end_x - end_x) if real_end_x < end_x else max_width)].copy() + real_valid_mask = valid_mask[(real_start_y - start_y):((real_end_y - end_y) if real_end_y < end_y else max_height), + (real_start_x - start_x):((real_end_x - end_x) if real_end_x < end_x else max_width)].copy() + return {'bbox': [real_start_x, real_start_y, real_end_x, real_end_y], 'img_hw': [original_img_height, original_img_width]}, real_rgb, real_valid_mask + + +class WeakPerspectiveCamera(pyrender.Camera): + def __init__(self, scale, translation, znear=pyrender.camera.DEFAULT_Z_NEAR, zfar=None, name=None): + super(WeakPerspectiveCamera, self).__init__(znear=znear, zfar=zfar, name=name) + self.scale = scale + self.translation = translation + + def get_projection_matrix(self, width=None, height=None): + P = np.eye(4) + P[0, 0] = self.scale[0] + P[1, 1] = self.scale[1] + P[0, 3] = self.translation[0] * self.scale[0] + P[1, 3] = -self.translation[1] * self.scale[1] + P[2, 2] = -1 + return P + + +class PyRender_Renderer: + def __init__(self, resolution=(256, 256), faces=None, orig_img=False, wireframe=False): + self.resolution = resolution + self.faces = faces + self.orig_img = orig_img + self.wireframe = wireframe + self.renderer = pyrender.OffscreenRenderer(viewport_width=self.resolution[0], + viewport_height=self.resolution[1], + point_size=1.0) + + # set the scene & create light source + self.scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], ambient_light=(0.05, 0.05, 0.05)) + light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=3.0) + light_pose = trimesh.transformations.rotation_matrix(np.radians(-45), [1, 0, 0]) + self.scene.add(light, pose=light_pose) + light_pose = trimesh.transformations.rotation_matrix(np.radians(45), [0, 1, 0]) + self.scene.add(light, pose=light_pose) + + # mesh colors + self.colors_dict = {'blue': np.array([0.35, 0.60, 0.92]), + 'neutral': np.array([0.7, 0.7, 0.6]), + 'pink': np.array([0.7, 0.5, 0.5]), + 'white': np.array([1.0, 0.98, 0.94]), + 'green': np.array([0.5, 0.55, 0.3]), + 'sky': np.array([0.3, 0.5, 0.55])} + + def __call__(self, verts, bbox_meta, img=np.zeros((224, 224, 3)), cam=np.array([1, 0, 0]), + angle=None, axis=None, mesh_filename=None, color_type=None, color=[0.7, 0.7, 0.6]): + if color_type != None: + color = self.colors_dict[color_type] + + mesh = trimesh.Trimesh(vertices=verts, faces=self.faces, process=False) + Rx = trimesh.transformations.rotation_matrix(math.radians(180), [1, 0, 0]) + mesh.apply_transform(Rx) + if mesh_filename is not None: + mesh.export(mesh_filename) + if angle and axis: + R = trimesh.transformations.rotation_matrix(math.radians(angle), axis) + mesh.apply_transform(R) + + sy, tx, ty = cam + sx = sy + camera = WeakPerspectiveCamera(scale=[sx, sy], translation=[tx, ty], zfar=1000.0) + + material = pyrender.MetallicRoughnessMaterial( + metallicFactor=0.2, + roughnessFactor=1.0, + alphaMode='OPAQUE', + baseColorFactor=(color[0], color[1], color[2], 1.0) + ) + + mesh = pyrender.Mesh.from_trimesh(mesh, material=material) + mesh_node = self.scene.add(mesh, 'mesh') + + camera_pose = np.eye(4) + cam_node = self.scene.add(camera, pose=camera_pose) + + if self.wireframe: + render_flags = RenderFlags.RGBA | RenderFlags.ALL_WIREFRAME + else: + render_flags = RenderFlags.RGBA + + rgb, depth = self.renderer.render(self.scene, flags=render_flags) + valid_mask = (depth > 0)[:, :, np.newaxis] # bbox size + # adjust bbox (no out of boundary) + bbox_meta, rgb, valid_mask = crop_bbox(bbox_meta, [self.resolution[0], self.resolution[1]], rgb, valid_mask) + # parse bbox + start_x, start_y, end_x, end_y, original_img_height, original_img_width = *bbox_meta['bbox'], *bbox_meta['img_hw'] + # start_x = int(bbox_meta['bbox'][0]) + # start_y = int(bbox_meta['bbox'][1]) + # end_x = start_x + int(self.resolution[0]) # w + start_x + # end_y = start_y + int(self.resolution[1]) # h + start_y + whole_img_mask = np.zeros((original_img_height, original_img_width,1)) + whole_img_mask[start_y:end_y, start_x:end_x] = valid_mask + whole_rgb = np.zeros((original_img_height, original_img_width,4)) + whole_rgb[start_y:end_y, start_x:end_x,:3] = rgb + output_img = whole_rgb[:, :, :3] * whole_img_mask + (1 - whole_img_mask) * img + image = output_img.astype(np.uint8) + + self.scene.remove_node(mesh_node) + self.scene.remove_node(cam_node) + + return image + + +def visualize_reconstruction_pyrender(img, vertices, camera, renderer, color='blue', focal_length=1000): + img = (img * 255).astype(np.uint8) + save_mesh_path = None + rend_color = color + + # Render front view + rend_img = renderer(vertices, + img=img, + cam=camera, + color_type=rend_color, + mesh_filename=save_mesh_path) + + combined = np.hstack([img, rend_img]) + + return combined + +def visualize_reconstruction_multi_view_pyrender(img, vertices, camera, renderer, color='blue', focal_length=1000): + img = (img * 255).astype(np.uint8) + save_mesh_path = None + rend_color = color + + # Render front view + rend_img = renderer(vertices, + img=img, + cam=camera, + color_type=rend_color, + mesh_filename=save_mesh_path) + + # Render side views + aroundy0 = cv2.Rodrigues(np.array([0, np.radians(0.), 0]))[0] + aroundy1 = cv2.Rodrigues(np.array([0, np.radians(90.), 0]))[0] + aroundy2 = cv2.Rodrigues(np.array([0, np.radians(180.), 0]))[0] + aroundy3 = cv2.Rodrigues(np.array([0, np.radians(270.), 0]))[0] + aroundy4 = cv2.Rodrigues(np.array([0, np.radians(45.), 0]))[0] + center = vertices.mean(axis=0) + rot_vertices0 = np.dot((vertices - center), aroundy0) + center + rot_vertices1 = np.dot((vertices - center), aroundy1) + center + rot_vertices2 = np.dot((vertices - center), aroundy2) + center + rot_vertices3 = np.dot((vertices - center), aroundy3) + center + rot_vertices4 = np.dot((vertices - center), aroundy4) + center + + # Render side-view shape + img_side0 = renderer(rot_vertices0, + img=np.ones_like(img)*255, + cam=camera, + color_type=rend_color, + mesh_filename=save_mesh_path) + img_side1 = renderer(rot_vertices1, + img=np.ones_like(img)*255, + cam=camera, + color_type=rend_color, + mesh_filename=save_mesh_path) + img_side2 = renderer(rot_vertices2, + img=np.ones_like(img)*255, + cam=camera, + color_type=rend_color, + mesh_filename=save_mesh_path) + img_side3 = renderer(rot_vertices3, + img=np.ones_like(img)*255, + cam=camera, + color_type=rend_color, + mesh_filename=save_mesh_path) + img_side4 = renderer(rot_vertices4, + img=np.ones_like(img)*255, + cam=camera, + color_type=rend_color, + mesh_filename=save_mesh_path) + + combined = np.hstack([img, rend_img, img_side0, img_side1, img_side2, img_side3, img_side4]) + + return combined + +def visualize_reconstruction_smpl_pyrender(img, vertices, camera, renderer, smpl_vertices, color='blue', focal_length=1000): + img = (img * 255).astype(np.uint8) + save_mesh_path = None + rend_color = color + + # Render front view + rend_img = renderer(vertices, + img=img, + cam=camera, + color_type=rend_color, + mesh_filename=save_mesh_path) + + rend_img_smpl = renderer(smpl_vertices, + img=img, + cam=camera, + color_type=rend_color, + mesh_filename=save_mesh_path) + + combined = np.hstack([img, rend_img, rend_img_smpl]) + + return combined \ No newline at end of file diff --git a/main/postometro_utils/smpl.py b/main/postometro_utils/smpl.py new file mode 100644 index 0000000000000000000000000000000000000000..108ee9dfd390f789fd2ed815454d7108ff1b454b --- /dev/null +++ b/main/postometro_utils/smpl.py @@ -0,0 +1,291 @@ +# ---------------------------------------------------------------------------------------------- +# METRO (https://github.com/microsoft/MeshTransformer) +# Copyright (c) Microsoft Corporation. All Rights Reserved [see https://github.com/microsoft/MeshTransformer/blob/main/LICENSE for details] +# Licensed under the MIT license. +# ---------------------------------------------------------------------------------------------- +""" +This file contains the definition of the SMPL model + +It is adapted from opensource project GraphCMR (https://github.com/nkolot/GraphCMR/) +""" +from __future__ import division + +import torch +import torch.nn as nn +import numpy as np +import scipy.sparse +try: + import cPickle as pickle +except ImportError: + import pickle + +from postometro_utils.geometric_layers import rodrigues +import data.config as cfg + +class SMPL(nn.Module): + + def __init__(self, gender='neutral'): + super(SMPL, self).__init__() + + if gender=='m': + model_file=cfg.SMPL_Male + elif gender=='f': + model_file=cfg.SMPL_Female + else: + model_file=cfg.SMPL_FILE + + smpl_model = pickle.load(open(model_file, 'rb'), encoding='latin1') + J_regressor = smpl_model['J_regressor'].tocoo() + row = J_regressor.row + col = J_regressor.col + data = J_regressor.data + i = torch.LongTensor([row, col]) + v = torch.FloatTensor(data) + J_regressor_shape = [24, 6890] + self.register_buffer('J_regressor', torch.sparse.FloatTensor(i, v, J_regressor_shape).to_dense()) + self.register_buffer('weights', torch.FloatTensor(smpl_model['weights'])) + self.register_buffer('posedirs', torch.FloatTensor(smpl_model['posedirs'])) + self.register_buffer('v_template', torch.FloatTensor(smpl_model['v_template'])) + self.register_buffer('shapedirs', torch.FloatTensor(np.array(smpl_model['shapedirs']))) + self.register_buffer('faces', torch.from_numpy(smpl_model['f'].astype(np.int64))) + self.register_buffer('kintree_table', torch.from_numpy(smpl_model['kintree_table'].astype(np.int64))) + id_to_col = {self.kintree_table[1, i].item(): i for i in range(self.kintree_table.shape[1])} + self.register_buffer('parent', torch.LongTensor([id_to_col[self.kintree_table[0, it].item()] for it in range(1, self.kintree_table.shape[1])])) + + self.pose_shape = [24, 3] + self.beta_shape = [10] + self.translation_shape = [3] + + self.pose = torch.zeros(self.pose_shape) + self.beta = torch.zeros(self.beta_shape) + self.translation = torch.zeros(self.translation_shape) + + self.verts = None + self.J = None + self.R = None + + J_regressor_extra = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_TRAIN_EXTRA)).float() + self.register_buffer('J_regressor_extra', J_regressor_extra) + self.joints_idx = cfg.JOINTS_IDX + + J_regressor_h36m_correct = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_H36M_correct)).float() + self.register_buffer('J_regressor_h36m_correct', J_regressor_h36m_correct) + + def forward(self, pose, beta): + device = pose.device + batch_size = pose.shape[0] + v_template = self.v_template[None, :] + shapedirs = self.shapedirs.view(-1,10)[None, :].expand(batch_size, -1, -1) + beta = beta[:, :, None] + v_shaped = torch.matmul(shapedirs, beta).view(-1, 6890, 3) + v_template + # batched sparse matmul not supported in pytorch + J = [] + for i in range(batch_size): + J.append(torch.matmul(self.J_regressor, v_shaped[i])) + J = torch.stack(J, dim=0) + # input it rotmat: (bs,24,3,3) + if pose.ndimension() == 4: + R = pose + # input it rotmat: (bs,72) + elif pose.ndimension() == 2: + pose_cube = pose.view(-1, 3) # (batch_size * 24, 1, 3) + R = rodrigues(pose_cube).view(batch_size, 24, 3, 3) + R = R.view(batch_size, 24, 3, 3) + I_cube = torch.eye(3)[None, None, :].to(device) + # I_cube = torch.eye(3)[None, None, :].expand(theta.shape[0], R.shape[1]-1, -1, -1) + lrotmin = (R[:,1:,:] - I_cube).view(batch_size, -1) + posedirs = self.posedirs.view(-1,207)[None, :].expand(batch_size, -1, -1) + v_posed = v_shaped + torch.matmul(posedirs, lrotmin[:, :, None]).view(-1, 6890, 3) + J_ = J.clone() + J_[:, 1:, :] = J[:, 1:, :] - J[:, self.parent, :] + G_ = torch.cat([R, J_[:, :, :, None]], dim=-1) + pad_row = torch.FloatTensor([0,0,0,1]).to(device).view(1,1,1,4).expand(batch_size, 24, -1, -1) + G_ = torch.cat([G_, pad_row], dim=2) + G = [G_[:, 0].clone()] + for i in range(1, 24): + G.append(torch.matmul(G[self.parent[i-1]], G_[:, i, :, :])) + G = torch.stack(G, dim=1) + + rest = torch.cat([J, torch.zeros(batch_size, 24, 1).to(device)], dim=2).view(batch_size, 24, 4, 1) + zeros = torch.zeros(batch_size, 24, 4, 3).to(device) + rest = torch.cat([zeros, rest], dim=-1) + rest = torch.matmul(G, rest) + G = G - rest + T = torch.matmul(self.weights, G.permute(1,0,2,3).contiguous().view(24,-1)).view(6890, batch_size, 4, 4).transpose(0,1) + rest_shape_h = torch.cat([v_posed, torch.ones_like(v_posed)[:, :, [0]]], dim=-1) + v = torch.matmul(T, rest_shape_h[:, :, :, None])[:, :, :3, 0] + return v + + def get_joints(self, vertices): + """ + This method is used to get the joint locations from the SMPL mesh + Input: + vertices: size = (B, 6890, 3) + Output: + 3D joints: size = (B, 38, 3) + """ + joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor]) + joints_extra = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_extra]) + joints = torch.cat((joints, joints_extra), dim=1) + joints = joints[:, cfg.JOINTS_IDX] + return joints + + def get_24_joints(self, vertices): + """ + This method is used to get the joint locations from the SMPL mesh + Input: + vertices: size = (B, 6890, 3) + Output: + 3D joints: size = (B, 38, 3) + """ + joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor]) + return joints + + def get_h36m_joints(self, vertices): + """ + This method is used to get the joint locations from the SMPL mesh + Input: + vertices: size = (B, 6890, 3) + Output: + 3D joints: size = (B, 17, 3) + """ + joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_h36m_correct]) + return joints + +class SparseMM(torch.autograd.Function): + """Redefine sparse @ dense matrix multiplication to enable backpropagation. + The builtin matrix multiplication operation does not support backpropagation in some cases. + """ + @staticmethod + def forward(ctx, sparse, dense): + ctx.req_grad = dense.requires_grad + ctx.save_for_backward(sparse) + return torch.matmul(sparse, dense) + + @staticmethod + def backward(ctx, grad_output): + grad_input = None + sparse, = ctx.saved_tensors + if ctx.req_grad: + grad_input = torch.matmul(sparse.t(), grad_output) + return None, grad_input + +def spmm(sparse, dense): + return SparseMM.apply(sparse, dense) + + +def scipy_to_pytorch(A, U, D): + """Convert scipy sparse matrices to pytorch sparse matrix.""" + ptU = [] + ptD = [] + + for i in range(len(U)): + u = scipy.sparse.coo_matrix(U[i]) + i = torch.LongTensor(np.array([u.row, u.col])) + v = torch.FloatTensor(u.data) + ptU.append(torch.sparse.FloatTensor(i, v, u.shape)) + + for i in range(len(D)): + d = scipy.sparse.coo_matrix(D[i]) + i = torch.LongTensor(np.array([d.row, d.col])) + v = torch.FloatTensor(d.data) + ptD.append(torch.sparse.FloatTensor(i, v, d.shape)) + + return ptU, ptD + + +def adjmat_sparse(adjmat, nsize=1): + """Create row-normalized sparse graph adjacency matrix.""" + adjmat = scipy.sparse.csr_matrix(adjmat) + if nsize > 1: + orig_adjmat = adjmat.copy() + for _ in range(1, nsize): + adjmat = adjmat * orig_adjmat + adjmat.data = np.ones_like(adjmat.data) + for i in range(adjmat.shape[0]): + adjmat[i,i] = 1 + num_neighbors = np.array(1 / adjmat.sum(axis=-1)) + adjmat = adjmat.multiply(num_neighbors) + adjmat = scipy.sparse.coo_matrix(adjmat) + row = adjmat.row + col = adjmat.col + data = adjmat.data + i = torch.LongTensor(np.array([row, col])) + v = torch.from_numpy(data).float() + adjmat = torch.sparse.FloatTensor(i, v, adjmat.shape) + return adjmat + +def get_graph_params(filename, nsize=1): + """Load and process graph adjacency matrix and upsampling/downsampling matrices.""" + data = np.load(filename, encoding='latin1', allow_pickle=True) + A = data['A'] + U = data['U'] + D = data['D'] + U, D = scipy_to_pytorch(A, U, D) + A = [adjmat_sparse(a, nsize=nsize) for a in A] + return A, U, D + + +class Mesh(object): + """Mesh object that is used for handling certain graph operations.""" + def __init__(self, filename=cfg.SMPL_sampling_matrix, + num_downsampling=1, nsize=1, device=torch.device('cuda')): + self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize) + self._A = [a.to(device) for a in self._A] + self._U = [u.to(device) for u in self._U] + self._D = [d.to(device) for d in self._D] + self.num_downsampling = num_downsampling + + # load template vertices from SMPL and normalize them + smpl = SMPL() + ref_vertices = smpl.v_template + center = 0.5*(ref_vertices.max(dim=0)[0] + ref_vertices.min(dim=0)[0])[None] + ref_vertices -= center + ref_vertices /= ref_vertices.abs().max().item() + + self._ref_vertices = ref_vertices.to(device) + self.faces = smpl.faces.int().to(device) + + @property + def ref_vertices(self): + """Return the template vertices at the specified subsampling level.""" + ref_vertices = self._ref_vertices + for i in range(self.num_downsampling): + ref_vertices = torch.spmm(self._D[i], ref_vertices) + return ref_vertices + + def adjmat(self, num_downsampling): + """Return the graph adjacency matrix at the specified subsampling level.""" + return self._A[num_downsampling].float() + + def downsample(self, x, n1=0, n2=None): + """Downsample mesh.""" + if n2 is None: + n2 = self.num_downsampling + if x.ndimension() < 3: + for i in range(n1, n2): + x = spmm(self._D[i], x) + elif x.ndimension() == 3: + out = [] + for i in range(x.shape[0]): + y = x[i] + for j in range(n1, n2): + y = spmm(self._D[j], y) + out.append(y) + x = torch.stack(out, dim=0) + return x + + def upsample(self, x, n1=1, n2=0): + """Upsample mesh.""" + if x.ndimension() < 3: + for i in reversed(range(n2, n1)): + x = spmm(self._U[i], x) + elif x.ndimension() == 3: + out = [] + for i in range(x.shape[0]): + y = x[i] + for j in reversed(range(n2, n1)): + y = spmm(self._U[j], y) + out.append(y) + x = torch.stack(out, dim=0) + return x diff --git a/main/postometro_utils/transformer.py b/main/postometro_utils/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..35132acf9a03128252539d28407620d52f25f361 --- /dev/null +++ b/main/postometro_utils/transformer.py @@ -0,0 +1,249 @@ +# ---------------------------------------------------------------------------------------------- +# FastMETRO Official Code +# Copyright (c) POSTECH Algorithmic Machine Intelligence Lab. (P-AMI Lab.) All Rights Reserved +# Licensed under the MIT license. +# ---------------------------------------------------------------------------------------------- +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved [see https://github.com/facebookresearch/detr/blob/main/LICENSE for details] +# ---------------------------------------------------------------------------------------------- + +""" +Transformer encoder-decoder architecture in FastMETRO model. +""" +import copy +import torch +import torch.nn.functional as F +from typing import Optional +from torch import nn, Tensor + +class Transformer(nn.Module): + """Transformer encoder-decoder""" + def __init__(self, model_dim=512, nhead=8, num_enc_layers=3, num_dec_layers=3, + feedforward_dim=2048, dropout=0.1, activation="relu"): + """ + Parameters: + - model_dim: The hidden dimension size in the transformer architecture + - nhead: The number of attention heads in the attention modules + - num_enc_layers: The number of encoder layers in the transformer encoder + - num_dec_layers: The number of decoder layers in the transformer decoder + - feedforward_dim: The hidden dimension size in MLP + - dropout: The dropout rate in the transformer architecture + - activation: The activation function used in MLP + """ + super().__init__() + self.model_dim = model_dim + self.nhead = nhead + + # transformer encoder + encoder_layer = TransformerEncoderLayer(model_dim, nhead, feedforward_dim, dropout, activation) + encoder_norm = nn.LayerNorm(model_dim) + self.encoder = TransformerEncoder(encoder_layer, num_enc_layers, encoder_norm) + + # transformer decoder + decoder_layer = TransformerDecoderLayer(model_dim, nhead, feedforward_dim, dropout, activation) + decoder_norm = nn.LayerNorm(model_dim) + self.decoder = TransformerDecoder(decoder_layer, num_dec_layers, decoder_norm) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, img_features, cam_token, jv_tokens, pos_embed, pct_token = None, attention_mask=None): + device = img_features.device + hw, bs, _ = img_features.shape # (height * width), batch_size, feature_dim + + if pct_token is None: + mask = torch.zeros((bs, hw), dtype=torch.bool, device=device) # batch_size X (height * width) + else: + pct_len = pct_token.size(0) + mask = torch.zeros((bs, hw + pct_len), dtype=torch.bool, device=device) + # Transformer Encoder + zero_mask = torch.zeros((bs, 1), dtype=torch.bool, device=device) # batch_size X 1 + mem_mask = torch.cat([zero_mask, mask], dim=1) # batch_size X (1 + height * width) + cam_with_img = torch.cat([cam_token, img_features], dim=0) # (1 + height * width) X batch_size X feature_dim + e_outputs = self.encoder(cam_with_img, pct_token, src_key_padding_mask=mem_mask, pos=pos_embed) # (1 + height * width) X batch_size X feature_dim + if pct_token is not None: + cam_features, enc_img_features, pct_features = e_outputs.split([1, hw, pct_len], dim=0) + enc_img_features = torch.cat([enc_img_features, pct_features], dim = 0) # concat pct to img features + else: + cam_features, enc_img_features = e_outputs.split([1, hw], dim=0) + pct_features = None + + # Transformer Decoder + zero_tgt = torch.zeros_like(jv_tokens) # (num_joints + num_vertices) X batch_size X feature_dim + jv_features = self.decoder(jv_tokens, enc_img_features, + tgt_mask=attention_mask, + memory_key_padding_mask=mask, pos=pos_embed, query_pos=zero_tgt) # (num_joints + num_vertices) X batch_size X feature_dim + + return cam_features, enc_img_features, jv_features, pct_features + + +class TransformerEncoder(nn.Module): + """Transformer encoder""" + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.num_layers = num_layers + self.norm = norm + self.layers = _get_clones(encoder_layer, num_layers) + + def forward(self, src, pct_token = None, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + if pct_token is not None: + output = torch.concat([src, pct_token], dim = 0) + else: + output = src + + for layer in self.layers: + output = layer(output, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + """Transformer decoder""" + def __init__(self, decoder_layer, num_layers, norm=None): + super().__init__() + self.num_layers = num_layers + self.norm = norm + self.layers = _get_clones(decoder_layer, num_layers) + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + output = tgt + + for layer in self.layers: + output = layer(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerEncoderLayer(nn.Module): + """Transformer encoder layer""" + def __init__(self, model_dim, nhead, feedforward_dim=2048, dropout=0.1, activation="relu"): + super().__init__() + self.self_attn = nn.MultiheadAttention(model_dim, nhead, dropout=dropout, batch_first=False) + + # MLP + self.linear1 = nn.Linear(model_dim, feedforward_dim) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(feedforward_dim, model_dim) + + # Layer Normalization & Dropout + self.norm1 = nn.LayerNorm(model_dim) + self.norm2 = nn.LayerNorm(model_dim) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + # tensor[0] is for a camera token (no positional encoding) + if pos is not None: + pos_len = pos.size(0) + return tensor if pos is None else torch.cat([tensor[:1], (tensor[1:1+pos_len] + pos), tensor[1+pos_len:]], dim=0) + + def forward(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2, attn_mask=None, + key_padding_mask=None)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + +class TransformerDecoderLayer(nn.Module): + """Transformer decoder layer""" + def __init__(self, model_dim, nhead, feedforward_dim=2048, dropout=0.1, activation="relu"): + super().__init__() + self.self_attn = nn.MultiheadAttention(model_dim, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(model_dim, nhead, dropout=dropout) + + # MLP + self.linear1 = nn.Linear(model_dim, feedforward_dim) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(feedforward_dim, model_dim) + + # Layer Normalization & Dropout + self.norm1 = nn.LayerNorm(model_dim) + self.norm2 = nn.LayerNorm(model_dim) + self.norm3 = nn.LayerNorm(model_dim) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + self.activation = _get_activation_fn(activation) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + if pos is not None: + pos_len = pos.size(0) + return tensor if pos is None else torch.cat([tensor[:pos_len] + pos, tensor[pos_len:]], dim = 0) + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError("activation should be relu/gelu, not {activation}.") + +def build_transformer(transformer_config): + return Transformer(model_dim=transformer_config['model_dim'], + dropout=transformer_config['dropout'], + nhead=transformer_config['nhead'], + feedforward_dim=transformer_config['feedforward_dim'], + num_enc_layers=transformer_config['num_enc_layers'], + num_dec_layers=transformer_config['num_dec_layers']) \ No newline at end of file