dgbkn commited on
Commit
a21ebb4
·
1 Parent(s): 502613c

added things

Browse files
main.py CHANGED
@@ -6,6 +6,12 @@ import cv2
6
  import numpy as np
7
  from pillmodel import get_prediction
8
  import base64
 
 
 
 
 
 
9
 
10
  app = FastAPI()
11
 
@@ -17,6 +23,7 @@ app.add_middleware(
17
  allow_headers=["*"],
18
  )
19
 
 
20
 
21
 
22
  @app.post("/predict")
@@ -43,7 +50,49 @@ async def predict(image: UploadFile = File(...)):
43
 
44
  return JSONResponse(content={"message": message_to_send, "count": count_dict, "predicted_image": predicted_image_str})
45
 
46
- @app.get("/", response_class=HTMLResponse)
47
- async def read_root():
48
- with open("upload.html", "r") as file:
49
- return HTMLResponse(content=file.read(), status_code=200)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import numpy as np
7
  from pillmodel import get_prediction
8
  import base64
9
+ from fastapi.staticfiles import StaticFiles
10
+ import os
11
+
12
+ from inference_sdk import InferenceHTTPClient
13
+
14
+
15
 
16
  app = FastAPI()
17
 
 
23
  allow_headers=["*"],
24
  )
25
 
26
+ app.mount("/", StaticFiles(directory="static"), name="static")
27
 
28
 
29
  @app.post("/predict")
 
50
 
51
  return JSONResponse(content={"message": message_to_send, "count": count_dict, "predicted_image": predicted_image_str})
52
 
53
+
54
+
55
+
56
+
57
+
58
+ @app.post("/predict_wheat")
59
+ async def predict_wheat(image: UploadFile = File(...), model_id: str = "grian/1"):
60
+ contents = await image.read()
61
+ nparr = np.frombuffer(contents, np.uint8)
62
+ img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
63
+
64
+ # delete the image if exists
65
+ try:
66
+ os.remove("temp_image.jpg")
67
+ except:
68
+ print("temp_image.jpg does not exist")
69
+
70
+ # Save the image to a temporary location
71
+ temp_image_path = "temp_image.jpg"
72
+ cv2.imwrite(temp_image_path, img)
73
+
74
+ CLIENT = InferenceHTTPClient(
75
+ api_url="https://detect.roboflow.com",
76
+ api_key="PpEebXofNuob5VSx7YP3"
77
+ )
78
+
79
+
80
+ result = CLIENT.infer("temp_image.jpg", model_id=model_id)
81
+ # Prediction
82
+ predicted_count = len(result['predictions'])
83
+ message_to_send = (
84
+ f"There are {predicted_count} wheat grains."
85
+ )
86
+
87
+ for prediction in result['predictions']:
88
+ x = int(prediction['x'])
89
+ y = int(prediction['y'])
90
+ width = int(prediction['width'])
91
+ height = int(prediction['height'])
92
+ cv2.rectangle(img, (x, y), (x + width, y + height), (0, 255, 0), 2)
93
+ # Encode predicted image to base64
94
+ _, buffer = cv2.imencode('.jpg', img)
95
+ predicted_image_str = base64.b64encode(buffer).decode('utf-8')
96
+
97
+
98
+ return JSONResponse(content={"message": message_to_send, "count": predicted_count, "predicted_image": predicted_image_str})
requirements.txt CHANGED
@@ -9,4 +9,7 @@ shapely
9
  ultralytics
10
 
11
  fastapi
12
- uvicorn
 
 
 
 
9
  ultralytics
10
 
11
  fastapi
12
+ uvicorn
13
+
14
+ inference_sdk
15
+ # roboflow
static/index.html ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Detection Models</title>
7
+ <link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/tailwind.min.css" rel="stylesheet">
8
+ </head>
9
+ <body class="bg-gray-100 h-screen flex items-center justify-center">
10
+ <div class="bg-white p-8 rounded shadow-md w-full max-w-sm">
11
+ <h1 class="text-2xl font-bold mb-4 text-center">Choose a Detection Model</h1>
12
+ <div class="flex flex-col space-y-4">
13
+ <a href="/pill_upload.html" class="bg-blue-500 text-white text-center py-2 rounded hover:bg-blue-600">Pill Detection</a>
14
+ <a href="/wheat_upload.html" class="bg-green-500 text-white text-center py-2 rounded hover:bg-green-600">Wheat Grain Detection</a>
15
+ </div>
16
+ </div>
17
+ </body>
18
+ </html>
upload.html → static/pill_upload.html RENAMED
@@ -9,14 +9,28 @@
9
  body {
10
  font-family: Arial, sans-serif;
11
  }
 
 
 
 
 
 
 
 
 
12
  </style>
13
  </head>
14
  <body class="flex flex-col items-center justify-center h-screen bg-gray-100">
15
  <h1 class="text-2xl font-bold mb-4">Upload an Image</h1>
16
- <input type="file" id="fileInput" accept="image/*" class="mb-4">
 
17
  <button onclick="uploadImage()" class="bg-blue-500 text-white px-4 py-2 rounded">Upload</button>
18
  <div id="output" class="mt-4"></div>
19
  <script>
 
 
 
 
20
  async function uploadImage() {
21
  const fileInput = document.getElementById('fileInput');
22
  const output = document.getElementById('output');
 
9
  body {
10
  font-family: Arial, sans-serif;
11
  }
12
+ .clickable-image {
13
+ cursor: pointer;
14
+ border: 2px solid #ddd;
15
+ border-radius: 8px;
16
+ transition: border-color 0.3s;
17
+ }
18
+ .clickable-image:hover {
19
+ border-color: #007bff;
20
+ }
21
  </style>
22
  </head>
23
  <body class="flex flex-col items-center justify-center h-screen bg-gray-100">
24
  <h1 class="text-2xl font-bold mb-4">Upload an Image</h1>
25
+ <input type="file" id="fileInput" accept="image/*" capture="environment" class="hidden">
26
+ <img src="/upload.png" alt="Click to Upload" class="clickable-image mb-4" onclick="triggerFileInput()">
27
  <button onclick="uploadImage()" class="bg-blue-500 text-white px-4 py-2 rounded">Upload</button>
28
  <div id="output" class="mt-4"></div>
29
  <script>
30
+ function triggerFileInput() {
31
+ document.getElementById('fileInput').click();
32
+ }
33
+
34
  async function uploadImage() {
35
  const fileInput = document.getElementById('fileInput');
36
  const output = document.getElementById('output');
static/upload.png ADDED
static/wheat_upload.html ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Image Upload</title>
7
+ <link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/tailwind.min.css" rel="stylesheet">
8
+ <style>
9
+ body {
10
+ font-family: Arial, sans-serif;
11
+ }
12
+ .clickable-image {
13
+ cursor: pointer;
14
+ border: 2px solid #ddd;
15
+ border-radius: 8px;
16
+ transition: border-color 0.3s;
17
+ }
18
+ .clickable-image:hover {
19
+ border-color: #007bff;
20
+ }
21
+ </style>
22
+ </head>
23
+ <body class="flex flex-col items-center justify-center h-screen bg-gray-100">
24
+ <h1 class="text-2xl font-bold mb-4">Upload an Image</h1>
25
+ <input type="file" id="fileInput" accept="image/*" capture="environment" class="hidden">
26
+ <img src="/upload.png" alt="Click to Upload" class="clickable-image mb-4" onclick="triggerFileInput()">
27
+
28
+ <!-- Dropdown for selecting model_id -->
29
+ <select id="modelSelect" class="mb-4 px-4 py-2 border rounded">
30
+ <option value="grian/1">Grian Model</option>
31
+ <option value="wheat-dataset-new/2">Wheat Dataset New Model</option>
32
+ </select>
33
+
34
+ <button onclick="uploadImage()" class="bg-blue-500 text-white px-4 py-2 rounded">Upload</button>
35
+ <div id="output" class="mt-4"></div>
36
+
37
+ <script>
38
+ function triggerFileInput() {
39
+ document.getElementById('fileInput').click();
40
+ }
41
+
42
+ async function uploadImage() {
43
+ const fileInput = document.getElementById('fileInput');
44
+ const modelSelect = document.getElementById('modelSelect');
45
+ const output = document.getElementById('output');
46
+ const file = fileInput.files[0];
47
+ const modelId = modelSelect.value;
48
+
49
+ if (!file) {
50
+ output.innerHTML = 'No image selected.';
51
+ return;
52
+ }
53
+
54
+ const formData = new FormData();
55
+ formData.append('image', file);
56
+ formData.append('model_id', modelId); // Append model_id to FormData
57
+
58
+ output.innerHTML = 'Uploading...';
59
+
60
+ try {
61
+ const response = await fetch('/predict_wheat', {
62
+ method: 'POST',
63
+ body: formData
64
+ });
65
+ const result = await response.json();
66
+ const predictedImageSrc = `data:image/jpeg;base64,${result.predicted_image}`;
67
+ output.innerHTML = `
68
+ <p>${result.message}</p>
69
+ <img src="${predictedImageSrc}" alt="Predicted Image" class="mt-4">
70
+ `;
71
+ } catch (error) {
72
+ output.innerHTML = 'Failed to get prediction';
73
+ console.error(error);
74
+ }
75
+ }
76
+ </script>
77
+ </body>
78
+ </html>