Included gradient checkpointing
#1
by
FJFehr
- opened
- modeling_codesage.py +19 -6
modeling_codesage.py
CHANGED
@@ -156,6 +156,7 @@ class CodeSageBlock(nn.Module):
|
|
156 |
class CodeSagePreTrainedModel(PreTrainedModel):
|
157 |
config_class = CodeSageConfig
|
158 |
base_model_prefix = "transformer"
|
|
|
159 |
|
160 |
def _init_weights(self, module):
|
161 |
"""Initialize the weights."""
|
@@ -183,6 +184,8 @@ class CodeSageModel(CodeSagePreTrainedModel):
|
|
183 |
self.h = nn.ModuleList([CodeSageBlock(config) for _ in range(config.num_hidden_layers)])
|
184 |
self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
185 |
|
|
|
|
|
186 |
self.init_weights()
|
187 |
|
188 |
def get_input_embeddings(self):
|
@@ -247,12 +250,22 @@ class CodeSageModel(CodeSagePreTrainedModel):
|
|
247 |
if output_hidden_states:
|
248 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
249 |
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
|
257 |
hidden_states = outputs[0]
|
258 |
if output_attentions:
|
|
|
156 |
class CodeSagePreTrainedModel(PreTrainedModel):
|
157 |
config_class = CodeSageConfig
|
158 |
base_model_prefix = "transformer"
|
159 |
+
supports_gradient_checkpointing = True
|
160 |
|
161 |
def _init_weights(self, module):
|
162 |
"""Initialize the weights."""
|
|
|
184 |
self.h = nn.ModuleList([CodeSageBlock(config) for _ in range(config.num_hidden_layers)])
|
185 |
self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
186 |
|
187 |
+
self.gradient_checkpointing = False
|
188 |
+
|
189 |
self.init_weights()
|
190 |
|
191 |
def get_input_embeddings(self):
|
|
|
250 |
if output_hidden_states:
|
251 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
252 |
|
253 |
+
# Gradient checkpointing
|
254 |
+
if self.gradient_checkpointing and self.training:
|
255 |
+
outputs = self._gradient_checkpointing_func(
|
256 |
+
block.__call__,
|
257 |
+
hidden_states,
|
258 |
+
extended_attention_mask,
|
259 |
+
head_mask[i],
|
260 |
+
output_attentions,
|
261 |
+
)
|
262 |
+
else:
|
263 |
+
outputs = block(
|
264 |
+
hidden_states,
|
265 |
+
attention_mask=extended_attention_mask,
|
266 |
+
head_mask=head_mask[i],
|
267 |
+
output_attentions=output_attentions,
|
268 |
+
)
|
269 |
|
270 |
hidden_states = outputs[0]
|
271 |
if output_attentions:
|