Jackie2235 commited on
Commit
773120e
·
1 Parent(s): 487d029

Upload app.py, version from Jiaqi

Browse files
Files changed (1) hide show
  1. app.py +181 -217
app.py CHANGED
@@ -1,16 +1,11 @@
1
  import streamlit as st
2
- from streamlit_tags import st_tags, st_tags_sidebar
3
- from keytotext import pipeline
4
  from PIL import Image
5
 
6
  import json
7
  from sentence_transformers import SentenceTransformer, CrossEncoder, util
8
- import gzip
9
- import os
10
- import torch
11
  import pickle
12
- import random
13
- import numpy as np
14
 
15
  ############
16
  ## Main page
@@ -40,7 +35,7 @@ option1 = st.sidebar.selectbox(
40
  ('multi-qa-MiniLM-L6-cos-v1','null','null'))
41
 
42
  option2 = st.sidebar.selectbox(
43
- 'Which corss-encoder model would you like to be selected?',
44
  ('cross-encoder/ms-marco-MiniLM-L-6-v2','null','null'))
45
 
46
  st.sidebar.success("Load Successfully!")
@@ -49,22 +44,28 @@ st.sidebar.success("Load Successfully!")
49
  # print("Warning: No GPU found. Please add GPU to your notebook")
50
 
51
  #We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
52
- bi_encoder = SentenceTransformer(option1,device='cpu')
 
 
 
53
  bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens
54
  top_k = 32 #Number of passages we want to retrieve with the bi-encoder
55
 
56
- #The bi-encoder will retrieve 100 documents. We use a cross-encoder, to re-rank the results list to improve the quality
57
- cross_encoder = CrossEncoder(option2, device='cpu')
58
-
59
  passages = []
60
 
61
  # load pre-train embeedings files
 
 
 
 
 
 
 
 
 
62
  embedding_cache_path = 'etsy-embeddings-cpu.pkl'
63
- print("Load pre-computed embeddings from disc")
64
- with open(embedding_cache_path, "rb") as fIn:
65
- cache_data = pickle.load(fIn)
66
- passages = cache_data['sentences']
67
- corpus_embeddings = cache_data['embeddings']
68
 
69
  from rank_bm25 import BM25Okapi
70
  from sklearn.feature_extraction import _stop_words
@@ -75,15 +76,24 @@ import re
75
 
76
  import yake
77
 
78
- language = "en"
79
- max_ngram_size = 3
80
- deduplication_threshold = 0.9
81
- deduplication_algo = 'seqm'
82
- windowSize = 3
83
- numOfKeywords = 3
84
-
85
- custom_kw_extractor = yake.KeywordExtractor(lan=language, n=max_ngram_size, dedupLim=deduplication_threshold, dedupFunc=deduplication_algo, windowsSize=windowSize, top=numOfKeywords, features=None)
86
-
 
 
 
 
 
 
 
 
 
87
  # We lower case our text and remove stop-words from indexing
88
  def bm25_tokenizer(text):
89
  tokenized_doc = []
@@ -94,10 +104,14 @@ def bm25_tokenizer(text):
94
  tokenized_doc.append(token)
95
  return tokenized_doc
96
 
97
- tokenized_corpus = []
98
- for passage in tqdm(passages):
99
- tokenized_corpus.append(bm25_tokenizer(passage))
 
 
 
100
 
 
101
  bm25 = BM25Okapi(tokenized_corpus)
102
 
103
  def word_len(s):
@@ -106,205 +120,155 @@ def word_len(s):
106
 
107
  # This function will search all wikipedia articles for passages that
108
  # answer the query
109
- def search(query):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  print("Input query:", query)
111
- total_qe = []
112
 
113
  ##### BM25 search (lexical search) #####
114
  bm25_scores = bm25.get_scores(bm25_tokenizer(query))
115
- top_n = np.argpartition(bm25_scores, -5)[-5:]
116
- bm25_hits = [{'corpus_id': idx, 'score': bm25_scores[idx]} for idx in top_n]
117
- bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
118
-
119
- #print("Top-10 lexical search (BM25) hits")
120
- qe_string = []
121
- for hit in bm25_hits[0:1000]:
122
- if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
123
- qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
124
-
125
- sub_string = []
126
- for item in qe_string:
127
- for sub_item in item.split(","):
128
- sub_string.append(sub_item)
129
- #print(sub_string)
130
- total_qe.append(sub_string)
131
-
132
- ##### Sematic Search #####
133
- # Encode the query using the bi-encoder and find potentially relevant passages
134
- query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
135
- hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)
136
- hits = hits[0] # Get the hits for the first query
137
 
138
- ##### Re-Ranking #####
139
- # Now, score all retrieved passages with the cross_encoder
140
- cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
141
- cross_scores = cross_encoder.predict(cross_inp)
142
-
143
- # Sort results by the cross-encoder scores
144
- for idx in range(len(cross_scores)):
145
- hits[idx]['cross-score'] = cross_scores[idx]
146
-
147
- # Output of top-10 hits from bi-encoder
148
- #print("\n-------------------------\n")
149
- #print("Top-N Bi-Encoder Retrieval hits")
150
- hits = sorted(hits, key=lambda x: x['score'], reverse=True)
151
- qe_string = []
152
- for hit in hits[0:1000]:
153
- if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
154
- qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
155
- #print(qe_string)
156
- total_qe.append(qe_string)
157
-
158
- # Output of top-10 hits from re-ranker
159
- #print("\n-------------------------\n")
160
- #print("Top-N Cross-Encoder Re-ranker hits")
161
- hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
162
- qe_string = []
163
- for hit in hits[0:1000]:
164
- if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
165
- qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
166
- #print(qe_string)
167
- total_qe.append(qe_string)
168
-
169
- # Total Results
170
- total_qe.append(qe_string)
171
- st.write("E-Commerce Query Expansion Results: \n")
172
-
173
- res = []
174
- for sub_list in total_qe:
175
- for i in sub_list:
176
- rs = re.sub("([^\u0030-\u0039\u0041-\u007a])", ' ', i)
177
- rs_final = re.sub("\x20\x20", "\n", rs)
178
- #st.write(rs_final.strip())
179
- res.append(rs_final.strip())
180
-
181
- res_clean = []
182
- for out in res:
183
- if len(out) > 20:
184
- keywords = custom_kw_extractor.extract_keywords(out)
185
- for key in keywords:
186
- res_clean.append(key[0])
187
- else:
188
- res_clean.append(out)
189
-
190
- show_out = []
191
- for i in res_clean:
192
- num = word_len(i)
193
- if num > 1:
194
- show_out.append(i)
195
- unique_list = list(set(show_out))
196
- new_unique_list = [item for item in unique_list if item != query]
197
- Lowercasing_list = [item.lower() for item in new_unique_list]
198
- st.write(Lowercasing_list[0:maxtags_sidebar])
199
-
200
- return Lowercasing_list
201
-
202
- def search_nolog(query):
203
- total_qe = []
204
- ##### BM25 search (lexical search) #####
205
- bm25_scores = bm25.get_scores(bm25_tokenizer(query))
206
- top_n = np.argpartition(bm25_scores, -5)[-5:]
207
- bm25_hits = [{'corpus_id': idx, 'score': bm25_scores[idx]} for idx in top_n]
208
- bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
209
-
210
- qe_string = []
211
- for hit in bm25_hits[0:1000]:
212
- if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
213
- qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
214
-
215
- sub_string = []
216
- for item in qe_string:
217
- for sub_item in item.split(","):
218
- sub_string.append(sub_item)
219
- total_qe.append(sub_string)
220
 
221
  ##### Sematic Search #####
222
  # Encode the query using the bi-encoder and find potentially relevant passages
223
  query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
224
- hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)
225
- hits = hits[0] # Get the hits for the first query
 
226
 
227
- ##### Re-Ranking #####
228
- # Now, score all retrieved passages with the cross_encoder
229
- cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
230
  cross_scores = cross_encoder.predict(cross_inp)
231
-
232
- # Sort results by the cross-encoder scores
233
  for idx in range(len(cross_scores)):
234
- hits[idx]['cross-score'] = cross_scores[idx]
235
-
236
- # Output of top-10 hits from bi-encoder
237
- hits = sorted(hits, key=lambda x: x['score'], reverse=True)
238
- qe_string = []
239
- for hit in hits[0:1000]:
240
- if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
241
- qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
242
- total_qe.append(qe_string)
243
-
244
- # Output of top-10 hits from re-ranker
245
- hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
246
- qe_string = []
247
- for hit in hits[0:1000]:
248
- if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
249
- qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
250
- total_qe.append(qe_string)
251
-
 
 
 
 
 
 
 
 
 
 
252
  # Total Results
253
- total_qe.append(qe_string)
254
-
255
- res = []
256
- for sub_list in total_qe:
257
- for i in sub_list:
258
- rs = re.sub("([^\u0030-\u0039\u0041-\u007a])", ' ', i)
259
- rs_final = re.sub("\x20\x20", "\n", rs)
260
- res.append(rs_final.strip())
261
-
262
- res_clean = []
263
- for out in res:
264
- if len(out) > 20:
265
- keywords = custom_kw_extractor.extract_keywords(out)
266
- for key in keywords:
267
- res_clean.append(key[0])
268
- else:
269
- res_clean.append(out)
270
-
271
- show_out = []
272
- for i in res_clean:
273
- num = word_len(i)
274
- if num > 1:
275
- show_out.append(i)
276
-
277
- return show_out
278
-
279
- def reranking():
280
- rerank_list = []
281
- reres = []
282
- remove_dup = []
283
- rerank_list = search_nolog(query = user_query)
284
- unique_list = list(set(rerank_list))
285
- Lowercasing_list = [item.lower() for item in unique_list]
286
- new_unique_list = [item for item in Lowercasing_list if item != user_query]
287
-
288
- for i in new_unique_list:
289
- clean_string = i.strip()
290
- if clean_string not in remove_dup:
291
- remove_dup.append(clean_string)
292
-
293
- st.write("E-Commerce Query Expansion Results: \n")
294
- st.write(remove_dup[0:maxtags_sidebar])
295
-
296
- for i in remove_dup[0:maxtags_sidebar]:
297
- reres.append(i)
298
- np.random.seed(7)
299
- np.random.shuffle(reres)
300
- st.write("Reranking Results: \n")
301
- st.write(reres)
302
-
303
- st.write("## Results:")
304
- if st.button('Generated Expansion'):
305
- out_res = search(query = user_query)
306
- #st.success(out_res)
307
 
308
- if st.button('Rerank'):
309
- out_res = reranking()
310
- #st.success(out_res)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+
 
3
  from PIL import Image
4
 
5
  import json
6
  from sentence_transformers import SentenceTransformer, CrossEncoder, util
 
 
 
7
  import pickle
8
+ import pandas as pd
 
9
 
10
  ############
11
  ## Main page
 
35
  ('multi-qa-MiniLM-L6-cos-v1','null','null'))
36
 
37
  option2 = st.sidebar.selectbox(
38
+ 'Which cross-encoder model would you like to be selected?',
39
  ('cross-encoder/ms-marco-MiniLM-L-6-v2','null','null'))
40
 
41
  st.sidebar.success("Load Successfully!")
 
44
  # print("Warning: No GPU found. Please add GPU to your notebook")
45
 
46
  #We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
47
+ @st.cache_resource
48
+ def load_encoders(sentence_enc, cross_enc):
49
+ return SentenceTransformer(sentence_enc,device='cpu'), CrossEncoder(cross_enc,device='cpu')
50
+ bi_encoder, cross_encoder = load_encoders(option1,option2)
51
  bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens
52
  top_k = 32 #Number of passages we want to retrieve with the bi-encoder
53
 
 
 
 
54
  passages = []
55
 
56
  # load pre-train embeedings files
57
+ @st.cache_resource
58
+ def load_pickle(path):
59
+ with open(path, "rb") as fIn:
60
+ cache_data = pickle.load(fIn)
61
+ passages = cache_data['sentences']
62
+ corpus_embeddings = cache_data['embeddings']
63
+ print("Load pre-computed embeddings from disc")
64
+ return passages,corpus_embeddings
65
+
66
  embedding_cache_path = 'etsy-embeddings-cpu.pkl'
67
+ passages,corpus_embeddings = load_pickle(embedding_cache_path)
68
+
 
 
 
69
 
70
  from rank_bm25 import BM25Okapi
71
  from sklearn.feature_extraction import _stop_words
 
76
 
77
  import yake
78
 
79
+ @st.cache_resource
80
+ def load_model():
81
+ language = "en"
82
+ max_ngram_size = 3
83
+ deduplication_threshold = 0.9
84
+ deduplication_algo = 'seqm'
85
+ windowSize = 3
86
+ numOfKeywords = 3
87
+ return yake.KeywordExtractor(lan=language, n=max_ngram_size, dedupLim=deduplication_threshold, dedupFunc=deduplication_algo, windowsSize=windowSize, top=numOfKeywords, features=None)
88
+ custom_kw_extractor = load_model()
89
+ # load query GMS information
90
+ @st.cache_resource
91
+ def load_json(path):
92
+ with open(path, 'r') as file:
93
+ query_gms_dict = json.load(file)
94
+ return query_gms_dict
95
+
96
+ query_gms_dict = load_json('query_gms.json')
97
  # We lower case our text and remove stop-words from indexing
98
  def bm25_tokenizer(text):
99
  tokenized_doc = []
 
104
  tokenized_doc.append(token)
105
  return tokenized_doc
106
 
107
+ @st.cache_resource
108
+ def get_tokenized_corpus(passages,_tokenizer):
109
+ tokenized_corpus = []
110
+ for passage in passages:
111
+ tokenized_corpus.append(_tokenizer(passage))
112
+ return tokenized_corpus
113
 
114
+ tokenized_corpus = get_tokenized_corpus(passages,bm25_tokenizer)
115
  bm25 = BM25Okapi(tokenized_corpus)
116
 
117
  def word_len(s):
 
120
 
121
  # This function will search all wikipedia articles for passages that
122
  # answer the query
123
+ DEFAULT_SCORE = -100.0
124
+ def clean_string(input_string):
125
+ string_sub1 = re.sub("([^\u0030-\u0039\u0041-\u007a])", ' ', input_string)
126
+ string_sub2 = re.sub("\x20\x20", "\n", string_sub1)
127
+ string_strip = string_sub2.strip().lower()
128
+ output_string = []
129
+ if len(string_strip) > 20:
130
+ keywords = custom_kw_extractor.extract_keywords(string_strip)
131
+ for tokens in keywords:
132
+ string_clean = tokens[0]
133
+ if word_len(string_clean) > 1:
134
+ output_string.append(string_clean)
135
+ else:
136
+ output_string.append(string_strip)
137
+ return output_string
138
+
139
+ # def add_gms_score_for_candidates(candidates, query_gms_dict):
140
+ # for query_candidate in candidates:
141
+ # value = candidates[query_candidate]
142
+ # value['gms'] = query_gms_dict.get(query_candidate, 0)
143
+ # candidates[query_candidate] = value
144
+ # return candidates
145
+
146
+ def generate_query_expansion_candidates(query):
147
  print("Input query:", query)
148
+ expanded_query_set = {}
149
 
150
  ##### BM25 search (lexical search) #####
151
  bm25_scores = bm25.get_scores(bm25_tokenizer(query))
152
+ # finds the indices of the top n scores
153
+ top_n_indices = np.argpartition(bm25_scores, -5)[-5:]
154
+ bm25_hits = [{'corpus_id': idx, 'bm25_score': bm25_scores[idx]} for idx in top_n_indices]
155
+ # bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  ##### Sematic Search #####
159
  # Encode the query using the bi-encoder and find potentially relevant passages
160
  query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
161
+ # query_embedding = query_embedding.cuda()
162
+ # Get the hits for the first query
163
+ encoder_hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)[0]
164
 
165
+ # For all retrieved passages, add the cross_encoder scores
166
+ cross_inp = [[query, passages[hit['corpus_id']]] for hit in encoder_hits]
 
167
  cross_scores = cross_encoder.predict(cross_inp)
 
 
168
  for idx in range(len(cross_scores)):
169
+ encoder_hits[idx]['cross_score'] = cross_scores[idx]
170
+
171
+ candidates = {}
172
+ for hit in bm25_hits:
173
+ corpus_id = hit['corpus_id']
174
+ if corpus_id not in candidates:
175
+ candidates[corpus_id] = {'bm25_score': hit['bm25_score'], 'bi_score': DEFAULT_SCORE, 'cross_score': DEFAULT_SCORE}
176
+ for hit in encoder_hits:
177
+ corpus_id = hit['corpus_id']
178
+ if corpus_id not in candidates:
179
+ candidates[corpus_id] = {'bm25_score': DEFAULT_SCORE, 'bi_score': hit['score'], 'cross_score': hit['cross_score']}
180
+ else:
181
+ bm25_score = candidates[corpus_id]['bm25_score']
182
+ candidates[corpus_id].update({'bm25_score': bm25_score, 'bi_score': hit['score'], 'cross_score': hit['cross_score']})
183
+
184
+ final_candidates = {}
185
+ for key, value in candidates.items():
186
+ input_string = passages[key].replace("\n", "")
187
+ string_set = set(clean_string(input_string))
188
+ for item in string_set:
189
+ final_candidates[item.replace("\n", " ")] = value
190
+ # remove the query itself from candidates
191
+ if query in final_candidates:
192
+ del final_candidates[query]
193
+ # print(final_candidates)
194
+ # add gms column
195
+ df = pd.DataFrame(final_candidates).T
196
+ df['gms'] = [query_gms_dict.get(i,0) for i in df.index]
197
  # Total Results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
+ return df.to_dict('index')
200
+
201
+ def re_rank_candidates(query, candidates, method):
202
+ if method == 'bm25':
203
+ # Filter and sort by bm25_score
204
+ filtered_sorted_result = sorted(
205
+ [(k, v) for k, v in candidates.items() if v['bm25_score'] > DEFAULT_SCORE],
206
+ key=lambda x: x[1]['bm25_score'],
207
+ reverse=True
208
+ )
209
+ elif method == 'bi_encoder':
210
+ # Filter and sort by bi_score
211
+ filtered_sorted_result = sorted(
212
+ [(k, v) for k, v in candidates.items() if v['bi_score'] > DEFAULT_SCORE],
213
+ key=lambda x: x[1]['bi_score'],
214
+ reverse=True
215
+ )
216
+ elif method == 'cross_encoder':
217
+ # Filter and sort by cross_score
218
+ filtered_sorted_result = sorted(
219
+ [(k, v) for k, v in candidates.items() if v['cross_score'] > DEFAULT_SCORE],
220
+ key=lambda x: x[1]['cross_score'],
221
+ reverse=True
222
+ )
223
+ elif method == 'gms':
224
+ filtered_sorted_by_encoder = sorted(
225
+ [(k, v) for k, v in candidates.items() if (v['cross_score'] > DEFAULT_SCORE) & (v['bi_score'] > DEFAULT_SCORE)],
226
+ key=lambda x: x[1]['cross_score'] + x[1]['bi_score'],
227
+ reverse=True
228
+ )
229
+ # first sort by cross_score + bi_score
230
+ filtered_sorted_result = sorted(filtered_sorted_by_encoder, key=lambda x: x[1]['gms'], reverse=True
231
+ )
232
+ else:
233
+ # use default method cross_score + bi_score
234
+ # Filter and sort by cross_score + bi_score
235
+ filtered_sorted_result = sorted(
236
+ [(k, v) for k, v in candidates.items() if (v['cross_score'] > DEFAULT_SCORE) & (v['bi_score'] > DEFAULT_SCORE)],
237
+ key=lambda x: x[1]['cross_score'] + x[1]['bi_score'],
238
+ reverse=True
239
+ )
240
+ data_dicts = [{'query': item[0], **item[1]} for item in filtered_sorted_result]
241
+ # Convert the list of dictionaries into a DataFrame
242
+ df = pd.DataFrame(data_dicts)
243
+ return df
244
+
245
+
246
+ # st.write("## Raw Candidates:")
247
+ if st.button('Generated Expansion'):
248
+ col1, col2 = st.columns(2)
249
+ candidates = generate_query_expansion_candidates(query = user_query)
250
+
251
+ with col1:
252
+ st.subheader('Original Ranking')
253
+ ranking_cross = re_rank_candidates(user_query, candidates, method='cross_encoder')
254
+ ranking_cross.index = ranking_cross.index+1
255
+ st.table(ranking_cross['query'][:maxtags_sidebar])
256
+
257
+ with col2:
258
+ st.subheader('GMS-sorted Ranking')
259
+ ranking_gms = re_rank_candidates(user_query, candidates, method='gms')
260
+ ranking_gms.index = ranking_gms.index + 1
261
+ st.table(ranking_gms[['query', 'gms']][:maxtags_sidebar])
262
+
263
+ ## convert into dataframe
264
+ # data_dicts = [{'query': key, **values} for key, values in candidates.items()]
265
+ # df = pd.DataFrame(data_dicts)
266
+ # st.write(list(candidates.keys())[0:maxtags_sidebar])
267
+ # st.write(df)
268
+ # st.dataframe(df)
269
+ # st.success(raw_candidates)
270
+
271
+ #if st.button('Rerank By GMS'):
272
+ #candidates = generate_query_expansion_candidates(query = user_query)
273
+ #df = re_rank_candidates(user_query, candidates, method='gms')
274
+ #st.dataframe(df[['query', 'gms']][:maxtags_sidebar])