import gradio as gr import yaml import fitz import warnings from PIL import Image from langchain_chroma import Chroma from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from langchain.document_loaders import PyPDFLoader from langchain.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.retrievers import BM25Retriever, EnsembleRetriever, ContextualCompressionRetriever from langchain.retrievers.document_compressors import FlashrankRerank from langchain_huggingface import HuggingFacePipeline, ChatHuggingFace, HuggingFaceEmbeddings from utils import get_page_from_documents warnings.filterwarnings("ignore") class PDFChatBot: def __init__(self, config_path="config/config.yaml"): """ Initializes the PDFChatBot instance by loading configurations and setting up default attributes. Args: config_path (str): Path to the YAML configuration file. """ self.processed = False self.current_page = 0 self.chat_history = [] self.config = self._load_config(config_path) # Initialize attributes self.documents = None self.embeddings = None self.retriever = None self.pipeline = None self.hyde_chain = None def _load_config(self, file_path): """ Load the YAML configuration file. Args: file_path (str): Path to the YAML file. Returns: dict: Parsed YAML configuration as a dictionary. """ try: with open(file_path, "r") as file: return yaml.safe_load(file) except yaml.YAMLError as e: raise RuntimeError(f"Failed to load config: {e}") def _initialize_embeddings(self): """ Load embeddings model using HuggingFace. """ model_name = self.config.get("EMBEDDING_MODEL") if not model_name: raise ValueError("Embedding model not specified in configuration.") self.embeddings = HuggingFaceEmbeddings(model_name=model_name) def _initialize_retriever(self): """ Create and set up the retriever with ensemble and contextual reranking. """ text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( chunk_size=100, chunk_overlap=25 ) docs = text_splitter.split_documents(self.documents) # Vector-based retriever vector_store = Chroma.from_documents(docs, self.embeddings) vector_retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 3}) # Keyword-based retriever keyword_retriever = BM25Retriever.from_documents(docs) keyword_retriever.k = 3 # Ensemble of retrievers ensemble_retriever = EnsembleRetriever( retrievers=[vector_retriever, keyword_retriever], weights=[0.5, 0.5], ) # Reranker compressor_model = self.config.get("RERANK_MODEL") if not compressor_model: raise ValueError("Reranker model not specified in configuration.") compressor = FlashrankRerank(model=compressor_model, top_n=5) # Contextual retriever self.retriever = ContextualCompressionRetriever( base_compressor=compressor, base_retriever=ensemble_retriever ) def _initialize_pipeline(self): """ Initialize the LLM pipeline for text generation. """ model_name = self.config.get("MODEL_DEEPSEEK_R1_1_5B") if not model_name: raise ValueError("LLM model not specified in configuration.") tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", trust_remote_code=True ) generation_pipeline = pipeline( task="text-generation", model=model, tokenizer=tokenizer, temperature=0.1, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0, do_sample=True, return_full_text=False, ) self.pipeline = ChatHuggingFace(llm=HuggingFacePipeline(pipeline=generation_pipeline)) def _create_hyde_chain(self): """ Create a HyDE chain for document generation and retrieval. """ hyde_template = """ You are an intelligent assistant. Generate a concise hypothetical answer or document based on the user's question. This will be used to retrieve relevant information. Question: {question} Hypothetical Answer: """ hyde_prompt = ChatPromptTemplate.from_template(hyde_template) generate_docs_chain = ( hyde_prompt | self.pipeline | StrOutputParser() ) self.hyde_chain = generate_docs_chain | self.retriever def process_file(self, file): """ Process the uploaded PDF file by loading its content and initializing necessary components. Args: file: File object for the uploaded PDF. """ loader = PyPDFLoader(file) self.documents = loader.load() self._initialize_embeddings() self._initialize_retriever() self._initialize_pipeline() self._create_hyde_chain() self.processed = True def rag_pipeline(self, query): retrieved_docs = self.hyde_chain.invoke({"question": query}) context_template = """ Use the following context to answer the question: {context} Question: {question} Answer: """ rag_prompt = ChatPromptTemplate.from_template(context_template) answer_chain = rag_prompt | self.pipeline | StrOutputParser() answer = answer_chain.invoke({"context": retrieved_docs, "question": query}) return { "answer": answer, "documents": [{"content": doc.page_content, "metadata": doc.metadata} for doc in retrieved_docs], } def generate_response(self, query, file): """ Generate a response to the user's query using the retrieval-augmented generation pipeline. Args: query (str): User's query. Returns: dict: Response with the answer and source documents. """ if not query: raise gr.Error(message='Submit a question') if not file: raise gr.Error(message='Upload a PDF') if not self.processed: self.process_file(file) result = self.rag_pipeline(query) self.current_page = get_page_from_documents(result["documents"]) return result["answer"] def render_page(self, file): """ Render the current page of the PDF as an image. Args: file: File object for the PDF. Returns: PIL.Image.Image: Rendered image of the current page. """ doc = fitz.open(file) page = doc[self.current_page] pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72)) return Image.frombytes("RGB", [pix.width, pix.height], pix.samples)