Spaces:
Sleeping
Sleeping
import chromadb | |
from chromadb.utils import embedding_functions | |
from chromadb.config import Settings | |
from transformers import pipeline | |
import streamlit as st | |
import fitz # PyMuPDF for PDF parsing | |
from PIL import Image | |
# Configure ChromaDB with persistent SQLite database | |
config = Settings( | |
persist_directory="./chromadb_data", | |
chroma_db_impl="sqlite", | |
) | |
# Initialize persistent client with SQLite | |
def setup_chromadb(): | |
client = chromadb.PersistentClient(path="./chromadb_data") | |
collection = client.get_or_create_collection( | |
name="pdf_data", | |
embedding_function=chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name="sentence-transformers/all-MiniLM-L6-v2" | |
), | |
) | |
return client, collection | |
# Clear the collection | |
def clear_collection(client, collection_name): | |
# Delete the collection and recreate it | |
client.delete_collection(name=collection_name) | |
return client.get_or_create_collection( | |
name=collection_name, | |
embedding_function=chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name="sentence-transformers/all-MiniLM-L6-v2" | |
), | |
) | |
def extract_text_from_pdf(uploaded_file): | |
with fitz.open(stream=uploaded_file.read(), filetype="pdf") as doc: | |
text = "" | |
for page in doc: | |
text += page.get_text() | |
return text | |
def add_pdf_text_to_db(collection, pdf_text): | |
sentences = pdf_text.split("\n") # Split text into lines for granularity | |
for idx, sentence in enumerate(sentences): | |
if sentence.strip(): # Avoid empty lines | |
collection.add( | |
ids=[f"pdf_text_{idx}"], | |
documents=[sentence], | |
metadatas={"line_number": idx, "text": sentence} | |
) | |
def query_pdf_data(collection, query, retriever_model): | |
results = collection.query( | |
query_texts=[query], | |
n_results=3 | |
) | |
context = " ".join([doc for doc in results["documents"][0]]) | |
answer = retriever_model(f"Context: {context}\nQuestion: {query}") | |
return answer, results["metadatas"] | |
# Streamlit Interface | |
def main(): | |
image = Image.open('LOGO.PNG') | |
st.image( | |
image, width=250) | |
st.title("PDF Chatbot with Retrieval-Augmented Generation") | |
st.write("Upload a PDF, and ask questions about its content!") | |
# Initialize components | |
client, collection = setup_chromadb() | |
retriever_model = pipeline("text2text-generation", model="google/flan-t5-small") # Free LLM | |
# File upload | |
uploaded_file = st.file_uploader("Upload your PDF file", type="pdf") | |
if uploaded_file: | |
try: | |
# Clear existing data | |
collection = clear_collection(client, "pdf_data") | |
st.info("Existing data cleared from the database.") | |
# Extract and add new data | |
pdf_text = extract_text_from_pdf(uploaded_file) | |
st.success("Text extracted successfully!") | |
st.text_area("Extracted Text:", pdf_text, height=300) | |
add_pdf_text_to_db(collection, pdf_text) | |
st.success("PDF text has been added to the database. You can now query it!") | |
except Exception as e: | |
st.error(f"Error extracting text: {e}") | |
query = st.text_input("Enter your query about the PDF:") | |
if query: | |
try: | |
answer, metadata = query_pdf_data(collection, query, retriever_model) | |
st.subheader("Answer:") | |
st.write(answer[0]['generated_text']) | |
st.subheader("Retrieved Context:") | |
st.write(answer) | |
for meta in metadata[0]: | |
st.write(meta) | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
if __name__ == "__main__": | |
main() | |