merve's picture
merve HF staff
Update app.py
bce304b
import time
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.ndimage import gaussian_filter
from skimage.data import coins
from skimage.transform import rescale
from sklearn.cluster import AgglomerativeClustering
from sklearn.feature_extraction.image import grid_to_graph
plt.switch_backend('agg')
def cluster(gf_sigma, scale, anti_alias, mode, n_clusters,linkage):
orig_coins = coins()
smoothened_coins = gaussian_filter(orig_coins, sigma=gf_sigma)
# Resize it to 20% of the original size to speed up the processing Applying a Gaussian filter for smoothing
# prior to down-scaling reduces aliasing artifacts.
rescaled_coins = rescale(
smoothened_coins,
scale = scale,
mode="reflect",
anti_aliasing=False,
)
X = np.reshape(rescaled_coins, (-1, 1))
connectivity = grid_to_graph(*rescaled_coins.shape)
result = ""
result += "Compute structured hierarchical clustering...\n"
st = time.time()
ward = AgglomerativeClustering(
n_clusters=n_clusters, linkage="ward", connectivity=connectivity
)
ward.fit(X)
label = np.reshape(ward.labels_, rescaled_coins.shape)
result += f"Elapsed time: {time.time() - st:.3f}s \n"
result += f"Number of pixels: {label.size} \n"
result += f"Number of clusters: {np.unique(label).size} \n"
fig = plt.figure(figsize=(7, 7))
plt.imshow(rescaled_coins, cmap=plt.cm.gray)
for l in range(n_clusters):
plt.contour(
label == l,
colors=[
plt.cm.nipy_spectral(l / float(n_clusters)),
],
)
plt.axis("off")
return result, fig
## https://scikit-learn.org/stable/auto_examples/cluster/plot_coin_ward_segmentation.html
title = "A demo of structured Ward hierarchical clustering on an image of coins"
def do_submit(gf_sigma, scale, anti_alias, mode, n_clusters,linkage):
gf_sigma = float(gf_sigma)
scale = float(scale)
anti_alias = True if anti_alias == "True" else False
n_clusters = int(n_clusters)
result, fig = cluster(gf_sigma, scale, anti_alias, mode, n_clusters,linkage)
return result, fig
with gr.Blocks(title=title) as demo:
gr.Markdown(f"## {title}")
gr.Markdown("This is an interactive demo for this [scikit-learn example](https://scikit-learn.org/stable/auto_examples/cluster/plot_coin_ward_segmentation.html).")
gr.Markdown("Compute the segmentation of a 2D image with Ward hierarchical clustering. \
The clustering is spatially constrained in order for each segmented region to be in one piece.")
with gr.Row(variant="evenly-spaced"):
gf_sigma = gr.Slider(minimum=1, maximum=10, label="Gaussian Filter Sigma", value=2, \
info="Standard deviation for Gaussian filtering before down-scaling.", step=0.1)
scale = gr.Slider(minimum=0.1, maximum=0.7, label="Scale", value=0.2, \
info="Scale factor for the image.", step=0.1)
anti_alias = gr.Radio(["True","False"], label="Anti Aliasing", value="False", \
info="Whether to apply a Gaussian filter to smooth the image prior to down-scaling. \
It is crucial to filter when down-sampling the image to avoid aliasing artifacts.\
If input image data type is bool, no anti-aliasing is applied.")
mode = gr.Dropdown(
["constant", "edge", "symmetric", "reflect", "wrap"], value=["reflect"], multiselect=False, label="mode",\
info="Points outside the boundaries of the input are filled according to the given mode. Modes match the behaviour of numpy.pad."
)
with gr.Row():
## Agglomerative Clustering parameters
n_clusters = gr.Slider(minimum=2, maximum=70,label="Number of Clusters", value=27, \
info="The number of clusters to find.", step=1)
linkage = gr.Dropdown(["ward", "complete", "average", "single"], value=["ward"], multiselect=False, label="linkage",\
info="Which linkage criterion to use. The linkage criterion determines which distance to use between sets of observation. ")
output = gr.Textbox(label="Output Box")
plt_out = gr.Plot()
submit_btn = gr.Button("Submit")
submit_btn.click(fn=do_submit, inputs=[gf_sigma, scale, anti_alias, mode, n_clusters,linkage], outputs=[output, plt_out])
if __name__ == "__main__":
demo.launch()