|
import io |
|
import os |
|
import uuid |
|
|
|
import streamlit as st |
|
import torch |
|
from byaldi import RAGMultiModalModel |
|
from pdf2image import convert_from_bytes |
|
from PIL import Image |
|
from transformers import (AutoModelForVision2Seq, AutoProcessor, |
|
BitsAndBytesConfig) |
|
from transformers.image_utils import load_image |
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
if "session_id" not in st.session_state: |
|
st.session_state["session_id"] = str(uuid.uuid4()) |
|
|
|
|
|
@st.cache_resource |
|
def load_model_embedding(): |
|
|
|
|
|
docs_retrieval_model = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2") |
|
return docs_retrieval_model |
|
model_embedding = load_model_embedding() |
|
|
|
@st.cache_resource |
|
def load_model_vlm(): |
|
checkpoint = "HuggingFaceTB/SmolVLM-Instruct" |
|
processor = AutoProcessor.from_pretrained(checkpoint) |
|
quantization_config = BitsAndBytesConfig(load_in_8bit=True) |
|
model = AutoModelForVision2Seq.from_pretrained( |
|
checkpoint, |
|
|
|
quantization_config=quantization_config, |
|
) |
|
return model, processor |
|
model_vlm, processor_vlm = load_model_vlm() |
|
|
|
|
|
|
|
def save_images_to_local(dataset, output_folder="data/"): |
|
os.makedirs(output_folder, exist_ok=True) |
|
|
|
for image_id, image in enumerate(dataset): |
|
|
|
|
|
|
|
output_path = os.path.join(output_folder, f"image_{image_id}.png") |
|
|
|
image.save(output_path, format="PNG") |
|
|
|
|
|
|
|
|
|
with st.sidebar: |
|
"[Source Code](https://huggingface.co/spaces/deepakkarkala/multimodal-rag/tree/main)" |
|
|
|
st.title("π Image Q&A with VLM") |
|
uploaded_pdf = st.file_uploader("Upload PDF file", type=("pdf")) |
|
query = st.text_input( |
|
"Ask something about the image", |
|
placeholder="Can you describe me the image ?", |
|
disabled=not uploaded_pdf, |
|
) |
|
|
|
images = [] |
|
images_folder = "data/" + st.session_state["session_id"] + "/" |
|
index_name = "index_" + st.session_state["session_id"] |
|
|
|
|
|
if uploaded_pdf and "is_index_complete" not in st.session_state: |
|
images = convert_from_bytes(uploaded_pdf.getvalue()) |
|
save_images_to_local(images, output_folder=images_folder) |
|
|
|
model_embedding.index( |
|
input_path=images_folder, index_name=index_name, store_collection_with_index=False, overwrite=True |
|
) |
|
st.session_state["is_index_complete"] = True |
|
|
|
|
|
|
|
if uploaded_pdf and query: |
|
docs_retrieved = model_embedding.search(query, k=1) |
|
image_similar_to_query = images[docs_retrieved[0]["doc_id"]] |
|
|
|
|
|
system_prompt = "You are an AI assistant. Your task is reply to user questions based on the provided image context." |
|
chat_template = [ |
|
{"role": "system", "content": system_prompt}, |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "image"}, |
|
{"type": "text", "text": query} |
|
] |
|
}, |
|
] |
|
|
|
|
|
prompt = processor_vlm.apply_chat_template(chat_template, add_generation_prompt=True) |
|
inputs = processor_vlm(text=prompt, images=[image_similar_to_query], return_tensors="pt") |
|
inputs = inputs.to(DEVICE) |
|
|
|
|
|
generated_ids = model_vlm.generate(**inputs, max_new_tokens=500) |
|
|
|
|
|
generated_texts = processor_vlm.batch_decode( |
|
generated_ids, |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=False, |
|
) |
|
response = generated_texts[0] |
|
|
|
st.write(response) |
|
|