Commit
·
84126a5
1
Parent(s):
0af320d
Better layout
Browse files
app.py
CHANGED
@@ -48,11 +48,12 @@ def interpret(*args):
|
|
48 |
shap_values = explainer.shap_values(xgb.DMatrix(df, enable_categorical=True))
|
49 |
scores_desc = list(zip(shap_values[0], X_train.columns))
|
50 |
scores_desc = sorted(scores_desc)
|
51 |
-
fig_m = plt.figure()
|
52 |
plt.barh([s[1] for s in scores_desc], [s[0] for s in scores_desc])
|
53 |
plt.title("Feature Shap Values")
|
54 |
plt.ylabel("Shap Value")
|
55 |
plt.xlabel("Feature")
|
|
|
56 |
return fig_m
|
57 |
|
58 |
|
@@ -124,44 +125,44 @@ with gr.Blocks() as demo:
|
|
124 |
with gr.Column():
|
125 |
label = gr.Label()
|
126 |
plot = gr.Plot()
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
|
167 |
demo.launch()
|
|
|
48 |
shap_values = explainer.shap_values(xgb.DMatrix(df, enable_categorical=True))
|
49 |
scores_desc = list(zip(shap_values[0], X_train.columns))
|
50 |
scores_desc = sorted(scores_desc)
|
51 |
+
fig_m = plt.figure(tight_layout=True)
|
52 |
plt.barh([s[1] for s in scores_desc], [s[0] for s in scores_desc])
|
53 |
plt.title("Feature Shap Values")
|
54 |
plt.ylabel("Shap Value")
|
55 |
plt.xlabel("Feature")
|
56 |
+
plt.tight_layout()
|
57 |
return fig_m
|
58 |
|
59 |
|
|
|
125 |
with gr.Column():
|
126 |
label = gr.Label()
|
127 |
plot = gr.Plot()
|
128 |
+
with gr.Row():
|
129 |
+
predict_btn = gr.Button(value="Predict")
|
130 |
+
interpret_btn = gr.Button(value="Interpret")
|
131 |
+
predict_btn.click(
|
132 |
+
predict,
|
133 |
+
inputs=[
|
134 |
+
age,
|
135 |
+
work_class,
|
136 |
+
education,
|
137 |
+
years,
|
138 |
+
marital_status,
|
139 |
+
occupation,
|
140 |
+
relationship,
|
141 |
+
sex,
|
142 |
+
capital_gain,
|
143 |
+
capital_loss,
|
144 |
+
hours_per_week,
|
145 |
+
country,
|
146 |
+
],
|
147 |
+
outputs=[label],
|
148 |
+
)
|
149 |
+
interpret_btn.click(
|
150 |
+
interpret,
|
151 |
+
inputs=[
|
152 |
+
age,
|
153 |
+
work_class,
|
154 |
+
education,
|
155 |
+
years,
|
156 |
+
marital_status,
|
157 |
+
occupation,
|
158 |
+
relationship,
|
159 |
+
sex,
|
160 |
+
capital_gain,
|
161 |
+
capital_loss,
|
162 |
+
hours_per_week,
|
163 |
+
country,
|
164 |
+
],
|
165 |
+
outputs=[plot],
|
166 |
+
)
|
167 |
|
168 |
demo.launch()
|