RAHMAN00700's picture
Rename app.py to app1.py
30acf1a unverified
raw
history blame
6.41 kB
import os
import tempfile
from dotenv import load_dotenv
import streamlit as st
from langchain.document_loaders import PyPDFLoader, TextLoader
from langchain.indexes import VectorstoreIndexCreator
from langchain.chains import RetrievalQA
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from ibm_watson_machine_learning.foundation_models import Model
from ibm_watson_machine_learning.foundation_models.extensions.langchain import WatsonxLLM
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
from ibm_watson_machine_learning.foundation_models.utils.enums import DecodingMethods
from pptx import Presentation
from docx import Document
# Load environment variables
load_dotenv()
# Watsonx API setup
watsonx_api_key = os.getenv("API_KEY")
watsonx_project_id = os.getenv("PROJECT_ID")
watsonx_url = "https://us-south.ml.cloud.ibm.com"
if not watsonx_api_key or not watsonx_project_id:
st.error("API Key or Project ID is not set. Please set them as environment variables.")
# Custom loader for DOCX files
class DocxLoader:
def __init__(self, file_path):
self.file_path = file_path
def load(self):
document = Document(self.file_path)
text_content = [para.text for para in document.paragraphs]
return " ".join(text_content)
# Custom loader for PPTX files
class PptxLoader:
def __init__(self, file_path):
self.file_path = file_path
def load(self):
presentation = Presentation(self.file_path)
text_content = []
for slide in presentation.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
text_content.append(shape.text)
return " ".join(text_content)
# Caching function to load various file types
@st.cache_resource
def load_file(uploaded_file, file_type):
loaders = []
# Save uploaded file to a temporary path
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_type}") as temp_file:
temp_file.write(uploaded_file.read())
temp_file_path = temp_file.name
if file_type == "pdf":
loaders = [PyPDFLoader(temp_file_path)]
elif file_type == "docx":
loader = DocxLoader(temp_file_path)
text = loader.load()
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_txt_file:
temp_txt_file.write(text.encode("utf-8"))
temp_txt_file_path = temp_txt_file.name
loaders = [TextLoader(temp_txt_file_path)]
elif file_type == "txt":
loaders = [TextLoader(temp_file_path)]
elif file_type == "pptx":
loader = PptxLoader(temp_file_path)
text = loader.load()
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_txt_file:
temp_txt_file.write(text.encode("utf-8"))
temp_txt_file_path = temp_txt_file.name
loaders = [TextLoader(temp_txt_file_path)]
else:
st.error("Unsupported file type.")
return None
# Create the index with the loaded documents
index = VectorstoreIndexCreator(
embedding=HuggingFaceEmbeddings(model_name="all-MiniLM-L12-v2"),
text_splitter=RecursiveCharacterTextSplitter(chunk_size=450, chunk_overlap=50)
).from_loaders(loaders)
return index
# Prompt template
prompt_template = PromptTemplate(
input_variables=["context", "question"],
template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
I am a helpful assistant.
<|eot_id|>
{context}
<|start_header_id|>user<|end_header_id|>
{question}<|eot_id|>
"""
)
# Sidebar settings
with st.sidebar:
st.title("Watsonx RAG Demo")
model_name = st.selectbox("Model", ["meta-llama/llama-3-405b-instruct", "codellama/codellama-34b-instruct-hf", "ibm/granite-20b-multilingual"])
max_new_tokens = st.slider("Max output tokens", min_value=100, max_value=1000, value=300, step=100)
decoding_method = st.radio("Decoding Method", [DecodingMethods.GREEDY.value, DecodingMethods.SAMPLE.value])
st.info("Upload a PDF, DOCX, TXT, or PPTX file for RAG")
uploaded_file = st.file_uploader("Upload file", accept_multiple_files=False, type=["pdf", "docx", "txt", "pptx"])
if uploaded_file:
file_type = uploaded_file.name.split('.')[-1].lower()
index = load_file(uploaded_file, file_type)
# Watsonx Model setup with UI feedback
credentials = {
"url": watsonx_url,
"apikey": watsonx_api_key
}
parameters = {
GenParams.DECODING_METHOD: decoding_method,
GenParams.MAX_NEW_TOKENS: max_new_tokens,
GenParams.MIN_NEW_TOKENS: 1,
GenParams.TEMPERATURE: 0.7,
GenParams.TOP_K: 50,
GenParams.TOP_P: 1,
GenParams.REPETITION_PENALTY: 1.0
}
# Display setup status
status_placeholder = st.empty()
status_placeholder.markdown("**Setting up Watsonx...**")
try:
model = WatsonxLLM(Model(model_name, credentials, parameters, project_id=watsonx_project_id))
status_placeholder.markdown(f"**Model [{model_name}] ready.**")
except Exception as e:
st.error(f"Failed to initialize model: {str(e)}")
# Chat History Setup
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
st.chat_message(message["role"]).markdown(message["content"])
# User Input
prompt = st.chat_input("Ask your question here", disabled=False if model else True)
# Process User Input
if prompt:
st.chat_message("user").markdown(prompt)
if index:
rag_chain = RetrievalQA.from_chain_type(
llm=model,
chain_type="stuff",
retriever=index.vectorstore.as_retriever(),
chain_type_kwargs={"prompt": prompt_template},
verbose=True
)
response_text = rag_chain.run(prompt).strip()
else:
chain = LLMChain(llm=model, prompt=prompt_template)
response_text = chain.run(context="", question=prompt).strip("<|start_header_id|>assistant<|end_header_id|>").strip("<|eot_id|>")
st.session_state.messages.append({'role': 'user', 'content': prompt})
st.chat_message("assistant").markdown(response_text)
st.session_state.messages.append({'role': 'assistant', 'content': response_text})