File size: 4,645 Bytes
82b20ab
 
dc9eaa5
82b20ab
 
3c574ec
82b20ab
3c574ec
82b20ab
3c574ec
82b20ab
 
 
 
3c574ec
ed26992
82b20ab
dc9eaa5
 
 
 
82b20ab
3c574ec
 
82b20ab
 
fadd9bd
a72fc60
ed26992
82b20ab
 
662c4db
82b20ab
3c574ec
fadd9bd
 
a72fc60
82b20ab
 
 
fadd9bd
b53f046
 
 
0601eb3
b53f046
82b20ab
f83c0ee
82b20ab
b53f046
82b20ab
2c6d023
f83c0ee
944b712
82b20ab
fadd9bd
82b20ab
 
b26ba8c
3c574ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65bfcd0
3c574ec
 
 
 
 
 
82b20ab
 
3c574ec
 
 
 
 
 
 
 
 
 
 
 
65bfcd0
3c574ec
 
 
 
 
 
65bfcd0
3c574ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82b20ab
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python

"""A demo of the DAB-DETR model."""

import pathlib
import tempfile

import cv2
import gradio as gr
import numpy as np
import PIL.Image
import spaces
import supervision as sv
import torch
import tqdm
from transformers import AutoProcessor, AutoModelForObjectDetection

DESCRIPTION = """
# DAB-DETR
##### [ArXiv](https://arxiv.org/abs/2201.12329) | [Docs](https://huggingface.co/docs/transformers/main/en/model_doc/dab-detr)
"""

MAX_NUM_FRAMES = 300

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

checkpoint = "IDEA-Research/dab-detr-resnet-50-dc5-pat3"
image_processor = AutoProcessor.from_pretrained(checkpoint)
model = AutoModelForObjectDetection.from_pretrained(checkpoint, device_map=device)


@spaces.GPU(duration=5)
@torch.inference_mode()
def process_image(image: PIL.Image.Image) -> tuple[PIL.Image.Image, list[dict]]:
    inputs = image_processor(images=image, return_tensors="pt").to(device)
    outputs = model(**inputs)
    results = image_processor.post_process_object_detection(
        outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=0.3
    )
    result = results[0]  # take first image results
    boxes_xyxy = result["boxes"].cpu().numpy()
    indexes = result["labels"].cpu().numpy()
    scores = result["scores"].cpu().numpy()
    text_labels = [
        f"{model.config.id2label[index]} [{score.item():.2f}]" for index, score in zip(indexes, scores)
    ]

    detections = sv.Detections(xyxy=boxes_xyxy, class_id=indexes, confidence=scores)
    bounding_box_annotator = sv.BoxAnnotator(color=sv.Color.WHITE, color_lookup=sv.ColorLookup.INDEX, thickness=1)
    label_annotator = sv.LabelAnnotator()

    # annotate bounding boxes
    annotated_frame = bounding_box_annotator.annotate(scene=image.copy(), detections=detections)
    annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=text_labels)

    return annotated_frame


@spaces.GPU(duration=90)
def process_video(
    video_path: str,
    progress: gr.Progress = gr.Progress(track_tqdm=True),  # noqa: ARG001, B008
) -> str:
    cap = cv2.VideoCapture(video_path)

    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    fps = cap.get(cv2.CAP_PROP_FPS)
    num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as out_file:
        writer = cv2.VideoWriter(out_file.name, fourcc, fps, (width, height))
        for _ in tqdm.auto.tqdm(range(min(MAX_NUM_FRAMES, num_frames))):
            ok, frame = cap.read()
            if not ok:
                break
            rgb_frame = frame[:, :, ::-1]
            annotated_frame = process_image(PIL.Image.fromarray(rgb_frame))
            writer.write(np.asarray(annotated_frame)[:, :, ::-1])
        writer.release()
    cap.release()
    return out_file.name


with gr.Blocks(css_paths="style.css") as demo:
    gr.Markdown(DESCRIPTION)

    with gr.Tabs():
        with gr.Tab("Image"):
            with gr.Row():
                with gr.Column():
                    input_image = gr.Image(label="Input Image", type="pil")
                    run_button_image = gr.Button()
                with gr.Column():
                    output_image = gr.Image(label="Output Image")
            gr.Examples(
                examples=sorted(pathlib.Path("images").glob("*.jpg")),
                inputs=input_image,
                outputs=[output_image],
                fn=process_image,
            )

            run_button_image.click(
                fn=process_image,
                inputs=input_image,
                outputs=[output_image],
            )

        with gr.Tab("Video"):
            gr.Markdown(f"The input video will be truncated to {MAX_NUM_FRAMES} frames.")

            with gr.Row():
                with gr.Column():
                    input_video = gr.Video(label="Input Video")
                    run_button_video = gr.Button()
                with gr.Column():
                    output_video = gr.Video(label="Output Video")

            gr.Examples(
                examples=sorted(pathlib.Path("videos").glob("*.mp4")),
                inputs=input_video,
                outputs=output_video,
                fn=process_video,
                cache_examples=False,
            )
            run_button_video.click(
                fn=process_video,
                inputs=input_video,
                outputs=output_video,
            )


if __name__ == "__main__":
    demo.queue(max_size=20).launch()