Spaces:
Running
Running
from fastapi import FastAPI, File, UploadFile | |
from fastapi.responses import JSONResponse,HTMLResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
import cv2 | |
import numpy as np | |
from pillmodel import get_prediction | |
import base64 | |
from fastapi.staticfiles import StaticFiles | |
import os | |
import google.generativeai as genai | |
from google.generativeai.types import HarmCategory, HarmBlockThreshold | |
import google.ai.generativelanguage as glm | |
from PIL import Image | |
import io | |
import random | |
import re | |
import json | |
api_keys = os.getenv('GEMINI_API_KEYS').split(',') | |
print(api_keys) | |
from inference_sdk import InferenceHTTPClient | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def predict(image: UploadFile = File(...)): | |
contents = await image.read() | |
nparr = np.frombuffer(contents, np.uint8) | |
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
# Save the image to a temporary location | |
# temp_image_path = "temp_image.jpg" | |
# cv2.imwrite(temp_image_path, img) | |
# Prediction | |
predicted_image, count_dict = get_prediction(img) | |
# Encode predicted image to base64 | |
_, buffer = cv2.imencode('.jpg', predicted_image) | |
predicted_image_str = base64.b64encode(buffer).decode('utf-8') | |
# Send a confirmation message | |
message_to_send = ( | |
f"There are {count_dict.get('capsules', 0)} capsules and {count_dict.get('tablets', 0)} tablets. " | |
f"A total of {count_dict.get('capsules', 0) + count_dict.get('tablets', 0)} pills." | |
) | |
return JSONResponse(content={"message": message_to_send, "count": count_dict, "predicted_image": predicted_image_str}) | |
async def predict_wheat(image: UploadFile = File(...), model_id: str = "grian/1"): | |
contents = await image.read() | |
nparr = np.frombuffer(contents, np.uint8) | |
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
# delete the image if exists | |
try: | |
os.remove("temp_image.jpg") | |
except: | |
print("temp_image.jpg does not exist") | |
# Save the image to a temporary location | |
temp_image_path = "temp_image.jpg" | |
cv2.imwrite(temp_image_path, img) | |
CLIENT = InferenceHTTPClient( | |
api_url="https://detect.roboflow.com", | |
api_key="PpEebXofNuob5VSx7YP3" | |
) | |
result = CLIENT.infer("temp_image.jpg", model_id=model_id) | |
# Prediction | |
predicted_count = len(result['predictions']) | |
message_to_send = ( | |
f"There are {predicted_count} wheat grains." | |
) | |
for prediction in result['predictions']: | |
x = int(prediction['x']) | |
y = int(prediction['y']) | |
width = int(prediction['width']) | |
height = int(prediction['height']) | |
cv2.rectangle(img, (x, y), (x + width, y + height), (0, 255, 0), 2) | |
# Encode predicted image to base64 | |
_, buffer = cv2.imencode('.jpg', img) | |
predicted_image_str = base64.b64encode(buffer).decode('utf-8') | |
return JSONResponse(content={"message": message_to_send, "count": predicted_count, "predicted_image": predicted_image_str}) | |
def process_image(file: UploadFile): | |
image = Image.open(file.file) | |
# Convert the image to RGB if not already | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# Convert the image to a byte array | |
img_byte_arr = io.BytesIO() | |
image.save(img_byte_arr, format='JPEG') | |
# Create a Blob object | |
blob = glm.Blob( | |
mime_type='image/jpeg', | |
data=img_byte_arr.getvalue() | |
) | |
return blob | |
async def analyze_image(file: UploadFile = File(...)): | |
selected_api_key = random.choice(api_keys) | |
print(f"Selected API Key: {selected_api_key}") | |
genai.configure(api_key=selected_api_key) | |
generation_config = { | |
"temperature": 1, | |
"top_p": 0.95, | |
"top_k": 64, | |
"max_output_tokens": 8192, | |
"response_mime_type": "text/plain", | |
} | |
safety_settings = { | |
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, | |
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, | |
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, | |
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, | |
} | |
# Process the image | |
blob = process_image(file) | |
# Initialize the Generative Model | |
model = genai.GenerativeModel( | |
model_name="gemini-1.5-flash", | |
generation_config=generation_config, | |
safety_settings=safety_settings | |
) | |
# Prompt for content generation | |
prompt = """ | |
give a safety score for a website called unipall which is a olx, now when a user is uploading a product, | |
tell me this in json like: | |
only give this json nothing else not be too harmful | |
when a picture contains some accessories in a scene focus on them and don't flag it | |
don't flag text on the product | |
{ | |
useable_on_website: true/false, | |
safety_score: /100, | |
category: "", | |
reason: "", | |
suggested_product_title: "", | |
suggested_product_description: "" | |
} | |
""" | |
# Generate content using the AI model | |
response = model.generate_content([prompt, blob]) | |
if '```json' not in response.text: | |
return JSONResponse(content=response.text ,media_type="application/json") | |
# Extract JSON string from Markdown-formatted JSON string | |
json_string = re.search(r'```json(.*?)```', response.text, re.DOTALL).group(1) | |
# Clean JSON string | |
cleaned_response = json_string.strip() | |
# Parse the cleaned string as JSON | |
data = json.loads(cleaned_response) | |
fd = json.dumps(data, indent=4) | |
# Return the AI-generated response | |
return JSONResponse(content=fd ,media_type="application/json") | |
app.mount("/", StaticFiles(directory="static"), name="static") | |
async def home(): | |
return HTMLResponse(content="<html><head><meta http-equiv='refresh' content='0; url=/index.html'></head></html>") | |