prithivMLmods commited on
Commit
a23a8fc
·
verified ·
1 Parent(s): 7a2c608

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -24
app.py CHANGED
@@ -69,14 +69,6 @@ def history_to_messages(history: List, system: str) -> List[Dict]:
69
  messages.append({'role': Role.ASSISTANT, 'content': h[1]})
70
  return messages
71
 
72
- def messages_to_history(messages: List[Dict]) -> Tuple[str, List]:
73
- assert messages[0]['role'] == Role.SYSTEM
74
- system = messages[0]['content']
75
- history = []
76
- for q, r in zip(messages[1::2], messages[2::2]):
77
- history.append([q['content'], r['content']])
78
- return system, history
79
-
80
  @spaces.GPU(duration=120)
81
  def generate(
82
  query: Optional[str],
@@ -97,26 +89,18 @@ def generate(
97
  messages = history_to_messages(history, system)
98
  messages.append({'role': Role.USER, 'content': query})
99
 
100
- # Apply chat template and get input_ids
101
- input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
102
-
103
- # Create attention mask
104
- attention_mask = torch.ones_like(input_ids)
105
-
106
- # Trim input if it exceeds the maximum token length
107
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
108
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
109
- attention_mask = attention_mask[:, -MAX_INPUT_TOKEN_LENGTH:]
110
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
111
-
112
- input_ids = input_ids.to(model.device)
113
- attention_mask = attention_mask.to(model.device)
114
 
115
  # Set up the streamer for real-time text generation
116
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
117
  generate_kwargs = dict(
118
- input_ids=input_ids,
119
- attention_mask=attention_mask,
120
  streamer=streamer,
121
  max_new_tokens=max_new_tokens,
122
  do_sample=True,
 
69
  messages.append({'role': Role.ASSISTANT, 'content': h[1]})
70
  return messages
71
 
 
 
 
 
 
 
 
 
72
  @spaces.GPU(duration=120)
73
  def generate(
74
  query: Optional[str],
 
89
  messages = history_to_messages(history, system)
90
  messages.append({'role': Role.USER, 'content': query})
91
 
92
+ # Apply chat template and tokenize
93
+ text = tokenizer.apply_chat_template(
94
+ messages,
95
+ tokenize=False,
96
+ add_generation_prompt=True
97
+ )
98
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
99
 
100
  # Set up the streamer for real-time text generation
101
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
102
  generate_kwargs = dict(
103
+ **model_inputs,
 
104
  streamer=streamer,
105
  max_new_tokens=max_new_tokens,
106
  do_sample=True,