File size: 4,093 Bytes
99e984c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import numpy as np
import random
import matplotlib.pyplot as plt
from matplotlib import cm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode as IMode

from PIL import Image

from ds import *
from losses import *
from networks_SRGAN import *
from utils import *

device = 'cuda'


NetG = Generator()
model_parameters = filter(lambda p: True, NetG.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print("Number of Parameters:", params)
NetC = BayesCap(in_channels=3, out_channels=3)

ensure_checkpoint_exists('BayesCap_SRGAN.pth')
NetG.load_state_dict(torch.load('BayesCap_SRGAN.pth', map_location=device))
NetG.to(device)
NetG.eval()

ensure_checkpoint_exists('BayesCap_ckpt.pth')
NetC.load_state_dict(torch.load('BayesCap_ckpt.pth', map_location=device))
NetC.to(device)
NetC.eval()

def tensor01_to_pil(xt):
    r = transforms.ToPILImage(mode='RGB')(xt.squeeze())
    return r


def predict(img):
    """
    img: image
    """
    image_size = (256,256)
    upscale_factor = 4
    lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True)
    # lr_transforms = transforms.Resize((128, 128), interpolation=IMode.BICUBIC, antialias=True)
    
    img = Image.fromarray(np.array(img))
    img = lr_transforms(img)
    lr_tensor = utils.image2tensor(img, range_norm=False, half=False)
    
    device = 'cuda'
    dtype = torch.cuda.FloatTensor
    xLR = lr_tensor.to(device).unsqueeze(0)
    xLR = xLR.type(dtype)
    # pass them through the network
    with torch.no_grad():
        xSR = NetG(xLR)
        xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
        
    a_map = (1/(xSRC_alpha[0] + 1e-5)).to('cpu').data
    b_map = xSRC_beta[0].to('cpu').data
    u_map = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2)))) 
    
    
    x_LR = tensor01_to_pil(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
    
    x_mean = tensor01_to_pil(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
    
    #im = Image.fromarray(np.uint8(cm.gist_earth(myarray)*255))
    
    a_map = torch.clamp(a_map, min=0, max=0.1)
    a_map = (a_map - a_map.min())/(a_map.max() - a_map.min())
    x_alpha = Image.fromarray(np.uint8(cm.inferno(a_map.transpose(0,2).transpose(0,1).squeeze())*255))
    
    b_map = torch.clamp(b_map, min=0.45, max=0.75)
    b_map = (b_map - b_map.min())/(b_map.max() - b_map.min())
    x_beta = Image.fromarray(np.uint8(cm.cividis(b_map.transpose(0,2).transpose(0,1).squeeze())*255))
    
    u_map = torch.clamp(u_map, min=0, max=0.15)
    u_map = (u_map - u_map.min())/(u_map.max() - u_map.min())
    x_uncer = Image.fromarray(np.uint8(cm.hot(u_map.transpose(0,2).transpose(0,1).squeeze())*255))
    
    return x_LR, x_mean, x_alpha, x_beta, x_uncer

import gradio as gr

title = "BayesCap"
description = "BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks (ECCV 2022)"
article = "<p style='text-align: center'> BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks| <a href='https://github.com/ExplainableML/BayesCap'>Github Repo</a></p>"


gr.Interface(
    fn=predict, 
    inputs=gr.inputs.Image(type='pil', label="Orignal"), 
    outputs=[
        gr.outputs.Image(type='pil', label="Low-res"), 
        gr.outputs.Image(type='pil', label="Super-res"), 
        gr.outputs.Image(type='pil', label="Alpha"), 
        gr.outputs.Image(type='pil', label="Beta"), 
        gr.outputs.Image(type='pil', label="Uncertainty")
     ],
    title=title,
    description=description,
    article=article,
     examples=[
        ["./demo_examples/baby.png"],
        ["./demo_examples/bird.png"],
        ["./demo_examples/butterfly.png"],
        ["./demo_examples/head.png"],
        ["./demo_examples/woman.png"],
    ]
).launch(share=True)