foxxy-hm commited on
Commit
baa297c
·
1 Parent(s): 50d2c82

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -79
app.py CHANGED
@@ -1,83 +1,5 @@
1
  import streamlit as st
2
- # from src.models.predict_model import *
3
-
4
- from src.models.pairwise_model import *
5
- from src.features.text_utils import *
6
- import regex as re
7
- from src.models.bm25_utils import BM25Gensim
8
- from src.models.qa_model import *
9
- from tqdm.auto import tqdm
10
- tqdm.pandas()
11
- from datasets import load_dataset
12
-
13
- df_wiki_windows = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/wikipedia_20220620_cleaned_v2.csv")["train"].to_pandas()
14
- df_wiki = load_dataset("foxxy-hm/e2eqa-wiki", data_files="wikipedia_20220620_short.csv")["train"].to_pandas()
15
- df_wiki.title = df_wiki.title.apply(str)
16
-
17
- entity_dict = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/entities.json")["train"].to_dict()
18
- new_dict = dict()
19
- for key, val in entity_dict.items():
20
- val = val[0].replace("wiki/", "").replace("_", " ")
21
- entity_dict[key] = val
22
- key = preprocess(key)
23
- new_dict[key.lower()] = val
24
- entity_dict.update(new_dict)
25
- title2idx = dict([(x.strip(), y) for x, y in zip(df_wiki.title, df_wiki.index.values)])
26
-
27
- qa_model = QAEnsembleModel("nguyenvulebinh/vi-mrc-large", ["models/qa_model_robust.bin"], entity_dict)
28
- pairwise_model_stage1 = PairwiseModel("nguyenvulebinh/vi-mrc-base")#.half()
29
- pairwise_model_stage1.load_state_dict(torch.load("models/pairwise_v2.bin", map_location=torch.device('cpu')))
30
- pairwise_model_stage1.eval()
31
-
32
- pairwise_model_stage2 = PairwiseModel("nguyenvulebinh/vi-mrc-base")#.half()
33
- pairwise_model_stage2.load_state_dict(torch.load("models/pairwise_stage2_seed0.bin", map_location=torch.device('cpu')))
34
-
35
- bm25_model_stage1 = BM25Gensim("models/bm25_stage1/", entity_dict, title2idx)
36
- bm25_model_stage2_full = BM25Gensim("models/bm25_stage2/full_text/", entity_dict, title2idx)
37
- bm25_model_stage2_title = BM25Gensim("models/bm25_stage2/title/", entity_dict, title2idx)
38
-
39
- def get_answer_e2e(question):
40
- #Bm25 retrieval for top200 candidates
41
- query = preprocess(question).lower()
42
- top_n, bm25_scores = bm25_model_stage1.get_topk_stage1(query, topk=200)
43
- titles = [preprocess(df_wiki_windows.title.values[i]) for i in top_n]
44
- texts = [preprocess(df_wiki_windows.text.values[i]) for i in top_n]
45
-
46
- #Reranking with pairwise model for top10
47
- question = preprocess(question)
48
- ranking_preds = pairwise_model_stage1.stage1_ranking(question, texts)
49
- ranking_scores = ranking_preds * bm25_scores
50
-
51
- #Question answering
52
- best_idxs = np.argsort(ranking_scores)[-10:]
53
- ranking_scores = np.array(ranking_scores)[best_idxs]
54
- texts = np.array(texts)[best_idxs]
55
- best_answer = qa_model(question, texts, ranking_scores)
56
- if best_answer is None:
57
- return "Chịu"
58
- bm25_answer = preprocess(str(best_answer).lower(), max_length=128, remove_puncts=True)
59
-
60
- #Entity mapping
61
- if not check_number(bm25_answer):
62
- bm25_question = preprocess(str(question).lower(), max_length=128, remove_puncts=True)
63
- bm25_question_answer = bm25_question + " " + bm25_answer
64
- candidates, scores = bm25_model_stage2_title.get_topk_stage2(bm25_answer, raw_answer=best_answer)
65
- titles = [df_wiki.title.values[i] for i in candidates]
66
- texts = [df_wiki.text.values[i] for i in candidates]
67
- ranking_preds = pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts)
68
- if ranking_preds.max() >= 0.1:
69
- final_answer = titles[ranking_preds.argmax()]
70
- else:
71
- candidates, scores = bm25_model_stage2_full.get_topk_stage2(bm25_question_answer)
72
- titles = [df_wiki.title.values[i] for i in candidates] + titles
73
- texts = [df_wiki.text.values[i] for i in candidates] + texts
74
- ranking_preds = np.concatenate(
75
- [pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts), ranking_preds])
76
- final_answer = "wiki/"+titles[ranking_preds.argmax()].replace(" ","_")
77
- else:
78
- final_answer = bm25_answer.lower()
79
- return final_answer
80
-
81
 
82
  with st.sidebar:
83
  st.write("# 🤖 Language Models")
 
1
  import streamlit as st
2
+ from src.models.predict_model import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  with st.sidebar:
5
  st.write("# 🤖 Language Models")