Spaces:
Sleeping
Sleeping
Lennard Schober
commited on
Commit
·
0bb1ad3
1
Parent(s):
c9793cd
Fix heat equation solution
Browse files
app.py
CHANGED
@@ -6,10 +6,10 @@ import numpy as np
|
|
6 |
import gradio as gr
|
7 |
import plotly.graph_objs as go
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
|
14 |
|
15 |
def clear_npz():
|
@@ -26,12 +26,13 @@ def clear_npz():
|
|
26 |
print(f"Failed to delete {file_path}. Reason: {e}")
|
27 |
|
28 |
|
29 |
-
def complex_heat_eq_solution(x, t, a=glob_a, b=glob_b, c=glob_c
|
30 |
-
global glob_a, glob_b, glob_c
|
|
|
31 |
return (
|
32 |
-
np.exp(-
|
33 |
-
+
|
34 |
-
+
|
35 |
)
|
36 |
|
37 |
|
@@ -358,15 +359,15 @@ def train_coefficients(m, kernel):
|
|
358 |
return f"Training completed in {time.time() - start_time:.2f} seconds. The average error is {avg_err}."
|
359 |
|
360 |
|
361 |
-
def plot_function(a, b, c
|
362 |
-
global glob_a, glob_b, glob_c
|
363 |
|
364 |
-
glob_a, glob_b, glob_c
|
365 |
|
366 |
x = np.linspace(0, 1, 100)
|
367 |
t = np.linspace(0, 5, 500)
|
368 |
X, T = np.meshgrid(x, t) # Create the mesh grid
|
369 |
-
Z = complex_heat_eq_solution(X, T,
|
370 |
|
371 |
traces = []
|
372 |
traces.append(
|
@@ -419,19 +420,25 @@ def plot_function(a, b, c, d, k=0.5):
|
|
419 |
|
420 |
return fig
|
421 |
|
|
|
422 |
def plot_all(m, kernel):
|
423 |
# Generate the plot content (replace this with your actual plot logic)
|
424 |
-
approx_fig = plot_heat_equation(
|
425 |
-
|
426 |
-
|
|
|
|
|
427 |
# Return the figures and make the plots visible
|
428 |
return (
|
429 |
gr.update(visible=True, value=approx_fig),
|
430 |
gr.update(visible=True, value=error_fig),
|
431 |
)
|
432 |
|
|
|
433 |
# Gradio interface
|
434 |
def create_gradio_ui():
|
|
|
|
|
435 |
# Get the initial available files
|
436 |
with gr.Blocks() as demo:
|
437 |
gr.Markdown("# Learn the Coefficients for the Heat Equation using the RFM")
|
@@ -447,39 +454,39 @@ def create_gradio_ui():
|
|
447 |
|
448 |
with gr.Row():
|
449 |
with gr.Column():
|
|
|
|
|
|
|
450 |
a_slider = gr.Slider(
|
451 |
-
minimum=-10, maximum
|
452 |
)
|
453 |
b_slider = gr.Slider(
|
454 |
-
minimum=-10, maximum=10, step=1, value
|
455 |
)
|
456 |
c_slider = gr.Slider(
|
457 |
-
minimum=-10, maximum
|
458 |
-
)
|
459 |
-
d_slider = gr.Slider(
|
460 |
-
minimum=-10, maximum=10, step=1, value=7, label="d"
|
461 |
)
|
462 |
|
463 |
plot_output = gr.Plot()
|
464 |
-
|
465 |
-
|
466 |
fn=plot_function,
|
467 |
-
inputs=[a_slider, b_slider, c_slider
|
468 |
outputs=[plot_output],
|
469 |
)
|
470 |
-
|
471 |
fn=plot_function,
|
472 |
-
inputs=[a_slider, b_slider, c_slider
|
473 |
outputs=[plot_output],
|
474 |
)
|
475 |
-
|
476 |
fn=plot_function,
|
477 |
-
inputs=[a_slider, b_slider, c_slider
|
478 |
outputs=[plot_output],
|
479 |
)
|
480 |
-
|
481 |
fn=plot_function,
|
482 |
-
inputs=[a_slider, b_slider, c_slider
|
483 |
outputs=[plot_output],
|
484 |
)
|
485 |
|
@@ -522,7 +529,7 @@ def create_gradio_ui():
|
|
522 |
demo.load(fn=clear_npz, inputs=None, outputs=None)
|
523 |
demo.load(
|
524 |
fn=plot_function,
|
525 |
-
inputs=[a_slider, b_slider, c_slider
|
526 |
outputs=[plot_output],
|
527 |
)
|
528 |
|
|
|
6 |
import gradio as gr
|
7 |
import plotly.graph_objs as go
|
8 |
|
9 |
+
glob_k = 0.0025
|
10 |
+
glob_a = -2.
|
11 |
+
glob_b = 4.
|
12 |
+
glob_c = 7.5
|
13 |
|
14 |
|
15 |
def clear_npz():
|
|
|
26 |
print(f"Failed to delete {file_path}. Reason: {e}")
|
27 |
|
28 |
|
29 |
+
def complex_heat_eq_solution(x, t, k=glob_k, a=glob_a, b=glob_b, c=glob_c):
|
30 |
+
global glob_k, glob_a, glob_b, glob_c
|
31 |
+
glob_k, glob_a, glob_b, glob_c = k, a, b, c
|
32 |
return (
|
33 |
+
np.exp(-glob_k * (glob_b * np.pi) ** 2 * t) * np.cos(glob_a * np.pi * x)
|
34 |
+
+ np.exp(-glob_k * (glob_b * np.pi) ** 2 * t) * np.sin(glob_b * np.pi * x)
|
35 |
+
+ np.exp(-glob_k * (glob_c * np.pi) ** 2 * t) * np.sin(glob_c * np.pi * x)
|
36 |
)
|
37 |
|
38 |
|
|
|
359 |
return f"Training completed in {time.time() - start_time:.2f} seconds. The average error is {avg_err}."
|
360 |
|
361 |
|
362 |
+
def plot_function(k, a, b, c):
|
363 |
+
global glob_k, glob_a, glob_b, glob_c
|
364 |
|
365 |
+
glob_k, glob_a, glob_b, glob_c = k, a, b, c
|
366 |
|
367 |
x = np.linspace(0, 1, 100)
|
368 |
t = np.linspace(0, 5, 500)
|
369 |
X, T = np.meshgrid(x, t) # Create the mesh grid
|
370 |
+
Z = complex_heat_eq_solution(X, T, glob_k, glob_a, glob_b, glob_c)
|
371 |
|
372 |
traces = []
|
373 |
traces.append(
|
|
|
420 |
|
421 |
return fig
|
422 |
|
423 |
+
|
424 |
def plot_all(m, kernel):
|
425 |
# Generate the plot content (replace this with your actual plot logic)
|
426 |
+
approx_fig = plot_heat_equation(
|
427 |
+
m, kernel
|
428 |
+
) # Replace with your function for approx_plot
|
429 |
+
error_fig = plot_errors(m, kernel) # Replace with your function for error_plot
|
430 |
+
|
431 |
# Return the figures and make the plots visible
|
432 |
return (
|
433 |
gr.update(visible=True, value=approx_fig),
|
434 |
gr.update(visible=True, value=error_fig),
|
435 |
)
|
436 |
|
437 |
+
|
438 |
# Gradio interface
|
439 |
def create_gradio_ui():
|
440 |
+
global glob_k, glob_a, glob_b, glob_c
|
441 |
+
|
442 |
# Get the initial available files
|
443 |
with gr.Blocks() as demo:
|
444 |
gr.Markdown("# Learn the Coefficients for the Heat Equation using the RFM")
|
|
|
454 |
|
455 |
with gr.Row():
|
456 |
with gr.Column():
|
457 |
+
k_slider = gr.Slider(
|
458 |
+
minimum=0.0001, maximum=0.1, step=0.0001, value=glob_k, label="k"
|
459 |
+
)
|
460 |
a_slider = gr.Slider(
|
461 |
+
minimum=-10, maximum=10, step=0.1, value=glob_a, label="a"
|
462 |
)
|
463 |
b_slider = gr.Slider(
|
464 |
+
minimum=-10, maximum=10, step=0.1, value=glob_b, label="b"
|
465 |
)
|
466 |
c_slider = gr.Slider(
|
467 |
+
minimum=-10, maximum=10, step=0.1, value=glob_c, label="c"
|
|
|
|
|
|
|
468 |
)
|
469 |
|
470 |
plot_output = gr.Plot()
|
471 |
+
|
472 |
+
k_slider.change(
|
473 |
fn=plot_function,
|
474 |
+
inputs=[k_slider, a_slider, b_slider, c_slider],
|
475 |
outputs=[plot_output],
|
476 |
)
|
477 |
+
a_slider.change(
|
478 |
fn=plot_function,
|
479 |
+
inputs=[k_slider, a_slider, b_slider, c_slider],
|
480 |
outputs=[plot_output],
|
481 |
)
|
482 |
+
b_slider.change(
|
483 |
fn=plot_function,
|
484 |
+
inputs=[k_slider, a_slider, b_slider, c_slider],
|
485 |
outputs=[plot_output],
|
486 |
)
|
487 |
+
c_slider.change(
|
488 |
fn=plot_function,
|
489 |
+
inputs=[k_slider, a_slider, b_slider, c_slider],
|
490 |
outputs=[plot_output],
|
491 |
)
|
492 |
|
|
|
529 |
demo.load(fn=clear_npz, inputs=None, outputs=None)
|
530 |
demo.load(
|
531 |
fn=plot_function,
|
532 |
+
inputs=[k_slider, a_slider, b_slider, c_slider],
|
533 |
outputs=[plot_output],
|
534 |
)
|
535 |
|