Lennard Schober commited on
Commit
0bb1ad3
·
1 Parent(s): c9793cd

Fix heat equation solution

Browse files
Files changed (1) hide show
  1. app.py +39 -32
app.py CHANGED
@@ -6,10 +6,10 @@ import numpy as np
6
  import gradio as gr
7
  import plotly.graph_objs as go
8
 
9
- glob_a = -2
10
- glob_b = -2
11
- glob_c = -4
12
- glob_d = 7
13
 
14
 
15
  def clear_npz():
@@ -26,12 +26,13 @@ def clear_npz():
26
  print(f"Failed to delete {file_path}. Reason: {e}")
27
 
28
 
29
- def complex_heat_eq_solution(x, t, a=glob_a, b=glob_b, c=glob_c, d=glob_d, k=0.5):
30
- global glob_a, glob_b, glob_c, glob_d
 
31
  return (
32
- np.exp(-k * t) * np.sin(np.pi * x)
33
- + 0.5 * np.exp(glob_a * k * t) * np.sin(glob_b * np.pi * x)
34
- + 0.25 * np.exp(glob_c * k * t) * np.sin(glob_d * np.pi * x)
35
  )
36
 
37
 
@@ -358,15 +359,15 @@ def train_coefficients(m, kernel):
358
  return f"Training completed in {time.time() - start_time:.2f} seconds. The average error is {avg_err}."
359
 
360
 
361
- def plot_function(a, b, c, d, k=0.5):
362
- global glob_a, glob_b, glob_c, glob_d
363
 
364
- glob_a, glob_b, glob_c, glob_d = a, b, c, d
365
 
366
  x = np.linspace(0, 1, 100)
367
  t = np.linspace(0, 5, 500)
368
  X, T = np.meshgrid(x, t) # Create the mesh grid
369
- Z = complex_heat_eq_solution(X, T, a, b, c, d)
370
 
371
  traces = []
372
  traces.append(
@@ -419,19 +420,25 @@ def plot_function(a, b, c, d, k=0.5):
419
 
420
  return fig
421
 
 
422
  def plot_all(m, kernel):
423
  # Generate the plot content (replace this with your actual plot logic)
424
- approx_fig = plot_heat_equation(m, kernel) # Replace with your function for approx_plot
425
- error_fig = plot_errors(m, kernel) # Replace with your function for error_plot
426
-
 
 
427
  # Return the figures and make the plots visible
428
  return (
429
  gr.update(visible=True, value=approx_fig),
430
  gr.update(visible=True, value=error_fig),
431
  )
432
 
 
433
  # Gradio interface
434
  def create_gradio_ui():
 
 
435
  # Get the initial available files
436
  with gr.Blocks() as demo:
437
  gr.Markdown("# Learn the Coefficients for the Heat Equation using the RFM")
@@ -447,39 +454,39 @@ def create_gradio_ui():
447
 
448
  with gr.Row():
449
  with gr.Column():
 
 
 
450
  a_slider = gr.Slider(
451
- minimum=-10, maximum=-1, step=1, value=-2, label="a"
452
  )
453
  b_slider = gr.Slider(
454
- minimum=-10, maximum=10, step=1, value=-2, label="b"
455
  )
456
  c_slider = gr.Slider(
457
- minimum=-10, maximum=-1, step=1, value=-4, label="c"
458
- )
459
- d_slider = gr.Slider(
460
- minimum=-10, maximum=10, step=1, value=7, label="d"
461
  )
462
 
463
  plot_output = gr.Plot()
464
-
465
- a_slider.change(
466
  fn=plot_function,
467
- inputs=[a_slider, b_slider, c_slider, d_slider],
468
  outputs=[plot_output],
469
  )
470
- b_slider.change(
471
  fn=plot_function,
472
- inputs=[a_slider, b_slider, c_slider, d_slider],
473
  outputs=[plot_output],
474
  )
475
- c_slider.change(
476
  fn=plot_function,
477
- inputs=[a_slider, b_slider, c_slider, d_slider],
478
  outputs=[plot_output],
479
  )
480
- d_slider.change(
481
  fn=plot_function,
482
- inputs=[a_slider, b_slider, c_slider, d_slider],
483
  outputs=[plot_output],
484
  )
485
 
@@ -522,7 +529,7 @@ def create_gradio_ui():
522
  demo.load(fn=clear_npz, inputs=None, outputs=None)
523
  demo.load(
524
  fn=plot_function,
525
- inputs=[a_slider, b_slider, c_slider, d_slider],
526
  outputs=[plot_output],
527
  )
528
 
 
6
  import gradio as gr
7
  import plotly.graph_objs as go
8
 
9
+ glob_k = 0.0025
10
+ glob_a = -2.
11
+ glob_b = 4.
12
+ glob_c = 7.5
13
 
14
 
15
  def clear_npz():
 
26
  print(f"Failed to delete {file_path}. Reason: {e}")
27
 
28
 
29
+ def complex_heat_eq_solution(x, t, k=glob_k, a=glob_a, b=glob_b, c=glob_c):
30
+ global glob_k, glob_a, glob_b, glob_c
31
+ glob_k, glob_a, glob_b, glob_c = k, a, b, c
32
  return (
33
+ np.exp(-glob_k * (glob_b * np.pi) ** 2 * t) * np.cos(glob_a * np.pi * x)
34
+ + np.exp(-glob_k * (glob_b * np.pi) ** 2 * t) * np.sin(glob_b * np.pi * x)
35
+ + np.exp(-glob_k * (glob_c * np.pi) ** 2 * t) * np.sin(glob_c * np.pi * x)
36
  )
37
 
38
 
 
359
  return f"Training completed in {time.time() - start_time:.2f} seconds. The average error is {avg_err}."
360
 
361
 
362
+ def plot_function(k, a, b, c):
363
+ global glob_k, glob_a, glob_b, glob_c
364
 
365
+ glob_k, glob_a, glob_b, glob_c = k, a, b, c
366
 
367
  x = np.linspace(0, 1, 100)
368
  t = np.linspace(0, 5, 500)
369
  X, T = np.meshgrid(x, t) # Create the mesh grid
370
+ Z = complex_heat_eq_solution(X, T, glob_k, glob_a, glob_b, glob_c)
371
 
372
  traces = []
373
  traces.append(
 
420
 
421
  return fig
422
 
423
+
424
  def plot_all(m, kernel):
425
  # Generate the plot content (replace this with your actual plot logic)
426
+ approx_fig = plot_heat_equation(
427
+ m, kernel
428
+ ) # Replace with your function for approx_plot
429
+ error_fig = plot_errors(m, kernel) # Replace with your function for error_plot
430
+
431
  # Return the figures and make the plots visible
432
  return (
433
  gr.update(visible=True, value=approx_fig),
434
  gr.update(visible=True, value=error_fig),
435
  )
436
 
437
+
438
  # Gradio interface
439
  def create_gradio_ui():
440
+ global glob_k, glob_a, glob_b, glob_c
441
+
442
  # Get the initial available files
443
  with gr.Blocks() as demo:
444
  gr.Markdown("# Learn the Coefficients for the Heat Equation using the RFM")
 
454
 
455
  with gr.Row():
456
  with gr.Column():
457
+ k_slider = gr.Slider(
458
+ minimum=0.0001, maximum=0.1, step=0.0001, value=glob_k, label="k"
459
+ )
460
  a_slider = gr.Slider(
461
+ minimum=-10, maximum=10, step=0.1, value=glob_a, label="a"
462
  )
463
  b_slider = gr.Slider(
464
+ minimum=-10, maximum=10, step=0.1, value=glob_b, label="b"
465
  )
466
  c_slider = gr.Slider(
467
+ minimum=-10, maximum=10, step=0.1, value=glob_c, label="c"
 
 
 
468
  )
469
 
470
  plot_output = gr.Plot()
471
+
472
+ k_slider.change(
473
  fn=plot_function,
474
+ inputs=[k_slider, a_slider, b_slider, c_slider],
475
  outputs=[plot_output],
476
  )
477
+ a_slider.change(
478
  fn=plot_function,
479
+ inputs=[k_slider, a_slider, b_slider, c_slider],
480
  outputs=[plot_output],
481
  )
482
+ b_slider.change(
483
  fn=plot_function,
484
+ inputs=[k_slider, a_slider, b_slider, c_slider],
485
  outputs=[plot_output],
486
  )
487
+ c_slider.change(
488
  fn=plot_function,
489
+ inputs=[k_slider, a_slider, b_slider, c_slider],
490
  outputs=[plot_output],
491
  )
492
 
 
529
  demo.load(fn=clear_npz, inputs=None, outputs=None)
530
  demo.load(
531
  fn=plot_function,
532
+ inputs=[k_slider, a_slider, b_slider, c_slider],
533
  outputs=[plot_output],
534
  )
535