TypeError: GemmaModel.forward() got an unexpected keyword argument 'num_items_in_batch'

#61
by smkhant - opened

getting below error while training the model for finetuning purpose for QLORA configuration


TypeError Traceback (most recent call last)
Cell In[34], line 3
1 import time
2 start = time.time()
----> 3 trainer.train()
4 print(time.time()- start)

File /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2171, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
2169 hf_hub_utils.enable_progress_bars()
2170 else:
-> 2171 return inner_training_loop(
2172 args=args,
2173 resume_from_checkpoint=resume_from_checkpoint,
2174 trial=trial,
2175 ignore_keys_for_eval=ignore_keys_for_eval,
2176 )

File /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2531, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
2524 context = (
2525 functools.partial(self.accelerator.no_sync, model=model)
2526 if i != len(batch_samples) - 1
2527 and self.accelerator.distributed_type != DistributedType.DEEPSPEED
2528 else contextlib.nullcontext
2529 )
2530 with context():
-> 2531 tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
2533 if (
2534 args.logging_nan_inf_filter
2535 and not is_torch_xla_available()
2536 and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
2537 ):
2538 # if loss is nan or inf simply add the average of previous logged losses
2539 tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3675, in Trainer.training_step(self, model, inputs, num_items_in_batch)
3672 return loss_mb.reduce_mean().detach().to(self.args.device)
3674 with self.compute_loss_context_manager():
-> 3675 loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
3677 del inputs
3678 if (
3679 self.args.torch_empty_cache_steps is not None
3680 and self.state.global_step % self.args.torch_empty_cache_steps == 0
3681 ):

File /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3731, in Trainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
3729 loss_kwargs["num_items_in_batch"] = num_items_in_batch
3730 inputs = {**inputs, **loss_kwargs}
-> 3731 outputs = model(**inputs)
3732 # Save past state if it exists
3733 # TODO: this needs to be fixed and made cleaner later.
3734 if self.args.past_index >= 0:

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
1745 # If we don't have any hooks, we want to skip the rest of the logic in
1746 # this function, and just call forward.
1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1752 result = None
1753 called_always_called_hooks = set()

File /usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py:819, in convert_outputs_to_fp32..forward(*args, **kwargs)
818 def forward(*args, **kwargs):
--> 819 return model_forward(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py:807, in ConvertOutputsToFp32.call(self, *args, **kwargs)
806 def call(self, *args, **kwargs):
--> 807 return convert_to_fp32(self.model_forward(*args, **kwargs))

File /usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py:44, in autocast_decorator..decorate_autocast(*args, **kwargs)
41 @functools.wraps(func)
42 def decorate_autocast(*args, **kwargs):
43 with autocast_instance:
---> 44 return func(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/peft/peft_model.py:1719, in PeftModelForCausalLM.forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
1717 with self._enable_peft_forward_hooks(**kwargs):
1718 kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1719 return self.base_model(
1720 input_ids=input_ids,
1721 attention_mask=attention_mask,
1722 inputs_embeds=inputs_embeds,
1723 labels=labels,
1724 output_attentions=output_attentions,
1725 output_hidden_states=output_hidden_states,
1726 return_dict=return_dict,
1727 **kwargs,
1728 )
1730 batch_size = _get_batch_size(input_ids, inputs_embeds)
1731 if attention_mask is not None:
1732 # concat prompt attention mask

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
1745 # If we don't have any hooks, we want to skip the rest of the logic in
1746 # this function, and just call forward.
1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1752 result = None
1753 called_always_called_hooks = set()

File /usr/local/lib/python3.10/dist-packages/peft/tuners/tuners_utils.py:197, in BaseTuner.forward(self, *args, **kwargs)
196 def forward(self, *args: Any, **kwargs: Any):
--> 197 return self.model.forward(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/transformers/models/gemma/modeling_gemma.py:832, in GemmaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **kwargs)
829 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
831 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 832 outputs = self.model(
833 input_ids=input_ids,
834 attention_mask=attention_mask,
835 position_ids=position_ids,
836 past_key_values=past_key_values,
837 inputs_embeds=inputs_embeds,
838 use_cache=use_cache,
839 output_attentions=output_attentions,
840 output_hidden_states=output_hidden_states,
841 return_dict=return_dict,
842 cache_position=cache_position,
843 **kwargs,
844 )
846 hidden_states = outputs[0]
847 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
1745 # If we don't have any hooks, we want to skip the rest of the logic in
1746 # this function, and just call forward.
1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1752 result = None
1753 called_always_called_hooks = set()

TypeError: GemmaModel.forward() got an unexpected keyword argument 'num_items_in_batch'

Hi @smkhant ,

I reproduced the same error, please refer this gist file. The error was occurring because the default Trainer was passing num_items_in_batch to the model's forward pass, which Gemma doesn't accept. To avoid that error we need to create a CustomTrainer class that inherits from Trainer. Override the compute_loss method to properly handle the inputs and remove the problematic num_items_in_batch parameter. For more details, please refer to the Github Code.

Thank you.

Sign up or log in to comment