Daemontatox commited on
Commit
c8e2710
·
verified ·
1 Parent(s): 22e29b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -273
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import os
3
  import re
4
  import time
@@ -10,334 +9,214 @@ from transformers import (
10
  AutoModelForCausalLM,
11
  AutoTokenizer,
12
  BitsAndBytesConfig,
13
- TextIteratorStreamer
 
 
14
  )
15
 
16
  # Configuration Constants
17
- MODEL_ID="deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
18
-
19
 
20
- # Understand]: Analyze the question to identify key details and clarify the goal.
21
- # [Plan]: Outline a logical, step-by-step approach to address the question or problem.
22
- # [Reason]: Execute the plan, applying logical reasoning, calculations, or analysis to reach a conclusion. Document each step clearly.
23
- # [Reflect]: Review the reasoning and the final answer to ensure it is accurate, complete, and adheres to the principle of openness.
24
- # [Respond]: Present a well-structured and transparent answer, enriched with supporting details as needed.
25
- # Use these tags as headers in your response to make your thought process easy to follow and aligned with the principle of openness.
 
26
 
27
- DEFAULT_SYSTEM_PROMPT ="""
28
- You are an intelligent assistant , You should think Step by Step.
29
 
30
- """
31
  # UI Configuration
32
- TITLE = "<h1><center>AI Reasoning Assistant</center></h1>"
33
- PLACEHOLDER = "Ask me anything! I'll think through it step by step."
34
-
 
35
  CSS = """
36
- .duplicate-button {
37
- margin: auto !important;
38
- color: white !important;
39
- background: black !important;
40
- border-radius: 100vh !important;
41
- }
42
- h3 {
43
- text-align: center;
44
- }
45
- .message-wrap {
46
- overflow-x: auto;
47
- }
48
- .message-wrap p {
49
- margin-bottom: 1em;
50
- }
51
- .message-wrap pre {
52
- background-color: #f6f8fa;
53
- border-radius: 3px;
54
- padding: 16px;
55
- overflow-x: auto;
56
- }
57
- .message-wrap code {
58
- background-color: rgba(175,184,193,0.2);
59
- border-radius: 3px;
60
- padding: 0.2em 0.4em;
61
- font-family: monospace;
62
- }
63
- .custom-tag {
64
- color: #0066cc;
65
- font-weight: bold;
66
- }
67
- .chat-area {
68
- height: 500px !important;
69
- overflow-y: auto !important;
70
- }
71
  """
72
 
 
 
 
 
 
73
  def initialize_model():
74
- """Initialize the model with appropriate configurations"""
 
 
 
75
  quantization_config = BitsAndBytesConfig(
76
  load_in_4bit=True,
77
  bnb_4bit_compute_dtype=torch.bfloat16,
78
  bnb_4bit_quant_type="nf4",
79
  bnb_4bit_use_double_quant=True,
80
- #llm_int8_enable_fp32_cpu_offload=True
81
  )
82
 
83
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID , trust_remote_code=True)
84
- if tokenizer.pad_token_id is None:
85
- tokenizer.pad_token_id = tokenizer.eos_token_id
86
 
87
  model = AutoModelForCausalLM.from_pretrained(
88
  MODEL_ID,
89
- torch_dtype="auto",
90
- device_map="cuda",
91
- # attn_implementation="flash_attention_2",
92
- trust_remote_code=True,
93
- quantization_config=quantization_config
94
-
95
  )
96
 
97
  return model, tokenizer
98
 
99
- def format_text(text):
100
- """Format text with proper spacing and tag highlighting (but keep tags visible)"""
101
- tag_patterns = [
102
- (r'<Thinking>', '\n<Thinking>\n'),
103
- (r'</Thinking>', '\n</Thinking>\n'),
104
- (r'<Critique>', '\n<Critique>\n'),
105
- (r'</Critique>', '\n</Critique>\n'),
106
- (r'<Revising>', '\n<Revising>\n'),
107
- (r'</Revising>', '\n</Revising>\n'),
108
- (r'<Final>', '\n<Final>\n'),
109
- (r'</Final>', '\n</Final>\n')
110
- ]
111
-
112
- formatted = text
113
- for pattern, replacement in tag_patterns:
114
- formatted = re.sub(pattern, replacement, formatted)
115
-
116
- formatted = '\n'.join(line for line in formatted.split('\n') if line.strip())
117
-
118
  return formatted
119
 
120
- def format_chat_history(history):
121
- """Format chat history for display, keeping tags visible"""
122
- formatted = []
123
- for user_msg, assistant_msg in history:
124
- formatted.append(f"User: {user_msg}")
125
- if assistant_msg:
126
- formatted.append(f"Assistant: {assistant_msg}")
127
- return "\n\n".join(formatted)
128
-
129
- def create_examples():
130
- """Create example queries for the UI"""
131
- return [
132
- "Explain the concept of artificial intelligence.",
133
- "How does photosynthesis work?",
134
- "What are the main causes of climate change?",
135
- "Describe the process of protein synthesis.",
136
- "What are the key features of a democratic government?",
137
- "Explain the theory of relativity.",
138
- "How do vaccines work to prevent diseases?",
139
- "What are the major events of World War II?",
140
- "Describe the structure of a human cell.",
141
- "What is the role of DNA in genetics?"
142
- ]
143
-
144
  @spaces.GPU(duration=120)
145
  def chat_response(
146
  message: str,
147
  history: list,
148
- chat_display: str,
149
  system_prompt: str,
150
  temperature: float = 0.3,
151
- max_new_tokens: int =4096 ,
152
- top_p: float = 0.1,
153
- top_k: int = 45,
154
- penalty: float = 1.5,
155
  ):
156
- """Generate chat responses, keeping tags visible in the output"""
157
- conversation = [
158
- {"role": "system", "content": system_prompt}
159
- ]
160
-
161
- for prompt, answer in history:
162
- conversation.extend([
163
- {"role": "user", "content": prompt},
164
- {"role": "assistant", "content": answer}
165
- ])
166
-
167
- conversation.append({"role": "user", "content": message})
168
-
169
- input_ids = tokenizer.apply_chat_template(
170
- conversation,
171
- add_generation_prompt=True,
172
- return_tensors="pt"
173
- ).to(model.device)
174
-
175
- streamer = TextIteratorStreamer(
176
- tokenizer,
177
- timeout=60.0,
178
- skip_prompt=True,
179
- skip_special_tokens=True
180
- )
181
-
182
- generate_kwargs = dict(
183
- input_ids=input_ids,
184
- max_new_tokens=max_new_tokens,
185
- do_sample=False if temperature == 0 else True,
186
- top_p=top_p,
187
- top_k=top_k,
188
- temperature=temperature,
189
- repetition_penalty=penalty,
190
- streamer=streamer,
191
- )
192
-
193
- buffer = ""
194
-
195
- with torch.no_grad():
196
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
197
  thread.start()
198
-
199
- history = history + [[message, ""]]
200
-
201
  for new_text in streamer:
202
- buffer += new_text
203
- formatted_buffer = format_text(buffer)
204
- history[-1][1] = formatted_buffer
205
- chat_display = format_chat_history(history)
206
 
207
- yield history, chat_display
 
 
 
 
 
 
 
208
 
209
- def process_example(example: str) -> tuple:
210
- """Process example query and return empty history and updated display"""
211
- return [], f"User: {example}\n\n"
 
 
 
 
 
 
 
 
 
212
 
213
  def main():
214
- """Main function to set up and launch the Gradio interface"""
215
  global model, tokenizer
216
  model, tokenizer = initialize_model()
217
-
218
- with gr.Blocks(css=CSS, theme="soft") as demo:
219
  gr.HTML(TITLE)
220
- gr.DuplicateButton(
221
- value="Duplicate Space for private use",
222
- elem_classes="duplicate-button"
223
- )
224
 
225
  with gr.Row():
226
- with gr.Column():
227
- chat_history = gr.State([])
228
- chat_display = gr.TextArea(
229
- value="",
230
- label="Chat History",
231
- interactive=False,
232
- elem_classes=["chat-area"],
233
  )
234
-
235
- message = gr.TextArea(
236
- placeholder=PLACEHOLDER,
237
- label="Your message",
238
- lines=3
239
  )
240
-
241
  with gr.Row():
242
- submit = gr.Button("Send")
243
- clear = gr.Button("Clear")
 
 
 
 
 
 
 
 
244
 
245
- with gr.Accordion("⚙️ Advanced Settings", open=False):
246
  system_prompt = gr.TextArea(
247
  value=DEFAULT_SYSTEM_PROMPT,
248
- label="System Prompt",
249
- lines=5,
250
- )
251
- temperature = gr.Slider(
252
- minimum=0,
253
- maximum=1,
254
- step=0.1,
255
- value=0.3,
256
- label="Temperature",
257
- )
258
- max_tokens = gr.Slider(
259
- minimum=128,
260
- maximum=32000,
261
- step=128,
262
- value=4096,
263
- label="Max Tokens",
264
- )
265
- top_p = gr.Slider(
266
- minimum=0.1,
267
- maximum=1.0,
268
- step=0.1,
269
- value=0.8,
270
- label="Top-p",
271
  )
272
- top_k = gr.Slider(
273
- minimum=1,
274
- maximum=100,
275
- step=1,
276
- value=45,
277
- label="Top-k",
278
- )
279
- penalty = gr.Slider(
280
- minimum=1.0,
281
- maximum=2.0,
282
- step=0.1,
283
- value=1.5,
284
- label="Repetition Penalty",
285
- )
286
-
287
- examples = gr.Examples(
288
- examples=create_examples(),
289
- inputs=[message],
290
- outputs=[chat_history, chat_display],
291
- fn=process_example,
292
- cache_examples=False,
293
- )
294
-
295
- # Set up event handlers
296
- submit_click = submit.click(
297
  chat_response,
298
- inputs=[
299
- message,
300
- chat_history,
301
- chat_display,
302
- system_prompt,
303
- temperature,
304
- max_tokens,
305
- top_p,
306
- top_k,
307
- penalty,
308
- ],
309
- outputs=[chat_history, chat_display],
310
- show_progress=True,
311
- )
312
-
313
- message.submit(
314
  chat_response,
315
- inputs=[
316
- message,
317
- chat_history,
318
- chat_display,
319
- system_prompt,
320
- temperature,
321
- max_tokens,
322
- top_p,
323
- top_k,
324
- penalty,
325
- ],
326
- outputs=[chat_history, chat_display],
327
- show_progress=True,
328
- )
329
-
330
- clear.click(
331
- lambda: ([], ""),
332
- outputs=[chat_history, chat_display],
333
- show_progress=True,
334
- )
335
-
336
- submit_click.then(lambda: "", outputs=message)
337
- message.submit(lambda: "", outputs=message)
338
-
339
  return demo
340
 
341
  if __name__ == "__main__":
342
  demo = main()
343
- demo.launch()
 
 
1
  import os
2
  import re
3
  import time
 
9
  AutoModelForCausalLM,
10
  AutoTokenizer,
11
  BitsAndBytesConfig,
12
+ TextIteratorStreamer,
13
+ StoppingCriteria,
14
+ StoppingCriteriaList
15
  )
16
 
17
  # Configuration Constants
18
+ MODEL_ID = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
 
19
 
20
+ # Enhanced System Prompt
21
+ DEFAULT_SYSTEM_PROMPT = """You are an Expert Reasoning Assistant. Follow these steps:
22
+ [Understand]: Analyze key elements and clarify objectives
23
+ [Plan]: Outline step-by-step methodology
24
+ [Reason]: Execute plan with detailed analysis
25
+ [Verify]: Check logic and evidence
26
+ [Conclude]: Present structured conclusion
27
 
28
+ Use these section headers and maintain technical accuracy with clear explanations."""
 
29
 
 
30
  # UI Configuration
31
+ TITLE = """
32
+ <h1 align="center" style="color: #2d3436; margin-bottom: 0">🧠 AI Reasoning Assistant</h1>
33
+ <p align="center" style="color: #636e72; margin-top: 0">DeepSeek-R1-Distill-Qwen-14B</p>
34
+ """
35
  CSS = """
36
+ .gr-chatbot { min-height: 500px !important; border-radius: 15px !important; }
37
+ .message-wrap pre { background: #f8f9fa !important; padding: 15px !important; }
38
+ .thinking-tag { color: #2ecc71; font-weight: 600; }
39
+ .plan-tag { color: #e67e22; font-weight: 600; }
40
+ .conclude-tag { color: #3498db; font-weight: 600; }
41
+ .control-panel { background: #f8f9fa !important; padding: 20px !important; }
42
+ footer { visibility: hidden !important; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
44
 
45
+ class StopOnTokens(StoppingCriteria):
46
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
47
+ stop_ids = [0] # Add custom stop tokens here
48
+ return input_ids[0][-1] in stop_ids
49
+
50
  def initialize_model():
51
+ """Initialize model with safety checks"""
52
+ if not torch.cuda.is_available():
53
+ raise RuntimeError("CUDA is required for this application")
54
+
55
  quantization_config = BitsAndBytesConfig(
56
  load_in_4bit=True,
57
  bnb_4bit_compute_dtype=torch.bfloat16,
58
  bnb_4bit_quant_type="nf4",
59
  bnb_4bit_use_double_quant=True,
 
60
  )
61
 
62
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
63
+ tokenizer.pad_token = tokenizer.eos_token
 
64
 
65
  model = AutoModelForCausalLM.from_pretrained(
66
  MODEL_ID,
67
+ device_map="auto",
68
+ quantization_config=quantization_config,
69
+ torch_dtype=torch.bfloat16,
70
+ trust_remote_code=True
 
 
71
  )
72
 
73
  return model, tokenizer
74
 
75
+ def format_response(text):
76
+ """Enhanced formatting with syntax highlighting for reasoning steps"""
77
+ formatted = text.replace("[Understand]", '\n<strong class="thinking-tag">[Understand]</strong>\n')
78
+ formatted = formatted.replace("[Plan]", '\n<strong class="plan-tag">[Plan]</strong>\n')
79
+ formatted = formatted.replace("[Conclude]", '\n<strong class="conclude-tag">[Conclude]</strong>\n')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  return formatted
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  @spaces.GPU(duration=120)
83
  def chat_response(
84
  message: str,
85
  history: list,
 
86
  system_prompt: str,
87
  temperature: float = 0.3,
88
+ max_new_tokens: int = 2048,
89
+ top_p: float = 0.9,
90
+ top_k: int = 50,
91
+ penalty: float = 1.2,
92
  ):
93
+ """Improved streaming generator with error handling"""
94
+ try:
95
+ conversation = [{"role": "system", "content": system_prompt}]
96
+ for user, assistant in history:
97
+ conversation.extend([
98
+ {"role": "user", "content": user},
99
+ {"role": "assistant", "content": assistant}
100
+ ])
101
+ conversation.append({"role": "user", "content": message})
102
+
103
+ input_ids = tokenizer.apply_chat_template(
104
+ conversation,
105
+ add_generation_prompt=True,
106
+ return_tensors="pt"
107
+ ).to(model.device)
108
+
109
+ streamer = TextIteratorStreamer(
110
+ tokenizer,
111
+ timeout=30,
112
+ skip_prompt=True,
113
+ skip_special_tokens=True
114
+ )
115
+
116
+ generate_kwargs = dict(
117
+ input_ids=input_ids,
118
+ max_new_tokens=max_new_tokens,
119
+ temperature=temperature,
120
+ top_p=top_p,
121
+ top_k=top_k,
122
+ repetition_penalty=penalty,
123
+ streamer=streamer,
124
+ stopping_criteria=StoppingCriteriaList([StopOnTokens()])
125
+ )
126
+
127
+ buffer = []
 
 
 
 
 
128
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
129
  thread.start()
130
+
 
 
131
  for new_text in streamer:
132
+ buffer.append(new_text)
133
+ partial_result = "".join(buffer)
 
 
134
 
135
+ # Check for complete sections
136
+ if any(tag in partial_result for tag in ["[Understand]", "[Plan]", "[Conclude]"]):
137
+ yield format_response(partial_result)
138
+ else:
139
+ yield format_response(partial_result + " ▌")
140
+
141
+ # Final formatting pass
142
+ yield format_response("".join(buffer))
143
 
144
+ except Exception as e:
145
+ yield f"⚠️ Error generating response: {str(e)}"
146
+
147
+ def create_examples():
148
+ """Enhanced examples with diverse use cases"""
149
+ return [
150
+ ["Explain quantum entanglement in simple terms"],
151
+ ["Design a study plan for learning machine learning"],
152
+ ["Compare blockchain and traditional databases"],
153
+ ["How would you optimize AWS costs for a startup?"],
154
+ ["Explain the ethical implications of CRISPR technology"]
155
+ ]
156
 
157
  def main():
158
+ """Improved UI layout and interactions"""
159
  global model, tokenizer
160
  model, tokenizer = initialize_model()
161
+
162
+ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
163
  gr.HTML(TITLE)
 
 
 
 
164
 
165
  with gr.Row():
166
+ with gr.Column(scale=3):
167
+ chatbot = gr.Chatbot(
168
+ elem_id="chatbot",
169
+ bubble_full_width=False,
170
+ show_copy_button=True,
171
+ render=False
 
172
  )
173
+ msg = gr.Textbox(
174
+ placeholder="Enter your question...",
175
+ label="Ask the Expert",
176
+ container=False
 
177
  )
 
178
  with gr.Row():
179
+ submit_btn = gr.Button("Send", variant="primary")
180
+ clear_btn = gr.Button("Clear", variant="secondary")
181
+
182
+ with gr.Column(scale=1, elem_classes="control-panel"):
183
+ gr.Examples(
184
+ examples=create_examples(),
185
+ inputs=msg,
186
+ label="Example Queries",
187
+ examples_per_page=5
188
+ )
189
 
190
+ with gr.Accordion("⚙️ Generation Parameters", open=False):
191
  system_prompt = gr.TextArea(
192
  value=DEFAULT_SYSTEM_PROMPT,
193
+ label="System Instructions",
194
+ lines=5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  )
196
+ temperature = gr.Slider(0, 2, value=0.7, label="Creativity")
197
+ max_tokens = gr.Slider(128, 4096, value=2048, step=128, label="Max Tokens")
198
+ top_p = gr.Slider(0, 1, value=0.9, step=0.05, label="Focus (Top-p)")
199
+ penalty = gr.Slider(1, 2, value=1.2, step=0.1, label="Repetition Control")
200
+
201
+ # Event handling
202
+ msg.submit(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  chat_response,
204
+ [msg, chatbot, system_prompt, temperature, max_tokens, top_p, penalty],
205
+ [msg, chatbot],
206
+ show_progress="hidden"
207
+ ).then(lambda: "", None, msg)
208
+
209
+ submit_btn.click(
 
 
 
 
 
 
 
 
 
 
210
  chat_response,
211
+ [msg, chatbot, system_prompt, temperature, max_tokens, top_p, penalty],
212
+ [msg, chatbot],
213
+ show_progress="hidden"
214
+ ).then(lambda: "", None, msg)
215
+
216
+ clear_btn.click(lambda: None, None, chatbot, queue=False)
217
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  return demo
219
 
220
  if __name__ == "__main__":
221
  demo = main()
222
+ demo.queue(max_size=20).launch()