Spaces:
Runtime error
Runtime error
Upload 18 files
Browse files- ChatHaruhi/BaiChuan2GPT.py +83 -0
- ChatHaruhi/BaiChuanAPIGPT.py +112 -0
- ChatHaruhi/BaseDB.py +27 -0
- ChatHaruhi/BaseLLM.py +56 -0
- ChatHaruhi/ChatGLM2GPT.py +79 -0
- ChatHaruhi/ChatHaruhi.py +450 -0
- ChatHaruhi/ChatHaruhi_safe.py +337 -0
- ChatHaruhi/ChromaDB.py +61 -0
- ChatHaruhi/ErnieGPT.py +72 -0
- ChatHaruhi/GLMPro.py +90 -0
- ChatHaruhi/LangChainGPT.py +78 -0
- ChatHaruhi/PrintLLM.py +61 -0
- ChatHaruhi/Qwen118k2GPT.py +85 -0
- ChatHaruhi/SparkApi.py +139 -0
- ChatHaruhi/SparkGPT.py +75 -0
- ChatHaruhi/__init__.py +26 -0
- ChatHaruhi/role_name_to_file.py +67 -0
- ChatHaruhi/utils.py +431 -0
ChatHaruhi/BaiChuan2GPT.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .BaseLLM import BaseLLM
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
+
from transformers.generation.utils import GenerationConfig
|
5 |
+
from peft import PeftModel
|
6 |
+
|
7 |
+
tokenizer_BaiChuan = None
|
8 |
+
model_BaiChuan = None
|
9 |
+
|
10 |
+
def initialize_BaiChuan2LORA():
|
11 |
+
global model_BaiChuan, tokenizer_BaiChuan
|
12 |
+
|
13 |
+
if model_BaiChuan is None:
|
14 |
+
model_BaiChuan = AutoModelForCausalLM.from_pretrained(
|
15 |
+
"baichuan-inc/Baichuan2-13B-Chat",
|
16 |
+
device_map="auto",
|
17 |
+
torch_dtype=torch.bfloat16,
|
18 |
+
trust_remote_code=True,
|
19 |
+
)
|
20 |
+
model_BaiChuan = PeftModel.from_pretrained(
|
21 |
+
model_BaiChuan,
|
22 |
+
"silk-road/Chat-Haruhi-Fusion_Baichuan2_13B"
|
23 |
+
)
|
24 |
+
model_BaiChuan.generation_config = GenerationConfig.from_pretrained(
|
25 |
+
"baichuan-inc/Baichuan2-13B-Chat"
|
26 |
+
)
|
27 |
+
|
28 |
+
if tokenizer_BaiChuan is None:
|
29 |
+
tokenizer_BaiChuan = AutoTokenizer.from_pretrained(
|
30 |
+
"baichuan-inc/Baichuan2-13B-Chat",
|
31 |
+
use_fast=True,
|
32 |
+
trust_remote_code=True
|
33 |
+
)
|
34 |
+
|
35 |
+
return model_BaiChuan, tokenizer_BaiChuan
|
36 |
+
|
37 |
+
def BaiChuan_tokenizer(text):
|
38 |
+
return len(tokenizer_BaiChuan.encode(text))
|
39 |
+
|
40 |
+
class BaiChuan2GPT(BaseLLM):
|
41 |
+
def __init__(self, model = "haruhi-fusion-baichuan"):
|
42 |
+
super(BaiChuan2GPT, self).__init__()
|
43 |
+
if model == "baichuan2-13b":
|
44 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
45 |
+
"baichuan-inc/Baichuan2-13B-Chat",
|
46 |
+
use_fast=True,
|
47 |
+
trust_remote_code=True
|
48 |
+
),
|
49 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
50 |
+
"baichuan-inc/Baichuan2-13B-Chat",
|
51 |
+
device_map="auto",
|
52 |
+
torch_dtype=torch.bfloat16,
|
53 |
+
trust_remote_code=True,
|
54 |
+
)
|
55 |
+
self.model.generation_config = GenerationConfig.from_pretrained(
|
56 |
+
"baichuan-inc/Baichuan2-13B-Chat"
|
57 |
+
)
|
58 |
+
elif model == "haruhi-fusion-baichuan":
|
59 |
+
self.model, self.tokenizer = initialize_BaiChuan2LORA()
|
60 |
+
else:
|
61 |
+
raise Exception("Unknown BaiChuan Model! Currently supported: [BaiChuan2-13B, haruhi-fusion-baichuan]")
|
62 |
+
self.messages = []
|
63 |
+
|
64 |
+
def initialize_message(self):
|
65 |
+
self.messages = []
|
66 |
+
|
67 |
+
def ai_message(self, payload):
|
68 |
+
self.messages.append({"role": "assistant", "content": payload})
|
69 |
+
|
70 |
+
def system_message(self, payload):
|
71 |
+
self.messages.append({"role": "system", "content": payload})
|
72 |
+
|
73 |
+
def user_message(self, payload):
|
74 |
+
self.messages.append({"role": "user", "content": payload})
|
75 |
+
|
76 |
+
def get_response(self):
|
77 |
+
with torch.no_grad():
|
78 |
+
response = self.model.chat(self.tokenizer, self.messages)
|
79 |
+
return response
|
80 |
+
|
81 |
+
def print_prompt(self):
|
82 |
+
print(type(self.messages))
|
83 |
+
print(self.messages)
|
ChatHaruhi/BaiChuanAPIGPT.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
import hashlib
|
5 |
+
import requests
|
6 |
+
import copy
|
7 |
+
|
8 |
+
from .BaseLLM import BaseLLM
|
9 |
+
|
10 |
+
BAICHUAN_API_AK = os.getenv("BAICHUAN_API_AK")
|
11 |
+
BAICHUAN_API_SK = os.getenv("BAICHUAN_API_SK")
|
12 |
+
|
13 |
+
def sign(secret_key, data):
|
14 |
+
json_data = json.dumps(data)
|
15 |
+
time_stamp = int(time.time())
|
16 |
+
input_string = secret_key + json_data + str(time_stamp)
|
17 |
+
md5 = hashlib.md5()
|
18 |
+
md5.update(input_string.encode('utf-8'))
|
19 |
+
encrypted = md5.hexdigest()
|
20 |
+
return encrypted
|
21 |
+
|
22 |
+
def do_request(messages, api_key, secret_key):
|
23 |
+
url = "https://api.baichuan-ai.com/v1/chat"
|
24 |
+
|
25 |
+
data = {
|
26 |
+
"model": "Baichuan2-53B",
|
27 |
+
"messages": messages
|
28 |
+
}
|
29 |
+
|
30 |
+
signature = sign(secret_key, data)
|
31 |
+
|
32 |
+
headers = {
|
33 |
+
"Content-Type": "application/json",
|
34 |
+
"Authorization": "Bearer " + api_key,
|
35 |
+
"X-BC-Request-Id": "your requestId",
|
36 |
+
"X-BC-Timestamp": str(int(time.time())),
|
37 |
+
"X-BC-Signature": signature,
|
38 |
+
"X-BC-Sign-Algo": "MD5",
|
39 |
+
}
|
40 |
+
|
41 |
+
response = requests.post(url, data=json.dumps(data), headers=headers)
|
42 |
+
if response.status_code == 200:
|
43 |
+
return response.json()
|
44 |
+
else:
|
45 |
+
return None
|
46 |
+
|
47 |
+
class BaiChuanAPIGPT(BaseLLM):
|
48 |
+
def __init__(self, model="baichuan-api", api_key=None, secret_key=None, verbose=False, if_trick = True):
|
49 |
+
self.if_trick = if_trick
|
50 |
+
super(BaiChuanAPIGPT, self).__init__()
|
51 |
+
self.api_key = api_key or BAICHUAN_API_AK
|
52 |
+
self.secret_key = secret_key or BAICHUAN_API_SK
|
53 |
+
self.verbose = verbose
|
54 |
+
self.model_name = model
|
55 |
+
self.messages = []
|
56 |
+
if self.verbose:
|
57 |
+
print('model name, ', self.model_name)
|
58 |
+
if self.api_key is None or self.secret_key is None:
|
59 |
+
print('Please set BAICHUAN_API_AK and BAICHUAN_API_SK')
|
60 |
+
|
61 |
+
def initialize_message(self):
|
62 |
+
self.messages = []
|
63 |
+
|
64 |
+
|
65 |
+
def ai_message(self, payload):
|
66 |
+
if len(self.messages) == 0:
|
67 |
+
self.user_message("请根据我的要求进行角色扮演:")
|
68 |
+
elif len(self.messages) % 2 == 1:
|
69 |
+
self.messages.append({"role":"assistant","content":payload})
|
70 |
+
elif len(self.messages)% 2 == 0:
|
71 |
+
self.messages[-1]["content"] += "\n"+ payload
|
72 |
+
|
73 |
+
def system_message(self, payload):
|
74 |
+
|
75 |
+
self.messages.append({"role":"user","content":payload})
|
76 |
+
|
77 |
+
|
78 |
+
def user_message(self, payload):
|
79 |
+
if len(self.messages) % 2 == 0:
|
80 |
+
self.messages.append({"role":"user","content":payload})
|
81 |
+
# self.messages[-1]["content"] +=
|
82 |
+
elif len(self.messages)% 2 == 1:
|
83 |
+
self.messages[-1]["content"] += "\n"+ payload
|
84 |
+
|
85 |
+
def get_response(self):
|
86 |
+
max_try = 5
|
87 |
+
sleep_interval = 3
|
88 |
+
|
89 |
+
chat_messages = copy.deepcopy(self.messages)
|
90 |
+
|
91 |
+
if self.if_trick == True:
|
92 |
+
lines = chat_messages[-1]["content"].split('\n')
|
93 |
+
lines.insert(-1, '请请模仿上述经典桥段进行回复\n')
|
94 |
+
chat_messages[-1]["content"] = '\n'.join(lines)
|
95 |
+
|
96 |
+
for i in range(max_try):
|
97 |
+
response = do_request(chat_messages, self.api_key, self.secret_key)
|
98 |
+
if response is not None:
|
99 |
+
if self.verbose:
|
100 |
+
print('Get Baichuan API response success')
|
101 |
+
messages = response['data']['messages']
|
102 |
+
if len(messages) > 0:
|
103 |
+
return messages[-1]['content'].strip("\"'")
|
104 |
+
else:
|
105 |
+
if self.verbose:
|
106 |
+
print('Get Baichuan API response failed, retrying...')
|
107 |
+
time.sleep(sleep_interval)
|
108 |
+
|
109 |
+
def print_prompt(self):
|
110 |
+
for message in self.messages:
|
111 |
+
print(f"{message['role']}: {message['content']}")
|
112 |
+
|
ChatHaruhi/BaseDB.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# BaseDB.py
|
2 |
+
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
|
5 |
+
class BaseDB(ABC):
|
6 |
+
|
7 |
+
@abstractmethod
|
8 |
+
def init_db(self):
|
9 |
+
pass
|
10 |
+
|
11 |
+
@abstractmethod
|
12 |
+
def save(self, file_path):
|
13 |
+
pass
|
14 |
+
|
15 |
+
@abstractmethod
|
16 |
+
def load(self, file_path):
|
17 |
+
pass
|
18 |
+
|
19 |
+
@abstractmethod
|
20 |
+
def search(self, vector, n_results):
|
21 |
+
pass
|
22 |
+
|
23 |
+
@abstractmethod
|
24 |
+
def init_from_docs(self, vectors, documents):
|
25 |
+
pass
|
26 |
+
|
27 |
+
|
ChatHaruhi/BaseLLM.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ChatHaruhi: Reviving Anime Character in Reality via Large Language Model
|
2 |
+
#
|
3 |
+
# ChatHaruhi 2.0, built by Cheng Li and Weishi Mi
|
4 |
+
#
|
5 | |
6 |
+
#
|
7 |
+
# Weishi Mi is a second-year graduate student at Tsinghua University, majoring in computer science.
|
8 |
+
# Weishi Mi is pursuing a job or a PhD position, which who will be available next year
|
9 |
+
#
|
10 |
+
# homepage https://github.com/LC1332/Chat-Haruhi-Suzumiya
|
11 |
+
#
|
12 |
+
# ChatHaruhi is a chatbot that can revive anime characters in reality.
|
13 |
+
# the 2.0 version was built by Cheng Li and Weishi Mi.
|
14 |
+
#
|
15 |
+
# Please cite our paper if you use this code for research:
|
16 |
+
#
|
17 |
+
# @misc{li2023chatharuhi,
|
18 |
+
# title={ChatHaruhi: Reviving Anime Character in Reality via Large Language Model},
|
19 |
+
# author={Cheng Li and Ziang Leng and Chenxi Yan and Junyi Shen and Hao Wang and Weishi MI and Yaying Fei and Xiaoyang Feng and Song Yan and HaoSheng Wang and Linkang Zhan and Yaokai Jia and Pingyu Wu and Haozhen Sun},
|
20 |
+
# year={2023},
|
21 |
+
# eprint={2308.09597},
|
22 |
+
# archivePrefix={arXiv},
|
23 |
+
# primaryClass={cs.CL}
|
24 |
+
# }
|
25 |
+
from abc import ABC, abstractmethod
|
26 |
+
|
27 |
+
class BaseLLM(ABC):
|
28 |
+
|
29 |
+
def __init__(self):
|
30 |
+
pass
|
31 |
+
|
32 |
+
@abstractmethod
|
33 |
+
def initialize_message(self):
|
34 |
+
pass
|
35 |
+
|
36 |
+
@abstractmethod
|
37 |
+
def ai_message(self, payload):
|
38 |
+
pass
|
39 |
+
|
40 |
+
@abstractmethod
|
41 |
+
def system_message(self, payload):
|
42 |
+
pass
|
43 |
+
|
44 |
+
@abstractmethod
|
45 |
+
def user_message(self, payload):
|
46 |
+
pass
|
47 |
+
|
48 |
+
@abstractmethod
|
49 |
+
def get_response(self):
|
50 |
+
pass
|
51 |
+
|
52 |
+
@abstractmethod
|
53 |
+
def print_prompt(self):
|
54 |
+
pass
|
55 |
+
|
56 |
+
|
ChatHaruhi/ChatGLM2GPT.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .BaseLLM import BaseLLM
|
3 |
+
from transformers import AutoTokenizer, AutoModel
|
4 |
+
from peft import PeftModel
|
5 |
+
|
6 |
+
tokenizer_GLM = None
|
7 |
+
model_GLM = None
|
8 |
+
|
9 |
+
def initialize_GLM2LORA():
|
10 |
+
global model_GLM, tokenizer_GLM
|
11 |
+
|
12 |
+
if model_GLM is None:
|
13 |
+
model_GLM = AutoModel.from_pretrained(
|
14 |
+
"THUDM/chatglm2-6b",
|
15 |
+
torch_dtype=torch.float16,
|
16 |
+
device_map="auto",
|
17 |
+
trust_remote_code=True
|
18 |
+
)
|
19 |
+
model_GLM = PeftModel.from_pretrained(
|
20 |
+
model_GLM,
|
21 |
+
"silk-road/Chat-Haruhi-Fusion_B"
|
22 |
+
)
|
23 |
+
|
24 |
+
if tokenizer_GLM is None:
|
25 |
+
tokenizer_GLM = AutoTokenizer.from_pretrained(
|
26 |
+
"THUDM/chatglm2-6b",
|
27 |
+
use_fast=True,
|
28 |
+
trust_remote_code=True
|
29 |
+
)
|
30 |
+
|
31 |
+
return model_GLM, tokenizer_GLM
|
32 |
+
|
33 |
+
def GLM_tokenizer(text):
|
34 |
+
return len(tokenizer_GLM.encode(text))
|
35 |
+
|
36 |
+
class ChatGLM2GPT(BaseLLM):
|
37 |
+
def __init__(self, model = "haruhi-fusion"):
|
38 |
+
super(ChatGLM2GPT, self).__init__()
|
39 |
+
if model == "glm2-6b":
|
40 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
41 |
+
"THUDM/chatglm2-6b",
|
42 |
+
use_fast=True,
|
43 |
+
trust_remote_code=True
|
44 |
+
)
|
45 |
+
self.model = AutoModel.from_pretrained(
|
46 |
+
"THUDM/chatglm2-6b",
|
47 |
+
torch_dtype=torch.float16,
|
48 |
+
device_map="auto",
|
49 |
+
trust_remote_code=True
|
50 |
+
)
|
51 |
+
if model == "haruhi-fusion":
|
52 |
+
self.model, self.tokenizer = initialize_GLM2LORA()
|
53 |
+
else:
|
54 |
+
raise Exception("Unknown GLM model")
|
55 |
+
self.messages = ""
|
56 |
+
|
57 |
+
def initialize_message(self):
|
58 |
+
self.messages = ""
|
59 |
+
|
60 |
+
def ai_message(self, payload):
|
61 |
+
self.messages = self.messages + "\n " + payload
|
62 |
+
|
63 |
+
def system_message(self, payload):
|
64 |
+
self.messages = self.messages + "\n " + payload
|
65 |
+
|
66 |
+
def user_message(self, payload):
|
67 |
+
self.messages = self.messages + "\n " + payload
|
68 |
+
|
69 |
+
def get_response(self):
|
70 |
+
with torch.no_grad():
|
71 |
+
response, history = self.model.chat(self.tokenizer, self.messages, history=[])
|
72 |
+
# print(response)
|
73 |
+
return response
|
74 |
+
|
75 |
+
def print_prompt(self):
|
76 |
+
print(type(self.messages))
|
77 |
+
print(self.messages)
|
78 |
+
|
79 |
+
|
ChatHaruhi/ChatHaruhi.py
ADDED
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .ChromaDB import ChromaDB
|
2 |
+
import os
|
3 |
+
|
4 |
+
from .utils import luotuo_openai_embedding, tiktokenizer
|
5 |
+
|
6 |
+
from .utils import response_postprocess
|
7 |
+
|
8 |
+
def get_text_from_data( data ):
|
9 |
+
if "text" in data:
|
10 |
+
return data['text']
|
11 |
+
elif "enc_text" in data:
|
12 |
+
from .utils import base64_to_string
|
13 |
+
return base64_to_string( data['enc_text'] )
|
14 |
+
else:
|
15 |
+
print("warning! failed to get text from data ", data)
|
16 |
+
return ""
|
17 |
+
|
18 |
+
class ChatHaruhi:
|
19 |
+
|
20 |
+
def __init__(self, system_prompt = None, \
|
21 |
+
role_name = None, role_from_hf = None,
|
22 |
+
role_from_jsonl = None, \
|
23 |
+
story_db=None, story_text_folder = None, \
|
24 |
+
llm = 'openai', \
|
25 |
+
embedding = 'luotuo_openai', \
|
26 |
+
max_len_story = None, max_len_history = None,
|
27 |
+
verbose = False):
|
28 |
+
super(ChatHaruhi, self).__init__()
|
29 |
+
self.verbose = verbose
|
30 |
+
|
31 |
+
# constants
|
32 |
+
self.story_prefix_prompt = "Classic scenes for the role are as follows:\n"
|
33 |
+
self.k_search = 19
|
34 |
+
self.narrator = ['旁白', '', 'scene','Scene','narrator' , 'Narrator']
|
35 |
+
self.dialogue_divide_token = '\n###\n'
|
36 |
+
self.dialogue_bra_token = '「'
|
37 |
+
self.dialogue_ket_token = '」'
|
38 |
+
|
39 |
+
if system_prompt:
|
40 |
+
self.system_prompt = self.check_system_prompt( system_prompt )
|
41 |
+
|
42 |
+
# TODO: embedding should be the seperately defined, so refactor this part later
|
43 |
+
if llm == 'openai':
|
44 |
+
# self.llm = LangChainGPT()
|
45 |
+
self.llm, self.tokenizer = self.get_models('openai')
|
46 |
+
elif llm == 'debug':
|
47 |
+
self.llm, self.tokenizer = self.get_models('debug')
|
48 |
+
elif llm == 'spark':
|
49 |
+
self.llm, self.tokenizer = self.get_models('spark')
|
50 |
+
elif llm == 'GLMPro':
|
51 |
+
self.llm, self.tokenizer = self.get_models('GLMPro')
|
52 |
+
elif llm == 'ChatGLM2GPT':
|
53 |
+
self.llm, self.tokenizer = self.get_models('ChatGLM2GPT')
|
54 |
+
self.story_prefix_prompt = '\n'
|
55 |
+
elif llm == "BaiChuan2GPT":
|
56 |
+
self.llm, self.tokenizer = self.get_models('BaiChuan2GPT')
|
57 |
+
elif llm == "BaiChuanAPIGPT":
|
58 |
+
self.llm, self.tokenizer = self.get_models('BaiChuanAPIGPT')
|
59 |
+
elif llm == "ernie3.5":
|
60 |
+
self.llm, self.tokenizer = self.get_models('ernie3.5')
|
61 |
+
elif llm == "ernie4.0":
|
62 |
+
self.llm, self.tokenizer = self.get_models('ernie4.0')
|
63 |
+
elif "qwen" in llm:
|
64 |
+
self.llm, self.tokenizer = self.get_models(llm)
|
65 |
+
else:
|
66 |
+
print(f'warning! undefined llm {llm}, use openai instead.')
|
67 |
+
self.llm, self.tokenizer = self.get_models('openai')
|
68 |
+
|
69 |
+
if embedding == 'luotuo_openai':
|
70 |
+
self.embedding = luotuo_openai_embedding
|
71 |
+
elif embedding == 'bge_en':
|
72 |
+
from .utils import get_bge_embedding
|
73 |
+
self.embedding = get_bge_embedding
|
74 |
+
elif embedding == 'bge_zh':
|
75 |
+
from .utils import get_bge_zh_embedding
|
76 |
+
self.embedding = get_bge_zh_embedding
|
77 |
+
else:
|
78 |
+
print(f'warning! undefined embedding {embedding}, use luotuo_openai instead.')
|
79 |
+
self.embedding = luotuo_openai_embedding
|
80 |
+
|
81 |
+
if role_name:
|
82 |
+
# TODO move into a function
|
83 |
+
from .role_name_to_file import get_folder_role_name
|
84 |
+
# correct role_name to folder_role_name
|
85 |
+
role_name, url = get_folder_role_name(role_name)
|
86 |
+
|
87 |
+
unzip_folder = f'./temp_character_folder/temp_{role_name}'
|
88 |
+
db_folder = os.path.join(unzip_folder, f'content/{role_name}')
|
89 |
+
system_prompt = os.path.join(unzip_folder, f'content/system_prompt.txt')
|
90 |
+
|
91 |
+
if not os.path.exists(unzip_folder):
|
92 |
+
# not yet downloaded
|
93 |
+
# url = f'https://github.com/LC1332/Haruhi-2-Dev/raw/main/data/character_in_zip/{role_name}.zip'
|
94 |
+
import requests, zipfile, io
|
95 |
+
r = requests.get(url)
|
96 |
+
z = zipfile.ZipFile(io.BytesIO(r.content))
|
97 |
+
z.extractall(unzip_folder)
|
98 |
+
|
99 |
+
if self.verbose:
|
100 |
+
print(f'loading pre-defined character {role_name}...')
|
101 |
+
|
102 |
+
self.db = ChromaDB()
|
103 |
+
self.db.load(db_folder)
|
104 |
+
self.system_prompt = self.check_system_prompt(system_prompt)
|
105 |
+
elif role_from_hf:
|
106 |
+
# TODO move into a function
|
107 |
+
from datasets import load_dataset
|
108 |
+
|
109 |
+
if role_from_hf.count("/") == 1:
|
110 |
+
dataset = load_dataset(role_from_hf)
|
111 |
+
datas = dataset["train"]
|
112 |
+
elif role_from_hf.count("/") >= 2:
|
113 |
+
split_index = role_from_hf.index('/')
|
114 |
+
second_split_index = role_from_hf.index('/', split_index+1)
|
115 |
+
dataset_name = role_from_hf[:second_split_index]
|
116 |
+
split_name = role_from_hf[second_split_index+1:]
|
117 |
+
|
118 |
+
fname = split_name + '.jsonl'
|
119 |
+
dataset = load_dataset(dataset_name,data_files={'train':fname})
|
120 |
+
datas = dataset["train"]
|
121 |
+
|
122 |
+
if embedding == 'luotuo_openai':
|
123 |
+
embed_name = 'luotuo_openai'
|
124 |
+
elif embedding == 'bge_en':
|
125 |
+
embed_name = 'bge_en_s15'
|
126 |
+
elif embedding == 'bge_zh':
|
127 |
+
embed_name = 'bge_zh_s15'
|
128 |
+
else:
|
129 |
+
print('warning! unkown embedding name ', embedding ,' while loading role')
|
130 |
+
embed_name = 'luotuo_openai'
|
131 |
+
|
132 |
+
texts, vecs, self.system_prompt = self.extract_text_vec_from_datas(datas, embed_name)
|
133 |
+
|
134 |
+
self.build_story_db_from_vec( texts, vecs )
|
135 |
+
|
136 |
+
elif role_from_jsonl:
|
137 |
+
import json
|
138 |
+
datas = []
|
139 |
+
with open( role_from_jsonl , encoding="utf-8") as f:
|
140 |
+
for line in f:
|
141 |
+
try:
|
142 |
+
data = json.loads(line)
|
143 |
+
# 逐行处理JSON数据
|
144 |
+
datas.append(data)
|
145 |
+
except:
|
146 |
+
print("warning! failed to load json line ", line)
|
147 |
+
|
148 |
+
if embedding == 'luotuo_openai':
|
149 |
+
embed_name = 'luotuo_openai'
|
150 |
+
elif embedding == 'bge_en':
|
151 |
+
embed_name = 'bge_en_s15'
|
152 |
+
elif embedding == 'bge_zh':
|
153 |
+
embed_name = 'bge_zh_s15'
|
154 |
+
else:
|
155 |
+
print('warning! unkown embedding name ', embedding ,' while loading role')
|
156 |
+
embed_name = 'luotuo_openai'
|
157 |
+
|
158 |
+
texts, vecs, self.system_prompt = self.extract_text_vec_from_datas(datas, embed_name)
|
159 |
+
|
160 |
+
self.build_story_db_from_vec( texts, vecs )
|
161 |
+
|
162 |
+
elif story_db:
|
163 |
+
self.db = ChromaDB()
|
164 |
+
self.db.load(story_db)
|
165 |
+
elif story_text_folder:
|
166 |
+
# print("Building story database from texts...")
|
167 |
+
self.db = self.build_story_db(story_text_folder)
|
168 |
+
else:
|
169 |
+
self.db = None
|
170 |
+
print('warning! database not yet figured out, both story_db and story_text_folder are not inputted.')
|
171 |
+
# raise ValueError("Either story_db or story_text_folder must be provided")
|
172 |
+
|
173 |
+
|
174 |
+
self.max_len_story, self.max_len_history = self.get_tokenlen_setting('openai')
|
175 |
+
|
176 |
+
if max_len_history is not None:
|
177 |
+
self.max_len_history = max_len_history
|
178 |
+
# user setting will override default setting
|
179 |
+
|
180 |
+
if max_len_story is not None:
|
181 |
+
self.max_len_story = max_len_story
|
182 |
+
# user setting will override default setting
|
183 |
+
|
184 |
+
self.dialogue_history = []
|
185 |
+
|
186 |
+
def extract_text_vec_from_datas( self, datas, embed_name ):
|
187 |
+
# extract text and vec from huggingface dataset
|
188 |
+
# return texts, vecs
|
189 |
+
from .utils import base64_to_float_array
|
190 |
+
|
191 |
+
texts = []
|
192 |
+
vecs = []
|
193 |
+
for data in datas:
|
194 |
+
if data[embed_name] == 'system_prompt':
|
195 |
+
system_prompt = get_text_from_data( data )
|
196 |
+
elif data[embed_name] == 'config':
|
197 |
+
pass
|
198 |
+
else:
|
199 |
+
vec = base64_to_float_array( data[embed_name] )
|
200 |
+
text = get_text_from_data( data )
|
201 |
+
vecs.append( vec )
|
202 |
+
texts.append( text )
|
203 |
+
return texts, vecs, system_prompt
|
204 |
+
|
205 |
+
|
206 |
+
|
207 |
+
def check_system_prompt(self, system_prompt):
|
208 |
+
# if system_prompt end with .txt, read the file with utf-8
|
209 |
+
# else, return the string directly
|
210 |
+
if system_prompt.endswith('.txt'):
|
211 |
+
with open(system_prompt, 'r', encoding='utf-8') as f:
|
212 |
+
return f.read()
|
213 |
+
else:
|
214 |
+
return system_prompt
|
215 |
+
|
216 |
+
|
217 |
+
def get_models(self, model_name):
|
218 |
+
|
219 |
+
# TODO: if output only require tokenizer model, no need to initialize llm
|
220 |
+
|
221 |
+
# return the combination of llm, embedding and tokenizer
|
222 |
+
if model_name == 'openai':
|
223 |
+
from .LangChainGPT import LangChainGPT
|
224 |
+
return (LangChainGPT(), tiktokenizer)
|
225 |
+
elif model_name == 'debug':
|
226 |
+
from .PrintLLM import PrintLLM
|
227 |
+
return (PrintLLM(), tiktokenizer)
|
228 |
+
elif model_name == 'spark':
|
229 |
+
from .SparkGPT import SparkGPT
|
230 |
+
return (SparkGPT(), tiktokenizer)
|
231 |
+
elif model_name == 'GLMPro':
|
232 |
+
from .GLMPro import GLMPro
|
233 |
+
return (GLMPro(), tiktokenizer)
|
234 |
+
elif model_name == 'ernie3.5':
|
235 |
+
from .ErnieGPT import ErnieGPT
|
236 |
+
return (ErnieGPT(), tiktokenizer)
|
237 |
+
elif model_name == 'ernie4.0':
|
238 |
+
from .ErnieGPT import ErnieGPT
|
239 |
+
return (ErnieGPT(model="ernie-bot-4"), tiktokenizer)
|
240 |
+
elif model_name == "ChatGLM2GPT":
|
241 |
+
from .ChatGLM2GPT import ChatGLM2GPT, GLM_tokenizer
|
242 |
+
return (ChatGLM2GPT(), GLM_tokenizer)
|
243 |
+
elif model_name == "BaiChuan2GPT":
|
244 |
+
from .BaiChuan2GPT import BaiChuan2GPT, BaiChuan_tokenizer
|
245 |
+
return (BaiChuan2GPT(), BaiChuan_tokenizer)
|
246 |
+
elif model_name == "BaiChuanAPIGPT":
|
247 |
+
from .BaiChuanAPIGPT import BaiChuanAPIGPT
|
248 |
+
return (BaiChuanAPIGPT(), tiktokenizer)
|
249 |
+
elif "qwen" in model_name:
|
250 |
+
if model_name == "qwen118k_raw":
|
251 |
+
from .Qwen118k2GPT import Qwen118k2GPT, Qwen_tokenizer
|
252 |
+
return (Qwen118k2GPT(model = "Qwen/Qwen-1_8B-Chat"), Qwen_tokenizer)
|
253 |
+
from huggingface_hub import HfApi
|
254 |
+
from huggingface_hub.hf_api import ModelFilter
|
255 |
+
qwen_api = HfApi()
|
256 |
+
qwen_models = qwen_api.list_models(
|
257 |
+
filter = ModelFilter(model_name=model_name),
|
258 |
+
author = "silk-road"
|
259 |
+
)
|
260 |
+
qwen_models_id = []
|
261 |
+
for qwen_model in qwen_models:
|
262 |
+
qwen_models_id.append(qwen_model.id)
|
263 |
+
# print(model.id)
|
264 |
+
if "silk-road/" + model_name in qwen_models_id:
|
265 |
+
from .Qwen118k2GPT import Qwen118k2GPT, Qwen_tokenizer
|
266 |
+
return (Qwen118k2GPT(model = "silk-road/" + model_name), Qwen_tokenizer)
|
267 |
+
else:
|
268 |
+
print(f'warning! undefined model {model_name}, use openai instead.')
|
269 |
+
from .LangChainGPT import LangChainGPT
|
270 |
+
return (LangChainGPT(), tiktokenizer)
|
271 |
+
# print(models_id)
|
272 |
+
else:
|
273 |
+
print(f'warning! undefined model {model_name}, use openai instead.')
|
274 |
+
from .LangChainGPT import LangChainGPT
|
275 |
+
return (LangChainGPT(), tiktokenizer)
|
276 |
+
|
277 |
+
def get_tokenlen_setting( self, model_name ):
|
278 |
+
# return the setting of story and history token length
|
279 |
+
if model_name == 'openai':
|
280 |
+
return (1500, 1200)
|
281 |
+
else:
|
282 |
+
print(f'warning! undefined model {model_name}, use openai instead.')
|
283 |
+
return (1500, 1200)
|
284 |
+
|
285 |
+
def build_story_db_from_vec( self, texts, vecs ):
|
286 |
+
self.db = ChromaDB()
|
287 |
+
|
288 |
+
self.db.init_from_docs( vecs, texts)
|
289 |
+
|
290 |
+
def build_story_db(self, text_folder):
|
291 |
+
# 实现读取文本文件夹,抽取向量的逻辑
|
292 |
+
db = ChromaDB()
|
293 |
+
|
294 |
+
strs = []
|
295 |
+
|
296 |
+
# scan all txt file from text_folder
|
297 |
+
for file in os.listdir(text_folder):
|
298 |
+
# if file name end with txt
|
299 |
+
if file.endswith(".txt"):
|
300 |
+
file_path = os.path.join(text_folder, file)
|
301 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
302 |
+
strs.append(f.read())
|
303 |
+
|
304 |
+
if self.verbose:
|
305 |
+
print(f'starting extract embedding... for { len(strs) } files')
|
306 |
+
|
307 |
+
vecs = []
|
308 |
+
|
309 |
+
## TODO: 建立一个新的embedding batch test的单元测试
|
310 |
+
## 新的支持list batch test的embedding代码
|
311 |
+
## 用新的代码替换下面的for循环
|
312 |
+
## Luotuo-bert-en也发布了,所以可以避开使用openai
|
313 |
+
|
314 |
+
for mystr in strs:
|
315 |
+
vecs.append(self.embedding(mystr))
|
316 |
+
|
317 |
+
db.init_from_docs(vecs, strs)
|
318 |
+
|
319 |
+
return db
|
320 |
+
|
321 |
+
def save_story_db(self, db_path):
|
322 |
+
self.db.save(db_path)
|
323 |
+
|
324 |
+
def generate_prompt( self, text, role):
|
325 |
+
from langchain.schema import (
|
326 |
+
AIMessage,
|
327 |
+
HumanMessage,
|
328 |
+
SystemMessage
|
329 |
+
)
|
330 |
+
messages = self.generate_messages( text, role )
|
331 |
+
prompt = ""
|
332 |
+
for msg in messages:
|
333 |
+
if isinstance(msg, HumanMessage):
|
334 |
+
prompt += msg.content + "\n"
|
335 |
+
elif isinstance(msg, AIMessage):
|
336 |
+
prompt += msg.content + "\n"
|
337 |
+
elif isinstance(msg, SystemMessage):
|
338 |
+
prompt += msg.content + "\n"
|
339 |
+
return prompt
|
340 |
+
|
341 |
+
|
342 |
+
def generate_messages( self, text, role):
|
343 |
+
# add system prompt
|
344 |
+
self.llm.initialize_message()
|
345 |
+
self.llm.system_message(self.system_prompt)
|
346 |
+
|
347 |
+
# add story
|
348 |
+
query = self.get_query_string(text, role)
|
349 |
+
self.add_story( query )
|
350 |
+
self.last_query = query
|
351 |
+
|
352 |
+
# add query
|
353 |
+
self.llm.user_message(query)
|
354 |
+
|
355 |
+
return self.llm.messages
|
356 |
+
|
357 |
+
def append_response( self, response, last_query = None ):
|
358 |
+
if last_query == None:
|
359 |
+
last_query_record = ""
|
360 |
+
if hasattr( self, "last_query" ):
|
361 |
+
last_query_record = self.last_query
|
362 |
+
else:
|
363 |
+
last_query_record = last_query
|
364 |
+
|
365 |
+
# record dialogue history
|
366 |
+
self.dialogue_history.append((last_query_record, response))
|
367 |
+
|
368 |
+
def chat(self, text, role):
|
369 |
+
# add system prompt
|
370 |
+
self.llm.initialize_message()
|
371 |
+
self.llm.system_message(self.system_prompt)
|
372 |
+
|
373 |
+
|
374 |
+
# add story
|
375 |
+
query = self.get_query_string(text, role)
|
376 |
+
self.add_story( query )
|
377 |
+
|
378 |
+
# add history
|
379 |
+
self.add_history()
|
380 |
+
|
381 |
+
# add query
|
382 |
+
self.llm.user_message(query)
|
383 |
+
|
384 |
+
# get response
|
385 |
+
response_raw = self.llm.get_response()
|
386 |
+
|
387 |
+
response = response_postprocess(response_raw, self.dialogue_bra_token, self.dialogue_ket_token)
|
388 |
+
|
389 |
+
# record dialogue history
|
390 |
+
self.dialogue_history.append((query, response))
|
391 |
+
|
392 |
+
|
393 |
+
|
394 |
+
return response
|
395 |
+
|
396 |
+
def get_query_string(self, text, role):
|
397 |
+
if role in self.narrator:
|
398 |
+
return role + ":" + text
|
399 |
+
else:
|
400 |
+
return f"{role}:{self.dialogue_bra_token}{text}{self.dialogue_ket_token}"
|
401 |
+
|
402 |
+
def add_story(self, query):
|
403 |
+
|
404 |
+
if self.db is None:
|
405 |
+
return
|
406 |
+
|
407 |
+
query_vec = self.embedding(query)
|
408 |
+
|
409 |
+
stories = self.db.search(query_vec, self.k_search)
|
410 |
+
|
411 |
+
story_string = self.story_prefix_prompt
|
412 |
+
sum_story_token = self.tokenizer(story_string)
|
413 |
+
|
414 |
+
for story in stories:
|
415 |
+
story_token = self.tokenizer(story) + self.tokenizer(self.dialogue_divide_token)
|
416 |
+
if sum_story_token + story_token > self.max_len_story:
|
417 |
+
break
|
418 |
+
else:
|
419 |
+
sum_story_token += story_token
|
420 |
+
story_string += story + self.dialogue_divide_token
|
421 |
+
|
422 |
+
self.llm.user_message(story_string)
|
423 |
+
|
424 |
+
def add_history(self):
|
425 |
+
|
426 |
+
if len(self.dialogue_history) == 0:
|
427 |
+
return
|
428 |
+
|
429 |
+
sum_history_token = 0
|
430 |
+
flag = 0
|
431 |
+
for query, response in reversed(self.dialogue_history):
|
432 |
+
current_count = 0
|
433 |
+
if query is not None:
|
434 |
+
current_count += self.tokenizer(query)
|
435 |
+
if response is not None:
|
436 |
+
current_count += self.tokenizer(response)
|
437 |
+
sum_history_token += current_count
|
438 |
+
if sum_history_token > self.max_len_history:
|
439 |
+
break
|
440 |
+
else:
|
441 |
+
flag += 1
|
442 |
+
|
443 |
+
if flag == 0:
|
444 |
+
print('warning! no history added. the last dialogue is too long.')
|
445 |
+
|
446 |
+
for (query, response) in self.dialogue_history[-flag:]:
|
447 |
+
if query is not None:
|
448 |
+
self.llm.user_message(query)
|
449 |
+
if response is not None:
|
450 |
+
self.llm.ai_message(response)
|
ChatHaruhi/ChatHaruhi_safe.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .ChromaDB import ChromaDB
|
2 |
+
import os
|
3 |
+
|
4 |
+
from .utils import luotuo_openai_embedding, tiktokenizer
|
5 |
+
|
6 |
+
from .utils import response_postprocess
|
7 |
+
|
8 |
+
from .utils import text_censor
|
9 |
+
|
10 |
+
class ChatHaruhi_safe:
|
11 |
+
|
12 |
+
def __init__(self, system_prompt = None, \
|
13 |
+
role_name = None, role_from_hf = None, \
|
14 |
+
story_db=None, story_text_folder = None, \
|
15 |
+
llm = 'openai', \
|
16 |
+
embedding = 'luotuo_openai', \
|
17 |
+
max_len_story = None, max_len_history = None,
|
18 |
+
verbose = False):
|
19 |
+
super(ChatHaruhi_safe, self).__init__()
|
20 |
+
self.verbose = verbose
|
21 |
+
|
22 |
+
# constants
|
23 |
+
self.story_prefix_prompt = "Classic scenes for the role are as follows:\n"
|
24 |
+
self.k_search = 19
|
25 |
+
self.narrator = ['旁白', '', 'scene','Scene','narrator' , 'Narrator']
|
26 |
+
self.dialogue_divide_token = '\n###\n'
|
27 |
+
self.dialogue_bra_token = '「'
|
28 |
+
self.dialogue_ket_token = '」'
|
29 |
+
|
30 |
+
if system_prompt:
|
31 |
+
self.system_prompt = self.check_system_prompt( system_prompt )
|
32 |
+
|
33 |
+
# TODO: embedding should be the seperately defined, so refactor this part later
|
34 |
+
if llm == 'openai':
|
35 |
+
# self.llm = LangChainGPT()
|
36 |
+
self.llm, self.tokenizer = self.get_models('openai')
|
37 |
+
elif llm == 'debug':
|
38 |
+
self.llm, self.tokenizer = self.get_models('debug')
|
39 |
+
elif llm == 'spark':
|
40 |
+
self.llm, self.tokenizer = self.get_models('spark')
|
41 |
+
elif llm == 'GLMPro':
|
42 |
+
self.llm, self.tokenizer = self.get_models('GLMPro')
|
43 |
+
elif llm == 'ChatGLM2GPT':
|
44 |
+
self.llm, self.tokenizer = self.get_models('ChatGLM2GPT')
|
45 |
+
self.story_prefix_prompt = '\n'
|
46 |
+
elif llm == "BaiChuan2GPT":
|
47 |
+
self.llm, self.tokenizer = self.get_models('BaiChuan2GPT')
|
48 |
+
elif llm == "BaiChuanAPIGPT":
|
49 |
+
self.llm, self.tokenizer = self.get_models('BaiChuanAPIGPT')
|
50 |
+
elif llm == "ernie3.5":
|
51 |
+
self.llm, self.tokenizer = self.get_models('ernie3.5')
|
52 |
+
elif llm == "ernie4.0":
|
53 |
+
self.llm, self.tokenizer = self.get_models('ernie4.0')
|
54 |
+
else:
|
55 |
+
print(f'warning! undefined llm {llm}, use openai instead.')
|
56 |
+
self.llm, self.tokenizer = self.get_models('openai')
|
57 |
+
|
58 |
+
if embedding == 'luotuo_openai':
|
59 |
+
self.embedding = luotuo_openai_embedding
|
60 |
+
elif embedding == 'bge_en':
|
61 |
+
from .utils import get_bge_embedding
|
62 |
+
self.embedding = get_bge_embedding
|
63 |
+
else:
|
64 |
+
print(f'warning! undefined embedding {embedding}, use luotuo_openai instead.')
|
65 |
+
self.embedding = luotuo_openai_embedding
|
66 |
+
|
67 |
+
if role_name:
|
68 |
+
# TODO move into a function
|
69 |
+
from .role_name_to_file import get_folder_role_name
|
70 |
+
# correct role_name to folder_role_name
|
71 |
+
role_name, url = get_folder_role_name(role_name)
|
72 |
+
|
73 |
+
unzip_folder = f'./temp_character_folder/temp_{role_name}'
|
74 |
+
db_folder = os.path.join(unzip_folder, f'content/{role_name}')
|
75 |
+
system_prompt = os.path.join(unzip_folder, f'content/system_prompt.txt')
|
76 |
+
|
77 |
+
if not os.path.exists(unzip_folder):
|
78 |
+
# not yet downloaded
|
79 |
+
# url = f'https://github.com/LC1332/Haruhi-2-Dev/raw/main/data/character_in_zip/{role_name}.zip'
|
80 |
+
import requests, zipfile, io
|
81 |
+
r = requests.get(url)
|
82 |
+
z = zipfile.ZipFile(io.BytesIO(r.content))
|
83 |
+
z.extractall(unzip_folder)
|
84 |
+
|
85 |
+
if self.verbose:
|
86 |
+
print(f'loading pre-defined character {role_name}...')
|
87 |
+
|
88 |
+
self.db = ChromaDB()
|
89 |
+
self.db.load(db_folder)
|
90 |
+
self.system_prompt = self.check_system_prompt(system_prompt)
|
91 |
+
elif role_from_hf:
|
92 |
+
# TODO move into a function
|
93 |
+
from datasets import load_dataset
|
94 |
+
|
95 |
+
if role_from_hf.count("/") == 1:
|
96 |
+
dataset = load_dataset(role_from_hf)
|
97 |
+
datas = dataset["train"]
|
98 |
+
elif role_from_hf.count("/") >= 2:
|
99 |
+
split_index = role_from_hf.index('/')
|
100 |
+
second_split_index = role_from_hf.index('/', split_index+1)
|
101 |
+
dataset_name = role_from_hf[:second_split_index]
|
102 |
+
split_name = role_from_hf[second_split_index+1:]
|
103 |
+
|
104 |
+
fname = split_name + '.jsonl'
|
105 |
+
dataset = load_dataset(dataset_name,data_files={'train':fname})
|
106 |
+
datas = dataset["train"]
|
107 |
+
|
108 |
+
|
109 |
+
from .utils import base64_to_float_array
|
110 |
+
|
111 |
+
if embedding == 'luotuo_openai':
|
112 |
+
embed_name = 'luotuo_openai'
|
113 |
+
elif embedding == 'bge_en':
|
114 |
+
embed_name = 'bge_en_s15'
|
115 |
+
else:
|
116 |
+
print('warning! unkown embedding name ', embedding ,' while loading role')
|
117 |
+
embed_name = 'luotuo_openai'
|
118 |
+
|
119 |
+
texts = []
|
120 |
+
vecs = []
|
121 |
+
for data in datas:
|
122 |
+
if data[embed_name] == 'system_prompt':
|
123 |
+
self.system_prompt = data['text']
|
124 |
+
elif data[embed_name] == 'config':
|
125 |
+
pass
|
126 |
+
else:
|
127 |
+
vec = base64_to_float_array( data[embed_name] )
|
128 |
+
text = data['text']
|
129 |
+
vecs.append( vec )
|
130 |
+
texts.append( text )
|
131 |
+
|
132 |
+
self.build_story_db_from_vec( texts, vecs )
|
133 |
+
|
134 |
+
elif story_db:
|
135 |
+
self.db = ChromaDB()
|
136 |
+
self.db.load(story_db)
|
137 |
+
elif story_text_folder:
|
138 |
+
# print("Building story database from texts...")
|
139 |
+
self.db = self.build_story_db(story_text_folder)
|
140 |
+
else:
|
141 |
+
self.db = None
|
142 |
+
print('warning! database not yet figured out, both story_db and story_text_folder are not inputted.')
|
143 |
+
# raise ValueError("Either story_db or story_text_folder must be provided")
|
144 |
+
|
145 |
+
|
146 |
+
self.max_len_story, self.max_len_history = self.get_tokenlen_setting('openai')
|
147 |
+
|
148 |
+
if max_len_history is not None:
|
149 |
+
self.max_len_history = max_len_history
|
150 |
+
# user setting will override default setting
|
151 |
+
|
152 |
+
if max_len_story is not None:
|
153 |
+
self.max_len_story = max_len_story
|
154 |
+
# user setting will override default setting
|
155 |
+
|
156 |
+
self.dialogue_history = []
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
def check_system_prompt(self, system_prompt):
|
161 |
+
# if system_prompt end with .txt, read the file with utf-8
|
162 |
+
# else, return the string directly
|
163 |
+
if system_prompt.endswith('.txt'):
|
164 |
+
with open(system_prompt, 'r', encoding='utf-8') as f:
|
165 |
+
return f.read()
|
166 |
+
else:
|
167 |
+
return system_prompt
|
168 |
+
|
169 |
+
|
170 |
+
def get_models(self, model_name):
|
171 |
+
|
172 |
+
# TODO: if output only require tokenizer model, no need to initialize llm
|
173 |
+
|
174 |
+
# return the combination of llm, embedding and tokenizer
|
175 |
+
if model_name == 'openai':
|
176 |
+
from .LangChainGPT import LangChainGPT
|
177 |
+
return (LangChainGPT(), tiktokenizer)
|
178 |
+
elif model_name == 'debug':
|
179 |
+
from .PrintLLM import PrintLLM
|
180 |
+
return (PrintLLM(), tiktokenizer)
|
181 |
+
elif model_name == 'spark':
|
182 |
+
from .SparkGPT import SparkGPT
|
183 |
+
return (SparkGPT(), tiktokenizer)
|
184 |
+
elif model_name == 'GLMPro':
|
185 |
+
from .GLMPro import GLMPro
|
186 |
+
return (GLMPro(), tiktokenizer)
|
187 |
+
elif model_name == 'ernie3.5':
|
188 |
+
from .ErnieGPT import ErnieGPT
|
189 |
+
return (ErnieGPT(), tiktokenizer)
|
190 |
+
elif model_name == 'ernie4.0':
|
191 |
+
from .ErnieGPT import ErnieGPT
|
192 |
+
return (ErnieGPT(model="ernie-bot-4"), tiktokenizer)
|
193 |
+
elif model_name == "ChatGLM2GPT":
|
194 |
+
from .ChatGLM2GPT import ChatGLM2GPT, GLM_tokenizer
|
195 |
+
return (ChatGLM2GPT(), GLM_tokenizer)
|
196 |
+
elif model_name == "BaiChuan2GPT":
|
197 |
+
from .BaiChuan2GPT import BaiChuan2GPT, BaiChuan_tokenizer
|
198 |
+
return (BaiChuan2GPT(), BaiChuan_tokenizer)
|
199 |
+
elif model_name == "BaiChuanAPIGPT":
|
200 |
+
from .BaiChuanAPIGPT import BaiChuanAPIGPT
|
201 |
+
return (BaiChuanAPIGPT(), tiktokenizer)
|
202 |
+
else:
|
203 |
+
print(f'warning! undefined model {model_name}, use openai instead.')
|
204 |
+
from .LangChainGPT import LangChainGPT
|
205 |
+
return (LangChainGPT(), tiktokenizer)
|
206 |
+
|
207 |
+
def get_tokenlen_setting( self, model_name ):
|
208 |
+
# return the setting of story and history token length
|
209 |
+
if model_name == 'openai':
|
210 |
+
return (1500, 1200)
|
211 |
+
else:
|
212 |
+
print(f'warning! undefined model {model_name}, use openai instead.')
|
213 |
+
return (1500, 1200)
|
214 |
+
|
215 |
+
def build_story_db_from_vec( self, texts, vecs ):
|
216 |
+
self.db = ChromaDB()
|
217 |
+
|
218 |
+
self.db.init_from_docs( vecs, texts)
|
219 |
+
|
220 |
+
def build_story_db(self, text_folder):
|
221 |
+
# 实现读取文本文件夹,抽取向量的逻辑
|
222 |
+
db = ChromaDB()
|
223 |
+
|
224 |
+
strs = []
|
225 |
+
|
226 |
+
# scan all txt file from text_folder
|
227 |
+
for file in os.listdir(text_folder):
|
228 |
+
# if file name end with txt
|
229 |
+
if file.endswith(".txt"):
|
230 |
+
file_path = os.path.join(text_folder, file)
|
231 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
232 |
+
strs.append(f.read())
|
233 |
+
|
234 |
+
if self.verbose:
|
235 |
+
print(f'starting extract embedding... for { len(strs) } files')
|
236 |
+
|
237 |
+
vecs = []
|
238 |
+
|
239 |
+
## TODO: 建立一个新的embedding batch test的单元测试
|
240 |
+
## 新的支持list batch test的embedding代码
|
241 |
+
## 用新的代码替换下面的for循环
|
242 |
+
## Luotuo-bert-en也发布了,所以可以避开使用openai
|
243 |
+
|
244 |
+
for mystr in strs:
|
245 |
+
vecs.append(self.embedding(mystr))
|
246 |
+
|
247 |
+
db.init_from_docs(vecs, strs)
|
248 |
+
|
249 |
+
return db
|
250 |
+
|
251 |
+
def save_story_db(self, db_path):
|
252 |
+
self.db.save(db_path)
|
253 |
+
|
254 |
+
def chat(self, text, role):
|
255 |
+
# add system prompt
|
256 |
+
self.llm.initialize_message()
|
257 |
+
self.llm.system_message(self.system_prompt)
|
258 |
+
|
259 |
+
|
260 |
+
# add story
|
261 |
+
query = self.get_query_string(text, role)
|
262 |
+
self.add_story( query )
|
263 |
+
|
264 |
+
# add history
|
265 |
+
self.add_history()
|
266 |
+
|
267 |
+
# add query
|
268 |
+
self.llm.user_message(query)
|
269 |
+
|
270 |
+
# get response
|
271 |
+
response_raw = self.llm.get_response()
|
272 |
+
|
273 |
+
response = response_postprocess(response_raw, self.dialogue_bra_token, self.dialogue_ket_token)
|
274 |
+
|
275 |
+
# record dialogue history
|
276 |
+
self.dialogue_history.append((query, response))
|
277 |
+
|
278 |
+
|
279 |
+
|
280 |
+
return response
|
281 |
+
|
282 |
+
def get_query_string(self, text, role):
|
283 |
+
if role in self.narrator:
|
284 |
+
return role + ":" + text
|
285 |
+
else:
|
286 |
+
return f"{role}:{self.dialogue_bra_token}{text}{self.dialogue_ket_token}"
|
287 |
+
|
288 |
+
def add_story(self, query):
|
289 |
+
|
290 |
+
if self.db is None:
|
291 |
+
return
|
292 |
+
|
293 |
+
query_vec = self.embedding(query)
|
294 |
+
|
295 |
+
stories = self.db.search(query_vec, self.k_search)
|
296 |
+
|
297 |
+
story_string = self.story_prefix_prompt
|
298 |
+
sum_story_token = self.tokenizer(story_string)
|
299 |
+
|
300 |
+
for story in stories:
|
301 |
+
story_token = self.tokenizer(story) + self.tokenizer(self.dialogue_divide_token)
|
302 |
+
if sum_story_token + story_token > self.max_len_story:
|
303 |
+
break
|
304 |
+
else:
|
305 |
+
sum_story_token += story_token
|
306 |
+
story_string += story + self.dialogue_divide_token
|
307 |
+
|
308 |
+
if text_censor(story_string):
|
309 |
+
self.llm.user_message(story_string)
|
310 |
+
|
311 |
+
def add_history(self):
|
312 |
+
|
313 |
+
if len(self.dialogue_history) == 0:
|
314 |
+
return
|
315 |
+
|
316 |
+
sum_history_token = 0
|
317 |
+
flag = 0
|
318 |
+
for query, response in reversed(self.dialogue_history):
|
319 |
+
current_count = 0
|
320 |
+
if query is not None:
|
321 |
+
current_count += self.tokenizer(query)
|
322 |
+
if response is not None:
|
323 |
+
current_count += self.tokenizer(response)
|
324 |
+
sum_history_token += current_count
|
325 |
+
if sum_history_token > self.max_len_history:
|
326 |
+
break
|
327 |
+
else:
|
328 |
+
flag += 1
|
329 |
+
|
330 |
+
if flag == 0:
|
331 |
+
print('warning! no history added. the last dialogue is too long.')
|
332 |
+
|
333 |
+
for (query, response) in self.dialogue_history[-flag:]:
|
334 |
+
if query is not None:
|
335 |
+
self.llm.user_message(query)
|
336 |
+
if response is not None:
|
337 |
+
self.llm.ai_message(response)
|
ChatHaruhi/ChromaDB.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import chromadb
|
2 |
+
from .BaseDB import BaseDB
|
3 |
+
import random
|
4 |
+
import string
|
5 |
+
import os
|
6 |
+
|
7 |
+
class ChromaDB(BaseDB):
|
8 |
+
|
9 |
+
def __init__(self):
|
10 |
+
self.client = None
|
11 |
+
self.collection = None
|
12 |
+
self.path = None
|
13 |
+
|
14 |
+
def init_db(self):
|
15 |
+
|
16 |
+
if self.client is not None:
|
17 |
+
print('ChromaDB has already been initialized')
|
18 |
+
return
|
19 |
+
|
20 |
+
folder_name = ''
|
21 |
+
|
22 |
+
while os.path.exists(folder_name) or folder_name == '':
|
23 |
+
# try to create a folder named temp_<random string> which is not yet existed
|
24 |
+
folder_name = "tempdb_" + ''.join(random.sample(string.ascii_letters + string.digits, 8))
|
25 |
+
|
26 |
+
self.path = folder_name
|
27 |
+
self.client = chromadb.PersistentClient(path = folder_name)
|
28 |
+
|
29 |
+
self.collection = self.client.get_or_create_collection("search")
|
30 |
+
|
31 |
+
def save(self, file_path):
|
32 |
+
if file_path != self.path:
|
33 |
+
# copy all files in self.path to file_path, with overwrite
|
34 |
+
os.system("cp -r " + self.path + " " + file_path)
|
35 |
+
previous_path = self.path
|
36 |
+
self.path = file_path
|
37 |
+
self.client = chromadb.PersistentClient(path = file_path)
|
38 |
+
# remove previous path if it start with tempdb
|
39 |
+
if previous_path.startswith("tempdb"):
|
40 |
+
os.system("rm -rf " + previous_path)
|
41 |
+
|
42 |
+
|
43 |
+
def load(self, file_path):
|
44 |
+
self.path = file_path
|
45 |
+
self.client = chromadb.PersistentClient(path = file_path)
|
46 |
+
self.collection = self.client.get_collection("search")
|
47 |
+
|
48 |
+
def search(self, vector, n_results):
|
49 |
+
results = self.collection.query(query_embeddings=[vector], n_results=n_results)
|
50 |
+
return results['documents'][0]
|
51 |
+
|
52 |
+
def init_from_docs(self, vectors, documents):
|
53 |
+
if self.client is None:
|
54 |
+
self.init_db()
|
55 |
+
|
56 |
+
ids = []
|
57 |
+
for i, doc in enumerate(documents):
|
58 |
+
first_four_chat = doc[:min(4, len(doc))]
|
59 |
+
ids.append( str(i) + "_" + doc)
|
60 |
+
self.collection.add(embeddings=vectors, documents=documents, ids = ids)
|
61 |
+
|
ChatHaruhi/ErnieGPT.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ErnieGPT.py
|
2 |
+
from pyexpat import model
|
3 |
+
import erniebot
|
4 |
+
#以下密钥信息从os环境获取
|
5 |
+
import os
|
6 |
+
import copy
|
7 |
+
|
8 |
+
# appid = os.environ['APPID']
|
9 |
+
# api_secret = os.environ['APISecret']
|
10 |
+
# api_key = os.environ['APIKey']
|
11 |
+
erniebot.api_type = os.environ["APIType"]
|
12 |
+
erniebot.access_token = os.environ["ErnieAccess"]
|
13 |
+
|
14 |
+
from .BaseLLM import BaseLLM
|
15 |
+
|
16 |
+
class ErnieGPT(BaseLLM):
|
17 |
+
|
18 |
+
def __init__(self,model="ernie-bot", ernie_trick = True ):
|
19 |
+
super(ErnieGPT,self).__init__()
|
20 |
+
self.model = model
|
21 |
+
if model not in ["ernie-bot", "ernie-bot-turbo", "ernie-vilg-v2", "ernie-text-embedding", "ernie-bot-8k", "ernie-bot-4"]:
|
22 |
+
raise Exception("Unknown Ernie model")
|
23 |
+
# SparkApi.answer =""
|
24 |
+
self.messages = []
|
25 |
+
|
26 |
+
self.ernie_trick = ernie_trick
|
27 |
+
|
28 |
+
|
29 |
+
def initialize_message(self):
|
30 |
+
self.messages = []
|
31 |
+
|
32 |
+
def ai_message(self, payload):
|
33 |
+
if len(self.messages) == 0:
|
34 |
+
self.user_message("请根据我的要求进行角色扮演:")
|
35 |
+
elif len(self.messages) % 2 == 1:
|
36 |
+
self.messages.append({"role":"assistant","content":payload})
|
37 |
+
elif len(self.messages)% 2 == 0:
|
38 |
+
self.messages[-1]["content"] += "\n"+ payload
|
39 |
+
|
40 |
+
def system_message(self, payload):
|
41 |
+
|
42 |
+
self.messages.append({"role":"user","content":payload})
|
43 |
+
|
44 |
+
|
45 |
+
def user_message(self, payload):
|
46 |
+
if len(self.messages) % 2 == 0:
|
47 |
+
self.messages.append({"role":"user","content":payload})
|
48 |
+
# self.messages[-1]["content"] +=
|
49 |
+
elif len(self.messages)% 2 == 1:
|
50 |
+
self.messages[-1]["content"] += "\n"+ payload
|
51 |
+
|
52 |
+
def get_response(self):
|
53 |
+
# question = checklen(getText("user",Input))
|
54 |
+
chat_messages = copy.deepcopy(self.messages)
|
55 |
+
|
56 |
+
lines = chat_messages[-1]["content"].split('\n')
|
57 |
+
|
58 |
+
if self.ernie_trick:
|
59 |
+
lines.insert(-1, '请请模仿上述经典桥段进行回复\n')
|
60 |
+
|
61 |
+
chat_messages[-1]["content"] = '\n'.join(lines)
|
62 |
+
|
63 |
+
# chat_messages[-1]["content"] = "请请模仿上述经典桥段进行回复\n" + chat_messages[-1]["content"]
|
64 |
+
response = erniebot.ChatCompletion.create(model=self.model, messages=chat_messages)
|
65 |
+
# message_json = [{"role": "user", "content": self.messages}]
|
66 |
+
# SparkApi.answer =""
|
67 |
+
# SparkApi.main(appid,api_key,api_secret,self.Spark_url,self.domain,message_json)
|
68 |
+
return response["result"]
|
69 |
+
|
70 |
+
def print_prompt(self):
|
71 |
+
for message in self.messages:
|
72 |
+
print(f"{message['role']}: {message['content']}")
|
ChatHaruhi/GLMPro.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .BaseLLM import BaseLLM
|
2 |
+
import os
|
3 |
+
|
4 |
+
zhipu_api = os.environ['ZHIPU_API']
|
5 |
+
|
6 |
+
import zhipuai
|
7 |
+
import time
|
8 |
+
|
9 |
+
class GLMPro( BaseLLM ):
|
10 |
+
def __init__(self, model="chatglm_pro", verbose = False ):
|
11 |
+
super(GLMPro,self).__init__()
|
12 |
+
|
13 |
+
zhipuai.api_key = zhipu_api
|
14 |
+
|
15 |
+
self.verbose = verbose
|
16 |
+
|
17 |
+
self.model_name = model
|
18 |
+
|
19 |
+
self.prompts = []
|
20 |
+
|
21 |
+
if self.verbose == True:
|
22 |
+
print('model name, ', self.model_name )
|
23 |
+
if len( zhipu_api ) > 8:
|
24 |
+
print( 'found apikey ', zhipu_api[:4], '****', zhipu_api[-4:] )
|
25 |
+
else:
|
26 |
+
print( 'found apikey but too short, ' )
|
27 |
+
|
28 |
+
|
29 |
+
def initialize_message(self):
|
30 |
+
self.prompts = []
|
31 |
+
|
32 |
+
def ai_message(self, payload):
|
33 |
+
self.prompts.append({"role":"assistant","content":payload})
|
34 |
+
|
35 |
+
def system_message(self, payload):
|
36 |
+
self.prompts.append({"role":"user","content":payload})
|
37 |
+
|
38 |
+
def user_message(self, payload):
|
39 |
+
self.prompts.append({"role":"user","content":payload})
|
40 |
+
|
41 |
+
def get_response(self):
|
42 |
+
zhipuai.api_key = zhipu_api
|
43 |
+
max_test_name = 5
|
44 |
+
sleep_interval = 3
|
45 |
+
|
46 |
+
request_id = None
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
# try submit asychonize request until success
|
51 |
+
for test_time in range( max_test_name ):
|
52 |
+
response = zhipuai.model_api.async_invoke(
|
53 |
+
model = self.model_name,
|
54 |
+
prompt = self.prompts,
|
55 |
+
temperature = 0)
|
56 |
+
if response['success'] == True:
|
57 |
+
request_id = response['data']['task_id']
|
58 |
+
|
59 |
+
if self.verbose == True:
|
60 |
+
print('submit request, id = ', request_id )
|
61 |
+
break
|
62 |
+
else:
|
63 |
+
print('submit GLM request failed, retrying...')
|
64 |
+
time.sleep( sleep_interval )
|
65 |
+
|
66 |
+
if request_id:
|
67 |
+
# try get response until success
|
68 |
+
for test_time in range( 2 * max_test_name ):
|
69 |
+
result = zhipuai.model_api.query_async_invoke_result( request_id )
|
70 |
+
if result['code'] == 200 and result['data']['task_status'] == 'SUCCESS':
|
71 |
+
|
72 |
+
if self.verbose == True:
|
73 |
+
print('get GLM response success' )
|
74 |
+
|
75 |
+
choices = result['data']['choices']
|
76 |
+
if len( choices ) > 0:
|
77 |
+
return choices[-1]['content'].strip("\"'")
|
78 |
+
|
79 |
+
# other wise means failed
|
80 |
+
if self.verbose == True:
|
81 |
+
print('get GLM response failed, retrying...')
|
82 |
+
# sleep for 1 second
|
83 |
+
time.sleep( sleep_interval )
|
84 |
+
else:
|
85 |
+
print('submit GLM request failed, please check your api key and model name')
|
86 |
+
return ''
|
87 |
+
|
88 |
+
def print_prompt(self):
|
89 |
+
for message in self.prompts:
|
90 |
+
print(f"{message['role']}: {message['content']}")
|
ChatHaruhi/LangChainGPT.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ChatHaruhi: Reviving Anime Character in Reality via Large Language Model
|
2 |
+
#
|
3 |
+
# ChatHaruhi 2.0, built by Cheng Li and Weishi Mi
|
4 |
+
#
|
5 | |
6 |
+
#
|
7 |
+
# Weishi Mi is a second-year graduate student at Tsinghua University, majoring in computer science.
|
8 |
+
# Weishi Mi is pursuing a job or a PhD position, which who will be available next year
|
9 |
+
#
|
10 |
+
# homepage https://github.com/LC1332/Chat-Haruhi-Suzumiya
|
11 |
+
#
|
12 |
+
# ChatHaruhi is a chatbot that can revive anime characters in reality.
|
13 |
+
# the 2.0 version was built by Cheng Li and Weishi Mi.
|
14 |
+
#
|
15 |
+
# Please cite our paper if you use this code for research:
|
16 |
+
#
|
17 |
+
# @misc{li2023chatharuhi,
|
18 |
+
# title={ChatHaruhi: Reviving Anime Character in Reality via Large Language Model},
|
19 |
+
# author={Cheng Li and Ziang Leng and Chenxi Yan and Junyi Shen and Hao Wang and Weishi MI and Yaying Fei and Xiaoyang Feng and Song Yan and HaoSheng Wang and Linkang Zhan and Yaokai Jia and Pingyu Wu and Haozhen Sun},
|
20 |
+
# year={2023},
|
21 |
+
# eprint={2308.09597},
|
22 |
+
# archivePrefix={arXiv},
|
23 |
+
# primaryClass={cs.CL}
|
24 |
+
# }
|
25 |
+
|
26 |
+
|
27 |
+
from langchain.chat_models import ChatOpenAI
|
28 |
+
from langchain.prompts.chat import (
|
29 |
+
ChatPromptTemplate,
|
30 |
+
SystemMessagePromptTemplate,
|
31 |
+
AIMessagePromptTemplate,
|
32 |
+
HumanMessagePromptTemplate,
|
33 |
+
)
|
34 |
+
from langchain.schema import (
|
35 |
+
AIMessage,
|
36 |
+
HumanMessage,
|
37 |
+
SystemMessage
|
38 |
+
)
|
39 |
+
from .BaseLLM import BaseLLM
|
40 |
+
|
41 |
+
import os
|
42 |
+
from dotenv import load_dotenv
|
43 |
+
|
44 |
+
|
45 |
+
class LangChainGPT(BaseLLM):
|
46 |
+
|
47 |
+
def __init__(self, model="gpt-3.5-turbo"):
|
48 |
+
super(LangChainGPT, self).__init__()
|
49 |
+
self.model = model
|
50 |
+
if "OPENAI_API_BASE" in os.environ:
|
51 |
+
load_dotenv()
|
52 |
+
api_base = os.environ["OPENAI_API_BASE"]
|
53 |
+
api_key = os.environ["OPENAI_API_KEY"]
|
54 |
+
self.chat = ChatOpenAI(model=self.model, openai_api_base=api_base)
|
55 |
+
else:
|
56 |
+
self.chat = ChatOpenAI(model=self.model)
|
57 |
+
# add api_base
|
58 |
+
self.messages = []
|
59 |
+
|
60 |
+
def initialize_message(self):
|
61 |
+
self.messages = []
|
62 |
+
|
63 |
+
def ai_message(self, payload):
|
64 |
+
self.messages.append(AIMessage(content=payload))
|
65 |
+
|
66 |
+
def system_message(self, payload):
|
67 |
+
self.messages.append(SystemMessage(content=payload))
|
68 |
+
|
69 |
+
def user_message(self, payload):
|
70 |
+
self.messages.append(HumanMessage(content=payload))
|
71 |
+
|
72 |
+
def get_response(self):
|
73 |
+
response = self.chat(self.messages)
|
74 |
+
return response.content
|
75 |
+
|
76 |
+
def print_prompt(self):
|
77 |
+
for message in self.messages:
|
78 |
+
print(message)
|
ChatHaruhi/PrintLLM.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ChatHaruhi: Reviving Anime Character in Reality via Large Language Model
|
2 |
+
#
|
3 |
+
# ChatHaruhi 2.0, built by Cheng Li and Weishi Mi
|
4 |
+
#
|
5 | |
6 |
+
#
|
7 |
+
# Weishi Mi is a second-year graduate student at Tsinghua University, majoring in computer science.
|
8 |
+
# Weishi Mi is pursuing a job or a PhD position, which who will be available next year
|
9 |
+
#
|
10 |
+
# homepage https://github.com/LC1332/Chat-Haruhi-Suzumiya
|
11 |
+
#
|
12 |
+
# ChatHaruhi is a chatbot that can revive anime characters in reality.
|
13 |
+
# the 2.0 version was built by Cheng Li and Weishi Mi.
|
14 |
+
#
|
15 |
+
# Please cite our paper if you use this code for research:
|
16 |
+
#
|
17 |
+
# @misc{li2023chatharuhi,
|
18 |
+
# title={ChatHaruhi: Reviving Anime Character in Reality via Large Language Model},
|
19 |
+
# author={Cheng Li and Ziang Leng and Chenxi Yan and Junyi Shen and Hao Wang and Weishi MI and Yaying Fei and Xiaoyang Feng and Song Yan and HaoSheng Wang and Linkang Zhan and Yaokai Jia and Pingyu Wu and Haozhen Sun},
|
20 |
+
# year={2023},
|
21 |
+
# eprint={2308.09597},
|
22 |
+
# archivePrefix={arXiv},
|
23 |
+
# primaryClass={cs.CL}
|
24 |
+
# }
|
25 |
+
#
|
26 |
+
# This PrintLLM.py is for debuging with any real-runing LLM
|
27 |
+
# so you can see full prompt and copy it into GPT or Claude to debug
|
28 |
+
#
|
29 |
+
|
30 |
+
from .BaseLLM import BaseLLM
|
31 |
+
|
32 |
+
class PrintLLM(BaseLLM):
|
33 |
+
|
34 |
+
def __init__(self ):
|
35 |
+
self.messages = []
|
36 |
+
self.messages.append("Noticing: This is a print LLM for debug.")
|
37 |
+
self.messages.append("But you can also copy the prompt into GPT or Claude to debugging")
|
38 |
+
|
39 |
+
def initialize_message(self):
|
40 |
+
self.messages = []
|
41 |
+
self.messages.append("Noticing: This is a print LLM for debug.")
|
42 |
+
self.messages.append("But you can also copy the prompt into GPT or Claude to debugging")
|
43 |
+
|
44 |
+
def ai_message(self, payload):
|
45 |
+
self.messages.append("AI: \n" + payload)
|
46 |
+
|
47 |
+
def system_message(self, payload):
|
48 |
+
self.messages.append("System: \n" + payload)
|
49 |
+
|
50 |
+
def user_message(self, payload):
|
51 |
+
self.messages.append("User: \n" + payload)
|
52 |
+
|
53 |
+
def get_response(self):
|
54 |
+
for message in self.messages:
|
55 |
+
print(message)
|
56 |
+
response = input("Please input your response: ")
|
57 |
+
return response
|
58 |
+
|
59 |
+
def print_prompt(self):
|
60 |
+
for message in self.messages:
|
61 |
+
print(message)
|
ChatHaruhi/Qwen118k2GPT.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .BaseLLM import BaseLLM
|
3 |
+
from transformers import AutoTokenizer, AutoModel
|
4 |
+
from peft import PeftModel
|
5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
6 |
+
from transformers.generation import GenerationConfig
|
7 |
+
|
8 |
+
tokenizer_qwen = None
|
9 |
+
model_qwen = None
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
def initialize_Qwen2LORA(model):
|
14 |
+
global model_qwen, tokenizer_qwen
|
15 |
+
|
16 |
+
if model_qwen is None:
|
17 |
+
model_qwen = AutoModelForCausalLM.from_pretrained(
|
18 |
+
model,
|
19 |
+
# torch_dtype=torch.float16,
|
20 |
+
device_map="auto",
|
21 |
+
trust_remote_code=True
|
22 |
+
).half()
|
23 |
+
model_qwen = model_qwen.eval()
|
24 |
+
# model_qwen = PeftModel.from_pretrained(
|
25 |
+
# model_qwen,
|
26 |
+
# "silk-road/Chat-Haruhi-Fusion_B"
|
27 |
+
# )
|
28 |
+
|
29 |
+
if tokenizer_qwen is None:
|
30 |
+
tokenizer_qwen = AutoTokenizer.from_pretrained(
|
31 |
+
model,
|
32 |
+
# use_fast=True,
|
33 |
+
trust_remote_code=True
|
34 |
+
)
|
35 |
+
|
36 |
+
return model_qwen, tokenizer_qwen
|
37 |
+
|
38 |
+
def Qwen_tokenizer(text):
|
39 |
+
return len(tokenizer_qwen.encode(text))
|
40 |
+
|
41 |
+
class Qwen118k2GPT(BaseLLM):
|
42 |
+
def __init__(self, model):
|
43 |
+
super(Qwen118k2GPT, self).__init__()
|
44 |
+
global model_qwen, tokenizer_qwen
|
45 |
+
if model == "Qwen/Qwen-1_8B-Chat":
|
46 |
+
tokenizer_qwen = AutoTokenizer.from_pretrained(
|
47 |
+
"Qwen/Qwen-1_8B-Chat",
|
48 |
+
trust_remote_code=True
|
49 |
+
)
|
50 |
+
model_qwen = AutoModelForCausalLM.from_pretrained(
|
51 |
+
"Qwen/Qwen-1_8B-Chat",
|
52 |
+
device_map="auto",
|
53 |
+
trust_remote_code=True
|
54 |
+
).eval()
|
55 |
+
self.model = model_qwen
|
56 |
+
self.tokenizer = tokenizer_qwen
|
57 |
+
elif "silk-road/" in model :
|
58 |
+
self.model, self.tokenizer = initialize_Qwen2LORA(model)
|
59 |
+
else:
|
60 |
+
raise Exception("Unknown Qwen model")
|
61 |
+
self.messages = ""
|
62 |
+
|
63 |
+
def initialize_message(self):
|
64 |
+
self.messages = ""
|
65 |
+
|
66 |
+
def ai_message(self, payload):
|
67 |
+
self.messages = "AI: " + self.messages + "\n " + payload
|
68 |
+
|
69 |
+
def system_message(self, payload):
|
70 |
+
self.messages = "SYSTEM PROMPT: " + self.messages + "\n " + payload
|
71 |
+
|
72 |
+
def user_message(self, payload):
|
73 |
+
self.messages = "User: " + self.messages + "\n " + payload
|
74 |
+
|
75 |
+
def get_response(self):
|
76 |
+
with torch.no_grad():
|
77 |
+
response, history = self.model.chat(self.tokenizer, self.messages, history=[])
|
78 |
+
# print(response)
|
79 |
+
return response
|
80 |
+
|
81 |
+
def print_prompt(self):
|
82 |
+
print(type(self.messages))
|
83 |
+
print(self.messages)
|
84 |
+
|
85 |
+
|
ChatHaruhi/SparkApi.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 由讯飞提供的websocket接口,用于与星火机器人进行交互
|
2 |
+
|
3 |
+
import _thread as thread
|
4 |
+
import base64
|
5 |
+
import datetime
|
6 |
+
import hashlib
|
7 |
+
import hmac
|
8 |
+
import json
|
9 |
+
from urllib.parse import urlparse
|
10 |
+
import ssl
|
11 |
+
from datetime import datetime
|
12 |
+
from time import mktime
|
13 |
+
from urllib.parse import urlencode
|
14 |
+
from wsgiref.handlers import format_date_time
|
15 |
+
|
16 |
+
import websocket # 使用websocket_client
|
17 |
+
answer = ""
|
18 |
+
|
19 |
+
class Ws_Param(object):
|
20 |
+
# 初始化
|
21 |
+
def __init__(self, APPID, APIKey, APISecret, Spark_url):
|
22 |
+
self.APPID = APPID
|
23 |
+
self.APIKey = APIKey
|
24 |
+
self.APISecret = APISecret
|
25 |
+
self.host = urlparse(Spark_url).netloc
|
26 |
+
self.path = urlparse(Spark_url).path
|
27 |
+
self.Spark_url = Spark_url
|
28 |
+
|
29 |
+
# 生成url
|
30 |
+
def create_url(self):
|
31 |
+
# 生成RFC1123格式的时间戳
|
32 |
+
now = datetime.now()
|
33 |
+
date = format_date_time(mktime(now.timetuple()))
|
34 |
+
|
35 |
+
# 拼接字符串
|
36 |
+
signature_origin = "host: " + self.host + "\n"
|
37 |
+
signature_origin += "date: " + date + "\n"
|
38 |
+
signature_origin += "GET " + self.path + " HTTP/1.1"
|
39 |
+
|
40 |
+
# 进行hmac-sha256进行加密
|
41 |
+
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
42 |
+
digestmod=hashlib.sha256).digest()
|
43 |
+
|
44 |
+
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
45 |
+
|
46 |
+
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
47 |
+
|
48 |
+
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
49 |
+
|
50 |
+
# 将请求的鉴权参数组合为字典
|
51 |
+
v = {
|
52 |
+
"authorization": authorization,
|
53 |
+
"date": date,
|
54 |
+
"host": self.host
|
55 |
+
}
|
56 |
+
# 拼接鉴权参数,生成url
|
57 |
+
url = self.Spark_url + '?' + urlencode(v)
|
58 |
+
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
59 |
+
return url
|
60 |
+
|
61 |
+
|
62 |
+
# 收到websocket错误的处理
|
63 |
+
def on_error(ws, error):
|
64 |
+
print("### error:", error)
|
65 |
+
|
66 |
+
|
67 |
+
# 收到websocket关闭的处理
|
68 |
+
def on_close(ws,one,two):
|
69 |
+
print(" ")
|
70 |
+
|
71 |
+
|
72 |
+
# 收到websocket连接建立的处理
|
73 |
+
def on_open(ws):
|
74 |
+
thread.start_new_thread(run, (ws,))
|
75 |
+
|
76 |
+
|
77 |
+
def run(ws, *args):
|
78 |
+
data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question))
|
79 |
+
ws.send(data)
|
80 |
+
|
81 |
+
|
82 |
+
# 收到websocket消息的处理
|
83 |
+
def on_message(ws, message):
|
84 |
+
# print(message)
|
85 |
+
data = json.loads(message)
|
86 |
+
code = data['header']['code']
|
87 |
+
if code != 0:
|
88 |
+
print(f'请求错误: {code}, {data}')
|
89 |
+
ws.close()
|
90 |
+
else:
|
91 |
+
choices = data["payload"]["choices"]
|
92 |
+
status = choices["status"]
|
93 |
+
content = choices["text"][0]["content"]
|
94 |
+
# print(content,end ="")
|
95 |
+
global answer
|
96 |
+
answer += content
|
97 |
+
# print(1)
|
98 |
+
if status == 2:
|
99 |
+
ws.close()
|
100 |
+
|
101 |
+
|
102 |
+
def gen_params(appid, domain,question):
|
103 |
+
"""
|
104 |
+
通过appid和用户的提问来生成请参数
|
105 |
+
"""
|
106 |
+
data = {
|
107 |
+
"header": {
|
108 |
+
"app_id": appid,
|
109 |
+
"uid": "1234"
|
110 |
+
},
|
111 |
+
"parameter": {
|
112 |
+
"chat": {
|
113 |
+
"domain": domain,
|
114 |
+
"random_threshold": 0.5,
|
115 |
+
"max_tokens": 2048,
|
116 |
+
"auditing": "default"
|
117 |
+
}
|
118 |
+
},
|
119 |
+
"payload": {
|
120 |
+
"message": {
|
121 |
+
"text": question
|
122 |
+
}
|
123 |
+
}
|
124 |
+
}
|
125 |
+
return data
|
126 |
+
|
127 |
+
|
128 |
+
def main(appid, api_key, api_secret, Spark_url,domain, question):
|
129 |
+
# print("星火:")
|
130 |
+
wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
|
131 |
+
websocket.enableTrace(False)
|
132 |
+
wsUrl = wsParam.create_url()
|
133 |
+
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
|
134 |
+
ws.appid = appid
|
135 |
+
ws.question = question
|
136 |
+
ws.domain = domain
|
137 |
+
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
138 |
+
|
139 |
+
|
ChatHaruhi/SparkGPT.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SparkGPT.py
|
2 |
+
from . import SparkApi
|
3 |
+
#以下密钥信息从os环境获取
|
4 |
+
import os
|
5 |
+
|
6 |
+
appid = os.environ['APPID']
|
7 |
+
api_secret = os.environ['APISecret']
|
8 |
+
api_key = os.environ['APIKey']
|
9 |
+
|
10 |
+
from .BaseLLM import BaseLLM
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
class SparkGPT(BaseLLM):
|
16 |
+
|
17 |
+
def __init__(self, model="Spark3.0"):
|
18 |
+
super(SparkGPT,self).__init__()
|
19 |
+
self.model_type = model
|
20 |
+
self.messages = []
|
21 |
+
if self.model_type == "Spark2.0":
|
22 |
+
self.domain = "generalv2" # v2.0版本
|
23 |
+
self.Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址
|
24 |
+
elif self.model_type == "Spark1.5":
|
25 |
+
self.domain = "general" # v1.5版本
|
26 |
+
self.Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址
|
27 |
+
elif self.model_type == "Spark3.0":
|
28 |
+
self.domain = "generalv3" # v3.0版本
|
29 |
+
self.Spark_url = "ws://spark-api.xf-yun.com/v3.1/chat" # v3.0环境的地址
|
30 |
+
else:
|
31 |
+
raise Exception("Unknown Spark model")
|
32 |
+
|
33 |
+
def initialize_message(self):
|
34 |
+
self.messages = []
|
35 |
+
|
36 |
+
def ai_message(self, payload):
|
37 |
+
if len(self.messages) == 0:
|
38 |
+
self.user_message("请根据我的要求进行角色扮演:")
|
39 |
+
elif len(self.messages) % 2 == 1:
|
40 |
+
self.messages.append({"role":"assistant","content":payload})
|
41 |
+
elif len(self.messages)% 2 == 0:
|
42 |
+
self.messages[-1]["content"] += "\n"+ payload
|
43 |
+
|
44 |
+
def system_message(self, payload):
|
45 |
+
|
46 |
+
self.messages.append({"role":"user","content":payload})
|
47 |
+
|
48 |
+
|
49 |
+
def user_message(self, payload):
|
50 |
+
if len(self.messages) % 2 == 0:
|
51 |
+
self.messages.append({"role":"user","content":payload})
|
52 |
+
# self.messages[-1]["content"] +=
|
53 |
+
elif len(self.messages)% 2 == 1:
|
54 |
+
self.messages[-1]["content"] += "\n"+ payload
|
55 |
+
|
56 |
+
def get_response(self):
|
57 |
+
# question = checklen(getText("user",Input))
|
58 |
+
SparkApi.answer =""
|
59 |
+
if self.model_type == "Spark2.0":
|
60 |
+
self.domain = "generalv2" # v2.0版本
|
61 |
+
self.Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址
|
62 |
+
elif self.model_type == "Spark1.5":
|
63 |
+
self.domain = "general" # v1.5版本
|
64 |
+
self.Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址
|
65 |
+
elif self.model_type == "Spark3.0":
|
66 |
+
self.domain = "generalv3" # v3.0版本
|
67 |
+
self.Spark_url = "ws://spark-api.xf-yun.com/v3.1/chat" # v3.0环境的地址
|
68 |
+
else:
|
69 |
+
raise Exception("Unknown Spark model")
|
70 |
+
SparkApi.main(appid,api_key,api_secret,self.Spark_url,self.domain,self.messages)
|
71 |
+
return SparkApi.answer
|
72 |
+
|
73 |
+
def print_prompt(self):
|
74 |
+
for message in self.messages:
|
75 |
+
print(f"{message['role']}: {message['content']}")
|
ChatHaruhi/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ChatHaruhi: Reviving Anime Character in Reality via Large Language Model
|
2 |
+
#
|
3 |
+
# ChatHaruhi 2.0, built by Cheng Li and Weishi Mi
|
4 |
+
#
|
5 | |
6 |
+
#
|
7 |
+
# Weishi Mi is a second-year graduate student at Tsinghua University, majoring in computer science.
|
8 |
+
# Weishi Mi is pursuing a job or a PhD position, which who will be available next year
|
9 |
+
#
|
10 |
+
# homepage https://github.com/LC1332/Chat-Haruhi-Suzumiya
|
11 |
+
#
|
12 |
+
# ChatHaruhi is a chatbot that can revive anime characters in reality.
|
13 |
+
# the 2.0 version was built by Cheng Li and Weishi Mi.
|
14 |
+
#
|
15 |
+
# Please cite our paper if you use this code for research:
|
16 |
+
#
|
17 |
+
# @misc{li2023chatharuhi,
|
18 |
+
# title={ChatHaruhi: Reviving Anime Character in Reality via Large Language Model},
|
19 |
+
# author={Cheng Li and Ziang Leng and Chenxi Yan and Junyi Shen and Hao Wang and Weishi MI and Yaying Fei and Xiaoyang Feng and Song Yan and HaoSheng Wang and Linkang Zhan and Yaokai Jia and Pingyu Wu and Haozhen Sun},
|
20 |
+
# year={2023},
|
21 |
+
# eprint={2308.09597},
|
22 |
+
# archivePrefix={arXiv},
|
23 |
+
# primaryClass={cs.CL}
|
24 |
+
# }
|
25 |
+
|
26 |
+
from .ChatHaruhi import ChatHaruhi
|
ChatHaruhi/role_name_to_file.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ChatHaruhi: Reviving Anime Character in Reality via Large Language Model
|
2 |
+
#
|
3 |
+
# ChatHaruhi 2.0, built by Cheng Li and Weishi Mi
|
4 |
+
#
|
5 | |
6 |
+
#
|
7 |
+
# Weishi Mi is a second-year graduate student at Tsinghua University, majoring in computer science.
|
8 |
+
# Weishi Mi is pursuing a job or a PhD position, which who will be available next year
|
9 |
+
#
|
10 |
+
# homepage https://github.com/LC1332/Chat-Haruhi-Suzumiya
|
11 |
+
#
|
12 |
+
# ChatHaruhi is a chatbot that can revive anime characters in reality.
|
13 |
+
# the 2.0 version was built by Cheng Li and Weishi Mi.
|
14 |
+
#
|
15 |
+
# Please cite our paper if you use this code for research:
|
16 |
+
#
|
17 |
+
# @misc{li2023chatharuhi,
|
18 |
+
# title={ChatHaruhi: Reviving Anime Character in Reality via Large Language Model},
|
19 |
+
# author={Cheng Li and Ziang Leng and Chenxi Yan and Junyi Shen and Hao Wang and Weishi MI and Yaying Fei and Xiaoyang Feng and Song Yan and HaoSheng Wang and Linkang Zhan and Yaokai Jia and Pingyu Wu and Haozhen Sun},
|
20 |
+
# year={2023},
|
21 |
+
# eprint={2308.09597},
|
22 |
+
# archivePrefix={arXiv},
|
23 |
+
# primaryClass={cs.CL}
|
24 |
+
# }
|
25 |
+
#
|
26 |
+
# if you have attempt to add a new character, please add the role name here
|
27 |
+
#
|
28 |
+
|
29 |
+
role_name_Haruhiu = {'汤师爷': 'tangshiye', 'tangshiye': 'tangshiye', 'Tangshiye': 'tangshiye',
|
30 |
+
'慕容复': 'murongfu', 'murongfu': 'murongfu', 'Murongfu': 'murongfu',
|
31 |
+
'李云龙': 'liyunlong', 'liyunlong': 'liyunlong', 'Liyunlong': 'liyunlong',
|
32 |
+
'Luna': 'Luna', '王多鱼': 'wangduoyu', 'wangduoyu': 'wangduoyu',
|
33 |
+
'Wangduoyu': 'wangduoyu', 'Ron': 'Ron', '鸠摩智': 'jiumozhi',
|
34 |
+
'jiumozhi': 'jiumozhi', 'Jiumozhi': 'jiumozhi', 'Snape': 'Snape',
|
35 |
+
'凉宫春日': 'haruhi', 'haruhi': 'haruhi', 'Haruhi': 'haruhi',
|
36 |
+
'Malfoy': 'Malfoy', '虚竹': 'xuzhu', 'xuzhu': 'xuzhu',
|
37 |
+
'Xuzhu': 'xuzhu', '萧峰': 'xiaofeng',
|
38 |
+
'xiaofeng': 'xiaofeng', 'Xiaofeng': 'xiaofeng', '段誉': 'duanyu',
|
39 |
+
'duanyu': 'duanyu', 'Duanyu': 'duanyu', 'Hermione': 'Hermione',
|
40 |
+
'Dumbledore': 'Dumbledore', '王语嫣': 'wangyuyan', 'wangyuyan':
|
41 |
+
'wangyuyan', 'Wangyuyan': 'wangyuyan', 'Harry': 'Harry',
|
42 |
+
'McGonagall': 'McGonagall', '白展堂': 'baizhantang',
|
43 |
+
'baizhantang': 'baizhantang', 'Baizhantang': 'baizhantang',
|
44 |
+
'佟湘玉': 'tongxiangyu', 'tongxiangyu': 'tongxiangyu',
|
45 |
+
'Tongxiangyu': 'tongxiangyu', '郭芙蓉': 'guofurong',
|
46 |
+
'guofurong': 'guofurong', 'Guofurong': 'guofurong', '流浪者': 'wanderer',
|
47 |
+
'wanderer': 'wanderer', 'Wanderer': 'wanderer', '钟离': 'zhongli',
|
48 |
+
'zhongli': 'zhongli', 'Zhongli': 'zhongli', '胡桃': 'hutao', 'hutao': 'hutao',
|
49 |
+
'Hutao': 'hutao', 'Sheldon': 'Sheldon', 'Raj': 'Raj',
|
50 |
+
'Penny': 'Penny', '韦小宝': 'weixiaobao', 'weixiaobao': 'weixiaobao',
|
51 |
+
'Weixiaobao': 'weixiaobao', '乔峰': 'qiaofeng', 'qiaofeng': 'qiaofeng',
|
52 |
+
'Qiaofeng': 'qiaofeng', '神里绫华': 'ayaka', 'ayaka': 'ayaka',
|
53 |
+
'Ayaka': 'ayaka', '雷电将军': 'raidenShogun', 'raidenShogun': 'raidenShogun',
|
54 |
+
'RaidenShogun': 'raidenShogun', '于谦': 'yuqian', 'yuqian': 'yuqian',
|
55 |
+
'Yuqian': 'yuqian', 'Professor McGonagall': 'McGonagall',
|
56 |
+
'Professor Dumbledore': 'Dumbledore'}
|
57 |
+
|
58 |
+
# input role_name , nick name is also allowed
|
59 |
+
# output folder_role_name and url url = f'https://github.com/LC1332/Haruhi-2-Dev/raw/main/data/character_in_zip/{role_name}.zip'
|
60 |
+
def get_folder_role_name(role_name):
|
61 |
+
if role_name in role_name_Haruhiu:
|
62 |
+
folder_role_name = role_name_Haruhiu[role_name]
|
63 |
+
url = f'https://github.com/LC1332/Haruhi-2-Dev/raw/main/data/character_in_zip/{folder_role_name}.zip'
|
64 |
+
return folder_role_name, url
|
65 |
+
else:
|
66 |
+
print('role_name {} not found, using haruhi as default'.format(role_name))
|
67 |
+
return get_folder_role_name('haruhi')
|
ChatHaruhi/utils.py
ADDED
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import Namespace
|
2 |
+
|
3 |
+
from openai import OpenAI
|
4 |
+
|
5 |
+
# client = OpenAI(api_key=<YOUR OPENAI API KEY>)
|
6 |
+
|
7 |
+
from transformers import AutoModel, AutoTokenizer
|
8 |
+
import torch
|
9 |
+
import random
|
10 |
+
|
11 |
+
import tiktoken
|
12 |
+
import re
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
import base64
|
17 |
+
import struct
|
18 |
+
|
19 |
+
import os
|
20 |
+
|
21 |
+
import tqdm
|
22 |
+
|
23 |
+
import requests
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
def get_access_token():
|
28 |
+
API_KEY = os.getenv("StoryAudit_API_AK")
|
29 |
+
SECRET_KEY = os.getenv("StoryAudit_API_SK")
|
30 |
+
|
31 |
+
"""
|
32 |
+
使用 AK,SK 生成鉴权签名(Access Token)
|
33 |
+
:return: access_token,或是None(如果错误)
|
34 |
+
"""
|
35 |
+
url = "https://aip.baidubce.com/oauth/2.0/token"
|
36 |
+
params = {"grant_type": "client_credentials", "client_id": API_KEY, "client_secret": SECRET_KEY}
|
37 |
+
return str(requests.post(url, params=params).json().get("access_token"))
|
38 |
+
|
39 |
+
'''
|
40 |
+
文本审核接口
|
41 |
+
'''
|
42 |
+
def text_censor(text):
|
43 |
+
request_url = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined"
|
44 |
+
|
45 |
+
params = {"text":text}
|
46 |
+
access_token = get_access_token()
|
47 |
+
request_url = request_url + "?access_token=" + access_token
|
48 |
+
headers = {'content-type': 'application/x-www-form-urlencoded'}
|
49 |
+
response = requests.post(request_url, data=params, headers=headers)
|
50 |
+
return response.json()["conclusion"] == "合规"
|
51 |
+
|
52 |
+
def package_role( system_prompt, texts_path , embedding ):
|
53 |
+
datas = []
|
54 |
+
|
55 |
+
# 暂时只有一种embedding 'luotuo_openai'
|
56 |
+
embed_name = 'luotuo_openai'
|
57 |
+
|
58 |
+
datas.append({ 'text':system_prompt , embed_name:'system_prompt'})
|
59 |
+
datas.append({ 'text':'Reserve Config Setting Here' , embed_name:'config'})
|
60 |
+
|
61 |
+
|
62 |
+
# debug_count = 3
|
63 |
+
|
64 |
+
# for file in os.listdir(texts_path):
|
65 |
+
|
66 |
+
files = os.listdir(texts_path)
|
67 |
+
|
68 |
+
for i in tqdm.tqdm(range(len(files))):
|
69 |
+
file = files[i]
|
70 |
+
# if file name end with txt
|
71 |
+
if file.endswith(".txt"):
|
72 |
+
file_path = os.path.join(texts_path, file)
|
73 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
74 |
+
current_str = f.read()
|
75 |
+
current_vec = embedding(current_str)
|
76 |
+
encode_vec = float_array_to_base64(current_vec)
|
77 |
+
datas.append({ 'text':current_str , embed_name:encode_vec})
|
78 |
+
|
79 |
+
# debug_count -= 1
|
80 |
+
# if debug_count == 0:
|
81 |
+
# break
|
82 |
+
return datas
|
83 |
+
|
84 |
+
|
85 |
+
import struct
|
86 |
+
|
87 |
+
def string_to_base64(text):
|
88 |
+
byte_array = b''
|
89 |
+
for char in text:
|
90 |
+
num_bytes = char.encode('utf-8')
|
91 |
+
byte_array += num_bytes
|
92 |
+
|
93 |
+
base64_data = base64.b64encode(byte_array)
|
94 |
+
return base64_data.decode('utf-8')
|
95 |
+
|
96 |
+
def base64_to_string(base64_data):
|
97 |
+
byte_array = base64.b64decode(base64_data)
|
98 |
+
text = byte_array.decode('utf-8')
|
99 |
+
return text
|
100 |
+
|
101 |
+
|
102 |
+
def float_array_to_base64(float_arr):
|
103 |
+
|
104 |
+
byte_array = b''
|
105 |
+
|
106 |
+
for f in float_arr:
|
107 |
+
# 将每个浮点数打包为4字节
|
108 |
+
num_bytes = struct.pack('!f', f)
|
109 |
+
byte_array += num_bytes
|
110 |
+
|
111 |
+
# 将字节数组进行base64编码
|
112 |
+
base64_data = base64.b64encode(byte_array)
|
113 |
+
|
114 |
+
return base64_data.decode('utf-8')
|
115 |
+
|
116 |
+
def base64_to_float_array(base64_data):
|
117 |
+
|
118 |
+
byte_array = base64.b64decode(base64_data)
|
119 |
+
|
120 |
+
float_array = []
|
121 |
+
|
122 |
+
# 每 4 个字节解析为一个浮点数
|
123 |
+
for i in range(0, len(byte_array), 4):
|
124 |
+
num = struct.unpack('!f', byte_array[i:i+4])[0]
|
125 |
+
float_array.append(num)
|
126 |
+
|
127 |
+
return float_array
|
128 |
+
|
129 |
+
|
130 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
131 |
+
|
132 |
+
_luotuo_model = None
|
133 |
+
|
134 |
+
_luotuo_model_en = None
|
135 |
+
_luotuo_en_tokenizer = None
|
136 |
+
|
137 |
+
_enc_model = None
|
138 |
+
|
139 |
+
# ======== add bge_zh mmodel
|
140 |
+
# by Cheng Li
|
141 |
+
# 这一次我们试图一次性去适配更多的模型
|
142 |
+
|
143 |
+
_model_pool = {}
|
144 |
+
_tokenizer_pool = {}
|
145 |
+
|
146 |
+
# BAAI/bge-small-zh-v1.5
|
147 |
+
|
148 |
+
def get_general_embeddings( sentences , model_name = "BAAI/bge-small-zh-v1.5" ):
|
149 |
+
|
150 |
+
global _model_pool
|
151 |
+
global _tokenizer_pool
|
152 |
+
|
153 |
+
if model_name not in _model_pool:
|
154 |
+
from transformers import AutoTokenizer, AutoModel
|
155 |
+
_tokenizer_pool[model_name] = AutoTokenizer.from_pretrained(model_name)
|
156 |
+
_model_pool[model_name] = AutoModel.from_pretrained(model_name)
|
157 |
+
|
158 |
+
_model_pool[model_name].eval()
|
159 |
+
|
160 |
+
# Tokenize sentences
|
161 |
+
encoded_input = _tokenizer_pool[model_name](sentences, padding=True, truncation=True, return_tensors='pt', max_length = 512)
|
162 |
+
|
163 |
+
# Compute token embeddings
|
164 |
+
with torch.no_grad():
|
165 |
+
model_output = _model_pool[model_name](**encoded_input)
|
166 |
+
# Perform pooling. In this case, cls pooling.
|
167 |
+
sentence_embeddings = model_output[0][:, 0]
|
168 |
+
|
169 |
+
# normalize embeddings
|
170 |
+
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
|
171 |
+
return sentence_embeddings.cpu().tolist()
|
172 |
+
|
173 |
+
def get_general_embedding( text_or_texts , model_name = "BAAI/bge-small-zh-v1.5" ):
|
174 |
+
if isinstance(text_or_texts, str):
|
175 |
+
return get_general_embeddings([text_or_texts], model_name)[0]
|
176 |
+
else:
|
177 |
+
return get_general_embeddings_safe(text_or_texts, model_name)
|
178 |
+
|
179 |
+
general_batch_size = 16
|
180 |
+
|
181 |
+
import math
|
182 |
+
|
183 |
+
def get_general_embeddings_safe(sentences, model_name = "BAAI/bge-small-zh-v1.5"):
|
184 |
+
|
185 |
+
embeddings = []
|
186 |
+
|
187 |
+
num_batches = math.ceil(len(sentences) / general_batch_size)
|
188 |
+
|
189 |
+
for i in tqdm.tqdm( range(num_batches) ):
|
190 |
+
# print("run bge with batch ", i)
|
191 |
+
start_index = i * general_batch_size
|
192 |
+
end_index = min(len(sentences), start_index + general_batch_size)
|
193 |
+
batch = sentences[start_index:end_index]
|
194 |
+
embs = get_general_embeddings(batch, model_name)
|
195 |
+
embeddings.extend(embs)
|
196 |
+
|
197 |
+
return embeddings
|
198 |
+
|
199 |
+
def get_bge_zh_embedding( text_or_texts ):
|
200 |
+
return get_general_embedding(text_or_texts, "BAAI/bge-small-zh-v1.5")
|
201 |
+
|
202 |
+
## TODO: 重构bge_en部分的代码,复用general的函数
|
203 |
+
|
204 |
+
# ======== add bge model
|
205 |
+
# by Cheng Li
|
206 |
+
# for English only right now
|
207 |
+
|
208 |
+
_bge_model = None
|
209 |
+
_bge_tokenizer = None
|
210 |
+
|
211 |
+
def get_bge_embeddings( sentences ):
|
212 |
+
# unsafe ensure batch size by yourself
|
213 |
+
|
214 |
+
global _bge_model
|
215 |
+
global _bge_tokenizer
|
216 |
+
|
217 |
+
if _bge_model is None:
|
218 |
+
from transformers import AutoTokenizer, AutoModel
|
219 |
+
_bge_tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5')
|
220 |
+
_bge_model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5')
|
221 |
+
|
222 |
+
_bge_model.eval()
|
223 |
+
|
224 |
+
# Tokenize sentences
|
225 |
+
encoded_input = _bge_tokenizer(sentences, padding=True, truncation=True, return_tensors='pt', max_length = 512)
|
226 |
+
|
227 |
+
# Compute token embeddings
|
228 |
+
with torch.no_grad():
|
229 |
+
model_output = _bge_model(**encoded_input)
|
230 |
+
# Perform pooling. In this case, cls pooling.
|
231 |
+
sentence_embeddings = model_output[0][:, 0]
|
232 |
+
# normalize embeddings
|
233 |
+
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
|
234 |
+
return sentence_embeddings.cpu().tolist()
|
235 |
+
|
236 |
+
def get_bge_embedding( text_or_texts ):
|
237 |
+
if isinstance(text_or_texts, str):
|
238 |
+
return get_bge_embeddings([text_or_texts])[0]
|
239 |
+
else:
|
240 |
+
return get_bge_embeddings_safe(text_or_texts)
|
241 |
+
|
242 |
+
bge_batch_size = 32
|
243 |
+
|
244 |
+
import math
|
245 |
+
# from tqdm import tqdm
|
246 |
+
|
247 |
+
def get_bge_embeddings_safe(sentences):
|
248 |
+
|
249 |
+
embeddings = []
|
250 |
+
|
251 |
+
num_batches = math.ceil(len(sentences) / bge_batch_size)
|
252 |
+
|
253 |
+
for i in tqdm.tqdm( range(num_batches) ):
|
254 |
+
# print("run bge with batch ", i)
|
255 |
+
start_index = i * bge_batch_size
|
256 |
+
end_index = min(len(sentences), start_index + bge_batch_size)
|
257 |
+
batch = sentences[start_index:end_index]
|
258 |
+
embs = get_bge_embeddings(batch)
|
259 |
+
embeddings.extend(embs)
|
260 |
+
|
261 |
+
return embeddings
|
262 |
+
|
263 |
+
# === add bge model
|
264 |
+
|
265 |
+
def tiktokenizer( text ):
|
266 |
+
global _enc_model
|
267 |
+
|
268 |
+
if _enc_model is None:
|
269 |
+
_enc_model = tiktoken.get_encoding("cl100k_base")
|
270 |
+
|
271 |
+
return len(_enc_model.encode(text))
|
272 |
+
|
273 |
+
def response_postprocess(text,dialogue_bra_token = '「',dialogue_ket_token = '」'):
|
274 |
+
lines = text.split('\n')
|
275 |
+
new_lines = ""
|
276 |
+
|
277 |
+
first_name = None
|
278 |
+
|
279 |
+
for line in lines:
|
280 |
+
line = line.strip(" ")
|
281 |
+
match = re.match(r'^(.*?)[::]' + dialogue_bra_token + r"(.*?)" + dialogue_ket_token + r"$", line)
|
282 |
+
|
283 |
+
|
284 |
+
if match:
|
285 |
+
curr_name = match.group(1)
|
286 |
+
# print(curr_name)
|
287 |
+
if first_name is None:
|
288 |
+
first_name = curr_name
|
289 |
+
new_lines += (match.group(2))
|
290 |
+
else:
|
291 |
+
if curr_name != first_name:
|
292 |
+
return first_name + ":" + dialogue_bra_token + new_lines + dialogue_ket_token
|
293 |
+
else:
|
294 |
+
new_lines += (match.group(2))
|
295 |
+
|
296 |
+
else:
|
297 |
+
if first_name == None:
|
298 |
+
return text
|
299 |
+
else:
|
300 |
+
return first_name + ":" + dialogue_bra_token + new_lines + dialogue_ket_token
|
301 |
+
return first_name + ":" + dialogue_bra_token + new_lines + dialogue_ket_token
|
302 |
+
|
303 |
+
def download_models():
|
304 |
+
print("正在下载Luotuo-Bert")
|
305 |
+
# Import our models. The package will take care of downloading the models automatically
|
306 |
+
model_args = Namespace(do_mlm=None, pooler_type="cls", temp=0.05, mlp_only_train=False,
|
307 |
+
init_embeddings_model=None)
|
308 |
+
model = AutoModel.from_pretrained("silk-road/luotuo-bert-medium", trust_remote_code=True, model_args=model_args).to(
|
309 |
+
device)
|
310 |
+
print("Luotuo-Bert下载完毕")
|
311 |
+
return model
|
312 |
+
|
313 |
+
def get_luotuo_model():
|
314 |
+
global _luotuo_model
|
315 |
+
if _luotuo_model is None:
|
316 |
+
_luotuo_model = download_models()
|
317 |
+
return _luotuo_model
|
318 |
+
|
319 |
+
|
320 |
+
def luotuo_embedding(model, texts):
|
321 |
+
# Tokenize the texts_source
|
322 |
+
tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert-medium")
|
323 |
+
inputs = tokenizer(texts, padding=True, truncation=False, return_tensors="pt")
|
324 |
+
inputs = inputs.to(device)
|
325 |
+
# Extract the embeddings
|
326 |
+
# Get the embeddings
|
327 |
+
with torch.no_grad():
|
328 |
+
embeddings = model(**inputs, output_hidden_states=True, return_dict=True, sent_emb=True).pooler_output
|
329 |
+
return embeddings
|
330 |
+
|
331 |
+
def luotuo_en_embedding( texts ):
|
332 |
+
# this function implemented by Cheng
|
333 |
+
global _luotuo_model_en
|
334 |
+
global _luotuo_en_tokenizer
|
335 |
+
|
336 |
+
if _luotuo_model_en is None:
|
337 |
+
_luotuo_en_tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert-en")
|
338 |
+
_luotuo_model_en = AutoModel.from_pretrained("silk-road/luotuo-bert-en").to(device)
|
339 |
+
|
340 |
+
if _luotuo_en_tokenizer is None:
|
341 |
+
_luotuo_en_tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert-en")
|
342 |
+
|
343 |
+
inputs = _luotuo_en_tokenizer(texts, padding=True, truncation=False, return_tensors="pt")
|
344 |
+
inputs = inputs.to(device)
|
345 |
+
|
346 |
+
with torch.no_grad():
|
347 |
+
embeddings = _luotuo_model_en(**inputs, output_hidden_states=True, return_dict=True, sent_emb=True).pooler_output
|
348 |
+
|
349 |
+
return embeddings
|
350 |
+
|
351 |
+
|
352 |
+
def get_embedding_for_chinese(model, texts):
|
353 |
+
model = model.to(device)
|
354 |
+
# str or strList
|
355 |
+
texts = texts if isinstance(texts, list) else [texts]
|
356 |
+
# 截断
|
357 |
+
for i in range(len(texts)):
|
358 |
+
if len(texts[i]) > 510:
|
359 |
+
texts[i] = texts[i][:510]
|
360 |
+
if len(texts) >= 64:
|
361 |
+
embeddings = []
|
362 |
+
chunk_size = 64
|
363 |
+
for i in range(0, len(texts), chunk_size):
|
364 |
+
embeddings.append(luotuo_embedding(model, texts[i: i + chunk_size]))
|
365 |
+
return torch.cat(embeddings, dim=0)
|
366 |
+
else:
|
367 |
+
return luotuo_embedding(model, texts)
|
368 |
+
|
369 |
+
|
370 |
+
def is_chinese_or_english(text):
|
371 |
+
# no longer use online openai api
|
372 |
+
return "chinese"
|
373 |
+
|
374 |
+
text = list(text)
|
375 |
+
is_chinese, is_english = 0, 0
|
376 |
+
|
377 |
+
for char in text:
|
378 |
+
# 判断字符的Unicode值是否在中文字符的Unicode范围内
|
379 |
+
if '\u4e00' <= char <= '\u9fa5':
|
380 |
+
is_chinese += 4
|
381 |
+
# 判断字符是否为英文字符(包括大小写字母和常见标点符号)
|
382 |
+
elif ('\u0041' <= char <= '\u005a') or ('\u0061' <= char <= '\u007a'):
|
383 |
+
is_english += 1
|
384 |
+
if is_chinese >= is_english:
|
385 |
+
return "chinese"
|
386 |
+
else:
|
387 |
+
return "english"
|
388 |
+
|
389 |
+
|
390 |
+
def get_embedding_openai(text, model="text-embedding-ada-002"):
|
391 |
+
text = text.replace("\n", " ")
|
392 |
+
return client.embeddings.create(input = [text], model=model).data[0].embedding
|
393 |
+
|
394 |
+
def get_embedding_for_english(text, model="text-embedding-ada-002"):
|
395 |
+
text = text.replace("\n", " ")
|
396 |
+
return client.embeddings.create(input = [text], model=model).data[0].embedding
|
397 |
+
|
398 |
+
import os
|
399 |
+
|
400 |
+
def luotuo_openai_embedding(texts, is_chinese= None ):
|
401 |
+
"""
|
402 |
+
when input is chinese, use luotuo_embedding
|
403 |
+
when input is english, use openai_embedding
|
404 |
+
texts can be a list or a string
|
405 |
+
when texts is a list, return a list of embeddings, using batch inference
|
406 |
+
when texts is a string, return a single embedding
|
407 |
+
"""
|
408 |
+
|
409 |
+
openai_key = os.environ.get("OPENAI_API_KEY")
|
410 |
+
|
411 |
+
if isinstance(texts, list):
|
412 |
+
index = random.randint(0, len(texts) - 1)
|
413 |
+
if openai_key is None or is_chinese_or_english(texts[index]) == "chinese":
|
414 |
+
return [embed.cpu().tolist() for embed in get_embedding_for_chinese(get_luotuo_model(), texts)]
|
415 |
+
else:
|
416 |
+
return [get_embedding_for_english(text) for text in texts]
|
417 |
+
else:
|
418 |
+
if openai_key is None or is_chinese_or_english(texts) == "chinese":
|
419 |
+
return get_embedding_for_chinese(get_luotuo_model(), texts)[0].cpu().tolist()
|
420 |
+
else:
|
421 |
+
return get_embedding_for_english(texts)
|
422 |
+
|
423 |
+
|
424 |
+
# compute cosine similarity between two vector
|
425 |
+
def get_cosine_similarity( v1, v2):
|
426 |
+
v1 = torch.tensor(v1).to(device)
|
427 |
+
v2 = torch.tensor(v2).to(device)
|
428 |
+
return torch.cosine_similarity(v1, v2, dim=0).item()
|
429 |
+
|
430 |
+
|
431 |
+
|