BK-Lee commited on
Commit
e1361b1
1 Parent(s): 2acb8d8
Files changed (3) hide show
  1. app.py +8 -15
  2. meteor/arch/modeling_internlm2.py +2 -2
  3. requirements.txt +2 -1
app.py CHANGED
@@ -8,13 +8,17 @@ from PIL import Image
8
  from utils.utils import *
9
  from threading import Thread
10
  import torch.nn.functional as F
 
11
  from meteor.load_mmamba import load_mmamba
12
  from meteor.load_meteor import load_meteor
13
  from transformers import TextIteratorStreamer
14
  from torchvision.transforms.functional import pil_to_tensor
15
 
 
 
 
16
  # loading meteor model
17
- mmamba = load_mmamba('BK-Lee/Meteor-Mamba').cuda()
18
  meteor, tok_meteor = load_meteor('BK-Lee/Meteor-MLM', bits=16)
19
 
20
  # freeze model
@@ -24,7 +28,6 @@ freeze_model(meteor)
24
  # previous length
25
  previous_length = 0
26
 
27
- @spaces.GPU
28
  def threading_function(inputs, image_token_number, streamer, device):
29
 
30
  # Meteor Mamba
@@ -49,24 +52,14 @@ def threading_function(inputs, image_token_number, streamer, device):
49
  generation_kwargs.update({'use_cache': True})
50
  return meteor.generate(**generation_kwargs)
51
 
52
- def add_message(history, message):
53
- for x in message["files"]:
54
- history.append(((x,), None))
55
- if message["text"] is not None:
56
- history.append((message["text"], None))
57
- return history, gr.MultimodalTextbox(value=None, interactive=False)
58
-
59
  @spaces.GPU
60
  def bot_streaming(message, history):
61
 
62
- # device
63
- device = torch.cuda.current_device()
64
-
65
  # param
66
  for param in mmamba.parameters():
67
- param.data = param.to(device)
68
  for param in meteor.parameters():
69
- param.data = param.to(device)
70
 
71
  # prompt type -> input prompt
72
  image_token_number = int((490/14)**2)
@@ -83,7 +76,7 @@ def bot_streaming(message, history):
83
  streamer = TextIteratorStreamer(tok_meteor, skip_special_tokens=True)
84
 
85
  # Threading generation
86
- thread = Thread(target=threading_function, kwargs=dict(inputs=inputs, image_token_number=image_token_number, streamer=streamer, device=device))
87
  thread.start()
88
 
89
  # generated text
 
8
  from utils.utils import *
9
  from threading import Thread
10
  import torch.nn.functional as F
11
+ from accelerate import Accelerator
12
  from meteor.load_mmamba import load_mmamba
13
  from meteor.load_meteor import load_meteor
14
  from transformers import TextIteratorStreamer
15
  from torchvision.transforms.functional import pil_to_tensor
16
 
17
+ # accel
18
+ accel = Accelerator()
19
+
20
  # loading meteor model
21
+ mmamba = load_mmamba('BK-Lee/Meteor-Mamba')
22
  meteor, tok_meteor = load_meteor('BK-Lee/Meteor-MLM', bits=16)
23
 
24
  # freeze model
 
28
  # previous length
29
  previous_length = 0
30
 
 
31
  def threading_function(inputs, image_token_number, streamer, device):
32
 
33
  # Meteor Mamba
 
52
  generation_kwargs.update({'use_cache': True})
53
  return meteor.generate(**generation_kwargs)
54
 
 
 
 
 
 
 
 
55
  @spaces.GPU
56
  def bot_streaming(message, history):
57
 
 
 
 
58
  # param
59
  for param in mmamba.parameters():
60
+ param.data = param.to(accel.device)
61
  for param in meteor.parameters():
62
+ param.data = param.to(accel.device)
63
 
64
  # prompt type -> input prompt
65
  image_token_number = int((490/14)**2)
 
76
  streamer = TextIteratorStreamer(tok_meteor, skip_special_tokens=True)
77
 
78
  # Threading generation
79
+ thread = Thread(target=threading_function, kwargs=dict(inputs=inputs, image_token_number=image_token_number, streamer=streamer, device=accel.device))
80
  thread.start()
81
 
82
  # generated text
meteor/arch/modeling_internlm2.py CHANGED
@@ -277,8 +277,8 @@ def rotate_half(x):
277
  # Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
278
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
279
  """Applies Rotary Position Embedding to the query and key tensors."""
280
- cos = cos.to(position_ids.device)[position_ids].unsqueeze(unsqueeze_dim)
281
- sin = sin.to(position_ids.device)[position_ids].unsqueeze(unsqueeze_dim)
282
  q_embed = (q * cos) + (rotate_half(q) * sin)
283
  k_embed = (k * cos) + (rotate_half(k) * sin)
284
  return q_embed, k_embed
 
277
  # Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
278
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
279
  """Applies Rotary Position Embedding to the query and key tensors."""
280
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
281
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
282
  q_embed = (q * cos) + (rotate_half(q) * sin)
283
  k_embed = (k * cos) + (rotate_half(k) * sin)
284
  return q_embed, k_embed
requirements.txt CHANGED
@@ -13,4 +13,5 @@ timm
13
  shortuuid
14
  matplotlib
15
  gradio
16
- spaces
 
 
13
  shortuuid
14
  matplotlib
15
  gradio
16
+ spaces
17
+ accelerate