LegalAlly / app.py
Rohil Bansal
New structure
7a7b50b
raw
history blame
8.58 kB
from langchain.embeddings import OpenAIEmbeddings
from langchain.llms import OpenAI
import streamlit as st
import time
import logging
import os , sys
from langchain.memory import ConversationBufferWindowMemory
from langchain.chains import ConversationalRetrievalChain, ConversationChain
from langchain.prompts import PromptTemplate
from src.settings import load_env_variables
from src.logger import setup_logger
from src.vector_db import load_vector_db, save_vector_db
from src.embeddings import get_embeddings, get_model, test_openai_key
from src.dataloader import dataloader
def reset_conversation():
print("Resetting conversation")
st.session_state.messages = []
st.session_state.memory.clear()
print("Conversation reset complete")
print("Starting app.py")
try:
# Load environment variables and setup logging
print("Loading environment variables and setting up logging")
openai_api_key = load_env_variables()
setup_logger()
print("Environment variables loaded and logging set up")
# Test OpenAI API key
print("Testing OpenAI API key")
if not test_openai_key(openai_api_key):
print("OpenAI API key is invalid or has no credits. Falling back to Mistral.")
else:
print("OpenAI API key is valid and has credits")
st.set_page_config(page_title="LawGPT")
print("Streamlit page config set")
col1, col2, col3 = st.columns([1, 4, 1])
with col2:
try:
st.image("assets/Black Bold Initial AI Business Logo.jpg")
print("Logo image loaded successfully")
except Exception as e:
print(f"Error loading logo image: {str(e)}")
print("Applying custom CSS")
st.markdown("""
<style>
.stApp, .ea3mdgi6{ background-color:#000000; }
div.stButton > button:first-child { background-color: #ffd0d0; }
div.stButton > button:active { background-color: #ff6262; }
div[data-testid="stStatusWidget"] div button { display: none; }
.reportview-container { margin-top: -2em; }
#MainMenu {visibility: hidden;}
.stDeployButton {display:none;}
footer {visibility: hidden;}
#stDecoration {display:none;}
button[title="View fullscreen"]{ visibility: hidden;}
button:first-child{ background-color : transparent !important; }
</style>
""", unsafe_allow_html=True)
print("Initializing session state")
if "messages" not in st.session_state:
st.session_state["messages"] = []
if "memory" not in st.session_state:
st.session_state["memory"] = ConversationBufferWindowMemory(k=2, memory_key="chat_history", return_messages=True)
print("Session state initialized")
# Get the appropriate embeddings
print("Setting up embeddings")
embeddings = get_embeddings(openai_api_key)
print(f"Using embeddings: {type(embeddings).__name__}")
# Get the appropriate model
print("Getting appropriate model")
model_name = get_model(openai_api_key)
print(f"Using model: {model_name}")
print("Setting up OpenAI embeddings")
try:
embeddings = get_embeddings(openai_api_key)
print("OpenAI embeddings set up successfully")
except Exception as e:
print(f"Error setting up OpenAI embeddings: {str(e)}")
st.error("An error occurred while setting up OpenAI embeddings. Please check your API key and try again.")
st.stop()
# Placeholder data for creating the vector database
file_name = 'Indian_Penal_Code_Book.pdf'
data = dataloader(file_name)
print("Loading vector database")
db_path = "./ipc_vector_db/vectordb"
os.makedirs(os.path.dirname(db_path), exist_ok=True)
print(f"Ensured directory exists: {os.path.dirname(db_path)}")
vector_db = load_vector_db(db_path, embeddings, data)
db_retriever = vector_db.as_retriever(search_type="similarity", search_kwargs={"k": 4})
print("Vector database loaded successfully")
print("Setting up prompt template")
prompt_template = """
This is a chat template and As a legal chat bot specializing in Indian Penal Code queries, your primary objective is to provide accurate and concise information based on the user's questions. Do not generate your own questions and answers. You will adhere strictly to the instructions provided, offering relevant context from the knowledge base while avoiding unnecessary details. Your responses will be brief, to the point, and in compliance with the established format. If a question falls outside the given context, you will refrain from utilizing the chat history and instead rely on your own knowledge base to generate an appropriate response. You will prioritize the user's query and refrain from posing additional questions. The aim is to deliver professional, precise, and contextually relevant information pertaining to the Indian Penal Code.
CONTEXT: {context}
CHAT HISTORY: {chat_history}
QUESTION: {question}
ANSWER:
"""
prompt = PromptTemplate(template=prompt_template, input_variables=['context', 'question', 'chat_history'])
print("Setting up OpenAI LLM")
try:
if "gpt-4" in model_name or "gpt-3.5-turbo" in model_name:
from langchain.chat_models import ChatOpenAI
llm = ChatOpenAI(model_name=model_name, temperature=0.5, openai_api_key=openai_api_key)
elif "mistral" in model_name.lower():
from langchain.llms import HuggingFaceHub
llm = HuggingFaceHub(repo_id=model_name, model_kwargs={"temperature": 0.5})
else:
llm = OpenAI(model_name=model_name, temperature=0.5, openai_api_key=openai_api_key)
print(f"LLM set up successfully: {type(llm).__name__}")
except Exception as e:
print(f"Error setting up OpenAI LLM: {str(e)}")
raise
print("Setting up ConversationalRetrievalChain")
try:
if db_retriever:
qa = ConversationalRetrievalChain.from_llm(
llm=llm,
memory=ConversationBufferWindowMemory(k=2, memory_key="chat_history", return_messages=True),
retriever=db_retriever,
combine_docs_chain_kwargs={'prompt': prompt}
)
else:
# Fall back to a simple conversation chain without retrieval
qa = ConversationChain(
llm=llm,
memory=ConversationBufferWindowMemory(k=2, memory_key="chat_history", return_messages=True),
prompt=prompt
)
print("ConversationalRetrievalChain (or fallback) set up successfully")
except Exception as e:
print(f"Error setting up ConversationalRetrievalChain: {str(e)}")
raise
print("Displaying chat messages")
for message in st.session_state.get("messages", []):
with st.chat_message(message.get("role")):
st.write(message.get("content"))
input_prompt = st.chat_input("Say something")
if input_prompt:
print(f"Received input: {input_prompt}")
with st.chat_message("user"):
st.write(input_prompt)
st.session_state.messages.append({"role": "user", "content": input_prompt})
with st.chat_message("assistant"):
with st.spinner("Thinking πŸ’‘..."):
try:
print("Invoking ConversationalRetrievalChain")
result = qa.invoke(input=input_prompt)
print("ConversationalRetrievalChain invoked successfully")
message_placeholder = st.empty()
full_response = "⚠️ **_Note: Information provided may be inaccurate._** \n\n\n"
for chunk in result["answer"]:
full_response += chunk
time.sleep(0.02)
message_placeholder.markdown(full_response + " β–Œ")
print("Response displayed successfully")
except Exception as e:
print(f"Error generating or displaying response: {str(e)}")
st.error("An error occurred while processing your request. Please try again.")
st.button('Reset All Chat πŸ—‘οΈ', on_click=reset_conversation)
st.session_state.messages.append({"role": "assistant", "content": result["answer"]})
except Exception as e:
print(f"Unhandled exception in main.py: {str(e)}")
logging.exception("Unhandled exception in main.py")
st.error("An unexpected error occurred. Please try again later.")
print("End of src/app/main.py")