ZennyKenny commited on
Commit
43728f4
Β·
verified Β·
1 Parent(s): edfc8c7

add visualisation elements

Browse files
Files changed (1) hide show
  1. app.py +24 -11
app.py CHANGED
@@ -1,23 +1,22 @@
1
  import gradio as gr
2
  import numpy as np
 
3
  from sklearn.datasets import load_iris
4
  from sklearn.ensemble import GradientBoostingClassifier
5
  from sklearn.model_selection import train_test_split
6
  from sklearn.metrics import accuracy_score, confusion_matrix
7
 
8
- # 1. Load dataset
9
  iris = load_iris()
10
  X, y = iris.data, iris.target
11
  feature_names = iris.feature_names
12
  class_names = iris.target_names
13
 
14
- # Split into train/test
15
  X_train, X_test, y_train, y_test = train_test_split(
16
  X, y, test_size=0.3, random_state=42
17
  )
18
 
19
- # 2. Define a function that trains & evaluates a model given hyperparameters
20
  def train_and_evaluate(learning_rate, n_estimators, max_depth):
 
21
  clf = GradientBoostingClassifier(
22
  learning_rate=learning_rate,
23
  n_estimators=n_estimators,
@@ -25,15 +24,31 @@ def train_and_evaluate(learning_rate, n_estimators, max_depth):
25
  random_state=42
26
  )
27
  clf.fit(X_train, y_train)
28
- y_pred = clf.predict(X_test)
29
 
 
 
30
  accuracy = accuracy_score(y_test, y_pred)
31
  cm = confusion_matrix(y_test, y_pred)
 
 
32
  cm_display = "\n".join([str(row) for row in cm])
33
 
34
- return f"Accuracy: {accuracy:.3f}\nConfusion Matrix:\n{cm_display}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # 3. Define a prediction function for user-supplied feature values
37
  def predict_species(sepal_length, sepal_width, petal_length, petal_width,
38
  learning_rate, n_estimators, max_depth):
39
  clf = GradientBoostingClassifier(
@@ -43,12 +58,10 @@ def predict_species(sepal_length, sepal_width, petal_length, petal_width,
43
  random_state=42
44
  )
45
  clf.fit(X_train, y_train)
46
-
47
  user_sample = np.array([[sepal_length, sepal_width, petal_length, petal_width]])
48
  prediction = clf.predict(user_sample)[0]
49
  return f"Predicted species: {class_names[prediction]}"
50
 
51
- # 4. Build the Gradio interface
52
  with gr.Blocks() as demo:
53
  with gr.Tab("Train & Evaluate"):
54
  gr.Markdown("## Train a GradientBoostingClassifier on the Iris dataset")
@@ -58,11 +71,12 @@ with gr.Blocks() as demo:
58
 
59
  train_button = gr.Button("Train & Evaluate")
60
  output_text = gr.Textbox(label="Results")
 
61
 
62
  train_button.click(
63
  fn=train_and_evaluate,
64
  inputs=[learning_rate_slider, n_estimators_slider, max_depth_slider],
65
- outputs=output_text,
66
  )
67
 
68
  with gr.Tab("Predict"):
@@ -71,8 +85,7 @@ with gr.Blocks() as demo:
71
  sepal_width_input = gr.Number(value=3.5, label=feature_names[1])
72
  petal_length_input = gr.Number(value=1.4, label=feature_names[2])
73
  petal_width_input = gr.Number(value=0.2, label=feature_names[3])
74
-
75
- # Hyperparams for the model that will do the prediction
76
  learning_rate_slider2 = gr.Slider(0.01, 1.0, value=0.1, step=0.01, label="learning_rate")
77
  n_estimators_slider2 = gr.Slider(50, 300, value=100, step=50, label="n_estimators")
78
  max_depth_slider2 = gr.Slider(1, 10, value=3, step=1, label="max_depth")
 
1
  import gradio as gr
2
  import numpy as np
3
+ import matplotlib.pyplot as plt
4
  from sklearn.datasets import load_iris
5
  from sklearn.ensemble import GradientBoostingClassifier
6
  from sklearn.model_selection import train_test_split
7
  from sklearn.metrics import accuracy_score, confusion_matrix
8
 
 
9
  iris = load_iris()
10
  X, y = iris.data, iris.target
11
  feature_names = iris.feature_names
12
  class_names = iris.target_names
13
 
 
14
  X_train, X_test, y_train, y_test = train_test_split(
15
  X, y, test_size=0.3, random_state=42
16
  )
17
 
 
18
  def train_and_evaluate(learning_rate, n_estimators, max_depth):
19
+ # Train model
20
  clf = GradientBoostingClassifier(
21
  learning_rate=learning_rate,
22
  n_estimators=n_estimators,
 
24
  random_state=42
25
  )
26
  clf.fit(X_train, y_train)
 
27
 
28
+ # Predict and compute metrics
29
+ y_pred = clf.predict(X_test)
30
  accuracy = accuracy_score(y_test, y_pred)
31
  cm = confusion_matrix(y_test, y_pred)
32
+
33
+ # Convert confusion matrix to readable string
34
  cm_display = "\n".join([str(row) for row in cm])
35
 
36
+ # Create a feature importance bar chart
37
+ importances = clf.feature_importances_
38
+ fig, ax = plt.subplots()
39
+ ax.barh(range(len(feature_names)), importances, color='skyblue')
40
+ ax.set_yticks(range(len(feature_names)))
41
+ ax.set_yticklabels(feature_names)
42
+ ax.set_xlabel("Importance")
43
+ ax.set_title("Feature Importances (Gradient Boosting)")
44
+
45
+ # Convert the Matplotlib figure to a Gradio-readable format
46
+ # (returns a temporary .png file path)
47
+ return (
48
+ f"Accuracy: {accuracy:.3f}\nConfusion Matrix:\n{cm_display}",
49
+ fig
50
+ )
51
 
 
52
  def predict_species(sepal_length, sepal_width, petal_length, petal_width,
53
  learning_rate, n_estimators, max_depth):
54
  clf = GradientBoostingClassifier(
 
58
  random_state=42
59
  )
60
  clf.fit(X_train, y_train)
 
61
  user_sample = np.array([[sepal_length, sepal_width, petal_length, petal_width]])
62
  prediction = clf.predict(user_sample)[0]
63
  return f"Predicted species: {class_names[prediction]}"
64
 
 
65
  with gr.Blocks() as demo:
66
  with gr.Tab("Train & Evaluate"):
67
  gr.Markdown("## Train a GradientBoostingClassifier on the Iris dataset")
 
71
 
72
  train_button = gr.Button("Train & Evaluate")
73
  output_text = gr.Textbox(label="Results")
74
+ output_plot = gr.Plot(label="Feature Importance")
75
 
76
  train_button.click(
77
  fn=train_and_evaluate,
78
  inputs=[learning_rate_slider, n_estimators_slider, max_depth_slider],
79
+ outputs=[output_text, output_plot],
80
  )
81
 
82
  with gr.Tab("Predict"):
 
85
  sepal_width_input = gr.Number(value=3.5, label=feature_names[1])
86
  petal_length_input = gr.Number(value=1.4, label=feature_names[2])
87
  petal_width_input = gr.Number(value=0.2, label=feature_names[3])
88
+
 
89
  learning_rate_slider2 = gr.Slider(0.01, 1.0, value=0.1, step=0.01, label="learning_rate")
90
  n_estimators_slider2 = gr.Slider(50, 300, value=100, step=50, label="n_estimators")
91
  max_depth_slider2 = gr.Slider(1, 10, value=3, step=1, label="max_depth")