Upload app.py
Browse files
app.py
CHANGED
@@ -215,60 +215,58 @@ def stream_chat(input_images: List[Image.Image], caption_type: str, caption_tone
|
|
215 |
for i in range(0, len(input_images), batch_size):
|
216 |
batch = input_images[i:i+batch_size]
|
217 |
# Preprocess image
|
218 |
-
|
219 |
-
|
220 |
-
for input_image in batch:
|
221 |
image = input_image.resize((384, 384), Image.LANCZOS)
|
222 |
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
223 |
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
all_captions.append(caption)
|
272 |
|
273 |
if pbar:
|
274 |
pbar.update(len(batch))
|
|
|
215 |
for i in range(0, len(input_images), batch_size):
|
216 |
batch = input_images[i:i+batch_size]
|
217 |
# Preprocess image
|
218 |
+
for input_image in input_images:
|
219 |
+
try:
|
|
|
220 |
image = input_image.resize((384, 384), Image.LANCZOS)
|
221 |
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
222 |
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
223 |
+
pixel_values = pixel_values.to(device)
|
224 |
+
except ValueError as e:
|
225 |
+
print(f"Error processing image: {e}")
|
226 |
+
print("Skipping this image and continuing...")
|
227 |
+
continue
|
228 |
+
|
229 |
+
# Embed image
|
230 |
+
with torch.amp.autocast_mode.autocast(device, enabled=True):
|
231 |
+
vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
|
232 |
+
image_features = vision_outputs.hidden_states
|
233 |
+
embedded_images = image_adapter(image_features).to(device)
|
234 |
+
|
235 |
+
# Tokenize the prompt
|
236 |
+
prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
|
237 |
+
|
238 |
+
# Embed prompt
|
239 |
+
prompt_embeds = text_model.model.embed_tokens(prompt.to(device))
|
240 |
+
assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
|
241 |
+
embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
|
242 |
+
eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
|
243 |
+
|
244 |
+
# Construct prompts
|
245 |
+
inputs_embeds = torch.cat([
|
246 |
+
embedded_bos.expand(embedded_images.shape[0], -1, -1),
|
247 |
+
embedded_images.to(dtype=embedded_bos.dtype),
|
248 |
+
prompt_embeds.expand(embedded_images.shape[0], -1, -1),
|
249 |
+
eot_embed.expand(embedded_images.shape[0], -1, -1),
|
250 |
+
], dim=1)
|
251 |
+
|
252 |
+
input_ids = torch.cat([
|
253 |
+
torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
|
254 |
+
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
|
255 |
+
prompt,
|
256 |
+
torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
|
257 |
+
], dim=1).to(device)
|
258 |
+
attention_mask = torch.ones_like(input_ids)
|
259 |
+
|
260 |
+
generate_ids = text_model.generate(input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, do_sample=True,
|
261 |
+
suppress_tokens=None, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature)
|
262 |
+
|
263 |
+
# Trim off the prompt
|
264 |
+
generate_ids = generate_ids[:, input_ids.shape[1]:]
|
265 |
+
if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
|
266 |
+
generate_ids = generate_ids[:, :-1]
|
267 |
+
|
268 |
+
caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
|
269 |
+
all_captions.append(caption.strip())
|
|
|
270 |
|
271 |
if pbar:
|
272 |
pbar.update(len(batch))
|