Spaces:
Running
Running
import streamlit as st | |
import pandas as pd | |
import os | |
import tempfile | |
from PyPDF2 import PdfReader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import openai | |
# OpenAI API key configuration | |
st.set_page_config(page_title="RAG Chatbot with Files", layout="centered") | |
openai.api_key = st.sidebar.text_input("Enter OpenAI API Key:", type="password") | |
# Initialize FAISS and embedding model | |
embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
faiss_index = None | |
data_chunks = [] | |
chunk_mapping = {} | |
# File Upload and Processing | |
def load_files(uploaded_files): | |
global data_chunks, chunk_mapping, faiss_index | |
data_chunks = [] | |
chunk_mapping = {} | |
for uploaded_file in uploaded_files: | |
file_type = uploaded_file.name.split('.')[-1] | |
with tempfile.NamedTemporaryFile(delete=False) as tmp_file: | |
tmp_file.write(uploaded_file.read()) | |
tmp_file_path = tmp_file.name | |
if file_type == "csv": | |
df = pd.read_csv(tmp_file_path) | |
content = "\n".join(df.astype(str).values.flatten()) | |
elif file_type == "xlsx": | |
df = pd.read_excel(tmp_file_path) | |
content = "\n".join(df.astype(str).values.flatten()) | |
elif file_type == "pdf": | |
reader = PdfReader(tmp_file_path) | |
content = "".join([page.extract_text() for page in reader.pages]) | |
else: | |
st.error(f"Unsupported file type: {file_type}") | |
continue | |
# Split into chunks | |
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
chunks = splitter.split_text(content) | |
data_chunks.extend(chunks) | |
chunk_mapping.update({i: (uploaded_file.name, chunk) for i, chunk in enumerate(chunks)}) | |
# Create FAISS index | |
embeddings = embedding_model.encode(data_chunks) | |
faiss_index = faiss.IndexFlatL2(embeddings.shape[1]) | |
faiss_index.add(embeddings) | |
# Query Processing | |
def handle_query(query): | |
if not faiss_index: | |
return "No data available. Please upload files first." | |
# Generate embedding for the query | |
query_embedding = embedding_model.encode([query]) | |
distances, indices = faiss_index.search(query_embedding, k=5) | |
relevant_chunks = [chunk_mapping[idx][1] for idx in indices[0]] | |
# Use OpenAI for summarization | |
prompt = "Summarize the following information:\n" + "\n".join(relevant_chunks) | |
response = openai.Completion.create( | |
engine="text-davinci-003", | |
prompt=prompt, | |
max_tokens=150 | |
) | |
return response['choices'][0]['text'] | |
# Streamlit UI | |
def main(): | |
st.title("RAG Chatbot with Files") | |
st.sidebar.title("Options") | |
uploaded_files = st.sidebar.file_uploader("Upload files (CSV, Excel, PDF):", type=["csv", "xlsx", "pdf"], accept_multiple_files=True) | |
if uploaded_files: | |
load_files(uploaded_files) | |
st.sidebar.success("Files loaded successfully!") | |
query = st.text_input("Ask a question about the data:") | |
if st.button("Get Answer"): | |
if openai.api_key and query: | |
answer = handle_query(query) | |
st.subheader("Answer:") | |
st.write(answer) | |
else: | |
st.error("Please provide a valid API key and query.") | |
if __name__ == "__main__": | |
main() | |