Error with custom inference loop with past_key_values

#48
by dimaischenko - opened

I tried to get an answer in falcon-7b discussions, but no one answers, maybe this community can help me.

Writing my own inference loop, I get error when using past_key_values, tensor dimension error or complete nonsense in generation. I debugged default examples from model description with standard model.generate or pipeline, and saw that they don't use past_key_values at all in generation loop.

I would be very glad if you can tell me if the use past_key_values is implemented with a bug or is not supported or I do not understand something? After all, it speeds up inference by several times.

More details in falcon-7b discussion https://huggingface.co/tiiuae/falcon-7b/discussions/17

I tried to get an answer in falcon-7b discussions, but no one answers, maybe this community can help me.

Writing my own inference loop, I get error when using past_key_values, tensor dimension error or complete nonsense in generation. I debugged default examples from model description with standard model.generate or pipeline, and saw that they don't use past_key_values at all in generation loop.

I would be very glad if you can tell me if the use past_key_values is implemented with a bug or is not supported or I do not understand something? After all, it speeds up inference by several times.

More details in falcon-7b discussion https://huggingface.co/tiiuae/falcon-7b/discussions/17

Hi!
I had the exact same problem, and found this: https://huggingface.co/tiiuae/falcon-40b/discussions/47

Hope this helps!

@cchudant thank you! I'm going to test it now and I'll be sure to write here about the results!

@dimaischenko
just to be clear, I have not tested it :)
it's one of:

  1. Either, it is a bug as I pointed out, and the line I referenced is wrong & other changes elsewhere
    I would guess that _, _, kv_length = key_layer.shape would also need to change to _, kv_length, _ = key_layer.shape. Maybe some other stuff.

  2. Or, departure from bloom kv-cache shape is intended, meaning _convert_to_rw_cache and _convert_to_standard_cache need to be changed to swap the key dimensions.
    The comment

# concatenate along seq_length dimension:
#  - key: [batch_size * self.num_heads, head_dim, kv_length]
#  - value: [batch_size * self.num_heads, kv_length, head_dim]

would need to change to

# concatenate along seq_length dimension:
#  - key: [batch_size * self.num_heads, kv_length, head_dim]
#  - value: [batch_size * self.num_heads, kv_length, head_dim]

I have not tested it :) I am currently only interested in benchmarking the inference of the models, and I found this out because my bench broke with kv-cache.
I have not actually checked if it gives me valid outputs.

@cchudant I will carefully sort it out and report back, I also came to similar thoughts and began to swap the dimensions of the tensors manually, this removed the errors of dimensions but it led to "nonsense" in the output of the model. Now I will carefully investigate everything

@cchudant Unfortunately, I still get "nonsense" when generating if I change the dimensions ... something is not right.

Do you happen to know which of the developers can ask a question directly? I can't figure out who to mention to help figure it out. Using past_key_values in decoding speeds up the generation by 2-3 times, and it seems that it is very important to make it work

@cchudant I actually tested on the code from the falcon-7b model, it looks like the code is slightly different between 7b and 40b. I don't have a video card on which I could test 40b model, if you can test this code on it (with corrections on tensor dimensions) would be cool!

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import random

device = torch.device("cuda")
model_id = "tiiuae/falcon-40b"

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map="auto",
).to(device)

tokenizer = AutoTokenizer.from_pretrained(model_id)

text = "We are in the dark forest and want to find some mushrooms. We go to the nearest tree and"

inputs = tokenizer(text, return_tensors="pt").to(device)
input_ids = inputs["input_ids"]

output = None
step = 0

# generation cycle with 20 steps
while step < 20:
    attention_mask = input_ids.new_ones(input_ids.shape)
    past_key_values = None
    
    if output is not None:
        past_key_values = output["past_key_values"]

    ids = model.prepare_inputs_for_generation(input_ids,
                                              past=past_key_values,
                                              attention_mask=attention_mask,
                                              use_cache=True)
                                 
    output = model(**ids)
    
    # get random of 3 most probable tokens and add to input_ids
    top_k = 3
    next_token = random.choice(torch.topk(output.logits[:, -1, :], top_k, dim=-1).indices[0])
    
    input_ids = torch.cat([input_ids, torch.tensor([[next_token]]).to(device)], dim=-1)
    
    step += 1

print(tokenizer.decode(input_ids[0]))

Hi, I am sorry - I don't have the machine anymore :(

No problem, we'll try to wait for answers from someone from the @FalconLLM team. Already many people write that they have the same problems and questions with both 7b ( https://huggingface.co/tiiuae/falcon-7b/discussions/17 ) and 40b models, it seems that this is an important thing

It doesn't look right to me that your decode loop concats the new token onto the previous input_ids. For KV cache inference, you should only pass in the new token

@ColmanTT The thing is that in prepare_inputs_for_generation we get only last token

...
   input_ids = input_ids[:, -1].unsqueeze(-1)
...

This loop works fine with any other model such as gpt2, gptj-6b, bloom, etc.

This comment has been hidden

@ColmanTT Wow! I will test tomorrow and will be sure to report on the results!

Apologies for hiding this, I realize I haven't tested well enough. Let me know how yours goes as well

@dimaischenko I think I solved this problem by passing 'is_causal' of 'F.scaled_dot_product_attention' different value due to 'layer_past'.

if layer_past is not None:
    attn_output = F.scaled_dot_product_attention(
        query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=False
    )
else:
    attn_output = F.scaled_dot_product_attention(
        query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
    )

The generated text turned out to be just fine.

@siemon1996 I tested it, feeling that it has become much better, but it seems that there is still a lot of "nonsense" 🤔 It feels like it's not the 7b model, but the very first small gpt2. I will be testing more, thank you for sharing!

@siemon1996 I think you did not really solved it. This problem maybe caused by a bug in scaled_dot_product_attention: https://github.com/pytorch/pytorch/issues/103082

I change like this

  if layer_past is not None:
               L = query_layer_.shape[-2]
               S = key_layer_.shape[-2]
               attn_mask = torch.ones(L, S, dtype=torch.bool, device=query_layer_.device)
               attn_output = F.scaled_dot_product_attention(
                   query_layer_, key_layer_, value_layer_, attn_mask, 0.0, is_causal=False
               )
           else:
               attn_output = F.scaled_dot_product_attention(
                   query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
               )

@siemon1996 I tested it, feeling that it has become much better, but it seems that there is still a lot of "nonsense" 🤔 It feels like it's not the 7b model, but the very first small gpt2. I will be testing more, thank you for sharing!

It occurs to me that the first several output tokens are correct, compared with run falcon-40b without using past_key_values. Then the 'nonsense' begins, I`ll try to figure out the reason, just letting you know.

There is another issue besides the incorrect causal mask. The rotary embeddings are not position-indexed. Suppose that you are generating the next piece, then the new (non-cache) query/value representations of the last generated piece have shape [batch_size, n_heads, 1, head_dim]. The rotary embeddings are applied like this:

https://huggingface.co/tiiuae/falcon-7b/blob/2f5c3cd4eace6be6c0f12981f377fb35e5bf6ee5/modelling_RW.py#L257

However, that means that the query/key representations for this new piece get the rotary embedding applied for position 0, whereas it needs the rotary embedding for position n+1 (where n is the number of preceding pieces, which may be different per batch item in case that there are padding pieces). Here is a correct use of rotary embeddings in this model when a cache is used:

https://github.com/explosion/curated-transformers/blob/b44a0fa24c64844909656a0fa9eb4d5acc6af142/curated_transformers/models/attention.py#L315

Example of text generated with the 7b instruction tuned model, with correct causal masks, correct indexing, caching, and deterministic decoding:

Prompt: What is the Rust programming language?

Answer: Rust is a programming language that is designed to be a safe, concurrent, and efficient replacement for C++. It is a statically-typed language that is designed to be memory-safe and thread-safe, making it a good choice for developing high-performance applications.

@siemon1996 @Tron2060 @dimaischenko @danieldk-explosion

I managed to compile a quick script incorporating all the changes mentioned on this thread - use of the correct past_key_values variable instead of past, adding the correct parameters in the scaled_dot_product_attention call, addition of the right implementation of rotary embeddings from curated-transformers, and minor fixes for contrastive search. You can find the gist here. I've only tested for base 10000 and fraction 1 as the rotary embeddings params, please feel free to experiment with different values. Additionally, as the modifications were done rather hastily, kindly overlook any hard-coded/non-parametrised values that I might have added.

This seems to be generating expected results, and the speed of generation is much higher with the adoption of past_key_values. As a result, the speed is not overly dependent on the value of max_new_tokens now, as was the case with the default model. I've done these experiments on the 7b variant

Hope this helps.

@sekharvth Trying your modelling_RW.py I get an error with types

--> 509    attn_output = F.scaled_dot_product_attention(
    510        query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
    511    )
    512 #attn_output = F.scaled_dot_product_attention(
    513 #    query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
    514 #)
    516 x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)

RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype: float key.dtype: float and value.dtype: c10::BFloat16 instead.

@sekharvth Trying your modelling_RW.py I get an error with types

--> 509    attn_output = F.scaled_dot_product_attention(
    510        query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
    511    )
    512 #attn_output = F.scaled_dot_product_attention(
    513 #    query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
    514 #)
    516 x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)

RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype: float key.dtype: float and value.dtype: c10::BFloat16 instead.

Ahh apologies, forgot to mention that this doesn't work on fp16. When you're loading the model, ensure that you're not using torch_dtype=torch.float16 or torch_dtype=torch.bfloat16

@sekharvth is this some kind of fundamental limitation on fp16? Or do I just need to do something else to make it work?

@dimaischenko From what I remember, the precision errors were raised from the rotary embeddings implementation, specifically the sine-cosine transformation parts. The transformations didn't support half precision then, and I didn't experiment with it further to get it to work for fp16, as my immediate goal was to get the whole thing working any way it could.

You may try out some data transformations internally in the script and check if it works for fp16.

Sign up or log in to comment