yassonee commited on
Commit
0b15508
·
verified ·
1 Parent(s): 2c3deb2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +257 -83
app.py CHANGED
@@ -1,96 +1,270 @@
1
- import gradio as gr
 
2
  from transformers import pipeline
3
- import torch
4
- from PIL import Image, ImageDraw
5
- import numpy as np
 
 
 
6
 
7
  # Chargement des modèles
8
- classifier = pipeline("image-classification", model="abhishek/chest-xray-classification")
9
- detector = pipeline("object-detection", model="nickysam/detect-thorax-anomaly-75acc")
 
 
10
 
11
- def draw_boxes(image, predictions):
12
- # Convertir l'image numpy en PIL si nécessaire
13
- if isinstance(image, np.ndarray):
14
- image = Image.fromarray(np.uint8(image))
15
-
16
- draw = ImageDraw.Draw(image)
17
-
18
- # Dessiner les boîtes de détection
19
- for pred in predictions:
20
- box = pred['box']
21
- score = pred['score']
22
- label = pred['label']
23
-
24
- # Coordonnées de la boîte
25
- x1, y1 = box['xmin'], box['ymin']
26
- x2, y2 = box['xmax'], box['ymax']
27
-
28
- # Couleur en fonction du score
29
- color = (255, 0, 0) if score > 0.7 else (255, 165, 0)
30
-
31
- # Dessiner le rectangle
32
- draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
33
-
34
- # Ajouter le label et le score
35
- label_text = f"{label}: {score:.1%}"
36
- draw.text((x1, y1-15), label_text, fill=color)
37
-
38
- return image
39
 
40
- def analyze_xray(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  try:
42
- # Classification générale
43
- classifications = classifier(image)
44
 
45
- # Détection des anomalies
46
- detections = detector(image)
47
 
48
- # Dessiner les boîtes sur l'image
49
- annotated_image = draw_boxes(image, detections)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- # Préparer les résultats
52
- results = "Classifications:\n"
53
- for pred in classifications:
54
- results += f"{pred['label']}: {pred['score']:.1%}\n"
 
 
 
 
55
 
56
- results += "\nDetected Anomalies:\n"
57
- for det in detections:
58
- results += f"{det['label']}: {det['score']:.1%}\n"
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- return annotated_image, results
61
- except Exception as e:
62
- return image, f"Error: {str(e)}"
63
-
64
- # Interface Gradio
65
- with gr.Blocks(theme=gr.themes.Soft(
66
- primary_hue="gray",
67
- secondary_hue="gray",
68
- )) as demo:
69
- gr.Markdown("""
70
- # Chest X-Ray Analysis
71
- This application analyzes chest X-rays to:
72
- 1. Classify general conditions
73
- 2. Detect and locate specific anomalies
74
- """)
75
-
76
- with gr.Row():
77
- with gr.Column():
78
- input_image = gr.Image(label="Upload X-Ray Image", type="pil")
79
- analyze_btn = gr.Button("Analyze", variant="primary")
80
 
81
- with gr.Column():
82
- output_image = gr.Image(label="Analyzed Image")
83
- output_text = gr.Textbox(label="Results", lines=10)
84
-
85
- analyze_btn.click(
86
- fn=analyze_xray,
87
- inputs=[input_image],
88
- outputs=[output_image, output_text]
89
- )
90
-
91
- gr.Markdown("""
92
- Note: This tool is for demonstration purposes only and should not be used for medical diagnosis.
93
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- # Lancement de l'application
96
- demo.launch()
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.responses import HTMLResponse
3
  from transformers import pipeline
4
+ from PIL import Image
5
+ import io
6
+ import uvicorn
7
+ import base64
8
+
9
+ app = FastAPI()
10
 
11
  # Chargement des modèles
12
+ def load_models():
13
+ return {
14
+ "chest_classifier": pipeline("image-classification", model="codewithdark/vit-chest-xray")
15
+ }
16
 
17
+ models = load_models()
18
+
19
+ def translate_label(label):
20
+ translations = {
21
+ 'Cardiomegaly': 'Kardiomegalie',
22
+ 'Edema': 'Ödem',
23
+ 'Consolidation': 'Konsolidierung',
24
+ 'Pneumonia': 'Lungenentzündung',
25
+ 'No Finding': 'Kein Befund'
26
+ }
27
+ return translations.get(label, label)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ def image_to_base64(image):
30
+ buffered = io.BytesIO()
31
+ image.save(buffered, format="PNG")
32
+ img_str = base64.b64encode(buffered.getvalue()).decode()
33
+ return f"data:image/png;base64,{img_str}"
34
+
35
+ COMMON_STYLES = """
36
+ body {
37
+ font-family: system-ui, -apple-system, sans-serif;
38
+ background: #f0f2f5;
39
+ margin: 0;
40
+ padding: 20px;
41
+ color: #1a1a1a;
42
+ }
43
+ .container {
44
+ max-width: 1200px;
45
+ margin: 0 auto;
46
+ background: white;
47
+ padding: 20px;
48
+ border-radius: 10px;
49
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
50
+ }
51
+ .button {
52
+ background: #2d2d2d;
53
+ color: white;
54
+ border: none;
55
+ padding: 12px 30px;
56
+ border-radius: 8px;
57
+ cursor: pointer;
58
+ font-size: 1.1em;
59
+ transition: all 0.3s ease;
60
+ position: relative;
61
+ }
62
+ .button:hover {
63
+ background: #404040;
64
+ }
65
+ @keyframes blink {
66
+ 0% { opacity: 1; }
67
+ 50% { opacity: 0; }
68
+ 100% { opacity: 1; }
69
+ }
70
+ #loading {
71
+ display: none;
72
+ color: white;
73
+ margin-top: 10px;
74
+ animation: blink 1s infinite;
75
+ text-align: center;
76
+ }
77
+ .upload-section {
78
+ background: #2d2d2d;
79
+ padding: 40px;
80
+ border-radius: 12px;
81
+ margin: 20px 0;
82
+ text-align: center;
83
+ border: 2px dashed #404040;
84
+ transition: all 0.3s ease;
85
+ color: white;
86
+ }
87
+ .upload-section:hover {
88
+ border-color: #555;
89
+ }
90
+ input[type="file"] {
91
+ font-size: 1.1em;
92
+ margin: 20px 0;
93
+ color: white;
94
+ }
95
+ input[type="file"]::file-selector-button {
96
+ font-size: 1em;
97
+ padding: 10px 20px;
98
+ border-radius: 8px;
99
+ border: 1px solid #404040;
100
+ background: #2d2d2d;
101
+ color: white;
102
+ transition: all 0.3s ease;
103
+ cursor: pointer;
104
+ }
105
+ input[type="file"]::file-selector-button:hover {
106
+ background: #404040;
107
+ }
108
+ .preview-image {
109
+ max-width: 300px;
110
+ margin: 20px auto;
111
+ display: none;
112
+ }
113
+ .results-grid {
114
+ display: grid;
115
+ grid-template-columns: 1fr 1fr;
116
+ gap: 20px;
117
+ margin-top: 20px;
118
+ }
119
+ .result-box {
120
+ background: white;
121
+ padding: 20px;
122
+ border-radius: 12px;
123
+ margin: 10px 0;
124
+ border: 1px solid #e9ecef;
125
+ }
126
+ .analyzed-image {
127
+ max-width: 400px;
128
+ margin: 0 auto;
129
+ }
130
+ .score-high {
131
+ color: #0066cc;
132
+ font-weight: bold;
133
+ }
134
+ .score-medium {
135
+ color: #ffa500;
136
+ font-weight: bold;
137
+ }
138
+ h3 {
139
+ color: #0066cc;
140
+ margin-top: 0;
141
+ }
142
+ @media (max-width: 768px) {
143
+ .results-grid {
144
+ grid-template-columns: 1fr;
145
+ }
146
+ }
147
+ """
148
+
149
+ @app.get("/", response_class=HTMLResponse)
150
+ async def main():
151
+ content = f"""
152
+ <!DOCTYPE html>
153
+ <html>
154
+ <head>
155
+ <title>Röntgenbild-Analyse</title>
156
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
157
+ <style>
158
+ {COMMON_STYLES}
159
+ </style>
160
+ </head>
161
+ <body>
162
+ <div class="container">
163
+ <div class="upload-section">
164
+ <form action="/analyze" method="post" enctype="multipart/form-data"
165
+ onsubmit="document.getElementById('loading').style.display = 'block';">
166
+ <div>
167
+ <input type="file" name="file" accept="image/*" required
168
+ onchange="document.getElementById('preview').src = window.URL.createObjectURL(this.files[0]);
169
+ document.getElementById('preview').style.display = 'block';">
170
+ </div>
171
+ <img id="preview" class="preview-image" src="" alt="Vorschau">
172
+ <button type="submit" class="button">
173
+ Analysieren
174
+ </button>
175
+ <div id="loading">Wird geladen...</div>
176
+ </form>
177
+ </div>
178
+ </div>
179
+ </body>
180
+ </html>
181
+ """
182
+ return content
183
+
184
+ @app.post("/analyze", response_class=HTMLResponse)
185
+ async def analyze_file(file: UploadFile = File(...)):
186
  try:
187
+ contents = await file.read()
188
+ image = Image.open(io.BytesIO(contents))
189
 
190
+ predictions = models["chest_classifier"](image)
191
+ result_image_b64 = image_to_base64(image)
192
 
193
+ results_html = f"""
194
+ <!DOCTYPE html>
195
+ <html>
196
+ <head>
197
+ <title>Ergebnisse</title>
198
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
199
+ <style>
200
+ {COMMON_STYLES}
201
+ </style>
202
+ </head>
203
+ <body>
204
+ <div class="container">
205
+ <div class="results-grid">
206
+ <div class="result-box">
207
+ <h3>Analyse-Ergebnisse</h3>
208
+ """
209
 
210
+ for pred in predictions:
211
+ confidence_class = "score-high" if pred['score'] > 0.7 else "score-medium"
212
+ results_html += f"""
213
+ <div>
214
+ <span class="{confidence_class}">{pred['score']:.1%}</span> -
215
+ {translate_label(pred['label'])}
216
+ </div>
217
+ """
218
 
219
+ results_html += f"""
220
+ </div>
221
+ <div class="result-box">
222
+ <h3>Röntgenbild</h3>
223
+ <img src="{result_image_b64}" alt="Analysiertes Röntgenbild" class="analyzed-image">
224
+ </div>
225
+ </div>
226
+
227
+ <a href="/" class="button back-button">
228
+ ← Zurück
229
+ </a>
230
+ </div>
231
+ </body>
232
+ </html>
233
+ """
234
 
235
+ return results_html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
+ except Exception as e:
238
+ return f"""
239
+ <!DOCTYPE html>
240
+ <html>
241
+ <head>
242
+ <title>Fehler</title>
243
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
244
+ <style>
245
+ {COMMON_STYLES}
246
+ .error-box {{
247
+ background: #fee2e2;
248
+ border: 1px solid #ef4444;
249
+ padding: 20px;
250
+ border-radius: 8px;
251
+ margin: 20px 0;
252
+ }}
253
+ </style>
254
+ </head>
255
+ <body>
256
+ <div class="container">
257
+ <div class="error-box">
258
+ <h3>Fehler</h3>
259
+ <p>{str(e)}</p>
260
+ </div>
261
+ <a href="/" class="button back-button">
262
+ ← Zurück
263
+ </a>
264
+ </div>
265
+ </body>
266
+ </html>
267
+ """
268
 
269
+ if __name__ == "__main__":
270
+ uvicorn.run(app, host="0.0.0.0", port=7860)