File size: 5,845 Bytes
99e984c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
974236a
 
 
 
 
99e984c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d43d5d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99e984c
 
 
 
 
 
 
08c0204
 
 
99e984c
 
 
 
d43d5d3
99e984c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9661f5
 
99e984c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfff052
99e984c
 
 
 
 
 
269330d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import numpy as np
import random
import matplotlib.pyplot as plt
from matplotlib import cm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode as IMode

from PIL import Image

from ds import *
from losses import *
from networks_SRGAN import *
from utils import *

device = 'cpu'
if device == 'cuda':
    dtype = torch.cuda.FloatTensor
else:
    dtype = torch.FloatTensor

NetG = Generator()
model_parameters = filter(lambda p: True, NetG.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print("Number of Parameters:", params)
NetC = BayesCap(in_channels=3, out_channels=3)

ensure_checkpoint_exists('BayesCap_SRGAN.pth')
NetG.load_state_dict(torch.load('BayesCap_SRGAN.pth', map_location=device))
NetG.to(device)
NetG.eval()

ensure_checkpoint_exists('BayesCap_ckpt.pth')
NetC.load_state_dict(torch.load('BayesCap_ckpt.pth', map_location=device))
NetC.to(device)
NetC.eval()

def tensor01_to_pil(xt):
    r = transforms.ToPILImage(mode='RGB')(xt.squeeze())
    return r

def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor:
    """Convert ``PIL.Image`` to Tensor.
    Args:
        image (np.ndarray): The image data read by ``PIL.Image``
        range_norm (bool): Scale [0, 1] data to between [-1, 1]
        half (bool): Whether to convert torch.float32 similarly to torch.half type.
    Returns:
        Normalized image data
    Examples:
        >>> image = Image.open("image.bmp")
        >>> tensor_image = image2tensor(image, range_norm=False, half=False)
    """
    tensor = F.to_tensor(image)

    if range_norm:
        tensor = tensor.mul_(2.0).sub_(1.0)
    if half:
        tensor = tensor.half()

    return tensor


def predict(img):
    """
    img: image
    """
    image_size = (256,256)
    upscale_factor = 4
    # lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True)
    # to retain aspect ratio
    lr_transforms = transforms.Resize(image_size[0]//upscale_factor, interpolation=IMode.BICUBIC, antialias=True)
    # lr_transforms = transforms.Resize((128, 128), interpolation=IMode.BICUBIC, antialias=True)
    
    img = Image.fromarray(np.array(img))
    img = lr_transforms(img)
    lr_tensor = image2tensor(img, range_norm=False, half=False)
    
    xLR = lr_tensor.to(device).unsqueeze(0)
    xLR = xLR.type(dtype)
    # pass them through the network
    with torch.no_grad():
        xSR = NetG(xLR)
        xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
        
    a_map = (1/(xSRC_alpha[0] + 1e-5)).to('cpu').data
    b_map = xSRC_beta[0].to('cpu').data
    u_map = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2)))) 
    
    
    x_LR = tensor01_to_pil(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
    
    x_mean = tensor01_to_pil(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
    
    #im = Image.fromarray(np.uint8(cm.gist_earth(myarray)*255))
    
    a_map = torch.clamp(a_map, min=0, max=0.1)
    a_map = (a_map - a_map.min())/(a_map.max() - a_map.min())
    x_alpha = Image.fromarray(np.uint8(cm.inferno(a_map.transpose(0,2).transpose(0,1).squeeze())*255))
    
    b_map = torch.clamp(b_map, min=0.45, max=0.75)
    b_map = (b_map - b_map.min())/(b_map.max() - b_map.min())
    x_beta = Image.fromarray(np.uint8(cm.cividis(b_map.transpose(0,2).transpose(0,1).squeeze())*255))
    
    u_map = torch.clamp(u_map, min=0, max=0.15)
    u_map = (u_map - u_map.min())/(u_map.max() - u_map.min())
    x_uncer = Image.fromarray(np.uint8(cm.hot(u_map.transpose(0,2).transpose(0,1).squeeze())*255))
    
    return x_LR, x_mean, x_alpha, x_beta, x_uncer

import gradio as gr

title = "BayesCap"
method = "In this work, we propose a method (called BayesCap) to estimate the per-pixel uncertainty of a pretrained computer vision model like SRGAN (used for super-resolution). BayesCap takes the ouput of the pretrained model (in this case SRGAN), and predicts the per-pixel distribution parameters for the output, that can be used to quantify the per-pixel uncertainty. In our work, we model the per-pixel output as a <a href='https://en.wikipedia.org/wiki/Generalized_normal_distribution'>Generalized Gaussian distribution</a> that is parameterized by 3 parameters the mean, scale (alpha), and the shape (beta). As a result our model predicts these three parameters as shown below. From these 3 parameters one can compute the uncertainty as shown in <a href='https://en.wikipedia.org/wiki/Generalized_normal_distribution'>this article</a>."
description = "BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks (ECCV 2022) <br>" + method
article = "<p style='text-align: center'> BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks| <a href='https://github.com/ExplainableML/BayesCap'>Github Repo</a></p>"


gr.Interface(
    fn=predict, 
    inputs=gr.inputs.Image(type='pil', label="Orignal"), 
    outputs=[
        gr.outputs.Image(type='pil', label="Low-res"), 
        gr.outputs.Image(type='pil', label="Super-res"), 
        gr.outputs.Image(type='pil', label="Alpha"), 
        gr.outputs.Image(type='pil', label="Beta"), 
        gr.outputs.Image(type='pil', label="Uncertainty")
     ],
    title=title,
    description=description,
    article=article,
     examples=[
        ["./demo_examples/tue.jpeg"],
        ["./demo_examples/baby.png"],
        ["./demo_examples/bird.png"],
        ["./demo_examples/butterfly.png"],
        ["./demo_examples/head.png"],
        ["./demo_examples/woman.png"],
    ]
).launch()