Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,621 Bytes
426e73b ad7e7b4 6c899dd 426e73b 6c899dd 426e73b ad7e7b4 426e73b 6c899dd 426e73b 6c899dd 426e73b 6c899dd 426e73b ba59054 426e73b 559b321 426e73b 6c303c0 426e73b 9694456 426e73b 6c303c0 426e73b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
from transformers import pipeline, SamModel, SamProcessor
import torch
import numpy as np
import spaces
from PIL import Image, ImageDraw
# Load models (unchanged)
checkpoint = "google/owlvit-base-patch16"
detector = pipeline(model=checkpoint, task="zero-shot-object-detection")
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to("cuda")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
@spaces.GPU
def query(image, texts, threshold):
texts = texts.split(",")
# --- Object Detection (unchanged) ---
predictions = detector(
image,
candidate_labels=texts,
threshold=threshold
)
result_labels = []
draw = ImageDraw.Draw(image) # Create a drawing object for the image
for pred in predictions:
box = pred["box"]
score = pred["score"]
label = pred["label"]
# Round box coordinates for display and SAM input (mostly unchanged)
box = [round(coord, 2) for coord in list(box.values())]
# --- Segmentation (unchanged) ---
inputs = sam_processor(
image,
input_boxes=[[[box]]], # Note: SAM expects a nested list
return_tensors="pt"
).to("cuda")
with torch.no_grad():
outputs = sam_model(**inputs)
mask = sam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0][0][0].numpy()
mask = mask[np.newaxis, ...]
result_labels.append((mask, label))
# --- Draw Bounding Box ---
draw.rectangle(box, outline="red", width=3) # Draw rectangle with a red outline
draw.text((box[0], box[1] - 10), label, fill="red") # Add label above the box
return image, result_labels # Return the modified image
import gradio as gr
description = "This DSA2024 Demo Space combines OWLv2, a state-of-the-art zero-shot object detection model, with SAM, a state-of-the-art mask generation model. SAM normally doesn't accept text input. Combining SAM with OWLv2 makes SAM text promptable. Try the example or input an image and comma separated candidate labels to segment."
demo = gr.Interface(
query,
inputs=[gr.Image(type="pil", label="Image Input"), gr.Textbox(label = "Candidate Labels"), gr.Slider(0, 1, value=0.05, label="Confidence Threshold")],
outputs="annotatedimage",
title="OWL 🤝 SAM",
description=description,
examples=[
["./cats.png", "cat", 0.1],
],
cache_examples=True
)
demo.launch(debug=True) |