RAHMAN00700 commited on
Commit
a455ea8
·
unverified ·
1 Parent(s): c73f78c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +199 -0
app.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import tempfile
4
+ from pptx import Presentation
5
+ from docx import Document
6
+
7
+ from langchain.document_loaders import PyPDFLoader, TextLoader
8
+ from langchain.indexes import VectorstoreIndexCreator
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from langchain.embeddings import HuggingFaceEmbeddings
12
+ from langchain.chains import LLMChain
13
+ from langchain.prompts import PromptTemplate
14
+
15
+ from ibm_watson_machine_learning.foundation_models import Model
16
+ from ibm_watson_machine_learning.foundation_models.extensions.langchain import WatsonxLLM
17
+ from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
18
+ from ibm_watson_machine_learning.foundation_models.utils.enums import DecodingMethods
19
+
20
+ # Initialize index to None
21
+ index = None
22
+ rag_chain = None # Initialize rag_chain as None by default
23
+
24
+ # Custom loader for DOCX files
25
+ class DocxLoader:
26
+ def __init__(self, file_path):
27
+ self.file_path = file_path
28
+
29
+ def load(self):
30
+ document = Document(self.file_path)
31
+ text_content = []
32
+ for para in document.paragraphs:
33
+ text_content.append(para.text)
34
+ return " ".join(text_content)
35
+
36
+ # Custom loader for PPTX files
37
+ class PptxLoader:
38
+ def __init__(self, file_path):
39
+ self.file_path = file_path
40
+
41
+ def load(self):
42
+ presentation = Presentation(self.file_path)
43
+ text_content = []
44
+ for slide in presentation.slides:
45
+ for shape in slide.shapes:
46
+ if hasattr(shape, "text"):
47
+ text_content.append(shape.text)
48
+ return " ".join(text_content)
49
+
50
+ # Caching function to load various file types
51
+ @st.cache_resource
52
+ def load_file(file_name, file_type):
53
+ loaders = []
54
+
55
+ if file_type == "pdf":
56
+ loaders = [PyPDFLoader(file_name)]
57
+ elif file_type == "docx":
58
+ loader = DocxLoader(file_name)
59
+ text = loader.load()
60
+
61
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file:
62
+ temp_file.write(text.encode("utf-8"))
63
+ temp_file_path = temp_file.name
64
+ loaders = [TextLoader(temp_file_path)]
65
+
66
+ elif file_type == "txt":
67
+ loaders = [TextLoader(file_name)]
68
+
69
+ elif file_type == "pptx":
70
+ loader = PptxLoader(file_name)
71
+ text = loader.load()
72
+
73
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file:
74
+ temp_file.write(text.encode("utf-8"))
75
+ temp_file_path = temp_file.name
76
+ loaders = [TextLoader(temp_file_path)]
77
+
78
+ else:
79
+ st.error("Unsupported file type.")
80
+ return None
81
+
82
+ index = VectorstoreIndexCreator(
83
+ embedding=HuggingFaceEmbeddings(model_name="all-MiniLM-L12-v2"),
84
+ text_splitter=RecursiveCharacterTextSplitter(chunk_size=450, chunk_overlap=50)
85
+ ).from_loaders(loaders)
86
+ return index
87
+
88
+ def format_history():
89
+ return ""
90
+
91
+ # Watsonx API setup using environment variables
92
+ watsonx_api_key = os.getenv("WATSONX_API_KEY")
93
+ watsonx_project_id = os.getenv("WATSONX_PROJECT_ID")
94
+
95
+ if not watsonx_api_key or not watsonx_project_id:
96
+ st.error("API Key or Project ID is not set. Please set them as environment variables.")
97
+
98
+ prompt_template_br = PromptTemplate(
99
+ input_variables=["context", "question"],
100
+ template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
101
+ I am a helpful assistant.
102
+
103
+ <|eot_id|>
104
+ {context}
105
+ <|start_header_id|>user<|end_header_id|>
106
+ {question}<|eot_id|>
107
+ """
108
+ )
109
+
110
+ with st.sidebar:
111
+ st.title("Watsonx RAG with Multiple docs")
112
+ watsonx_model = st.selectbox("Model", ["meta-llama/llama-3-405b-instruct", "codellama/codellama-34b-instruct-hf", "ibm/granite-20b-multilingual"])
113
+ max_new_tokens = st.slider("Max output tokens", min_value=100, max_value=4000, value=600, step=100)
114
+ decoding_method = st.radio("Decoding", (DecodingMethods.GREEDY.value, DecodingMethods.SAMPLE.value))
115
+ parameters = {
116
+ GenParams.DECODING_METHOD: decoding_method,
117
+ GenParams.MAX_NEW_TOKENS: max_new_tokens,
118
+ GenParams.MIN_NEW_TOKENS: 1,
119
+ GenParams.TEMPERATURE: 0,
120
+ GenParams.TOP_K: 50,
121
+ GenParams.TOP_P: 1,
122
+ GenParams.STOP_SEQUENCES: [],
123
+ GenParams.REPETITION_PENALTY: 1
124
+ }
125
+ st.info("Upload a PDF, DOCX, TXT, or PPTX file to use RAG")
126
+ uploaded_file = st.file_uploader("Upload file", accept_multiple_files=False, type=["pdf", "docx", "txt", "pptx"])
127
+ if uploaded_file is not None:
128
+ bytes_data = uploaded_file.read()
129
+ st.write("Filename:", uploaded_file.name)
130
+
131
+ with open(uploaded_file.name, 'wb') as f:
132
+ f.write(bytes_data)
133
+
134
+ file_type = uploaded_file.name.split('.')[-1].lower()
135
+ index = load_file(uploaded_file.name, file_type)
136
+
137
+ model_name = watsonx_model
138
+
139
+ def clear_messages():
140
+ st.session_state.messages = []
141
+
142
+ st.button('Clear messages', on_click=clear_messages)
143
+
144
+ st.info("Setting up Watsonx...")
145
+
146
+ my_credentials = {
147
+ "url": "https://us-south.ml.cloud.ibm.com",
148
+ "apikey": watsonx_api_key
149
+ }
150
+ params = parameters
151
+ project_id = watsonx_project_id
152
+ space_id = None
153
+ verify = False
154
+ model = WatsonxLLM(model=Model(model_name, my_credentials, params, project_id, space_id, verify))
155
+
156
+ if model:
157
+ st.info(f"Model {model_name} ready.")
158
+ chain = LLMChain(llm=model, prompt=prompt_template_br, verbose=True)
159
+
160
+ if chain:
161
+ st.info("Chat ready.")
162
+
163
+ # Only create rag_chain if index is successfully created
164
+ if index is not None:
165
+ rag_chain = RetrievalQA.from_chain_type(
166
+ llm=model,
167
+ chain_type="stuff",
168
+ retriever=index.vectorstore.as_retriever(),
169
+ chain_type_kwargs={"prompt": prompt_template_br},
170
+ return_source_documents=False,
171
+ verbose=True
172
+ )
173
+ st.info("Document-based retrieval is ready.")
174
+ else:
175
+ st.warning("No document uploaded. Answering common queries without retrieval.")
176
+
177
+ # Chat loop for handling queries
178
+ if "messages" not in st.session_state:
179
+ st.session_state.messages = []
180
+
181
+ for message in st.session_state.messages:
182
+ st.chat_message(message["role"]).markdown(message["content"])
183
+
184
+ prompt = st.chat_input("Ask your question here", disabled=False if chain else True)
185
+
186
+ if prompt:
187
+ st.chat_message("user").markdown(prompt)
188
+
189
+ # Answer based on availability of rag_chain or chain
190
+ if rag_chain:
191
+ response_text = rag_chain.run(prompt).strip()
192
+ else:
193
+ # Use general model-based response if rag_chain is not available
194
+ response_text = chain.run(question=prompt, context=format_history()).strip("<|start_header_id|>assistant<|end_header_id|>").strip("<|eot_id|>")
195
+
196
+ # Store and display conversation
197
+ st.session_state.messages.append({'role': 'User', 'content': prompt})
198
+ st.chat_message("assistant").markdown(response_text)
199
+ st.session_state.messages.append({'role': 'Assistant', 'content': response_text})