chatbot-demo / app.py
surkovvv's picture
5 min timeout
bb80087
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
from functools import partial
tokenizer = AutoTokenizer.from_pretrained("IlyaGusev/saiga_llama3_8b")
model = AutoModelForCausalLM.from_pretrained("IlyaGusev/saiga_llama3_8b", torch_dtype=torch.bfloat16)
model = model
def transform_history(history):
transformed_history = []
for qa_pair in history:
transformed_history.append({"role": "user", "content": qa_pair[0]})
transformed_history.append({"role": "assistant", "content": qa_pair[1]})
return transformed_history
def predict(message, history):
# print(history) [[вопрос1, ответ1], [вопрос2, ответ2]...]
history = transform_history(history)
history_transformer_format = history + [{"role": "user", "content": message},
{"role": "assistant", "content": ""}]
model_inputs = tokenizer.apply_chat_template(history_transformer_format, return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer, timeout=300, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
top_p=0.95,
top_k=1000,
temperature=1.0,
num_beams=1,
)
generating_func = partial(model.generate, model_inputs)
t = Thread(target=generating_func, kwargs=generate_kwargs)
t.start()
partial_message = ""
for new_token in streamer:
if 'assistant' not in new_token:
partial_message += new_token
yield partial_message
gr.ChatInterface(predict).launch(share=True)