Spaces:
Running
Running
fix imports
Browse files- .env_example +3 -1
- .pre-commit-config.yaml +24 -0
- app.py +115 -118
- constants.py +14 -0
- database.py +10 -7
- preprocess.py +30 -34
- utilities.py +0 -32
- utils.py +38 -0
.env_example
CHANGED
@@ -1,3 +1,5 @@
|
|
1 |
REDIS_KEY = ''
|
2 |
OPENAI_API_KEY = ''
|
3 |
-
HUGGINGFACEHUB_API_TOKEN = ''
|
|
|
|
|
|
1 |
REDIS_KEY = ''
|
2 |
OPENAI_API_KEY = ''
|
3 |
+
HUGGINGFACEHUB_API_TOKEN = ''
|
4 |
+
REDIS_HOST = ''
|
5 |
+
REDIS_PORT = ''
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
repos:
|
2 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
3 |
+
rev: v3.2.0
|
4 |
+
hooks:
|
5 |
+
- id: trailing-whitespace
|
6 |
+
- id: end-of-file-fixer
|
7 |
+
- id: check-yaml
|
8 |
+
- id: check-added-large-files
|
9 |
+
- repo: https://github.com/psf/black
|
10 |
+
rev: 22.10.0
|
11 |
+
hooks:
|
12 |
+
- id: black
|
13 |
+
args: ["--line-length=118"]
|
14 |
+
- repo: https://github.com/pycqa/isort
|
15 |
+
rev: 5.12.0
|
16 |
+
hooks:
|
17 |
+
- id: isort
|
18 |
+
name: isort (python)
|
19 |
+
args: ["--profile", "black", "--filter-files"]
|
20 |
+
- repo: https://github.com/pycqa/flake8
|
21 |
+
rev: 6.0.0
|
22 |
+
hooks:
|
23 |
+
- id: flake8
|
24 |
+
args: ["--max-line-length=118", "--ignore=E501,E266,E203,W503"]
|
app.py
CHANGED
@@ -1,124 +1,121 @@
|
|
1 |
-
import
|
2 |
-
|
3 |
-
|
4 |
import redis
|
5 |
-
|
|
|
6 |
from langchain import HuggingFaceHub
|
7 |
from langchain.chains import LLMChain
|
8 |
-
from langchain.memory import ConversationBufferMemory
|
9 |
from langchain.chat_models import ChatOpenAI
|
10 |
-
from langchain.
|
11 |
-
import
|
12 |
-
from
|
13 |
-
|
14 |
-
|
15 |
-
load_dotenv()
|
16 |
-
redis_key = os.getenv('REDIS_KEY')
|
17 |
-
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
|
18 |
-
repo_id = 'tiiuae/falcon-7b-instruct'
|
19 |
-
|
20 |
-
class StreamHandler(BaseCallbackHandler):
|
21 |
-
def __init__(self, container, initial_text="", display_method='markdown'):
|
22 |
-
self.container = container
|
23 |
-
self.text = initial_text
|
24 |
-
self.display_method = display_method
|
25 |
-
|
26 |
-
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
27 |
-
self.text += token + " "
|
28 |
-
display_function = getattr(self.container, self.display_method, None)
|
29 |
-
if display_function is not None:
|
30 |
-
display_function(self.text)
|
31 |
-
else:
|
32 |
-
raise ValueError(f"Invalid display_method: {self.display_method}")
|
33 |
-
|
34 |
-
|
35 |
-
st.title('My Amazon shopping buddy 🏷️')
|
36 |
-
st.caption('🤖 Powered by Falcon Open Source AI model')
|
37 |
-
|
38 |
-
#connect to redis database
|
39 |
-
@st.cache_resource()
|
40 |
-
def redis_connect():
|
41 |
-
redis_conn = redis.Redis(
|
42 |
-
host='redis-12882.c259.us-central1-2.gce.cloud.redislabs.com',
|
43 |
-
port=12882,
|
44 |
-
password=redis_key)
|
45 |
-
return redis_conn
|
46 |
-
|
47 |
-
redis_conn = redis_connect()
|
48 |
-
|
49 |
-
#the encoding keywords chain
|
50 |
-
@st.cache_resource()
|
51 |
-
def encode_keywords_chain():
|
52 |
-
falcon_llm_1 = HuggingFaceHub(repo_id = repo_id, model_kwargs={'temperature':0.1,'max_new_tokens':500},huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN)
|
53 |
-
prompt = PromptTemplate(
|
54 |
-
input_variables=["product_description"],
|
55 |
-
template="Create comma seperated product keywords to perform a query on a amazon dataset for this user input: {product_description}",
|
56 |
-
)
|
57 |
-
chain = LLMChain(llm=falcon_llm_1, prompt=prompt)
|
58 |
-
return chain
|
59 |
-
chain = encode_keywords_chain()
|
60 |
-
#the present products chain
|
61 |
-
|
62 |
-
@st.cache_resource()
|
63 |
-
def present_products_chain():
|
64 |
-
template = """You are a salesman. Be kind, detailed and nice. take the given context and Present the given queried search result in a nice way as answer to the user_msg. dont ask questions back or freestyle and invent followup conversation!
|
65 |
-
{chat_history}
|
66 |
-
user:{user_msg}
|
67 |
-
Chatbot:"""
|
68 |
-
prompt = PromptTemplate(
|
69 |
-
input_variables=["chat_history", "user_msg"],
|
70 |
-
template=template
|
71 |
-
)
|
72 |
-
memory = ConversationBufferMemory(memory_key="chat_history")
|
73 |
-
llm_chain = LLMChain(
|
74 |
-
llm = ChatOpenAI(openai_api_key=os.getenv('OPENAI_API_KEY'),temperature=0.8,model='gpt-3.5-turbo'),
|
75 |
-
prompt=prompt,
|
76 |
-
verbose=False,
|
77 |
-
memory=memory,
|
78 |
-
)
|
79 |
-
return llm_chain
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
llm_chain = present_products_chain()
|
84 |
-
|
85 |
-
@st.cache_resource()
|
86 |
-
def embedding_model():
|
87 |
-
embedding_model = SentenceTransformer('sentence-transformers/all-distilroberta-v1')
|
88 |
-
return embedding_model
|
89 |
-
|
90 |
-
embedding_model = embedding_model()
|
91 |
-
|
92 |
-
if "messages" not in st.session_state:
|
93 |
-
st.session_state["messages"] = [{"role": "assistant", "content": "Hey im your online shopping buddy, how can i help you today?"}]
|
94 |
-
for msg in st.session_state["messages"]:
|
95 |
-
st.chat_message(msg["role"]).write(msg["content"])
|
96 |
-
|
97 |
-
prompt = st.chat_input(key="user_input" )
|
98 |
-
|
99 |
-
if prompt:
|
100 |
-
st.session_state["messages"].append({"role": "user", "content": prompt})
|
101 |
-
st.chat_message('user').write(prompt)
|
102 |
-
st.session_state.disabled = True
|
103 |
-
keywords = chain.run(prompt)
|
104 |
-
#vectorize the query
|
105 |
-
query_vector = embedding_model.encode(keywords)
|
106 |
-
query_vector = np.array(query_vector).astype(np.float32).tobytes()
|
107 |
-
#prepare the query
|
108 |
-
ITEM_KEYWORD_EMBEDDING_FIELD = 'item_vector'
|
109 |
-
topK=5
|
110 |
-
q = Query(f'*=>[KNN {topK} @{ITEM_KEYWORD_EMBEDDING_FIELD} $vec_param AS vector_score]').sort_by('vector_score').paging(0,topK).return_fields('vector_score','item_name','item_id','item_keywords').dialect(2)
|
111 |
-
params_dict = {"vec_param": query_vector}
|
112 |
-
#Execute the query
|
113 |
-
results = redis_conn.ft().search(q, query_params = params_dict)
|
114 |
-
|
115 |
-
full_result_string = ''
|
116 |
-
for product in results.docs:
|
117 |
-
full_result_string += product.item_name + ' ' + product.item_keywords + "\n\n\n"
|
118 |
-
|
119 |
-
result = llm_chain.predict(user_msg=f"{full_result_string} ---\n\n {prompt}")
|
120 |
-
st.session_state.messages.append({"role": "assistant", "content": result})
|
121 |
-
st.chat_message('assistant').write(result)
|
122 |
-
|
123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
import redis
|
5 |
+
import streamlit as st
|
6 |
+
from dotenv import load_dotenv
|
7 |
from langchain import HuggingFaceHub
|
8 |
from langchain.chains import LLMChain
|
|
|
9 |
from langchain.chat_models import ChatOpenAI
|
10 |
+
from langchain.memory import ConversationBufferMemory
|
11 |
+
from langchain.prompts import PromptTemplate
|
12 |
+
from redis.commands.search.query import Query
|
13 |
+
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
+
from constants import (
|
16 |
+
EMBEDDING_MODEL_NAME,
|
17 |
+
FALCON_MAX_TOKENS,
|
18 |
+
FALCON_REPO_ID,
|
19 |
+
FALCON_TEMPERATURE,
|
20 |
+
OPENAI_MODEL_NAME,
|
21 |
+
OPENAI_TEMPERATURE,
|
22 |
+
TEMPLATE_1,
|
23 |
+
TEMPLATE_2,
|
24 |
+
)
|
25 |
+
from database import create_redis
|
26 |
|
27 |
+
load_dotenv()
|
28 |
+
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
29 |
+
ITEM_KEYWORD_EMBEDDING = "item_vector"
|
30 |
+
TOPK = 5
|
31 |
+
|
32 |
+
|
33 |
+
def main():
|
34 |
+
# connect to redis database
|
35 |
+
@st.cache_resource()
|
36 |
+
def connect_to_redis():
|
37 |
+
pool = create_redis()
|
38 |
+
return redis.Redis(connection_pool=pool)
|
39 |
+
|
40 |
+
# the encoding keywords chain
|
41 |
+
@st.cache_resource()
|
42 |
+
def encode_keywords_chain():
|
43 |
+
falcon_llm_1 = HuggingFaceHub(
|
44 |
+
repo_id=FALCON_REPO_ID,
|
45 |
+
model_kwargs={"temperature": FALCON_TEMPERATURE, "max_new_tokens": FALCON_MAX_TOKENS},
|
46 |
+
huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN,
|
47 |
+
)
|
48 |
+
prompt = PromptTemplate(
|
49 |
+
input_variables=["product_description"],
|
50 |
+
template=TEMPLATE_1,
|
51 |
+
)
|
52 |
+
chain = LLMChain(llm=falcon_llm_1, prompt=prompt)
|
53 |
+
return chain
|
54 |
+
|
55 |
+
# the present products chain
|
56 |
+
@st.cache_resource()
|
57 |
+
def present_products_chain():
|
58 |
+
template = TEMPLATE_2
|
59 |
+
prompt = PromptTemplate(input_variables=["chat_history", "user_msg"], template=template)
|
60 |
+
memory = ConversationBufferMemory(memory_key="chat_history")
|
61 |
+
llm_chain = LLMChain(
|
62 |
+
llm=ChatOpenAI(
|
63 |
+
openai_api_key=os.getenv("OPENAI_API_KEY"), temperature=OPENAI_TEMPERATURE, model=OPENAI_MODEL_NAME
|
64 |
+
),
|
65 |
+
prompt=prompt,
|
66 |
+
verbose=False,
|
67 |
+
memory=memory,
|
68 |
+
)
|
69 |
+
return llm_chain
|
70 |
+
|
71 |
+
@st.cache_resource()
|
72 |
+
def instance_embedding_model():
|
73 |
+
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
|
74 |
+
return embedding_model
|
75 |
+
|
76 |
+
st.title("My Amazon shopping buddy 🏷️")
|
77 |
+
st.caption("🤖 Powered by Falcon Open Source AI model")
|
78 |
+
redis_conn = connect_to_redis()
|
79 |
+
keywords_chain = encode_keywords_chain()
|
80 |
+
chat_chain = present_products_chain()
|
81 |
+
embedding_model = instance_embedding_model()
|
82 |
+
|
83 |
+
if "messages" not in st.session_state:
|
84 |
+
st.session_state["messages"] = [
|
85 |
+
{"role": "assistant", "content": "Hey im your online shopping buddy, how can i help you today?"}
|
86 |
+
]
|
87 |
+
for msg in st.session_state["messages"]:
|
88 |
+
st.chat_message(msg["role"]).write(msg["content"])
|
89 |
+
|
90 |
+
prompt = st.chat_input(key="user_input")
|
91 |
+
|
92 |
+
if prompt:
|
93 |
+
st.session_state["messages"].append({"role": "user", "content": prompt})
|
94 |
+
st.chat_message("user").write(prompt)
|
95 |
+
st.session_state.disabled = True
|
96 |
+
keywords = keywords_chain.run(prompt)
|
97 |
+
# vectorize the query
|
98 |
+
query_vector = embedding_model.encode(keywords)
|
99 |
+
query_vector_bytes = np.array(query_vector).astype(np.float32).tobytes()
|
100 |
+
# prepare the query
|
101 |
+
|
102 |
+
q = (
|
103 |
+
Query(f"*=>[KNN {TOPK} @{ITEM_KEYWORD_EMBEDDING} $vec_param AS vector_score]")
|
104 |
+
.sort_by("vector_score")
|
105 |
+
.paging(0, TOPK)
|
106 |
+
.return_fields("vector_score", "item_name", "item_id", "item_keywords")
|
107 |
+
.dialect(2)
|
108 |
+
)
|
109 |
+
params_dict = {"vec_param": query_vector_bytes}
|
110 |
+
# Execute the query
|
111 |
+
results = redis_conn.ft().search(q, query_params=params_dict)
|
112 |
+
result_output = ""
|
113 |
+
for product in results.docs:
|
114 |
+
result_output += f"product_name:{product.item_name}, product_description:{product.item_keywords} \n"
|
115 |
+
result = chat_chain.predict(user_msg=f"{result_output}\n{prompt}")
|
116 |
+
st.session_state.messages.append({"role": "assistant", "content": result})
|
117 |
+
st.chat_message("assistant").write(result)
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == "__main__":
|
121 |
+
main()
|
constants.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FALCON_REPO_ID = "tiiuae/falcon-7b-instruct"
|
2 |
+
FALCON_TEMPERATURE = 0.1
|
3 |
+
FALCON_MAX_TOKENS = 500
|
4 |
+
|
5 |
+
OPENAI_MODEL_NAME = "gpt-3.5-turbo"
|
6 |
+
OPENAI_TEMPERATURE = 0.8
|
7 |
+
|
8 |
+
EMBEDDING_MODEL_NAME = "sentence-transformers/all-distilroberta-v1"
|
9 |
+
|
10 |
+
TEMPLATE_1 = "Create comma seperated product keywords to perform a query on a amazon dataset for this user input: {product_description}"
|
11 |
+
TEMPLATE_2 = """You are a salesman.Present the given product results in a nice way as answer to the user_msg. Dont ask questions back,
|
12 |
+
{chat_history}
|
13 |
+
user:{user_msg}
|
14 |
+
Chatbot:"""
|
database.py
CHANGED
@@ -1,13 +1,16 @@
|
|
1 |
-
import redis
|
2 |
import os
|
|
|
|
|
3 |
from dotenv import load_dotenv
|
4 |
|
5 |
load_dotenv()
|
6 |
-
redis_key = os.getenv('REDIS_KEY')
|
7 |
-
|
8 |
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
|
3 |
+
import redis
|
4 |
from dotenv import load_dotenv
|
5 |
|
6 |
load_dotenv()
|
|
|
|
|
7 |
|
8 |
|
9 |
+
def create_redis():
|
10 |
+
return redis.ConnectionPool(
|
11 |
+
host=os.getenv("REDIS_HOST"),
|
12 |
+
port=os.getenv("REDIS_PORT"),
|
13 |
+
password=os.getenv("REDIS_KEY"),
|
14 |
+
db=0,
|
15 |
+
decode_responses=True,
|
16 |
+
)
|
preprocess.py
CHANGED
@@ -1,48 +1,44 @@
|
|
1 |
-
from langchain.embeddings import OpenAIEmbeddings
|
2 |
-
from sentence_transformers import SentenceTransformer
|
3 |
-
import os
|
4 |
-
import pandas as pd
|
5 |
import numpy as np
|
6 |
-
|
7 |
-
|
8 |
-
from
|
9 |
-
|
10 |
|
|
|
|
|
11 |
|
12 |
-
|
|
|
|
|
13 |
MAX_TEXT_LENGTH = 512
|
|
|
|
|
|
|
14 |
|
15 |
-
def auto_truncate(text:str):
|
16 |
return text[0:MAX_TEXT_LENGTH]
|
17 |
|
18 |
-
|
19 |
-
data
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
23 |
data.reset_index(drop=True, inplace=True)
|
24 |
-
data_metadata = data.head(
|
25 |
|
26 |
-
#generating embeddings (vectors) for the item keywords
|
27 |
-
embedding_model = SentenceTransformer(
|
28 |
# embedding_model = OpenAIEmbeddings(openai_api_key=openai_api_key)
|
29 |
|
30 |
-
#get the item keywords attribute for each product and encode them into vector embeddings
|
31 |
-
item_keywords = [data_metadata[i][
|
32 |
item_keywords_vectors = [embedding_model.encode(item) for item in item_keywords]
|
33 |
|
34 |
-
|
35 |
-
NUMBER_PRODUCTS=500
|
36 |
-
|
37 |
-
print ('Loading and Indexing + ' + str(NUMBER_PRODUCTS) + ' products')
|
38 |
-
#flush all data
|
39 |
redis_conn.flushall()
|
40 |
-
#create flat index & load vectors
|
41 |
-
create_flat_index(redis_conn,NUMBER_PRODUCTS,TEXT_EMBEDDING_DIMENSION,
|
42 |
-
load_vectors(redis_conn,data_metadata,item_keywords_vectors)
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
import redis
|
4 |
+
from sentence_transformers import SentenceTransformer
|
|
|
5 |
|
6 |
+
from database import create_redis
|
7 |
+
from utils import create_flat_index, load_vectors
|
8 |
|
9 |
+
pool = create_redis()
|
10 |
+
redis_conn = redis.Redis(connection_pool=pool)
|
11 |
+
# set maximum length for text fields
|
12 |
MAX_TEXT_LENGTH = 512
|
13 |
+
TEXT_EMBEDDING_DIMENSION = 768
|
14 |
+
NUMBER_PRODUCTS = 10000
|
15 |
+
|
16 |
|
17 |
+
def auto_truncate(text: str):
|
18 |
return text[0:MAX_TEXT_LENGTH]
|
19 |
|
20 |
+
|
21 |
+
data = pd.read_csv(
|
22 |
+
"product_data.csv",
|
23 |
+
converters={"bullet_point": auto_truncate, "item_keywords": auto_truncate, "item_name": auto_truncate},
|
24 |
+
)
|
25 |
+
data["primary_key"] = data["item_id"] + "-" + data["domain_name"]
|
26 |
+
data.drop(columns=["item_id", "domain_name"], inplace=True)
|
27 |
+
data["item_keywords"].replace("", np.nan, inplace=True)
|
28 |
+
data.dropna(subset=["item_keywords"], inplace=True)
|
29 |
data.reset_index(drop=True, inplace=True)
|
30 |
+
data_metadata = data.head(10000).to_dict(orient="index")
|
31 |
|
32 |
+
# generating embeddings (vectors) for the item keywords
|
33 |
+
embedding_model = SentenceTransformer("sentence-transformers/all-distilroberta-v1")
|
34 |
# embedding_model = OpenAIEmbeddings(openai_api_key=openai_api_key)
|
35 |
|
36 |
+
# get the item keywords attribute for each product and encode them into vector embeddings
|
37 |
+
item_keywords = [data_metadata[i]["item_keywords"] for i in data_metadata.keys()]
|
38 |
item_keywords_vectors = [embedding_model.encode(item) for item in item_keywords]
|
39 |
|
40 |
+
# flush all data
|
|
|
|
|
|
|
|
|
41 |
redis_conn.flushall()
|
42 |
+
# create flat index & load vectors
|
43 |
+
create_flat_index(redis_conn, NUMBER_PRODUCTS, TEXT_EMBEDDING_DIMENSION, "COSINE")
|
44 |
+
load_vectors(redis_conn, data_metadata, item_keywords_vectors)
|
|
|
|
|
|
|
|
|
|
|
|
utilities.py
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
from redis import Redis
|
2 |
-
from redis.commands.search.field import VectorField
|
3 |
-
from redis.commands.search.field import TextField
|
4 |
-
from redis.commands.search.field import TagField
|
5 |
-
from redis.commands.search.result import Result
|
6 |
-
import numpy as np
|
7 |
-
|
8 |
-
def load_vectors(client:Redis, product_metadata, vector_dict):
|
9 |
-
p = client.pipeline(transaction=False)
|
10 |
-
for index in product_metadata.keys():
|
11 |
-
#hash key
|
12 |
-
key='product:'+ str(index)+ ':' + product_metadata[index]['primary_key']
|
13 |
-
|
14 |
-
#hash values
|
15 |
-
item_metadata = product_metadata[index]
|
16 |
-
item_keywords_vector = np.array(vector_dict[index], dtype=np.float32).tobytes()
|
17 |
-
item_metadata['item_vector']=item_keywords_vector
|
18 |
-
|
19 |
-
# HSET
|
20 |
-
p.hset(key,mapping=item_metadata)
|
21 |
-
|
22 |
-
p.execute()
|
23 |
-
|
24 |
-
def create_flat_index (redis_conn, number_of_vectors, vector_dimensions=512, distance_metric='L2'):
|
25 |
-
redis_conn.ft().create_index([
|
26 |
-
VectorField('item_vector', "FLAT", {"TYPE": "FLOAT32", "DIM": vector_dimensions, "DISTANCE_METRIC": distance_metric, "INITIAL_CAP": number_of_vectors, "BLOCK_SIZE":number_of_vectors }),
|
27 |
-
TagField("product_type"),
|
28 |
-
TextField("item_name"),
|
29 |
-
TextField("item_keywords"),
|
30 |
-
TagField("country")
|
31 |
-
])
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from redis import Redis
|
3 |
+
from redis.commands.search.field import TagField, TextField, VectorField
|
4 |
+
|
5 |
+
|
6 |
+
def load_vectors(client: Redis, product_metadata, vector_dict):
|
7 |
+
p = client.pipeline(transaction=False)
|
8 |
+
for index in product_metadata.keys():
|
9 |
+
# hash key
|
10 |
+
key = "product:" + str(index) + ":" + product_metadata[index]["primary_key"]
|
11 |
+
# hash values
|
12 |
+
item_metadata = product_metadata[index]
|
13 |
+
item_keywords_vector = np.array(vector_dict[index], dtype=np.float32).tobytes()
|
14 |
+
item_metadata["item_vector"] = item_keywords_vector
|
15 |
+
p.hset(key, mapping=item_metadata)
|
16 |
+
p.execute()
|
17 |
+
|
18 |
+
|
19 |
+
def create_flat_index(redis_conn, number_of_vectors, vector_dimensions=512, distance_metric="L2"):
|
20 |
+
redis_conn.ft().create_index(
|
21 |
+
[
|
22 |
+
VectorField(
|
23 |
+
"item_vector",
|
24 |
+
"FLAT",
|
25 |
+
{
|
26 |
+
"TYPE": "FLOAT32",
|
27 |
+
"DIM": vector_dimensions,
|
28 |
+
"DISTANCE_METRIC": distance_metric,
|
29 |
+
"INITIAL_CAP": number_of_vectors,
|
30 |
+
"BLOCK_SIZE": number_of_vectors,
|
31 |
+
},
|
32 |
+
),
|
33 |
+
TagField("product_type"),
|
34 |
+
TextField("item_name"),
|
35 |
+
TextField("item_keywords"),
|
36 |
+
TagField("country"),
|
37 |
+
]
|
38 |
+
)
|