ZennyKenny commited on
Commit
edfc8c7
·
verified ·
1 Parent(s): f238ce0

update gradio syntax for latest versions

Browse files
Files changed (1) hide show
  1. app.py +46 -48
app.py CHANGED
@@ -16,9 +16,8 @@ X_train, X_test, y_train, y_test = train_test_split(
16
  X, y, test_size=0.3, random_state=42
17
  )
18
 
19
- # 2. Define a function that takes hyperparameters and returns model accuracy + confusion matrix
20
  def train_and_evaluate(learning_rate, n_estimators, max_depth):
21
- # Train model
22
  clf = GradientBoostingClassifier(
23
  learning_rate=learning_rate,
24
  n_estimators=n_estimators,
@@ -26,25 +25,17 @@ def train_and_evaluate(learning_rate, n_estimators, max_depth):
26
  random_state=42
27
  )
28
  clf.fit(X_train, y_train)
29
-
30
- # Predict on test data
31
  y_pred = clf.predict(X_test)
32
 
33
- # Calculate metrics
34
  accuracy = accuracy_score(y_test, y_pred)
35
  cm = confusion_matrix(y_test, y_pred)
36
-
37
- # Convert confusion matrix to a more display-friendly format
38
- cm_display = ""
39
- for row in cm:
40
- cm_display += str(row) + "\n"
41
 
42
  return f"Accuracy: {accuracy:.3f}\nConfusion Matrix:\n{cm_display}"
43
 
44
  # 3. Define a prediction function for user-supplied feature values
45
  def predict_species(sepal_length, sepal_width, petal_length, petal_width,
46
  learning_rate, n_estimators, max_depth):
47
- # Train a new model using same hyperparams
48
  clf = GradientBoostingClassifier(
49
  learning_rate=learning_rate,
50
  n_estimators=n_estimators,
@@ -53,47 +44,54 @@ def predict_species(sepal_length, sepal_width, petal_length, petal_width,
53
  )
54
  clf.fit(X_train, y_train)
55
 
56
- # Predict species
57
  user_sample = np.array([[sepal_length, sepal_width, petal_length, petal_width]])
58
  prediction = clf.predict(user_sample)[0]
59
  return f"Predicted species: {class_names[prediction]}"
60
 
61
  # 4. Build the Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- # Inputs to tune hyperparameters
64
- hyperparam_inputs = [
65
- gr.inputs.Slider(0.01, 1.0, step=0.01, default=0.1, label="learning_rate"),
66
- gr.inputs.Slider(50, 300, step=50, default=100, label="n_estimators"),
67
- gr.inputs.Slider(1, 10, step=1, default=3, label="max_depth")
68
- ]
69
-
70
- # Button or automatic “live” updates
71
- training_interface = gr.Interface(
72
- fn=train_and_evaluate,
73
- inputs=hyperparam_inputs,
74
- outputs="text",
75
- title="Gradient Boosting Training and Evaluation",
76
- description="Train a GradientBoostingClassifier on the Iris dataset with different hyperparameters."
77
- )
78
-
79
- # Inputs for real-time prediction
80
- feature_inputs = [
81
- gr.inputs.Number(default=5.1, label=feature_names[0]),
82
- gr.inputs.Number(default=3.5, label=feature_names[1]),
83
- gr.inputs.Number(default=1.4, label=feature_names[2]),
84
- gr.inputs.Number(default=0.2, label=feature_names[3])
85
- ] + hyperparam_inputs
86
-
87
- prediction_interface = gr.Interface(
88
- fn=predict_species,
89
- inputs=feature_inputs,
90
- outputs="text",
91
- title="Iris Species Prediction",
92
- description="Use a GradientBoostingClassifier to predict Iris species from user input."
93
- )
94
-
95
- demo = gr.TabbedInterface([training_interface, prediction_interface],
96
- ["Train & Evaluate", "Predict"])
97
-
98
- # Launch the Gradio app
99
  demo.launch()
 
16
  X, y, test_size=0.3, random_state=42
17
  )
18
 
19
+ # 2. Define a function that trains & evaluates a model given hyperparameters
20
  def train_and_evaluate(learning_rate, n_estimators, max_depth):
 
21
  clf = GradientBoostingClassifier(
22
  learning_rate=learning_rate,
23
  n_estimators=n_estimators,
 
25
  random_state=42
26
  )
27
  clf.fit(X_train, y_train)
 
 
28
  y_pred = clf.predict(X_test)
29
 
 
30
  accuracy = accuracy_score(y_test, y_pred)
31
  cm = confusion_matrix(y_test, y_pred)
32
+ cm_display = "\n".join([str(row) for row in cm])
 
 
 
 
33
 
34
  return f"Accuracy: {accuracy:.3f}\nConfusion Matrix:\n{cm_display}"
35
 
36
  # 3. Define a prediction function for user-supplied feature values
37
  def predict_species(sepal_length, sepal_width, petal_length, petal_width,
38
  learning_rate, n_estimators, max_depth):
 
39
  clf = GradientBoostingClassifier(
40
  learning_rate=learning_rate,
41
  n_estimators=n_estimators,
 
44
  )
45
  clf.fit(X_train, y_train)
46
 
 
47
  user_sample = np.array([[sepal_length, sepal_width, petal_length, petal_width]])
48
  prediction = clf.predict(user_sample)[0]
49
  return f"Predicted species: {class_names[prediction]}"
50
 
51
  # 4. Build the Gradio interface
52
+ with gr.Blocks() as demo:
53
+ with gr.Tab("Train & Evaluate"):
54
+ gr.Markdown("## Train a GradientBoostingClassifier on the Iris dataset")
55
+ learning_rate_slider = gr.Slider(0.01, 1.0, value=0.1, step=0.01, label="learning_rate")
56
+ n_estimators_slider = gr.Slider(50, 300, value=100, step=50, label="n_estimators")
57
+ max_depth_slider = gr.Slider(1, 10, value=3, step=1, label="max_depth")
58
+
59
+ train_button = gr.Button("Train & Evaluate")
60
+ output_text = gr.Textbox(label="Results")
61
+
62
+ train_button.click(
63
+ fn=train_and_evaluate,
64
+ inputs=[learning_rate_slider, n_estimators_slider, max_depth_slider],
65
+ outputs=output_text,
66
+ )
67
+
68
+ with gr.Tab("Predict"):
69
+ gr.Markdown("## Predict Iris Species with GradientBoostingClassifier")
70
+ sepal_length_input = gr.Number(value=5.1, label=feature_names[0])
71
+ sepal_width_input = gr.Number(value=3.5, label=feature_names[1])
72
+ petal_length_input = gr.Number(value=1.4, label=feature_names[2])
73
+ petal_width_input = gr.Number(value=0.2, label=feature_names[3])
74
+
75
+ # Hyperparams for the model that will do the prediction
76
+ learning_rate_slider2 = gr.Slider(0.01, 1.0, value=0.1, step=0.01, label="learning_rate")
77
+ n_estimators_slider2 = gr.Slider(50, 300, value=100, step=50, label="n_estimators")
78
+ max_depth_slider2 = gr.Slider(1, 10, value=3, step=1, label="max_depth")
79
+
80
+ predict_button = gr.Button("Predict")
81
+ prediction_text = gr.Textbox(label="Prediction")
82
+
83
+ predict_button.click(
84
+ fn=predict_species,
85
+ inputs=[
86
+ sepal_length_input,
87
+ sepal_width_input,
88
+ petal_length_input,
89
+ petal_width_input,
90
+ learning_rate_slider2,
91
+ n_estimators_slider2,
92
+ max_depth_slider2,
93
+ ],
94
+ outputs=prediction_text
95
+ )
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  demo.launch()