File size: 2,621 Bytes
426e73b
 
 
ad7e7b4
6c899dd
426e73b
6c899dd
426e73b
 
 
 
 
ad7e7b4
426e73b
6c899dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426e73b
6c899dd
426e73b
 
 
6c899dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426e73b
 
 
ba59054
426e73b
 
559b321
426e73b
 
6c303c0
426e73b
9694456
426e73b
6c303c0
426e73b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from transformers import pipeline, SamModel, SamProcessor
import torch
import numpy as np
import spaces
from PIL import Image, ImageDraw  

# Load models (unchanged)
checkpoint = "google/owlvit-base-patch16"
detector = pipeline(model=checkpoint, task="zero-shot-object-detection")
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to("cuda")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

@spaces.GPU
def query(image, texts, threshold):
    texts = texts.split(",")

    # --- Object Detection (unchanged) ---
    predictions = detector(
        image,
        candidate_labels=texts,
        threshold=threshold
    )

    result_labels = []
    draw = ImageDraw.Draw(image)  # Create a drawing object for the image

    for pred in predictions:
        box = pred["box"]
        score = pred["score"]
        label = pred["label"]

        # Round box coordinates for display and SAM input (mostly unchanged)
        box = [round(coord, 2) for coord in list(box.values())]

        # --- Segmentation (unchanged) ---
        inputs = sam_processor(
            image,
            input_boxes=[[[box]]],  # Note: SAM expects a nested list
            return_tensors="pt"
        ).to("cuda")

        with torch.no_grad():
            outputs = sam_model(**inputs)

        mask = sam_processor.image_processor.post_process_masks(
            outputs.pred_masks.cpu(),
            inputs["original_sizes"].cpu(),
            inputs["reshaped_input_sizes"].cpu()
        )[0][0][0].numpy()
        mask = mask[np.newaxis, ...]
        result_labels.append((mask, label))

        # --- Draw Bounding Box ---
        draw.rectangle(box, outline="red", width=3)  # Draw rectangle with a red outline
        draw.text((box[0], box[1] - 10), label, fill="red")  # Add label above the box

    return image, result_labels  # Return the modified image


import gradio as gr

description = "This DSA2024 Demo Space  combines OWLv2, a state-of-the-art zero-shot object detection model, with SAM, a state-of-the-art mask generation model. SAM normally doesn't accept text input. Combining SAM with OWLv2 makes SAM text promptable. Try the example or input an image and comma separated candidate labels to segment."
demo = gr.Interface(
    query,
    inputs=[gr.Image(type="pil", label="Image Input"), gr.Textbox(label = "Candidate Labels"), gr.Slider(0, 1, value=0.05, label="Confidence Threshold")],
    outputs="annotatedimage",
    title="OWL 🤝 SAM",
    description=description,
    examples=[
        ["./cats.png", "cat", 0.1],
    ],
    cache_examples=True
)
demo.launch(debug=True)