Update modeling_intern_vit.py
Browse files- modeling_intern_vit.py +36 -11
modeling_intern_vit.py
CHANGED
@@ -219,12 +219,11 @@ class InternAttention(nn.Module):
|
|
219 |
|
220 |
attn = ((q * self.scale) @ k.transpose(-2, -1))
|
221 |
attn = attn.softmax(dim=-1)
|
222 |
-
attn = self.attn_drop(attn)
|
223 |
|
224 |
-
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
225 |
x = self.proj(x)
|
226 |
x = self.proj_drop(x)
|
227 |
-
return x
|
228 |
|
229 |
def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
|
230 |
qkv = self.qkv(x)
|
@@ -243,8 +242,11 @@ class InternAttention(nn.Module):
|
|
243 |
outs = self.proj_drop(outs)
|
244 |
return outs
|
245 |
|
246 |
-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
247 |
-
|
|
|
|
|
|
|
248 |
return x
|
249 |
|
250 |
|
@@ -283,23 +285,37 @@ class InternVisionEncoderLayer(nn.Module):
|
|
283 |
def forward(
|
284 |
self,
|
285 |
hidden_states: torch.Tensor,
|
|
|
286 |
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
|
287 |
"""
|
288 |
Args:
|
289 |
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
290 |
"""
|
291 |
-
|
292 |
|
293 |
-
hidden_states =
|
|
|
|
|
|
|
|
|
|
|
294 |
|
295 |
-
|
|
|
|
|
|
|
296 |
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
|
298 |
class InternVisionEncoder(nn.Module):
|
299 |
"""
|
300 |
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
301 |
[`InternEncoderLayer`].
|
302 |
-
|
303 |
Args:
|
304 |
config (`InternConfig`):
|
305 |
The corresponding vision configuration for the `InternEncoder`.
|
@@ -318,6 +334,7 @@ class InternVisionEncoder(nn.Module):
|
|
318 |
self,
|
319 |
inputs_embeds,
|
320 |
output_hidden_states: Optional[bool] = None,
|
|
|
321 |
return_dict: Optional[bool] = None,
|
322 |
) -> Union[Tuple, BaseModelOutput]:
|
323 |
r"""
|
@@ -336,6 +353,8 @@ class InternVisionEncoder(nn.Module):
|
|
336 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
337 |
|
338 |
encoder_states = () if output_hidden_states else None
|
|
|
|
|
339 |
hidden_states = inputs_embeds
|
340 |
|
341 |
for idx, encoder_layer in enumerate(self.layers):
|
@@ -348,8 +367,12 @@ class InternVisionEncoder(nn.Module):
|
|
348 |
else:
|
349 |
layer_outputs = encoder_layer(
|
350 |
hidden_states,
|
|
|
351 |
)
|
352 |
-
hidden_states = layer_outputs
|
|
|
|
|
|
|
353 |
|
354 |
if output_hidden_states:
|
355 |
encoder_states = encoder_states + (hidden_states,)
|
@@ -357,7 +380,7 @@ class InternVisionEncoder(nn.Module):
|
|
357 |
if not return_dict:
|
358 |
return tuple(v for v in [hidden_states, encoder_states] if v is not None)
|
359 |
return BaseModelOutput(
|
360 |
-
last_hidden_state=hidden_states, hidden_states=encoder_states
|
361 |
)
|
362 |
|
363 |
|
@@ -393,6 +416,7 @@ class InternVisionModel(PreTrainedModel):
|
|
393 |
self,
|
394 |
pixel_values: Optional[torch.FloatTensor] = None,
|
395 |
output_hidden_states: Optional[bool] = None,
|
|
|
396 |
return_dict: Optional[bool] = None,
|
397 |
pixel_embeds: Optional[torch.FloatTensor] = None,
|
398 |
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
@@ -414,6 +438,7 @@ class InternVisionModel(PreTrainedModel):
|
|
414 |
encoder_outputs = self.encoder(
|
415 |
inputs_embeds=hidden_states,
|
416 |
output_hidden_states=output_hidden_states,
|
|
|
417 |
return_dict=return_dict,
|
418 |
)
|
419 |
last_hidden_state = encoder_outputs.last_hidden_state
|
|
|
219 |
|
220 |
attn = ((q * self.scale) @ k.transpose(-2, -1))
|
221 |
attn = attn.softmax(dim=-1)
|
|
|
222 |
|
223 |
+
x = (self.attn_drop(attn) @ v).transpose(1, 2).reshape(B, N, C)
|
224 |
x = self.proj(x)
|
225 |
x = self.proj_drop(x)
|
226 |
+
return x, attn
|
227 |
|
228 |
def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
|
229 |
qkv = self.qkv(x)
|
|
|
242 |
outs = self.proj_drop(outs)
|
243 |
return outs
|
244 |
|
245 |
+
def forward(self, hidden_states: torch.Tensor, output_attentions: Optional[bool] = False,) -> torch.Tensor:
|
246 |
+
if not self.use_flash_attn:
|
247 |
+
x = self._naive_attn(hidden_states)
|
248 |
+
else:
|
249 |
+
x = self._flash_attn(hidden_states)
|
250 |
return x
|
251 |
|
252 |
|
|
|
285 |
def forward(
|
286 |
self,
|
287 |
hidden_states: torch.Tensor,
|
288 |
+
output_attentions: Optional[bool] = False,
|
289 |
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
|
290 |
"""
|
291 |
Args:
|
292 |
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
293 |
"""
|
294 |
+
residual = hidden_states
|
295 |
|
296 |
+
hidden_states = self.norm1(hidden_states).to(hidden_states.dtype)
|
297 |
+
hidden_states, attn_weights = self.attn(
|
298 |
+
hidden_states=hidden_states,
|
299 |
+
output_attentions=output_attentions,
|
300 |
+
)
|
301 |
+
hidden_states = residual + self.drop_path1(hidden_states * self.ls1)
|
302 |
|
303 |
+
residual = hidden_states
|
304 |
+
hidden_states = self.norm2(hidden_states)
|
305 |
+
hidden_states = self.mlp(hidden_states)
|
306 |
+
hidden_states = residual + self.drop_path2(hidden_states * self.ls2)
|
307 |
|
308 |
+
outputs = (hidden_states,)
|
309 |
+
|
310 |
+
if output_attentions:
|
311 |
+
outputs += (attn_weights,)
|
312 |
+
|
313 |
+
return outputs
|
314 |
|
315 |
class InternVisionEncoder(nn.Module):
|
316 |
"""
|
317 |
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
318 |
[`InternEncoderLayer`].
|
|
|
319 |
Args:
|
320 |
config (`InternConfig`):
|
321 |
The corresponding vision configuration for the `InternEncoder`.
|
|
|
334 |
self,
|
335 |
inputs_embeds,
|
336 |
output_hidden_states: Optional[bool] = None,
|
337 |
+
output_attentions: Optional[bool] = None,
|
338 |
return_dict: Optional[bool] = None,
|
339 |
) -> Union[Tuple, BaseModelOutput]:
|
340 |
r"""
|
|
|
353 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
354 |
|
355 |
encoder_states = () if output_hidden_states else None
|
356 |
+
all_attentions = () if output_attentions else None
|
357 |
+
|
358 |
hidden_states = inputs_embeds
|
359 |
|
360 |
for idx, encoder_layer in enumerate(self.layers):
|
|
|
367 |
else:
|
368 |
layer_outputs = encoder_layer(
|
369 |
hidden_states,
|
370 |
+
output_attentions=output_attentions,
|
371 |
)
|
372 |
+
hidden_states = layer_outputs[0]
|
373 |
+
|
374 |
+
if output_attentions:
|
375 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
376 |
|
377 |
if output_hidden_states:
|
378 |
encoder_states = encoder_states + (hidden_states,)
|
|
|
380 |
if not return_dict:
|
381 |
return tuple(v for v in [hidden_states, encoder_states] if v is not None)
|
382 |
return BaseModelOutput(
|
383 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
384 |
)
|
385 |
|
386 |
|
|
|
416 |
self,
|
417 |
pixel_values: Optional[torch.FloatTensor] = None,
|
418 |
output_hidden_states: Optional[bool] = None,
|
419 |
+
output_attentions: Optional[bool] = None,
|
420 |
return_dict: Optional[bool] = None,
|
421 |
pixel_embeds: Optional[torch.FloatTensor] = None,
|
422 |
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
|
|
438 |
encoder_outputs = self.encoder(
|
439 |
inputs_embeds=hidden_states,
|
440 |
output_hidden_states=output_hidden_states,
|
441 |
+
output_attentions=output_attentions,
|
442 |
return_dict=return_dict,
|
443 |
)
|
444 |
last_hidden_state = encoder_outputs.last_hidden_state
|