deepakkarkala commited on
Commit
1564fda
Β·
1 Parent(s): 04743bf

Basic retrieval and generation

Browse files
Files changed (2) hide show
  1. app.py +49 -17
  2. requirements.txt +3 -0
app.py CHANGED
@@ -1,7 +1,10 @@
1
  import io
 
2
 
3
  import streamlit as st
4
  import torch
 
 
5
  from PIL import Image
6
  from transformers import (AutoModelForVision2Seq, AutoProcessor,
7
  BitsAndBytesConfig)
@@ -9,9 +12,13 @@ from transformers.image_utils import load_image
9
 
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
 
 
 
 
 
12
 
13
  @st.cache_resource # Streamlit Caching decorator
14
- def load_model():
15
  checkpoint = "HuggingFaceTB/SmolVLM-Instruct"
16
  processor = AutoProcessor.from_pretrained(checkpoint)
17
  quantization_config = BitsAndBytesConfig(load_in_8bit=True)
@@ -20,51 +27,76 @@ def load_model():
20
  #torch_dtype=torch.bfloat16,
21
  quantization_config=quantization_config,
22
  )
23
- return model
24
- model = load_model()
25
 
26
 
27
- with st.sidebar:
28
- "[Source Code](https://huggingface.co/spaces/deepakkarkala/multimodal-rag/tree/main)"
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  # Home page UI
 
 
 
32
  st.title("πŸ“ Image Q&A with VLM")
33
- uploaded_file = st.file_uploader("Upload an image", type=("png", "jpg"))
34
- question = st.text_input(
35
  "Ask something about the image",
36
  placeholder="Can you describe me the image ?",
37
- disabled=not uploaded_file,
38
  )
39
 
 
 
 
 
 
 
 
 
40
 
41
- if uploaded_file and question:
42
- image_bytes = uploaded_file.read()
43
- image = Image.open(io.BytesIO(image_bytes))
 
 
44
 
45
  # Create input messages
46
  system_prompt = "You are an AI assistant. Your task is reply to user questions based on the provided image context."
47
- messages = [
48
  {"role": "system", "content": system_prompt},
49
  {
50
  "role": "user",
51
  "content": [
52
  {"type": "image"},
53
- {"type": "text", "text": question}
54
  ]
55
  },
56
  ]
57
 
58
  # Prepare inputs
59
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
60
- inputs = processor(text=prompt, images=[image], return_tensors="pt")
61
  inputs = inputs.to(DEVICE)
62
 
63
  # Generate outputs
64
- generated_ids = model.generate(**inputs, max_new_tokens=500)
65
- generated_texts = processor.batch_decode(
 
 
66
  generated_ids,
67
  skip_special_tokens=True,
 
68
  )
69
  response = generated_texts[0]
70
 
 
1
  import io
2
+ import os
3
 
4
  import streamlit as st
5
  import torch
6
+ from byaldi import RAGMultiModalModel
7
+ from pdf2image import convert_from_bytes
8
  from PIL import Image
9
  from transformers import (AutoModelForVision2Seq, AutoProcessor,
10
  BitsAndBytesConfig)
 
12
 
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
+ @st.cache_resource # Streamlit Caching decorator
16
+ def load_model_embedding():
17
+ docs_retrieval_model = RAGMultiModalModel.from_pretrained("vidore/colsmolvlm-alpha")
18
+ model_embedding = load_model_embedding()
19
 
20
  @st.cache_resource # Streamlit Caching decorator
21
+ def load_model_vlm():
22
  checkpoint = "HuggingFaceTB/SmolVLM-Instruct"
23
  processor = AutoProcessor.from_pretrained(checkpoint)
24
  quantization_config = BitsAndBytesConfig(load_in_8bit=True)
 
27
  #torch_dtype=torch.bfloat16,
28
  quantization_config=quantization_config,
29
  )
30
+ return model, processor
31
+ model_vlm, processor_vlm = load_model_vlm()
32
 
33
 
34
+
35
+ def save_images_to_local(dataset, output_folder="data/"):
36
+ os.makedirs(output_folder, exist_ok=True)
37
+
38
+ for image_id, image in enumerate(dataset):
39
+ #if isinstance(image, str):
40
+ # image = Image.open(image)
41
+
42
+ output_path = os.path.join(output_folder, f"image_{image_id}.png")
43
+ image.save(output_path, format="PNG")
44
+
45
 
46
 
47
  # Home page UI
48
+ with st.sidebar:
49
+ "[Source Code](https://huggingface.co/spaces/deepakkarkala/multimodal-rag/tree/main)"
50
+
51
  st.title("πŸ“ Image Q&A with VLM")
52
+ uploaded_pdf = st.file_uploader("Upload PDF file", type=("pdf"))
53
+ query = st.text_input(
54
  "Ask something about the image",
55
  placeholder="Can you describe me the image ?",
56
+ disabled=not uploaded_pdf,
57
  )
58
 
59
+ images = []
60
+ if uploaded_pdf:
61
+ images = convert_from_bytes(uploaded_pdf.getvalue())
62
+ save_images_to_local(images)
63
+ # index documents using the document retrieval model
64
+ model_embedding.index(
65
+ input_path="data/", index_name="image_index", store_collection_with_index=False, overwrite=True
66
+ )
67
 
68
+
69
+
70
+ if uploaded_pdf and query:
71
+ docs_retrieved = model_embedding.search(query, k=1)
72
+ image_similar_to_query = images[docs_retrieved[0]["doc_id"]]
73
 
74
  # Create input messages
75
  system_prompt = "You are an AI assistant. Your task is reply to user questions based on the provided image context."
76
+ chat_template = [
77
  {"role": "system", "content": system_prompt},
78
  {
79
  "role": "user",
80
  "content": [
81
  {"type": "image"},
82
+ {"type": "text", "text": query}
83
  ]
84
  },
85
  ]
86
 
87
  # Prepare inputs
88
+ prompt = processor_vlm.apply_chat_template(chat_template, add_generation_prompt=True)
89
+ inputs = processor_vlm(text=prompt, images=[image_similar_to_query], return_tensors="pt")
90
  inputs = inputs.to(DEVICE)
91
 
92
  # Generate outputs
93
+ generated_ids = model_vlm.generate(**inputs, max_new_tokens=500)
94
+ #generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
95
+
96
+ generated_texts = processor_vlm.batch_decode(
97
  generated_ids,
98
  skip_special_tokens=True,
99
+ clean_up_tokenization_spaces=False,
100
  )
101
  response = generated_texts[0]
102
 
requirements.txt CHANGED
@@ -11,3 +11,6 @@ transformers
11
  accelerate>=0.26.0
12
  bitsandbytes
13
  pillow
 
 
 
 
11
  accelerate>=0.26.0
12
  bitsandbytes
13
  pillow
14
+ flash-attn
15
+ byaldi
16
+ pdf2image