blip2_test / handler.py
advaitadasein's picture
uploaded handler and requirements for it
df87eca
raw
history blame
820 Bytes
import requests
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from typing import Dict, List, Any
import torch
class EndpointHandler():
def __init__(self, path=""):
self.processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
image = data.pop("inputs", data)
processed = self.processor(images=image, return_tensors="pt").to(self.device)
out = self.model.generate(**processed)
return self.processor.decode(out[0], skip_special_tokens=True)