File size: 1,543 Bytes
f8515e0
fa2391a
 
f8515e0
fa2391a
 
f8515e0
fa2391a
 
 
d1c4f7a
fa2391a
 
d1c4f7a
fa2391a
 
d1c4f7a
fa2391a
 
d1c4f7a
fa2391a
 
 
d1c4f7a
fa2391a
 
 
 
 
 
 
 
d1c4f7a
fa2391a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")

def chatbot(input_text, chat_history):
    # Encode the new user input, add the eos_token and return a tensor in Pytorch
    new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt')

    # Append the new user input tokens to the chat history
    bot_input_ids = torch.cat([torch.tensor(chat_history), new_user_input_ids], dim=-1) if chat_history else new_user_input_ids

    # Generate a response while limiting the total chat history to 1000 tokens,
    chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)

    # Decode the last output tokens from bot
    output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)

    # Append to chat history for next turn. Important: Return the *full* chat history tensor to Gradio
    chat_history_tensor = chat_history_ids.tolist()
    return output, chat_history_tensor

iface = gr.ChatInterface(
    fn=chatbot,
    inputs=["text", "state"], # "state" will hold the chat history as a tensor list
    outputs=["text", "state"],
    title="DialoGPT Chatbot (Small)",
    description="Simple chat application using microsoft/DialoGPT-small model. Try it out!",
    examples=["Hello", "How are you?", "Tell me a joke"]
)

iface.launch()