Spaces:
Runtime error
Runtime error
import gradio as gr | |
from PIL import Image | |
from collections import OrderedDict | |
import torch | |
from models.model import GLPDepth | |
from PIL import Image | |
from torchvision import transforms | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import os | |
# load model | |
DEVICE='cpu' | |
def load_mde_model(path): | |
model = GLPDepth(max_depth=700.0, is_train=False).to(DEVICE) | |
model_weight = torch.load(path, map_location=torch.device('cpu')) | |
model_weight = model_weight['model_state_dict'] | |
if 'module' in next(iter(model_weight.items()))[0]: | |
model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items()) | |
model.load_state_dict(model_weight) | |
model.eval() | |
return model | |
model = load_mde_model('best_model.ckpt') | |
preprocess = transforms.Compose([ | |
transforms.Resize((512, 512)), | |
transforms.ToTensor() | |
]) | |
def predict(input_image): | |
pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB') | |
# transform image to torch and do preprocessing | |
torch_img = preprocess(pil_image).to(DEVICE).unsqueeze(0) | |
# model predict | |
with torch.no_grad(): | |
output_patch = model(torch_img) | |
# transform torch to image | |
predicted_image = output_patch['pred_d'].squeeze().cpu().detach().numpy() | |
# return correct image | |
fig, ax = plt.subplots() | |
im = ax.imshow(predicted_image, cmap='jet', vmin=0, vmax=np.max(predicted_image)) | |
plt.colorbar(im, ax=ax) | |
fig.canvas.draw() | |
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
return data #, str(predicted_image.tolist()) | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(shape=(512,512)), | |
outputs=[ | |
gr.Image(shape=(512,512)), | |
# gr.outputs.Textbox(label='Raw output') | |
], | |
examples=[ | |
[f"demo_imgs/{name}"] for name in os.listdir('demo_imgs') | |
], | |
title="DTM Estimation", | |
description="This demo predict a DTM using GLP Depth model. It will scale input image to 512x512 and at the end it will apply a colormap to better visualize the output." | |
) | |
iface.launch() |