0%| | 0/20000 [00:00 main() File "/fsx/sanchit/distil-zephyr-1.5b-ssft-ultrachat/run_sft.py", line 172, in main train_result = trainer.train(resume_from_checkpoint=checkpoint) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 361, in train output = super().train(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/fsx/sanchit/transformers/src/transformers/trainer.py", line 1849, in train return inner_training_loop( ^^^^^^^^^^^^^^^^^^^^ File "/fsx/sanchit/transformers/src/transformers/trainer.py", line 2202, in _inner_training_loop tr_loss_step = self.training_step(model, inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/fsx/sanchit/transformers/src/transformers/trainer.py", line 3137, in training_step loss = self.compute_loss(model, inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/fsx/sanchit/transformers/src/transformers/trainer.py", line 3160, in compute_loss outputs = model(**inputs) ^^^^^^^^^^^^^^^ File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1608, in forward else self._run_ddp_forward(*inputs, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1426, in _run_ddp_forward return self.module(*inputs, **kwargs) # type: ignore[index] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 825, in forward return model_forward(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 813, in __call__ return convert_to_fp32(self.model_forward(*args, **kwargs)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/fsx/sanchit/transformers/src/transformers/models/mistral/modeling_mistral.py", line 1184, in forward loss = loss_fct(shift_logits, shift_labels) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/torch/nn/modules/loss.py", line 1185, in forward return F.cross_entropy(input, target, weight=self.weight, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/torch/nn/functional.py", line 3088, in cross_entropy return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 15.62 GiB. GPU [rank0]: Traceback (most recent call last): [rank0]: File "/fsx/sanchit/distil-zephyr-1.5b-ssft-ultrachat/run_sft.py", line 217, in [rank0]: main() [rank0]: File "/fsx/sanchit/distil-zephyr-1.5b-ssft-ultrachat/run_sft.py", line 172, in main [rank0]: train_result = trainer.train(resume_from_checkpoint=checkpoint) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 361, in train [rank0]: output = super().train(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/sanchit/transformers/src/transformers/trainer.py", line 1849, in train [rank0]: return inner_training_loop( [rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/sanchit/transformers/src/transformers/trainer.py", line 2202, in _inner_training_loop [rank0]: tr_loss_step = self.training_step(model, inputs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/sanchit/transformers/src/transformers/trainer.py", line 3137, in training_step [rank0]: loss = self.compute_loss(model, inputs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/sanchit/transformers/src/transformers/trainer.py", line 3160, in compute_loss [rank0]: outputs = model(**inputs) [rank0]: ^^^^^^^^^^^^^^^ [rank0]: File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1608, in forward [rank0]: else self._run_ddp_forward(*inputs, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1426, in _run_ddp_forward [rank0]: return self.module(*inputs, **kwargs) # type: ignore[index] [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 825, in forward [rank0]: return model_forward(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 813, in __call__ [rank0]: return convert_to_fp32(self.model_forward(*args, **kwargs)) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast [rank0]: return func(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/sanchit/transformers/src/transformers/models/mistral/modeling_mistral.py", line 1184, in forward [rank0]: loss = loss_fct(shift_logits, shift_labels) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/torch/nn/modules/loss.py", line 1185, in forward [rank0]: return F.cross_entropy(input, target, weight=self.weight, [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/sanchit/miniconda3/envs/venv/lib/python3.11/site-packages/torch/nn/functional.py", line 3088, in cross_entropy [rank0]: return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 15.62 GiB. GPU