finalf0 commited on
Commit
ceec226
1 Parent(s): d1c49e2

Add vision_batch_size to avoid cuda OOM (#4)

Browse files

- Add vision_batch_size to avoid cuda OOM (74e66369a97903f4922eaaa1f0c6d5d5a591faf9)

Files changed (2) hide show
  1. configuration_minicpm.py +3 -1
  2. modeling_minicpmv.py +23 -24
configuration_minicpm.py CHANGED
@@ -69,6 +69,7 @@ class MiniCPMVConfig(Qwen2Config):
69
  slice_config=None,
70
  vision_config=None,
71
  use_image_id=True,
 
72
  **kwargs,
73
  ):
74
  self.use_cache = use_cache
@@ -77,6 +78,7 @@ class MiniCPMVConfig(Qwen2Config):
77
  self.drop_vision_last_layer = drop_vision_last_layer
78
  self.batch_vision_input = batch_vision_input
79
  self.use_image_id = use_image_id
 
80
 
81
  if slice_config is None:
82
  self.slice_config = MiniCPMVSliceConfig(max_slice_nums=1)
@@ -95,4 +97,4 @@ class MiniCPMVConfig(Qwen2Config):
95
 
96
  self.patch_size = self.vision_config.patch_size
97
 
98
- super().__init__(**kwargs)
 
69
  slice_config=None,
70
  vision_config=None,
71
  use_image_id=True,
72
+ vision_batch_size=16,
73
  **kwargs,
74
  ):
75
  self.use_cache = use_cache
 
78
  self.drop_vision_last_layer = drop_vision_last_layer
79
  self.batch_vision_input = batch_vision_input
80
  self.use_image_id = use_image_id
81
+ self.vision_batch_size = vision_batch_size
82
 
83
  if slice_config is None:
84
  self.slice_config = MiniCPMVSliceConfig(max_slice_nums=1)
 
97
 
98
  self.patch_size = self.vision_config.patch_size
99
 
100
+ super().__init__(**kwargs)
modeling_minicpmv.py CHANGED
@@ -92,31 +92,30 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
92
  tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)]
93
  tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
94
 
95
- if self.config.batch_vision_input:
96
- max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
97
-
98
- all_pixel_values = torch.nn.utils.rnn.pad_sequence(all_pixel_values, batch_first=True,
99
- padding_value=0.0)
100
- B, L, _ = all_pixel_values.shape
101
- all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
102
-
103
- patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device)
104
- for i in range(B):
105
- patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
106
-
107
- vision_embedding = self.vpm(all_pixel_values.type(dtype), patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes).last_hidden_state
108
- vision_embedding = self.resampler(vision_embedding, tgt_sizes)
 
 
 
 
 
 
 
109
  else:
110
- # get vision_embedding foreach
111
- vision_embedding = []
112
- for single_tgt_size, single_pixel_values in zip(tgt_sizes, all_pixel_values):
113
- single_pixel_values = single_pixel_values.unsqueeze(0)
114
- B, L, _ = single_pixel_values.shape
115
- single_pixel_values = single_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
116
- single_vision_embedding = self.vpm(single_pixel_values.type(dtype), tgt_sizes=single_tgt_size.unsqueeze(0)).last_hidden_state
117
- single_vision_embedding = self.resampler(single_vision_embedding, single_tgt_size.unsqueeze(0))
118
- vision_embedding.append(single_vision_embedding)
119
- vision_embedding = torch.vstack(vision_embedding)
120
 
121
  start = 0
122
  for pixel_values in pixel_values_list:
 
92
  tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)]
93
  tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
94
 
95
+ max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
96
+
97
+ all_pixel_values = torch.nn.utils.rnn.pad_sequence(all_pixel_values, batch_first=True,
98
+ padding_value=0.0)
99
+ B, L, _ = all_pixel_values.shape
100
+ all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
101
+
102
+ patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device)
103
+ for i in range(B):
104
+ patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
105
+
106
+ vision_batch_size = self.config.vision_batch_size
107
+ all_pixel_values = all_pixel_values.type(dtype)
108
+ if B > vision_batch_size:
109
+ hs = []
110
+ for i in range(0, B, vision_batch_size):
111
+ start_idx = i
112
+ end_idx = i + vision_batch_size
113
+ tmp_hs = self.vpm(all_pixel_values[start_idx:end_idx], patch_attention_mask=patch_attn_mask[start_idx:end_idx], tgt_sizes=tgt_sizes[start_idx:end_idx]).last_hidden_state
114
+ hs.append(tmp_hs)
115
+ vision_embedding = torch.cat(hs, dim=0)
116
  else:
117
+ vision_embedding = self.vpm(all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes).last_hidden_state
118
+ vision_embedding = self.resampler(vision_embedding, tgt_sizes)
 
 
 
 
 
 
 
 
119
 
120
  start = 0
121
  for pixel_values in pixel_values_list: