import io import os import uuid import streamlit as st import torch from byaldi import RAGMultiModalModel from pdf2image import convert_from_bytes from PIL import Image from transformers import (AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig) from transformers.image_utils import load_image DEVICE = "cuda" if torch.cuda.is_available() else "cpu" if "session_id" not in st.session_state: st.session_state["session_id"] = str(uuid.uuid4()) # Generate unique session ID @st.cache_resource # Streamlit Caching decorator def load_model_embedding(): #docs_retrieval_model = RAGMultiModalModel.from_pretrained("vidore/colsmolvlm-alpha") #docs_retrieval_model = RAGMultiModalModel.from_pretrained("vidore/colqwen2-v1.0") docs_retrieval_model = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2") return docs_retrieval_model model_embedding = load_model_embedding() @st.cache_resource # Streamlit Caching decorator def load_model_vlm(): checkpoint = "HuggingFaceTB/SmolVLM-Instruct" processor = AutoProcessor.from_pretrained(checkpoint) quantization_config = BitsAndBytesConfig(load_in_8bit=True) model = AutoModelForVision2Seq.from_pretrained( checkpoint, #torch_dtype=torch.bfloat16, quantization_config=quantization_config, ) return model, processor model_vlm, processor_vlm = load_model_vlm() def save_images_to_local(dataset, output_folder="data/"): os.makedirs(output_folder, exist_ok=True) for image_id, image in enumerate(dataset): #if isinstance(image, str): # image = Image.open(image) output_path = os.path.join(output_folder, f"image_{image_id}.png") #image = Image.open(io.BytesIO(image_data)) image.save(output_path, format="PNG") # Home page UI with st.sidebar: "[Source Code](https://huggingface.co/spaces/deepakkarkala/multimodal-rag/tree/main)" st.title("📝 Image Q&A with VLM") uploaded_pdf = st.file_uploader("Upload PDF file", type=("pdf")) query = st.text_input( "Ask something about the image", placeholder="Can you describe me the image ?", disabled=not uploaded_pdf, ) images = [] images_folder = "data/" + st.session_state["session_id"] + "/" index_name = "index_" + st.session_state["session_id"] if uploaded_pdf and "is_index_complete" not in st.session_state: images = convert_from_bytes(uploaded_pdf.getvalue()) save_images_to_local(images, output_folder=images_folder) # index documents using the document retrieval model model_embedding.index( input_path=images_folder, index_name=index_name, store_collection_with_index=False, overwrite=True ) st.session_state["is_index_complete"] = True if uploaded_pdf and query: docs_retrieved = model_embedding.search(query, k=1) image_similar_to_query = images[docs_retrieved[0]["doc_id"]] # Create input messages system_prompt = "You are an AI assistant. Your task is reply to user questions based on the provided image context." chat_template = [ {"role": "system", "content": system_prompt}, { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": query} ] }, ] # Prepare inputs prompt = processor_vlm.apply_chat_template(chat_template, add_generation_prompt=True) inputs = processor_vlm(text=prompt, images=[image_similar_to_query], return_tensors="pt") inputs = inputs.to(DEVICE) # Generate outputs generated_ids = model_vlm.generate(**inputs, max_new_tokens=500) #generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] generated_texts = processor_vlm.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False, ) response = generated_texts[0] st.write(response)