ZennyKenny commited on
Commit
96b98f3
·
verified ·
1 Parent(s): 5356007

first commit -- grad boost demo

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from sklearn.datasets import load_iris
4
+ from sklearn.ensemble import GradientBoostingClassifier
5
+ from sklearn.model_selection import train_test_split
6
+ from sklearn.metrics import accuracy_score, confusion_matrix
7
+
8
+ # 1. Load dataset
9
+ iris = load_iris()
10
+ X, y = iris.data, iris.target
11
+ feature_names = iris.feature_names
12
+ class_names = iris.target_names
13
+
14
+ # Split into train/test
15
+ X_train, X_test, y_train, y_test = train_test_split(
16
+ X, y, test_size=0.3, random_state=42
17
+ )
18
+
19
+ # 2. Define a function that takes hyperparameters and returns model accuracy + confusion matrix
20
+ def train_and_evaluate(learning_rate, n_estimators, max_depth):
21
+ # Train model
22
+ clf = GradientBoostingClassifier(
23
+ learning_rate=learning_rate,
24
+ n_estimators=n_estimators,
25
+ max_depth=int(max_depth),
26
+ random_state=42
27
+ )
28
+ clf.fit(X_train, y_train)
29
+
30
+ # Predict on test data
31
+ y_pred = clf.predict(X_test)
32
+
33
+ # Calculate metrics
34
+ accuracy = accuracy_score(y_test, y_pred)
35
+ cm = confusion_matrix(y_test, y_pred)
36
+
37
+ # Convert confusion matrix to a more display-friendly format
38
+ cm_display = ""
39
+ for row in cm:
40
+ cm_display += str(row) + "\n"
41
+
42
+ return f"Accuracy: {accuracy:.3f}\nConfusion Matrix:\n{cm_display}"
43
+
44
+ # 3. Define a prediction function for user-supplied feature values
45
+ def predict_species(sepal_length, sepal_width, petal_length, petal_width,
46
+ learning_rate, n_estimators, max_depth):
47
+ # Train a new model using same hyperparams
48
+ clf = GradientBoostingClassifier(
49
+ learning_rate=learning_rate,
50
+ n_estimators=n_estimators,
51
+ max_depth=int(max_depth),
52
+ random_state=42
53
+ )
54
+ clf.fit(X_train, y_train)
55
+
56
+ # Predict species
57
+ user_sample = np.array([[sepal_length, sepal_width, petal_length, petal_width]])
58
+ prediction = clf.predict(user_sample)[0]
59
+ return f"Predicted species: {class_names[prediction]}"
60
+
61
+ # 4. Build the Gradio interface
62
+
63
+ # Inputs to tune hyperparameters
64
+ hyperparam_inputs = [
65
+ gr.inputs.Slider(0.01, 1.0, step=0.01, default=0.1, label="learning_rate"),
66
+ gr.inputs.Slider(50, 300, step=50, default=100, label="n_estimators"),
67
+ gr.inputs.Slider(1, 10, step=1, default=3, label="max_depth")
68
+ ]
69
+
70
+ # Button or automatic “live” updates
71
+ training_interface = gr.Interface(
72
+ fn=train_and_evaluate,
73
+ inputs=hyperparam_inputs,
74
+ outputs="text",
75
+ title="Gradient Boosting Training and Evaluation",
76
+ description="Train a GradientBoostingClassifier on the Iris dataset with different hyperparameters."
77
+ )
78
+
79
+ # Inputs for real-time prediction
80
+ feature_inputs = [
81
+ gr.inputs.Number(default=5.1, label=feature_names[0]),
82
+ gr.inputs.Number(default=3.5, label=feature_names[1]),
83
+ gr.inputs.Number(default=1.4, label=feature_names[2]),
84
+ gr.inputs.Number(default=0.2, label=feature_names[3])
85
+ ] + hyperparam_inputs
86
+
87
+ prediction_interface = gr.Interface(
88
+ fn=predict_species,
89
+ inputs=feature_inputs,
90
+ outputs="text",
91
+ title="Iris Species Prediction",
92
+ description="Use a GradientBoostingClassifier to predict Iris species from user input."
93
+ )
94
+
95
+ demo = gr.TabbedInterface([training_interface, prediction_interface],
96
+ ["Train & Evaluate", "Predict"])
97
+
98
+ # Launch the Gradio app
99
+ demo.launch()