freddyaboulton HF staff commited on
Commit
0af320d
·
1 Parent(s): 517f747
Files changed (2) hide show
  1. app.py +167 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import gradio as gr
4
+ import matplotlib
5
+ import matplotlib.pyplot as plt
6
+ import pandas as pd
7
+ import shap
8
+ import xgboost as xgb
9
+ from datasets import load_dataset
10
+
11
+ matplotlib.use("Agg")
12
+
13
+ dataset = load_dataset("scikit-learn/adult-census-income")
14
+
15
+ X_train = dataset["train"].to_pandas()
16
+ _ = X_train.pop("fnlwgt")
17
+ _ = X_train.pop("race")
18
+
19
+ y_train = X_train.pop("income")
20
+ y_train = (y_train == ">50K").astype(int)
21
+ categorical_columns = [
22
+ "workclass",
23
+ "education",
24
+ "marital.status",
25
+ "occupation",
26
+ "relationship",
27
+ "sex",
28
+ "native.country",
29
+ ]
30
+ X_train = X_train.astype({col: "category" for col in categorical_columns})
31
+
32
+
33
+ data = xgb.DMatrix(X_train, label=y_train, enable_categorical=True)
34
+ model = xgb.train(params={"objective": "binary:logistic"}, dtrain=data)
35
+ explainer = shap.TreeExplainer(model)
36
+
37
+
38
+ def predict(*args):
39
+ df = pd.DataFrame([args], columns=X_train.columns)
40
+ df = df.astype({col: "category" for col in categorical_columns})
41
+ pos_pred = model.predict(xgb.DMatrix(df, enable_categorical=True))
42
+ return {">50K": float(pos_pred[0]), "<=50K": 1 - float(pos_pred[0])}
43
+
44
+
45
+ def interpret(*args):
46
+ df = pd.DataFrame([args], columns=X_train.columns)
47
+ df = df.astype({col: "category" for col in categorical_columns})
48
+ shap_values = explainer.shap_values(xgb.DMatrix(df, enable_categorical=True))
49
+ scores_desc = list(zip(shap_values[0], X_train.columns))
50
+ scores_desc = sorted(scores_desc)
51
+ fig_m = plt.figure()
52
+ plt.barh([s[1] for s in scores_desc], [s[0] for s in scores_desc])
53
+ plt.title("Feature Shap Values")
54
+ plt.ylabel("Shap Value")
55
+ plt.xlabel("Feature")
56
+ return fig_m
57
+
58
+
59
+ unique_class = sorted(X_train["workclass"].unique())
60
+ unique_education = sorted(X_train["education"].unique())
61
+ unique_marital_status = sorted(X_train["marital.status"].unique())
62
+ unique_relationship = sorted(X_train["relationship"].unique())
63
+ unique_occupation = sorted(X_train["occupation"].unique())
64
+ unique_sex = sorted(X_train["sex"].unique())
65
+ unique_country = sorted(X_train["native.country"].unique())
66
+
67
+ with gr.Blocks() as demo:
68
+ with gr.Row():
69
+ with gr.Column():
70
+ age = gr.Slider(label="Age", minimum=17, maximum=90, step=1, randomize=True)
71
+ work_class = gr.Dropdown(
72
+ label="Workclass",
73
+ choices=unique_class,
74
+ value=lambda: random.choice(unique_class),
75
+ )
76
+ education = gr.Dropdown(
77
+ label="Education Level",
78
+ choices=unique_education,
79
+ value=lambda: random.choice(unique_education),
80
+ )
81
+ years = gr.Slider(
82
+ label="Years of schooling",
83
+ minimum=1,
84
+ maximum=16,
85
+ step=1,
86
+ randomize=True,
87
+ )
88
+ marital_status = gr.Dropdown(
89
+ label="Marital Status",
90
+ choices=unique_marital_status,
91
+ value=lambda: random.choice(unique_marital_status),
92
+ )
93
+ occupation = gr.Dropdown(
94
+ label="Occupation",
95
+ choices=unique_occupation,
96
+ value=lambda: random.choice(unique_education),
97
+ )
98
+ relationship = gr.Dropdown(
99
+ label="Relationship Status",
100
+ choices=unique_relationship,
101
+ value=lambda: random.choice(unique_relationship),
102
+ )
103
+ sex = gr.Dropdown(
104
+ label="Sex", choices=unique_sex, value=lambda: random.choice(unique_sex)
105
+ )
106
+ capital_gain = gr.Slider(
107
+ label="Capital Gain",
108
+ minimum=0,
109
+ maximum=100000,
110
+ step=500,
111
+ randomize=True,
112
+ )
113
+ capital_loss = gr.Slider(
114
+ label="Capital Loss", minimum=0, maximum=10000, step=500, randomize=True
115
+ )
116
+ hours_per_week = gr.Slider(
117
+ label="Hours Per Week Worked", minimum=1, maximum=99, step=1
118
+ )
119
+ country = gr.Dropdown(
120
+ label="Native Country",
121
+ choices=unique_country,
122
+ value=lambda: random.choice(unique_country),
123
+ )
124
+ with gr.Column():
125
+ label = gr.Label()
126
+ plot = gr.Plot()
127
+ with gr.Row():
128
+ predict_btn = gr.Button(value="Predict")
129
+ interpret_btn = gr.Button(value="Interpret")
130
+ predict_btn.click(
131
+ predict,
132
+ inputs=[
133
+ age,
134
+ work_class,
135
+ education,
136
+ years,
137
+ marital_status,
138
+ occupation,
139
+ relationship,
140
+ sex,
141
+ capital_gain,
142
+ capital_loss,
143
+ hours_per_week,
144
+ country,
145
+ ],
146
+ outputs=[label],
147
+ )
148
+ interpret_btn.click(
149
+ interpret,
150
+ inputs=[
151
+ age,
152
+ work_class,
153
+ education,
154
+ years,
155
+ marital_status,
156
+ occupation,
157
+ relationship,
158
+ sex,
159
+ capital_gain,
160
+ capital_loss,
161
+ hours_per_week,
162
+ country,
163
+ ],
164
+ outputs=[plot],
165
+ )
166
+
167
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ matplotlib
2
+ shap
3
+ xgboost
4
+ pandas
5
+ datasets