RAHMAN00700 commited on
Commit
4896a0b
·
unverified ·
1 Parent(s): 30acf1a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -0
app.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain.document_loaders import PyPDFLoader, TextLoader
3
+ from langchain.indexes import VectorstoreIndexCreator
4
+ from langchain.chains import RetrievalQA
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain.embeddings import HuggingFaceEmbeddings
7
+ from langchain.chains import LLMChain
8
+ from langchain.prompts import PromptTemplate
9
+ import streamlit as st
10
+ import tempfile
11
+
12
+ from ibm_watson_machine_learning.foundation_models import Model
13
+ from ibm_watson_machine_learning.foundation_models.extensions.langchain import WatsonxLLM
14
+ from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
15
+ from ibm_watson_machine_learning.foundation_models.utils.enums import DecodingMethods
16
+
17
+ from pptx import Presentation
18
+ from docx import Document
19
+
20
+ # Initialize index to None
21
+ index = None
22
+
23
+ # Custom loader for DOCX files
24
+ class DocxLoader:
25
+ def __init__(self, file_path):
26
+ self.file_path = file_path
27
+
28
+ def load(self):
29
+ document = Document(self.file_path)
30
+ text_content = []
31
+ for para in document.paragraphs:
32
+ text_content.append(para.text)
33
+ return " ".join(text_content)
34
+
35
+ # Custom loader for PPTX files
36
+ class PptxLoader:
37
+ def __init__(self, file_path):
38
+ self.file_path = file_path
39
+
40
+ def load(self):
41
+ presentation = Presentation(self.file_path)
42
+ text_content = []
43
+ for slide in presentation.slides:
44
+ for shape in slide.shapes:
45
+ if hasattr(shape, "text"):
46
+ text_content.append(shape.text)
47
+ return " ".join(text_content)
48
+
49
+ # Caching function to load various file types
50
+ @st.cache_resource
51
+ def load_file(file_name, file_type):
52
+ loaders = []
53
+
54
+ if file_type == "pdf":
55
+ loaders = [PyPDFLoader(file_name)]
56
+ elif file_type == "docx":
57
+ loader = DocxLoader(file_name)
58
+ text = loader.load()
59
+
60
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file:
61
+ temp_file.write(text.encode("utf-8"))
62
+ temp_file_path = temp_file.name
63
+ loaders = [TextLoader(temp_file_path)]
64
+
65
+ elif file_type == "txt":
66
+ loaders = [TextLoader(file_name)]
67
+
68
+ elif file_type == "pptx":
69
+ loader = PptxLoader(file_name)
70
+ text = loader.load()
71
+
72
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file:
73
+ temp_file.write(text.encode("utf-8"))
74
+ temp_file_path = temp_file.name
75
+ loaders = [TextLoader(temp_file_path)]
76
+
77
+ else:
78
+ st.error("Unsupported file type.")
79
+ return None
80
+
81
+ index = VectorstoreIndexCreator(
82
+ embedding=HuggingFaceEmbeddings(model_name="all-MiniLM-L12-v2"),
83
+ text_splitter=RecursiveCharacterTextSplitter(chunk_size=450, chunk_overlap=50)
84
+ ).from_loaders(loaders)
85
+ return index
86
+
87
+ def format_history():
88
+ return ""
89
+
90
+ # Watsonx API setup using environment variables
91
+ watsonx_api_key = os.getenv("WATSONX_API_KEY")
92
+ watsonx_project_id = os.getenv("WATSONX_PROJECT_ID")
93
+
94
+ if not watsonx_api_key or not watsonx_project_id:
95
+ st.error("API Key or Project ID is not set. Please set them as environment variables.")
96
+
97
+ prompt_template_br = PromptTemplate(
98
+ input_variables=["context", "question"],
99
+ template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
100
+ I am a helpful assistant.
101
+
102
+ <|eot_id|>
103
+ {context}
104
+ <|start_header_id|>user<|end_header_id|>
105
+ {question}<|eot_id|>
106
+ """
107
+ )
108
+
109
+ with st.sidebar:
110
+ st.title("Watsonx RAG Demo")
111
+ watsonx_model = st.selectbox("Model", ["meta-llama/llama-3-405b-instruct", "codellama/codellama-34b-instruct-hf", "ibm/granite-20b-multilingual"])
112
+ max_new_tokens = st.slider("Max output tokens", min_value=100, max_value=1000, value=500, step=100)
113
+ decoding_method = st.radio("Decoding", (DecodingMethods.GREEDY.value, DecodingMethods.SAMPLE.value))
114
+ parameters = {
115
+ GenParams.DECODING_METHOD: decoding_method,
116
+ GenParams.MAX_NEW_TOKENS: max_new_tokens,
117
+ GenParams.MIN_NEW_TOKENS: 1,
118
+ GenParams.TEMPERATURE: 0,
119
+ GenParams.TOP_K: 50,
120
+ GenParams.TOP_P: 1,
121
+ GenParams.STOP_SEQUENCES: [],
122
+ GenParams.REPETITION_PENALTY: 1
123
+ }
124
+ st.info("Upload a PDF, DOCX, TXT, or PPTX file to use RAG")
125
+ uploaded_file = st.file_uploader("Upload file", accept_multiple_files=False, type=["pdf", "docx", "txt", "pptx"])
126
+ if uploaded_file is not None:
127
+ bytes_data = uploaded_file.read()
128
+ st.write("Filename:", uploaded_file.name)
129
+
130
+ with open(uploaded_file.name, 'wb') as f:
131
+ f.write(bytes_data)
132
+
133
+ file_type = uploaded_file.name.split('.')[-1].lower()
134
+ index = load_file(uploaded_file.name, file_type)
135
+
136
+ model_name = watsonx_model
137
+
138
+ def clear_messages():
139
+ st.session_state.messages = []
140
+
141
+ st.button('Clear messages', on_click=clear_messages)
142
+
143
+ st.info("Setting up Watsonx...")
144
+
145
+ my_credentials = {
146
+ "url": "https://us-south.ml.cloud.ibm.com",
147
+ "apikey": watsonx_api_key
148
+ }
149
+ params = parameters
150
+ project_id = watsonx_project_id
151
+ space_id = None
152
+ verify = False
153
+ model = WatsonxLLM(model=Model(model_name, my_credentials, params, project_id, space_id, verify))
154
+
155
+ if model:
156
+ st.info(f"Model {model_name} ready.")
157
+ chain = LLMChain(llm=model, prompt=prompt_template_br, verbose=True)
158
+
159
+ if chain:
160
+ st.info("Chat ready.")
161
+ if index:
162
+ rag_chain = RetrievalQA.from_chain_type(
163
+ llm=model,
164
+ chain_type="stuff",
165
+ retriever=index.vectorstore.as_retriever(),
166
+ chain_type_kwargs={"prompt": prompt_template_br},
167
+ return_source_documents=False,
168
+ verbose=True
169
+ )
170
+ st.info("Chat with document ready.")
171
+
172
+ if "messages" not in st.session_state:
173
+ st.session_state.messages = []
174
+
175
+ for message in st.session_state.messages:
176
+ st.chat_message(message["role"]).markdown(message["content"])
177
+
178
+ prompt = st.chat_input("Ask your question here", disabled=False if chain else True)
179
+
180
+ if prompt:
181
+ st.chat_message("user").markdown(prompt)
182
+
183
+ response_text = None
184
+ if rag_chain:
185
+ response_text = rag_chain.run(prompt).strip()
186
+
187
+ if not response_text:
188
+ response = chain.run(question=prompt, context=format_history())
189
+ response_text = response.strip("<|start_header_id|>assistant<|end_header_id|>").strip("<|eot_id|>")
190
+
191
+ st.session_state.messages.append({'role': 'User', 'content': prompt })
192
+ st.chat_message("assistant").markdown(response_text)
193
+ st.session_state.messages.append({'role': 'Assistant', 'content': response_text })