lokidev commited on
Commit
fa2391a
·
verified ·
1 Parent(s): 82a8632

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -112
app.py CHANGED
@@ -1,121 +1,34 @@
1
- # app.py
2
  import gradio as gr
3
- import onnxruntime
4
- from transformers import AutoTokenizer
5
- import logging
6
- import os
7
 
8
- # Set up logging
9
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
10
 
11
- # Model and Tokenizer paths - Define ONNX model path
12
- MODEL_NAME = "microsoft/DialoGPT-small"
13
- ONNX_MODEL_PATH = "dialogpt-small.onnx" # Path to the ONNX model
14
- TOKENIZER_NAME = MODEL_NAME # Use the same name for tokenizer
15
 
16
- # Fallback message in case of errors
17
- FALLBACK_MESSAGE = "Sorry, I am having trouble processing your request. Please try again later."
18
 
19
- # Global variables to hold loaded model and tokenizer
20
- ort_session = None
21
- tokenizer = None
22
 
23
- # --- Model Loading and Preprocessing ---
24
- def load_model_and_tokenizer():
25
- global ort_session, tokenizer
26
- try:
27
- logging.info(f"Checking for ONNX model at: {ONNX_MODEL_PATH}")
28
- if not os.path.exists(ONNX_MODEL_PATH):
29
- logging.warning(f"ONNX model not found at {ONNX_MODEL_PATH}. Please ensure it exists. Refer to README for conversion instructions.")
30
- return False # Model file missing - indicate failure
31
 
32
- logging.info("Loading ONNX Runtime session...")
33
- ort_session = onnxruntime.InferenceSession(ONNX_MODEL_PATH, providers=['CPUExecutionProvider']) # Explicitly using CPU provider for simplicity for this example, you could expand providers
 
34
 
35
- logging.info("Loading tokenizer...")
36
- tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
37
- logging.info(f"Tokenizer loaded: {tokenizer.__class__.__name__}")
 
 
 
 
 
38
 
39
- logging.info("Model and Tokenizer loaded successfully using ONNX Runtime.")
40
- return True # Success
41
-
42
- except Exception as e:
43
- logging.error(f"Error during model or tokenizer loading: {e}")
44
- if ort_session:
45
- del ort_session # Attempt to cleanup in case of failure during loading
46
- ort_session = None
47
- tokenizer = None
48
- return False # Failure
49
-
50
- # Pre-load model and tokenizer on app startup
51
- model_loaded_successfully = load_model_and_tokenizer()
52
-
53
-
54
- # --- Inference Function ---
55
- def predict(message, history):
56
- if not model_loaded_successfully:
57
- logging.warning("Model not loaded, returning fallback message.")
58
- return FALLBACK_MESSAGE
59
-
60
- if ort_session is None or tokenizer is None: # Double check after global check, for robustness.
61
- logging.error("ONNX Session or Tokenizer is unexpectedly None in predict function.")
62
- return FALLBACK_MESSAGE
63
-
64
- try:
65
- # Reconstruct conversation history for DialoGPT input
66
- input_text = ""
67
- for human_msg, bot_response in history: # History comes as list of lists [user_msg, bot_response] pairs
68
- input_text += human_msg + tokenizer.eos_token
69
- input_text += message + tokenizer.eos_token
70
-
71
- inputs = tokenizer(input_text, return_tensors="np")
72
-
73
- # Get input and output names - essential for ONNX Runtime
74
- input_name = ort_session.get_inputs()[0].name
75
- output_name = ort_session.get_outputs()[0].name # Assuming output is logits for generation
76
-
77
-
78
- ort_inputs = {input_name: inputs['input_ids']} # Only input_ids typically needed for simple generation with DialoGPT
79
-
80
- ort_outputs = ort_session.run([output_name], ort_inputs) # Run inference
81
-
82
- logits = ort_outputs[0] # logits from the model
83
-
84
- # Basic generation - argmax for simplicity. For better responses, consider more sophisticated decoding (sampling)
85
- predicted_token_ids = logits.argmax(axis=-1) # Pick token with highest probability. Very simple decoding.
86
-
87
- # Decode ONLY the last generated turn to get the bot's response.
88
- # Find the EOS token indices to split the input_text (which includes history) and extract only the NEW response
89
- generated_text = tokenizer.decode(predicted_token_ids[0, inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
90
-
91
-
92
- if not generated_text.strip(): # Handle empty responses from the model.
93
- logging.info("Model returned an empty response, using default.")
94
- return "I'm not sure what to say." # Or a more specific fallback.
95
-
96
-
97
- return generated_text.strip()
98
-
99
-
100
- except Exception as e:
101
- logging.error(f"Error during inference: {e}")
102
- return FALLBACK_MESSAGE
103
-
104
-
105
- # --- Gradio Interface ---
106
- if __name__ == "__main__":
107
- iface = gr.ChatInterface(
108
- fn=predict,
109
- textbox=gr.Textbox(placeholder="Type your message here...", label="User Input"),
110
- chatbot=gr.Chatbot(label="Chatbot Response"),
111
- title="DialoGPT Chatbot (ONNX Runtime)",
112
- description="Chat with a simple DialoGPT-small chatbot powered by ONNX Runtime for faster inference. This is a basic demonstration. For better performance in a real-world setting, ensure you have a properly quantized and optimized ONNX model. **Note:** For this Space to work, you must upload a `dialogpt-small.onnx` file. Refer to the README on how to convert and optimize the model.",
113
- examples=[
114
- ["Hello, how are you today?"],
115
- ["Tell me a joke"],
116
- ["What is the weather like in London?"]
117
- ],
118
- # Removed retry_btn, undo_btn, clear_btn to resolve TypeError
119
- theme="default"
120
- )
121
- iface.launch()
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
 
 
4
 
5
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
6
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
7
 
8
+ def chatbot(input_text, chat_history):
9
+ # Encode the new user input, add the eos_token and return a tensor in Pytorch
10
+ new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt')
 
11
 
12
+ # Append the new user input tokens to the chat history
13
+ bot_input_ids = torch.cat([torch.tensor(chat_history), new_user_input_ids], dim=-1) if chat_history else new_user_input_ids
14
 
15
+ # Generate a response while limiting the total chat history to 1000 tokens,
16
+ chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
 
17
 
18
+ # Decode the last output tokens from bot
19
+ output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
 
 
 
 
 
 
20
 
21
+ # Append to chat history for next turn. Important: Return the *full* chat history tensor to Gradio
22
+ chat_history_tensor = chat_history_ids.tolist()
23
+ return output, chat_history_tensor
24
 
25
+ iface = gr.ChatInterface(
26
+ fn=chatbot,
27
+ inputs=["text", "state"], # "state" will hold the chat history as a tensor list
28
+ outputs=["text", "state"],
29
+ title="DialoGPT Chatbot (Small)",
30
+ description="Simple chat application using microsoft/DialoGPT-small model. Try it out!",
31
+ examples=["Hello", "How are you?", "Tell me a joke"]
32
+ )
33
 
34
+ iface.launch()