Lennard Schober commited on
Commit
e581117
·
1 Parent(s): 7a286ff

Fix dynamic function change

Browse files
Files changed (1) hide show
  1. app.py +17 -9
app.py CHANGED
@@ -7,8 +7,8 @@ import gradio as gr
7
  import plotly.graph_objs as go
8
 
9
  glob_k = 0.0025
10
- glob_a = -2.
11
- glob_b = 4.
12
  glob_c = 7.5
13
 
14
 
@@ -26,7 +26,7 @@ def clear_npz():
26
  print(f"Failed to delete {file_path}. Reason: {e}")
27
 
28
 
29
- def complex_heat_eq_solution(x, t, k=glob_k, a=glob_a, b=glob_b, c=glob_c):
30
  global glob_k, glob_a, glob_b, glob_c
31
  glob_k, glob_a, glob_b, glob_c = k, a, b, c
32
  return (
@@ -37,6 +37,7 @@ def complex_heat_eq_solution(x, t, k=glob_k, a=glob_a, b=glob_b, c=glob_c):
37
 
38
 
39
  def plot_heat_equation(m, approx_type):
 
40
  # Define grid dimensions
41
  n_x = 32 # Fixed spatial grid resolution
42
  n_t = 50
@@ -54,7 +55,7 @@ def plot_heat_equation(m, approx_type):
54
  X, T = np.meshgrid(x, t)
55
 
56
  # Compute the real solution over the grid
57
- U_real = complex_heat_eq_solution(X, T)
58
 
59
  # Compute the selected approximation
60
  U_approx = np.zeros_like(U_real)
@@ -172,7 +173,7 @@ def plot_errors(m, approx_type):
172
  X, T = np.meshgrid(x, t)
173
 
174
  # Compute the real solution over the grid
175
- U_real = complex_heat_eq_solution(X, T)
176
 
177
  # Compute the selected approximation
178
  U_approx = np.zeros_like(U_real)
@@ -261,13 +262,14 @@ def plot_errors(m, approx_type):
261
 
262
 
263
  def generate_data(n_x=32, n_t=50):
 
264
  """Generate training data."""
265
  x = np.linspace(0, 1, n_x) # spatial points
266
  t = np.linspace(0, 5, n_t) # temporal points
267
  X, T = np.meshgrid(x, t)
268
  a_train = np.c_[X.ravel(), T.ravel()] # shape (n_x * n_t, 2)
269
  u_train = complex_heat_eq_solution(
270
- a_train[:, 0], a_train[:, 1]
271
  ) # shape (n_x * n_t,)
272
  return a_train, u_train, x, t
273
 
@@ -321,6 +323,7 @@ def polyfit2d(x, y, z, kx=3, ky=3, order=None):
321
 
322
 
323
  def train_coefficients(m, kernel):
 
324
  # Start time for training
325
  start_time = time.time()
326
 
@@ -328,6 +331,8 @@ def train_coefficients(m, kernel):
328
  n_x, n_t = 32, 50
329
  a_train, u_train, x, t = generate_data(n_x, n_t)
330
 
 
 
331
  # Define random features
332
  theta = np.column_stack(
333
  (
@@ -340,7 +345,7 @@ def train_coefficients(m, kernel):
340
  Phi = design_matrix(a_train, theta, kernel)
341
  alpha = learn_coefficients(Phi, u_train)
342
  # Validate and animate results
343
- u_real = np.array([complex_heat_eq_solution(x, t_i) for t_i in t])
344
  a_test = np.c_[np.meshgrid(x, t)[0].ravel(), np.meshgrid(x, t)[1].ravel()]
345
  u_approx = approximate_solution(a_test, alpha, theta, kernel).reshape(n_t, n_x)
346
 
@@ -364,6 +369,8 @@ def plot_function(k, a, b, c):
364
 
365
  glob_k, glob_a, glob_b, glob_c = k, a, b, c
366
 
 
 
367
  x = np.linspace(0, 1, 100)
368
  t = np.linspace(0, 5, 500)
369
  X, T = np.meshgrid(x, t) # Create the mesh grid
@@ -434,10 +441,11 @@ def plot_all(m, kernel):
434
  gr.update(visible=True, value=error_fig),
435
  )
436
 
 
437
  # Gradio interface
438
  def create_gradio_ui():
439
  global glob_k, glob_a, glob_b, glob_c
440
-
441
  # Get the initial available files
442
  with gr.Blocks() as demo:
443
  gr.Markdown("# Learn the Coefficients for the Heat Equation using the RFM")
@@ -467,7 +475,7 @@ def create_gradio_ui():
467
  )
468
 
469
  plot_output = gr.Plot()
470
-
471
  k_slider.change(
472
  fn=plot_function,
473
  inputs=[k_slider, a_slider, b_slider, c_slider],
 
7
  import plotly.graph_objs as go
8
 
9
  glob_k = 0.0025
10
+ glob_a = -2.0
11
+ glob_b = 4.0
12
  glob_c = 7.5
13
 
14
 
 
26
  print(f"Failed to delete {file_path}. Reason: {e}")
27
 
28
 
29
+ def complex_heat_eq_solution(x, t, k, a, b, c):
30
  global glob_k, glob_a, glob_b, glob_c
31
  glob_k, glob_a, glob_b, glob_c = k, a, b, c
32
  return (
 
37
 
38
 
39
  def plot_heat_equation(m, approx_type):
40
+ global glob_k, glob_a, glob_b, glob_c
41
  # Define grid dimensions
42
  n_x = 32 # Fixed spatial grid resolution
43
  n_t = 50
 
55
  X, T = np.meshgrid(x, t)
56
 
57
  # Compute the real solution over the grid
58
+ U_real = complex_heat_eq_solution(X, T, glob_k, glob_a, glob_b, glob_c)
59
 
60
  # Compute the selected approximation
61
  U_approx = np.zeros_like(U_real)
 
173
  X, T = np.meshgrid(x, t)
174
 
175
  # Compute the real solution over the grid
176
+ U_real = complex_heat_eq_solution(X, T, glob_k, glob_a, glob_b, glob_c)
177
 
178
  # Compute the selected approximation
179
  U_approx = np.zeros_like(U_real)
 
262
 
263
 
264
  def generate_data(n_x=32, n_t=50):
265
+ global glob_k, glob_a, glob_b, glob_c
266
  """Generate training data."""
267
  x = np.linspace(0, 1, n_x) # spatial points
268
  t = np.linspace(0, 5, n_t) # temporal points
269
  X, T = np.meshgrid(x, t)
270
  a_train = np.c_[X.ravel(), T.ravel()] # shape (n_x * n_t, 2)
271
  u_train = complex_heat_eq_solution(
272
+ a_train[:, 0], a_train[:, 1], glob_k, glob_a, glob_b, glob_c
273
  ) # shape (n_x * n_t,)
274
  return a_train, u_train, x, t
275
 
 
323
 
324
 
325
  def train_coefficients(m, kernel):
326
+ global glob_k, glob_a, glob_b, glob_c
327
  # Start time for training
328
  start_time = time.time()
329
 
 
331
  n_x, n_t = 32, 50
332
  a_train, u_train, x, t = generate_data(n_x, n_t)
333
 
334
+ print("in train coeffs: ", glob_k)
335
+
336
  # Define random features
337
  theta = np.column_stack(
338
  (
 
345
  Phi = design_matrix(a_train, theta, kernel)
346
  alpha = learn_coefficients(Phi, u_train)
347
  # Validate and animate results
348
+ u_real = np.array([complex_heat_eq_solution(x, t_i, glob_k, glob_a, glob_b, glob_c) for t_i in t])
349
  a_test = np.c_[np.meshgrid(x, t)[0].ravel(), np.meshgrid(x, t)[1].ravel()]
350
  u_approx = approximate_solution(a_test, alpha, theta, kernel).reshape(n_t, n_x)
351
 
 
369
 
370
  glob_k, glob_a, glob_b, glob_c = k, a, b, c
371
 
372
+ print(glob_k, k)
373
+
374
  x = np.linspace(0, 1, 100)
375
  t = np.linspace(0, 5, 500)
376
  X, T = np.meshgrid(x, t) # Create the mesh grid
 
441
  gr.update(visible=True, value=error_fig),
442
  )
443
 
444
+
445
  # Gradio interface
446
  def create_gradio_ui():
447
  global glob_k, glob_a, glob_b, glob_c
448
+
449
  # Get the initial available files
450
  with gr.Blocks() as demo:
451
  gr.Markdown("# Learn the Coefficients for the Heat Equation using the RFM")
 
475
  )
476
 
477
  plot_output = gr.Plot()
478
+
479
  k_slider.change(
480
  fn=plot_function,
481
  inputs=[k_slider, a_slider, b_slider, c_slider],