prithivMLmods commited on
Commit
f8af0ad
·
verified ·
1 Parent(s): 1d74de7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -57,19 +57,22 @@ def generate(
57
  conversation = chat_history.copy()
58
  conversation.append({"role": "user", "content": message})
59
 
60
- # Apply chat template and get input_ids and attention_mask
61
- inputs = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
62
- input_ids = inputs["input_ids"]
63
- attention_mask = inputs["attention_mask"]
64
 
 
 
 
 
65
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
66
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
67
  attention_mask = attention_mask[:, -MAX_INPUT_TOKEN_LENGTH:]
68
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
69
-
70
  input_ids = input_ids.to(model.device)
71
  attention_mask = attention_mask.to(model.device)
72
 
 
73
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
74
  generate_kwargs = dict(
75
  input_ids=input_ids,
@@ -87,6 +90,7 @@ def generate(
87
  t = Thread(target=model.generate, kwargs=generate_kwargs)
88
  t.start()
89
 
 
90
  outputs = []
91
  for text in streamer:
92
  outputs.append(text)
@@ -148,4 +152,4 @@ demo = gr.ChatInterface(
148
 
149
 
150
  if __name__ == "__main__":
151
- demo.queue(max_size=20).launch()
 
57
  conversation = chat_history.copy()
58
  conversation.append({"role": "user", "content": message})
59
 
60
+ # Apply chat template and get input_ids
61
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
 
 
62
 
63
+ # Create attention mask
64
+ attention_mask = torch.ones_like(input_ids)
65
+
66
+ # Trim input if it exceeds the maximum token length
67
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
68
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
69
  attention_mask = attention_mask[:, -MAX_INPUT_TOKEN_LENGTH:]
70
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
71
+
72
  input_ids = input_ids.to(model.device)
73
  attention_mask = attention_mask.to(model.device)
74
 
75
+ # Set up the streamer for real-time text generation
76
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
77
  generate_kwargs = dict(
78
  input_ids=input_ids,
 
90
  t = Thread(target=model.generate, kwargs=generate_kwargs)
91
  t.start()
92
 
93
+ # Stream the output tokens
94
  outputs = []
95
  for text in streamer:
96
  outputs.append(text)
 
152
 
153
 
154
  if __name__ == "__main__":
155
+ demo.queue(max_size=20).launch(share=True) # Set `share=True` for a public link