import os import torch from transformers import OwlViTProcessor, OwlViTForObjectDetection import warnings import numpy as np from PIL import Image from io import BytesIO import streamlit as st import matplotlib.pyplot as plt import io import matplotlib.colors as mcolors # setttings os.environ['CUDA_VISIBLE_DEVICES'] = '1' warnings.filterwarnings('ignore') st.set_page_config() class owl_vit: def __init__(self, image_path, text, threshold): self.image_path = image_path self.text = text self.threshold = threshold def process(self, processor, model): image = Image.open(self.image_path) if len(image.split()) == 1: image = image.convert("RGB") inputs = processor(text=[self.text], images=[image], return_tensors="pt") outputs = model(**inputs) target_sizes = torch.tensor([[image.height, image.width] for image in [image]]) self.results = processor.post_process(outputs=outputs, target_sizes=target_sizes) self.image = image return self.result_image() def result_image(self): boxes, scores, labels = self.results[0]["boxes"], self.results[0]["scores"], self.results[0]["labels"] plt.imshow(self.image) ax = plt.gca() for box, score, label in zip(boxes, scores, labels): if score >= self.threshold: box = box.detach().numpy() color = list(mcolors.CSS4_COLORS.keys())[label] ax.add_patch(plt.Rectangle(box[:2], box[2] - box[0], box[3] - box[1], fill=False, color=color, linewidth=3,)) ax.text(box[0], box[1], f"{self.text[label]}: {round(score.item(), 2)}", fontsize=15, color=color) plt.tight_layout() img_buf = io.BytesIO() plt.savefig(img_buf, format='png') image = Image.open(img_buf) return image def load_model(): with st.spinner('Getting Neruons in Order ...'): processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch16") model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch16") return processor, model def show_detects(image): st.title("Results") st.image(image, use_column_width=True, caption="Object Detection Results", clamp=True) def process(upload, text, threshold): # save upload to file filetype = upload.name.split('.')[-1] name = len(os.listdir("images")) + 1 file_path = os.path.join('images', f'{name}.{filetype}') with open(file_path, "wb") as f: f.write(upload.getbuffer()) # predict detections and show results detector = owl_vit(file_path, text, threshold) results = detector.process(processor, model) show_detects(results) # clean up - if over 1000 images in folder, delete oldest 1 if len(os.listdir("images")) > 1000: oldest = min(os.listdir("images"), key=os.path.getctime) os.remove(os.path.join("images", oldest)) def main(processor, model): # splash image st.image(os.path.join('refs', 'baseball_labeled.png'), use_column_width=True) # title project descriptions st.title("OWL-ViT") st.markdown("**OWL-ViT** is a zero-shot text-conditioned object detection model. OWL-ViT uses CLIP as its multi-modal \ backbone, with a ViT-like Transformer to get visual features and a causal language model to get the text features. \ To use CLIP for detection, OWL-ViT removes the final token pooling layer of the vision model and attaches a \ lightweight classification and box head to each transformer output token. Open-vocabulary classification \ is enabled by replacing the fixed classification layer weights with the class-name embeddings obtained \ from the text model. The authors first train CLIP from scratch and fine-tune it end-to-end with the classification \ and box heads on standard detection datasets using a bipartite matching loss. One or multiple text queries per image \ can be used to perform zero-shot text-conditioned object detection.", unsafe_allow_html=True) # example if st.button("Run the Example Image/Text"): with st.spinner('Detecting Objects and Comparing Vocab...'): info = owl_vit(os.path.join('refs', 'baseball.jpg'), ["batter", "umpire", "catcher"], 0.50) results = info.process(processor, model) show_detects(results) if st.button("Clear Example"): st.markdown("") # upload col1, col2 = st.columns(2) threshold = st.slider('Confidence Threshold', min_value=0.0, max_value=1.0, value=0.1) with col1: upload = st.file_uploader('Image:', type=['jpg', 'jpeg', 'png']) with col2: text = st.text_area('Objects to Detect: (comma, seperated)', "batter, umpire, catcher") text = [x.strip() for x in text.split(',')] # process if upload is not None and text is not None: filetype = upload.name.split('.')[-1] if filetype in ['jpg', 'jpeg', 'png']: with st.spinner('Detecting and Counting Single Image...'): process(upload, text, threshold) else: st.warning('Unsupported file type.') if __name__ == '__main__': processor, model = load_model() main(processor, model)