Transformers
PyTorch
code
custom_code
Inference Endpoints

Included gradient checkpointing

#1
by FJFehr - opened
Files changed (1) hide show
  1. modeling_codesage.py +19 -6
modeling_codesage.py CHANGED
@@ -156,6 +156,7 @@ class CodeSageBlock(nn.Module):
156
  class CodeSagePreTrainedModel(PreTrainedModel):
157
  config_class = CodeSageConfig
158
  base_model_prefix = "transformer"
 
159
 
160
  def _init_weights(self, module):
161
  """Initialize the weights."""
@@ -183,6 +184,8 @@ class CodeSageModel(CodeSagePreTrainedModel):
183
  self.h = nn.ModuleList([CodeSageBlock(config) for _ in range(config.num_hidden_layers)])
184
  self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
185
 
 
 
186
  self.init_weights()
187
 
188
  def get_input_embeddings(self):
@@ -247,12 +250,22 @@ class CodeSageModel(CodeSagePreTrainedModel):
247
  if output_hidden_states:
248
  all_hidden_states = all_hidden_states + (hidden_states,)
249
 
250
- outputs = block(
251
- hidden_states,
252
- attention_mask=extended_attention_mask,
253
- head_mask=head_mask[i],
254
- output_attentions=output_attentions,
255
- )
 
 
 
 
 
 
 
 
 
 
256
 
257
  hidden_states = outputs[0]
258
  if output_attentions:
 
156
  class CodeSagePreTrainedModel(PreTrainedModel):
157
  config_class = CodeSageConfig
158
  base_model_prefix = "transformer"
159
+ supports_gradient_checkpointing = True
160
 
161
  def _init_weights(self, module):
162
  """Initialize the weights."""
 
184
  self.h = nn.ModuleList([CodeSageBlock(config) for _ in range(config.num_hidden_layers)])
185
  self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
186
 
187
+ self.gradient_checkpointing = False
188
+
189
  self.init_weights()
190
 
191
  def get_input_embeddings(self):
 
250
  if output_hidden_states:
251
  all_hidden_states = all_hidden_states + (hidden_states,)
252
 
253
+ # Gradient checkpointing
254
+ if self.gradient_checkpointing and self.training:
255
+ outputs = self._gradient_checkpointing_func(
256
+ block.__call__,
257
+ hidden_states,
258
+ extended_attention_mask,
259
+ head_mask[i],
260
+ output_attentions,
261
+ )
262
+ else:
263
+ outputs = block(
264
+ hidden_states,
265
+ attention_mask=extended_attention_mask,
266
+ head_mask=head_mask[i],
267
+ output_attentions=output_attentions,
268
+ )
269
 
270
  hidden_states = outputs[0]
271
  if output_attentions: