Lennard Schober commited on
Commit
c9793cd
·
1 Parent(s): e1d92e1

Plot everything on page

Browse files
Files changed (1) hide show
  1. app.py +88 -21
app.py CHANGED
@@ -11,6 +11,7 @@ glob_b = -2
11
  glob_c = -4
12
  glob_d = 7
13
 
 
14
  def clear_npz():
15
  current_directory = os.getcwd() # Get the current working directory
16
  for filename in os.listdir(current_directory):
@@ -128,7 +129,28 @@ def plot_heat_equation(m, approx_type):
128
  # Create the figure
129
  fig = go.Figure(data=traces, layout=layout)
130
 
131
- fig.show(config=config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
 
134
  def plot_errors(m, approx_type):
@@ -213,7 +235,29 @@ def plot_errors(m, approx_type):
213
  # Create the figure
214
  fig = go.Figure(data=traces, layout=layout)
215
 
216
- fig.show(config=config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
  def generate_data(n_x=32, n_t=50):
219
  """Generate training data."""
@@ -318,12 +362,12 @@ def plot_function(a, b, c, d, k=0.5):
318
  global glob_a, glob_b, glob_c, glob_d
319
 
320
  glob_a, glob_b, glob_c, glob_d = a, b, c, d
321
-
322
  x = np.linspace(0, 1, 100)
323
  t = np.linspace(0, 5, 500)
324
  X, T = np.meshgrid(x, t) # Create the mesh grid
325
  Z = complex_heat_eq_solution(X, T, a, b, c, d)
326
-
327
  traces = []
328
  traces.append(
329
  go.Surface(
@@ -348,11 +392,10 @@ def plot_function(a, b, c, d, k=0.5):
348
  ),
349
  margin=dict(l=0, r=0, t=0, b=0), # Reduce margins
350
  )
351
-
352
  # Create the figure
353
  fig = go.Figure(data=traces, layout=layout)
354
 
355
- # fig.show(config=config)
356
  fig.update_layout(
357
  modebar_remove=[
358
  "pan",
@@ -370,12 +413,22 @@ def plot_function(a, b, c, d, k=0.5):
370
  "orbitRotation",
371
  "tableRotation",
372
  "toImage",
373
- "resetCameraDefault3d"
374
  ]
375
  )
376
 
377
  return fig
378
 
 
 
 
 
 
 
 
 
 
 
379
 
380
  # Gradio interface
381
  def create_gradio_ui():
@@ -394,10 +447,18 @@ def create_gradio_ui():
394
 
395
  with gr.Row():
396
  with gr.Column():
397
- a_slider = gr.Slider(minimum=-10, maximum=-1, step=1, value=-2, label="a")
398
- b_slider = gr.Slider(minimum=-10, maximum=10, step=1, value=-2, label="b")
399
- c_slider = gr.Slider(minimum=-10, maximum=-1, step=1, value=-4, label="c")
400
- d_slider = gr.Slider(minimum=-10, maximum=10, step=1, value=7, label="d")
 
 
 
 
 
 
 
 
401
 
402
  plot_output = gr.Plot()
403
 
@@ -437,7 +498,7 @@ def create_gradio_ui():
437
  # Output to show status
438
  output = gr.Textbox(label="Status", interactive=False)
439
 
440
- with gr.Row():
441
  # Button to train coefficients
442
  train_button = gr.Button("Train Coefficients")
443
  # Function to trigger training and update dropdown
@@ -446,18 +507,24 @@ def create_gradio_ui():
446
  inputs=[m_slider, kernel_dropdown],
447
  outputs=output,
448
  )
449
- with gr.Row():
450
  approx_button = gr.Button("Plot Approximation")
451
- approx_button.click(
452
- fn=plot_heat_equation, inputs=[m_slider, kernel_dropdown], outputs=None
453
- )
454
 
455
- error_button = gr.Button("Plot Errors")
456
- error_button.click(
457
- fn=plot_errors, inputs=[m_slider, kernel_dropdown], outputs=None
458
- )
 
 
 
 
 
 
459
  demo.load(fn=clear_npz, inputs=None, outputs=None)
460
- demo.load(fn=plot_function, inputs=[a_slider, b_slider, c_slider, d_slider], outputs=[plot_output])
 
 
 
 
461
 
462
  return demo
463
 
 
11
  glob_c = -4
12
  glob_d = 7
13
 
14
+
15
  def clear_npz():
16
  current_directory = os.getcwd() # Get the current working directory
17
  for filename in os.listdir(current_directory):
 
129
  # Create the figure
130
  fig = go.Figure(data=traces, layout=layout)
131
 
132
+ fig.update_layout(
133
+ modebar_remove=[
134
+ "pan",
135
+ "resetCameraLastSave",
136
+ "hoverClosest3d",
137
+ "hoverCompareCartesian",
138
+ "zoomIn",
139
+ "zoomOut",
140
+ "select2d",
141
+ "lasso2d",
142
+ "zoomIn2d",
143
+ "zoomOut2d",
144
+ "sendDataToCloud",
145
+ "zoom3d",
146
+ "orbitRotation",
147
+ "tableRotation",
148
+ "toImage",
149
+ "resetCameraDefault3d",
150
+ ]
151
+ )
152
+
153
+ return fig
154
 
155
 
156
  def plot_errors(m, approx_type):
 
235
  # Create the figure
236
  fig = go.Figure(data=traces, layout=layout)
237
 
238
+ fig.update_layout(
239
+ modebar_remove=[
240
+ "pan",
241
+ "resetCameraLastSave",
242
+ "hoverClosest3d",
243
+ "hoverCompareCartesian",
244
+ "zoomIn",
245
+ "zoomOut",
246
+ "select2d",
247
+ "lasso2d",
248
+ "zoomIn2d",
249
+ "zoomOut2d",
250
+ "sendDataToCloud",
251
+ "zoom3d",
252
+ "orbitRotation",
253
+ "tableRotation",
254
+ "toImage",
255
+ "resetCameraDefault3d",
256
+ ]
257
+ )
258
+
259
+ return fig
260
+
261
 
262
  def generate_data(n_x=32, n_t=50):
263
  """Generate training data."""
 
362
  global glob_a, glob_b, glob_c, glob_d
363
 
364
  glob_a, glob_b, glob_c, glob_d = a, b, c, d
365
+
366
  x = np.linspace(0, 1, 100)
367
  t = np.linspace(0, 5, 500)
368
  X, T = np.meshgrid(x, t) # Create the mesh grid
369
  Z = complex_heat_eq_solution(X, T, a, b, c, d)
370
+
371
  traces = []
372
  traces.append(
373
  go.Surface(
 
392
  ),
393
  margin=dict(l=0, r=0, t=0, b=0), # Reduce margins
394
  )
395
+
396
  # Create the figure
397
  fig = go.Figure(data=traces, layout=layout)
398
 
 
399
  fig.update_layout(
400
  modebar_remove=[
401
  "pan",
 
413
  "orbitRotation",
414
  "tableRotation",
415
  "toImage",
416
+ "resetCameraDefault3d",
417
  ]
418
  )
419
 
420
  return fig
421
 
422
+ def plot_all(m, kernel):
423
+ # Generate the plot content (replace this with your actual plot logic)
424
+ approx_fig = plot_heat_equation(m, kernel) # Replace with your function for approx_plot
425
+ error_fig = plot_errors(m, kernel) # Replace with your function for error_plot
426
+
427
+ # Return the figures and make the plots visible
428
+ return (
429
+ gr.update(visible=True, value=approx_fig),
430
+ gr.update(visible=True, value=error_fig),
431
+ )
432
 
433
  # Gradio interface
434
  def create_gradio_ui():
 
447
 
448
  with gr.Row():
449
  with gr.Column():
450
+ a_slider = gr.Slider(
451
+ minimum=-10, maximum=-1, step=1, value=-2, label="a"
452
+ )
453
+ b_slider = gr.Slider(
454
+ minimum=-10, maximum=10, step=1, value=-2, label="b"
455
+ )
456
+ c_slider = gr.Slider(
457
+ minimum=-10, maximum=-1, step=1, value=-4, label="c"
458
+ )
459
+ d_slider = gr.Slider(
460
+ minimum=-10, maximum=10, step=1, value=7, label="d"
461
+ )
462
 
463
  plot_output = gr.Plot()
464
 
 
498
  # Output to show status
499
  output = gr.Textbox(label="Status", interactive=False)
500
 
501
+ with gr.Column():
502
  # Button to train coefficients
503
  train_button = gr.Button("Train Coefficients")
504
  # Function to trigger training and update dropdown
 
507
  inputs=[m_slider, kernel_dropdown],
508
  outputs=output,
509
  )
 
510
  approx_button = gr.Button("Plot Approximation")
 
 
 
511
 
512
+ with gr.Row():
513
+ approx_plot = gr.Plot(visible=False)
514
+ error_plot = gr.Plot(visible=False)
515
+
516
+ approx_button.click(
517
+ fn=plot_all,
518
+ inputs=[m_slider, kernel_dropdown],
519
+ outputs=[approx_plot, error_plot],
520
+ )
521
+
522
  demo.load(fn=clear_npz, inputs=None, outputs=None)
523
+ demo.load(
524
+ fn=plot_function,
525
+ inputs=[a_slider, b_slider, c_slider, d_slider],
526
+ outputs=[plot_output],
527
+ )
528
 
529
  return demo
530