Spaces:
Sleeping
Sleeping
import numpy as np | |
import redis | |
import streamlit as st | |
from langchain import HuggingFaceHub | |
from langchain.chains import LLMChain | |
from langchain.chat_models import ChatOpenAI | |
from langchain.memory import ConversationBufferMemory | |
from langchain.prompts import PromptTemplate | |
from redis.commands.search.query import Query | |
from sentence_transformers import SentenceTransformer | |
from constants import ( | |
EMBEDDING_MODEL_NAME, | |
FALCON_MAX_TOKENS, | |
FALCON_REPO_ID, | |
FALCON_TEMPERATURE, | |
HUGGINGFACEHUB_API_TOKEN, | |
ITEM_KEYWORD_EMBEDDING, | |
OPENAI_API_KEY, | |
OPENAI_MODEL_NAME, | |
OPENAI_TEMPERATURE, | |
TEMPLATE_1, | |
TEMPLATE_2, | |
TOPK, | |
) | |
from database import create_redis | |
# connect to redis database | |
def connect_to_redis(): | |
pool = create_redis() | |
return redis.Redis(connection_pool=pool) | |
# the encoding keywords chain | |
def encode_keywords_chain(): | |
llm = HuggingFaceHub( | |
repo_id=FALCON_REPO_ID, | |
model_kwargs={"temperature": FALCON_TEMPERATURE, "max_new_tokens": FALCON_MAX_TOKENS}, | |
huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN, | |
) | |
prompt = PromptTemplate( | |
input_variables=["product_description"], | |
template=TEMPLATE_1, | |
) | |
chain = LLMChain(llm=llm, prompt=prompt) | |
return chain | |
# the present products chain | |
def present_products_chain(): | |
template = TEMPLATE_2 | |
memory = ConversationBufferMemory(memory_key="chat_history") | |
prompt = PromptTemplate(input_variables=["chat_history", "user_msg"], template=template) | |
chain = LLMChain( | |
llm=ChatOpenAI(openai_api_key=OPENAI_API_KEY, temperature=OPENAI_TEMPERATURE, model=OPENAI_MODEL_NAME), | |
prompt=prompt, | |
verbose=False, | |
memory=memory, | |
) | |
return chain | |
def instance_embedding_model(): | |
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME) | |
return embedding_model | |
def main(): | |
st.title("My Amazon shopping buddy π·οΈ") | |
st.caption("π€ Powered by Falcon Open Source AI model") | |
redis_conn = connect_to_redis() | |
keywords_chain = encode_keywords_chain() | |
if "window_refreshed" not in st.session_state: | |
st.session_state.window_refreshed = True | |
st.session_state.chat_chain = present_products_chain() | |
embedding_model = instance_embedding_model() | |
if "messages" not in st.session_state: | |
st.session_state["messages"] = [ | |
{"role": "assistant", "content": "Hey im your online shopping buddy, how can i help you today?"} | |
] | |
for msg in st.session_state["messages"]: | |
st.chat_message(msg["role"]).write(msg["content"]) | |
prompt = st.chat_input(key="user_input") | |
if prompt: | |
st.session_state["messages"].append({"role": "user", "content": prompt}) | |
st.chat_message("user").write(prompt) | |
st.session_state.disabled = True | |
keywords = keywords_chain.run(prompt) | |
# vectorize the query | |
query_vector = embedding_model.encode(keywords) | |
query_vector_bytes = np.array(query_vector).astype(np.float32).tobytes() | |
# prepare the query | |
q = ( | |
Query(f"*=>[KNN {TOPK} @{ITEM_KEYWORD_EMBEDDING} $vec_param AS vector_score]") | |
.sort_by("vector_score") | |
.paging(0, TOPK) | |
.return_fields("vector_score", "item_name", "item_id", "item_keywords") | |
.dialect(2) | |
) | |
params_dict = {"vec_param": query_vector_bytes} | |
# Execute the query | |
results = redis_conn.ft().search(q, query_params=params_dict) | |
result_output = "" | |
for product in results.docs: | |
result_output += f"product_name:{product.item_name}, product_description:{product.item_keywords} \n" | |
result = st.session_state.chat_chain.predict(user_msg=f"{result_output}\n{prompt}") | |
st.session_state.messages.append({"role": "assistant", "content": result}) | |
st.chat_message("assistant").write(result) | |
if __name__ == "__main__": | |
main() | |