Spaces:
Running
Running
File size: 6,113 Bytes
4bbacde a21ebb4 eae8229 26ec01f 76d8292 00b35b6 26ec01f c205145 a21ebb4 4bbacde a21ebb4 e40bb12 3ce3e54 eae8229 26ec01f eae8229 454cd1f eae8229 be57c73 8759844 00658f1 eae8229 00658f1 be57c73 eae8229 be57c73 4f68b98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse,HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
import cv2
import numpy as np
from pillmodel import get_prediction
import base64
from fastapi.staticfiles import StaticFiles
import os
import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold
import google.ai.generativelanguage as glm
from PIL import Image
import io
import random
import re
import json
api_keys = os.getenv('GEMINI_API_KEYS').split(',')
print(api_keys)
from inference_sdk import InferenceHTTPClient
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/predict")
async def predict(image: UploadFile = File(...)):
contents = await image.read()
nparr = np.frombuffer(contents, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# Save the image to a temporary location
# temp_image_path = "temp_image.jpg"
# cv2.imwrite(temp_image_path, img)
# Prediction
predicted_image, count_dict = get_prediction(img)
# Encode predicted image to base64
_, buffer = cv2.imencode('.jpg', predicted_image)
predicted_image_str = base64.b64encode(buffer).decode('utf-8')
# Send a confirmation message
message_to_send = (
f"There are {count_dict.get('capsules', 0)} capsules and {count_dict.get('tablets', 0)} tablets. "
f"A total of {count_dict.get('capsules', 0) + count_dict.get('tablets', 0)} pills."
)
return JSONResponse(content={"message": message_to_send, "count": count_dict, "predicted_image": predicted_image_str})
@app.post("/predict_wheat")
async def predict_wheat(image: UploadFile = File(...), model_id: str = "grian/1"):
contents = await image.read()
nparr = np.frombuffer(contents, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# delete the image if exists
try:
os.remove("temp_image.jpg")
except:
print("temp_image.jpg does not exist")
# Save the image to a temporary location
temp_image_path = "temp_image.jpg"
cv2.imwrite(temp_image_path, img)
CLIENT = InferenceHTTPClient(
api_url="https://detect.roboflow.com",
api_key="PpEebXofNuob5VSx7YP3"
)
result = CLIENT.infer("temp_image.jpg", model_id=model_id)
# Prediction
predicted_count = len(result['predictions'])
message_to_send = (
f"There are {predicted_count} wheat grains."
)
for prediction in result['predictions']:
x = int(prediction['x'])
y = int(prediction['y'])
width = int(prediction['width'])
height = int(prediction['height'])
cv2.rectangle(img, (x, y), (x + width, y + height), (0, 255, 0), 2)
# Encode predicted image to base64
_, buffer = cv2.imencode('.jpg', img)
predicted_image_str = base64.b64encode(buffer).decode('utf-8')
return JSONResponse(content={"message": message_to_send, "count": predicted_count, "predicted_image": predicted_image_str})
def process_image(file: UploadFile):
image = Image.open(file.file)
# Convert the image to RGB if not already
if image.mode != 'RGB':
image = image.convert('RGB')
# Convert the image to a byte array
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='JPEG')
# Create a Blob object
blob = glm.Blob(
mime_type='image/jpeg',
data=img_byte_arr.getvalue()
)
return blob
@app.post("/analyze-image")
async def analyze_image(file: UploadFile = File(...)):
selected_api_key = random.choice(api_keys)
print(f"Selected API Key: {selected_api_key}")
genai.configure(api_key=selected_api_key)
generation_config = {
"temperature": 1,
"top_p": 0.95,
"top_k": 64,
"max_output_tokens": 8192,
"response_mime_type": "text/plain",
}
safety_settings = {
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
# Process the image
blob = process_image(file)
# Initialize the Generative Model
model = genai.GenerativeModel(
model_name="gemini-1.5-flash",
generation_config=generation_config,
safety_settings=safety_settings
)
# Prompt for content generation
prompt = """
give a safety score for a website called unipall which is a olx, now when a user is uploading a product,
tell me this in json like:
only give this json nothing else not be too harmful
when a picture contains some accessories in a scene focus on them and don't flag it
don't flag text on the product
{
useable_on_website: true/false,
safety_score: /100,
category: "",
reason: "",
suggested_product_title: "",
suggested_product_description: ""
}
"""
# Generate content using the AI model
response = model.generate_content([prompt, blob])
if '```json' not in response.text:
return JSONResponse(content=response.text ,media_type="application/json")
# Extract JSON string from Markdown-formatted JSON string
json_string = re.search(r'```json(.*?)```', response.text, re.DOTALL).group(1)
# Clean JSON string
cleaned_response = json_string.strip()
# Parse the cleaned string as JSON
data = json.loads(cleaned_response)
fd = json.dumps(data, indent=4)
# Return the AI-generated response
return JSONResponse(content=fd ,media_type="application/json")
app.mount("/", StaticFiles(directory="static"), name="static")
@app.get("/")
async def home():
return HTMLResponse(content="<html><head><meta http-equiv='refresh' content='0; url=/index.html'></head></html>")
|