thefish1 commited on
Commit
b031362
·
1 Parent(s): bb0dffa

update 0711 lk

Browse files
Files changed (3) hide show
  1. app.py +37 -10
  2. load_data.py +21 -0
  3. train_2000_modified.json +0 -0
app.py CHANGED
@@ -1,15 +1,26 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
- # from ad_matching import fetch_top_ad
 
 
 
4
 
5
 
 
6
 
7
- """
8
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
9
- """
10
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
11
 
12
 
 
 
 
 
 
 
 
 
 
13
  def get_keywords(message):
14
  system_message = """
15
  #角色
@@ -30,6 +41,20 @@ def get_keywords(message):
30
  response+=token
31
 
32
  keywords=response.split(' ')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
 
35
  def respond(
@@ -67,11 +92,13 @@ def respond(
67
  messages.append({"role": "assistant", "content": val[1]})
68
 
69
  key_words=get_keywords(message)
70
- # ad=fetch_top_ad(key_words)
71
- # if ad:
72
- # messages.append({"role": "assistant", "content": f"<sep> {ad}"})
73
-
74
- messages.append({"role": "user", "content": message})
 
 
75
 
76
  response = ""
77
 
@@ -116,4 +143,4 @@ demo = gr.ChatInterface(
116
 
117
 
118
  if __name__ == "__main__":
119
- demo.launch(share=True)
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ import json
4
+ import random
5
+ import re
6
+ from load_data import load_data
7
 
8
 
9
+ # from ad_matching import fetch_top_ad
10
 
11
+ #对话模型
 
 
12
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
13
 
14
 
15
+ #本地加载数据
16
+ dataset = load_data(file_path='train_2000.json',num_samples=2000)
17
+ keyword_lists = [item['content'] for item in dataset if 'content' in item]
18
+ summary_lists = [item['summary'] for item in dataset if 'summary' in item]
19
+
20
+ for item in keyword_lists:
21
+ item=item.split(',')
22
+
23
+
24
  def get_keywords(message):
25
  system_message = """
26
  #角色
 
41
  response+=token
42
 
43
  keywords=response.split(' ')
44
+ return keywords
45
+
46
+
47
+ def keyword_match(query_keywords, ad_keywords_lists):
48
+ max_matches = 0
49
+ most_matching_list = None
50
+ index=0
51
+ for i,lst in enumerate(ad_keywords_lists):
52
+ matches = sum(keyword in lst for keyword in query_keywords)
53
+ if matches > max_matches:
54
+ max_matches = matches
55
+ most_matching_list = lst
56
+ index=i
57
+ return max_matches,index
58
 
59
 
60
  def respond(
 
92
  messages.append({"role": "assistant", "content": val[1]})
93
 
94
  key_words=get_keywords(message)
95
+ max_matches,index=keyword_match(key_words,keyword_lists)
96
+
97
+ if max_matches>1:
98
+ ad=summary_lists[index]
99
+ messages.append({"role": "user", "content": f"{message} <sep> {ad}"})
100
+ else :
101
+ messages.append({"role": "user", "content": message})
102
 
103
  response = ""
104
 
 
143
 
144
 
145
  if __name__ == "__main__":
146
+ demo.launch(share=True)
load_data.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+
4
+ def load_data(file_path='train.json',num_samples=2000):
5
+ data = []
6
+ with open(file_path, 'r', encoding='utf-8') as file:
7
+ for line in file:
8
+ try:
9
+ data.append(json.loads(line))
10
+ except json.JSONDecodeError as e:
11
+ print(f"Error decoding JSON: {e}")
12
+ return data[:num_samples]
13
+
14
+
15
+ if __name__ == '__main__':
16
+
17
+ #数据集切分
18
+ dataset = load_data()
19
+ with open('train_2000.json', 'w', encoding='utf-8') as file:
20
+ for item in dataset:
21
+ file.write(json.dumps(item, ensure_ascii=False) + '\n')
train_2000_modified.json ADDED
The diff for this file is too large to render. See raw diff