Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from .IFNet_HDv3 import * | |
import torch.nn.functional as F | |
class RIFEModel: | |
def __init__(self, device=None): | |
if device is None: | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
else: | |
self.device = device | |
self.flownet = IFNet().to(self.device).eval() | |
def train(self): | |
self.flownet.train() | |
def eval(self): | |
self.flownet.eval() | |
def load_model(self, path, rank=-1): | |
def convert(param): | |
if rank == -1: | |
return { | |
k.replace("module.", ""): v | |
for k, v in param.items() | |
if "module." in k | |
} | |
else: | |
return param | |
self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path), map_location ='cpu'))) | |
def inference(self, img0, img1, scale=1.0): | |
imgs = torch.cat((img0, img1), 1) | |
scale_list = [4/scale, 2/scale, 1/scale] | |
flow, mask, merged = self.flownet(imgs, scale_list) | |
return merged[2] |