import asyncio import io import logging import os import threading import uuid import streamlit as st import torch from byaldi import RAGMultiModalModel from pdf2image import convert_from_bytes from PIL import Image from transformers import (AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig) from transformers.image_utils import load_image DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Capture logs #log_stream = io.StringIO() #logging.basicConfig(stream=log_stream, level=logging.INFO) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) if "session_id" not in st.session_state: st.session_state["session_id"] = str(uuid.uuid4()) # Generate unique session ID # Async function to load the model async def load_model_embedding_async(): st.session_state["loading_model_embedding"] = True # Show loading status await asyncio.sleep(0.1) # Allow UI updates model_embedding = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2") st.session_state["model_embedding"] = model_embedding st.session_state["loading_model_embedding"] = False # Model is ready # Function to run async function in a separate thread def load_model_embedding(): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(load_model_embedding_async()) # Start model loading in a background thread if "model_embedding" not in st.session_state: with st.status("Loading embedding model... ⏳"): threading.Thread(target=load_model_embedding, daemon=True).start() # Async function to load the model async def load_model_vlm_async(): st.session_state["loading_model_vlm"] = True # Show loading status await asyncio.sleep(0.1) # Allow UI updates checkpoint = "HuggingFaceTB/SmolVLM-Instruct" processor_vlm = AutoProcessor.from_pretrained(checkpoint) quantization_config = BitsAndBytesConfig(load_in_8bit=True) model_vlm = AutoModelForVision2Seq.from_pretrained( checkpoint, #torch_dtype=torch.bfloat16, quantization_config=quantization_config, ) st.session_state["model_vlm"] = model_vlm st.session_state["processor_vlm"] = processor_vlm st.session_state["loading_model_vlm"] = False # Model is ready # Function to run async function in a separate thread def load_model_vlm(): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(load_model_vlm_async()) # Start model loading in a background thread if "model_vlm" not in st.session_state: with st.status("Loading VLM model... ⏳"): threading.Thread(target=load_model_vlm, daemon=True).start() def save_images_to_local(dataset, output_folder="data/"): os.makedirs(output_folder, exist_ok=True) for image_id, image in enumerate(dataset): #if isinstance(image, str): # image = Image.open(image) output_path = os.path.join(output_folder, f"image_{image_id}.png") #image = Image.open(io.BytesIO(image_data)) image.save(output_path, format="PNG") # Home page UI with st.sidebar: "[Source Code](https://huggingface.co/spaces/deepakkarkala/multimodal-rag/tree/main)" st.title("📝 Image Q&A with VLM") #st.text_area("Logs:", log_stream.getvalue(), height=200) uploaded_pdf = st.file_uploader("Upload PDF file", type=("pdf")) query = st.text_input( "Ask something about the image", placeholder="Can you describe me the image ?", disabled=not uploaded_pdf, ) if st.session_state.get("loading_model_embedding", True): st.warning("Loading Embedding model....") else: st.success("Embedding Model loaded successfully! 🎉") if st.session_state.get("loading_model_vlm", True): st.warning("Loading VLM model....") else: st.success("VLM Model loaded successfully! 🎉") images = [] images_folder = "data/" + st.session_state["session_id"] + "/" index_name = "index_" + st.session_state["session_id"] if uploaded_pdf and "model_embedding" in st.session_state and "is_index_complete" not in st.session_state: images = convert_from_bytes(uploaded_pdf.getvalue()) save_images_to_local(images, output_folder=images_folder) # index documents using the document retrieval model st.session_state["model_embedding"].index( input_path=images_folder, index_name=index_name, store_collection_with_index=False, overwrite=True ) logging.info(f"{len(images)} number of images extracted from PDF and indexed") st.session_state["is_index_complete"] = True if uploaded_pdf and query and "model_embedding" in st.session_state and "model_vlm" in st.session_state: docs_retrieved = st.session_state["model_embedding"].search(query, k=1) logging.info(f"{len(docs_retrieved)} number of images retrieved as relevant to query") image_id = docs_retrieved[0]["doc_id"] logging.info(f"Image id:{image_id} retrieved" ) image_similar_to_query = images[image_id] model_vlm, processor_vlm = st.session_state["model_vlm"], st.session_state["processor_vlm"] # Create input messages system_prompt = "You are an AI assistant. Your task is reply to user questions based on the provided image context." chat_template = [ {"role": "system", "content": system_prompt}, { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": query} ] }, ] # Prepare inputs prompt = processor_vlm.apply_chat_template(chat_template, add_generation_prompt=True) inputs = processor_vlm(text=prompt, images=[image_similar_to_query], return_tensors="pt") inputs = inputs.to(DEVICE) # Generate outputs generated_ids = model_vlm.generate(**inputs, max_new_tokens=500) #generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] generated_texts = processor_vlm.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False, ) response = generated_texts[0] st.write(response)