update
Browse files
app.py
CHANGED
@@ -317,7 +317,7 @@ def get_candidates_from_db(keywords_dict, class_name,limit=3):
|
|
317 |
for embedding in embeddings:
|
318 |
response = (
|
319 |
weaviate_client.query
|
320 |
-
.get(class_name, ['
|
321 |
.with_near_vector({'vector': embedding})
|
322 |
.with_limit(limit)
|
323 |
.with_additional(['distance'])
|
@@ -331,8 +331,11 @@ def get_candidates_from_db(keywords_dict, class_name,limit=3):
|
|
331 |
for result in results:
|
332 |
candidate_list.append({
|
333 |
'distance': result['_additional']['distance'],
|
|
|
|
|
334 |
'summary': result['summary'],
|
335 |
-
'
|
|
|
336 |
})
|
337 |
return candidate_list
|
338 |
|
@@ -353,7 +356,7 @@ def keyword_match(keywords_dict,candidates):
|
|
353 |
return candidate['distance'],candidate['summary'],candidate['keywords']
|
354 |
return 1000,None,None
|
355 |
|
356 |
-
def chatbot_response(message, history, window_size, threshold, user_weight, triggered_weight,candidate_length,api_key):
|
357 |
#初始化openai client
|
358 |
initialize_openai_client(api_key)
|
359 |
|
@@ -379,22 +382,60 @@ def chatbot_response(message, history, window_size, threshold, user_weight, trig
|
|
379 |
#数据库检索,双方平均方式
|
380 |
# distance, ad_summary, ad_keywords = get_response_from_db(keywords_dict, class_name="ad_DB02")
|
381 |
#数据库索引,数据库关键词平均方式
|
382 |
-
candidates=get_candidates_from_db(keywords_dict, class_name="
|
|
|
|
|
|
|
|
|
|
|
|
|
383 |
|
384 |
#先对候选集的distance进行筛选,保留小于threshold的
|
385 |
candidates.sort(key=lambda x:x['distance'])
|
386 |
candidates=[candidate for candidate in candidates if candidate['distance']<threshold]
|
|
|
387 |
print("----------------------------------------------------------------------")
|
388 |
print(f"keywords:{keywords_dict.keys()}")
|
389 |
print(f"candidates:{candidates}")
|
|
|
|
|
|
|
|
|
390 |
|
|
|
391 |
if(candidates):
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
398 |
brands=['腾讯','阿里巴巴','百度','京东','华为','小米','苹果','微软','谷歌','亚马逊']
|
399 |
brand=random.choice(brands)
|
400 |
ad_message = f"{message} <sep>品牌{brand}<sep>{ad_summary}"
|
@@ -439,6 +480,7 @@ demo = gr.ChatInterface(
|
|
439 |
# gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
|
440 |
gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Window size"),
|
441 |
gr.Slider(minimum=0.01, maximum=0.25, value=0.10, step=0.01, label="Distance threshold"),
|
|
|
442 |
gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Weight of keywords from users"),
|
443 |
gr.Slider(minimum=0, maximum=2, value=0.5, step=0.5, label="Weight of triggered keywords"),
|
444 |
gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Number of candidates"),
|
|
|
317 |
for embedding in embeddings:
|
318 |
response = (
|
319 |
weaviate_client.query
|
320 |
+
.get(class_name, ['group_id','keyword_list','keyword', 'summary'])
|
321 |
.with_near_vector({'vector': embedding})
|
322 |
.with_limit(limit)
|
323 |
.with_additional(['distance'])
|
|
|
331 |
for result in results:
|
332 |
candidate_list.append({
|
333 |
'distance': result['_additional']['distance'],
|
334 |
+
'group_id': result['group_id'],
|
335 |
+
'keyword_list':result['keyword_list'],
|
336 |
'summary': result['summary'],
|
337 |
+
'keyword': result['keyword']
|
338 |
+
|
339 |
})
|
340 |
return candidate_list
|
341 |
|
|
|
356 |
return candidate['distance'],candidate['summary'],candidate['keywords']
|
357 |
return 1000,None,None
|
358 |
|
359 |
+
def chatbot_response(message, history, window_size, threshold, score_threshold,user_weight, triggered_weight,candidate_length,api_key):
|
360 |
#初始化openai client
|
361 |
initialize_openai_client(api_key)
|
362 |
|
|
|
382 |
#数据库检索,双方平均方式
|
383 |
# distance, ad_summary, ad_keywords = get_response_from_db(keywords_dict, class_name="ad_DB02")
|
384 |
#数据库索引,数据库关键词平均方式
|
385 |
+
candidates=get_candidates_from_db(keywords_dict, class_name="Ad_DB05",limit=candidate_length)
|
386 |
+
|
387 |
+
# #对类别进行判断加权
|
388 |
+
# for candidate in candidates:
|
389 |
+
# if candidate['keyword']!= candidate['keyword_list'].split(',')[0]:
|
390 |
+
# candidate['distance']*=2
|
391 |
+
|
392 |
|
393 |
#先对候选集的distance进行筛选,保留小于threshold的
|
394 |
candidates.sort(key=lambda x:x['distance'])
|
395 |
candidates=[candidate for candidate in candidates if candidate['distance']<threshold]
|
396 |
+
|
397 |
print("----------------------------------------------------------------------")
|
398 |
print(f"keywords:{keywords_dict.keys()}")
|
399 |
print(f"candidates:{candidates}")
|
400 |
+
|
401 |
+
#此时的候选集中所有元素都至少有一个关键词命中了
|
402 |
+
#筛选后的候选集进行投票,选出被投票最多的一条
|
403 |
+
#投中第一个元素加双倍权重
|
404 |
|
405 |
+
group_scores={}
|
406 |
if(candidates):
|
407 |
+
for candidate in candidates:
|
408 |
+
group_id=candidate['group_id']
|
409 |
+
keyword = candidate['keyword']
|
410 |
+
keyword_list = candidate['keyword_list'].split(',')
|
411 |
+
|
412 |
+
# 检查 keyword 是否是 keyword_list 中的第一个元素
|
413 |
+
if keyword == keyword_list[0]:
|
414 |
+
score = 2
|
415 |
+
else:
|
416 |
+
score = 1
|
417 |
+
|
418 |
+
# 更新 group_scores 字典中的分数
|
419 |
+
if group_id in group_scores:
|
420 |
+
group_scores[group_id] += score
|
421 |
+
else:
|
422 |
+
group_scores[group_id] = score
|
423 |
+
print(group_scores[:4])
|
424 |
+
if group_scores:
|
425 |
+
max_group_id = max(group_scores, key=group_scores.get)
|
426 |
+
max_score = group_scores[max_group_id]
|
427 |
+
if(max_score>=score_threshold):
|
428 |
+
distance,ad_summary,ad_keywords=[candidate['distance'],candidate['summary'],candidate['keyword_list'] for candidate in candidates if candidate['group_id']==max_group_id][0]
|
429 |
+
else:
|
430 |
+
distance=1000
|
431 |
+
|
432 |
+
# if(candidates):
|
433 |
+
# # distance, ad_summary, ad_keywords=keyword_match(keywords_dict,candidates)
|
434 |
+
# distance,ad_summary,ad_keywords=candidates[0]['distance'],candidates[0]['summary'],candidates[0]['keyword_list']
|
435 |
+
# else:
|
436 |
+
# distance=1000
|
437 |
+
|
438 |
+
if distance and distance < 1000:
|
439 |
brands=['腾讯','阿里巴巴','百度','京东','华为','小米','苹果','微软','谷歌','亚马逊']
|
440 |
brand=random.choice(brands)
|
441 |
ad_message = f"{message} <sep>品牌{brand}<sep>{ad_summary}"
|
|
|
480 |
# gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
|
481 |
gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Window size"),
|
482 |
gr.Slider(minimum=0.01, maximum=0.25, value=0.10, step=0.01, label="Distance threshold"),
|
483 |
+
gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Score threshold"),
|
484 |
gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Weight of keywords from users"),
|
485 |
gr.Slider(minimum=0, maximum=2, value=0.5, step=0.5, label="Weight of triggered keywords"),
|
486 |
gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Number of candidates"),
|