shilinxu commited on
Commit
43b3637
·
verified ·
1 Parent(s): 4846136

Update modeling_intern_vit.py

Browse files
Files changed (1) hide show
  1. modeling_intern_vit.py +36 -11
modeling_intern_vit.py CHANGED
@@ -219,12 +219,11 @@ class InternAttention(nn.Module):
219
 
220
  attn = ((q * self.scale) @ k.transpose(-2, -1))
221
  attn = attn.softmax(dim=-1)
222
- attn = self.attn_drop(attn)
223
 
224
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
225
  x = self.proj(x)
226
  x = self.proj_drop(x)
227
- return x
228
 
229
  def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
230
  qkv = self.qkv(x)
@@ -243,8 +242,11 @@ class InternAttention(nn.Module):
243
  outs = self.proj_drop(outs)
244
  return outs
245
 
246
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
247
- x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
 
 
 
248
  return x
249
 
250
 
@@ -283,23 +285,37 @@ class InternVisionEncoderLayer(nn.Module):
283
  def forward(
284
  self,
285
  hidden_states: torch.Tensor,
 
286
  ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
287
  """
288
  Args:
289
  hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
290
  """
291
- hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1)
292
 
293
- hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2)
 
 
 
 
 
294
 
295
- return hidden_states
 
 
 
296
 
 
 
 
 
 
 
297
 
298
  class InternVisionEncoder(nn.Module):
299
  """
300
  Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
301
  [`InternEncoderLayer`].
302
-
303
  Args:
304
  config (`InternConfig`):
305
  The corresponding vision configuration for the `InternEncoder`.
@@ -318,6 +334,7 @@ class InternVisionEncoder(nn.Module):
318
  self,
319
  inputs_embeds,
320
  output_hidden_states: Optional[bool] = None,
 
321
  return_dict: Optional[bool] = None,
322
  ) -> Union[Tuple, BaseModelOutput]:
323
  r"""
@@ -336,6 +353,8 @@ class InternVisionEncoder(nn.Module):
336
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
337
 
338
  encoder_states = () if output_hidden_states else None
 
 
339
  hidden_states = inputs_embeds
340
 
341
  for idx, encoder_layer in enumerate(self.layers):
@@ -348,8 +367,12 @@ class InternVisionEncoder(nn.Module):
348
  else:
349
  layer_outputs = encoder_layer(
350
  hidden_states,
 
351
  )
352
- hidden_states = layer_outputs
 
 
 
353
 
354
  if output_hidden_states:
355
  encoder_states = encoder_states + (hidden_states,)
@@ -357,7 +380,7 @@ class InternVisionEncoder(nn.Module):
357
  if not return_dict:
358
  return tuple(v for v in [hidden_states, encoder_states] if v is not None)
359
  return BaseModelOutput(
360
- last_hidden_state=hidden_states, hidden_states=encoder_states
361
  )
362
 
363
 
@@ -393,6 +416,7 @@ class InternVisionModel(PreTrainedModel):
393
  self,
394
  pixel_values: Optional[torch.FloatTensor] = None,
395
  output_hidden_states: Optional[bool] = None,
 
396
  return_dict: Optional[bool] = None,
397
  pixel_embeds: Optional[torch.FloatTensor] = None,
398
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
@@ -414,6 +438,7 @@ class InternVisionModel(PreTrainedModel):
414
  encoder_outputs = self.encoder(
415
  inputs_embeds=hidden_states,
416
  output_hidden_states=output_hidden_states,
 
417
  return_dict=return_dict,
418
  )
419
  last_hidden_state = encoder_outputs.last_hidden_state
 
219
 
220
  attn = ((q * self.scale) @ k.transpose(-2, -1))
221
  attn = attn.softmax(dim=-1)
 
222
 
223
+ x = (self.attn_drop(attn) @ v).transpose(1, 2).reshape(B, N, C)
224
  x = self.proj(x)
225
  x = self.proj_drop(x)
226
+ return x, attn
227
 
228
  def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
229
  qkv = self.qkv(x)
 
242
  outs = self.proj_drop(outs)
243
  return outs
244
 
245
+ def forward(self, hidden_states: torch.Tensor, output_attentions: Optional[bool] = False,) -> torch.Tensor:
246
+ if not self.use_flash_attn:
247
+ x = self._naive_attn(hidden_states)
248
+ else:
249
+ x = self._flash_attn(hidden_states)
250
  return x
251
 
252
 
 
285
  def forward(
286
  self,
287
  hidden_states: torch.Tensor,
288
+ output_attentions: Optional[bool] = False,
289
  ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
290
  """
291
  Args:
292
  hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
293
  """
294
+ residual = hidden_states
295
 
296
+ hidden_states = self.norm1(hidden_states).to(hidden_states.dtype)
297
+ hidden_states, attn_weights = self.attn(
298
+ hidden_states=hidden_states,
299
+ output_attentions=output_attentions,
300
+ )
301
+ hidden_states = residual + self.drop_path1(hidden_states * self.ls1)
302
 
303
+ residual = hidden_states
304
+ hidden_states = self.norm2(hidden_states)
305
+ hidden_states = self.mlp(hidden_states)
306
+ hidden_states = residual + self.drop_path2(hidden_states * self.ls2)
307
 
308
+ outputs = (hidden_states,)
309
+
310
+ if output_attentions:
311
+ outputs += (attn_weights,)
312
+
313
+ return outputs
314
 
315
  class InternVisionEncoder(nn.Module):
316
  """
317
  Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
318
  [`InternEncoderLayer`].
 
319
  Args:
320
  config (`InternConfig`):
321
  The corresponding vision configuration for the `InternEncoder`.
 
334
  self,
335
  inputs_embeds,
336
  output_hidden_states: Optional[bool] = None,
337
+ output_attentions: Optional[bool] = None,
338
  return_dict: Optional[bool] = None,
339
  ) -> Union[Tuple, BaseModelOutput]:
340
  r"""
 
353
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
354
 
355
  encoder_states = () if output_hidden_states else None
356
+ all_attentions = () if output_attentions else None
357
+
358
  hidden_states = inputs_embeds
359
 
360
  for idx, encoder_layer in enumerate(self.layers):
 
367
  else:
368
  layer_outputs = encoder_layer(
369
  hidden_states,
370
+ output_attentions=output_attentions,
371
  )
372
+ hidden_states = layer_outputs[0]
373
+
374
+ if output_attentions:
375
+ all_attentions = all_attentions + (layer_outputs[1],)
376
 
377
  if output_hidden_states:
378
  encoder_states = encoder_states + (hidden_states,)
 
380
  if not return_dict:
381
  return tuple(v for v in [hidden_states, encoder_states] if v is not None)
382
  return BaseModelOutput(
383
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
384
  )
385
 
386
 
 
416
  self,
417
  pixel_values: Optional[torch.FloatTensor] = None,
418
  output_hidden_states: Optional[bool] = None,
419
+ output_attentions: Optional[bool] = None,
420
  return_dict: Optional[bool] = None,
421
  pixel_embeds: Optional[torch.FloatTensor] = None,
422
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
 
438
  encoder_outputs = self.encoder(
439
  inputs_embeds=hidden_states,
440
  output_hidden_states=output_hidden_states,
441
+ output_attentions=output_attentions,
442
  return_dict=return_dict,
443
  )
444
  last_hidden_state = encoder_outputs.last_hidden_state