freddyaboulton HF staff commited on
Commit
84126a5
·
1 Parent(s): 0af320d

Better layout

Browse files
Files changed (1) hide show
  1. app.py +41 -40
app.py CHANGED
@@ -48,11 +48,12 @@ def interpret(*args):
48
  shap_values = explainer.shap_values(xgb.DMatrix(df, enable_categorical=True))
49
  scores_desc = list(zip(shap_values[0], X_train.columns))
50
  scores_desc = sorted(scores_desc)
51
- fig_m = plt.figure()
52
  plt.barh([s[1] for s in scores_desc], [s[0] for s in scores_desc])
53
  plt.title("Feature Shap Values")
54
  plt.ylabel("Shap Value")
55
  plt.xlabel("Feature")
 
56
  return fig_m
57
 
58
 
@@ -124,44 +125,44 @@ with gr.Blocks() as demo:
124
  with gr.Column():
125
  label = gr.Label()
126
  plot = gr.Plot()
127
- with gr.Row():
128
- predict_btn = gr.Button(value="Predict")
129
- interpret_btn = gr.Button(value="Interpret")
130
- predict_btn.click(
131
- predict,
132
- inputs=[
133
- age,
134
- work_class,
135
- education,
136
- years,
137
- marital_status,
138
- occupation,
139
- relationship,
140
- sex,
141
- capital_gain,
142
- capital_loss,
143
- hours_per_week,
144
- country,
145
- ],
146
- outputs=[label],
147
- )
148
- interpret_btn.click(
149
- interpret,
150
- inputs=[
151
- age,
152
- work_class,
153
- education,
154
- years,
155
- marital_status,
156
- occupation,
157
- relationship,
158
- sex,
159
- capital_gain,
160
- capital_loss,
161
- hours_per_week,
162
- country,
163
- ],
164
- outputs=[plot],
165
- )
166
 
167
  demo.launch()
 
48
  shap_values = explainer.shap_values(xgb.DMatrix(df, enable_categorical=True))
49
  scores_desc = list(zip(shap_values[0], X_train.columns))
50
  scores_desc = sorted(scores_desc)
51
+ fig_m = plt.figure(tight_layout=True)
52
  plt.barh([s[1] for s in scores_desc], [s[0] for s in scores_desc])
53
  plt.title("Feature Shap Values")
54
  plt.ylabel("Shap Value")
55
  plt.xlabel("Feature")
56
+ plt.tight_layout()
57
  return fig_m
58
 
59
 
 
125
  with gr.Column():
126
  label = gr.Label()
127
  plot = gr.Plot()
128
+ with gr.Row():
129
+ predict_btn = gr.Button(value="Predict")
130
+ interpret_btn = gr.Button(value="Interpret")
131
+ predict_btn.click(
132
+ predict,
133
+ inputs=[
134
+ age,
135
+ work_class,
136
+ education,
137
+ years,
138
+ marital_status,
139
+ occupation,
140
+ relationship,
141
+ sex,
142
+ capital_gain,
143
+ capital_loss,
144
+ hours_per_week,
145
+ country,
146
+ ],
147
+ outputs=[label],
148
+ )
149
+ interpret_btn.click(
150
+ interpret,
151
+ inputs=[
152
+ age,
153
+ work_class,
154
+ education,
155
+ years,
156
+ marital_status,
157
+ occupation,
158
+ relationship,
159
+ sex,
160
+ capital_gain,
161
+ capital_loss,
162
+ hours_per_week,
163
+ country,
164
+ ],
165
+ outputs=[plot],
166
+ )
167
 
168
  demo.launch()