prithivMLmods commited on
Commit
c863607
·
verified ·
1 Parent(s): 70b8813

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -12
app.py CHANGED
@@ -5,6 +5,8 @@ import gradio as gr
5
  import spaces
6
  import torch
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
8
 
9
  DESCRIPTION = """
10
  # QwQ Distill
@@ -44,21 +46,60 @@ model.eval()
44
  if tokenizer.pad_token_id is None:
45
  tokenizer.pad_token_id = tokenizer.eos_token_id
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  @spaces.GPU(duration=120)
48
  def generate(
49
- message: str,
50
- chat_history: list[dict],
 
51
  max_new_tokens: int = 1024,
52
  temperature: float = 0.6,
53
  top_p: float = 0.9,
54
  top_k: int = 50,
55
  repetition_penalty: float = 1.2,
56
- ) -> Iterator[str]:
57
- conversation = chat_history.copy()
58
- conversation.append({"role": "user", "content": message})
 
 
 
 
 
 
59
 
60
  # Apply chat template and get input_ids
61
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
62
 
63
  # Create attention mask
64
  attention_mask = torch.ones_like(input_ids)
@@ -94,12 +135,17 @@ def generate(
94
  outputs = []
95
  for text in streamer:
96
  outputs.append(text)
97
- yield "".join(outputs)
 
 
 
 
98
 
99
 
100
  demo = gr.ChatInterface(
101
  fn=generate,
102
  additional_inputs=[
 
103
  gr.Slider(
104
  label="Max new tokens",
105
  minimum=1,
@@ -138,13 +184,12 @@ demo = gr.ChatInterface(
138
  ],
139
  stop_btn=None,
140
  examples=[
141
- ["Write a Python function to reverses a string if it's length is a multiple of 4. def reverse_string(str1): if len(str1) % 4 == 0: return ''.join(reversed(str1)) return str1 print(reverse_string('abcd')) print(reverse_string('python')) "],
142
- ["Rectangle $ABCD$ is the base of pyramid $PABCD$. If $AB = 10$, $BC = 5$, $\overline{PA}\perp \text{plane } ABCD$, and $PA = 8$, then what is the volume of $PABCD$?"],
143
- ["Difference between List comprehension and Lambda in Python lst = [x ** 2 for x in range (1, 11) if x % 2 == 1] print(lst)"],
144
  ["What happens when the sun goes down?"],
145
  ],
146
  cache_examples=False,
147
- type="messages",
148
  description=DESCRIPTION,
149
  css=css,
150
  fill_height=True,
@@ -152,4 +197,4 @@ demo = gr.ChatInterface(
152
 
153
 
154
  if __name__ == "__main__":
155
- demo.queue(max_size=20).launch()
 
5
  import spaces
6
  import torch
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
+ from typing import List, Dict, Optional, Tuple
9
+ from http import HTTPStatus
10
 
11
  DESCRIPTION = """
12
  # QwQ Distill
 
46
  if tokenizer.pad_token_id is None:
47
  tokenizer.pad_token_id = tokenizer.eos_token_id
48
 
49
+ # Define roles for the chat
50
+ class Role:
51
+ SYSTEM = "system"
52
+ USER = "user"
53
+ ASSISTANT = "assistant"
54
+
55
+ # Default system message
56
+ default_system = "You are a helpful assistant."
57
+
58
+ def clear_session() -> List:
59
+ return "", []
60
+
61
+ def modify_system_session(system: str) -> Tuple[str, str, List]:
62
+ if system is None or len(system) == 0:
63
+ system = default_system
64
+ return system, system, []
65
+
66
+ def history_to_messages(history: List, system: str) -> List[Dict]:
67
+ messages = [{'role': Role.SYSTEM, 'content': system}]
68
+ for h in history:
69
+ messages.append({'role': Role.USER, 'content': h[0]})
70
+ messages.append({'role': Role.ASSISTANT, 'content': h[1]})
71
+ return messages
72
+
73
+ def messages_to_history(messages: List[Dict]) -> Tuple[str, List]:
74
+ assert messages[0]['role'] == Role.SYSTEM
75
+ system = messages[0]['content']
76
+ history = []
77
+ for q, r in zip(messages[1::2], messages[2::2]):
78
+ history.append([q['content'], r['content']])
79
+ return system, history
80
+
81
  @spaces.GPU(duration=120)
82
  def generate(
83
+ query: Optional[str],
84
+ history: Optional[List],
85
+ system: str,
86
  max_new_tokens: int = 1024,
87
  temperature: float = 0.6,
88
  top_p: float = 0.9,
89
  top_k: int = 50,
90
  repetition_penalty: float = 1.2,
91
+ ) -> Iterator[Tuple[str, List, str]]:
92
+ if query is None:
93
+ query = ''
94
+ if history is None:
95
+ history = []
96
+
97
+ # Convert history to messages
98
+ messages = history_to_messages(history, system)
99
+ messages.append({'role': Role.USER, 'content': query})
100
 
101
  # Apply chat template and get input_ids
102
+ input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
103
 
104
  # Create attention mask
105
  attention_mask = torch.ones_like(input_ids)
 
135
  outputs = []
136
  for text in streamer:
137
  outputs.append(text)
138
+ response = "".join(outputs)
139
+ # Update history with the new response
140
+ new_messages = messages + [{'role': Role.ASSISTANT, 'content': response}]
141
+ system, new_history = messages_to_history(new_messages)
142
+ yield "", new_history, system
143
 
144
 
145
  demo = gr.ChatInterface(
146
  fn=generate,
147
  additional_inputs=[
148
+ gr.Textbox(label="System Message", value=default_system, lines=2),
149
  gr.Slider(
150
  label="Max new tokens",
151
  minimum=1,
 
184
  ],
185
  stop_btn=None,
186
  examples=[
187
+ ["Write a Python function to reverses a string if it's length is a multiple of 4."],
188
+ ["What is the volume of a pyramid with a rectangular base?"],
189
+ ["Explain the difference between List comprehension and Lambda in Python."],
190
  ["What happens when the sun goes down?"],
191
  ],
192
  cache_examples=False,
 
193
  description=DESCRIPTION,
194
  css=css,
195
  fill_height=True,
 
197
 
198
 
199
  if __name__ == "__main__":
200
+ demo.queue(max_size=20).launch(share=True)