Mayank Mishra
commited on
Commit
·
6bb0180
1
Parent(s):
448e236
update script
Browse files- modeling_granite.py +12 -10
modeling_granite.py
CHANGED
@@ -1,4 +1,6 @@
|
|
|
|
1 |
import numbers
|
|
|
2 |
from enum import Enum
|
3 |
from typing import Optional, Tuple, Union
|
4 |
|
@@ -846,7 +848,7 @@ class GranitePreTrainedModel(PreTrainedModel):
|
|
846 |
self.initializer_range = config.initializer_range
|
847 |
|
848 |
def _init_weights(self, module: nn.Module) -> None:
|
849 |
-
if isinstance(module, (nn.LayerNorm, RMSNorm, RoPE)):
|
850 |
module.reset_parameters()
|
851 |
elif isinstance(module, nn.Linear):
|
852 |
nn.init.normal_(module.weight, mean=0, std=self.initializer_range)
|
@@ -1104,15 +1106,15 @@ class GraniteModel(GranitePreTrainedModel):
|
|
1104 |
|
1105 |
def _prepare_a_bunch_of_stuff(
|
1106 |
self,
|
1107 |
-
input_ids: torch.Tensor
|
1108 |
-
past_key_values: DynamicCache
|
1109 |
-
attention_mask: torch.Tensor
|
1110 |
-
token_type_ids: torch.Tensor
|
1111 |
-
position_ids: torch.Tensor
|
1112 |
-
inputs_embeds: torch.Tensor
|
1113 |
-
use_cache: bool
|
1114 |
-
output_hidden_states: bool
|
1115 |
-
return_dict: bool
|
1116 |
) -> Tuple[
|
1117 |
bool,
|
1118 |
bool,
|
|
|
1 |
+
import math
|
2 |
import numbers
|
3 |
+
import warnings
|
4 |
from enum import Enum
|
5 |
from typing import Optional, Tuple, Union
|
6 |
|
|
|
848 |
self.initializer_range = config.initializer_range
|
849 |
|
850 |
def _init_weights(self, module: nn.Module) -> None:
|
851 |
+
if isinstance(module, (nn.LayerNorm, RMSNorm, Alibi, RoPE)):
|
852 |
module.reset_parameters()
|
853 |
elif isinstance(module, nn.Linear):
|
854 |
nn.init.normal_(module.weight, mean=0, std=self.initializer_range)
|
|
|
1106 |
|
1107 |
def _prepare_a_bunch_of_stuff(
|
1108 |
self,
|
1109 |
+
input_ids: torch.Tensor,
|
1110 |
+
past_key_values: DynamicCache,
|
1111 |
+
attention_mask: torch.Tensor,
|
1112 |
+
token_type_ids: torch.Tensor,
|
1113 |
+
position_ids: torch.Tensor,
|
1114 |
+
inputs_embeds: torch.Tensor,
|
1115 |
+
use_cache: bool,
|
1116 |
+
output_hidden_states: bool,
|
1117 |
+
return_dict: bool,
|
1118 |
) -> Tuple[
|
1119 |
bool,
|
1120 |
bool,
|