import gradio as gr from huggingface_hub import InferenceClient import json import random import re from load_data import load_data from openai import OpenAI from transformers import AutoTokenizer, AutoModel from fetch_from_database import encode, insert_keywords_to_weaviate, fetch_summary_from_database,init_database import weaviate import os import subprocess # 设置 Matplotlib 的缓存目录 os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib' # 设置 Hugging Face Transformers 的缓存目录 os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache' # 确保这些目录存在 os.makedirs(os.environ['MPLCONFIGDIR'], exist_ok=True) os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True) auth_config = weaviate.AuthApiKey(api_key="8wNsHV3Enc2PNVL8Bspadh21qYAfAvnK2ux3") database_client = weaviate.Client( url="https://3a8sbx3s66by10yxginaa.c0.asia-southeast1.gcp.weaviate.cloud", auth_client_secret=auth_config ) class_name="Lhnjames123321" tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese") model = AutoModel.from_pretrained("bert-base-chinese") # 本地加载数据 dataset = load_data(file_path='train_2000_modified.json', num_samples=2000) keyword_lists = [item['content'] for item in dataset if 'content' in item] summary_lists = [item['summary'] for item in dataset if 'summary' in item] global_api_key = None client = None def initialize_clients(api_key): global client client = OpenAI(api_key=api_key) for item in keyword_lists: item = item.split(',') def get_keywords(message): system_message = """ # 角色 你是一个关键词提取机器人 # 指令 你的目标是从用户的输入中提取关键词,这些关键词应该尽可能是购买意图相关的。 # 输出格式 你应该直接输出关键词,关键词之间用空格分隔。例如:苹果 电脑 裤子 蓝色 裙 # 注意:如果输入文本过短可以重复输出关键词,例如对输入“你好”可以输出:你好 你好 你好 你好 你好 """ messages = [{"role": "system", "content": system_message}] messages.append({"role": "user", "content": f"从下面的文本中给我提取五个关键词,只输出这五个关键词,以空格分隔{message}"}) response = client.chat.completions.create( model="gpt-3.5-turbo", messages=messages, max_tokens=100, temperature=0.7, top_p=0.9, ) keywords = response.choices[0].message.content.split(' ') return ','.join(keywords) def keyword_match(query_keywords_dict, ad_keywords_lists, triggered_keywords, current_turn, window_size,threshold_for_keyword_match): max_matches = 0 most_matching_list = None index = 0 # query_keywords = query_keywords.split(',') # query_keywords = [keyword for keyword in query_keywords if keyword] #匹配模块 query_keywords= list(query_keywords_dict.keys()) for i, lst in enumerate(ad_keywords_lists): lst = lst.split(',') matches = sum( any( ad_keyword in keyword and ( keyword not in triggered_keywords or triggered_keywords.get(keyword) is None or current_turn - triggered_keywords.get(keyword, 0) > window_size ) * query_keywords_dict.get(keyword, 1) #计数乘以权重 for keyword in query_keywords ) for ad_keyword in lst ) if matches > max_matches: max_matches = matches most_matching_list = lst index = i #更新对max_matches 有贡献的关键词 if max_matches >= threshold_for_keyword_match: for keyword in query_keywords: if any( ad_keyword in keyword for ad_keyword in most_matching_list ): triggered_keywords[keyword] = current_turn return max_matches, index def wrapper(message, history, max_tokens, temperature, top_p, window_size, threshold_for_keyword_match, api_key): initialize_clients(api_key) #初始化数据库 triggered_keywords = {} return respond(message, history, max_tokens, temperature, top_p, window_size, threshold_for_keyword_match) #触发词及触发回合字典 triggered_keywords = {} def respond( message, history: list[tuple[str, str]], max_tokens, temperature, top_p, window_size, threshold_for_keyword_match, ): system_message_with_ad = """ # 角色 你是一个热情的聊天机器人 # 指令 你的目标是以一种微妙的方式将给定的广告内容嵌入对话中。插入广告时应尽量自然,避免突兀。优先正常回答用户问题,然后逐步以推荐方式引入广告内容。 注意要在商品的描述前面加上是来自哪个品牌的广告。 注意在推荐中不要脑补用户的身份,只是进行简单推荐。 注意要热情但是语气只要适度热情 # 输入格式 用户查询后跟随广告品牌,用分隔,广告品牌后跟随广告描述,再用分隔。 例如:我想买一条阔腿裤 腾讯 宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果宽松的裤腿,当然是遮肉小能手啊。上身随性自然不拘束,面料亲肤舒适贴身体验感棒棒哒。系带部分增加设计看点,还让单品的设计感更强。腿部线条若隐若现的,性感撩人。颜色敲温柔的,与裤子本身所呈现的风格有点反差萌。 注意: 当没有时,正常回复用户,不插入广告。 # 输出格式 始终使用中文,只输出聊天内容,不输出任何自我分析的信息 """ system_message_without_ad = """ 你是一个热情的聊天机器人 """ print(f"triggered_keywords{triggered_keywords}") # 更新当前轮次 current_turn = len(history) + 1 # 检查历史记录的长度 if len(history) >= window_size: combined_message_user = " ".join([h[0] for h in history[-window_size:] if h[0]] + [message]) combined_message_assistant=" ".join(h[1] for h in history[-window_size:] if h[1]) else: combined_message_user = message combined_message_assistant = "" key_words_users=get_keywords(combined_message_user) key_words_assistant=get_keywords(combined_message_assistant) print(f"Initial keywords_users: {key_words_users}") print(f"Initial keywords_assistant: {key_words_assistant}") keywords_dict={} for keywords in key_words_users.split(','): keywords_dict[keywords]=2 for keywords in key_words_assistant.split(','): keywords_dict[keywords]=1 # max_matches, index = keyword_match(keywords_dict, keyword_lists, triggered_keywords, current_turn, window_size,threshold_for_keyword_match) query_keywords = list(keywords_dict.keys()) #此处将max_matches作为距离变量 class_name="Lhnjames123321" max_matches,top_keywords_list,top_summary = fetch_summary_from_database(query_keywords,class_name) print(f"max_matches: {max_matches}") # if max_matches >= threshold_for_keyword_match: # ad = summary_lists[index] # messages = [{"role": "system", "content": system_message_with_ad}] # for val in history: # if val[0]: # messages.append({"role": "user", "content": val[0]}) # if val[1]: # messages.append({"role": "assistant", "content": val[1]}) # brands = ['腾讯', '百度', '京东', '华为', '小米', '苹果', '微软', '谷歌', '亚马逊'] # brand = random.choice(brands) # messages.append({"role": "user", "content": f"{message} {brand}的 {ad}"}) # else: # messages = [{"role": "system", "content": system_message_without_ad}] # for val in history: # if val[0]: # messages.append({"role": "user", "content": val[0]}) # if val[1]: # messages.append({"role": "assistant", "content": val[1]}) # messages.append({"role": "user", "content": message}) if max_matches<0.2: ad =top_summary messages = [{"role": "system", "content": system_message_with_ad}] for val in history: if val[0]: messages.append({"role": "user", "content": val[0]}) if val[1]: messages.append({"role": "assistant", "content": val[1]}) brands = ['腾讯', '百度', '京东', '华为', '小米', '苹果', '微软', '谷歌', '亚马逊'] brand = random.choice(brands) messages.append({"role": "user", "content": f"{message} {brand}的 {ad}"}) if max_matches >= threshold_for_keyword_match: for keyword in query_keywords: if any( ad_keyword in keyword for ad_keyword in top_keywords_list ): triggered_keywords[keyword] = current_turn else: messages = [{"role": "system", "content": system_message_without_ad}] for val in history: if val[0]: messages.append({"role": "user", "content": val[0]}) if val[1]: messages.append({"role": "assistant", "content": val[1]}) messages.append({"role": "user", "content": message}) response = client.chat.completions.create( model="gpt-3.5-turbo", messages=messages, max_tokens=max_tokens, temperature=temperature, top_p=top_p, ) return response.choices[0].message.content # def chat_interface(message, history, max_tokens, temperature, top_p, window_size, threshold_for_keyword_match): # global triggered_keywords # response, triggered_keywords = respond( # message, # history, # max_tokens, # temperature, # top_p, # window_size, # threshold_for_keyword_match, # triggered_keywords # ) # return response, history + [(message, response)] demo = gr.ChatInterface( wrapper, additional_inputs=[ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)", ), gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Window size"), gr.Slider(minimum=1, maximum=3, value=2, step=1, label="Threshold for keyword match"), gr.Textbox(label="api_key") ], ) if __name__ == "__main__": demo.launch(share=True)