crypto-code commited on
Commit
a000794
1 Parent(s): bb70f8e

Update llama/m2ugen.py

Browse files
Files changed (1) hide show
  1. llama/m2ugen.py +7 -8
llama/m2ugen.py CHANGED
@@ -152,7 +152,7 @@ class M2UGen(nn.Module):
152
 
153
  if torch.cuda.is_available():
154
  torch.set_default_tensor_type(torch.cuda.HalfTensor)
155
- self.llama = Transformer(self.model_args).to("cuda:1")
156
  torch.set_default_tensor_type(torch.FloatTensor)
157
 
158
  if load_llama:
@@ -233,7 +233,7 @@ class M2UGen(nn.Module):
233
  # 4. prefix
234
  self.query_layer = 20
235
  self.query_len = 1
236
- self.prefix_query = nn.Embedding(self.query_layer * self.query_len, self.model_args.dim).to("cuda:1")
237
 
238
  # 5. knn
239
  self.knn = knn
@@ -489,8 +489,8 @@ class M2UGen(nn.Module):
489
  @torch.inference_mode()
490
  def forward_inference(self, tokens, start_pos: int, audio_feats=None, image_feats=None, video_feats=None):
491
  _bsz, seqlen = tokens.shape
492
- h = self.llama.tok_embeddings(tokens).to("cuda:1")
493
- freqs_cis = self.llama.freqs_cis.to("cuda:1")
494
  freqs_cis = freqs_cis[start_pos:start_pos + seqlen]
495
 
496
  feats = torch.zeros((1, 1, 4096)).to("cuda:0")
@@ -500,10 +500,9 @@ class M2UGen(nn.Module):
500
  feats += video_feats
501
  if image_feats is not None:
502
  feats += image_feats
503
- feats = feats.to("cuda:1")
504
 
505
  mask = None
506
- mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device="cuda:1")
507
  mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
508
 
509
  music_output_embedding = []
@@ -669,10 +668,10 @@ class M2UGen(nn.Module):
669
 
670
  total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
671
 
672
- tokens = torch.full((bsz, total_len), 0).to("cuda:1").long()
673
 
674
  for k, t in enumerate(prompts):
675
- tokens[k, : len(t)] = torch.tensor(t).to("cuda:1").long()
676
  input_text_mask = tokens != 0
677
  start_pos = min_prompt_size
678
  prev_pos = 0
 
152
 
153
  if torch.cuda.is_available():
154
  torch.set_default_tensor_type(torch.cuda.HalfTensor)
155
+ self.llama = Transformer(self.model_args).to("cuda:0")
156
  torch.set_default_tensor_type(torch.FloatTensor)
157
 
158
  if load_llama:
 
233
  # 4. prefix
234
  self.query_layer = 20
235
  self.query_len = 1
236
+ self.prefix_query = nn.Embedding(self.query_layer * self.query_len, self.model_args.dim).to("cuda:0")
237
 
238
  # 5. knn
239
  self.knn = knn
 
489
  @torch.inference_mode()
490
  def forward_inference(self, tokens, start_pos: int, audio_feats=None, image_feats=None, video_feats=None):
491
  _bsz, seqlen = tokens.shape
492
+ h = self.llama.tok_embeddings(tokens).to("cuda:0")
493
+ freqs_cis = self.llama.freqs_cis.to("cuda:0")
494
  freqs_cis = freqs_cis[start_pos:start_pos + seqlen]
495
 
496
  feats = torch.zeros((1, 1, 4096)).to("cuda:0")
 
500
  feats += video_feats
501
  if image_feats is not None:
502
  feats += image_feats
 
503
 
504
  mask = None
505
+ mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device="cuda:0")
506
  mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
507
 
508
  music_output_embedding = []
 
668
 
669
  total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
670
 
671
+ tokens = torch.full((bsz, total_len), 0).to("cuda:0").long()
672
 
673
  for k, t in enumerate(prompts):
674
+ tokens[k, : len(t)] = torch.tensor(t).to("cuda:0").long()
675
  input_text_mask = tokens != 0
676
  start_pos = min_prompt_size
677
  prev_pos = 0