GPU training makes loss=nan

#37
by hidonbush - opened

When I was testing the model, I encountered the problem that if I use GPU, the loss becomes nan. If it's on CPU, everything is OK.

from transformers import AutoTokenizer,AutoModelForSequenceClassification
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
import torch

peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
target_modules=['down_proj','o_proj','k_proj','q_proj','gate_proj','up_proj','v_proj']
)

temp=AutoModelForSequenceClassification.from_pretrained("gemma2b",num_labels=2,torch_dtype = torch.bfloat16,device_map='auto')
model = get_peft_model(temp, peft_config)
tokenizer = AutoTokenizer.from_pretrained("gemma2b")
testdata=tokenizer('I like it too'+'',return_tensors='pt',padding='max_length',max_length=10).to('cuda')

print(model(**testdata,labels=torch.tensor(1).to('cuda')))

transformers is the latest
PyTorch:2.4
CUDA:12.1
Python:3.10

Google org

Hi @hidonbush ,

I executed the provided code in Google Colab with the runtime type set to 'T4 GPU' using the following library versions, and I did not encounter a NaN loss when using the GPU. For more details, please refer to the attached Gist notebook : https://colab.research.google.com/gist/Gopi-Uppari/e285595fac24fc55ae09a688b8e9d9b9/gpu-training-makes-loss-nan.ipynb#scrollTo=pzy7UtnkTuwx

  • torch: 2.4.0+cu121
  • Python: 3.10.12
  • peft: 0.12.0
  • transformers: 4.44.2

Please let us know if the issue still persists.

Thank you.

Hi @GopiUppari ,
I also tested the code on T4 and V100, it worked.
But if the GPU is Ampere or higher like 3090 or 4090, RTX6000, the problem occurs.
Here are the packages in the virtual conda environment:

accelerate 0.34.2
aiohappyeyeballs 2.4.0
aiohttp 3.10.5
aiosignal 1.3.1
async-timeout 4.0.3
attrs 24.2.0
certifi 2024.8.30
charset-normalizer 3.3.2
filelock 3.13.1
frozenlist 1.4.1
fsspec 2024.2.0
huggingface-hub 0.24.6
idna 3.8
Jinja2 3.1.3
lightning 2.4.0
lightning-utilities 0.11.7
MarkupSafe 2.1.5
mpmath 1.3.0
multidict 6.0.5
networkx 3.2.1
numpy 1.26.3
nvidia-cublas-cu12 12.4.2.65
nvidia-cuda-cupti-cu12 12.4.99
nvidia-cuda-nvrtc-cu12 12.4.99
nvidia-cuda-runtime-cu12 12.4.99
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.2.0.44
nvidia-curand-cu12 10.3.5.119
nvidia-cusolver-cu12 11.6.0.99
nvidia-cusparse-cu12 12.3.0.142
nvidia-nccl-cu12 2.20.5
nvidia-nvjitlink-cu12 12.4.99
nvidia-nvtx-cu12 12.4.99
packaging 24.1
peft 0.12.0
pillow 10.2.0
pip 24.2
psutil 6.0.0
pytorch-lightning 2.4.0
PyYAML 6.0.2
regex 2024.7.24
requests 2.32.3
safetensors 0.4.5
setuptools 72.1.0
sympy 1.12
tokenizers 0.19.1
torch 2.4.1+cu124
torchaudio 2.4.1+cu124
torchmetrics 1.4.1
torchvision 0.19.1+cu124
tqdm 4.66.5
transformers 4.44.2
triton 3.0.0
typing_extensions 4.9.0
urllib3 2.2.2
wheel 0.43.0
yarl 1.11.0

Your labels tensor is the wrong shape. This is what it should look like to not get a nan loss

print(model(**testdata,labels=torch.tensor([1]).to('cuda')))

# SequenceClassifierOutputWithPast(loss=tensor(0.6836, device='cuda:0', dtype=torch.bfloat16, grad_fn=<NllLossBackward0>), logits=tensor([[0.8047, 0.8203]], device='cuda:0', dtype=torch.bfloat16, grad_fn=<IndexBackward0>), past_key_values=None, hidden_states=None, attentions=None)

Okay guys, there is the solution:
https://github.com/huggingface/transformers/issues/32390
There might be some problem with the default attention implementation in the precision of bf16

Okay guys, there is the solution:
https://github.com/huggingface/transformers/issues/32390
There might be some problem with the default attention implementation in the precision of bf16

This fixed it for me. It appears there is some bug with the sdpa attention implementation in bf16. The solution for now is to change to a different attention implementation. I was loading the model via:

model = AutoModelForSequenceClassification.from_pretrained(model_config.model_name_or_path, num_labels=1, trust_remote_code=model_config.trust_remote_code, **model_kwargs)

When I set model_kwargs.attn_implementation = "flash_attention_2" it stopped being nans

Sign up or log in to comment