thefish1 commited on
Commit
941679b
·
1 Parent(s): 2bd0872
Files changed (1) hide show
  1. app.py +29 -4
app.py CHANGED
@@ -204,6 +204,9 @@ import torch
204
  from tqdm import tqdm
205
  import numpy as np
206
  import time
 
 
 
207
 
208
  # 设置缓存目录
209
  os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
@@ -212,10 +215,32 @@ os.makedirs(os.environ['MPLCONFIGDIR'], exist_ok=True)
212
  os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
213
 
214
  # Weaviate 连接配置
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  WEAVIATE_API_KEY = "Y7c8DRmcxZ4nP5IJLwkznIsK84l6EdwfXwcH"
216
  WEAVIATE_URL = "https://39nlafviqvard82k6y8btq.c0.asia-southeast1.gcp.weaviate.cloud"
 
 
 
217
  weaviate_auth_config = weaviate.AuthApiKey(api_key=WEAVIATE_API_KEY)
218
- weaviate_client = weaviate.Client(url=WEAVIATE_URL, auth_client_secret=weaviate_auth_config)
 
 
 
 
 
219
 
220
  # 预训练模型配置
221
  MODEL_NAME = "bert-base-chinese"
@@ -421,13 +446,13 @@ def chatbot_response(message, history, window_size, threshold, score_threshold,u
421
  else:
422
  group_scores[group_id] = score
423
 
 
424
  if group_scores:
425
  max_group_id = max(group_scores, key=group_scores.get)
426
  max_score = group_scores[max_group_id]
427
  if(max_score>=score_threshold):
428
  distance,ad_summary,ad_keywords=[(candidate['distance'],candidate['summary'],candidate['keyword_list']) for candidate in candidates if candidate['group_id']==max_group_id][0]
429
- else:
430
- distance=1000
431
 
432
  # if(candidates):
433
  # # distance, ad_summary, ad_keywords=keyword_match(keywords_dict,candidates)
@@ -435,7 +460,7 @@ def chatbot_response(message, history, window_size, threshold, score_threshold,u
435
  # else:
436
  # distance=1000
437
 
438
- if distance and distance < 1000:
439
  brands=['腾讯','阿里巴巴','百度','京东','华为','小米','苹果','微软','谷歌','亚马逊']
440
  brand=random.choice(brands)
441
  ad_message = f"{message} <sep>品牌{brand}<sep>{ad_summary}"
 
204
  from tqdm import tqdm
205
  import numpy as np
206
  import time
207
+ import requests
208
+ from requests.adapters import HTTPAdapter
209
+ from requests.packages.urllib3.util.retry import Retry
210
 
211
  # 设置缓存目录
212
  os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
 
215
  os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
216
 
217
  # Weaviate 连接配置
218
+ retry_strategy = Retry(
219
+ total=3, # 总共重试次数
220
+ status_forcelist=[429, 500, 502, 503, 504], # 需要重试的状态码
221
+ method_whitelist=["HEAD", "GET", "OPTIONS", "POST"], # 需要重试的方法
222
+ backoff_factor=1 # 重试间隔时间的倍数
223
+ )
224
+ adapter = HTTPAdapter(max_retries=retry_strategy)
225
+
226
+ http = requests.Session()
227
+ http.mount("https://", adapter)
228
+ http.mount("http://", adapter)
229
+
230
+ timeout =10 # 超时时间(秒)
231
+
232
  WEAVIATE_API_KEY = "Y7c8DRmcxZ4nP5IJLwkznIsK84l6EdwfXwcH"
233
  WEAVIATE_URL = "https://39nlafviqvard82k6y8btq.c0.asia-southeast1.gcp.weaviate.cloud"
234
+ # weaviate_auth_config = weaviate.AuthApiKey(api_key=WEAVIATE_API_KEY)
235
+ # weaviate_client = weaviate.Client(url=WEAVIATE_URL, auth_client_secret=weaviate_auth_config)
236
+ # 创建 Weaviate 客户端
237
  weaviate_auth_config = weaviate.AuthApiKey(api_key=WEAVIATE_API_KEY)
238
+ weaviate_client = weaviate.Client(
239
+ url=WEAVIATE_URL,
240
+ auth_client_secret=weaviate_auth_config,
241
+ timeout_config=(timeout, timeout) # 连接超时和读取超时
242
+ )
243
+
244
 
245
  # 预训练模型配置
246
  MODEL_NAME = "bert-base-chinese"
 
446
  else:
447
  group_scores[group_id] = score
448
 
449
+ distance=1000
450
  if group_scores:
451
  max_group_id = max(group_scores, key=group_scores.get)
452
  max_score = group_scores[max_group_id]
453
  if(max_score>=score_threshold):
454
  distance,ad_summary,ad_keywords=[(candidate['distance'],candidate['summary'],candidate['keyword_list']) for candidate in candidates if candidate['group_id']==max_group_id][0]
455
+
 
456
 
457
  # if(candidates):
458
  # # distance, ad_summary, ad_keywords=keyword_match(keywords_dict,candidates)
 
460
  # else:
461
  # distance=1000
462
 
463
+ if distance < 1000:
464
  brands=['腾讯','阿里巴巴','百度','京东','华为','小米','苹果','微软','谷歌','亚马逊']
465
  brand=random.choice(brands)
466
  ad_message = f"{message} <sep>品牌{brand}<sep>{ad_summary}"