phoebeklett
commited on
Commit
•
805432a
1
Parent(s):
646d067
latest file versions
Browse files- attention.py +5 -6
- blocks.py +1 -1
- configuration.py +5 -0
- modeling_mpt.py +15 -19
attention.py
CHANGED
@@ -95,10 +95,10 @@ def scaled_multihead_dot_product_attention(
|
|
95 |
)
|
96 |
attn_weight = attn_weight + attn_bias
|
97 |
|
98 |
-
if needs_weights:
|
99 |
reshaped_idx = None
|
100 |
if long_range_past_key_value is not None or faiss_indexes is not None:
|
101 |
-
if long_range_past_key_value is not None: #manual
|
102 |
|
103 |
k_cache, v_cache = long_range_past_key_value
|
104 |
s_cache = k_cache.size(-1)
|
@@ -134,15 +134,14 @@ def scaled_multihead_dot_product_attention(
|
|
134 |
|
135 |
selected_k=rearrange(torch.tensor(kv_index.reconstruct_batch(I.flatten()))[:,:d], '(h s) d -> 1 h d s', h=32).to(q.device)
|
136 |
selected_v=rearrange(torch.tensor(kv_index.reconstruct_batch(I.flatten()))[:,d:], '(h s) d -> 1 h s d', h=32).to(q.device)
|
137 |
-
|
138 |
s_k_ae = selected_k.size(-1)
|
139 |
s_k += s_k_ae
|
140 |
attn_weight_cache = q.matmul(selected_k) * softmax_scale
|
141 |
if mask_by_sim:
|
142 |
attn_weight_cache = attn_weight_cache.masked_fill(sim_mask, min_val)
|
143 |
|
144 |
-
if attn_bias_ae is not None:
|
145 |
-
# clamp to 0 necessary for torch 2.0 compile()
|
146 |
_s_q = max(0, attn_bias_ae.size(2) - s_q)
|
147 |
_s_k = max(0, attn_bias_ae.size(3) - s_k_ae)
|
148 |
attn_bias_ae = attn_bias_ae[:, :, _s_q:, _s_k:]
|
@@ -710,7 +709,7 @@ def build_attn_bias(
|
|
710 |
for_ae=for_ae,
|
711 |
topk=topk
|
712 |
))
|
713 |
-
else:
|
714 |
attn_bias = build_alibi_bias(
|
715 |
n_heads,
|
716 |
seq_len,
|
|
|
95 |
)
|
96 |
attn_weight = attn_weight + attn_bias
|
97 |
|
98 |
+
if needs_weights: #will return memory indices w/attention weights
|
99 |
reshaped_idx = None
|
100 |
if long_range_past_key_value is not None or faiss_indexes is not None:
|
101 |
+
if long_range_past_key_value is not None: #manual memories
|
102 |
|
103 |
k_cache, v_cache = long_range_past_key_value
|
104 |
s_cache = k_cache.size(-1)
|
|
|
134 |
|
135 |
selected_k=rearrange(torch.tensor(kv_index.reconstruct_batch(I.flatten()))[:,:d], '(h s) d -> 1 h d s', h=32).to(q.device)
|
136 |
selected_v=rearrange(torch.tensor(kv_index.reconstruct_batch(I.flatten()))[:,d:], '(h s) d -> 1 h s d', h=32).to(q.device)
|
137 |
+
|
138 |
s_k_ae = selected_k.size(-1)
|
139 |
s_k += s_k_ae
|
140 |
attn_weight_cache = q.matmul(selected_k) * softmax_scale
|
141 |
if mask_by_sim:
|
142 |
attn_weight_cache = attn_weight_cache.masked_fill(sim_mask, min_val)
|
143 |
|
144 |
+
if attn_bias_ae is not None: #add alibi bias to memories
|
|
|
145 |
_s_q = max(0, attn_bias_ae.size(2) - s_q)
|
146 |
_s_k = max(0, attn_bias_ae.size(3) - s_k_ae)
|
147 |
attn_bias_ae = attn_bias_ae[:, :, _s_q:, _s_k:]
|
|
|
709 |
for_ae=for_ae,
|
710 |
topk=topk
|
711 |
))
|
712 |
+
else: #for memories
|
713 |
attn_bias = build_alibi_bias(
|
714 |
n_heads,
|
715 |
seq_len,
|
blocks.py
CHANGED
@@ -7,7 +7,7 @@
|
|
7 |
from typing import Dict, Optional, Tuple
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
-
from .attention import ATTN_CLASS_REGISTRY
|
11 |
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
|
12 |
|
13 |
class MPTMLP(nn.Module):
|
|
|
7 |
from typing import Dict, Optional, Tuple
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
+
from extended_mind_transformers.mpt.attention import ATTN_CLASS_REGISTRY
|
11 |
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
|
12 |
|
13 |
class MPTMLP(nn.Module):
|
configuration.py
CHANGED
@@ -165,6 +165,11 @@ class ExtendedMPTConfig(PretrainedConfig):
|
|
165 |
init_config_defaults,
|
166 |
)
|
167 |
|
|
|
|
|
|
|
|
|
|
|
168 |
if self.d_model % self.n_heads != 0:
|
169 |
raise ValueError('d_model must be divisible by n_heads')
|
170 |
if any(
|
|
|
165 |
init_config_defaults,
|
166 |
)
|
167 |
|
168 |
+
if self.attn_config['memory_type']=='faiss' and self.attn_config['mask_by_sim'] is True:
|
169 |
+
raise ValueError(
|
170 |
+
'mask_by_sim is not supported for faiss memory type.'
|
171 |
+
)
|
172 |
+
|
173 |
if self.d_model % self.n_heads != 0:
|
174 |
raise ValueError('d_model must be divisible by n_heads')
|
175 |
if any(
|
modeling_mpt.py
CHANGED
@@ -27,10 +27,10 @@ from llmfoundry.models.layers.custom_embedding import SharedEmbedding
|
|
27 |
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
|
28 |
from llmfoundry.models.utils.param_init_fns import MODEL_INIT_REGISTRY
|
29 |
|
30 |
-
from .configuration import ExtendedMPTConfig
|
31 |
-
from .attention import attn_bias_shape, build_attn_bias
|
32 |
-
from .blocks import MPTBlock
|
33 |
-
from .utils import instantiate_from_config
|
34 |
|
35 |
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
36 |
|
@@ -111,7 +111,7 @@ class ExtendedMPTModel(MPTPreTrainedModel):
|
|
111 |
causal=self.is_causal,
|
112 |
use_sequence_id=self.attn_uses_sequence_id,
|
113 |
)
|
114 |
-
self._attn_bias_ae_initialized = False
|
115 |
self.attn_bias_ae = None
|
116 |
|
117 |
if self.config.no_bias:
|
@@ -168,7 +168,7 @@ class ExtendedMPTModel(MPTPreTrainedModel):
|
|
168 |
)
|
169 |
self._attn_bias_initialized = True
|
170 |
|
171 |
-
if use_active_externalism:
|
172 |
self.attn_bias_ae = build_attn_bias(
|
173 |
self.attn_impl,
|
174 |
self.config.n_heads,
|
@@ -196,7 +196,7 @@ class ExtendedMPTModel(MPTPreTrainedModel):
|
|
196 |
|
197 |
attn_bias = self.attn_bias
|
198 |
|
199 |
-
if self.attn_bias_ae is not None:
|
200 |
self.attn_bias_ae = self.attn_bias_ae.to(dtype=dtype, device=device)
|
201 |
attn_bias_ae = self.attn_bias_ae
|
202 |
|
@@ -417,9 +417,7 @@ class ExtendedMPTModel(MPTPreTrainedModel):
|
|
417 |
assert isinstance(self.emb_drop, nn.Module) # pyright
|
418 |
x = self.emb_drop(x_shrunk)
|
419 |
|
420 |
-
|
421 |
-
|
422 |
-
seq_len = S
|
423 |
if past_key_values is not None:
|
424 |
past_position = past_key_values[0][0].size(-1)
|
425 |
seq_len += past_position
|
@@ -493,7 +491,7 @@ class ExtendedMPTModel(MPTPreTrainedModel):
|
|
493 |
last_hidden_state=x,
|
494 |
past_key_values=past_key_values,
|
495 |
hidden_states=all_hidden_states,
|
496 |
-
attentions=(all_self_attns, all_idx),
|
497 |
)
|
498 |
|
499 |
# Param Initialization, needed for device='meta' fast initialization
|
@@ -598,7 +596,7 @@ class ExtendedMPTForCausalLM(MPTPreTrainedModel):
|
|
598 |
use_active_externalism: Optional[bool]=None,
|
599 |
topk:int=None
|
600 |
):
|
601 |
-
if self._memories is not None and self.memories is None:
|
602 |
self.memories = self.generate_cache(self._memories, cache_type=self.memory_type)
|
603 |
|
604 |
return_dict = (return_dict
|
@@ -702,9 +700,8 @@ class ExtendedMPTForCausalLM(MPTPreTrainedModel):
|
|
702 |
prev_end_loc=0
|
703 |
long_range_past_key_values = None
|
704 |
faiss_indexes= None
|
705 |
-
for b_idx in range(0, input_ids.size(-1), stride):
|
706 |
end_loc = min(b_idx + max_len, input_ids.size(-1))
|
707 |
-
|
708 |
trg_len = end_loc - prev_end_loc
|
709 |
subseq = input_ids[:, b_idx:end_loc].to(self.device)
|
710 |
with torch.no_grad():
|
@@ -734,7 +731,7 @@ class ExtendedMPTForCausalLM(MPTPreTrainedModel):
|
|
734 |
if long_range_past_key_values is not None and faiss_indexes is not None:
|
735 |
raise NotImplementedError("Using faiss and passing key value pairs manually are mutually exclusive right now.")
|
736 |
|
737 |
-
if cache_type=='faiss':
|
738 |
one_hot_encodings = F.one_hot(torch.arange(0, self.config.n_heads*self.config.n_layers))*10
|
739 |
if faiss_indexes is None:
|
740 |
faiss_indexes = (faiss.IndexFlatIP(to_cache[0][0].size(-2)+one_hot_encodings.size(-1)), faiss.IndexFlatIP(to_cache[0][1].size(-1)*2))
|
@@ -747,7 +744,6 @@ class ExtendedMPTForCausalLM(MPTPreTrainedModel):
|
|
747 |
k= rearrange(k, 'b h d s -> b (h s) d', h=self.config.n_heads)
|
748 |
v= rearrange(v, 'b h s d -> b (h s) d', h=self.config.n_heads)
|
749 |
kv_index.add(torch.concat([v.squeeze(), k.squeeze()], dim=1).to('cpu').numpy())
|
750 |
-
|
751 |
else:
|
752 |
if long_range_past_key_values is None:
|
753 |
long_range_past_key_values = [(k.to(self.memory_device),v.to(self.memory_device)) for k,v in to_cache]
|
@@ -759,8 +755,8 @@ class ExtendedMPTForCausalLM(MPTPreTrainedModel):
|
|
759 |
)
|
760 |
for ind, kv in enumerate(long_range_past_key_values)
|
761 |
]
|
762 |
-
if long_range_past_key_values is not None:
|
763 |
-
if long_range_past_key_values[0][0].size(-1) > max_length_cache:
|
764 |
long_range_past_key_values = [
|
765 |
(
|
766 |
kv[0][:, :, :, -max_length_cache:],
|
@@ -816,7 +812,7 @@ class ExtendedMPTForCausalLM(MPTPreTrainedModel):
|
|
816 |
'sequence_id': sequence_id,
|
817 |
'past_key_values': past_key_values,
|
818 |
'use_cache': kwargs.get('use_cache', True),
|
819 |
-
'use_active_externalism': kwargs.get('use_active_externalism'),
|
820 |
'topk': kwargs.get('topk', None),
|
821 |
}
|
822 |
|
|
|
27 |
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
|
28 |
from llmfoundry.models.utils.param_init_fns import MODEL_INIT_REGISTRY
|
29 |
|
30 |
+
from extended_mind_transformers.mpt.configuration import ExtendedMPTConfig
|
31 |
+
from extended_mind_transformers.mpt.attention import attn_bias_shape, build_attn_bias
|
32 |
+
from extended_mind_transformers.mpt.blocks import MPTBlock
|
33 |
+
from extended_mind_transformers.utils import instantiate_from_config
|
34 |
|
35 |
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
36 |
|
|
|
111 |
causal=self.is_causal,
|
112 |
use_sequence_id=self.attn_uses_sequence_id,
|
113 |
)
|
114 |
+
self._attn_bias_ae_initialized = False #for active externalism
|
115 |
self.attn_bias_ae = None
|
116 |
|
117 |
if self.config.no_bias:
|
|
|
168 |
)
|
169 |
self._attn_bias_initialized = True
|
170 |
|
171 |
+
if use_active_externalism: #for active externalism, init every time since seq_len changes
|
172 |
self.attn_bias_ae = build_attn_bias(
|
173 |
self.attn_impl,
|
174 |
self.config.n_heads,
|
|
|
196 |
|
197 |
attn_bias = self.attn_bias
|
198 |
|
199 |
+
if self.attn_bias_ae is not None: #for active externalism
|
200 |
self.attn_bias_ae = self.attn_bias_ae.to(dtype=dtype, device=device)
|
201 |
attn_bias_ae = self.attn_bias_ae
|
202 |
|
|
|
417 |
assert isinstance(self.emb_drop, nn.Module) # pyright
|
418 |
x = self.emb_drop(x_shrunk)
|
419 |
|
420 |
+
seq_len = S #for active externalism
|
|
|
|
|
421 |
if past_key_values is not None:
|
422 |
past_position = past_key_values[0][0].size(-1)
|
423 |
seq_len += past_position
|
|
|
491 |
last_hidden_state=x,
|
492 |
past_key_values=past_key_values,
|
493 |
hidden_states=all_hidden_states,
|
494 |
+
attentions=(all_self_attns, all_idx), #return reshaped_idx for active externalism
|
495 |
)
|
496 |
|
497 |
# Param Initialization, needed for device='meta' fast initialization
|
|
|
596 |
use_active_externalism: Optional[bool]=None,
|
597 |
topk:int=None
|
598 |
):
|
599 |
+
if self._memories is not None and self.memories is None: #init memories once on first call
|
600 |
self.memories = self.generate_cache(self._memories, cache_type=self.memory_type)
|
601 |
|
602 |
return_dict = (return_dict
|
|
|
700 |
prev_end_loc=0
|
701 |
long_range_past_key_values = None
|
702 |
faiss_indexes= None
|
703 |
+
for b_idx in range(0, input_ids.size(-1), stride): #generate kv-pairs using stride
|
704 |
end_loc = min(b_idx + max_len, input_ids.size(-1))
|
|
|
705 |
trg_len = end_loc - prev_end_loc
|
706 |
subseq = input_ids[:, b_idx:end_loc].to(self.device)
|
707 |
with torch.no_grad():
|
|
|
731 |
if long_range_past_key_values is not None and faiss_indexes is not None:
|
732 |
raise NotImplementedError("Using faiss and passing key value pairs manually are mutually exclusive right now.")
|
733 |
|
734 |
+
if cache_type=='faiss': #add one-hot encoding to match layer, head indices
|
735 |
one_hot_encodings = F.one_hot(torch.arange(0, self.config.n_heads*self.config.n_layers))*10
|
736 |
if faiss_indexes is None:
|
737 |
faiss_indexes = (faiss.IndexFlatIP(to_cache[0][0].size(-2)+one_hot_encodings.size(-1)), faiss.IndexFlatIP(to_cache[0][1].size(-1)*2))
|
|
|
744 |
k= rearrange(k, 'b h d s -> b (h s) d', h=self.config.n_heads)
|
745 |
v= rearrange(v, 'b h s d -> b (h s) d', h=self.config.n_heads)
|
746 |
kv_index.add(torch.concat([v.squeeze(), k.squeeze()], dim=1).to('cpu').numpy())
|
|
|
747 |
else:
|
748 |
if long_range_past_key_values is None:
|
749 |
long_range_past_key_values = [(k.to(self.memory_device),v.to(self.memory_device)) for k,v in to_cache]
|
|
|
755 |
)
|
756 |
for ind, kv in enumerate(long_range_past_key_values)
|
757 |
]
|
758 |
+
if long_range_past_key_values is not None: #set a limit on manual memory length
|
759 |
+
if long_range_past_key_values[0][0].size(-1) > max_length_cache:
|
760 |
long_range_past_key_values = [
|
761 |
(
|
762 |
kv[0][:, :, :, -max_length_cache:],
|
|
|
812 |
'sequence_id': sequence_id,
|
813 |
'past_key_values': past_key_values,
|
814 |
'use_cache': kwargs.get('use_cache', True),
|
815 |
+
'use_active_externalism': kwargs.get('use_active_externalism'), #add a few more kwargs for active externalism
|
816 |
'topk': kwargs.get('topk', None),
|
817 |
}
|
818 |
|