Flash attention NVCC requirements
Is there a way to run the code without nvcc, as modeling_flash_llama.py uses flash attention which needs nvcc. But the same is not the requirements on the licensed LLAMA2 model.
Seems you've solved this yourself? For future reference, running with trust_remote_code=False
disables the use of custom flash-attention and should work on any hardware. Most recently, transformers has integrated support for flash-attention-2 with use_flash_attention_2=True
. Hope this is helpful :)
Yes thanks seems to work. I would check flash attention for faster inference.
Also is there a way I could check the perplexity of the model to compare its performance in different use cases?
There will be more detailed perplexity evaluations in our paper. For now though, I can share some code that should help you get started with perplexity evaluation. If you don't want to implement it yourself, check out something like text-generation-webui
for prebuilt perplexity evals.
# Compute perplexity
nlls = []
for sample in dataset:
with torch.no_grad():
input_ids = torch.tensor(sample ['input_ids']).unsqueeze(0).to(model.device)
attention_mask = torch.tensor(sample ['attention_mask']).unsqueeze(0).to(model.device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids.clone())
loss = outputs.loss
nlls.append(loss.cpu().item())
nlls = torch.tensor(nlls)
perplexity = torch.exp(nlls.mean())
dataset
here is a pretokenized dataset iterator. Hope this is helpful :)