omni-final / main.py
banao-tech's picture
Update main.py
3f593e3 verified
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing import Optional
import base64
import io
from PIL import Image
import torch
import numpy as np
import os
import logging
# Existing imports
from utils import (
check_ocr_box,
get_yolo_model,
get_caption_model_processor,
get_som_labeled_img,
)
from ultralytics import YOLO
from transformers import AutoProcessor, AutoModelForCausalLM
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# main.py (YOLO loading fix)
from utils import get_yolo_model
import torch
# Load YOLO model using official method
yolo_model = get_yolo_model(model_path="weights/icon_detect/best.pt")
# Handle device placement
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if str(device) == "cuda":
yolo_model = yolo_model.cuda()
else:
yolo_model = yolo_model.cpu()
# Load caption model and processor
try:
processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base", trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=torch.float16,
trust_remote_code=True,
).to("cuda")
except Exception as e:
logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.")
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=torch.float16,
trust_remote_code=True,
)
caption_model_processor = {"processor": processor, "model": model}
logger.info("Finished loading models!!!")
app = FastAPI()
class ProcessResponse(BaseModel):
image: str # Base64 encoded image
parsed_content_list: str
label_coordinates: str
def process(
image_input: Image.Image, box_threshold: float, iou_threshold: float
) -> ProcessResponse:
try:
# Save the input image temporarily
image_save_path = "imgs/saved_image_demo.png"
os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
image_input.save(image_save_path)
image = Image.open(image_save_path)
# Calculate box overlay ratio
box_overlay_ratio = image.size[0] / 3200
draw_bbox_config = {
"text_scale": 0.8 * box_overlay_ratio,
"text_thickness": max(int(2 * box_overlay_ratio), 1),
"text_padding": max(int(3 * box_overlay_ratio), 1),
"thickness": max(int(3 * box_overlay_ratio), 1),
}
# Perform OCR and get bounding boxes
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
image_save_path,
display_img=False,
output_bb_format="xyxy",
goal_filtering=None,
easyocr_args={"paragraph": False, "text_threshold": 0.9},
use_paddleocr=True,
)
text, ocr_bbox = ocr_bbox_rslt
# Get labeled image and coordinates
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
image_save_path,
yolo_model,
BOX_TRESHOLD=box_threshold,
output_coord_in_ratio=True,
ocr_bbox=ocr_bbox,
draw_bbox_config=draw_bbox_config,
caption_model_processor=caption_model_processor,
ocr_text=text,
iou_threshold=iou_threshold,
)
# Ensure dino_labled_img is a base64-encoded string
if isinstance(dino_labled_img, bytes):
dino_labled_img = base64.b64encode(dino_labled_img).decode("utf-8")
elif not isinstance(dino_labled_img, str):
raise ValueError("dino_labled_img must be a base64-encoded string or bytes")
# Decode the base64 image and re-encode it to ensure consistency
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
# Prepare parsed content list
parsed_content_list_str = "\n".join(parsed_content_list)
return ProcessResponse(
image=img_str,
parsed_content_list=parsed_content_list_str,
label_coordinates=str(label_coordinates),
)
except Exception as e:
logger.error(f"Error in process function: {e}")
raise
@app.post("/process_image", response_model=ProcessResponse)
async def process_image(
image_file: UploadFile = File(...),
box_threshold: float = 0.05,
iou_threshold: float = 0.1,
):
try:
contents = await image_file.read()
image_input = Image.open(io.BytesIO(contents)).convert("RGB")
# Log image details
logger.info(f"Processing image: {image_file.filename}")
logger.info(f"Image size: {image_input.size}")
# Process the image
response = process(image_input, box_threshold, iou_threshold)
# Validate response
if not response.image:
raise ValueError("Empty image in response")
return response
except Exception as e:
logger.error(f"Error in process_image endpoint: {e}")
raise HTTPException(status_code=500, detail=str(e))