Spaces:
Runtime error
Runtime error
File size: 5,845 Bytes
99e984c 974236a 99e984c d43d5d3 99e984c 08c0204 99e984c d43d5d3 99e984c c9661f5 99e984c cfff052 99e984c 269330d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import numpy as np
import random
import matplotlib.pyplot as plt
from matplotlib import cm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode as IMode
from PIL import Image
from ds import *
from losses import *
from networks_SRGAN import *
from utils import *
device = 'cpu'
if device == 'cuda':
dtype = torch.cuda.FloatTensor
else:
dtype = torch.FloatTensor
NetG = Generator()
model_parameters = filter(lambda p: True, NetG.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print("Number of Parameters:", params)
NetC = BayesCap(in_channels=3, out_channels=3)
ensure_checkpoint_exists('BayesCap_SRGAN.pth')
NetG.load_state_dict(torch.load('BayesCap_SRGAN.pth', map_location=device))
NetG.to(device)
NetG.eval()
ensure_checkpoint_exists('BayesCap_ckpt.pth')
NetC.load_state_dict(torch.load('BayesCap_ckpt.pth', map_location=device))
NetC.to(device)
NetC.eval()
def tensor01_to_pil(xt):
r = transforms.ToPILImage(mode='RGB')(xt.squeeze())
return r
def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor:
"""Convert ``PIL.Image`` to Tensor.
Args:
image (np.ndarray): The image data read by ``PIL.Image``
range_norm (bool): Scale [0, 1] data to between [-1, 1]
half (bool): Whether to convert torch.float32 similarly to torch.half type.
Returns:
Normalized image data
Examples:
>>> image = Image.open("image.bmp")
>>> tensor_image = image2tensor(image, range_norm=False, half=False)
"""
tensor = F.to_tensor(image)
if range_norm:
tensor = tensor.mul_(2.0).sub_(1.0)
if half:
tensor = tensor.half()
return tensor
def predict(img):
"""
img: image
"""
image_size = (256,256)
upscale_factor = 4
# lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True)
# to retain aspect ratio
lr_transforms = transforms.Resize(image_size[0]//upscale_factor, interpolation=IMode.BICUBIC, antialias=True)
# lr_transforms = transforms.Resize((128, 128), interpolation=IMode.BICUBIC, antialias=True)
img = Image.fromarray(np.array(img))
img = lr_transforms(img)
lr_tensor = image2tensor(img, range_norm=False, half=False)
xLR = lr_tensor.to(device).unsqueeze(0)
xLR = xLR.type(dtype)
# pass them through the network
with torch.no_grad():
xSR = NetG(xLR)
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
a_map = (1/(xSRC_alpha[0] + 1e-5)).to('cpu').data
b_map = xSRC_beta[0].to('cpu').data
u_map = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
x_LR = tensor01_to_pil(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
x_mean = tensor01_to_pil(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
#im = Image.fromarray(np.uint8(cm.gist_earth(myarray)*255))
a_map = torch.clamp(a_map, min=0, max=0.1)
a_map = (a_map - a_map.min())/(a_map.max() - a_map.min())
x_alpha = Image.fromarray(np.uint8(cm.inferno(a_map.transpose(0,2).transpose(0,1).squeeze())*255))
b_map = torch.clamp(b_map, min=0.45, max=0.75)
b_map = (b_map - b_map.min())/(b_map.max() - b_map.min())
x_beta = Image.fromarray(np.uint8(cm.cividis(b_map.transpose(0,2).transpose(0,1).squeeze())*255))
u_map = torch.clamp(u_map, min=0, max=0.15)
u_map = (u_map - u_map.min())/(u_map.max() - u_map.min())
x_uncer = Image.fromarray(np.uint8(cm.hot(u_map.transpose(0,2).transpose(0,1).squeeze())*255))
return x_LR, x_mean, x_alpha, x_beta, x_uncer
import gradio as gr
title = "BayesCap"
method = "In this work, we propose a method (called BayesCap) to estimate the per-pixel uncertainty of a pretrained computer vision model like SRGAN (used for super-resolution). BayesCap takes the ouput of the pretrained model (in this case SRGAN), and predicts the per-pixel distribution parameters for the output, that can be used to quantify the per-pixel uncertainty. In our work, we model the per-pixel output as a <a href='https://en.wikipedia.org/wiki/Generalized_normal_distribution'>Generalized Gaussian distribution</a> that is parameterized by 3 parameters the mean, scale (alpha), and the shape (beta). As a result our model predicts these three parameters as shown below. From these 3 parameters one can compute the uncertainty as shown in <a href='https://en.wikipedia.org/wiki/Generalized_normal_distribution'>this article</a>."
description = "BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks (ECCV 2022) <br>" + method
article = "<p style='text-align: center'> BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks| <a href='https://github.com/ExplainableML/BayesCap'>Github Repo</a></p>"
gr.Interface(
fn=predict,
inputs=gr.inputs.Image(type='pil', label="Orignal"),
outputs=[
gr.outputs.Image(type='pil', label="Low-res"),
gr.outputs.Image(type='pil', label="Super-res"),
gr.outputs.Image(type='pil', label="Alpha"),
gr.outputs.Image(type='pil', label="Beta"),
gr.outputs.Image(type='pil', label="Uncertainty")
],
title=title,
description=description,
article=article,
examples=[
["./demo_examples/tue.jpeg"],
["./demo_examples/baby.png"],
["./demo_examples/bird.png"],
["./demo_examples/butterfly.png"],
["./demo_examples/head.png"],
["./demo_examples/woman.png"],
]
).launch() |