rashmi commited on
Commit
2d51d39
·
1 Parent(s): fdc89a4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -0
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import time
3
+ import gradio as gr
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import pandas as pd
7
+ from scipy.ndimage import gaussian_filter
8
+ from skimage.data import coins
9
+ from skimage.transform import rescale
10
+ from sklearn.cluster import AgglomerativeClustering
11
+ from sklearn.feature_extraction.image import grid_to_graph
12
+
13
+
14
+ plt.switch_backend('agg')
15
+
16
+ def cluster(gf_sigma, scale, anti_alias, mode, n_clusters,linkage):
17
+ orig_coins = coins()
18
+ smoothened_coins = gaussian_filter(orig_coins, sigma=gf_sigma)
19
+
20
+ # Resize it to 20% of the original size to speed up the processing Applying a Gaussian filter for smoothing
21
+ # prior to down-scaling reduces aliasing artifacts.
22
+
23
+ rescaled_coins = rescale(
24
+ smoothened_coins,
25
+ scale = scale,
26
+ mode="reflect",
27
+ anti_aliasing=False,
28
+ )
29
+
30
+ X = np.reshape(rescaled_coins, (-1, 1))
31
+
32
+
33
+ connectivity = grid_to_graph(*rescaled_coins.shape)
34
+
35
+
36
+ result = ""
37
+ result += "Compute structured hierarchical clustering...\n"
38
+ st = time.time()
39
+ ward = AgglomerativeClustering(
40
+ n_clusters=n_clusters, linkage="ward", connectivity=connectivity
41
+ )
42
+ ward.fit(X)
43
+ label = np.reshape(ward.labels_, rescaled_coins.shape)
44
+
45
+ result += f"Elapsed time: {time.time() - st:.3f}s \n"
46
+ result += f"Number of pixels: {label.size} \n"
47
+ result += f"Number of clusters: {np.unique(label).size} \n"
48
+
49
+
50
+ fig = plt.figure(figsize=(7, 7))
51
+ plt.imshow(rescaled_coins, cmap=plt.cm.gray)
52
+ for l in range(n_clusters):
53
+ plt.contour(
54
+ label == l,
55
+ colors=[
56
+ plt.cm.nipy_spectral(l / float(n_clusters)),
57
+ ],
58
+ )
59
+ plt.axis("off")
60
+
61
+ return result, fig
62
+
63
+ ## https://scikit-learn.org/stable/auto_examples/cluster/plot_coin_ward_segmentation.html
64
+
65
+ title = "A demo of structured Ward hierarchical clustering on an image of coins"
66
+
67
+
68
+ def do_submit(gf_sigma, scale, anti_alias, mode, n_clusters,linkage):
69
+ gf_sigma = float(gf_sigma)
70
+ scale = float(scale)
71
+ anti_alias = True if anti_alias == "True" else False
72
+ n_clusters = int(n_clusters)
73
+
74
+ result, fig = cluster(gf_sigma, scale, anti_alias, mode, n_clusters,linkage)
75
+ return result, fig
76
+
77
+
78
+
79
+ with gr.Blocks(title=title) as demo:
80
+ gr.Markdown(f"## {title}")
81
+ gr.Markdown("[Scikit-learn Example](https://scikit-learn.org/stable/auto_examples/cluster/plot_coin_ward_segmentation.html)")
82
+ gr.Markdown("Compute the segmentation of a 2D image with Ward hierarchical clustering. \
83
+ The clustering is spatially constrained in order for each segmented region to be in one piece.")
84
+
85
+ with gr.Row(variant="evenly-spaced"):
86
+ gf_sigma = gr.Slider(minimum=1, maximum=10, label="Gaussian Filter Sigma", value=2, \
87
+ info="Standard deviation for Gaussian filtering before down-scaling.", step=0.1)
88
+
89
+ scale = gr.Slider(minimum=0.1, maximum=0.7, label="Scale", value=0.2, \
90
+ info="Scale factor for the image.", step=0.1)
91
+
92
+ anti_alias = gr.Radio(["True","False"], label="Anti Aliasing", value="False", \
93
+ info="Whether to apply a Gaussian filter to smooth the image prior to down-scaling. \
94
+ It is crucial to filter when down-sampling the image to avoid aliasing artifacts.\
95
+ If input image data type is bool, no anti-aliasing is applied.")
96
+
97
+ mode = gr.Dropdown(
98
+ ["constant", "edge", "symmetric", "reflect", "wrap"], value=["reflect"], multiselect=False, label="mode",\
99
+ info="Points outside the boundaries of the input are filled according to the given mode. Modes match the behaviour of numpy.pad."
100
+ )
101
+
102
+ with gr.Row():
103
+ ## Agglomerative Clustering parameters
104
+ n_clusters = gr.Slider(minimum=2, maximum=70,label="Number of Clusters", value=27, \
105
+ info="The number of clusters to find.", step=1)
106
+ linkage = gr.Dropdown(["ward", "complete", "average", "single"], value=["ward"], multiselect=False, label="linkage",\
107
+ info="Which linkage criterion to use. The linkage criterion determines which distance to use between sets of observation. ")
108
+
109
+ output = gr.Textbox(label="Output Box")
110
+ plt_out = gr.Plot()
111
+
112
+ submit_btn = gr.Button("Submit")
113
+ submit_btn.click(fn=do_submit, inputs=[gf_sigma, scale, anti_alias, mode, n_clusters,linkage], outputs=[output, plt_out])
114
+
115
+ if __name__ == "__main__":
116
+ demo.launch()
117
+