countThings / main.py
merasabkuch's picture
Update main.py
c205145 verified
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse,HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
import cv2
import numpy as np
from pillmodel import get_prediction
import base64
from fastapi.staticfiles import StaticFiles
import os
import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold
import google.ai.generativelanguage as glm
from PIL import Image
import io
import random
import re
import json
api_keys = os.getenv('GEMINI_API_KEYS').split(',')
print(api_keys)
from inference_sdk import InferenceHTTPClient
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/predict")
async def predict(image: UploadFile = File(...)):
contents = await image.read()
nparr = np.frombuffer(contents, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# Save the image to a temporary location
# temp_image_path = "temp_image.jpg"
# cv2.imwrite(temp_image_path, img)
# Prediction
predicted_image, count_dict = get_prediction(img)
# Encode predicted image to base64
_, buffer = cv2.imencode('.jpg', predicted_image)
predicted_image_str = base64.b64encode(buffer).decode('utf-8')
# Send a confirmation message
message_to_send = (
f"There are {count_dict.get('capsules', 0)} capsules and {count_dict.get('tablets', 0)} tablets. "
f"A total of {count_dict.get('capsules', 0) + count_dict.get('tablets', 0)} pills."
)
return JSONResponse(content={"message": message_to_send, "count": count_dict, "predicted_image": predicted_image_str})
@app.post("/predict_wheat")
async def predict_wheat(image: UploadFile = File(...), model_id: str = "grian/1"):
contents = await image.read()
nparr = np.frombuffer(contents, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# delete the image if exists
try:
os.remove("temp_image.jpg")
except:
print("temp_image.jpg does not exist")
# Save the image to a temporary location
temp_image_path = "temp_image.jpg"
cv2.imwrite(temp_image_path, img)
CLIENT = InferenceHTTPClient(
api_url="https://detect.roboflow.com",
api_key="PpEebXofNuob5VSx7YP3"
)
result = CLIENT.infer("temp_image.jpg", model_id=model_id)
# Prediction
predicted_count = len(result['predictions'])
message_to_send = (
f"There are {predicted_count} wheat grains."
)
for prediction in result['predictions']:
x = int(prediction['x'])
y = int(prediction['y'])
width = int(prediction['width'])
height = int(prediction['height'])
cv2.rectangle(img, (x, y), (x + width, y + height), (0, 255, 0), 2)
# Encode predicted image to base64
_, buffer = cv2.imencode('.jpg', img)
predicted_image_str = base64.b64encode(buffer).decode('utf-8')
return JSONResponse(content={"message": message_to_send, "count": predicted_count, "predicted_image": predicted_image_str})
def process_image(file: UploadFile):
image = Image.open(file.file)
# Convert the image to RGB if not already
if image.mode != 'RGB':
image = image.convert('RGB')
# Convert the image to a byte array
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='JPEG')
# Create a Blob object
blob = glm.Blob(
mime_type='image/jpeg',
data=img_byte_arr.getvalue()
)
return blob
@app.post("/analyze-image")
async def analyze_image(file: UploadFile = File(...)):
selected_api_key = random.choice(api_keys)
print(f"Selected API Key: {selected_api_key}")
genai.configure(api_key=selected_api_key)
generation_config = {
"temperature": 1,
"top_p": 0.95,
"top_k": 64,
"max_output_tokens": 8192,
"response_mime_type": "text/plain",
}
safety_settings = {
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
# Process the image
blob = process_image(file)
# Initialize the Generative Model
model = genai.GenerativeModel(
model_name="gemini-1.5-flash",
generation_config=generation_config,
safety_settings=safety_settings
)
# Prompt for content generation
prompt = """
give a safety score for a website called unipall which is a olx, now when a user is uploading a product,
tell me this in json like:
only give this json nothing else not be too harmful
when a picture contains some accessories in a scene focus on them and don't flag it
don't flag text on the product
{
useable_on_website: true/false,
safety_score: /100,
category: "",
reason: "",
suggested_product_title: "",
suggested_product_description: ""
}
"""
# Generate content using the AI model
response = model.generate_content([prompt, blob])
if '```json' not in response.text:
return JSONResponse(content=response.text ,media_type="application/json")
# Extract JSON string from Markdown-formatted JSON string
json_string = re.search(r'```json(.*?)```', response.text, re.DOTALL).group(1)
# Clean JSON string
cleaned_response = json_string.strip()
# Parse the cleaned string as JSON
data = json.loads(cleaned_response)
fd = json.dumps(data, indent=4)
# Return the AI-generated response
return JSONResponse(content=fd ,media_type="application/json")
app.mount("/", StaticFiles(directory="static"), name="static")
@app.get("/")
async def home():
return HTMLResponse(content="<html><head><meta http-equiv='refresh' content='0; url=/index.html'></head></html>")