prithivMLmods commited on
Commit
34d2094
·
verified ·
1 Parent(s): 83a0174

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -29
app.py CHANGED
@@ -34,7 +34,7 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
34
 
35
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
36
 
37
- # Text-only model and tokenizer
38
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
39
  tokenizer = AutoTokenizer.from_pretrained(model_id)
40
  model = AutoModelForCausalLM.from_pretrained(
@@ -53,7 +53,7 @@ TTS_VOICES = [
53
  "en-US-JasonNeural", # @tts6
54
  ]
55
 
56
- # Multimodal (OCR) model and processor
57
  MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
58
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
59
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
@@ -70,12 +70,11 @@ async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
70
 
71
  def clean_chat_history(chat_history):
72
  """
73
- Filter out any entries whose content is not a string.
74
- This avoids non-text objects (like tuples or Audio) from being concatenated.
75
  """
76
  cleaned = []
77
  for msg in chat_history:
78
- # Only keep dict messages that have a string 'content'
79
  if isinstance(msg, dict) and isinstance(msg.get("content"), str):
80
  cleaned.append(msg)
81
  return cleaned
@@ -91,14 +90,13 @@ def generate(
91
  repetition_penalty: float = 1.2,
92
  ):
93
  """
94
- Generates a chatbot response and handles TTS requests with multimodal input support.
95
- If the user’s query begins with an @tts command, previous chat history is ignored
96
- (clearing any non-text outputs). Otherwise, the chat history is cleaned to include only text.
97
  """
98
  text = input_dict["text"]
99
  files = input_dict.get("files", [])
100
 
101
- # Determine if images are provided
102
  if len(files) > 1:
103
  images = [load_image(image) for image in files]
104
  elif len(files) == 1:
@@ -106,25 +104,23 @@ def generate(
106
  else:
107
  images = []
108
 
109
- # Check for TTS prefix
110
  tts_prefix = "@tts"
111
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 7))
112
  voice_index = next((i for i in range(1, 7) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
113
-
114
  if is_tts and voice_index:
115
  voice = TTS_VOICES[voice_index - 1]
116
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
117
- # Clear any previous chat history when using TTS to avoid type errors
118
  conversation = [{"role": "user", "content": text}]
119
  else:
120
  voice = None
121
  text = text.replace(tts_prefix, "").strip()
122
- # Clean the chat history to include only messages with string content
123
  conversation = clean_chat_history(chat_history)
124
  conversation.append({"role": "user", "content": text})
125
 
126
- # Multimodal branch if images are provided
127
  if images:
 
128
  messages = [{
129
  "role": "user",
130
  "content": [
@@ -134,9 +130,8 @@ def generate(
134
  }]
135
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
136
  inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
137
-
138
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
139
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
140
  thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
141
  thread.start()
142
 
@@ -154,19 +149,18 @@ def generate(
154
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
155
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
156
  input_ids = input_ids.to(model.device)
157
-
158
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
159
- generate_kwargs = dict(
160
- {"input_ids": input_ids},
161
- streamer=streamer,
162
- max_new_tokens=max_new_tokens,
163
- do_sample=True,
164
- top_p=top_p,
165
- top_k=top_k,
166
- temperature=temperature,
167
- num_beams=1,
168
- repetition_penalty=repetition_penalty,
169
- )
170
  t = Thread(target=model.generate, kwargs=generation_kwargs)
171
  t.start()
172
 
@@ -176,7 +170,6 @@ def generate(
176
  yield "".join(outputs)
177
 
178
  final_response = "".join(outputs)
179
- # Yield text response first
180
  yield final_response
181
 
182
  if is_tts and voice:
 
34
 
35
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
36
 
37
+ # Load text-only model and tokenizer
38
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
39
  tokenizer = AutoTokenizer.from_pretrained(model_id)
40
  model = AutoModelForCausalLM.from_pretrained(
 
53
  "en-US-JasonNeural", # @tts6
54
  ]
55
 
56
+ # Load multimodal (OCR) model and processor
57
  MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
58
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
59
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
 
70
 
71
  def clean_chat_history(chat_history):
72
  """
73
+ Filter out any chat entries whose "content" is not a string.
74
+ This helps prevent errors when concatenating previous messages.
75
  """
76
  cleaned = []
77
  for msg in chat_history:
 
78
  if isinstance(msg, dict) and isinstance(msg.get("content"), str):
79
  cleaned.append(msg)
80
  return cleaned
 
90
  repetition_penalty: float = 1.2,
91
  ):
92
  """
93
+ Generates chatbot responses with support for multimodal input and TTS.
94
+ If the query starts with an @tts command (e.g. "@tts1"), previous chat history is cleared.
 
95
  """
96
  text = input_dict["text"]
97
  files = input_dict.get("files", [])
98
 
99
+ # Process image files if provided
100
  if len(files) > 1:
101
  images = [load_image(image) for image in files]
102
  elif len(files) == 1:
 
104
  else:
105
  images = []
106
 
 
107
  tts_prefix = "@tts"
108
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 7))
109
  voice_index = next((i for i in range(1, 7) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
110
+
111
  if is_tts and voice_index:
112
  voice = TTS_VOICES[voice_index - 1]
113
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
114
+ # Clear any previous chat history to avoid concatenation issues
115
  conversation = [{"role": "user", "content": text}]
116
  else:
117
  voice = None
118
  text = text.replace(tts_prefix, "").strip()
 
119
  conversation = clean_chat_history(chat_history)
120
  conversation.append({"role": "user", "content": text})
121
 
 
122
  if images:
123
+ # Multimodal branch using the OCR model
124
  messages = [{
125
  "role": "user",
126
  "content": [
 
130
  }]
131
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
132
  inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
 
133
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
134
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
135
  thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
136
  thread.start()
137
 
 
149
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
150
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
151
  input_ids = input_ids.to(model.device)
 
152
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
153
+ generation_kwargs = {
154
+ "input_ids": input_ids,
155
+ "streamer": streamer,
156
+ "max_new_tokens": max_new_tokens,
157
+ "do_sample": True,
158
+ "top_p": top_p,
159
+ "top_k": top_k,
160
+ "temperature": temperature,
161
+ "num_beams": 1,
162
+ "repetition_penalty": repetition_penalty,
163
+ }
164
  t = Thread(target=model.generate, kwargs=generation_kwargs)
165
  t.start()
166
 
 
170
  yield "".join(outputs)
171
 
172
  final_response = "".join(outputs)
 
173
  yield final_response
174
 
175
  if is_tts and voice: