Spaces:
Runtime error
Runtime error
crypto-code
commited on
Commit
•
a000794
1
Parent(s):
bb70f8e
Update llama/m2ugen.py
Browse files- llama/m2ugen.py +7 -8
llama/m2ugen.py
CHANGED
@@ -152,7 +152,7 @@ class M2UGen(nn.Module):
|
|
152 |
|
153 |
if torch.cuda.is_available():
|
154 |
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
155 |
-
self.llama = Transformer(self.model_args).to("cuda:
|
156 |
torch.set_default_tensor_type(torch.FloatTensor)
|
157 |
|
158 |
if load_llama:
|
@@ -233,7 +233,7 @@ class M2UGen(nn.Module):
|
|
233 |
# 4. prefix
|
234 |
self.query_layer = 20
|
235 |
self.query_len = 1
|
236 |
-
self.prefix_query = nn.Embedding(self.query_layer * self.query_len, self.model_args.dim).to("cuda:
|
237 |
|
238 |
# 5. knn
|
239 |
self.knn = knn
|
@@ -489,8 +489,8 @@ class M2UGen(nn.Module):
|
|
489 |
@torch.inference_mode()
|
490 |
def forward_inference(self, tokens, start_pos: int, audio_feats=None, image_feats=None, video_feats=None):
|
491 |
_bsz, seqlen = tokens.shape
|
492 |
-
h = self.llama.tok_embeddings(tokens).to("cuda:
|
493 |
-
freqs_cis = self.llama.freqs_cis.to("cuda:
|
494 |
freqs_cis = freqs_cis[start_pos:start_pos + seqlen]
|
495 |
|
496 |
feats = torch.zeros((1, 1, 4096)).to("cuda:0")
|
@@ -500,10 +500,9 @@ class M2UGen(nn.Module):
|
|
500 |
feats += video_feats
|
501 |
if image_feats is not None:
|
502 |
feats += image_feats
|
503 |
-
feats = feats.to("cuda:1")
|
504 |
|
505 |
mask = None
|
506 |
-
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device="cuda:
|
507 |
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
|
508 |
|
509 |
music_output_embedding = []
|
@@ -669,10 +668,10 @@ class M2UGen(nn.Module):
|
|
669 |
|
670 |
total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
|
671 |
|
672 |
-
tokens = torch.full((bsz, total_len), 0).to("cuda:
|
673 |
|
674 |
for k, t in enumerate(prompts):
|
675 |
-
tokens[k, : len(t)] = torch.tensor(t).to("cuda:
|
676 |
input_text_mask = tokens != 0
|
677 |
start_pos = min_prompt_size
|
678 |
prev_pos = 0
|
|
|
152 |
|
153 |
if torch.cuda.is_available():
|
154 |
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
155 |
+
self.llama = Transformer(self.model_args).to("cuda:0")
|
156 |
torch.set_default_tensor_type(torch.FloatTensor)
|
157 |
|
158 |
if load_llama:
|
|
|
233 |
# 4. prefix
|
234 |
self.query_layer = 20
|
235 |
self.query_len = 1
|
236 |
+
self.prefix_query = nn.Embedding(self.query_layer * self.query_len, self.model_args.dim).to("cuda:0")
|
237 |
|
238 |
# 5. knn
|
239 |
self.knn = knn
|
|
|
489 |
@torch.inference_mode()
|
490 |
def forward_inference(self, tokens, start_pos: int, audio_feats=None, image_feats=None, video_feats=None):
|
491 |
_bsz, seqlen = tokens.shape
|
492 |
+
h = self.llama.tok_embeddings(tokens).to("cuda:0")
|
493 |
+
freqs_cis = self.llama.freqs_cis.to("cuda:0")
|
494 |
freqs_cis = freqs_cis[start_pos:start_pos + seqlen]
|
495 |
|
496 |
feats = torch.zeros((1, 1, 4096)).to("cuda:0")
|
|
|
500 |
feats += video_feats
|
501 |
if image_feats is not None:
|
502 |
feats += image_feats
|
|
|
503 |
|
504 |
mask = None
|
505 |
+
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device="cuda:0")
|
506 |
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
|
507 |
|
508 |
music_output_embedding = []
|
|
|
668 |
|
669 |
total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
|
670 |
|
671 |
+
tokens = torch.full((bsz, total_len), 0).to("cuda:0").long()
|
672 |
|
673 |
for k, t in enumerate(prompts):
|
674 |
+
tokens[k, : len(t)] = torch.tensor(t).to("cuda:0").long()
|
675 |
input_text_mask = tokens != 0
|
676 |
start_pos = min_prompt_size
|
677 |
prev_pos = 0
|