|
import streamlit as st |
|
from PIL import Image |
|
from pdf2image import convert_from_path |
|
from byaldi import RAGMultiModalModel |
|
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor |
|
from qwen_vl_utils import process_vision_info |
|
import torch |
|
import time |
|
import json |
|
import re |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
@st.cache_resource |
|
def initialize_models(): |
|
|
|
multimodal_model = RAGMultiModalModel.from_pretrained("vidore/colpali") |
|
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained( |
|
"Qwen/Qwen2-VL-7B-Instruct", |
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16 |
|
).to(device).eval() |
|
|
|
qwen_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True) |
|
|
|
return multimodal_model, qwen_model, qwen_processor |
|
|
|
multimodal_model, qwen_model, qwen_processor = initialize_models() |
|
|
|
|
|
st.title("Document Text Extraction") |
|
doc_file = st.file_uploader("Upload Image File", type=[ "png", "jpg", "jpeg"]) |
|
|
|
|
|
if "document_text" not in st.session_state: |
|
st.session_state.document_text = None |
|
|
|
if doc_file is not None: |
|
|
|
file_ext = doc_file.name.split('.')[-1].lower() |
|
document_image = Image.open(doc_file) |
|
|
|
|
|
st.image(document_image, caption="Document Preview", use_column_width=True) |
|
|
|
|
|
index_id = f"doc_index_{int(time.time())}" |
|
|
|
|
|
if st.session_state.document_text is None: |
|
st.write(f"Indexing document with unique ID: {index_id}...") |
|
temp_image_path = "temp_image.png" |
|
document_image.save(temp_image_path) |
|
|
|
|
|
multimodal_model.index( |
|
input_path=temp_image_path, |
|
index_name=index_id, |
|
store_collection_with_index=False, |
|
overwrite=False |
|
) |
|
|
|
|
|
extraction_query = "Extract all English and Hindi text from this document" |
|
st.write("Querying the document with the extraction query...") |
|
|
|
|
|
search_results = multimodal_model.search(extraction_query, k=1) |
|
|
|
|
|
input_message = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "image", "image": document_image}, |
|
{"type": "text", "text": extraction_query}, |
|
], |
|
} |
|
] |
|
|
|
|
|
input_text = qwen_processor.apply_chat_template(input_message, tokenize=False, add_generation_prompt=True) |
|
vision_inputs, _ = process_vision_info(input_message) |
|
|
|
model_inputs = qwen_processor( |
|
text=[input_text], |
|
images=vision_inputs, |
|
padding=True, |
|
return_tensors="pt", |
|
) |
|
|
|
model_inputs = model_inputs.to(device) |
|
|
|
|
|
st.write("Generating extracted text...") |
|
output_ids = qwen_model.generate(**model_inputs, max_new_tokens=100) |
|
|
|
generated_ids_trimmed = [ |
|
out_ids[len(in_ids):] for in_ids, out_ids in zip(model_inputs.input_ids, output_ids) |
|
] |
|
|
|
extracted_output = qwen_processor.batch_decode( |
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
) |
|
|
|
|
|
st.session_state.document_text = extracted_output[0] |
|
|
|
|
|
extracted_text_content = st.session_state.document_text |
|
structured_data = {"extracted_text": extracted_text_content} |
|
|
|
st.subheader("Extracted Text in JSON:") |
|
st.json(structured_data) |
|
|
|
|
|
if st.session_state.document_text: |
|
with st.form(key='text_search_form'): |
|
search_input = st.text_input("Enter a keyword to search within the extracted text:") |
|
search_action = st.form_submit_button("Search") |
|
|
|
if search_action and search_input: |
|
|
|
full_text = st.session_state.document_text |
|
lines = full_text.split('\n') |
|
|
|
results = [] |
|
|
|
for line in lines: |
|
if re.search(re.escape(search_input), line, re.IGNORECASE): |
|
|
|
highlighted_line = re.sub(f"({re.escape(search_input)})", r"**\1**", line, flags=re.IGNORECASE) |
|
results.append(highlighted_line) |
|
|
|
|
|
st.subheader("Search Results:") |
|
if not results: |
|
st.write("No matches found.") |
|
else: |
|
for result in results: |
|
st.markdown(result) |
|
|