HaWoR / lib /models /hawor.py
ThunderVVV's picture
update
5f028d6
raw
history blame
22.7 kB
import einops
import numpy as np
import torch
import pytorch_lightning as pl
from typing import Dict
from torchvision.utils import make_grid
from tqdm import tqdm
from yacs.config import CfgNode
from lib.datasets.track_dataset import TrackDatasetEval
from lib.models.modules import MANOTransformerDecoderHead, temporal_attention
from hawor.utils.pylogger import get_pylogger
from hawor.utils.render_openpose import render_openpose
from lib.utils.geometry import rot6d_to_rotmat_hmr2 as rot6d_to_rotmat
from lib.utils.geometry import perspective_projection
from hawor.utils.rotation import angle_axis_to_rotation_matrix
from torch.utils.data import default_collate
from .backbones import create_backbone
from .mano_wrapper import MANO
log = get_pylogger(__name__)
idx = 0
class HAWOR(pl.LightningModule):
def __init__(self, cfg: CfgNode):
"""
Setup HAWOR model
Args:
cfg (CfgNode): Config file as a yacs CfgNode
"""
super().__init__()
# Save hyperparameters
self.save_hyperparameters(logger=False, ignore=['init_renderer'])
self.cfg = cfg
self.crop_size = cfg.MODEL.IMAGE_SIZE
self.seq_len = 16
self.pose_num = 16
self.pose_dim = 6 # rot6d representation
self.box_info_dim = 3
# Create backbone feature extractor
self.backbone = create_backbone(cfg)
try:
if cfg.MODEL.BACKBONE.get('PRETRAINED_WEIGHTS', None):
whole_state_dict = torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu')['state_dict']
backbone_state_dict = {}
for key in whole_state_dict:
if key[:9] == 'backbone.':
backbone_state_dict[key[9:]] = whole_state_dict[key]
self.backbone.load_state_dict(backbone_state_dict)
print(f'Loaded backbone weights from {cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS}')
for param in self.backbone.parameters():
param.requires_grad = False
else:
print('WARNING: init backbone from sratch !!!')
except:
print('WARNING: init backbone from sratch !!!')
# Space-time memory
if cfg.MODEL.ST_MODULE:
hdim = cfg.MODEL.ST_HDIM
nlayer = cfg.MODEL.ST_NLAYER
self.st_module = temporal_attention(in_dim=1280+3,
out_dim=1280,
hdim=hdim,
nlayer=nlayer,
residual=True)
print(f'Using Temporal Attention space-time: {nlayer} layers {hdim} dim.')
else:
self.st_module = None
# Motion memory
if cfg.MODEL.MOTION_MODULE:
hdim = cfg.MODEL.MOTION_HDIM
nlayer = cfg.MODEL.MOTION_NLAYER
self.motion_module = temporal_attention(in_dim=self.pose_num * self.pose_dim + self.box_info_dim,
out_dim=self.pose_num * self.pose_dim,
hdim=hdim,
nlayer=nlayer,
residual=False)
print(f'Using Temporal Attention motion layer: {nlayer} layers {hdim} dim.')
else:
self.motion_module = None
# Create MANO head
# self.mano_head = build_mano_head(cfg)
self.mano_head = MANOTransformerDecoderHead(cfg)
# default open torch compile
if cfg.MODEL.BACKBONE.get('TORCH_COMPILE', 0):
log.info("Model will use torch.compile")
self.backbone = torch.compile(self.backbone)
self.mano_head = torch.compile(self.mano_head)
# Define loss functions
# self.keypoint_3d_loss = Keypoint3DLoss(loss_type='l1')
# self.keypoint_2d_loss = Keypoint2DLoss(loss_type='l1')
# self.mano_parameter_loss = ParameterLoss()
# Instantiate MANO model
mano_cfg = {k.lower(): v for k,v in dict(cfg.MANO).items()}
self.mano = MANO(**mano_cfg)
# Buffer that shows whetheer we need to initialize ActNorm layers
self.register_buffer('initialized', torch.tensor(False))
# Disable automatic optimization since we use adversarial training
self.automatic_optimization = False
if cfg.MODEL.get('LOAD_WEIGHTS', None):
whole_state_dict = torch.load(cfg.MODEL.LOAD_WEIGHTS, map_location='cpu')['state_dict']
self.load_state_dict(whole_state_dict, strict=True)
print(f"load {cfg.MODEL.LOAD_WEIGHTS}")
def get_parameters(self):
all_params = list(self.mano_head.parameters())
if not self.st_module is None:
all_params += list(self.st_module.parameters())
if not self.motion_module is None:
all_params += list(self.motion_module.parameters())
all_params += list(self.backbone.parameters())
return all_params
def configure_optimizers(self) -> torch.optim.Optimizer:
"""
Setup model and distriminator Optimizers
Returns:
Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Model and discriminator optimizers
"""
param_groups = [{'params': filter(lambda p: p.requires_grad, self.get_parameters()), 'lr': self.cfg.TRAIN.LR}]
optimizer = torch.optim.AdamW(params=param_groups,
# lr=self.cfg.TRAIN.LR,
weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
return optimizer
def forward_step(self, batch: Dict, train: bool = False) -> Dict:
"""
Run a forward step of the network
Args:
batch (Dict): Dictionary containing batch data
train (bool): Flag indicating whether it is training or validation mode
Returns:
Dict: Dictionary containing the regression output
"""
image = batch['img'].flatten(0, 1)
center = batch['center'].flatten(0, 1)
scale = batch['scale'].flatten(0, 1)
img_focal = batch['img_focal'].flatten(0, 1)
img_center = batch['img_center'].flatten(0, 1)
bn = len(image)
# estimate focal length, and bbox
bbox_info = self.bbox_est(center, scale, img_focal, img_center)
# backbone
feature = self.backbone(image[:,:,:,32:-32])
feature = feature.float()
# space-time module
if self.st_module is not None:
bb = einops.repeat(bbox_info, 'b c -> b c h w', h=16, w=12)
feature = torch.cat([feature, bb], dim=1)
feature = einops.rearrange(feature, '(b t) c h w -> (b h w) t c', t=16)
feature = self.st_module(feature)
feature = einops.rearrange(feature, '(b h w) t c -> (b t) c h w', h=16, w=12)
# smpl_head: transformer + smpl
# pred_mano_params, pred_cam, pred_mano_params_list = self.mano_head(feature)
# pred_shape = pred_mano_params_list['pred_shape']
# pred_pose = pred_mano_params_list['pred_pose']
pred_pose, pred_shape, pred_cam = self.mano_head(feature)
pred_rotmat_0 = rot6d_to_rotmat(pred_pose).reshape(-1, self.pose_num, 3, 3)
# smpl motion module
if self.motion_module is not None:
bb = einops.rearrange(bbox_info, '(b t) c -> b t c', t=16)
pred_pose = einops.rearrange(pred_pose, '(b t) c -> b t c', t=16)
pred_pose = torch.cat([pred_pose, bb], dim=2)
pred_pose = self.motion_module(pred_pose)
pred_pose = einops.rearrange(pred_pose, 'b t c -> (b t) c')
out = {}
if 'do_flip' in batch:
pred_cam[..., 1] *= -1
center[..., 0] = img_center[..., 0]*2 - center[..., 0] - 1
out['pred_cam'] = pred_cam
out['pred_pose'] = pred_pose
out['pred_shape'] = pred_shape
out['pred_rotmat'] = rot6d_to_rotmat(out['pred_pose']).reshape(-1, self.pose_num, 3, 3)
out['pred_rotmat_0'] = pred_rotmat_0
s_out = self.mano.query(out)
j3d = s_out.joints
j2d = self.project(j3d, out['pred_cam'], center, scale, img_focal, img_center)
j2d = j2d / self.crop_size - 0.5 # norm to [-0.5, 0.5]
trans_full = self.get_trans(out['pred_cam'], center, scale, img_focal, img_center)
out['trans_full'] = trans_full
output = {
'pred_mano_params': {
'global_orient': out['pred_rotmat'][:, :1].clone(),
'hand_pose': out['pred_rotmat'][:, 1:].clone(),
'betas': out['pred_shape'].clone(),
},
'pred_keypoints_3d': j3d.clone(),
'pred_keypoints_2d': j2d.clone(),
'out': out,
}
# print(output)
# output['gt_project_j2d'] = self.project(batch['gt_j3d_wo_trans'].clone().flatten(0,1), out['pred_cam'], center, scale, img_focal, img_center)
# output['gt_project_j2d'] = output['gt_project_j2d'] / self.crop_size - 0.5
return output
def compute_loss(self, batch: Dict, output: Dict, train: bool = True) -> torch.Tensor:
"""
Compute losses given the input batch and the regression output
Args:
batch (Dict): Dictionary containing batch data
output (Dict): Dictionary containing the regression output
train (bool): Flag indicating whether it is training or validation mode
Returns:
torch.Tensor : Total loss for current batch
"""
pred_mano_params = output['pred_mano_params']
pred_keypoints_2d = output['pred_keypoints_2d']
pred_keypoints_3d = output['pred_keypoints_3d']
batch_size = pred_mano_params['hand_pose'].shape[0]
device = pred_mano_params['hand_pose'].device
dtype = pred_mano_params['hand_pose'].dtype
# Get annotations
gt_keypoints_2d = batch['gt_cam_j2d'].flatten(0, 1)
gt_keypoints_2d = torch.cat([gt_keypoints_2d, torch.ones(*gt_keypoints_2d.shape[:-1], 1, device=gt_keypoints_2d.device)], dim=-1)
gt_keypoints_3d = batch['gt_j3d_wo_trans'].flatten(0, 1)
gt_keypoints_3d = torch.cat([gt_keypoints_3d, torch.ones(*gt_keypoints_3d.shape[:-1], 1, device=gt_keypoints_3d.device)], dim=-1)
pose_gt = batch['gt_cam_full_pose'].flatten(0, 1).reshape(-1, 16, 3)
rotmat_gt = angle_axis_to_rotation_matrix(pose_gt)
gt_mano_params = {
'global_orient': rotmat_gt[:, :1],
'hand_pose': rotmat_gt[:, 1:],
'betas': batch['gt_cam_betas'],
}
# Compute 3D keypoint loss
loss_keypoints_2d = self.keypoint_2d_loss(pred_keypoints_2d, gt_keypoints_2d)
loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, pelvis_id=0)
# to avoid nan
loss_keypoints_2d = torch.nan_to_num(loss_keypoints_2d)
# Compute loss on MANO parameters
loss_mano_params = {}
for k, pred in pred_mano_params.items():
gt = gt_mano_params[k].view(batch_size, -1)
loss_mano_params[k] = self.mano_parameter_loss(pred.reshape(batch_size, -1), gt.reshape(batch_size, -1))
loss = self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D'] * loss_keypoints_3d+\
self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'] * loss_keypoints_2d+\
sum([loss_mano_params[k] * self.cfg.LOSS_WEIGHTS[k.upper()] for k in loss_mano_params])
losses = dict(loss=loss.detach(),
loss_keypoints_2d=loss_keypoints_2d.detach() * self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'],
loss_keypoints_3d=loss_keypoints_3d.detach() * self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D'])
for k, v in loss_mano_params.items():
losses['loss_' + k] = v.detach() * self.cfg.LOSS_WEIGHTS[k.upper()]
output['losses'] = losses
return loss
# Tensoroboard logging should run from first rank only
@pl.utilities.rank_zero.rank_zero_only
def tensorboard_logging(self, batch: Dict, output: Dict, step_count: int, train: bool = True, write_to_summary_writer: bool = True, render_log: bool = True) -> None:
"""
Log results to Tensorboard
Args:
batch (Dict): Dictionary containing batch data
output (Dict): Dictionary containing the regression output
step_count (int): Global training step count
train (bool): Flag indicating whether it is training or validation mode
"""
mode = 'train' if train else 'val'
batch_size = output['pred_keypoints_2d'].shape[0]
images = batch['img'].flatten(0,1)
images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1,3,1,1)
images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1,3,1,1)
losses = output['losses']
if write_to_summary_writer:
summary_writer = self.logger.experiment
for loss_name, val in losses.items():
summary_writer.add_scalar(mode +'/' + loss_name, val.detach().item(), step_count)
if render_log:
gt_keypoints_2d = batch['gt_cam_j2d'].flatten(0,1).clone()
pred_keypoints_2d = output['pred_keypoints_2d'].clone().detach().reshape(batch_size, -1, 2)
gt_project_j2d = pred_keypoints_2d
# gt_project_j2d = output['gt_project_j2d'].clone().detach().reshape(batch_size, -1, 2)
num_images = 4
skip=16
predictions = self.visualize_tensorboard(images[:num_images*skip:skip].cpu().numpy(),
pred_keypoints_2d[:num_images*skip:skip].cpu().numpy(),
gt_project_j2d[:num_images*skip:skip].cpu().numpy(),
gt_keypoints_2d[:num_images*skip:skip].cpu().numpy(),
)
summary_writer.add_image('%s/predictions' % mode, predictions, step_count)
def forward(self, batch: Dict) -> Dict:
"""
Run a forward step of the network in val mode
Args:
batch (Dict): Dictionary containing batch data
Returns:
Dict: Dictionary containing the regression output
"""
return self.forward_step(batch, train=False)
def training_step(self, joint_batch: Dict, batch_idx: int) -> Dict:
"""
Run a full training step
Args:
joint_batch (Dict): Dictionary containing image and mocap batch data
batch_idx (int): Unused.
batch_idx (torch.Tensor): Unused.
Returns:
Dict: Dictionary containing regression output.
"""
batch = joint_batch['img']
optimizer = self.optimizers(use_pl_optimizer=True)
batch_size = batch['img'].shape[0]
output = self.forward_step(batch, train=True)
# pred_mano_params = output['pred_mano_params']
loss = self.compute_loss(batch, output, train=True)
# Error if Nan
if torch.isnan(loss):
raise ValueError('Loss is NaN')
optimizer.zero_grad()
self.manual_backward(loss)
# Clip gradient
if self.cfg.TRAIN.get('GRAD_CLIP_VAL', 0) > 0:
gn = torch.nn.utils.clip_grad_norm_(self.get_parameters(), self.cfg.TRAIN.GRAD_CLIP_VAL, error_if_nonfinite=True)
self.log('train/grad_norm', gn, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
optimizer.step()
# if self.global_step > 0 and self.global_step % self.cfg.GENERAL.LOG_STEPS == 0:
if self.global_step > 0 and self.global_step % 100 == 0:
self.tensorboard_logging(batch, output, self.global_step, train=True, render_log=self.cfg.TRAIN.get("RENDER_LOG", True))
self.log('train/loss', output['losses']['loss'], on_step=True, on_epoch=True, prog_bar=True, logger=False, batch_size=batch_size)
return output
def inference(self, imgfiles, boxes, img_focal, img_center, device='cuda', do_flip=False):
db = TrackDatasetEval(imgfiles, boxes, img_focal=img_focal,
img_center=img_center, normalization=True, dilate=1.2, do_flip=do_flip)
# Results
pred_cam = []
pred_pose = []
pred_shape = []
pred_rotmat = []
pred_trans = []
# To-do: efficient implementation with batch
items = []
for i in tqdm(range(len(db))):
item = db[i]
items.append(item)
# padding to 16
if i == len(db) - 1 and len(db) % 16 != 0:
pad = 16 - len(db) % 16
for _ in range(pad):
items.append(item)
if len(items) < 16:
continue
elif len(items) == 16:
batch = default_collate(items)
items = []
else:
raise NotImplementedError
with torch.no_grad():
batch = {k: v.to(device).unsqueeze(0) for k, v in batch.items() if type(v)==torch.Tensor}
# for image_i in range(16):
# hawor_input_cv2 = vis_tensor_cv2(batch['img'][:, image_i])
# cv2.imwrite(f'debug_vis_model.png', hawor_input_cv2)
# print("vis")
output = self.forward(batch)
out = output['out']
if i == len(db) - 1 and len(db) % 16 != 0:
out = {k:v[:len(db) % 16] for k,v in out.items()}
else:
out = {k:v for k,v in out.items()}
pred_cam.append(out['pred_cam'].cpu())
pred_pose.append(out['pred_pose'].cpu())
pred_shape.append(out['pred_shape'].cpu())
pred_rotmat.append(out['pred_rotmat'].cpu())
pred_trans.append(out['trans_full'].cpu())
results = {'pred_cam': torch.cat(pred_cam),
'pred_pose': torch.cat(pred_pose),
'pred_shape': torch.cat(pred_shape),
'pred_rotmat': torch.cat(pred_rotmat),
'pred_trans': torch.cat(pred_trans),
'img_focal': img_focal,
'img_center': img_center}
return results
def validation_step(self, batch: Dict, batch_idx: int, dataloader_idx=0) -> Dict:
"""
Run a validation step and log to Tensorboard
Args:
batch (Dict): Dictionary containing batch data
batch_idx (int): Unused.
Returns:
Dict: Dictionary containing regression output.
"""
# batch_size = batch['img'].shape[0]
output = self.forward_step(batch, train=False)
loss = self.compute_loss(batch, output, train=False)
output['loss'] = loss
self.tensorboard_logging(batch, output, self.global_step, train=False)
return output
def visualize_tensorboard(self, images, pred_keypoints, gt_project_j2d, gt_keypoints):
pred_keypoints = 256 * (pred_keypoints + 0.5)
gt_keypoints = 256 * (gt_keypoints + 0.5)
gt_project_j2d = 256 * (gt_project_j2d + 0.5)
pred_keypoints = np.concatenate((pred_keypoints, np.ones_like(pred_keypoints)[:, :, [0]]), axis=-1)
gt_keypoints = np.concatenate((gt_keypoints, np.ones_like(gt_keypoints)[:, :, [0]]), axis=-1)
gt_project_j2d = np.concatenate((gt_project_j2d, np.ones_like(gt_project_j2d)[:, :, [0]]), axis=-1)
images_np = np.transpose(images, (0,2,3,1))
rend_imgs = []
for i in range(images_np.shape[0]):
pred_keypoints_img = render_openpose(255 * images_np[i].copy(), pred_keypoints[i]) / 255
gt_project_j2d_img = render_openpose(255 * images_np[i].copy(), gt_project_j2d[i]) / 255
gt_keypoints_img = render_openpose(255*images_np[i].copy(), gt_keypoints[i]) / 255
rend_imgs.append(torch.from_numpy(images[i]))
rend_imgs.append(torch.from_numpy(pred_keypoints_img).permute(2,0,1))
rend_imgs.append(torch.from_numpy(gt_project_j2d_img).permute(2,0,1))
rend_imgs.append(torch.from_numpy(gt_keypoints_img).permute(2,0,1))
rend_imgs = make_grid(rend_imgs, nrow=4, padding=2)
return rend_imgs
def project(self, points, pred_cam, center, scale, img_focal, img_center, return_full=False):
trans_full = self.get_trans(pred_cam, center, scale, img_focal, img_center)
# Projection in full frame image coordinate
points = points + trans_full
points2d_full = perspective_projection(points, rotation=None, translation=None,
focal_length=img_focal, camera_center=img_center)
# Adjust projected points to crop image coordinate
# (s.t. 1. we can calculate loss in crop image easily
# 2. we can query its pixel in the crop
# )
b = scale * 200
points2d = points2d_full - (center - b[:,None]/2)[:,None,:]
points2d = points2d * (self.crop_size / b)[:,None,None]
if return_full:
return points2d_full, points2d
else:
return points2d
def get_trans(self, pred_cam, center, scale, img_focal, img_center):
b = scale * 200
cx, cy = center[:,0], center[:,1] # center of crop
s, tx, ty = pred_cam.unbind(-1)
img_cx, img_cy = img_center[:,0], img_center[:,1] # center of original image
bs = b*s
tx_full = tx + 2*(cx-img_cx)/bs
ty_full = ty + 2*(cy-img_cy)/bs
tz_full = 2*img_focal/bs
trans_full = torch.stack([tx_full, ty_full, tz_full], dim=-1)
trans_full = trans_full.unsqueeze(1)
return trans_full
def bbox_est(self, center, scale, img_focal, img_center):
# Original image center
img_cx, img_cy = img_center[:,0], img_center[:,1]
# Implement CLIFF (Li et al.) bbox feature
cx, cy, b = center[:, 0], center[:, 1], scale * 200
bbox_info = torch.stack([cx - img_cx, cy - img_cy, b], dim=-1)
bbox_info[:, :2] = bbox_info[:, :2] / img_focal.unsqueeze(-1) * 2.8
bbox_info[:, 2] = (bbox_info[:, 2] - 0.24 * img_focal) / (0.06 * img_focal)
return bbox_info