JustinLin610
commited on
Commit
•
60964f8
1
Parent(s):
d431257
Upload 2 files
Browse files- README.md +19 -10
- modeling_qwen.py +67 -54
README.md
CHANGED
@@ -129,16 +129,29 @@ config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat-Int4", trust_remote
|
|
129 |
response, history = model.chat(tokenizer, "Hi", history=None, generation_config=config)
|
130 |
```
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
### 推理速度 (Inference Speed)
|
133 |
|
134 |
-
我们测算了BF16和Int4模型生成8192个token的平均推理速度。如图所示:
|
135 |
|
136 |
-
We measured the average inference speed of generating 8192 tokens under BF16 precision and Int4 quantization level, respectively.
|
137 |
|
138 |
-
|
|
139 |
-
|
|
140 |
-
|
|
141 |
-
|
|
142 |
|
143 |
具体而言,我们记录在长度为1的上下文的条件下生成8192个token的性能。评测运行于单张A100-SXM4-80G GPU,使用PyTorch 2.0.1和CUDA 11.4。推理速度是生成8192个token的速度均值。
|
144 |
|
@@ -159,8 +172,6 @@ We also profile the peak GPU memory usage for encoding 2048 tokens as context (a
|
|
159 |
|
160 |
The above speed and memory profiling are conducted using [this script](https://qianwen-res.oss-cn-beijing.aliyuncs.com/profile.py).
|
161 |
|
162 |
-
|
163 |
-
|
164 |
## Tokenizer
|
165 |
|
166 |
> 注:作为术语的“tokenization”在中文中尚无共识的概念对应,本文档采用英文表达以利说明。
|
@@ -345,7 +356,6 @@ Qwen-7B-Chat also has the capability to be used as a [HuggingFace Agent](https:/
|
|
345 |
| StarCoder-15.5B | 87.04 | 87.96 | 68.89 |
|
346 |
| **Qwen-7B** | 90.74 | 92.59 | 74.07 |
|
347 |
|
348 |
-
|
349 |
## FAQ
|
350 |
|
351 |
如遇到问题,敬请查阅[FAQ](https://github.com/QwenLM/Qwen-7B/blob/main/FAQ_zh.md)以及issue区,如仍无法解决再提交issue。
|
@@ -364,4 +374,3 @@ Our code and checkpoints are open to research purpose, and they are allowed for
|
|
364 |
|
365 |
If you are interested to leave a message to either our research team or product team, feel free to send an email to [email protected].
|
366 |
|
367 |
-
|
|
|
129 |
response, history = model.chat(tokenizer, "Hi", history=None, generation_config=config)
|
130 |
```
|
131 |
|
132 |
+
### 效果评测
|
133 |
+
|
134 |
+
我们对BF16和Int4模型在基准评测上做了测试,发现量化模型效果损失较小,结果如下所示:
|
135 |
+
|
136 |
+
We illustrate the model performance of both BF16 and Int4 models on the benchmark, and we find that the quantized model does not suffer from significant performance degradation. Results are shown below:
|
137 |
+
|
138 |
+
| Quantization | MMLU | CEval (val) | GSM8K | Humaneval |
|
139 |
+
| ------------- | :--------: | :----------: | :----: | :--------: |
|
140 |
+
| BF16 | 53.9 | 54.2 | 41.1 | 24.4 |
|
141 |
+
| Int4 | 52.6 | 52.9 | 38.1 | 23.8 |
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
### 推理速度 (Inference Speed)
|
146 |
|
147 |
+
我们测算了BF16和Int4模型生成2048和8192个token的平均推理速度。如图所示:
|
148 |
|
149 |
+
We measured the average inference speed of generating 2048 and 8192 tokens under BF16 precision and Int4 quantization level, respectively.
|
150 |
|
151 |
+
| Quantization | Speed (2048 tokens) | Speed (8192 tokens) |
|
152 |
+
| ------------- | :------------------:| :------------------:|
|
153 |
+
| BF16 | 30.53 | 28.51 |
|
154 |
+
| Int4 | 45.60 | 33.83 |
|
155 |
|
156 |
具体而言,我们记录在长度为1的上下文的条件下生成8192个token的性能。评测运行于单张A100-SXM4-80G GPU,使用PyTorch 2.0.1和CUDA 11.4。推理速度是生成8192个token的速度均值。
|
157 |
|
|
|
172 |
|
173 |
The above speed and memory profiling are conducted using [this script](https://qianwen-res.oss-cn-beijing.aliyuncs.com/profile.py).
|
174 |
|
|
|
|
|
175 |
## Tokenizer
|
176 |
|
177 |
> 注:作为术语的“tokenization”在中文中尚无共识的概念对应,本文档采用英文表达以利说明。
|
|
|
356 |
| StarCoder-15.5B | 87.04 | 87.96 | 68.89 |
|
357 |
| **Qwen-7B** | 90.74 | 92.59 | 74.07 |
|
358 |
|
|
|
359 |
## FAQ
|
360 |
|
361 |
如遇到问题,敬请查阅[FAQ](https://github.com/QwenLM/Qwen-7B/blob/main/FAQ_zh.md)以及issue区,如仍无法解决再提交issue。
|
|
|
374 |
|
375 |
If you are interested to leave a message to either our research team or product team, feel free to send an email to [email protected].
|
376 |
|
|
modeling_qwen.py
CHANGED
@@ -131,6 +131,7 @@ class FlashSelfAttention(torch.nn.Module):
|
|
131 |
assert all((i.is_cuda for i in (q, k, v)))
|
132 |
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
133 |
seqlen_k = k.shape[1]
|
|
|
134 |
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
135 |
cu_seqlens_q = torch.arange(
|
136 |
0,
|
@@ -155,6 +156,7 @@ class FlashSelfAttention(torch.nn.Module):
|
|
155 |
device=q.device,
|
156 |
)
|
157 |
self.dropout_p = 0
|
|
|
158 |
output = flash_attn_unpadded_func(
|
159 |
q,
|
160 |
k,
|
@@ -168,7 +170,8 @@ class FlashSelfAttention(torch.nn.Module):
|
|
168 |
causal=is_causal,
|
169 |
)
|
170 |
|
171 |
-
|
|
|
172 |
return output
|
173 |
|
174 |
|
@@ -220,19 +223,6 @@ class QWenAttention(nn.Module):
|
|
220 |
|
221 |
self.bf16 = config.bf16
|
222 |
|
223 |
-
if config.rotary_pct == 1.0:
|
224 |
-
self.rotary_ndims = None
|
225 |
-
else:
|
226 |
-
assert config.rotary_pct < 1
|
227 |
-
self.rotary_ndims = int(
|
228 |
-
self.hidden_size_per_attention_head * config.rotary_pct
|
229 |
-
)
|
230 |
-
dim = (
|
231 |
-
self.rotary_ndims
|
232 |
-
if self.rotary_ndims is not None
|
233 |
-
else self.hidden_size_per_attention_head
|
234 |
-
)
|
235 |
-
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
|
236 |
|
237 |
self.use_dynamic_ntk = config.use_dynamic_ntk
|
238 |
self.use_logn_attn = config.use_logn_attn
|
@@ -242,7 +232,6 @@ class QWenAttention(nn.Module):
|
|
242 |
for i in range(1, 32768)
|
243 |
]
|
244 |
self.logn_tensor = torch.tensor(logn_list)[None, :, None, None]
|
245 |
-
self._ntk_cached = 1.0
|
246 |
|
247 |
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
|
248 |
|
@@ -351,6 +340,7 @@ class QWenAttention(nn.Module):
|
|
351 |
def forward(
|
352 |
self,
|
353 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
|
|
354 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
355 |
attention_mask: Optional[torch.FloatTensor] = None,
|
356 |
head_mask: Optional[torch.FloatTensor] = None,
|
@@ -361,43 +351,19 @@ class QWenAttention(nn.Module):
|
|
361 |
):
|
362 |
|
363 |
mixed_x_layer = self.c_attn(hidden_states)
|
|
|
364 |
query, key, value = mixed_x_layer.split(self.split_size, dim=2)
|
365 |
|
366 |
query = self._split_heads(query, self.num_heads, self.head_dim)
|
367 |
key = self._split_heads(key, self.num_heads, self.head_dim)
|
368 |
value = self._split_heads(value, self.num_heads, self.head_dim)
|
369 |
|
370 |
-
kv_seq_len = hidden_states.size()[1]
|
371 |
-
if layer_past:
|
372 |
-
# layer past[0] shape: bs * seq_len * head_num * dim
|
373 |
-
kv_seq_len += layer_past[0].shape[1]
|
374 |
-
if (
|
375 |
-
self.use_dynamic_ntk
|
376 |
-
and kv_seq_len == hidden_states.size()[1]
|
377 |
-
and not self.training
|
378 |
-
):
|
379 |
-
context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
|
380 |
-
ntk_alpha = 2 ** math.ceil(context_value) - 1
|
381 |
-
ntk_alpha = max(ntk_alpha, 1)
|
382 |
-
self._ntk_cached = ntk_alpha
|
383 |
-
else:
|
384 |
-
ntk_alpha = self._ntk_cached
|
385 |
-
rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha).to(
|
386 |
-
hidden_states.device
|
387 |
-
)
|
388 |
-
|
389 |
-
if rotary_pos_emb is not None:
|
390 |
-
if isinstance(rotary_pos_emb, tuple):
|
391 |
-
rotary_pos_emb = rotary_pos_emb
|
392 |
-
else:
|
393 |
-
rotary_pos_emb = (rotary_pos_emb,) * 2
|
394 |
-
|
395 |
if rotary_pos_emb is not None:
|
|
|
|
|
|
|
396 |
q_pos_emb, k_pos_emb = rotary_pos_emb
|
397 |
# Slice the pos emb for current inference
|
398 |
-
cur_len = query.shape[1]
|
399 |
-
q_pos_emb = q_pos_emb[:, -cur_len:, :, :]
|
400 |
-
k_pos_emb = k_pos_emb[:, -cur_len:, :, :]
|
401 |
query = apply_rotary_pos_emb(query, q_pos_emb)
|
402 |
key = apply_rotary_pos_emb(key, k_pos_emb)
|
403 |
|
@@ -428,9 +394,9 @@ class QWenAttention(nn.Module):
|
|
428 |
q, k, v = query, key, value
|
429 |
context_layer = self.core_attention_flash(q, k, v)
|
430 |
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
else:
|
435 |
query = query.permute(0, 2, 1, 3)
|
436 |
key = key.permute(0, 2, 1, 3)
|
@@ -443,6 +409,7 @@ class QWenAttention(nn.Module):
|
|
443 |
)
|
444 |
|
445 |
attn_output = self.c_proj(context_layer)
|
|
|
446 |
outputs = (attn_output, present)
|
447 |
if output_attentions:
|
448 |
if (
|
@@ -476,7 +443,6 @@ class QWenMLP(nn.Module):
|
|
476 |
output = self.c_proj(intermediate_parallel)
|
477 |
return output
|
478 |
|
479 |
-
|
480 |
class QWenBlock(nn.Module):
|
481 |
def __init__(self, config):
|
482 |
super().__init__()
|
@@ -498,6 +464,7 @@ class QWenBlock(nn.Module):
|
|
498 |
def forward(
|
499 |
self,
|
500 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
|
|
501 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
502 |
attention_mask: Optional[torch.FloatTensor] = None,
|
503 |
head_mask: Optional[torch.FloatTensor] = None,
|
@@ -510,6 +477,7 @@ class QWenBlock(nn.Module):
|
|
510 |
|
511 |
attn_outputs = self.attn(
|
512 |
layernorm_output,
|
|
|
513 |
layer_past=layer_past,
|
514 |
attention_mask=attention_mask,
|
515 |
head_mask=head_mask,
|
@@ -585,10 +553,28 @@ class QWenModel(QWenPreTrainedModel):
|
|
585 |
self.embed_dim = config.hidden_size
|
586 |
|
587 |
self.gradient_checkpointing = False
|
|
|
|
|
588 |
|
589 |
self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
|
590 |
|
591 |
self.drop = nn.Dropout(config.emb_dropout_prob)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
592 |
self.h = nn.ModuleList(
|
593 |
[
|
594 |
QWenBlock(
|
@@ -692,6 +678,25 @@ class QWenModel(QWenPreTrainedModel):
|
|
692 |
inputs_embeds = self.wte(input_ids)
|
693 |
hidden_states = inputs_embeds
|
694 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
695 |
hidden_states = self.drop(hidden_states)
|
696 |
output_shape = input_shape + (hidden_states.size(-1),)
|
697 |
|
@@ -722,6 +727,7 @@ class QWenModel(QWenPreTrainedModel):
|
|
722 |
outputs = torch.utils.checkpoint.checkpoint(
|
723 |
create_custom_forward(block),
|
724 |
hidden_states,
|
|
|
725 |
None,
|
726 |
attention_mask,
|
727 |
head_mask[i],
|
@@ -732,6 +738,7 @@ class QWenModel(QWenPreTrainedModel):
|
|
732 |
outputs = block(
|
733 |
hidden_states,
|
734 |
layer_past=layer_past,
|
|
|
735 |
attention_mask=attention_mask,
|
736 |
head_mask=head_mask[i],
|
737 |
encoder_hidden_states=encoder_hidden_states,
|
@@ -1142,14 +1149,19 @@ class RotaryEmbedding(torch.nn.Module):
|
|
1142 |
self._ntk_alpha_cached = ntk_alpha
|
1143 |
seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device)
|
1144 |
freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
|
|
|
1145 |
emb = torch.cat((freqs, freqs), dim=-1)
|
1146 |
from einops import rearrange
|
1147 |
|
1148 |
-
|
|
|
|
|
|
|
1149 |
|
1150 |
def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):
|
1151 |
self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)
|
1152 |
-
|
|
|
1153 |
|
1154 |
|
1155 |
def _rotate_half(x):
|
@@ -1161,19 +1173,20 @@ def _rotate_half(x):
|
|
1161 |
|
1162 |
|
1163 |
def apply_rotary_pos_emb(t, freqs):
|
|
|
1164 |
if apply_rotary_emb_func is not None and t.is_cuda:
|
1165 |
t_ = t.float()
|
1166 |
-
|
1167 |
-
|
1168 |
-
sin = freqs[:, : freqs.shape[-1] // 2].sin()
|
1169 |
output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
|
1170 |
return output
|
1171 |
else:
|
1172 |
-
rot_dim = freqs.shape[-1]
|
|
|
1173 |
t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
|
1174 |
t_ = t_.float()
|
1175 |
t_pass_ = t_pass_.float()
|
1176 |
-
t_ = (t_ *
|
1177 |
return torch.cat((t_, t_pass_), dim=-1).type_as(t)
|
1178 |
|
1179 |
|
|
|
131 |
assert all((i.is_cuda for i in (q, k, v)))
|
132 |
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
133 |
seqlen_k = k.shape[1]
|
134 |
+
|
135 |
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
136 |
cu_seqlens_q = torch.arange(
|
137 |
0,
|
|
|
156 |
device=q.device,
|
157 |
)
|
158 |
self.dropout_p = 0
|
159 |
+
|
160 |
output = flash_attn_unpadded_func(
|
161 |
q,
|
162 |
k,
|
|
|
170 |
causal=is_causal,
|
171 |
)
|
172 |
|
173 |
+
new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:]
|
174 |
+
output = output.view(new_shape)
|
175 |
return output
|
176 |
|
177 |
|
|
|
223 |
|
224 |
self.bf16 = config.bf16
|
225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
|
227 |
self.use_dynamic_ntk = config.use_dynamic_ntk
|
228 |
self.use_logn_attn = config.use_logn_attn
|
|
|
232 |
for i in range(1, 32768)
|
233 |
]
|
234 |
self.logn_tensor = torch.tensor(logn_list)[None, :, None, None]
|
|
|
235 |
|
236 |
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
|
237 |
|
|
|
340 |
def forward(
|
341 |
self,
|
342 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
343 |
+
rotary_pos_emb: Optional[List[torch.Tensor]] = None,
|
344 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
345 |
attention_mask: Optional[torch.FloatTensor] = None,
|
346 |
head_mask: Optional[torch.FloatTensor] = None,
|
|
|
351 |
):
|
352 |
|
353 |
mixed_x_layer = self.c_attn(hidden_states)
|
354 |
+
|
355 |
query, key, value = mixed_x_layer.split(self.split_size, dim=2)
|
356 |
|
357 |
query = self._split_heads(query, self.num_heads, self.head_dim)
|
358 |
key = self._split_heads(key, self.num_heads, self.head_dim)
|
359 |
value = self._split_heads(value, self.num_heads, self.head_dim)
|
360 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
if rotary_pos_emb is not None:
|
362 |
+
cur_len = query.shape[1]
|
363 |
+
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
364 |
+
rotary_pos_emb = (rotary_pos_emb,) * 2
|
365 |
q_pos_emb, k_pos_emb = rotary_pos_emb
|
366 |
# Slice the pos emb for current inference
|
|
|
|
|
|
|
367 |
query = apply_rotary_pos_emb(query, q_pos_emb)
|
368 |
key = apply_rotary_pos_emb(key, k_pos_emb)
|
369 |
|
|
|
394 |
q, k, v = query, key, value
|
395 |
context_layer = self.core_attention_flash(q, k, v)
|
396 |
|
397 |
+
# b s h d -> b s (h d)
|
398 |
+
context_layer = context_layer.flatten(2,3).contiguous()
|
399 |
+
|
400 |
else:
|
401 |
query = query.permute(0, 2, 1, 3)
|
402 |
key = key.permute(0, 2, 1, 3)
|
|
|
409 |
)
|
410 |
|
411 |
attn_output = self.c_proj(context_layer)
|
412 |
+
|
413 |
outputs = (attn_output, present)
|
414 |
if output_attentions:
|
415 |
if (
|
|
|
443 |
output = self.c_proj(intermediate_parallel)
|
444 |
return output
|
445 |
|
|
|
446 |
class QWenBlock(nn.Module):
|
447 |
def __init__(self, config):
|
448 |
super().__init__()
|
|
|
464 |
def forward(
|
465 |
self,
|
466 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
467 |
+
rotary_pos_emb: Optional[List[torch.Tensor]] = None,
|
468 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
469 |
attention_mask: Optional[torch.FloatTensor] = None,
|
470 |
head_mask: Optional[torch.FloatTensor] = None,
|
|
|
477 |
|
478 |
attn_outputs = self.attn(
|
479 |
layernorm_output,
|
480 |
+
rotary_pos_emb,
|
481 |
layer_past=layer_past,
|
482 |
attention_mask=attention_mask,
|
483 |
head_mask=head_mask,
|
|
|
553 |
self.embed_dim = config.hidden_size
|
554 |
|
555 |
self.gradient_checkpointing = False
|
556 |
+
self.use_dynamic_ntk = config.use_dynamic_ntk
|
557 |
+
self.seq_length = config.seq_length
|
558 |
|
559 |
self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
|
560 |
|
561 |
self.drop = nn.Dropout(config.emb_dropout_prob)
|
562 |
+
|
563 |
+
|
564 |
+
if config.rotary_pct == 1.0:
|
565 |
+
self.rotary_ndims = None
|
566 |
+
else:
|
567 |
+
assert config.rotary_pct < 1
|
568 |
+
self.rotary_ndims = int(
|
569 |
+
config.kv_channels * config.rotary_pct
|
570 |
+
)
|
571 |
+
dim = (
|
572 |
+
self.rotary_ndims
|
573 |
+
if self.rotary_ndims is not None
|
574 |
+
else config.kv_channels
|
575 |
+
)
|
576 |
+
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
|
577 |
+
|
578 |
self.h = nn.ModuleList(
|
579 |
[
|
580 |
QWenBlock(
|
|
|
678 |
inputs_embeds = self.wte(input_ids)
|
679 |
hidden_states = inputs_embeds
|
680 |
|
681 |
+
kv_seq_len = hidden_states.size()[1]
|
682 |
+
if past_key_values[0] is not None:
|
683 |
+
# past key values[0][0] shape: bs * seq_len * head_num * dim
|
684 |
+
kv_seq_len += past_key_values[0][0].shape[1]
|
685 |
+
if (
|
686 |
+
self.use_dynamic_ntk
|
687 |
+
and kv_seq_len == hidden_states.size()[1]
|
688 |
+
and not self.training
|
689 |
+
):
|
690 |
+
context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
|
691 |
+
ntk_alpha = 2 ** math.ceil(context_value) - 1
|
692 |
+
ntk_alpha = max(ntk_alpha, 1)
|
693 |
+
else:
|
694 |
+
ntk_alpha = self.rotary_emb._ntk_alpha_cached
|
695 |
+
|
696 |
+
rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha)
|
697 |
+
for idx in range(len(rotary_pos_emb)):
|
698 |
+
rotary_pos_emb[idx] = rotary_pos_emb[idx].to(hidden_states.device)
|
699 |
+
|
700 |
hidden_states = self.drop(hidden_states)
|
701 |
output_shape = input_shape + (hidden_states.size(-1),)
|
702 |
|
|
|
727 |
outputs = torch.utils.checkpoint.checkpoint(
|
728 |
create_custom_forward(block),
|
729 |
hidden_states,
|
730 |
+
rotary_pos_emb,
|
731 |
None,
|
732 |
attention_mask,
|
733 |
head_mask[i],
|
|
|
738 |
outputs = block(
|
739 |
hidden_states,
|
740 |
layer_past=layer_past,
|
741 |
+
rotary_pos_emb=rotary_pos_emb,
|
742 |
attention_mask=attention_mask,
|
743 |
head_mask=head_mask[i],
|
744 |
encoder_hidden_states=encoder_hidden_states,
|
|
|
1149 |
self._ntk_alpha_cached = ntk_alpha
|
1150 |
seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device)
|
1151 |
freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
|
1152 |
+
|
1153 |
emb = torch.cat((freqs, freqs), dim=-1)
|
1154 |
from einops import rearrange
|
1155 |
|
1156 |
+
emb = rearrange(emb, "n d -> 1 n 1 d")
|
1157 |
+
|
1158 |
+
cos, sin = emb.cos(), emb.sin()
|
1159 |
+
self._rotary_pos_emb_cache = [cos, sin]
|
1160 |
|
1161 |
def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):
|
1162 |
self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)
|
1163 |
+
cos, sin = self._rotary_pos_emb_cache
|
1164 |
+
return [cos[:, offset : offset + max_seq_len], sin[:, offset : offset + max_seq_len]]
|
1165 |
|
1166 |
|
1167 |
def _rotate_half(x):
|
|
|
1173 |
|
1174 |
|
1175 |
def apply_rotary_pos_emb(t, freqs):
|
1176 |
+
cos, sin = freqs
|
1177 |
if apply_rotary_emb_func is not None and t.is_cuda:
|
1178 |
t_ = t.float()
|
1179 |
+
cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2]
|
1180 |
+
sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2]
|
|
|
1181 |
output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
|
1182 |
return output
|
1183 |
else:
|
1184 |
+
rot_dim = freqs[0].shape[-1]
|
1185 |
+
cos, sin = freqs
|
1186 |
t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
|
1187 |
t_ = t_.float()
|
1188 |
t_pass_ = t_pass_.float()
|
1189 |
+
t_ = (t_ * cos) + (_rotate_half(t_) * sin)
|
1190 |
return torch.cat((t_, t_pass_), dim=-1).type_as(t)
|
1191 |
|
1192 |
|