update
Browse files
app.py
CHANGED
@@ -204,6 +204,9 @@ import torch
|
|
204 |
from tqdm import tqdm
|
205 |
import numpy as np
|
206 |
import time
|
|
|
|
|
|
|
207 |
|
208 |
# 设置缓存目录
|
209 |
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
|
@@ -212,10 +215,32 @@ os.makedirs(os.environ['MPLCONFIGDIR'], exist_ok=True)
|
|
212 |
os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
|
213 |
|
214 |
# Weaviate 连接配置
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
WEAVIATE_API_KEY = "Y7c8DRmcxZ4nP5IJLwkznIsK84l6EdwfXwcH"
|
216 |
WEAVIATE_URL = "https://39nlafviqvard82k6y8btq.c0.asia-southeast1.gcp.weaviate.cloud"
|
|
|
|
|
|
|
217 |
weaviate_auth_config = weaviate.AuthApiKey(api_key=WEAVIATE_API_KEY)
|
218 |
-
weaviate_client = weaviate.Client(
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
# 预训练模型配置
|
221 |
MODEL_NAME = "bert-base-chinese"
|
@@ -421,13 +446,13 @@ def chatbot_response(message, history, window_size, threshold, score_threshold,u
|
|
421 |
else:
|
422 |
group_scores[group_id] = score
|
423 |
|
|
|
424 |
if group_scores:
|
425 |
max_group_id = max(group_scores, key=group_scores.get)
|
426 |
max_score = group_scores[max_group_id]
|
427 |
if(max_score>=score_threshold):
|
428 |
distance,ad_summary,ad_keywords=[(candidate['distance'],candidate['summary'],candidate['keyword_list']) for candidate in candidates if candidate['group_id']==max_group_id][0]
|
429 |
-
|
430 |
-
distance=1000
|
431 |
|
432 |
# if(candidates):
|
433 |
# # distance, ad_summary, ad_keywords=keyword_match(keywords_dict,candidates)
|
@@ -435,7 +460,7 @@ def chatbot_response(message, history, window_size, threshold, score_threshold,u
|
|
435 |
# else:
|
436 |
# distance=1000
|
437 |
|
438 |
-
if distance
|
439 |
brands=['腾讯','阿里巴巴','百度','京东','华为','小米','苹果','微软','谷歌','亚马逊']
|
440 |
brand=random.choice(brands)
|
441 |
ad_message = f"{message} <sep>品牌{brand}<sep>{ad_summary}"
|
|
|
204 |
from tqdm import tqdm
|
205 |
import numpy as np
|
206 |
import time
|
207 |
+
import requests
|
208 |
+
from requests.adapters import HTTPAdapter
|
209 |
+
from requests.packages.urllib3.util.retry import Retry
|
210 |
|
211 |
# 设置缓存目录
|
212 |
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
|
|
|
215 |
os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
|
216 |
|
217 |
# Weaviate 连接配置
|
218 |
+
retry_strategy = Retry(
|
219 |
+
total=3, # 总共重试次数
|
220 |
+
status_forcelist=[429, 500, 502, 503, 504], # 需要重试的状态码
|
221 |
+
method_whitelist=["HEAD", "GET", "OPTIONS", "POST"], # 需要重试的方法
|
222 |
+
backoff_factor=1 # 重试间隔时间的倍数
|
223 |
+
)
|
224 |
+
adapter = HTTPAdapter(max_retries=retry_strategy)
|
225 |
+
|
226 |
+
http = requests.Session()
|
227 |
+
http.mount("https://", adapter)
|
228 |
+
http.mount("http://", adapter)
|
229 |
+
|
230 |
+
timeout =10 # 超时时间(秒)
|
231 |
+
|
232 |
WEAVIATE_API_KEY = "Y7c8DRmcxZ4nP5IJLwkznIsK84l6EdwfXwcH"
|
233 |
WEAVIATE_URL = "https://39nlafviqvard82k6y8btq.c0.asia-southeast1.gcp.weaviate.cloud"
|
234 |
+
# weaviate_auth_config = weaviate.AuthApiKey(api_key=WEAVIATE_API_KEY)
|
235 |
+
# weaviate_client = weaviate.Client(url=WEAVIATE_URL, auth_client_secret=weaviate_auth_config)
|
236 |
+
# 创建 Weaviate 客户端
|
237 |
weaviate_auth_config = weaviate.AuthApiKey(api_key=WEAVIATE_API_KEY)
|
238 |
+
weaviate_client = weaviate.Client(
|
239 |
+
url=WEAVIATE_URL,
|
240 |
+
auth_client_secret=weaviate_auth_config,
|
241 |
+
timeout_config=(timeout, timeout) # 连接超时和读取超时
|
242 |
+
)
|
243 |
+
|
244 |
|
245 |
# 预训练模型配置
|
246 |
MODEL_NAME = "bert-base-chinese"
|
|
|
446 |
else:
|
447 |
group_scores[group_id] = score
|
448 |
|
449 |
+
distance=1000
|
450 |
if group_scores:
|
451 |
max_group_id = max(group_scores, key=group_scores.get)
|
452 |
max_score = group_scores[max_group_id]
|
453 |
if(max_score>=score_threshold):
|
454 |
distance,ad_summary,ad_keywords=[(candidate['distance'],candidate['summary'],candidate['keyword_list']) for candidate in candidates if candidate['group_id']==max_group_id][0]
|
455 |
+
|
|
|
456 |
|
457 |
# if(candidates):
|
458 |
# # distance, ad_summary, ad_keywords=keyword_match(keywords_dict,candidates)
|
|
|
460 |
# else:
|
461 |
# distance=1000
|
462 |
|
463 |
+
if distance < 1000:
|
464 |
brands=['腾讯','阿里巴巴','百度','京东','华为','小米','苹果','微软','谷歌','亚马逊']
|
465 |
brand=random.choice(brands)
|
466 |
ad_message = f"{message} <sep>品牌{brand}<sep>{ad_summary}"
|