|
import hydra |
|
from hydra.core.global_hydra import GlobalHydra |
|
from omegaconf import DictConfig, OmegaConf |
|
import streamlit as st |
|
from PIL import Image |
|
import os |
|
import sys |
|
sys.path.append(os.path.dirname(__file__)) |
|
from download_models import download_model |
|
|
|
|
|
@st.cache_resource |
|
def load_simple_rag(config, used_lmdeploy=False): |
|
|
|
data_source_dir = config["data_source_dir"] |
|
db_persist_directory = config["db_persist_directory"] |
|
llm_model = config["llm_model"] |
|
embeddings_model = config["embeddings_model"] |
|
reranker_model = config["reranker_model"] |
|
llm_system_prompt = config["llm_system_prompt"] |
|
rag_prompt_template = config["rag_prompt_template"] |
|
from rag.simple_rag import WuleRAG |
|
|
|
if not used_lmdeploy: |
|
from rag.simple_rag import InternLM, WuleRAG |
|
base_mode = InternLM(model_path=llm_model, llm_system_prompt=llm_system_prompt) |
|
else: |
|
from deploy.lmdeploy_model import LmdeployLM, GenerationConfig |
|
cache_max_entry_count = config.get("cache_max_entry_count", 0.2) |
|
base_mode = LmdeployLM(model_path=llm_model, llm_system_prompt=llm_system_prompt, cache_max_entry_count=cache_max_entry_count) |
|
|
|
|
|
wulewule_rag = WuleRAG(data_source_dir, db_persist_directory, base_mode, embeddings_model, reranker_model, rag_prompt_template) |
|
return wulewule_rag |
|
|
|
|
|
@st.cache_resource |
|
def load_wulewule_agent(config): |
|
from agent.wulewule_agent import MultiModalAssistant, Settings |
|
use_remote = config["use_remote"] |
|
SiliconFlow_api = config["SiliconFlow_api"] |
|
data_source_dir = config["data_source_dir"] |
|
if len(SiliconFlow_api)<51 and os.environ.get('SiliconFlow_api', ""): |
|
SiliconFlow_api = os.environ.get('SiliconFlow_api') |
|
|
|
print(f"======= loading llm =======") |
|
if use_remote: |
|
from llama_index.llms.siliconflow import SiliconFlow |
|
from llama_index.embeddings.siliconflow import SiliconFlowEmbedding |
|
api_base_url = "https://api.siliconflow.cn/v1/chat/completions" |
|
|
|
|
|
remote_llm = config["remote_llm"] |
|
remote_embeddings_model = config["remote_embeddings_model"] |
|
llm = SiliconFlow( model=remote_llm, base_url=api_base_url, api_key=SiliconFlow_api, max_tokens=4096) |
|
embed_model = SiliconFlowEmbedding( model=remote_embeddings_model, api_key=SiliconFlow_api) |
|
else: |
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding |
|
from llama_index.llms.huggingface import HuggingFaceLLM |
|
local_llm = config["llm_model"] |
|
local_embeddings_model = config["agent_embeddings_model"] |
|
llm = HuggingFaceLLM( |
|
model_name=local_llm, |
|
tokenizer_name=local_llm, |
|
model_kwargs={"trust_remote_code":True}, |
|
tokenizer_kwargs={"trust_remote_code":True}, |
|
|
|
|
|
) |
|
embed_model = HuggingFaceEmbedding( |
|
model_name=local_embeddings_model |
|
) |
|
|
|
Settings.llm = llm |
|
Settings.embed_model = embed_model |
|
wulewule_assistant = MultiModalAssistant(data_source_dir, llm, SiliconFlow_api) |
|
print(f"======= finished loading ! =======") |
|
return wulewule_assistant |
|
|
|
|
|
GlobalHydra.instance().clear() |
|
@hydra.main(version_base=None, config_path="./configs", config_name="model_cfg") |
|
def main(cfg): |
|
|
|
config_dict = OmegaConf.to_container(cfg, resolve=True) |
|
|
|
|
|
if not config_dict["use_remote"] and not os.path.exists(config_dict["llm_model"]): |
|
download_model(llm_model_path =config_dict["llm_model"]) |
|
|
|
|
|
if cfg.agent_mode: |
|
|
|
wulewule_assistant = load_wulewule_agent(config_dict) |
|
cfg.use_rag = False |
|
cfg.use_lmdepoly = False |
|
|
|
if cfg.use_rag: |
|
|
|
wulewule_model = load_simple_rag(config_dict, used_lmdeploy=cfg.use_lmdepoly) |
|
elif ( cfg.use_lmdepoly): |
|
|
|
from deploy.lmdeploy_model import load_turbomind_model, GenerationConfig |
|
wulewule_model = load_turbomind_model(config_dict["llm_model"], config_dict["llm_system_prompt"], config_dict["cache_max_entry_count"]) |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state["messages"] = [] |
|
|
|
with st.sidebar: |
|
st.markdown("## 悟了悟了💡") |
|
logo_path = "assets/sd_wulewule.webp" |
|
if os.path.exists(logo_path): |
|
image = Image.open(logo_path) |
|
st.image(image, caption='wulewule') |
|
"[InternLM](https://github.com/InternLM)" |
|
"[悟了悟了](https://github.com/xzyun2011/wulewule.git)" |
|
|
|
|
|
|
|
st.title("悟了悟了:黑神话悟空AI助手🐒") |
|
|
|
|
|
for msg in st.session_state.messages: |
|
st.chat_message("user").write(msg["user"]) |
|
assistant_res = msg["assistant"] |
|
if isinstance(assistant_res, str): |
|
st.chat_message("assistant").write(assistant_res) |
|
elif cfg.agent_mode and isinstance(assistant_res, dict): |
|
image_url = assistant_res["image_url"] |
|
audio_text = assistant_res["audio_text"] |
|
st.chat_message("assistant").write(assistant_res["response"]) |
|
if image_url: |
|
|
|
st.image( image_url, width=256 ) |
|
if audio_text: |
|
|
|
st.audio("audio.mp3") |
|
st.write(f"语音内容为: \n\n{audio_text}") |
|
|
|
|
|
if prompt := st.chat_input("请输入你的问题,换行使用Shfit+Enter。"): |
|
|
|
st.chat_message("user").write(prompt) |
|
|
|
full_answer = "" |
|
if cfg.agent_mode: |
|
with st.chat_message('robot'): |
|
message_placeholder = st.empty() |
|
response_dict = wulewule_assistant.chat(prompt) |
|
image_url = response_dict["image_url"] |
|
audio_text = response_dict["audio_text"] |
|
for cur_response in response_dict["response"]: |
|
full_answer += cur_response |
|
|
|
message_placeholder.markdown(full_answer + '▌') |
|
message_placeholder.markdown(full_answer) |
|
|
|
st.session_state.messages.append({"user": prompt, "assistant": response_dict}) |
|
if image_url: |
|
|
|
st.image( image_url, width=256 ) |
|
|
|
if audio_text: |
|
|
|
st.audio("audio.mp3") |
|
st.write(f"语音内容为: \n\n{audio_text}") |
|
|
|
|
|
else: |
|
if cfg.stream_response: |
|
|
|
with st.chat_message('robot'): |
|
message_placeholder = st.empty() |
|
if cfg.use_rag: |
|
for cur_response in wulewule_model.query_stream(prompt): |
|
full_answer += cur_response |
|
|
|
message_placeholder.markdown(full_answer + '▌') |
|
elif cfg.use_lmdepoly: |
|
|
|
|
|
|
|
|
|
|
|
messages = [{'role': 'user', 'content': f'{prompt}'}] |
|
for response in wulewule_model.stream_infer(messages): |
|
full_answer += response.text |
|
|
|
message_placeholder.markdown(full_answer + '▌') |
|
|
|
message_placeholder.markdown(full_answer) |
|
|
|
else: |
|
if cfg.use_lmdepoly: |
|
messages = [{'role': 'user', 'content': f'{prompt}'}] |
|
full_answer = wulewule_model(messages).text |
|
elif cfg.use_rag: |
|
full_answer = wulewule_model.query(prompt) |
|
|
|
st.chat_message("assistant").write(full_answer) |
|
|
|
|
|
st.session_state.messages.append({"user": prompt, "assistant": full_answer}) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |