Sonic / src /utils /RIFE /RIFE_HDv3.py
xiaozhongji
init spaces
79d88c4
import torch
from .IFNet_HDv3 import *
import torch.nn.functional as F
class RIFEModel:
def __init__(self, device=None):
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = device
self.flownet = IFNet().to(self.device).eval()
def train(self):
self.flownet.train()
def eval(self):
self.flownet.eval()
def load_model(self, path, rank=-1):
def convert(param):
if rank == -1:
return {
k.replace("module.", ""): v
for k, v in param.items()
if "module." in k
}
else:
return param
self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path), map_location ='cpu')))
def inference(self, img0, img1, scale=1.0):
imgs = torch.cat((img0, img1), 1)
scale_list = [4/scale, 2/scale, 1/scale]
flow, mask, merged = self.flownet(imgs, scale_list)
return merged[2]