Unable to load the model for Torch versions starting from 2.0.1
Hello,
I am encountering an issue while attempting to load the Llama3 8B model using the pipeline
function with a bfloat16
dtype, with the latest version of the transformers library. However, I am faced with a runtime error when using the latest torch version (the same problem persists for any torch version starting from 2.0.1).
Code for loading:
# Loading the model
pipe = pipeline("text-generation", model="meta-llama/Meta-Llama-3-8B-Instruct", torch_dtype=torch.bfloat16, device_map="auto")
RuntimeError: (has something to do with flash attention)
RuntimeError: Failed to import transformers.models.llama.modeling_llama because of the following error (look up to see its traceback):
/databricks/python/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops9_pad_enum4callERKNS_6TensorEN3c108ArrayRefINS5_6SymIntEEElNS5_8optionalIdEE
This issue seems to be resolved when downgrading torch to versions earlier than 2.0.1, but then another issue arises during inference. The torch versions prior to 2.0.1 do not support operations on bfloat16
dtype, which results in the following error:
1094 if sequence_length != 1:
-> 1095 causal_mask = torch.triu(causal_mask, diagonal=1)
1096 causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1097 causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16'
Could anyone please help me resolve these issues or suggest a workaround? I would greatly appreciate any assistance. Thank you!
I encounter the same issue, though I am using AutoModelForCausalLM
function, the error is same as yours. My transformers package version is 4.39.3 and torch version is 2.0.1 .
You'll likely need to update your transformers
package to version 4.40.0
, which supports Llama 3.
However, we can get arround the error mentioned above by downgrading PyTorch
to version 2.0.1
. Then, load Llama3-8B in float16
precision, rather than bfloat16
. This approach should bypass the error as triu_tril_cuda_template
is implemented for the float16
data type in PyTorch 2.0.1
, but note that it doesn't take advantage of the bfloat16
format.
Exactly the same problem as the original post here, except for me with torch==2.0.1 I have the second bug (RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16'
)
I solved my problem by replacingcausal_mask = torch.triu(causal_mask, diagonal=1)
with causal_mask = custom_triu(causal_mask)
, with
def custom_triu(input_tensor):
rows, cols = input_tensor.shape
row_indices = torch.arange(rows).unsqueeze(1).expand(rows, cols)
col_indices = torch.arange(cols).unsqueeze(0).expand(rows, cols)
mask = row_indices >= col_indices
output_tensor = input_tensor.clone()
output_tensor[mask] = 0
return output_tensor
This is a torch break issue, which seems to have been fixed. Upgrading your torch version should be the best bet here 😉 This worked for me in torch 2.3
Still encounter the issue with torch2.3 transformers 4.40.0
Yes, I have the same issue. For torch 2.3.0 and transformers on 4.40.1 flash attention throws a runtime error.
I happened to solve the issue by uninstall torch and flash attention, then pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cu124 flash-attn --no-build-isolation. The issue was caused by incompatibility. You may want to try different version of torch and cuda regarding your required settings.
I use this 'model_kwargs={"torch_dtype": torch.float16}' instead of "bfloat16"
model_args['attn_implementation'] = 'flash_attention_2'
model = LlamaForCausalLM.from_pretrained(model_name, **model_args).eval()
adding the flash_attention_2 works for me