updates for new autoregressive
Browse files- api_new_autoregressive.py +1 -1
- models/new_autoregressive.py +9 -3
- models/xtransformers.py +53 -14
api_new_autoregressive.py
CHANGED
@@ -135,7 +135,7 @@ class TextToSpeech:
|
|
135 |
download_models()
|
136 |
|
137 |
self.autoregressive = AutoregressiveCodegen(1024, 16).cpu().eval()
|
138 |
-
self.autoregressive.load_state_dict(torch.load('X:\\dlas\\experiments\\train_autoregressive_codegen\\models\\
|
139 |
|
140 |
self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
|
141 |
text_seq_len=350, text_heads=8,
|
|
|
135 |
download_models()
|
136 |
|
137 |
self.autoregressive = AutoregressiveCodegen(1024, 16).cpu().eval()
|
138 |
+
self.autoregressive.load_state_dict(torch.load('X:\\dlas\\experiments\\train_autoregressive_codegen\\models\\17000_codegen_ema.pth'))
|
139 |
|
140 |
self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
|
141 |
text_seq_len=350, text_heads=8,
|
models/new_autoregressive.py
CHANGED
@@ -85,7 +85,13 @@ class InferenceModel(GPT2PreTrainedModel):
|
|
85 |
assert labels is None # Training not supported by this inference model.
|
86 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
87 |
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
logits = self.transformer.decoder.to_logits(hidden_states)
|
90 |
|
91 |
if not return_dict:
|
@@ -94,7 +100,7 @@ class InferenceModel(GPT2PreTrainedModel):
|
|
94 |
return CausalLMOutputWithCrossAttentions(
|
95 |
loss=None,
|
96 |
logits=logits,
|
97 |
-
past_key_values=
|
98 |
hidden_states=hidden_states,
|
99 |
attentions=None,
|
100 |
cross_attentions=None,
|
@@ -258,7 +264,7 @@ class AutoregressiveCodegen(nn.Module):
|
|
258 |
inference_model.store_context(full_context)
|
259 |
|
260 |
gen = inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
|
261 |
-
|
262 |
**hf_generate_kwargs)
|
263 |
return gen.sequences
|
264 |
|
|
|
85 |
assert labels is None # Training not supported by this inference model.
|
86 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
87 |
|
88 |
+
out = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True, past_key_values=past_key_values,
|
89 |
+
use_cache=use_cache, expected_seq_len=100)
|
90 |
+
if use_cache:
|
91 |
+
hidden_states, present_key_values = out
|
92 |
+
else:
|
93 |
+
hidden_states = out
|
94 |
+
present_key_values = None
|
95 |
logits = self.transformer.decoder.to_logits(hidden_states)
|
96 |
|
97 |
if not return_dict:
|
|
|
100 |
return CausalLMOutputWithCrossAttentions(
|
101 |
loss=None,
|
102 |
logits=logits,
|
103 |
+
past_key_values=present_key_values,
|
104 |
hidden_states=hidden_states,
|
105 |
attentions=None,
|
106 |
cross_attentions=None,
|
|
|
264 |
inference_model.store_context(full_context)
|
265 |
|
266 |
gen = inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
|
267 |
+
max_length=max_tokens, output_attentions=False, return_dict_in_generate=True, use_cache=False,
|
268 |
**hf_generate_kwargs)
|
269 |
return gen.sequences
|
270 |
|
models/xtransformers.py
CHANGED
@@ -24,7 +24,8 @@ Intermediates = namedtuple('Intermediates', [
|
|
24 |
|
25 |
LayerIntermediates = namedtuple('Intermediates', [
|
26 |
'hiddens',
|
27 |
-
'attn_intermediates'
|
|
|
28 |
])
|
29 |
|
30 |
|
@@ -589,7 +590,8 @@ class Attention(nn.Module):
|
|
589 |
sinusoidal_emb=None,
|
590 |
rotary_pos_emb=None,
|
591 |
prev_attn=None,
|
592 |
-
mem=None
|
|
|
593 |
):
|
594 |
b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(
|
595 |
context)
|
@@ -620,6 +622,13 @@ class Attention(nn.Module):
|
|
620 |
k = rearrange(k, 'b n d -> b () n d')
|
621 |
v = rearrange(v, 'b n (h d) -> b h n d', h=h)
|
622 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
623 |
if exists(rotary_pos_emb) and not has_context:
|
624 |
l = rotary_pos_emb.shape[-1]
|
625 |
(ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
|
@@ -723,7 +732,7 @@ class Attention(nn.Module):
|
|
723 |
post_softmax_attn=post_softmax_attn
|
724 |
)
|
725 |
|
726 |
-
return self.to_out(out), intermediates
|
727 |
|
728 |
|
729 |
class AttentionLayers(nn.Module):
|
@@ -770,6 +779,7 @@ class AttentionLayers(nn.Module):
|
|
770 |
self.dim = dim
|
771 |
self.depth = depth
|
772 |
self.layers = nn.ModuleList([])
|
|
|
773 |
|
774 |
rel_pos_bias = 'rel_pos_bias' in attn_kwargs
|
775 |
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
|
@@ -911,6 +921,8 @@ class AttentionLayers(nn.Module):
|
|
911 |
mems=None,
|
912 |
return_hiddens=False,
|
913 |
norm_scale_shift_inp=None,
|
|
|
|
|
914 |
):
|
915 |
|
916 |
assert not (self.cross_attend ^ (exists(context) or exists(
|
@@ -929,9 +941,17 @@ class AttentionLayers(nn.Module):
|
|
929 |
|
930 |
rotary_pos_emb = None
|
931 |
if exists(self.rotary_pos_emb):
|
932 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
933 |
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
934 |
|
|
|
935 |
cross_attn_count = 0
|
936 |
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
|
937 |
if layer_type == 'a':
|
@@ -944,18 +964,28 @@ class AttentionLayers(nn.Module):
|
|
944 |
if exists(pre_branch_norm):
|
945 |
x = pre_branch_norm(x, **norm_args)
|
946 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
947 |
if layer_type == 'a':
|
948 |
-
out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
|
949 |
-
prev_attn, layer_mem)
|
950 |
elif layer_type == 'c':
|
951 |
if exists(full_context):
|
952 |
-
out, inter = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
|
953 |
-
None, prev_attn)
|
954 |
else:
|
955 |
-
out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn)
|
956 |
elif layer_type == 'f':
|
957 |
out = checkpoint(block, x)
|
958 |
|
|
|
|
|
|
|
959 |
if exists(post_branch_norm):
|
960 |
out = post_branch_norm(out, **norm_args)
|
961 |
|
@@ -981,7 +1011,8 @@ class AttentionLayers(nn.Module):
|
|
981 |
if return_hiddens:
|
982 |
intermediates = LayerIntermediates(
|
983 |
hiddens=hiddens,
|
984 |
-
attn_intermediates=intermediates
|
|
|
985 |
)
|
986 |
|
987 |
return x, intermediates
|
@@ -1115,6 +1146,7 @@ class TransformerWrapper(nn.Module):
|
|
1115 |
return_hiddens=False,
|
1116 |
return_attn=False,
|
1117 |
mems=None,
|
|
|
1118 |
**kwargs
|
1119 |
):
|
1120 |
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
|
@@ -1147,11 +1179,14 @@ class TransformerWrapper(nn.Module):
|
|
1147 |
hiddens = intermediates.hiddens
|
1148 |
return out, hiddens
|
1149 |
|
|
|
1150 |
if return_attn:
|
1151 |
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
1152 |
-
|
|
|
|
|
1153 |
|
1154 |
-
return
|
1155 |
|
1156 |
|
1157 |
class ContinuousTransformerWrapper(nn.Module):
|
@@ -1191,6 +1226,7 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
1191 |
mask=None,
|
1192 |
return_attn=False,
|
1193 |
mems=None,
|
|
|
1194 |
**kwargs
|
1195 |
):
|
1196 |
b, n, _, device = *x.shape, x.device
|
@@ -1204,11 +1240,14 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
1204 |
|
1205 |
out = self.project_out(x) if not return_embeddings else x
|
1206 |
|
|
|
1207 |
if return_attn:
|
1208 |
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
1209 |
-
|
|
|
|
|
1210 |
|
1211 |
-
return
|
1212 |
|
1213 |
|
1214 |
class XTransformer(nn.Module):
|
|
|
24 |
|
25 |
LayerIntermediates = namedtuple('Intermediates', [
|
26 |
'hiddens',
|
27 |
+
'attn_intermediates',
|
28 |
+
'past_key_values',
|
29 |
])
|
30 |
|
31 |
|
|
|
590 |
sinusoidal_emb=None,
|
591 |
rotary_pos_emb=None,
|
592 |
prev_attn=None,
|
593 |
+
mem=None,
|
594 |
+
layer_past=None,
|
595 |
):
|
596 |
b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(
|
597 |
context)
|
|
|
622 |
k = rearrange(k, 'b n d -> b () n d')
|
623 |
v = rearrange(v, 'b n (h d) -> b h n d', h=h)
|
624 |
|
625 |
+
if layer_past is not None:
|
626 |
+
past_key, past_value = layer_past
|
627 |
+
k = torch.cat([past_key, k], dim=-2)
|
628 |
+
v = torch.cat([past_value, v], dim=-2)
|
629 |
+
k_cache = k
|
630 |
+
v_cache = v
|
631 |
+
|
632 |
if exists(rotary_pos_emb) and not has_context:
|
633 |
l = rotary_pos_emb.shape[-1]
|
634 |
(ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
|
|
|
732 |
post_softmax_attn=post_softmax_attn
|
733 |
)
|
734 |
|
735 |
+
return self.to_out(out), intermediates, k_cache, v_cache
|
736 |
|
737 |
|
738 |
class AttentionLayers(nn.Module):
|
|
|
779 |
self.dim = dim
|
780 |
self.depth = depth
|
781 |
self.layers = nn.ModuleList([])
|
782 |
+
self.causal = causal
|
783 |
|
784 |
rel_pos_bias = 'rel_pos_bias' in attn_kwargs
|
785 |
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
|
|
|
921 |
mems=None,
|
922 |
return_hiddens=False,
|
923 |
norm_scale_shift_inp=None,
|
924 |
+
past_key_values=None,
|
925 |
+
expected_seq_len=None,
|
926 |
):
|
927 |
|
928 |
assert not (self.cross_attend ^ (exists(context) or exists(
|
|
|
941 |
|
942 |
rotary_pos_emb = None
|
943 |
if exists(self.rotary_pos_emb):
|
944 |
+
if not self.training and self.causal:
|
945 |
+
assert expected_seq_len is not None, "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`"
|
946 |
+
elif expected_seq_len is None:
|
947 |
+
expected_seq_len = 0
|
948 |
+
seq_len = x.shape[1]
|
949 |
+
if past_key_values is not None:
|
950 |
+
seq_len += past_key_values[0][0].shape[-2]
|
951 |
+
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len])
|
952 |
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
953 |
|
954 |
+
present_key_values = []
|
955 |
cross_attn_count = 0
|
956 |
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
|
957 |
if layer_type == 'a':
|
|
|
964 |
if exists(pre_branch_norm):
|
965 |
x = pre_branch_norm(x, **norm_args)
|
966 |
|
967 |
+
if layer_type == 'a' or layer_type == 'c':
|
968 |
+
if past_key_values is not None:
|
969 |
+
layer_kv = past_key_values.pop(0)
|
970 |
+
layer_past = tuple(s.to(x.device) for s in layer_kv)
|
971 |
+
else:
|
972 |
+
layer_past = None
|
973 |
+
|
974 |
if layer_type == 'a':
|
975 |
+
out, inter, k, v = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
|
976 |
+
prev_attn, layer_mem, layer_past)
|
977 |
elif layer_type == 'c':
|
978 |
if exists(full_context):
|
979 |
+
out, inter, k, v = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
|
980 |
+
None, prev_attn, None, layer_past)
|
981 |
else:
|
982 |
+
out, inter, k, v = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
|
983 |
elif layer_type == 'f':
|
984 |
out = checkpoint(block, x)
|
985 |
|
986 |
+
if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
|
987 |
+
present_key_values.append((k.detach(), v.detach()))
|
988 |
+
|
989 |
if exists(post_branch_norm):
|
990 |
out = post_branch_norm(out, **norm_args)
|
991 |
|
|
|
1011 |
if return_hiddens:
|
1012 |
intermediates = LayerIntermediates(
|
1013 |
hiddens=hiddens,
|
1014 |
+
attn_intermediates=intermediates,
|
1015 |
+
past_key_values=present_key_values
|
1016 |
)
|
1017 |
|
1018 |
return x, intermediates
|
|
|
1146 |
return_hiddens=False,
|
1147 |
return_attn=False,
|
1148 |
mems=None,
|
1149 |
+
use_cache=False,
|
1150 |
**kwargs
|
1151 |
):
|
1152 |
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
|
|
|
1179 |
hiddens = intermediates.hiddens
|
1180 |
return out, hiddens
|
1181 |
|
1182 |
+
res = [out]
|
1183 |
if return_attn:
|
1184 |
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
1185 |
+
res.append(attn_maps)
|
1186 |
+
if use_cache:
|
1187 |
+
res.append(intermediates.past_key_values)
|
1188 |
|
1189 |
+
return res
|
1190 |
|
1191 |
|
1192 |
class ContinuousTransformerWrapper(nn.Module):
|
|
|
1226 |
mask=None,
|
1227 |
return_attn=False,
|
1228 |
mems=None,
|
1229 |
+
use_cache=False,
|
1230 |
**kwargs
|
1231 |
):
|
1232 |
b, n, _, device = *x.shape, x.device
|
|
|
1240 |
|
1241 |
out = self.project_out(x) if not return_embeddings else x
|
1242 |
|
1243 |
+
res = [out]
|
1244 |
if return_attn:
|
1245 |
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
1246 |
+
res.append(attn_maps)
|
1247 |
+
if use_cache:
|
1248 |
+
res.append(intermediates.past_key_values)
|
1249 |
|
1250 |
+
return tuple(res)
|
1251 |
|
1252 |
|
1253 |
class XTransformer(nn.Module):
|