Changes in modelling_RW.py to be able to handle past_key_values for faster model generations
Browse filesThe current code has missed out passing past_key_values in every forward pass for fast generation of tokens. This results in lot of recompute. This "modelling_RW.py" I am uploading deals with this in the way pytorch huggingface transformers package generation/utils.py wants. All the changes are basically around including past_key_values everywhere. I think this will apply on all falcon models These are the changes specifically
1) Class RotaryEmbedding forward method
Include past_seq_length in forward pass and apply rotary embedding according to the position of the query token ---- if else condition added (line number 100-103)
2) _make_causal_mask function
to give masking according to the way F.scaled dot product attention behaves. F.scaled_dot_product attention treats the attention_mask matrix as receiving attentions. For example if attention_mask is
[[True, False], [True, True]]. It would mean the first token is "receiving" attentions from first token and not second token. This is unlike what we generally end up thinking which is first token is giving attention to itself and not to the second one. Due to reason the past_key_values attentions are all True in make_causal mask function. Also I have reversed the inequality above that due to the same reason. ---- (line number 114 inequality, line number 117 attention mask to be True)
3) Class Attention forward method
a) past_key_value length is passed in rotary function ---- if,else loop added (line number 271-277)
b) concatenation of past key and current key is done after permuting the past key shape to match the current key shape ---- (line number 280-284)
c) to keep key_layer shape consistent with the output expectation which is (batch_size, head_dim, seq_length), another permutation done before creating "present" to return in the output ---- (line number 289-293)
4) RW Model prepare_attn_mask
Have removed src_length > 1 criteria for making causal mask (line number 554).
5) RW causal LM prepare inputs for generation
Read pastkey values from the input coming from huggingface generate method and dont call convert_to_rw_cache method (line number 740-748)
- modelling_RW.py +72 -36
@@ -11,7 +11,9 @@ import torch.utils.checkpoint
|
|
11 |
from torch import nn
|
12 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
|
13 |
from torch.nn import functional as F
|
14 |
-
|
|
|
|
|
15 |
from transformers.modeling_outputs import (
|
16 |
BaseModelOutputWithPastAndCrossAttentions,
|
17 |
CausalLMOutputWithCrossAttentions,
|
@@ -87,10 +89,19 @@ class RotaryEmbedding(torch.nn.Module):
|
|
87 |
|
88 |
return self.cos_cached, self.sin_cached
|
89 |
|
90 |
-
def forward(self, q, k):
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
92 |
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
93 |
-
|
|
|
|
|
|
|
|
|
94 |
|
95 |
|
96 |
def _make_causal_mask(
|
@@ -100,10 +111,10 @@ def _make_causal_mask(
|
|
100 |
mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
|
101 |
# ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
|
102 |
seq_ids = torch.arange(target_length, device=device)
|
103 |
-
mask[:, past_key_values_length:] = seq_ids[:, None]
|
104 |
|
105 |
if past_key_values_length > 0:
|
106 |
-
mask[:, :past_key_values_length] =
|
107 |
|
108 |
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
|
109 |
return expanded_mask
|
@@ -150,6 +161,10 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
|
|
150 |
out = residual + out
|
151 |
return out
|
152 |
|
|
|
|
|
|
|
|
|
153 |
|
154 |
class Attention(nn.Module):
|
155 |
def __init__(self, config: RWConfig):
|
@@ -239,9 +254,8 @@ class Attention(nn.Module):
|
|
239 |
use_cache: bool = False,
|
240 |
output_attentions: bool = False,
|
241 |
):
|
|
|
242 |
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
243 |
-
|
244 |
-
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
245 |
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
246 |
|
247 |
batch_size, q_length, _, _ = query_layer.shape
|
@@ -254,20 +268,27 @@ class Attention(nn.Module):
|
|
254 |
)
|
255 |
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
|
256 |
|
257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
|
259 |
if layer_past is not None:
|
260 |
past_key, past_value = layer_past
|
261 |
-
|
262 |
-
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
263 |
-
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
264 |
key_layer = torch.cat((past_key, key_layer), dim=1)
|
265 |
value_layer = torch.cat((past_value, value_layer), dim=1)
|
|
|
266 |
|
267 |
_, kv_length, _ = key_layer.shape
|
268 |
|
269 |
if use_cache is True:
|
270 |
-
|
|
|
271 |
else:
|
272 |
present = None
|
273 |
|
@@ -275,10 +296,16 @@ class Attention(nn.Module):
|
|
275 |
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
276 |
key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
277 |
value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
|
|
278 |
|
279 |
-
|
280 |
-
|
281 |
-
|
|
|
|
|
|
|
|
|
|
|
282 |
|
283 |
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
284 |
x = x.permute(0, 2, 1, 3)
|
@@ -475,8 +502,8 @@ class RWPreTrainedModel(PreTrainedModel):
|
|
475 |
def _convert_to_rw_cache(
|
476 |
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
|
477 |
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
478 |
-
batch_size,
|
479 |
-
batch_size_times_num_heads = batch_size
|
480 |
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
|
481 |
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
|
482 |
return tuple(
|
@@ -488,6 +515,7 @@ class RWPreTrainedModel(PreTrainedModel):
|
|
488 |
)
|
489 |
|
490 |
|
|
|
491 |
class RWModel(RWPreTrainedModel):
|
492 |
def __init__(self, config: RWConfig):
|
493 |
super().__init__(config)
|
@@ -522,10 +550,11 @@ class RWModel(RWPreTrainedModel):
|
|
522 |
device = attention_mask.device
|
523 |
_, src_length = input_shape
|
524 |
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
|
|
529 |
|
530 |
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
531 |
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
@@ -560,7 +589,7 @@ class RWModel(RWPreTrainedModel):
|
|
560 |
)
|
561 |
if len(deprecated_arguments) > 0:
|
562 |
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
563 |
-
|
564 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
565 |
output_hidden_states = (
|
566 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
@@ -616,6 +645,7 @@ class RWModel(RWPreTrainedModel):
|
|
616 |
input_shape=(batch_size, seq_length),
|
617 |
past_key_values_length=past_key_values_length,
|
618 |
)
|
|
|
619 |
|
620 |
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
621 |
|
@@ -646,16 +676,18 @@ class RWModel(RWPreTrainedModel):
|
|
646 |
)
|
647 |
else:
|
648 |
outputs = block(
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
)
|
|
|
657 |
|
658 |
hidden_states = outputs[0]
|
|
|
659 |
if use_cache is True:
|
660 |
presents = presents + (outputs[1],)
|
661 |
|
@@ -704,16 +736,20 @@ class RWForCausalLM(RWPreTrainedModel):
|
|
704 |
**kwargs,
|
705 |
) -> dict:
|
706 |
# only last token for input_ids if past is not None
|
707 |
-
if past
|
|
|
708 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
709 |
-
|
710 |
# the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
|
711 |
-
if
|
712 |
-
|
|
|
|
|
|
|
713 |
|
714 |
return {
|
715 |
"input_ids": input_ids,
|
716 |
-
"past_key_values":
|
717 |
"use_cache": kwargs.get("use_cache"),
|
718 |
"attention_mask": attention_mask,
|
719 |
}
|
|
|
11 |
from torch import nn
|
12 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
|
13 |
from torch.nn import functional as F
|
14 |
+
import pdb
|
15 |
+
import os
|
16 |
+
import pickle
|
17 |
from transformers.modeling_outputs import (
|
18 |
BaseModelOutputWithPastAndCrossAttentions,
|
19 |
CausalLMOutputWithCrossAttentions,
|
|
|
89 |
|
90 |
return self.cos_cached, self.sin_cached
|
91 |
|
92 |
+
def forward(self, q, k, past_seq_length=None):
|
93 |
+
if past_seq_length == None :
|
94 |
+
batch, seq_len, head_dim = q.shape
|
95 |
+
else :
|
96 |
+
# print("past_seq_length", past_seq_length)
|
97 |
+
batch, input_seq_len, head_dim = q.shape
|
98 |
+
seq_len = past_seq_length + input_seq_len
|
99 |
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
100 |
+
if past_seq_length != None :
|
101 |
+
return (q * cos[:, past_seq_length:, :]) + (rotate_half(q) * sin[:, past_seq_length:, :]), (k * cos[:, past_seq_length:, :]) + (rotate_half(k) * sin[:, past_seq_length:, :])
|
102 |
+
else :
|
103 |
+
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
104 |
+
|
105 |
|
106 |
|
107 |
def _make_causal_mask(
|
|
|
111 |
mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
|
112 |
# ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
|
113 |
seq_ids = torch.arange(target_length, device=device)
|
114 |
+
mask[:, past_key_values_length:] = seq_ids[:, None] >= seq_ids[None, :]
|
115 |
|
116 |
if past_key_values_length > 0:
|
117 |
+
mask[:, :past_key_values_length] = True
|
118 |
|
119 |
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
|
120 |
return expanded_mask
|
|
|
161 |
out = residual + out
|
162 |
return out
|
163 |
|
164 |
+
def dump_value(name, tensor) :
|
165 |
+
with open("/home/purushottam/inspect_falcon/{}".format(name), "wb") as f :
|
166 |
+
pickle.dump(tensor, f)
|
167 |
+
|
168 |
|
169 |
class Attention(nn.Module):
|
170 |
def __init__(self, config: RWConfig):
|
|
|
254 |
use_cache: bool = False,
|
255 |
output_attentions: bool = False,
|
256 |
):
|
257 |
+
|
258 |
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
|
|
|
|
259 |
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
260 |
|
261 |
batch_size, q_length, _, _ = query_layer.shape
|
|
|
268 |
)
|
269 |
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
|
270 |
|
271 |
+
if layer_past is not None :
|
272 |
+
past_key, past_value = layer_past
|
273 |
+
past_kv_length = past_key.shape[2]
|
274 |
+
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
|
275 |
+
else :
|
276 |
+
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
277 |
+
|
278 |
+
|
279 |
|
280 |
if layer_past is not None:
|
281 |
past_key, past_value = layer_past
|
282 |
+
past_key = past_key.permute(0, 2, 1)
|
|
|
|
|
283 |
key_layer = torch.cat((past_key, key_layer), dim=1)
|
284 |
value_layer = torch.cat((past_value, value_layer), dim=1)
|
285 |
+
|
286 |
|
287 |
_, kv_length, _ = key_layer.shape
|
288 |
|
289 |
if use_cache is True:
|
290 |
+
key_layer_permute = key_layer.permute(0, 2, 1)
|
291 |
+
present = (key_layer_permute, value_layer)
|
292 |
else:
|
293 |
present = None
|
294 |
|
|
|
296 |
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
297 |
key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
298 |
value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
299 |
+
|
300 |
|
301 |
+
if attention_mask is not None :
|
302 |
+
attn_output = F.scaled_dot_product_attention(
|
303 |
+
query_layer_, key_layer_, value_layer_, attention_mask, 0.0, is_causal=False
|
304 |
+
)
|
305 |
+
else :
|
306 |
+
attn_output = F.scaled_dot_product_attention(
|
307 |
+
query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
|
308 |
+
)
|
309 |
|
310 |
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
311 |
x = x.permute(0, 2, 1, 3)
|
|
|
502 |
def _convert_to_rw_cache(
|
503 |
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
|
504 |
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
505 |
+
batch_size, seq_length, head_dim = past_key_value[0][0].shape
|
506 |
+
batch_size_times_num_heads = batch_size
|
507 |
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
|
508 |
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
|
509 |
return tuple(
|
|
|
515 |
)
|
516 |
|
517 |
|
518 |
+
|
519 |
class RWModel(RWPreTrainedModel):
|
520 |
def __init__(self, config: RWConfig):
|
521 |
super().__init__(config)
|
|
|
550 |
device = attention_mask.device
|
551 |
_, src_length = input_shape
|
552 |
|
553 |
+
|
554 |
+
# if src_length > 1:
|
555 |
+
combined_attention_mask = _make_causal_mask(
|
556 |
+
input_shape, device=device, past_key_values_length=past_key_values_length
|
557 |
+
)
|
558 |
|
559 |
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
560 |
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
|
|
589 |
)
|
590 |
if len(deprecated_arguments) > 0:
|
591 |
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
592 |
+
# pdb.set_trace()
|
593 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
594 |
output_hidden_states = (
|
595 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
645 |
input_shape=(batch_size, seq_length),
|
646 |
past_key_values_length=past_key_values_length,
|
647 |
)
|
648 |
+
# print("causal_mask", causal_mask)
|
649 |
|
650 |
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
651 |
|
|
|
676 |
)
|
677 |
else:
|
678 |
outputs = block(
|
679 |
+
hidden_states,
|
680 |
+
layer_past=layer_past,
|
681 |
+
attention_mask=causal_mask,
|
682 |
+
head_mask=head_mask[i],
|
683 |
+
use_cache=use_cache,
|
684 |
+
output_attentions=output_attentions,
|
685 |
+
alibi=alibi,
|
686 |
+
)
|
687 |
+
|
688 |
|
689 |
hidden_states = outputs[0]
|
690 |
+
|
691 |
if use_cache is True:
|
692 |
presents = presents + (outputs[1],)
|
693 |
|
|
|
736 |
**kwargs,
|
737 |
) -> dict:
|
738 |
# only last token for input_ids if past is not None
|
739 |
+
# only last token for input_ids if past is not None
|
740 |
+
if kwargs.get("past_key_values", None) :
|
741 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
742 |
+
past_key_values = kwargs["past_key_values"]
|
743 |
# the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
|
744 |
+
# if kwargs["past_key_values"][0][0].shape[0] == input_ids.shape[0]:
|
745 |
+
# past_key_values = self._convert_to_rw_cache(kwargs["past_key_values"])
|
746 |
+
# past_key_values = kwargs["past_key_values"]
|
747 |
+
else :
|
748 |
+
past_key_values = None
|
749 |
|
750 |
return {
|
751 |
"input_ids": input_ids,
|
752 |
+
"past_key_values": past_key_values,
|
753 |
"use_cache": kwargs.get("use_cache"),
|
754 |
"attention_mask": attention_mask,
|
755 |
}
|