thefish1 commited on
Commit
57c1fc4
·
1 Parent(s): 212abdf
Files changed (1) hide show
  1. app.py +52 -10
app.py CHANGED
@@ -317,7 +317,7 @@ def get_candidates_from_db(keywords_dict, class_name,limit=3):
317
  for embedding in embeddings:
318
  response = (
319
  weaviate_client.query
320
- .get(class_name, ['keywords', 'summary'])
321
  .with_near_vector({'vector': embedding})
322
  .with_limit(limit)
323
  .with_additional(['distance'])
@@ -331,8 +331,11 @@ def get_candidates_from_db(keywords_dict, class_name,limit=3):
331
  for result in results:
332
  candidate_list.append({
333
  'distance': result['_additional']['distance'],
 
 
334
  'summary': result['summary'],
335
- 'keywords': result['keywords']
 
336
  })
337
  return candidate_list
338
 
@@ -353,7 +356,7 @@ def keyword_match(keywords_dict,candidates):
353
  return candidate['distance'],candidate['summary'],candidate['keywords']
354
  return 1000,None,None
355
 
356
- def chatbot_response(message, history, window_size, threshold, user_weight, triggered_weight,candidate_length,api_key):
357
  #初始化openai client
358
  initialize_openai_client(api_key)
359
 
@@ -379,22 +382,60 @@ def chatbot_response(message, history, window_size, threshold, user_weight, trig
379
  #数据库检索,双方平均方式
380
  # distance, ad_summary, ad_keywords = get_response_from_db(keywords_dict, class_name="ad_DB02")
381
  #数据库索引,数据库关键词平均方式
382
- candidates=get_candidates_from_db(keywords_dict, class_name="Ad_DB03",limit=candidate_length)
 
 
 
 
 
 
383
 
384
  #先对候选集的distance进行筛选,保留小于threshold的
385
  candidates.sort(key=lambda x:x['distance'])
386
  candidates=[candidate for candidate in candidates if candidate['distance']<threshold]
 
387
  print("----------------------------------------------------------------------")
388
  print(f"keywords:{keywords_dict.keys()}")
389
  print(f"candidates:{candidates}")
 
 
 
 
390
 
 
391
  if(candidates):
392
- distance, ad_summary, ad_keywords=keyword_match(keywords_dict,candidates)
393
- else:
394
- distance=1000
395
-
396
- #判断相似度
397
- if distance and distance < threshold:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  brands=['腾讯','阿里巴巴','百度','京东','华为','小米','苹果','微软','谷歌','亚马逊']
399
  brand=random.choice(brands)
400
  ad_message = f"{message} <sep>品牌{brand}<sep>{ad_summary}"
@@ -439,6 +480,7 @@ demo = gr.ChatInterface(
439
  # gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
440
  gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Window size"),
441
  gr.Slider(minimum=0.01, maximum=0.25, value=0.10, step=0.01, label="Distance threshold"),
 
442
  gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Weight of keywords from users"),
443
  gr.Slider(minimum=0, maximum=2, value=0.5, step=0.5, label="Weight of triggered keywords"),
444
  gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Number of candidates"),
 
317
  for embedding in embeddings:
318
  response = (
319
  weaviate_client.query
320
+ .get(class_name, ['group_id','keyword_list','keyword', 'summary'])
321
  .with_near_vector({'vector': embedding})
322
  .with_limit(limit)
323
  .with_additional(['distance'])
 
331
  for result in results:
332
  candidate_list.append({
333
  'distance': result['_additional']['distance'],
334
+ 'group_id': result['group_id'],
335
+ 'keyword_list':result['keyword_list'],
336
  'summary': result['summary'],
337
+ 'keyword': result['keyword']
338
+
339
  })
340
  return candidate_list
341
 
 
356
  return candidate['distance'],candidate['summary'],candidate['keywords']
357
  return 1000,None,None
358
 
359
+ def chatbot_response(message, history, window_size, threshold, score_threshold,user_weight, triggered_weight,candidate_length,api_key):
360
  #初始化openai client
361
  initialize_openai_client(api_key)
362
 
 
382
  #数据库检索,双方平均方式
383
  # distance, ad_summary, ad_keywords = get_response_from_db(keywords_dict, class_name="ad_DB02")
384
  #数据库索引,数据库关键词平均方式
385
+ candidates=get_candidates_from_db(keywords_dict, class_name="Ad_DB05",limit=candidate_length)
386
+
387
+ # #对类别进行判断加权
388
+ # for candidate in candidates:
389
+ # if candidate['keyword']!= candidate['keyword_list'].split(',')[0]:
390
+ # candidate['distance']*=2
391
+
392
 
393
  #先对候选集的distance进行筛选,保留小于threshold的
394
  candidates.sort(key=lambda x:x['distance'])
395
  candidates=[candidate for candidate in candidates if candidate['distance']<threshold]
396
+
397
  print("----------------------------------------------------------------------")
398
  print(f"keywords:{keywords_dict.keys()}")
399
  print(f"candidates:{candidates}")
400
+
401
+ #此时的候选集中所有元素都至少有一个关键词命中了
402
+ #筛选后的候选集进行投票,选出被投票最多的一条
403
+ #投中第一个元素加双倍权重
404
 
405
+ group_scores={}
406
  if(candidates):
407
+ for candidate in candidates:
408
+ group_id=candidate['group_id']
409
+ keyword = candidate['keyword']
410
+ keyword_list = candidate['keyword_list'].split(',')
411
+
412
+ # 检查 keyword 是否是 keyword_list 中的第一个元素
413
+ if keyword == keyword_list[0]:
414
+ score = 2
415
+ else:
416
+ score = 1
417
+
418
+ # 更新 group_scores 字典中的分数
419
+ if group_id in group_scores:
420
+ group_scores[group_id] += score
421
+ else:
422
+ group_scores[group_id] = score
423
+ print(group_scores[:4])
424
+ if group_scores:
425
+ max_group_id = max(group_scores, key=group_scores.get)
426
+ max_score = group_scores[max_group_id]
427
+ if(max_score>=score_threshold):
428
+ distance,ad_summary,ad_keywords=[candidate['distance'],candidate['summary'],candidate['keyword_list'] for candidate in candidates if candidate['group_id']==max_group_id][0]
429
+ else:
430
+ distance=1000
431
+
432
+ # if(candidates):
433
+ # # distance, ad_summary, ad_keywords=keyword_match(keywords_dict,candidates)
434
+ # distance,ad_summary,ad_keywords=candidates[0]['distance'],candidates[0]['summary'],candidates[0]['keyword_list']
435
+ # else:
436
+ # distance=1000
437
+
438
+ if distance and distance < 1000:
439
  brands=['腾讯','阿里巴巴','百度','京东','华为','小米','苹果','微软','谷歌','亚马逊']
440
  brand=random.choice(brands)
441
  ad_message = f"{message} <sep>品牌{brand}<sep>{ad_summary}"
 
480
  # gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
481
  gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Window size"),
482
  gr.Slider(minimum=0.01, maximum=0.25, value=0.10, step=0.01, label="Distance threshold"),
483
+ gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Score threshold"),
484
  gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Weight of keywords from users"),
485
  gr.Slider(minimum=0, maximum=2, value=0.5, step=0.5, label="Weight of triggered keywords"),
486
  gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Number of candidates"),