LenixC commited on
Commit
cd447f7
·
1 Parent(s): e38d1c8

Added Preview of Dataset to model.

Browse files
Files changed (1) hide show
  1. app.py +58 -3
app.py CHANGED
@@ -163,6 +163,54 @@ def plot_on_dataset(X, y, models, name):
163
 
164
  return plt
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  title = "Compare Stochastic learning strategies for MLPClassifier"
167
  with gr.Blocks() as demo:
168
  gr.Markdown(f" # {title}")
@@ -174,9 +222,11 @@ with gr.Blocks() as demo:
174
  with gr.Tabs():
175
  with gr.TabItem("Model and Data Selection"):
176
  with gr.Row():
177
- dataset = gr.Dropdown(["Iris", "Digits", "Circles", "Moons"],
178
- value="Iris",
179
- type="index")
 
 
180
  models = gr.CheckboxGroup(["Constant Learning-Rate",
181
  "Constant with Momentum",
182
  "Constant with Nesterov's Momentum",
@@ -224,6 +274,11 @@ with gr.Blocks() as demo:
224
  adam_lr],
225
  outputs=[stoch_graph]
226
  )
 
 
 
 
 
227
 
228
  if __name__ == '__main__':
229
  demo.launch()
 
163
 
164
  return plt
165
 
166
+
167
+ def plot_example(dataset):
168
+ if dataset == 0: # Iris
169
+ fig = plt.figure()
170
+ iris = datasets.load_iris()
171
+ col_1 = iris.data[:, 0]
172
+ col_2 = iris.data[:, 1]
173
+ target = iris.target
174
+ plt.scatter(col_1, col_2, c=target)
175
+ plt.title("Sepal Width vs. Sepal Height")
176
+ return fig
177
+
178
+ if dataset == 1: # Digits
179
+ digits = datasets.load_digits()
180
+
181
+ images = digits.images[:16]
182
+ labels = digits.target[:16]
183
+
184
+ fig, axes = plt.subplots(4, 4, figsize=(8, 8))
185
+
186
+ for i, ax in enumerate(axes.flat):
187
+ ax.imshow(images[i], cmap='gray')
188
+ ax.set_title(f"Label: {labels[i]}")
189
+ ax.axis('off')
190
+
191
+ plt.suptitle("First 16 Handwritten Digits")
192
+ plt.tight_layout()
193
+ return fig
194
+
195
+ if dataset == 2: # Circles
196
+ circles = datasets.make_circles(noise=0.2, factor=0.5, random_state=1),
197
+ X = circles[0][0]
198
+ y = circles[0][1]
199
+ fig = plt.figure()
200
+ plt.scatter(X[:, 0], X[:, 1], c=y)
201
+ plt.title("Circles Toy Dataset")
202
+ return fig
203
+
204
+ if dataset == 3: # Moons
205
+ moons = datasets.make_moons(noise=0.3, random_state=0),
206
+ X = moons[0][0]
207
+ y = moons[0][1]
208
+ fig = plt.figure()
209
+ plt.scatter(X[:, 0], X[:, 1], c=y)
210
+ plt.title("Moons Toy Dataset")
211
+ return fig
212
+
213
+
214
  title = "Compare Stochastic learning strategies for MLPClassifier"
215
  with gr.Blocks() as demo:
216
  gr.Markdown(f" # {title}")
 
222
  with gr.Tabs():
223
  with gr.TabItem("Model and Data Selection"):
224
  with gr.Row():
225
+ with gr.Column():
226
+ dataset = gr.Dropdown(["Iris", "Digits", "Circles", "Moons"],
227
+ value="Iris",
228
+ type="index")
229
+ example_plot = gr.Plot(label="Dataset")
230
  models = gr.CheckboxGroup(["Constant Learning-Rate",
231
  "Constant with Momentum",
232
  "Constant with Nesterov's Momentum",
 
274
  adam_lr],
275
  outputs=[stoch_graph]
276
  )
277
+ dataset.change(
278
+ fn=plot_example,
279
+ inputs=[dataset],
280
+ outputs=[example_plot]
281
+ )
282
 
283
  if __name__ == '__main__':
284
  demo.launch()