Spaces:
Running
Running
import einops | |
import numpy as np | |
import torch | |
import pytorch_lightning as pl | |
from typing import Dict | |
from torchvision.utils import make_grid | |
from tqdm import tqdm | |
from yacs.config import CfgNode | |
from lib.datasets.track_dataset import TrackDatasetEval | |
from lib.models.modules import MANOTransformerDecoderHead, temporal_attention | |
from hawor.utils.pylogger import get_pylogger | |
from hawor.utils.render_openpose import render_openpose | |
from lib.utils.geometry import rot6d_to_rotmat_hmr2 as rot6d_to_rotmat | |
from lib.utils.geometry import perspective_projection | |
from hawor.utils.rotation import angle_axis_to_rotation_matrix | |
from torch.utils.data import default_collate | |
from .backbones import create_backbone | |
from .mano_wrapper import MANO | |
log = get_pylogger(__name__) | |
idx = 0 | |
class HAWOR(pl.LightningModule): | |
def __init__(self, cfg: CfgNode): | |
""" | |
Setup HAWOR model | |
Args: | |
cfg (CfgNode): Config file as a yacs CfgNode | |
""" | |
super().__init__() | |
# Save hyperparameters | |
self.save_hyperparameters(logger=False, ignore=['init_renderer']) | |
self.cfg = cfg | |
self.crop_size = cfg.MODEL.IMAGE_SIZE | |
self.seq_len = 16 | |
self.pose_num = 16 | |
self.pose_dim = 6 # rot6d representation | |
self.box_info_dim = 3 | |
# Create backbone feature extractor | |
self.backbone = create_backbone(cfg) | |
try: | |
if cfg.MODEL.BACKBONE.get('PRETRAINED_WEIGHTS', None): | |
whole_state_dict = torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu')['state_dict'] | |
backbone_state_dict = {} | |
for key in whole_state_dict: | |
if key[:9] == 'backbone.': | |
backbone_state_dict[key[9:]] = whole_state_dict[key] | |
self.backbone.load_state_dict(backbone_state_dict) | |
print(f'Loaded backbone weights from {cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS}') | |
for param in self.backbone.parameters(): | |
param.requires_grad = False | |
else: | |
print('WARNING: init backbone from sratch !!!') | |
except: | |
print('WARNING: init backbone from sratch !!!') | |
# Space-time memory | |
if cfg.MODEL.ST_MODULE: | |
hdim = cfg.MODEL.ST_HDIM | |
nlayer = cfg.MODEL.ST_NLAYER | |
self.st_module = temporal_attention(in_dim=1280+3, | |
out_dim=1280, | |
hdim=hdim, | |
nlayer=nlayer, | |
residual=True) | |
print(f'Using Temporal Attention space-time: {nlayer} layers {hdim} dim.') | |
else: | |
self.st_module = None | |
# Motion memory | |
if cfg.MODEL.MOTION_MODULE: | |
hdim = cfg.MODEL.MOTION_HDIM | |
nlayer = cfg.MODEL.MOTION_NLAYER | |
self.motion_module = temporal_attention(in_dim=self.pose_num * self.pose_dim + self.box_info_dim, | |
out_dim=self.pose_num * self.pose_dim, | |
hdim=hdim, | |
nlayer=nlayer, | |
residual=False) | |
print(f'Using Temporal Attention motion layer: {nlayer} layers {hdim} dim.') | |
else: | |
self.motion_module = None | |
# Create MANO head | |
# self.mano_head = build_mano_head(cfg) | |
self.mano_head = MANOTransformerDecoderHead(cfg) | |
# default open torch compile | |
if cfg.MODEL.BACKBONE.get('TORCH_COMPILE', 0): | |
log.info("Model will use torch.compile") | |
self.backbone = torch.compile(self.backbone) | |
self.mano_head = torch.compile(self.mano_head) | |
# Define loss functions | |
# self.keypoint_3d_loss = Keypoint3DLoss(loss_type='l1') | |
# self.keypoint_2d_loss = Keypoint2DLoss(loss_type='l1') | |
# self.mano_parameter_loss = ParameterLoss() | |
# Instantiate MANO model | |
mano_cfg = {k.lower(): v for k,v in dict(cfg.MANO).items()} | |
self.mano = MANO(**mano_cfg) | |
# Buffer that shows whetheer we need to initialize ActNorm layers | |
self.register_buffer('initialized', torch.tensor(False)) | |
# Disable automatic optimization since we use adversarial training | |
self.automatic_optimization = False | |
if cfg.MODEL.get('LOAD_WEIGHTS', None): | |
whole_state_dict = torch.load(cfg.MODEL.LOAD_WEIGHTS, map_location='cpu')['state_dict'] | |
self.load_state_dict(whole_state_dict, strict=True) | |
print(f"load {cfg.MODEL.LOAD_WEIGHTS}") | |
def get_parameters(self): | |
all_params = list(self.mano_head.parameters()) | |
if not self.st_module is None: | |
all_params += list(self.st_module.parameters()) | |
if not self.motion_module is None: | |
all_params += list(self.motion_module.parameters()) | |
all_params += list(self.backbone.parameters()) | |
return all_params | |
def configure_optimizers(self) -> torch.optim.Optimizer: | |
""" | |
Setup model and distriminator Optimizers | |
Returns: | |
Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Model and discriminator optimizers | |
""" | |
param_groups = [{'params': filter(lambda p: p.requires_grad, self.get_parameters()), 'lr': self.cfg.TRAIN.LR}] | |
optimizer = torch.optim.AdamW(params=param_groups, | |
# lr=self.cfg.TRAIN.LR, | |
weight_decay=self.cfg.TRAIN.WEIGHT_DECAY) | |
return optimizer | |
def forward_step(self, batch: Dict, train: bool = False) -> Dict: | |
""" | |
Run a forward step of the network | |
Args: | |
batch (Dict): Dictionary containing batch data | |
train (bool): Flag indicating whether it is training or validation mode | |
Returns: | |
Dict: Dictionary containing the regression output | |
""" | |
image = batch['img'].flatten(0, 1) | |
center = batch['center'].flatten(0, 1) | |
scale = batch['scale'].flatten(0, 1) | |
img_focal = batch['img_focal'].flatten(0, 1) | |
img_center = batch['img_center'].flatten(0, 1) | |
bn = len(image) | |
# estimate focal length, and bbox | |
bbox_info = self.bbox_est(center, scale, img_focal, img_center) | |
# backbone | |
feature = self.backbone(image[:,:,:,32:-32]) | |
feature = feature.float() | |
# space-time module | |
if self.st_module is not None: | |
bb = einops.repeat(bbox_info, 'b c -> b c h w', h=16, w=12) | |
feature = torch.cat([feature, bb], dim=1) | |
feature = einops.rearrange(feature, '(b t) c h w -> (b h w) t c', t=16) | |
feature = self.st_module(feature) | |
feature = einops.rearrange(feature, '(b h w) t c -> (b t) c h w', h=16, w=12) | |
# smpl_head: transformer + smpl | |
# pred_mano_params, pred_cam, pred_mano_params_list = self.mano_head(feature) | |
# pred_shape = pred_mano_params_list['pred_shape'] | |
# pred_pose = pred_mano_params_list['pred_pose'] | |
pred_pose, pred_shape, pred_cam = self.mano_head(feature) | |
pred_rotmat_0 = rot6d_to_rotmat(pred_pose).reshape(-1, self.pose_num, 3, 3) | |
# smpl motion module | |
if self.motion_module is not None: | |
bb = einops.rearrange(bbox_info, '(b t) c -> b t c', t=16) | |
pred_pose = einops.rearrange(pred_pose, '(b t) c -> b t c', t=16) | |
pred_pose = torch.cat([pred_pose, bb], dim=2) | |
pred_pose = self.motion_module(pred_pose) | |
pred_pose = einops.rearrange(pred_pose, 'b t c -> (b t) c') | |
out = {} | |
if 'do_flip' in batch: | |
pred_cam[..., 1] *= -1 | |
center[..., 0] = img_center[..., 0]*2 - center[..., 0] - 1 | |
out['pred_cam'] = pred_cam | |
out['pred_pose'] = pred_pose | |
out['pred_shape'] = pred_shape | |
out['pred_rotmat'] = rot6d_to_rotmat(out['pred_pose']).reshape(-1, self.pose_num, 3, 3) | |
out['pred_rotmat_0'] = pred_rotmat_0 | |
s_out = self.mano.query(out) | |
j3d = s_out.joints | |
j2d = self.project(j3d, out['pred_cam'], center, scale, img_focal, img_center) | |
j2d = j2d / self.crop_size - 0.5 # norm to [-0.5, 0.5] | |
trans_full = self.get_trans(out['pred_cam'], center, scale, img_focal, img_center) | |
out['trans_full'] = trans_full | |
output = { | |
'pred_mano_params': { | |
'global_orient': out['pred_rotmat'][:, :1].clone(), | |
'hand_pose': out['pred_rotmat'][:, 1:].clone(), | |
'betas': out['pred_shape'].clone(), | |
}, | |
'pred_keypoints_3d': j3d.clone(), | |
'pred_keypoints_2d': j2d.clone(), | |
'out': out, | |
} | |
# print(output) | |
# output['gt_project_j2d'] = self.project(batch['gt_j3d_wo_trans'].clone().flatten(0,1), out['pred_cam'], center, scale, img_focal, img_center) | |
# output['gt_project_j2d'] = output['gt_project_j2d'] / self.crop_size - 0.5 | |
return output | |
def compute_loss(self, batch: Dict, output: Dict, train: bool = True) -> torch.Tensor: | |
""" | |
Compute losses given the input batch and the regression output | |
Args: | |
batch (Dict): Dictionary containing batch data | |
output (Dict): Dictionary containing the regression output | |
train (bool): Flag indicating whether it is training or validation mode | |
Returns: | |
torch.Tensor : Total loss for current batch | |
""" | |
pred_mano_params = output['pred_mano_params'] | |
pred_keypoints_2d = output['pred_keypoints_2d'] | |
pred_keypoints_3d = output['pred_keypoints_3d'] | |
batch_size = pred_mano_params['hand_pose'].shape[0] | |
device = pred_mano_params['hand_pose'].device | |
dtype = pred_mano_params['hand_pose'].dtype | |
# Get annotations | |
gt_keypoints_2d = batch['gt_cam_j2d'].flatten(0, 1) | |
gt_keypoints_2d = torch.cat([gt_keypoints_2d, torch.ones(*gt_keypoints_2d.shape[:-1], 1, device=gt_keypoints_2d.device)], dim=-1) | |
gt_keypoints_3d = batch['gt_j3d_wo_trans'].flatten(0, 1) | |
gt_keypoints_3d = torch.cat([gt_keypoints_3d, torch.ones(*gt_keypoints_3d.shape[:-1], 1, device=gt_keypoints_3d.device)], dim=-1) | |
pose_gt = batch['gt_cam_full_pose'].flatten(0, 1).reshape(-1, 16, 3) | |
rotmat_gt = angle_axis_to_rotation_matrix(pose_gt) | |
gt_mano_params = { | |
'global_orient': rotmat_gt[:, :1], | |
'hand_pose': rotmat_gt[:, 1:], | |
'betas': batch['gt_cam_betas'], | |
} | |
# Compute 3D keypoint loss | |
loss_keypoints_2d = self.keypoint_2d_loss(pred_keypoints_2d, gt_keypoints_2d) | |
loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, pelvis_id=0) | |
# to avoid nan | |
loss_keypoints_2d = torch.nan_to_num(loss_keypoints_2d) | |
# Compute loss on MANO parameters | |
loss_mano_params = {} | |
for k, pred in pred_mano_params.items(): | |
gt = gt_mano_params[k].view(batch_size, -1) | |
loss_mano_params[k] = self.mano_parameter_loss(pred.reshape(batch_size, -1), gt.reshape(batch_size, -1)) | |
loss = self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D'] * loss_keypoints_3d+\ | |
self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'] * loss_keypoints_2d+\ | |
sum([loss_mano_params[k] * self.cfg.LOSS_WEIGHTS[k.upper()] for k in loss_mano_params]) | |
losses = dict(loss=loss.detach(), | |
loss_keypoints_2d=loss_keypoints_2d.detach() * self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'], | |
loss_keypoints_3d=loss_keypoints_3d.detach() * self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D']) | |
for k, v in loss_mano_params.items(): | |
losses['loss_' + k] = v.detach() * self.cfg.LOSS_WEIGHTS[k.upper()] | |
output['losses'] = losses | |
return loss | |
# Tensoroboard logging should run from first rank only | |
def tensorboard_logging(self, batch: Dict, output: Dict, step_count: int, train: bool = True, write_to_summary_writer: bool = True, render_log: bool = True) -> None: | |
""" | |
Log results to Tensorboard | |
Args: | |
batch (Dict): Dictionary containing batch data | |
output (Dict): Dictionary containing the regression output | |
step_count (int): Global training step count | |
train (bool): Flag indicating whether it is training or validation mode | |
""" | |
mode = 'train' if train else 'val' | |
batch_size = output['pred_keypoints_2d'].shape[0] | |
images = batch['img'].flatten(0,1) | |
images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1,3,1,1) | |
images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1,3,1,1) | |
losses = output['losses'] | |
if write_to_summary_writer: | |
summary_writer = self.logger.experiment | |
for loss_name, val in losses.items(): | |
summary_writer.add_scalar(mode +'/' + loss_name, val.detach().item(), step_count) | |
if render_log: | |
gt_keypoints_2d = batch['gt_cam_j2d'].flatten(0,1).clone() | |
pred_keypoints_2d = output['pred_keypoints_2d'].clone().detach().reshape(batch_size, -1, 2) | |
gt_project_j2d = pred_keypoints_2d | |
# gt_project_j2d = output['gt_project_j2d'].clone().detach().reshape(batch_size, -1, 2) | |
num_images = 4 | |
skip=16 | |
predictions = self.visualize_tensorboard(images[:num_images*skip:skip].cpu().numpy(), | |
pred_keypoints_2d[:num_images*skip:skip].cpu().numpy(), | |
gt_project_j2d[:num_images*skip:skip].cpu().numpy(), | |
gt_keypoints_2d[:num_images*skip:skip].cpu().numpy(), | |
) | |
summary_writer.add_image('%s/predictions' % mode, predictions, step_count) | |
def forward(self, batch: Dict) -> Dict: | |
""" | |
Run a forward step of the network in val mode | |
Args: | |
batch (Dict): Dictionary containing batch data | |
Returns: | |
Dict: Dictionary containing the regression output | |
""" | |
return self.forward_step(batch, train=False) | |
def training_step(self, joint_batch: Dict, batch_idx: int) -> Dict: | |
""" | |
Run a full training step | |
Args: | |
joint_batch (Dict): Dictionary containing image and mocap batch data | |
batch_idx (int): Unused. | |
batch_idx (torch.Tensor): Unused. | |
Returns: | |
Dict: Dictionary containing regression output. | |
""" | |
batch = joint_batch['img'] | |
optimizer = self.optimizers(use_pl_optimizer=True) | |
batch_size = batch['img'].shape[0] | |
output = self.forward_step(batch, train=True) | |
# pred_mano_params = output['pred_mano_params'] | |
loss = self.compute_loss(batch, output, train=True) | |
# Error if Nan | |
if torch.isnan(loss): | |
raise ValueError('Loss is NaN') | |
optimizer.zero_grad() | |
self.manual_backward(loss) | |
# Clip gradient | |
if self.cfg.TRAIN.get('GRAD_CLIP_VAL', 0) > 0: | |
gn = torch.nn.utils.clip_grad_norm_(self.get_parameters(), self.cfg.TRAIN.GRAD_CLIP_VAL, error_if_nonfinite=True) | |
self.log('train/grad_norm', gn, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size) | |
optimizer.step() | |
# if self.global_step > 0 and self.global_step % self.cfg.GENERAL.LOG_STEPS == 0: | |
if self.global_step > 0 and self.global_step % 100 == 0: | |
self.tensorboard_logging(batch, output, self.global_step, train=True, render_log=self.cfg.TRAIN.get("RENDER_LOG", True)) | |
self.log('train/loss', output['losses']['loss'], on_step=True, on_epoch=True, prog_bar=True, logger=False, batch_size=batch_size) | |
return output | |
def inference(self, imgfiles, boxes, img_focal, img_center, device='cuda', do_flip=False): | |
db = TrackDatasetEval(imgfiles, boxes, img_focal=img_focal, | |
img_center=img_center, normalization=True, dilate=1.2, do_flip=do_flip) | |
# Results | |
pred_cam = [] | |
pred_pose = [] | |
pred_shape = [] | |
pred_rotmat = [] | |
pred_trans = [] | |
# To-do: efficient implementation with batch | |
items = [] | |
for i in tqdm(range(len(db))): | |
item = db[i] | |
items.append(item) | |
# padding to 16 | |
if i == len(db) - 1 and len(db) % 16 != 0: | |
pad = 16 - len(db) % 16 | |
for _ in range(pad): | |
items.append(item) | |
if len(items) < 16: | |
continue | |
elif len(items) == 16: | |
batch = default_collate(items) | |
items = [] | |
else: | |
raise NotImplementedError | |
with torch.no_grad(): | |
batch = {k: v.to(device).unsqueeze(0) for k, v in batch.items() if type(v)==torch.Tensor} | |
# for image_i in range(16): | |
# hawor_input_cv2 = vis_tensor_cv2(batch['img'][:, image_i]) | |
# cv2.imwrite(f'debug_vis_model.png', hawor_input_cv2) | |
# print("vis") | |
output = self.forward(batch) | |
out = output['out'] | |
if i == len(db) - 1 and len(db) % 16 != 0: | |
out = {k:v[:len(db) % 16] for k,v in out.items()} | |
else: | |
out = {k:v for k,v in out.items()} | |
pred_cam.append(out['pred_cam'].cpu()) | |
pred_pose.append(out['pred_pose'].cpu()) | |
pred_shape.append(out['pred_shape'].cpu()) | |
pred_rotmat.append(out['pred_rotmat'].cpu()) | |
pred_trans.append(out['trans_full'].cpu()) | |
results = {'pred_cam': torch.cat(pred_cam), | |
'pred_pose': torch.cat(pred_pose), | |
'pred_shape': torch.cat(pred_shape), | |
'pred_rotmat': torch.cat(pred_rotmat), | |
'pred_trans': torch.cat(pred_trans), | |
'img_focal': img_focal, | |
'img_center': img_center} | |
return results | |
def validation_step(self, batch: Dict, batch_idx: int, dataloader_idx=0) -> Dict: | |
""" | |
Run a validation step and log to Tensorboard | |
Args: | |
batch (Dict): Dictionary containing batch data | |
batch_idx (int): Unused. | |
Returns: | |
Dict: Dictionary containing regression output. | |
""" | |
# batch_size = batch['img'].shape[0] | |
output = self.forward_step(batch, train=False) | |
loss = self.compute_loss(batch, output, train=False) | |
output['loss'] = loss | |
self.tensorboard_logging(batch, output, self.global_step, train=False) | |
return output | |
def visualize_tensorboard(self, images, pred_keypoints, gt_project_j2d, gt_keypoints): | |
pred_keypoints = 256 * (pred_keypoints + 0.5) | |
gt_keypoints = 256 * (gt_keypoints + 0.5) | |
gt_project_j2d = 256 * (gt_project_j2d + 0.5) | |
pred_keypoints = np.concatenate((pred_keypoints, np.ones_like(pred_keypoints)[:, :, [0]]), axis=-1) | |
gt_keypoints = np.concatenate((gt_keypoints, np.ones_like(gt_keypoints)[:, :, [0]]), axis=-1) | |
gt_project_j2d = np.concatenate((gt_project_j2d, np.ones_like(gt_project_j2d)[:, :, [0]]), axis=-1) | |
images_np = np.transpose(images, (0,2,3,1)) | |
rend_imgs = [] | |
for i in range(images_np.shape[0]): | |
pred_keypoints_img = render_openpose(255 * images_np[i].copy(), pred_keypoints[i]) / 255 | |
gt_project_j2d_img = render_openpose(255 * images_np[i].copy(), gt_project_j2d[i]) / 255 | |
gt_keypoints_img = render_openpose(255*images_np[i].copy(), gt_keypoints[i]) / 255 | |
rend_imgs.append(torch.from_numpy(images[i])) | |
rend_imgs.append(torch.from_numpy(pred_keypoints_img).permute(2,0,1)) | |
rend_imgs.append(torch.from_numpy(gt_project_j2d_img).permute(2,0,1)) | |
rend_imgs.append(torch.from_numpy(gt_keypoints_img).permute(2,0,1)) | |
rend_imgs = make_grid(rend_imgs, nrow=4, padding=2) | |
return rend_imgs | |
def project(self, points, pred_cam, center, scale, img_focal, img_center, return_full=False): | |
trans_full = self.get_trans(pred_cam, center, scale, img_focal, img_center) | |
# Projection in full frame image coordinate | |
points = points + trans_full | |
points2d_full = perspective_projection(points, rotation=None, translation=None, | |
focal_length=img_focal, camera_center=img_center) | |
# Adjust projected points to crop image coordinate | |
# (s.t. 1. we can calculate loss in crop image easily | |
# 2. we can query its pixel in the crop | |
# ) | |
b = scale * 200 | |
points2d = points2d_full - (center - b[:,None]/2)[:,None,:] | |
points2d = points2d * (self.crop_size / b)[:,None,None] | |
if return_full: | |
return points2d_full, points2d | |
else: | |
return points2d | |
def get_trans(self, pred_cam, center, scale, img_focal, img_center): | |
b = scale * 200 | |
cx, cy = center[:,0], center[:,1] # center of crop | |
s, tx, ty = pred_cam.unbind(-1) | |
img_cx, img_cy = img_center[:,0], img_center[:,1] # center of original image | |
bs = b*s | |
tx_full = tx + 2*(cx-img_cx)/bs | |
ty_full = ty + 2*(cy-img_cy)/bs | |
tz_full = 2*img_focal/bs | |
trans_full = torch.stack([tx_full, ty_full, tz_full], dim=-1) | |
trans_full = trans_full.unsqueeze(1) | |
return trans_full | |
def bbox_est(self, center, scale, img_focal, img_center): | |
# Original image center | |
img_cx, img_cy = img_center[:,0], img_center[:,1] | |
# Implement CLIFF (Li et al.) bbox feature | |
cx, cy, b = center[:, 0], center[:, 1], scale * 200 | |
bbox_info = torch.stack([cx - img_cx, cy - img_cy, b], dim=-1) | |
bbox_info[:, :2] = bbox_info[:, :2] / img_focal.unsqueeze(-1) * 2.8 | |
bbox_info[:, 2] = (bbox_info[:, 2] - 0.24 * img_focal) / (0.06 * img_focal) | |
return bbox_info | |