John6666 commited on
Commit
4a827c1
·
verified ·
1 Parent(s): 3203488

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -51
app.py CHANGED
@@ -215,60 +215,58 @@ def stream_chat(input_images: List[Image.Image], caption_type: str, caption_tone
215
  for i in range(0, len(input_images), batch_size):
216
  batch = input_images[i:i+batch_size]
217
  # Preprocess image
218
- try:
219
- all_images = []
220
- for input_image in batch:
221
  image = input_image.resize((384, 384), Image.LANCZOS)
222
  pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
223
  pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
224
- all_images.append(TVF.to_pil_image(pixel_values.squeeze()))
225
- batch_pixel_values = clip_processor(images=all_images, return_tensors='pt', padding=True).pixel_values.to(device)
226
- except ValueError as e:
227
- print(f"Error processing image batch: {e}")
228
- print("Skipping this batch and continuing...")
229
- continue
230
-
231
- # Embed image
232
- with torch.amp.autocast_mode.autocast(device, enabled=True):
233
- vision_outputs = clip_model(pixel_values=batch_pixel_values, output_hidden_states=True)
234
- image_features = vision_outputs.hidden_states
235
- embedded_images = image_adapter(image_features).to(device)
236
-
237
- # Tokenize the prompt
238
- prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
239
-
240
- # Embed prompt
241
- prompt_embeds = text_model.model.embed_tokens(prompt.to(device))
242
- assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
243
- embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
244
- eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
245
-
246
- # Construct prompts
247
- inputs_embeds = torch.cat([
248
- embedded_bos.expand(embedded_images.shape[0], -1, -1),
249
- embedded_images.to(dtype=embedded_bos.dtype),
250
- prompt_embeds.expand(embedded_images.shape[0], -1, -1),
251
- eot_embed.expand(embedded_images.shape[0], -1, -1),
252
- ], dim=1)
253
-
254
- input_ids = torch.cat([
255
- torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
256
- torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
257
- prompt,
258
- torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
259
- ], dim=1).to(device)
260
- attention_mask = torch.ones_like(input_ids)
261
-
262
- generate_ids = text_model.generate(input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, do_sample=True,
263
- suppress_tokens=None, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature)
264
-
265
- # Trim off the prompt
266
- generate_ids = generate_ids[:, input_ids.shape[1]:]
267
- for ids in generate_ids:
268
- caption = tokenizer.decode(ids[:] if ids[0] == tokenizer.eos_token_id or ids[0] == tokenizer.convert_tokens_to_ids("<|eot_id|>") else ids,
269
- skip_special_tokens=True, clean_up_tokenization_spaces=True)
270
- caption = caption.replace('<|end_of_text|>', '').replace('<|finetune_right_pad_id|>', '').strip()
271
- all_captions.append(caption)
272
 
273
  if pbar:
274
  pbar.update(len(batch))
 
215
  for i in range(0, len(input_images), batch_size):
216
  batch = input_images[i:i+batch_size]
217
  # Preprocess image
218
+ for input_image in input_images:
219
+ try:
 
220
  image = input_image.resize((384, 384), Image.LANCZOS)
221
  pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
222
  pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
223
+ pixel_values = pixel_values.to(device)
224
+ except ValueError as e:
225
+ print(f"Error processing image: {e}")
226
+ print("Skipping this image and continuing...")
227
+ continue
228
+
229
+ # Embed image
230
+ with torch.amp.autocast_mode.autocast(device, enabled=True):
231
+ vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
232
+ image_features = vision_outputs.hidden_states
233
+ embedded_images = image_adapter(image_features).to(device)
234
+
235
+ # Tokenize the prompt
236
+ prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
237
+
238
+ # Embed prompt
239
+ prompt_embeds = text_model.model.embed_tokens(prompt.to(device))
240
+ assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
241
+ embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
242
+ eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
243
+
244
+ # Construct prompts
245
+ inputs_embeds = torch.cat([
246
+ embedded_bos.expand(embedded_images.shape[0], -1, -1),
247
+ embedded_images.to(dtype=embedded_bos.dtype),
248
+ prompt_embeds.expand(embedded_images.shape[0], -1, -1),
249
+ eot_embed.expand(embedded_images.shape[0], -1, -1),
250
+ ], dim=1)
251
+
252
+ input_ids = torch.cat([
253
+ torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
254
+ torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
255
+ prompt,
256
+ torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
257
+ ], dim=1).to(device)
258
+ attention_mask = torch.ones_like(input_ids)
259
+
260
+ generate_ids = text_model.generate(input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, do_sample=True,
261
+ suppress_tokens=None, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature)
262
+
263
+ # Trim off the prompt
264
+ generate_ids = generate_ids[:, input_ids.shape[1]:]
265
+ if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
266
+ generate_ids = generate_ids[:, :-1]
267
+
268
+ caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
269
+ all_captions.append(caption.strip())
 
270
 
271
  if pbar:
272
  pbar.update(len(batch))