pantdipendra
commited on
See
Browse files
app.py
CHANGED
@@ -52,7 +52,6 @@ class ModelPredictor:
|
|
52 |
combined_predictions = np.concatenate(predictions)
|
53 |
majority_vote = np.bincount(combined_predictions).argmax()
|
54 |
return majority_vote
|
55 |
-
|
56 |
# Based on Equal Interval and Percentage-Based Method
|
57 |
# Severe: 13 to 16 votes (upper 25%) Moderate: 9 to 12 votes (upper-middle 25%) Low: 5 to 8 votes (lower-middle 25%) Very Low: 0 to 4 votes (lower 25%)
|
58 |
def evaluate_severity(self, majority_vote_count):
|
@@ -74,6 +73,72 @@ model_filenames = [
|
|
74 |
model_path = "models/"
|
75 |
predictor = ModelPredictor(model_path, model_filenames)
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
# Mapping user-friendly text to corresponding numeric values for the prediction function
|
78 |
input_mapping = {
|
79 |
'YNURSMDE': {"Yes": 1, "No": 0},
|
@@ -107,48 +172,7 @@ input_mapping = {
|
|
107 |
'YMDELT': {"Yes": 1, "No": 2}
|
108 |
}
|
109 |
|
110 |
-
|
111 |
-
for arg in args:
|
112 |
-
if arg == '' or arg is None: # Assuming empty string or None as unselected
|
113 |
-
return False
|
114 |
-
return True
|
115 |
-
|
116 |
-
def predict_with_text(YNURSMDE, YMDEYR, YSOCMDE, YMDESUD5ANYO, YMSUD5YANY, YUSUITHK, YMDETXRX, YUSUITHKYR, YMDERSUD5ANY, YUSUIPLNYR, YCOUNMDE, YPSY1MDE, YHLTMDE, YDOCMDE, YPSY2MDE, YMDEHARX, LVLDIFMEM2, MDEIMPY, YMDEHPO, YMIMS5YANY, YMDEIMAD5YR, YMIUD5YANY, YMDEHPRX, YMIMI5YANY, YUSUIPLN, YTXMDEYR, YMDEAUD5YR, YRXMDEYR, YMDELT):
|
117 |
-
if not validate_inputs(YNURSMDE, YMDEYR, YSOCMDE, YMDESUD5ANYO, YMSUD5YANY, YUSUITHK, YMDETXRX, YUSUITHKYR, YMDERSUD5ANY, YUSUIPLNYR, YCOUNMDE, YPSY1MDE, YHLTMDE, YDOCMDE, YPSY2MDE, YMDEHARX, LVLDIFMEM2, MDEIMPY, YMDEHPO, YMIMS5YANY, YMDEIMAD5YR, YMIUD5YANY, YMDEHPRX, YMIMI5YANY, YUSUIPLN, YTXMDEYR, YMDEAUD5YR, YRXMDEYR, YMDELT):
|
118 |
-
return "Please select all required fields.", "Validation Error"
|
119 |
-
user_inputs = {
|
120 |
-
'YNURSMDE': input_mapping['YNURSMDE'][YNURSMDE],
|
121 |
-
'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
|
122 |
-
'YSOCMDE': input_mapping['YSOCMDE'][YSOCMDE],
|
123 |
-
'YMDESUD5ANYO': input_mapping['YMDESUD5ANYO'][YMDESUD5ANYO],
|
124 |
-
'YMSUD5YANY': input_mapping['YMSUD5YANY'][YMSUD5YANY],
|
125 |
-
'YUSUITHK': input_mapping['YUSUITHK'][YUSUITHK],
|
126 |
-
'YMDETXRX': input_mapping['YMDETXRX'][YMDETXRX],
|
127 |
-
'YUSUITHKYR': input_mapping['YUSUITHKYR'][YUSUITHKYR],
|
128 |
-
'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY],
|
129 |
-
'YUSUIPLNYR': input_mapping['YUSUIPLNYR'][YUSUIPLNYR],
|
130 |
-
'YCOUNMDE': input_mapping['YCOUNMDE'][YCOUNMDE],
|
131 |
-
'YPSY1MDE': input_mapping['YPSY1MDE'][YPSY1MDE],
|
132 |
-
'YHLTMDE': input_mapping['YHLTMDE'][YHLTMDE],
|
133 |
-
'YDOCMDE': input_mapping['YDOCMDE'][YDOCMDE],
|
134 |
-
'YPSY2MDE': input_mapping['YPSY2MDE'][YPSY2MDE],
|
135 |
-
'YMDEHARX': input_mapping['YMDEHARX'][YMDEHARX],
|
136 |
-
'LVLDIFMEM2': input_mapping['LVLDIFMEM2'][LVLDIFMEM2],
|
137 |
-
'MDEIMPY': input_mapping['MDEIMPY'][MDEIMPY],
|
138 |
-
'YMDEHPO': input_mapping['YMDEHPO'][YMDEHPO],
|
139 |
-
'YMIMS5YANY': input_mapping['YMIMS5YANY'][YMIMS5YANY],
|
140 |
-
'YMDEIMAD5YR': input_mapping['YMDEIMAD5YR'][YMDEIMAD5YR],
|
141 |
-
'YMIUD5YANY': input_mapping['YMIUD5YANY'][YMIUD5YANY],
|
142 |
-
'YMDEHPRX': input_mapping['YMDEHPRX'][YMDEHPRX],
|
143 |
-
'YMIMI5YANY': input_mapping['YMIMI5YANY'][YMIMI5YANY],
|
144 |
-
'YUSUIPLN': input_mapping['YUSUIPLN'][YUSUIPLN],
|
145 |
-
'YTXMDEYR': input_mapping['YTXMDEYR'][YTXMDEYR],
|
146 |
-
'YMDEAUD5YR': input_mapping['YMDEAUD5YR'][YMDEAUD5YR],
|
147 |
-
'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR],
|
148 |
-
'YMDELT': input_mapping['YMDELT'][YMDELT]
|
149 |
-
}
|
150 |
-
return predict(**user_inputs)
|
151 |
-
|
152 |
inputs = [
|
153 |
gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE: SAW/TALK TO NURSE/OCC THERA ABOUT MAJOR DEPRESSIVE EPISODE (MDE) IN PAST YEAR (PY)"),
|
154 |
gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR: PAST YEAR MAJOR DEPRESSIVE EPISODE"),
|
@@ -186,4 +210,41 @@ outputs = [
|
|
186 |
gr.Textbox(label="Mental Health Severity")
|
187 |
]
|
188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
gr.Interface(fn=predict_with_text, inputs=inputs, outputs=outputs, title="Adolescents Mental Health Multi-Output Prediction From NSDUH Data").launch()
|
|
|
52 |
combined_predictions = np.concatenate(predictions)
|
53 |
majority_vote = np.bincount(combined_predictions).argmax()
|
54 |
return majority_vote
|
|
|
55 |
# Based on Equal Interval and Percentage-Based Method
|
56 |
# Severe: 13 to 16 votes (upper 25%) Moderate: 9 to 12 votes (upper-middle 25%) Low: 5 to 8 votes (lower-middle 25%) Very Low: 0 to 4 votes (lower 25%)
|
57 |
def evaluate_severity(self, majority_vote_count):
|
|
|
73 |
model_path = "models/"
|
74 |
predictor = ModelPredictor(model_path, model_filenames)
|
75 |
|
76 |
+
def validate_inputs(*args):
|
77 |
+
for arg in args:
|
78 |
+
if arg == '' or arg is None: # Assuming empty string or None as unselected
|
79 |
+
return False
|
80 |
+
return True
|
81 |
+
|
82 |
+
def predict(YNURSMDE, YMDEYR, YSOCMDE, YMDESUD5ANYO, YMSUD5YANY, YUSUITHK, YMDETXRX, YUSUITHKYR, YMDERSUD5ANY, YUSUIPLNYR, YCOUNMDE, YPSY1MDE, YHLTMDE, YDOCMDE, YPSY2MDE, YMDEHARX, LVLDIFMEM2, MDEIMPY, YMDEHPO, YMIMS5YANY, YMDEIMAD5YR, YMIUD5YANY, YMDEHPRX, YMIMI5YANY, YUSUIPLN, YTXMDEYR, YMDEAUD5YR, YRXMDEYR, YMDELT):
|
83 |
+
user_input_data = {
|
84 |
+
'YNURSMDE': [int(YNURSMDE)],
|
85 |
+
'YMDEYR': [int(YMDEYR)],
|
86 |
+
'YSOCMDE': [int(YSOCMDE)],
|
87 |
+
'YMDESUD5ANYO': [int(YMDESUD5ANYO)],
|
88 |
+
'YMSUD5YANY': [int(YMSUD5YANY)],
|
89 |
+
'YUSUITHK': [int(YUSUITHK)],
|
90 |
+
'YMDETXRX': [int(YMDETXRX)],
|
91 |
+
'YUSUITHKYR': [int(YUSUITHKYR)],
|
92 |
+
'YMDERSUD5ANY': [int(YMDERSUD5ANY)],
|
93 |
+
'YUSUIPLNYR': [int(YUSUIPLNYR)],
|
94 |
+
'YCOUNMDE': [int(YCOUNMDE)],
|
95 |
+
'YPSY1MDE': [int(YPSY1MDE)],
|
96 |
+
'YHLTMDE': [int(YHLTMDE)],
|
97 |
+
'YDOCMDE': [int(YDOCMDE)],
|
98 |
+
'YPSY2MDE': [int(YPSY2MDE)],
|
99 |
+
'YMDEHARX': [int(YMDEHARX)],
|
100 |
+
'LVLDIFMEM2': [int(LVLDIFMEM2)],
|
101 |
+
'MDEIMPY': [int(MDEIMPY)],
|
102 |
+
'YMDEHPO': [int(YMDEHPO)],
|
103 |
+
'YMIMS5YANY': [int(YMIMS5YANY)],
|
104 |
+
'YMDEIMAD5YR': [int(YMDEIMAD5YR)],
|
105 |
+
'YMIUD5YANY': [int(YMIUD5YANY)],
|
106 |
+
'YMDEHPRX': [int(YMDEHPRX)],
|
107 |
+
'YMIMI5YANY': [int(YMIMI5YANY)],
|
108 |
+
'YUSUIPLN': [int(YUSUIPLN)],
|
109 |
+
'YTXMDEYR': [int(YTXMDEYR)],
|
110 |
+
'YMDEAUD5YR': [int(YMDEAUD5YR)],
|
111 |
+
'YRXMDEYR': [int(YRXMDEYR)],
|
112 |
+
'YMDELT': [int(YMDELT)]
|
113 |
+
}
|
114 |
+
user_input = pd.DataFrame(user_input_data)
|
115 |
+
predictions = predictor.make_predictions(user_input)
|
116 |
+
majority_vote = predictor.get_majority_vote(predictions)
|
117 |
+
majority_vote_count = sum([1 for pred in np.concatenate(predictions) if pred == 1])
|
118 |
+
severity = predictor.evaluate_severity(majority_vote_count)
|
119 |
+
|
120 |
+
results = []
|
121 |
+
unknown_count = 0
|
122 |
+
for i, pred in enumerate(predictions):
|
123 |
+
model_name = model_filenames[i].split('.')[0]
|
124 |
+
pred_value = pred[0]
|
125 |
+
if model_name in predictor.prediction_map:
|
126 |
+
if pred_value < len(predictor.prediction_map[model_name]):
|
127 |
+
results.append(f"Model {model_name}: {predictor.prediction_map[model_name][pred_value]}")
|
128 |
+
else:
|
129 |
+
results.append(f"Model {model_name}: Unknown prediction value {pred_value}")
|
130 |
+
unknown_count += 1
|
131 |
+
else:
|
132 |
+
results.append(f"Model {model_name}: Unknown model")
|
133 |
+
unknown_count += 1
|
134 |
+
|
135 |
+
formatted_results = "\n".join(results)
|
136 |
+
|
137 |
+
if unknown_count > len(model_filenames) / 2:
|
138 |
+
severity += " (Unknown prediction count is more. So not sure please consult with a human.)"
|
139 |
+
|
140 |
+
return formatted_results, severity
|
141 |
+
|
142 |
# Mapping user-friendly text to corresponding numeric values for the prediction function
|
143 |
input_mapping = {
|
144 |
'YNURSMDE': {"Yes": 1, "No": 0},
|
|
|
172 |
'YMDELT': {"Yes": 1, "No": 2}
|
173 |
}
|
174 |
|
175 |
+
# Create Gradio inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
inputs = [
|
177 |
gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE: SAW/TALK TO NURSE/OCC THERA ABOUT MAJOR DEPRESSIVE EPISODE (MDE) IN PAST YEAR (PY)"),
|
178 |
gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR: PAST YEAR MAJOR DEPRESSIVE EPISODE"),
|
|
|
210 |
gr.Textbox(label="Mental Health Severity")
|
211 |
]
|
212 |
|
213 |
+
def predict_with_text(YNURSMDE, YMDEYR, YSOCMDE, YMDESUD5ANYO, YMSUD5YANY, YUSUITHK, YMDETXRX, YUSUITHKYR, YMDERSUD5ANY, YUSUIPLNYR, YCOUNMDE, YPSY1MDE, YHLTMDE, YDOCMDE, YPSY2MDE, YMDEHARX, LVLDIFMEM2, MDEIMPY, YMDEHPO, YMIMS5YANY, YMDEIMAD5YR, YMIUD5YANY, YMDEHPRX, YMIMI5YANY, YUSUIPLN, YTXMDEYR, YMDEAUD5YR, YRXMDEYR, YMDELT):
|
214 |
+
if not validate_inputs(YNURSMDE, YMDEYR, YSOCMDE, YMDESUD5ANYO, YMSUD5YANY, YUSUITHK, YMDETXRX, YUSUITHKYR, YMDERSUD5ANY, YUSUIPLNYR, YCOUNMDE, YPSY1MDE, YHLTMDE, YDOCMDE, YPSY2MDE, YMDEHARX, LVLDIFMEM2, MDEIMPY, YMDEHPO, YMIMS5YANY, YMDEIMAD5YR, YMIUD5YANY, YMDEHPRX, YMIMI5YANY, YUSUIPLN, YTXMDEYR, YMDEAUD5YR, YRXMDEYR, YMDELT):
|
215 |
+
return "Please select all required fields.", "Validation Error"
|
216 |
+
|
217 |
+
user_inputs = {
|
218 |
+
'YNURSMDE': input_mapping['YNURSMDE'][YNURSMDE],
|
219 |
+
'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
|
220 |
+
'YSOCMDE': input_mapping['YSOCMDE'][YSOCMDE],
|
221 |
+
'YMDESUD5ANYO': input_mapping['YMDESUD5ANYO'][YMDESUD5ANYO],
|
222 |
+
'YMSUD5YANY': input_mapping['YMSUD5YANY'][YMSUD5YANY],
|
223 |
+
'YUSUITHK': input_mapping['YUSUITHK'][YUSUITHK],
|
224 |
+
'YMDETXRX': input_mapping['YMDETXRX'][YMDETXRX],
|
225 |
+
'YUSUITHKYR': input_mapping['YUSUITHKYR'][YUSUITHKYR],
|
226 |
+
'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY],
|
227 |
+
'YUSUIPLNYR': input_mapping['YUSUIPLNYR'][YUSUIPLNYR],
|
228 |
+
'YCOUNMDE': input_mapping['YCOUNMDE'][YCOUNMDE],
|
229 |
+
'YPSY1MDE': input_mapping['YPSY1MDE'][YPSY1MDE],
|
230 |
+
'YHLTMDE': input_mapping['YHLTMDE'][YHLTMDE],
|
231 |
+
'YDOCMDE': input_mapping['YDOCMDE'][YDOCMDE],
|
232 |
+
'YPSY2MDE': input_mapping['YPSY2MDE'][YPSY2MDE],
|
233 |
+
'YMDEHARX': input_mapping['YMDEHARX'][YMDEHARX],
|
234 |
+
'LVLDIFMEM2': input_mapping['LVLDIFMEM2'][LVLDIFMEM2],
|
235 |
+
'MDEIMPY': input_mapping['MDEIMPY'][MDEIMPY],
|
236 |
+
'YMDEHPO': input_mapping['YMDEHPO'][YMDEHPO],
|
237 |
+
'YMIMS5YANY': input_mapping['YMIMS5YANY'][YMIMS5YANY],
|
238 |
+
'YMDEIMAD5YR': input_mapping['YMDEIMAD5YR'][YMDEIMAD5YR],
|
239 |
+
'YMIUD5YANY': input_mapping['YMIUD5YANY'][YMIUD5YANY],
|
240 |
+
'YMDEHPRX': input_mapping['YMDEHPRX'][YMDEHPRX],
|
241 |
+
'YMIMI5YANY': input_mapping['YMIMI5YANY'][YMIMI5YANY],
|
242 |
+
'YUSUIPLN': input_mapping['YUSUIPLN'][YUSUIPLN],
|
243 |
+
'YTXMDEYR': input_mapping['YTXMDEYR'][YTXMDEYR],
|
244 |
+
'YMDEAUD5YR': input_mapping['YMDEAUD5YR'][YMDEAUD5YR],
|
245 |
+
'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR],
|
246 |
+
'YMDELT': input_mapping['YMDELT'][YMDELT]
|
247 |
+
}
|
248 |
+
return predict(**user_inputs)
|
249 |
+
|
250 |
gr.Interface(fn=predict_with_text, inputs=inputs, outputs=outputs, title="Adolescents Mental Health Multi-Output Prediction From NSDUH Data").launch()
|