|
|
|
import time |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import pandas as pd |
|
from scipy.ndimage import gaussian_filter |
|
from skimage.data import coins |
|
from skimage.transform import rescale |
|
from sklearn.cluster import AgglomerativeClustering |
|
from sklearn.feature_extraction.image import grid_to_graph |
|
|
|
|
|
plt.switch_backend('agg') |
|
|
|
def cluster(gf_sigma, scale, anti_alias, mode, n_clusters,linkage): |
|
orig_coins = coins() |
|
smoothened_coins = gaussian_filter(orig_coins, sigma=gf_sigma) |
|
|
|
|
|
|
|
|
|
rescaled_coins = rescale( |
|
smoothened_coins, |
|
scale = scale, |
|
mode="reflect", |
|
anti_aliasing=False, |
|
) |
|
|
|
X = np.reshape(rescaled_coins, (-1, 1)) |
|
|
|
|
|
connectivity = grid_to_graph(*rescaled_coins.shape) |
|
|
|
|
|
result = "" |
|
result += "Compute structured hierarchical clustering...\n" |
|
st = time.time() |
|
ward = AgglomerativeClustering( |
|
n_clusters=n_clusters, linkage="ward", connectivity=connectivity |
|
) |
|
ward.fit(X) |
|
label = np.reshape(ward.labels_, rescaled_coins.shape) |
|
|
|
result += f"Elapsed time: {time.time() - st:.3f}s \n" |
|
result += f"Number of pixels: {label.size} \n" |
|
result += f"Number of clusters: {np.unique(label).size} \n" |
|
|
|
|
|
fig = plt.figure(figsize=(7, 7)) |
|
plt.imshow(rescaled_coins, cmap=plt.cm.gray) |
|
for l in range(n_clusters): |
|
plt.contour( |
|
label == l, |
|
colors=[ |
|
plt.cm.nipy_spectral(l / float(n_clusters)), |
|
], |
|
) |
|
plt.axis("off") |
|
|
|
return result, fig |
|
|
|
|
|
|
|
title = "A demo of structured Ward hierarchical clustering on an image of coins" |
|
|
|
|
|
def do_submit(gf_sigma, scale, anti_alias, mode, n_clusters,linkage): |
|
gf_sigma = float(gf_sigma) |
|
scale = float(scale) |
|
anti_alias = True if anti_alias == "True" else False |
|
n_clusters = int(n_clusters) |
|
|
|
result, fig = cluster(gf_sigma, scale, anti_alias, mode, n_clusters,linkage) |
|
return result, fig |
|
|
|
|
|
|
|
with gr.Blocks(title=title) as demo: |
|
gr.Markdown(f"## {title}") |
|
gr.Markdown("This is an interactive demo for this [scikit-learn example](https://scikit-learn.org/stable/auto_examples/cluster/plot_coin_ward_segmentation.html).") |
|
gr.Markdown("Compute the segmentation of a 2D image with Ward hierarchical clustering. \ |
|
The clustering is spatially constrained in order for each segmented region to be in one piece.") |
|
|
|
with gr.Row(variant="evenly-spaced"): |
|
gf_sigma = gr.Slider(minimum=1, maximum=10, label="Gaussian Filter Sigma", value=2, \ |
|
info="Standard deviation for Gaussian filtering before down-scaling.", step=0.1) |
|
|
|
scale = gr.Slider(minimum=0.1, maximum=0.7, label="Scale", value=0.2, \ |
|
info="Scale factor for the image.", step=0.1) |
|
|
|
anti_alias = gr.Radio(["True","False"], label="Anti Aliasing", value="False", \ |
|
info="Whether to apply a Gaussian filter to smooth the image prior to down-scaling. \ |
|
It is crucial to filter when down-sampling the image to avoid aliasing artifacts.\ |
|
If input image data type is bool, no anti-aliasing is applied.") |
|
|
|
mode = gr.Dropdown( |
|
["constant", "edge", "symmetric", "reflect", "wrap"], value=["reflect"], multiselect=False, label="mode",\ |
|
info="Points outside the boundaries of the input are filled according to the given mode. Modes match the behaviour of numpy.pad." |
|
) |
|
|
|
with gr.Row(): |
|
|
|
n_clusters = gr.Slider(minimum=2, maximum=70,label="Number of Clusters", value=27, \ |
|
info="The number of clusters to find.", step=1) |
|
linkage = gr.Dropdown(["ward", "complete", "average", "single"], value=["ward"], multiselect=False, label="linkage",\ |
|
info="Which linkage criterion to use. The linkage criterion determines which distance to use between sets of observation. ") |
|
|
|
output = gr.Textbox(label="Output Box") |
|
plt_out = gr.Plot() |
|
|
|
submit_btn = gr.Button("Submit") |
|
submit_btn.click(fn=do_submit, inputs=[gf_sigma, scale, anti_alias, mode, n_clusters,linkage], outputs=[output, plt_out]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|
|
|