pantdipendra commited on
Commit
c38370d
·
verified ·
1 Parent(s): 57c1c6c
Files changed (1) hide show
  1. app.py +104 -43
app.py CHANGED
@@ -52,7 +52,6 @@ class ModelPredictor:
52
  combined_predictions = np.concatenate(predictions)
53
  majority_vote = np.bincount(combined_predictions).argmax()
54
  return majority_vote
55
-
56
  # Based on Equal Interval and Percentage-Based Method
57
  # Severe: 13 to 16 votes (upper 25%) Moderate: 9 to 12 votes (upper-middle 25%) Low: 5 to 8 votes (lower-middle 25%) Very Low: 0 to 4 votes (lower 25%)
58
  def evaluate_severity(self, majority_vote_count):
@@ -74,6 +73,72 @@ model_filenames = [
74
  model_path = "models/"
75
  predictor = ModelPredictor(model_path, model_filenames)
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  # Mapping user-friendly text to corresponding numeric values for the prediction function
78
  input_mapping = {
79
  'YNURSMDE': {"Yes": 1, "No": 0},
@@ -107,48 +172,7 @@ input_mapping = {
107
  'YMDELT': {"Yes": 1, "No": 2}
108
  }
109
 
110
- def validate_inputs(*args):
111
- for arg in args:
112
- if arg == '' or arg is None: # Assuming empty string or None as unselected
113
- return False
114
- return True
115
-
116
- def predict_with_text(YNURSMDE, YMDEYR, YSOCMDE, YMDESUD5ANYO, YMSUD5YANY, YUSUITHK, YMDETXRX, YUSUITHKYR, YMDERSUD5ANY, YUSUIPLNYR, YCOUNMDE, YPSY1MDE, YHLTMDE, YDOCMDE, YPSY2MDE, YMDEHARX, LVLDIFMEM2, MDEIMPY, YMDEHPO, YMIMS5YANY, YMDEIMAD5YR, YMIUD5YANY, YMDEHPRX, YMIMI5YANY, YUSUIPLN, YTXMDEYR, YMDEAUD5YR, YRXMDEYR, YMDELT):
117
- if not validate_inputs(YNURSMDE, YMDEYR, YSOCMDE, YMDESUD5ANYO, YMSUD5YANY, YUSUITHK, YMDETXRX, YUSUITHKYR, YMDERSUD5ANY, YUSUIPLNYR, YCOUNMDE, YPSY1MDE, YHLTMDE, YDOCMDE, YPSY2MDE, YMDEHARX, LVLDIFMEM2, MDEIMPY, YMDEHPO, YMIMS5YANY, YMDEIMAD5YR, YMIUD5YANY, YMDEHPRX, YMIMI5YANY, YUSUIPLN, YTXMDEYR, YMDEAUD5YR, YRXMDEYR, YMDELT):
118
- return "Please select all required fields.", "Validation Error"
119
- user_inputs = {
120
- 'YNURSMDE': input_mapping['YNURSMDE'][YNURSMDE],
121
- 'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
122
- 'YSOCMDE': input_mapping['YSOCMDE'][YSOCMDE],
123
- 'YMDESUD5ANYO': input_mapping['YMDESUD5ANYO'][YMDESUD5ANYO],
124
- 'YMSUD5YANY': input_mapping['YMSUD5YANY'][YMSUD5YANY],
125
- 'YUSUITHK': input_mapping['YUSUITHK'][YUSUITHK],
126
- 'YMDETXRX': input_mapping['YMDETXRX'][YMDETXRX],
127
- 'YUSUITHKYR': input_mapping['YUSUITHKYR'][YUSUITHKYR],
128
- 'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY],
129
- 'YUSUIPLNYR': input_mapping['YUSUIPLNYR'][YUSUIPLNYR],
130
- 'YCOUNMDE': input_mapping['YCOUNMDE'][YCOUNMDE],
131
- 'YPSY1MDE': input_mapping['YPSY1MDE'][YPSY1MDE],
132
- 'YHLTMDE': input_mapping['YHLTMDE'][YHLTMDE],
133
- 'YDOCMDE': input_mapping['YDOCMDE'][YDOCMDE],
134
- 'YPSY2MDE': input_mapping['YPSY2MDE'][YPSY2MDE],
135
- 'YMDEHARX': input_mapping['YMDEHARX'][YMDEHARX],
136
- 'LVLDIFMEM2': input_mapping['LVLDIFMEM2'][LVLDIFMEM2],
137
- 'MDEIMPY': input_mapping['MDEIMPY'][MDEIMPY],
138
- 'YMDEHPO': input_mapping['YMDEHPO'][YMDEHPO],
139
- 'YMIMS5YANY': input_mapping['YMIMS5YANY'][YMIMS5YANY],
140
- 'YMDEIMAD5YR': input_mapping['YMDEIMAD5YR'][YMDEIMAD5YR],
141
- 'YMIUD5YANY': input_mapping['YMIUD5YANY'][YMIUD5YANY],
142
- 'YMDEHPRX': input_mapping['YMDEHPRX'][YMDEHPRX],
143
- 'YMIMI5YANY': input_mapping['YMIMI5YANY'][YMIMI5YANY],
144
- 'YUSUIPLN': input_mapping['YUSUIPLN'][YUSUIPLN],
145
- 'YTXMDEYR': input_mapping['YTXMDEYR'][YTXMDEYR],
146
- 'YMDEAUD5YR': input_mapping['YMDEAUD5YR'][YMDEAUD5YR],
147
- 'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR],
148
- 'YMDELT': input_mapping['YMDELT'][YMDELT]
149
- }
150
- return predict(**user_inputs)
151
-
152
  inputs = [
153
  gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE: SAW/TALK TO NURSE/OCC THERA ABOUT MAJOR DEPRESSIVE EPISODE (MDE) IN PAST YEAR (PY)"),
154
  gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR: PAST YEAR MAJOR DEPRESSIVE EPISODE"),
@@ -186,4 +210,41 @@ outputs = [
186
  gr.Textbox(label="Mental Health Severity")
187
  ]
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  gr.Interface(fn=predict_with_text, inputs=inputs, outputs=outputs, title="Adolescents Mental Health Multi-Output Prediction From NSDUH Data").launch()
 
52
  combined_predictions = np.concatenate(predictions)
53
  majority_vote = np.bincount(combined_predictions).argmax()
54
  return majority_vote
 
55
  # Based on Equal Interval and Percentage-Based Method
56
  # Severe: 13 to 16 votes (upper 25%) Moderate: 9 to 12 votes (upper-middle 25%) Low: 5 to 8 votes (lower-middle 25%) Very Low: 0 to 4 votes (lower 25%)
57
  def evaluate_severity(self, majority_vote_count):
 
73
  model_path = "models/"
74
  predictor = ModelPredictor(model_path, model_filenames)
75
 
76
+ def validate_inputs(*args):
77
+ for arg in args:
78
+ if arg == '' or arg is None: # Assuming empty string or None as unselected
79
+ return False
80
+ return True
81
+
82
+ def predict(YNURSMDE, YMDEYR, YSOCMDE, YMDESUD5ANYO, YMSUD5YANY, YUSUITHK, YMDETXRX, YUSUITHKYR, YMDERSUD5ANY, YUSUIPLNYR, YCOUNMDE, YPSY1MDE, YHLTMDE, YDOCMDE, YPSY2MDE, YMDEHARX, LVLDIFMEM2, MDEIMPY, YMDEHPO, YMIMS5YANY, YMDEIMAD5YR, YMIUD5YANY, YMDEHPRX, YMIMI5YANY, YUSUIPLN, YTXMDEYR, YMDEAUD5YR, YRXMDEYR, YMDELT):
83
+ user_input_data = {
84
+ 'YNURSMDE': [int(YNURSMDE)],
85
+ 'YMDEYR': [int(YMDEYR)],
86
+ 'YSOCMDE': [int(YSOCMDE)],
87
+ 'YMDESUD5ANYO': [int(YMDESUD5ANYO)],
88
+ 'YMSUD5YANY': [int(YMSUD5YANY)],
89
+ 'YUSUITHK': [int(YUSUITHK)],
90
+ 'YMDETXRX': [int(YMDETXRX)],
91
+ 'YUSUITHKYR': [int(YUSUITHKYR)],
92
+ 'YMDERSUD5ANY': [int(YMDERSUD5ANY)],
93
+ 'YUSUIPLNYR': [int(YUSUIPLNYR)],
94
+ 'YCOUNMDE': [int(YCOUNMDE)],
95
+ 'YPSY1MDE': [int(YPSY1MDE)],
96
+ 'YHLTMDE': [int(YHLTMDE)],
97
+ 'YDOCMDE': [int(YDOCMDE)],
98
+ 'YPSY2MDE': [int(YPSY2MDE)],
99
+ 'YMDEHARX': [int(YMDEHARX)],
100
+ 'LVLDIFMEM2': [int(LVLDIFMEM2)],
101
+ 'MDEIMPY': [int(MDEIMPY)],
102
+ 'YMDEHPO': [int(YMDEHPO)],
103
+ 'YMIMS5YANY': [int(YMIMS5YANY)],
104
+ 'YMDEIMAD5YR': [int(YMDEIMAD5YR)],
105
+ 'YMIUD5YANY': [int(YMIUD5YANY)],
106
+ 'YMDEHPRX': [int(YMDEHPRX)],
107
+ 'YMIMI5YANY': [int(YMIMI5YANY)],
108
+ 'YUSUIPLN': [int(YUSUIPLN)],
109
+ 'YTXMDEYR': [int(YTXMDEYR)],
110
+ 'YMDEAUD5YR': [int(YMDEAUD5YR)],
111
+ 'YRXMDEYR': [int(YRXMDEYR)],
112
+ 'YMDELT': [int(YMDELT)]
113
+ }
114
+ user_input = pd.DataFrame(user_input_data)
115
+ predictions = predictor.make_predictions(user_input)
116
+ majority_vote = predictor.get_majority_vote(predictions)
117
+ majority_vote_count = sum([1 for pred in np.concatenate(predictions) if pred == 1])
118
+ severity = predictor.evaluate_severity(majority_vote_count)
119
+
120
+ results = []
121
+ unknown_count = 0
122
+ for i, pred in enumerate(predictions):
123
+ model_name = model_filenames[i].split('.')[0]
124
+ pred_value = pred[0]
125
+ if model_name in predictor.prediction_map:
126
+ if pred_value < len(predictor.prediction_map[model_name]):
127
+ results.append(f"Model {model_name}: {predictor.prediction_map[model_name][pred_value]}")
128
+ else:
129
+ results.append(f"Model {model_name}: Unknown prediction value {pred_value}")
130
+ unknown_count += 1
131
+ else:
132
+ results.append(f"Model {model_name}: Unknown model")
133
+ unknown_count += 1
134
+
135
+ formatted_results = "\n".join(results)
136
+
137
+ if unknown_count > len(model_filenames) / 2:
138
+ severity += " (Unknown prediction count is more. So not sure please consult with a human.)"
139
+
140
+ return formatted_results, severity
141
+
142
  # Mapping user-friendly text to corresponding numeric values for the prediction function
143
  input_mapping = {
144
  'YNURSMDE': {"Yes": 1, "No": 0},
 
172
  'YMDELT': {"Yes": 1, "No": 2}
173
  }
174
 
175
+ # Create Gradio inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  inputs = [
177
  gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE: SAW/TALK TO NURSE/OCC THERA ABOUT MAJOR DEPRESSIVE EPISODE (MDE) IN PAST YEAR (PY)"),
178
  gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR: PAST YEAR MAJOR DEPRESSIVE EPISODE"),
 
210
  gr.Textbox(label="Mental Health Severity")
211
  ]
212
 
213
+ def predict_with_text(YNURSMDE, YMDEYR, YSOCMDE, YMDESUD5ANYO, YMSUD5YANY, YUSUITHK, YMDETXRX, YUSUITHKYR, YMDERSUD5ANY, YUSUIPLNYR, YCOUNMDE, YPSY1MDE, YHLTMDE, YDOCMDE, YPSY2MDE, YMDEHARX, LVLDIFMEM2, MDEIMPY, YMDEHPO, YMIMS5YANY, YMDEIMAD5YR, YMIUD5YANY, YMDEHPRX, YMIMI5YANY, YUSUIPLN, YTXMDEYR, YMDEAUD5YR, YRXMDEYR, YMDELT):
214
+ if not validate_inputs(YNURSMDE, YMDEYR, YSOCMDE, YMDESUD5ANYO, YMSUD5YANY, YUSUITHK, YMDETXRX, YUSUITHKYR, YMDERSUD5ANY, YUSUIPLNYR, YCOUNMDE, YPSY1MDE, YHLTMDE, YDOCMDE, YPSY2MDE, YMDEHARX, LVLDIFMEM2, MDEIMPY, YMDEHPO, YMIMS5YANY, YMDEIMAD5YR, YMIUD5YANY, YMDEHPRX, YMIMI5YANY, YUSUIPLN, YTXMDEYR, YMDEAUD5YR, YRXMDEYR, YMDELT):
215
+ return "Please select all required fields.", "Validation Error"
216
+
217
+ user_inputs = {
218
+ 'YNURSMDE': input_mapping['YNURSMDE'][YNURSMDE],
219
+ 'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
220
+ 'YSOCMDE': input_mapping['YSOCMDE'][YSOCMDE],
221
+ 'YMDESUD5ANYO': input_mapping['YMDESUD5ANYO'][YMDESUD5ANYO],
222
+ 'YMSUD5YANY': input_mapping['YMSUD5YANY'][YMSUD5YANY],
223
+ 'YUSUITHK': input_mapping['YUSUITHK'][YUSUITHK],
224
+ 'YMDETXRX': input_mapping['YMDETXRX'][YMDETXRX],
225
+ 'YUSUITHKYR': input_mapping['YUSUITHKYR'][YUSUITHKYR],
226
+ 'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY],
227
+ 'YUSUIPLNYR': input_mapping['YUSUIPLNYR'][YUSUIPLNYR],
228
+ 'YCOUNMDE': input_mapping['YCOUNMDE'][YCOUNMDE],
229
+ 'YPSY1MDE': input_mapping['YPSY1MDE'][YPSY1MDE],
230
+ 'YHLTMDE': input_mapping['YHLTMDE'][YHLTMDE],
231
+ 'YDOCMDE': input_mapping['YDOCMDE'][YDOCMDE],
232
+ 'YPSY2MDE': input_mapping['YPSY2MDE'][YPSY2MDE],
233
+ 'YMDEHARX': input_mapping['YMDEHARX'][YMDEHARX],
234
+ 'LVLDIFMEM2': input_mapping['LVLDIFMEM2'][LVLDIFMEM2],
235
+ 'MDEIMPY': input_mapping['MDEIMPY'][MDEIMPY],
236
+ 'YMDEHPO': input_mapping['YMDEHPO'][YMDEHPO],
237
+ 'YMIMS5YANY': input_mapping['YMIMS5YANY'][YMIMS5YANY],
238
+ 'YMDEIMAD5YR': input_mapping['YMDEIMAD5YR'][YMDEIMAD5YR],
239
+ 'YMIUD5YANY': input_mapping['YMIUD5YANY'][YMIUD5YANY],
240
+ 'YMDEHPRX': input_mapping['YMDEHPRX'][YMDEHPRX],
241
+ 'YMIMI5YANY': input_mapping['YMIMI5YANY'][YMIMI5YANY],
242
+ 'YUSUIPLN': input_mapping['YUSUIPLN'][YUSUIPLN],
243
+ 'YTXMDEYR': input_mapping['YTXMDEYR'][YTXMDEYR],
244
+ 'YMDEAUD5YR': input_mapping['YMDEAUD5YR'][YMDEAUD5YR],
245
+ 'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR],
246
+ 'YMDELT': input_mapping['YMDELT'][YMDELT]
247
+ }
248
+ return predict(**user_inputs)
249
+
250
  gr.Interface(fn=predict_with_text, inputs=inputs, outputs=outputs, title="Adolescents Mental Health Multi-Output Prediction From NSDUH Data").launch()