Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,121 +1,34 @@
|
|
1 |
-
# app.py
|
2 |
import gradio as gr
|
3 |
-
import
|
4 |
-
|
5 |
-
import logging
|
6 |
-
import os
|
7 |
|
8 |
-
|
9 |
-
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
TOKENIZER_NAME = MODEL_NAME # Use the same name for tokenizer
|
15 |
|
16 |
-
#
|
17 |
-
|
18 |
|
19 |
-
#
|
20 |
-
|
21 |
-
tokenizer = None
|
22 |
|
23 |
-
#
|
24 |
-
|
25 |
-
global ort_session, tokenizer
|
26 |
-
try:
|
27 |
-
logging.info(f"Checking for ONNX model at: {ONNX_MODEL_PATH}")
|
28 |
-
if not os.path.exists(ONNX_MODEL_PATH):
|
29 |
-
logging.warning(f"ONNX model not found at {ONNX_MODEL_PATH}. Please ensure it exists. Refer to README for conversion instructions.")
|
30 |
-
return False # Model file missing - indicate failure
|
31 |
|
32 |
-
|
33 |
-
|
|
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
|
40 |
-
return True # Success
|
41 |
-
|
42 |
-
except Exception as e:
|
43 |
-
logging.error(f"Error during model or tokenizer loading: {e}")
|
44 |
-
if ort_session:
|
45 |
-
del ort_session # Attempt to cleanup in case of failure during loading
|
46 |
-
ort_session = None
|
47 |
-
tokenizer = None
|
48 |
-
return False # Failure
|
49 |
-
|
50 |
-
# Pre-load model and tokenizer on app startup
|
51 |
-
model_loaded_successfully = load_model_and_tokenizer()
|
52 |
-
|
53 |
-
|
54 |
-
# --- Inference Function ---
|
55 |
-
def predict(message, history):
|
56 |
-
if not model_loaded_successfully:
|
57 |
-
logging.warning("Model not loaded, returning fallback message.")
|
58 |
-
return FALLBACK_MESSAGE
|
59 |
-
|
60 |
-
if ort_session is None or tokenizer is None: # Double check after global check, for robustness.
|
61 |
-
logging.error("ONNX Session or Tokenizer is unexpectedly None in predict function.")
|
62 |
-
return FALLBACK_MESSAGE
|
63 |
-
|
64 |
-
try:
|
65 |
-
# Reconstruct conversation history for DialoGPT input
|
66 |
-
input_text = ""
|
67 |
-
for human_msg, bot_response in history: # History comes as list of lists [user_msg, bot_response] pairs
|
68 |
-
input_text += human_msg + tokenizer.eos_token
|
69 |
-
input_text += message + tokenizer.eos_token
|
70 |
-
|
71 |
-
inputs = tokenizer(input_text, return_tensors="np")
|
72 |
-
|
73 |
-
# Get input and output names - essential for ONNX Runtime
|
74 |
-
input_name = ort_session.get_inputs()[0].name
|
75 |
-
output_name = ort_session.get_outputs()[0].name # Assuming output is logits for generation
|
76 |
-
|
77 |
-
|
78 |
-
ort_inputs = {input_name: inputs['input_ids']} # Only input_ids typically needed for simple generation with DialoGPT
|
79 |
-
|
80 |
-
ort_outputs = ort_session.run([output_name], ort_inputs) # Run inference
|
81 |
-
|
82 |
-
logits = ort_outputs[0] # logits from the model
|
83 |
-
|
84 |
-
# Basic generation - argmax for simplicity. For better responses, consider more sophisticated decoding (sampling)
|
85 |
-
predicted_token_ids = logits.argmax(axis=-1) # Pick token with highest probability. Very simple decoding.
|
86 |
-
|
87 |
-
# Decode ONLY the last generated turn to get the bot's response.
|
88 |
-
# Find the EOS token indices to split the input_text (which includes history) and extract only the NEW response
|
89 |
-
generated_text = tokenizer.decode(predicted_token_ids[0, inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
|
90 |
-
|
91 |
-
|
92 |
-
if not generated_text.strip(): # Handle empty responses from the model.
|
93 |
-
logging.info("Model returned an empty response, using default.")
|
94 |
-
return "I'm not sure what to say." # Or a more specific fallback.
|
95 |
-
|
96 |
-
|
97 |
-
return generated_text.strip()
|
98 |
-
|
99 |
-
|
100 |
-
except Exception as e:
|
101 |
-
logging.error(f"Error during inference: {e}")
|
102 |
-
return FALLBACK_MESSAGE
|
103 |
-
|
104 |
-
|
105 |
-
# --- Gradio Interface ---
|
106 |
-
if __name__ == "__main__":
|
107 |
-
iface = gr.ChatInterface(
|
108 |
-
fn=predict,
|
109 |
-
textbox=gr.Textbox(placeholder="Type your message here...", label="User Input"),
|
110 |
-
chatbot=gr.Chatbot(label="Chatbot Response"),
|
111 |
-
title="DialoGPT Chatbot (ONNX Runtime)",
|
112 |
-
description="Chat with a simple DialoGPT-small chatbot powered by ONNX Runtime for faster inference. This is a basic demonstration. For better performance in a real-world setting, ensure you have a properly quantized and optimized ONNX model. **Note:** For this Space to work, you must upload a `dialogpt-small.onnx` file. Refer to the README on how to convert and optimize the model.",
|
113 |
-
examples=[
|
114 |
-
["Hello, how are you today?"],
|
115 |
-
["Tell me a joke"],
|
116 |
-
["What is the weather like in London?"]
|
117 |
-
],
|
118 |
-
# Removed retry_btn, undo_btn, clear_btn to resolve TypeError
|
119 |
-
theme="default"
|
120 |
-
)
|
121 |
-
iface.launch()
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
+
import torch
|
|
|
|
|
4 |
|
5 |
+
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
|
6 |
+
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
|
7 |
|
8 |
+
def chatbot(input_text, chat_history):
|
9 |
+
# Encode the new user input, add the eos_token and return a tensor in Pytorch
|
10 |
+
new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt')
|
|
|
11 |
|
12 |
+
# Append the new user input tokens to the chat history
|
13 |
+
bot_input_ids = torch.cat([torch.tensor(chat_history), new_user_input_ids], dim=-1) if chat_history else new_user_input_ids
|
14 |
|
15 |
+
# Generate a response while limiting the total chat history to 1000 tokens,
|
16 |
+
chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
|
|
|
17 |
|
18 |
+
# Decode the last output tokens from bot
|
19 |
+
output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
# Append to chat history for next turn. Important: Return the *full* chat history tensor to Gradio
|
22 |
+
chat_history_tensor = chat_history_ids.tolist()
|
23 |
+
return output, chat_history_tensor
|
24 |
|
25 |
+
iface = gr.ChatInterface(
|
26 |
+
fn=chatbot,
|
27 |
+
inputs=["text", "state"], # "state" will hold the chat history as a tensor list
|
28 |
+
outputs=["text", "state"],
|
29 |
+
title="DialoGPT Chatbot (Small)",
|
30 |
+
description="Simple chat application using microsoft/DialoGPT-small model. Try it out!",
|
31 |
+
examples=["Hello", "How are you?", "Tell me a joke"]
|
32 |
+
)
|
33 |
|
34 |
+
iface.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|