multimodal-rag / app.py
deepakkarkala's picture
Running indexing operation once
cfeb389
raw
history blame
3.97 kB
import io
import os
import uuid
import streamlit as st
import torch
from byaldi import RAGMultiModalModel
from pdf2image import convert_from_bytes
from PIL import Image
from transformers import (AutoModelForVision2Seq, AutoProcessor,
BitsAndBytesConfig)
from transformers.image_utils import load_image
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if "session_id" not in st.session_state:
st.session_state["session_id"] = str(uuid.uuid4()) # Generate unique session ID
@st.cache_resource # Streamlit Caching decorator
def load_model_embedding():
#docs_retrieval_model = RAGMultiModalModel.from_pretrained("vidore/colsmolvlm-alpha")
#docs_retrieval_model = RAGMultiModalModel.from_pretrained("vidore/colqwen2-v1.0")
docs_retrieval_model = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2")
return docs_retrieval_model
model_embedding = load_model_embedding()
@st.cache_resource # Streamlit Caching decorator
def load_model_vlm():
checkpoint = "HuggingFaceTB/SmolVLM-Instruct"
processor = AutoProcessor.from_pretrained(checkpoint)
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForVision2Seq.from_pretrained(
checkpoint,
#torch_dtype=torch.bfloat16,
quantization_config=quantization_config,
)
return model, processor
model_vlm, processor_vlm = load_model_vlm()
def save_images_to_local(dataset, output_folder="data/"):
os.makedirs(output_folder, exist_ok=True)
for image_id, image in enumerate(dataset):
#if isinstance(image, str):
# image = Image.open(image)
output_path = os.path.join(output_folder, f"image_{image_id}.png")
#image = Image.open(io.BytesIO(image_data))
image.save(output_path, format="PNG")
# Home page UI
with st.sidebar:
"[Source Code](https://huggingface.co/spaces/deepakkarkala/multimodal-rag/tree/main)"
st.title("πŸ“ Image Q&A with VLM")
uploaded_pdf = st.file_uploader("Upload PDF file", type=("pdf"))
query = st.text_input(
"Ask something about the image",
placeholder="Can you describe me the image ?",
disabled=not uploaded_pdf,
)
images = []
images_folder = "data/" + st.session_state["session_id"] + "/"
index_name = "index_" + st.session_state["session_id"]
if uploaded_pdf and "is_index_complete" not in st.session_state:
images = convert_from_bytes(uploaded_pdf.getvalue())
save_images_to_local(images, output_folder=images_folder)
# index documents using the document retrieval model
model_embedding.index(
input_path=images_folder, index_name=index_name, store_collection_with_index=False, overwrite=True
)
st.session_state["is_index_complete"] = True
if uploaded_pdf and query:
docs_retrieved = model_embedding.search(query, k=1)
image_similar_to_query = images[docs_retrieved[0]["doc_id"]]
# Create input messages
system_prompt = "You are an AI assistant. Your task is reply to user questions based on the provided image context."
chat_template = [
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": query}
]
},
]
# Prepare inputs
prompt = processor_vlm.apply_chat_template(chat_template, add_generation_prompt=True)
inputs = processor_vlm(text=prompt, images=[image_similar_to_query], return_tensors="pt")
inputs = inputs.to(DEVICE)
# Generate outputs
generated_ids = model_vlm.generate(**inputs, max_new_tokens=500)
#generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
generated_texts = processor_vlm.batch_decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
response = generated_texts[0]
st.write(response)