Merge branch 'main' of https://huggingface.co/spaces/gradio/xgboost-income-prediction-with-explainability
Browse files- .gitattributes +0 -31
- DESCRIPTION.md +1 -0
- README.md +6 -7
- app.py +6 -17
.gitattributes
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
23 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
26 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DESCRIPTION.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
This demo takes in 12 inputs from the user in dropdowns and sliders and predicts income. It also has a separate button for explaining the prediction.
|
README.md
CHANGED
@@ -1,13 +1,12 @@
|
|
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: 🔥
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license: mit
|
11 |
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
|
2 |
---
|
3 |
+
title: xgboost-income-prediction-with-explainability
|
4 |
emoji: 🔥
|
5 |
+
colorFrom: indigo
|
6 |
+
colorTo: indigo
|
7 |
sdk: gradio
|
8 |
+
sdk_version: 3.4
|
9 |
+
|
10 |
app_file: app.py
|
11 |
pinned: false
|
|
|
12 |
---
|
|
|
|
app.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
-
import random
|
2 |
-
|
3 |
import gradio as gr
|
|
|
4 |
import matplotlib
|
5 |
import matplotlib.pyplot as plt
|
6 |
import pandas as pd
|
@@ -8,14 +7,12 @@ import shap
|
|
8 |
import xgboost as xgb
|
9 |
from datasets import load_dataset
|
10 |
|
11 |
-
matplotlib.use("Agg")
|
12 |
|
|
|
13 |
dataset = load_dataset("scikit-learn/adult-census-income")
|
14 |
-
|
15 |
X_train = dataset["train"].to_pandas()
|
16 |
_ = X_train.pop("fnlwgt")
|
17 |
_ = X_train.pop("race")
|
18 |
-
|
19 |
y_train = X_train.pop("income")
|
20 |
y_train = (y_train == ">50K").astype(int)
|
21 |
categorical_columns = [
|
@@ -28,13 +25,10 @@ categorical_columns = [
|
|
28 |
"native.country",
|
29 |
]
|
30 |
X_train = X_train.astype({col: "category" for col in categorical_columns})
|
31 |
-
|
32 |
-
|
33 |
data = xgb.DMatrix(X_train, label=y_train, enable_categorical=True)
|
34 |
model = xgb.train(params={"objective": "binary:logistic"}, dtrain=data)
|
35 |
explainer = shap.TreeExplainer(model)
|
36 |
|
37 |
-
|
38 |
def predict(*args):
|
39 |
df = pd.DataFrame([args], columns=X_train.columns)
|
40 |
df = df.astype({col: "category" for col in categorical_columns})
|
@@ -51,8 +45,8 @@ def interpret(*args):
|
|
51 |
fig_m = plt.figure(tight_layout=True)
|
52 |
plt.barh([s[1] for s in scores_desc], [s[0] for s in scores_desc])
|
53 |
plt.title("Feature Shap Values")
|
54 |
-
plt.ylabel("
|
55 |
-
plt.xlabel("
|
56 |
plt.tight_layout()
|
57 |
return fig_m
|
58 |
|
@@ -67,12 +61,7 @@ unique_country = sorted(X_train["native.country"].unique())
|
|
67 |
|
68 |
with gr.Blocks() as demo:
|
69 |
gr.Markdown("""
|
70 |
-
|
71 |
-
|
72 |
-
This example shows how to load data from the hugging face hub to train an XGBoost classifier and
|
73 |
-
demo the predictions with gradio.
|
74 |
-
|
75 |
-
The source is [here](https://huggingface.co/spaces/gradio/xgboost-income-prediction-with-explainability).
|
76 |
""")
|
77 |
with gr.Row():
|
78 |
with gr.Column():
|
@@ -136,7 +125,7 @@ with gr.Blocks() as demo:
|
|
136 |
plot = gr.Plot()
|
137 |
with gr.Row():
|
138 |
predict_btn = gr.Button(value="Predict")
|
139 |
-
interpret_btn = gr.Button(value="
|
140 |
predict_btn.click(
|
141 |
predict,
|
142 |
inputs=[
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import random
|
3 |
import matplotlib
|
4 |
import matplotlib.pyplot as plt
|
5 |
import pandas as pd
|
|
|
7 |
import xgboost as xgb
|
8 |
from datasets import load_dataset
|
9 |
|
|
|
10 |
|
11 |
+
matplotlib.use("Agg")
|
12 |
dataset = load_dataset("scikit-learn/adult-census-income")
|
|
|
13 |
X_train = dataset["train"].to_pandas()
|
14 |
_ = X_train.pop("fnlwgt")
|
15 |
_ = X_train.pop("race")
|
|
|
16 |
y_train = X_train.pop("income")
|
17 |
y_train = (y_train == ">50K").astype(int)
|
18 |
categorical_columns = [
|
|
|
25 |
"native.country",
|
26 |
]
|
27 |
X_train = X_train.astype({col: "category" for col in categorical_columns})
|
|
|
|
|
28 |
data = xgb.DMatrix(X_train, label=y_train, enable_categorical=True)
|
29 |
model = xgb.train(params={"objective": "binary:logistic"}, dtrain=data)
|
30 |
explainer = shap.TreeExplainer(model)
|
31 |
|
|
|
32 |
def predict(*args):
|
33 |
df = pd.DataFrame([args], columns=X_train.columns)
|
34 |
df = df.astype({col: "category" for col in categorical_columns})
|
|
|
45 |
fig_m = plt.figure(tight_layout=True)
|
46 |
plt.barh([s[1] for s in scores_desc], [s[0] for s in scores_desc])
|
47 |
plt.title("Feature Shap Values")
|
48 |
+
plt.ylabel("Feature")
|
49 |
+
plt.xlabel("Shap Value")
|
50 |
plt.tight_layout()
|
51 |
return fig_m
|
52 |
|
|
|
61 |
|
62 |
with gr.Blocks() as demo:
|
63 |
gr.Markdown("""
|
64 |
+
**Income Classification with XGBoost 💰**: This demo uses an XGBoost classifier predicts income based on demographic factors, along with Shapley value-based *explanations*. The [source code for this Gradio demo is here](https://huggingface.co/spaces/gradio/xgboost-income-prediction-with-explainability/blob/main/app.py).
|
|
|
|
|
|
|
|
|
|
|
65 |
""")
|
66 |
with gr.Row():
|
67 |
with gr.Column():
|
|
|
125 |
plot = gr.Plot()
|
126 |
with gr.Row():
|
127 |
predict_btn = gr.Button(value="Predict")
|
128 |
+
interpret_btn = gr.Button(value="Explain")
|
129 |
predict_btn.click(
|
130 |
predict,
|
131 |
inputs=[
|