Vision-CAIR commited on
Commit
f42441f
1 Parent(s): b7bd78d

Upload folder using huggingface_hub

Browse files
Files changed (21) hide show
  1. Qformer.py +1216 -0
  2. __init__.py +206 -0
  3. base_model.py +249 -0
  4. base_processor.py +26 -0
  5. blip2.py +220 -0
  6. blip2_outputs.py +110 -0
  7. blip_processors.py +164 -0
  8. clip_vision_encoder.py +83 -0
  9. config.py +474 -0
  10. conversation.py +224 -0
  11. dist_utils.py +146 -0
  12. eva_vit.py +443 -0
  13. gradcam.py +24 -0
  14. logger.py +194 -0
  15. mini_gpt4_llama_v2.py +10 -9
  16. modeling_llama_v2.py +136 -0
  17. modeling_mistral.py +1388 -0
  18. optims.py +119 -0
  19. randaugment.py +398 -0
  20. registry.py +267 -0
  21. utils.py +470 -0
Qformer.py ADDED
@@ -0,0 +1,1216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ * Copyright (c) 2023, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ """
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Dict, Any
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from transformers.modeling_utils import (
40
+ PreTrainedModel,
41
+ apply_chunking_to_forward,
42
+ find_pruneable_heads_and_indices,
43
+ prune_linear_layer,
44
+ )
45
+ from transformers.utils import logging
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ class BertEmbeddings(nn.Module):
52
+ """Construct the embeddings from word and position embeddings."""
53
+
54
+ def __init__(self, config):
55
+ super().__init__()
56
+ self.word_embeddings = nn.Embedding(
57
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
58
+ )
59
+ self.position_embeddings = nn.Embedding(
60
+ config.max_position_embeddings, config.hidden_size
61
+ )
62
+
63
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
64
+ # any TensorFlow checkpoint file
65
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
66
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
67
+
68
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
69
+ self.register_buffer(
70
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
71
+ )
72
+ self.position_embedding_type = getattr(
73
+ config, "position_embedding_type", "absolute"
74
+ )
75
+
76
+ self.config = config
77
+
78
+ def forward(
79
+ self,
80
+ input_ids=None,
81
+ position_ids=None,
82
+ query_embeds=None,
83
+ past_key_values_length=0,
84
+ ):
85
+ if input_ids is not None:
86
+ seq_length = input_ids.size()[1]
87
+ else:
88
+ seq_length = 0
89
+
90
+ if position_ids is None:
91
+ position_ids = self.position_ids[
92
+ :, past_key_values_length : seq_length + past_key_values_length
93
+ ].clone()
94
+
95
+ if input_ids is not None:
96
+ embeddings = self.word_embeddings(input_ids)
97
+ if self.position_embedding_type == "absolute":
98
+ position_embeddings = self.position_embeddings(position_ids)
99
+ embeddings = embeddings + position_embeddings
100
+
101
+ if query_embeds is not None:
102
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
103
+ else:
104
+ embeddings = query_embeds
105
+
106
+ embeddings = self.LayerNorm(embeddings)
107
+ embeddings = self.dropout(embeddings)
108
+ return embeddings
109
+
110
+
111
+ class BertSelfAttention(nn.Module):
112
+ def __init__(self, config, is_cross_attention):
113
+ super().__init__()
114
+ self.config = config
115
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
116
+ config, "embedding_size"
117
+ ):
118
+ raise ValueError(
119
+ "The hidden size (%d) is not a multiple of the number of attention "
120
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
121
+ )
122
+
123
+ self.num_attention_heads = config.num_attention_heads
124
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
125
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
126
+
127
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
128
+ if is_cross_attention:
129
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
130
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
131
+ else:
132
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
133
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
134
+
135
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
136
+ self.position_embedding_type = getattr(
137
+ config, "position_embedding_type", "absolute"
138
+ )
139
+ if (
140
+ self.position_embedding_type == "relative_key"
141
+ or self.position_embedding_type == "relative_key_query"
142
+ ):
143
+ self.max_position_embeddings = config.max_position_embeddings
144
+ self.distance_embedding = nn.Embedding(
145
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
146
+ )
147
+ self.save_attention = False
148
+
149
+ def save_attn_gradients(self, attn_gradients):
150
+ self.attn_gradients = attn_gradients
151
+
152
+ def get_attn_gradients(self):
153
+ return self.attn_gradients
154
+
155
+ def save_attention_map(self, attention_map):
156
+ self.attention_map = attention_map
157
+
158
+ def get_attention_map(self):
159
+ return self.attention_map
160
+
161
+ def transpose_for_scores(self, x):
162
+ new_x_shape = x.size()[:-1] + (
163
+ self.num_attention_heads,
164
+ self.attention_head_size,
165
+ )
166
+ x = x.view(*new_x_shape)
167
+ return x.permute(0, 2, 1, 3)
168
+
169
+ def forward(
170
+ self,
171
+ hidden_states,
172
+ attention_mask=None,
173
+ head_mask=None,
174
+ encoder_hidden_states=None,
175
+ encoder_attention_mask=None,
176
+ past_key_value=None,
177
+ output_attentions=False,
178
+ ):
179
+
180
+ # If this is instantiated as a cross-attention module, the keys
181
+ # and values come from an encoder; the attention mask needs to be
182
+ # such that the encoder's padding tokens are not attended to.
183
+ is_cross_attention = encoder_hidden_states is not None
184
+
185
+ if is_cross_attention:
186
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
187
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
188
+ attention_mask = encoder_attention_mask
189
+ elif past_key_value is not None:
190
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
191
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
192
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
193
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
194
+ else:
195
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
196
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
197
+
198
+ mixed_query_layer = self.query(hidden_states)
199
+
200
+ query_layer = self.transpose_for_scores(mixed_query_layer)
201
+
202
+ past_key_value = (key_layer, value_layer)
203
+
204
+ # Take the dot product between "query" and "key" to get the raw attention scores.
205
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
206
+
207
+ if (
208
+ self.position_embedding_type == "relative_key"
209
+ or self.position_embedding_type == "relative_key_query"
210
+ ):
211
+ seq_length = hidden_states.size()[1]
212
+ position_ids_l = torch.arange(
213
+ seq_length, dtype=torch.long, device=hidden_states.device
214
+ ).view(-1, 1)
215
+ position_ids_r = torch.arange(
216
+ seq_length, dtype=torch.long, device=hidden_states.device
217
+ ).view(1, -1)
218
+ distance = position_ids_l - position_ids_r
219
+ positional_embedding = self.distance_embedding(
220
+ distance + self.max_position_embeddings - 1
221
+ )
222
+ positional_embedding = positional_embedding.to(
223
+ dtype=query_layer.dtype
224
+ ) # fp16 compatibility
225
+
226
+ if self.position_embedding_type == "relative_key":
227
+ relative_position_scores = torch.einsum(
228
+ "bhld,lrd->bhlr", query_layer, positional_embedding
229
+ )
230
+ attention_scores = attention_scores + relative_position_scores
231
+ elif self.position_embedding_type == "relative_key_query":
232
+ relative_position_scores_query = torch.einsum(
233
+ "bhld,lrd->bhlr", query_layer, positional_embedding
234
+ )
235
+ relative_position_scores_key = torch.einsum(
236
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
237
+ )
238
+ attention_scores = (
239
+ attention_scores
240
+ + relative_position_scores_query
241
+ + relative_position_scores_key
242
+ )
243
+
244
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
245
+ if attention_mask is not None:
246
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
247
+ attention_scores = attention_scores + attention_mask
248
+
249
+ # Normalize the attention scores to probabilities.
250
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
251
+
252
+ if is_cross_attention and self.save_attention:
253
+ self.save_attention_map(attention_probs)
254
+ attention_probs.register_hook(self.save_attn_gradients)
255
+
256
+ # This is actually dropping out entire tokens to attend to, which might
257
+ # seem a bit unusual, but is taken from the original Transformer paper.
258
+ attention_probs_dropped = self.dropout(attention_probs)
259
+
260
+ # Mask heads if we want to
261
+ if head_mask is not None:
262
+ attention_probs_dropped = attention_probs_dropped * head_mask
263
+
264
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
265
+
266
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
267
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
268
+ context_layer = context_layer.view(*new_context_layer_shape)
269
+
270
+ outputs = (
271
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
272
+ )
273
+
274
+ outputs = outputs + (past_key_value,)
275
+ return outputs
276
+
277
+
278
+ class BertSelfOutput(nn.Module):
279
+ def __init__(self, config):
280
+ super().__init__()
281
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
282
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
283
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
284
+
285
+ def forward(self, hidden_states, input_tensor):
286
+ hidden_states = self.dense(hidden_states)
287
+ hidden_states = self.dropout(hidden_states)
288
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
289
+ return hidden_states
290
+
291
+
292
+ class BertAttention(nn.Module):
293
+ def __init__(self, config, is_cross_attention=False):
294
+ super().__init__()
295
+ self.self = BertSelfAttention(config, is_cross_attention)
296
+ self.output = BertSelfOutput(config)
297
+ self.pruned_heads = set()
298
+
299
+ def prune_heads(self, heads):
300
+ if len(heads) == 0:
301
+ return
302
+ heads, index = find_pruneable_heads_and_indices(
303
+ heads,
304
+ self.self.num_attention_heads,
305
+ self.self.attention_head_size,
306
+ self.pruned_heads,
307
+ )
308
+
309
+ # Prune linear layers
310
+ self.self.query = prune_linear_layer(self.self.query, index)
311
+ self.self.key = prune_linear_layer(self.self.key, index)
312
+ self.self.value = prune_linear_layer(self.self.value, index)
313
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
314
+
315
+ # Update hyper params and store pruned heads
316
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
317
+ self.self.all_head_size = (
318
+ self.self.attention_head_size * self.self.num_attention_heads
319
+ )
320
+ self.pruned_heads = self.pruned_heads.union(heads)
321
+
322
+ def forward(
323
+ self,
324
+ hidden_states,
325
+ attention_mask=None,
326
+ head_mask=None,
327
+ encoder_hidden_states=None,
328
+ encoder_attention_mask=None,
329
+ past_key_value=None,
330
+ output_attentions=False,
331
+ ):
332
+ self_outputs = self.self(
333
+ hidden_states,
334
+ attention_mask,
335
+ head_mask,
336
+ encoder_hidden_states,
337
+ encoder_attention_mask,
338
+ past_key_value,
339
+ output_attentions,
340
+ )
341
+ attention_output = self.output(self_outputs[0], hidden_states)
342
+
343
+ outputs = (attention_output,) + self_outputs[
344
+ 1:
345
+ ] # add attentions if we output them
346
+ return outputs
347
+
348
+
349
+ class BertIntermediate(nn.Module):
350
+ def __init__(self, config):
351
+ super().__init__()
352
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
353
+ if isinstance(config.hidden_act, str):
354
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
355
+ else:
356
+ self.intermediate_act_fn = config.hidden_act
357
+
358
+ def forward(self, hidden_states):
359
+ hidden_states = self.dense(hidden_states)
360
+ hidden_states = self.intermediate_act_fn(hidden_states)
361
+ return hidden_states
362
+
363
+
364
+ class BertOutput(nn.Module):
365
+ def __init__(self, config):
366
+ super().__init__()
367
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
368
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
369
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
370
+
371
+ def forward(self, hidden_states, input_tensor):
372
+ hidden_states = self.dense(hidden_states)
373
+ hidden_states = self.dropout(hidden_states)
374
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
375
+ return hidden_states
376
+
377
+
378
+ class BertLayer(nn.Module):
379
+ def __init__(self, config, layer_num):
380
+ super().__init__()
381
+ self.config = config
382
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
383
+ self.seq_len_dim = 1
384
+ self.attention = BertAttention(config)
385
+ self.layer_num = layer_num
386
+ if (
387
+ self.config.add_cross_attention
388
+ and layer_num % self.config.cross_attention_freq == 0
389
+ ):
390
+ self.crossattention = BertAttention(
391
+ config, is_cross_attention=self.config.add_cross_attention
392
+ )
393
+ self.has_cross_attention = True
394
+ else:
395
+ self.has_cross_attention = False
396
+ self.intermediate = BertIntermediate(config)
397
+ self.output = BertOutput(config)
398
+
399
+ self.intermediate_query = BertIntermediate(config)
400
+ self.output_query = BertOutput(config)
401
+
402
+ def forward(
403
+ self,
404
+ hidden_states,
405
+ attention_mask=None,
406
+ head_mask=None,
407
+ encoder_hidden_states=None,
408
+ encoder_attention_mask=None,
409
+ past_key_value=None,
410
+ output_attentions=False,
411
+ query_length=0,
412
+ ):
413
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
414
+ self_attn_past_key_value = (
415
+ past_key_value[:2] if past_key_value is not None else None
416
+ )
417
+ self_attention_outputs = self.attention(
418
+ hidden_states,
419
+ attention_mask,
420
+ head_mask,
421
+ output_attentions=output_attentions,
422
+ past_key_value=self_attn_past_key_value,
423
+ )
424
+ attention_output = self_attention_outputs[0]
425
+ outputs = self_attention_outputs[1:-1]
426
+
427
+ present_key_value = self_attention_outputs[-1]
428
+
429
+ if query_length > 0:
430
+ query_attention_output = attention_output[:, :query_length, :]
431
+
432
+ if self.has_cross_attention:
433
+ assert (
434
+ encoder_hidden_states is not None
435
+ ), "encoder_hidden_states must be given for cross-attention layers"
436
+ cross_attention_outputs = self.crossattention(
437
+ query_attention_output,
438
+ attention_mask,
439
+ head_mask,
440
+ encoder_hidden_states,
441
+ encoder_attention_mask,
442
+ output_attentions=output_attentions,
443
+ )
444
+ query_attention_output = cross_attention_outputs[0]
445
+ outputs = (
446
+ outputs + cross_attention_outputs[1:-1]
447
+ ) # add cross attentions if we output attention weights
448
+
449
+ layer_output = apply_chunking_to_forward(
450
+ self.feed_forward_chunk_query,
451
+ self.chunk_size_feed_forward,
452
+ self.seq_len_dim,
453
+ query_attention_output,
454
+ )
455
+ if attention_output.shape[1] > query_length:
456
+ layer_output_text = apply_chunking_to_forward(
457
+ self.feed_forward_chunk,
458
+ self.chunk_size_feed_forward,
459
+ self.seq_len_dim,
460
+ attention_output[:, query_length:, :],
461
+ )
462
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
463
+ else:
464
+ layer_output = apply_chunking_to_forward(
465
+ self.feed_forward_chunk,
466
+ self.chunk_size_feed_forward,
467
+ self.seq_len_dim,
468
+ attention_output,
469
+ )
470
+ outputs = (layer_output,) + outputs
471
+
472
+ outputs = outputs + (present_key_value,)
473
+
474
+ return outputs
475
+
476
+ def feed_forward_chunk(self, attention_output):
477
+ intermediate_output = self.intermediate(attention_output)
478
+ layer_output = self.output(intermediate_output, attention_output)
479
+ return layer_output
480
+
481
+ def feed_forward_chunk_query(self, attention_output):
482
+ intermediate_output = self.intermediate_query(attention_output)
483
+ layer_output = self.output_query(intermediate_output, attention_output)
484
+ return layer_output
485
+
486
+
487
+ class BertEncoder(nn.Module):
488
+ def __init__(self, config):
489
+ super().__init__()
490
+ self.config = config
491
+ self.layer = nn.ModuleList(
492
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
493
+ )
494
+
495
+ def forward(
496
+ self,
497
+ hidden_states,
498
+ attention_mask=None,
499
+ head_mask=None,
500
+ encoder_hidden_states=None,
501
+ encoder_attention_mask=None,
502
+ past_key_values=None,
503
+ use_cache=None,
504
+ output_attentions=False,
505
+ output_hidden_states=False,
506
+ return_dict=True,
507
+ query_length=0,
508
+ ):
509
+ all_hidden_states = () if output_hidden_states else None
510
+ all_self_attentions = () if output_attentions else None
511
+ all_cross_attentions = (
512
+ () if output_attentions and self.config.add_cross_attention else None
513
+ )
514
+
515
+ next_decoder_cache = () if use_cache else None
516
+
517
+ for i in range(self.config.num_hidden_layers):
518
+ layer_module = self.layer[i]
519
+ if output_hidden_states:
520
+ all_hidden_states = all_hidden_states + (hidden_states,)
521
+
522
+ layer_head_mask = head_mask[i] if head_mask is not None else None
523
+ past_key_value = past_key_values[i] if past_key_values is not None else None
524
+
525
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
526
+
527
+ if use_cache:
528
+ logger.warn(
529
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
530
+ )
531
+ use_cache = False
532
+
533
+ def create_custom_forward(module):
534
+ def custom_forward(*inputs):
535
+ return module(
536
+ *inputs, past_key_value, output_attentions, query_length
537
+ )
538
+
539
+ return custom_forward
540
+
541
+ layer_outputs = torch.utils.checkpoint.checkpoint(
542
+ create_custom_forward(layer_module),
543
+ hidden_states,
544
+ attention_mask,
545
+ layer_head_mask,
546
+ encoder_hidden_states,
547
+ encoder_attention_mask,
548
+ )
549
+ else:
550
+ layer_outputs = layer_module(
551
+ hidden_states,
552
+ attention_mask,
553
+ layer_head_mask,
554
+ encoder_hidden_states,
555
+ encoder_attention_mask,
556
+ past_key_value,
557
+ output_attentions,
558
+ query_length,
559
+ )
560
+
561
+ hidden_states = layer_outputs[0]
562
+ if use_cache:
563
+ next_decoder_cache += (layer_outputs[-1],)
564
+ if output_attentions:
565
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
566
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
567
+
568
+ if output_hidden_states:
569
+ all_hidden_states = all_hidden_states + (hidden_states,)
570
+
571
+ if not return_dict:
572
+ return tuple(
573
+ v
574
+ for v in [
575
+ hidden_states,
576
+ next_decoder_cache,
577
+ all_hidden_states,
578
+ all_self_attentions,
579
+ all_cross_attentions,
580
+ ]
581
+ if v is not None
582
+ )
583
+ return BaseModelOutputWithPastAndCrossAttentions(
584
+ last_hidden_state=hidden_states,
585
+ past_key_values=next_decoder_cache,
586
+ hidden_states=all_hidden_states,
587
+ attentions=all_self_attentions,
588
+ cross_attentions=all_cross_attentions,
589
+ )
590
+
591
+
592
+ class BertPooler(nn.Module):
593
+ def __init__(self, config):
594
+ super().__init__()
595
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
596
+ self.activation = nn.Tanh()
597
+
598
+ def forward(self, hidden_states):
599
+ # We "pool" the model by simply taking the hidden state corresponding
600
+ # to the first token.
601
+ first_token_tensor = hidden_states[:, 0]
602
+ pooled_output = self.dense(first_token_tensor)
603
+ pooled_output = self.activation(pooled_output)
604
+ return pooled_output
605
+
606
+
607
+ class BertPredictionHeadTransform(nn.Module):
608
+ def __init__(self, config):
609
+ super().__init__()
610
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
611
+ if isinstance(config.hidden_act, str):
612
+ self.transform_act_fn = ACT2FN[config.hidden_act]
613
+ else:
614
+ self.transform_act_fn = config.hidden_act
615
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
616
+
617
+ def forward(self, hidden_states):
618
+ hidden_states = self.dense(hidden_states)
619
+ hidden_states = self.transform_act_fn(hidden_states)
620
+ hidden_states = self.LayerNorm(hidden_states)
621
+ return hidden_states
622
+
623
+
624
+ class BertLMPredictionHead(nn.Module):
625
+ def __init__(self, config):
626
+ super().__init__()
627
+ self.transform = BertPredictionHeadTransform(config)
628
+
629
+ # The output weights are the same as the input embeddings, but there is
630
+ # an output-only bias for each token.
631
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
632
+
633
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
634
+
635
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
636
+ self.decoder.bias = self.bias
637
+
638
+ def forward(self, hidden_states):
639
+ hidden_states = self.transform(hidden_states)
640
+ hidden_states = self.decoder(hidden_states)
641
+ return hidden_states
642
+
643
+
644
+ class BertOnlyMLMHead(nn.Module):
645
+ def __init__(self, config):
646
+ super().__init__()
647
+ self.predictions = BertLMPredictionHead(config)
648
+
649
+ def forward(self, sequence_output):
650
+ prediction_scores = self.predictions(sequence_output)
651
+ return prediction_scores
652
+
653
+
654
+ class BertPreTrainedModel(PreTrainedModel):
655
+ """
656
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
657
+ models.
658
+ """
659
+
660
+ config_class = BertConfig
661
+ base_model_prefix = "bert"
662
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
663
+
664
+ def _init_weights(self, module):
665
+ """Initialize the weights"""
666
+ if isinstance(module, (nn.Linear, nn.Embedding)):
667
+ # Slightly different from the TF version which uses truncated_normal for initialization
668
+ # cf https://github.com/pytorch/pytorch/pull/5617
669
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
670
+ elif isinstance(module, nn.LayerNorm):
671
+ module.bias.data.zero_()
672
+ module.weight.data.fill_(1.0)
673
+ if isinstance(module, nn.Linear) and module.bias is not None:
674
+ module.bias.data.zero_()
675
+
676
+
677
+ class BertModel(BertPreTrainedModel):
678
+ """
679
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
680
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
681
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
682
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
683
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
684
+ input to the forward pass.
685
+ """
686
+
687
+ def __init__(self, config, add_pooling_layer=False):
688
+ super().__init__(config)
689
+ self.config = config
690
+
691
+ self.embeddings = BertEmbeddings(config)
692
+
693
+ self.encoder = BertEncoder(config)
694
+
695
+ self.pooler = BertPooler(config) if add_pooling_layer else None
696
+
697
+ self.init_weights()
698
+
699
+ def get_input_embeddings(self):
700
+ return self.embeddings.word_embeddings
701
+
702
+ def set_input_embeddings(self, value):
703
+ self.embeddings.word_embeddings = value
704
+
705
+ def _prune_heads(self, heads_to_prune):
706
+ """
707
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
708
+ class PreTrainedModel
709
+ """
710
+ for layer, heads in heads_to_prune.items():
711
+ self.encoder.layer[layer].attention.prune_heads(heads)
712
+
713
+ def get_extended_attention_mask(
714
+ self,
715
+ attention_mask: Tensor,
716
+ input_shape: Tuple[int],
717
+ device: device,
718
+ is_decoder: bool,
719
+ has_query: bool = False,
720
+ ) -> Tensor:
721
+ """
722
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
723
+
724
+ Arguments:
725
+ attention_mask (:obj:`torch.Tensor`):
726
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
727
+ input_shape (:obj:`Tuple[int]`):
728
+ The shape of the input to the model.
729
+ device: (:obj:`torch.device`):
730
+ The device of the input to the model.
731
+
732
+ Returns:
733
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
734
+ """
735
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
736
+ # ourselves in which case we just need to make it broadcastable to all heads.
737
+ if attention_mask.dim() == 3:
738
+ extended_attention_mask = attention_mask[:, None, :, :]
739
+ elif attention_mask.dim() == 2:
740
+ # Provided a padding mask of dimensions [batch_size, seq_length]
741
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
742
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
743
+ if is_decoder:
744
+ batch_size, seq_length = input_shape
745
+
746
+ seq_ids = torch.arange(seq_length, device=device)
747
+ causal_mask = (
748
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
749
+ <= seq_ids[None, :, None]
750
+ )
751
+
752
+ # add a prefix ones mask to the causal mask
753
+ # causal and attention masks must have same type with pytorch version < 1.3
754
+ causal_mask = causal_mask.to(attention_mask.dtype)
755
+
756
+ if causal_mask.shape[1] < attention_mask.shape[1]:
757
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
758
+ if has_query: # UniLM style attention mask
759
+ causal_mask = torch.cat(
760
+ [
761
+ torch.zeros(
762
+ (batch_size, prefix_seq_len, seq_length),
763
+ device=device,
764
+ dtype=causal_mask.dtype,
765
+ ),
766
+ causal_mask,
767
+ ],
768
+ axis=1,
769
+ )
770
+ causal_mask = torch.cat(
771
+ [
772
+ torch.ones(
773
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
774
+ device=device,
775
+ dtype=causal_mask.dtype,
776
+ ),
777
+ causal_mask,
778
+ ],
779
+ axis=-1,
780
+ )
781
+ extended_attention_mask = (
782
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
783
+ )
784
+ else:
785
+ extended_attention_mask = attention_mask[:, None, None, :]
786
+ else:
787
+ raise ValueError(
788
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
789
+ input_shape, attention_mask.shape
790
+ )
791
+ )
792
+
793
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
794
+ # masked positions, this operation will create a tensor which is 0.0 for
795
+ # positions we want to attend and -10000.0 for masked positions.
796
+ # Since we are adding it to the raw scores before the softmax, this is
797
+ # effectively the same as removing these entirely.
798
+ extended_attention_mask = extended_attention_mask.to(
799
+ dtype=self.dtype
800
+ ) # fp16 compatibility
801
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
802
+ return extended_attention_mask
803
+
804
+ def forward(
805
+ self,
806
+ input_ids=None,
807
+ attention_mask=None,
808
+ position_ids=None,
809
+ head_mask=None,
810
+ query_embeds=None,
811
+ encoder_hidden_states=None,
812
+ encoder_attention_mask=None,
813
+ past_key_values=None,
814
+ use_cache=None,
815
+ output_attentions=None,
816
+ output_hidden_states=None,
817
+ return_dict=None,
818
+ is_decoder=False,
819
+ ):
820
+ r"""
821
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
822
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
823
+ the model is configured as a decoder.
824
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
825
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
826
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
827
+ - 1 for tokens that are **not masked**,
828
+ - 0 for tokens that are **masked**.
829
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
830
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
831
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
832
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
833
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
834
+ use_cache (:obj:`bool`, `optional`):
835
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
836
+ decoding (see :obj:`past_key_values`).
837
+ """
838
+ output_attentions = (
839
+ output_attentions
840
+ if output_attentions is not None
841
+ else self.config.output_attentions
842
+ )
843
+ output_hidden_states = (
844
+ output_hidden_states
845
+ if output_hidden_states is not None
846
+ else self.config.output_hidden_states
847
+ )
848
+ return_dict = (
849
+ return_dict if return_dict is not None else self.config.use_return_dict
850
+ )
851
+
852
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
853
+
854
+ if input_ids is None:
855
+ assert (
856
+ query_embeds is not None
857
+ ), "You have to specify query_embeds when input_ids is None"
858
+
859
+ # past_key_values_length
860
+ past_key_values_length = (
861
+ past_key_values[0][0].shape[2] - self.config.query_length
862
+ if past_key_values is not None
863
+ else 0
864
+ )
865
+
866
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
867
+
868
+ embedding_output = self.embeddings(
869
+ input_ids=input_ids,
870
+ position_ids=position_ids,
871
+ query_embeds=query_embeds,
872
+ past_key_values_length=past_key_values_length,
873
+ )
874
+
875
+ input_shape = embedding_output.size()[:-1]
876
+ batch_size, seq_length = input_shape
877
+ device = embedding_output.device
878
+
879
+ if attention_mask is None:
880
+ attention_mask = torch.ones(
881
+ ((batch_size, seq_length + past_key_values_length)), device=device
882
+ )
883
+
884
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
885
+ # ourselves in which case we just need to make it broadcastable to all heads.
886
+ if is_decoder:
887
+ extended_attention_mask = self.get_extended_attention_mask(
888
+ attention_mask,
889
+ input_ids.shape,
890
+ device,
891
+ is_decoder,
892
+ has_query=(query_embeds is not None),
893
+ )
894
+ else:
895
+ extended_attention_mask = self.get_extended_attention_mask(
896
+ attention_mask, input_shape, device, is_decoder
897
+ )
898
+
899
+ # If a 2D or 3D attention mask is provided for the cross-attention
900
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
901
+ if encoder_hidden_states is not None:
902
+ if type(encoder_hidden_states) == list:
903
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
904
+ 0
905
+ ].size()
906
+ else:
907
+ (
908
+ encoder_batch_size,
909
+ encoder_sequence_length,
910
+ _,
911
+ ) = encoder_hidden_states.size()
912
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
913
+
914
+ if type(encoder_attention_mask) == list:
915
+ encoder_extended_attention_mask = [
916
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
917
+ ]
918
+ elif encoder_attention_mask is None:
919
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
920
+ encoder_extended_attention_mask = self.invert_attention_mask(
921
+ encoder_attention_mask
922
+ )
923
+ else:
924
+ encoder_extended_attention_mask = self.invert_attention_mask(
925
+ encoder_attention_mask
926
+ )
927
+ else:
928
+ encoder_extended_attention_mask = None
929
+
930
+ # Prepare head mask if needed
931
+ # 1.0 in head_mask indicate we keep the head
932
+ # attention_probs has shape bsz x n_heads x N x N
933
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
934
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
935
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
936
+
937
+ encoder_outputs = self.encoder(
938
+ embedding_output,
939
+ attention_mask=extended_attention_mask,
940
+ head_mask=head_mask,
941
+ encoder_hidden_states=encoder_hidden_states,
942
+ encoder_attention_mask=encoder_extended_attention_mask,
943
+ past_key_values=past_key_values,
944
+ use_cache=use_cache,
945
+ output_attentions=output_attentions,
946
+ output_hidden_states=output_hidden_states,
947
+ return_dict=return_dict,
948
+ query_length=query_length,
949
+ )
950
+ sequence_output = encoder_outputs[0]
951
+ pooled_output = (
952
+ self.pooler(sequence_output) if self.pooler is not None else None
953
+ )
954
+
955
+ if not return_dict:
956
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
957
+
958
+ return BaseModelOutputWithPoolingAndCrossAttentions(
959
+ last_hidden_state=sequence_output,
960
+ pooler_output=pooled_output,
961
+ past_key_values=encoder_outputs.past_key_values,
962
+ hidden_states=encoder_outputs.hidden_states,
963
+ attentions=encoder_outputs.attentions,
964
+ cross_attentions=encoder_outputs.cross_attentions,
965
+ )
966
+
967
+
968
+ class BertLMHeadModel(BertPreTrainedModel):
969
+
970
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
971
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
972
+
973
+ def __init__(self, config):
974
+ super().__init__(config)
975
+
976
+ self.bert = BertModel(config, add_pooling_layer=False)
977
+ self.cls = BertOnlyMLMHead(config)
978
+
979
+ self.init_weights()
980
+
981
+ def get_output_embeddings(self):
982
+ return self.cls.predictions.decoder
983
+
984
+ def set_output_embeddings(self, new_embeddings):
985
+ self.cls.predictions.decoder = new_embeddings
986
+
987
+ def forward(
988
+ self,
989
+ input_ids=None,
990
+ attention_mask=None,
991
+ position_ids=None,
992
+ head_mask=None,
993
+ query_embeds=None,
994
+ encoder_hidden_states=None,
995
+ encoder_attention_mask=None,
996
+ labels=None,
997
+ past_key_values=None,
998
+ use_cache=True,
999
+ output_attentions=None,
1000
+ output_hidden_states=None,
1001
+ return_dict=None,
1002
+ return_logits=False,
1003
+ is_decoder=True,
1004
+ reduction="mean",
1005
+ ):
1006
+ r"""
1007
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1008
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1009
+ the model is configured as a decoder.
1010
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1011
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1012
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1013
+ - 1 for tokens that are **not masked**,
1014
+ - 0 for tokens that are **masked**.
1015
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1016
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1017
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1018
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1019
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1020
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1021
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1022
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1023
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1024
+ use_cache (:obj:`bool`, `optional`):
1025
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1026
+ decoding (see :obj:`past_key_values`).
1027
+ Returns:
1028
+ Example::
1029
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1030
+ >>> import torch
1031
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1032
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1033
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1034
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1035
+ >>> outputs = model(**inputs)
1036
+ >>> prediction_logits = outputs.logits
1037
+ """
1038
+ return_dict = (
1039
+ return_dict if return_dict is not None else self.config.use_return_dict
1040
+ )
1041
+ if labels is not None:
1042
+ use_cache = False
1043
+ if past_key_values is not None:
1044
+ query_embeds = None
1045
+
1046
+ outputs = self.bert(
1047
+ input_ids,
1048
+ attention_mask=attention_mask,
1049
+ position_ids=position_ids,
1050
+ head_mask=head_mask,
1051
+ query_embeds=query_embeds,
1052
+ encoder_hidden_states=encoder_hidden_states,
1053
+ encoder_attention_mask=encoder_attention_mask,
1054
+ past_key_values=past_key_values,
1055
+ use_cache=use_cache,
1056
+ output_attentions=output_attentions,
1057
+ output_hidden_states=output_hidden_states,
1058
+ return_dict=return_dict,
1059
+ is_decoder=is_decoder,
1060
+ )
1061
+
1062
+ sequence_output = outputs[0]
1063
+ if query_embeds is not None:
1064
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1065
+
1066
+ prediction_scores = self.cls(sequence_output)
1067
+
1068
+ if return_logits:
1069
+ return prediction_scores[:, :-1, :].contiguous()
1070
+
1071
+ lm_loss = None
1072
+ if labels is not None:
1073
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1074
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1075
+ labels = labels[:, 1:].contiguous()
1076
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
1077
+ lm_loss = loss_fct(
1078
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
1079
+ labels.view(-1),
1080
+ )
1081
+ if reduction == "none":
1082
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1083
+
1084
+ if not return_dict:
1085
+ output = (prediction_scores,) + outputs[2:]
1086
+ return ((lm_loss,) + output) if lm_loss is not None else output
1087
+
1088
+ return CausalLMOutputWithCrossAttentions(
1089
+ loss=lm_loss,
1090
+ logits=prediction_scores,
1091
+ past_key_values=outputs.past_key_values,
1092
+ hidden_states=outputs.hidden_states,
1093
+ attentions=outputs.attentions,
1094
+ cross_attentions=outputs.cross_attentions,
1095
+ )
1096
+
1097
+ def prepare_inputs_for_generation(
1098
+ self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
1099
+ ):
1100
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1101
+ if attention_mask is None:
1102
+ attention_mask = input_ids.new_ones(input_ids.shape)
1103
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
1104
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
1105
+
1106
+ # cut decoder_input_ids if past is used
1107
+ if past is not None:
1108
+ input_ids = input_ids[:, -1:]
1109
+
1110
+ return {
1111
+ "input_ids": input_ids,
1112
+ "query_embeds": query_embeds,
1113
+ "attention_mask": attention_mask,
1114
+ "past_key_values": past,
1115
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1116
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1117
+ "is_decoder": True,
1118
+ }
1119
+
1120
+ def _reorder_cache(self, past, beam_idx):
1121
+ reordered_past = ()
1122
+ for layer_past in past:
1123
+ reordered_past += (
1124
+ tuple(
1125
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1126
+ ),
1127
+ )
1128
+ return reordered_past
1129
+
1130
+
1131
+ class BertForMaskedLM(BertPreTrainedModel):
1132
+
1133
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1134
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1135
+
1136
+ def __init__(self, config):
1137
+ super().__init__(config)
1138
+
1139
+ self.bert = BertModel(config, add_pooling_layer=False)
1140
+ self.cls = BertOnlyMLMHead(config)
1141
+
1142
+ self.init_weights()
1143
+
1144
+ def get_output_embeddings(self):
1145
+ return self.cls.predictions.decoder
1146
+
1147
+ def set_output_embeddings(self, new_embeddings):
1148
+ self.cls.predictions.decoder = new_embeddings
1149
+
1150
+ def forward(
1151
+ self,
1152
+ input_ids=None,
1153
+ attention_mask=None,
1154
+ position_ids=None,
1155
+ head_mask=None,
1156
+ query_embeds=None,
1157
+ encoder_hidden_states=None,
1158
+ encoder_attention_mask=None,
1159
+ labels=None,
1160
+ output_attentions=None,
1161
+ output_hidden_states=None,
1162
+ return_dict=None,
1163
+ return_logits=False,
1164
+ is_decoder=False,
1165
+ ):
1166
+ r"""
1167
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1168
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1169
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1170
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1171
+ """
1172
+
1173
+ return_dict = (
1174
+ return_dict if return_dict is not None else self.config.use_return_dict
1175
+ )
1176
+
1177
+ outputs = self.bert(
1178
+ input_ids,
1179
+ attention_mask=attention_mask,
1180
+ position_ids=position_ids,
1181
+ head_mask=head_mask,
1182
+ query_embeds=query_embeds,
1183
+ encoder_hidden_states=encoder_hidden_states,
1184
+ encoder_attention_mask=encoder_attention_mask,
1185
+ output_attentions=output_attentions,
1186
+ output_hidden_states=output_hidden_states,
1187
+ return_dict=return_dict,
1188
+ is_decoder=is_decoder,
1189
+ )
1190
+
1191
+ if query_embeds is not None:
1192
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1193
+ prediction_scores = self.cls(sequence_output)
1194
+
1195
+ if return_logits:
1196
+ return prediction_scores
1197
+
1198
+ masked_lm_loss = None
1199
+ if labels is not None:
1200
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1201
+ masked_lm_loss = loss_fct(
1202
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1203
+ )
1204
+
1205
+ if not return_dict:
1206
+ output = (prediction_scores,) + outputs[2:]
1207
+ return (
1208
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1209
+ )
1210
+
1211
+ return MaskedLMOutput(
1212
+ loss=masked_lm_loss,
1213
+ logits=prediction_scores,
1214
+ hidden_states=outputs.hidden_states,
1215
+ attentions=outputs.attentions,
1216
+ )
__init__.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import torch
10
+ from omegaconf import OmegaConf
11
+
12
+ from .registry import registry
13
+ from .base_model import BaseModel
14
+ from .base_processor import BaseProcessor
15
+ from .blip_processors import *
16
+ from .blip2 import Blip2Base
17
+ from .clip_vision_encoder import *
18
+ from .config import *
19
+ from .eva_vit import *
20
+ from .mini_gpt4_llama_v2 import MiniGPT4_Video
21
+
22
+
23
+
24
+ __all__ = [
25
+ "load_model",
26
+ "BaseModel",
27
+ "Blip2Base",
28
+ "MiniGPT4_Video",
29
+
30
+ ]
31
+
32
+
33
+ def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None):
34
+ """
35
+ Load supported models.
36
+
37
+ To list all available models and types in registry:
38
+ >>> from minigpt4.models import model_zoo
39
+ >>> print(model_zoo)
40
+
41
+ Args:
42
+ name (str): name of the model.
43
+ model_type (str): type of the model.
44
+ is_eval (bool): whether the model is in eval mode. Default: False.
45
+ device (str): device to use. Default: "cpu".
46
+ checkpoint (str): path or to checkpoint. Default: None.
47
+ Note that expecting the checkpoint to have the same keys in state_dict as the model.
48
+
49
+ Returns:
50
+ model (torch.nn.Module): model.
51
+ """
52
+
53
+ model = registry.get_model_class(name).from_pretrained(model_type=model_type)
54
+
55
+ if checkpoint is not None:
56
+ model.load_checkpoint(checkpoint)
57
+
58
+ if is_eval:
59
+ model.eval()
60
+
61
+ if device == "cpu":
62
+ model = model.float()
63
+
64
+ return model.to(device)
65
+
66
+
67
+ def load_preprocess(config):
68
+ """
69
+ Load preprocessor configs and construct preprocessors.
70
+
71
+ If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.
72
+
73
+ Args:
74
+ config (dict): preprocessor configs.
75
+
76
+ Returns:
77
+ vis_processors (dict): preprocessors for visual inputs.
78
+ txt_processors (dict): preprocessors for text inputs.
79
+
80
+ Key is "train" or "eval" for processors used in training and evaluation respectively.
81
+ """
82
+
83
+ def _build_proc_from_cfg(cfg):
84
+ return (
85
+ registry.get_processor_class(cfg.name).from_config(cfg)
86
+ if cfg is not None
87
+ else BaseProcessor()
88
+ )
89
+
90
+ vis_processors = dict()
91
+ txt_processors = dict()
92
+
93
+ vis_proc_cfg = config.get("vis_processor")
94
+ txt_proc_cfg = config.get("text_processor")
95
+
96
+ if vis_proc_cfg is not None:
97
+ vis_train_cfg = vis_proc_cfg.get("train")
98
+ vis_eval_cfg = vis_proc_cfg.get("eval")
99
+ else:
100
+ vis_train_cfg = None
101
+ vis_eval_cfg = None
102
+
103
+ vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg)
104
+ vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg)
105
+
106
+ if txt_proc_cfg is not None:
107
+ txt_train_cfg = txt_proc_cfg.get("train")
108
+ txt_eval_cfg = txt_proc_cfg.get("eval")
109
+ else:
110
+ txt_train_cfg = None
111
+ txt_eval_cfg = None
112
+
113
+ txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg)
114
+ txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg)
115
+
116
+ return vis_processors, txt_processors
117
+
118
+
119
+ def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
120
+ """
121
+ Load model and its related preprocessors.
122
+
123
+ List all available models and types in registry:
124
+ >>> from minigpt4.models import model_zoo
125
+ >>> print(model_zoo)
126
+
127
+ Args:
128
+ name (str): name of the model.
129
+ model_type (str): type of the model.
130
+ is_eval (bool): whether the model is in eval mode. Default: False.
131
+ device (str): device to use. Default: "cpu".
132
+
133
+ Returns:
134
+ model (torch.nn.Module): model.
135
+ vis_processors (dict): preprocessors for visual inputs.
136
+ txt_processors (dict): preprocessors for text inputs.
137
+ """
138
+ model_cls = registry.get_model_class(name)
139
+
140
+ # load model
141
+ model = model_cls.from_pretrained(model_type=model_type)
142
+
143
+ if is_eval:
144
+ model.eval()
145
+
146
+ # load preprocess
147
+ cfg = OmegaConf.load(model_cls.default_config_path(model_type))
148
+ if cfg is not None:
149
+ preprocess_cfg = cfg.preprocess
150
+
151
+ vis_processors, txt_processors = load_preprocess(preprocess_cfg)
152
+ else:
153
+ vis_processors, txt_processors = None, None
154
+ logging.info(
155
+ f"""No default preprocess for model {name} ({model_type}).
156
+ This can happen if the model is not finetuned on downstream datasets,
157
+ or it is not intended for direct use without finetuning.
158
+ """
159
+ )
160
+
161
+ if device == "cpu" or device == torch.device("cpu"):
162
+ model = model.float()
163
+
164
+ return model.to(device), vis_processors, txt_processors
165
+
166
+
167
+ class ModelZoo:
168
+ """
169
+ A utility class to create string representation of available model architectures and types.
170
+
171
+ >>> from minigpt4.models import model_zoo
172
+ >>> # list all available models
173
+ >>> print(model_zoo)
174
+ >>> # show total number of models
175
+ >>> print(len(model_zoo))
176
+ """
177
+
178
+ def __init__(self) -> None:
179
+ self.model_zoo = {
180
+ k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())
181
+ for k, v in registry.mapping["model_name_mapping"].items()
182
+ }
183
+
184
+ def __str__(self) -> str:
185
+ return (
186
+ "=" * 50
187
+ + "\n"
188
+ + f"{'Architectures':<30} {'Types'}\n"
189
+ + "=" * 50
190
+ + "\n"
191
+ + "\n".join(
192
+ [
193
+ f"{name:<30} {', '.join(types)}"
194
+ for name, types in self.model_zoo.items()
195
+ ]
196
+ )
197
+ )
198
+
199
+ def __iter__(self):
200
+ return iter(self.model_zoo.items())
201
+
202
+ def __len__(self):
203
+ return sum([len(v) for v in self.model_zoo.values()])
204
+
205
+
206
+ model_zoo = ModelZoo()
base_model.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import os
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from .dist_utils import download_cached_file, is_dist_avail_and_initialized
15
+ from .utils import get_abs_path, is_url
16
+ from omegaconf import OmegaConf
17
+
18
+ from huggingface_hub import PyTorchModelHubMixin
19
+
20
+ class BaseModel(nn.Module,PyTorchModelHubMixin):
21
+ """Base class for models."""
22
+
23
+ def __init__(self):
24
+ PyTorchModelHubMixin.__init__(self)
25
+ nn.Module.__init__(self)
26
+
27
+ @property
28
+ def device(self):
29
+ return list(self.parameters())[0].device
30
+
31
+ def load_checkpoint(self, url_or_filename):
32
+ """
33
+ Load from a finetuned checkpoint.
34
+
35
+ This should expect no mismatch in the model keys and the checkpoint keys.
36
+ """
37
+
38
+ if is_url(url_or_filename):
39
+ cached_file = download_cached_file(
40
+ url_or_filename, check_hash=False, progress=True
41
+ )
42
+ checkpoint = torch.load(cached_file, map_location="cpu")
43
+ elif os.path.isfile(url_or_filename):
44
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
45
+ else:
46
+ raise RuntimeError("checkpoint url or path is invalid")
47
+
48
+ if "model" in checkpoint.keys():
49
+ state_dict = checkpoint["model"]
50
+ else:
51
+ state_dict = checkpoint
52
+
53
+ msg = self.load_state_dict(state_dict, strict=False)
54
+
55
+ logging.info("Missing keys {}".format(msg.missing_keys))
56
+ logging.info("load checkpoint from %s" % url_or_filename)
57
+
58
+ return msg
59
+
60
+ @classmethod
61
+ # def from_pretrained(cls, model_type):
62
+ # """
63
+ # Build a pretrained model from default configuration file, specified by model_type.
64
+
65
+ # Args:
66
+ # - model_type (str): model type, specifying architecture and checkpoints.
67
+
68
+ # Returns:
69
+ # - model (nn.Module): pretrained or finetuned model, depending on the configuration.
70
+ # """
71
+ # model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
72
+ # model = cls.from_config(model_cfg)
73
+
74
+ # return model
75
+
76
+ @classmethod
77
+ def default_config_path(cls, model_type):
78
+ assert (
79
+ model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
80
+ ), "Unknown model type {}".format(model_type)
81
+ return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
82
+
83
+ def load_checkpoint_from_config(self, cfg, **kwargs):
84
+ """
85
+ Load checkpoint as specified in the config file.
86
+
87
+ If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
88
+ When loading the pretrained model, each task-specific architecture may define their
89
+ own load_from_pretrained() method.
90
+ """
91
+ load_finetuned = cfg.get("load_finetuned", True)
92
+ if load_finetuned:
93
+ finetune_path = cfg.get("finetuned", None)
94
+ assert (
95
+ finetune_path is not None
96
+ ), "Found load_finetuned is True, but finetune_path is None."
97
+ self.load_checkpoint(url_or_filename=finetune_path)
98
+ else:
99
+ # load pre-trained weights
100
+ pretrain_path = cfg.get("pretrained", None)
101
+ assert "Found load_finetuned is False, but pretrain_path is None."
102
+ self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
103
+
104
+ def before_evaluation(self, **kwargs):
105
+ pass
106
+
107
+ def show_n_params(self, return_str=True):
108
+ tot = 0
109
+ for p in self.parameters():
110
+ w = 1
111
+ for x in p.shape:
112
+ w *= x
113
+ tot += w
114
+ if return_str:
115
+ if tot >= 1e6:
116
+ return "{:.1f}M".format(tot / 1e6)
117
+ else:
118
+ return "{:.1f}K".format(tot / 1e3)
119
+ else:
120
+ return tot
121
+
122
+
123
+ class BaseEncoder(nn.Module):
124
+ """
125
+ Base class for primitive encoders, such as ViT, TimeSformer, etc.
126
+ """
127
+
128
+ def __init__(self):
129
+ super().__init__()
130
+
131
+ def forward_features(self, samples, **kwargs):
132
+ raise NotImplementedError
133
+
134
+ @property
135
+ def device(self):
136
+ return list(self.parameters())[0].device
137
+
138
+
139
+ class SharedQueueMixin:
140
+ @torch.no_grad()
141
+ def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
142
+ # gather keys before updating queue
143
+ image_feats = concat_all_gather(image_feat)
144
+ text_feats = concat_all_gather(text_feat)
145
+
146
+ batch_size = image_feats.shape[0]
147
+
148
+ ptr = int(self.queue_ptr)
149
+ assert self.queue_size % batch_size == 0 # for simplicity
150
+
151
+ # replace the keys at ptr (dequeue and enqueue)
152
+ self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
153
+ self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
154
+
155
+ if idxs is not None:
156
+ idxs = concat_all_gather(idxs)
157
+ self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
158
+
159
+ ptr = (ptr + batch_size) % self.queue_size # move pointer
160
+ self.queue_ptr[0] = ptr
161
+
162
+
163
+ class MomentumDistilationMixin:
164
+ @torch.no_grad()
165
+ def copy_params(self):
166
+ for model_pair in self.model_pairs:
167
+ for param, param_m in zip(
168
+ model_pair[0].parameters(), model_pair[1].parameters()
169
+ ):
170
+ param_m.data.copy_(param.data) # initialize
171
+ param_m.requires_grad = False # not update by gradient
172
+
173
+ @torch.no_grad()
174
+ def _momentum_update(self):
175
+ for model_pair in self.model_pairs:
176
+ for param, param_m in zip(
177
+ model_pair[0].parameters(), model_pair[1].parameters()
178
+ ):
179
+ param_m.data = param_m.data * self.momentum + param.data * (
180
+ 1.0 - self.momentum
181
+ )
182
+
183
+
184
+ class GatherLayer(torch.autograd.Function):
185
+ """
186
+ Gather tensors from all workers with support for backward propagation:
187
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
188
+ """
189
+
190
+ @staticmethod
191
+ def forward(ctx, x):
192
+ output = [
193
+ torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
194
+ ]
195
+ torch.distributed.all_gather(output, x)
196
+ return tuple(output)
197
+
198
+ @staticmethod
199
+ def backward(ctx, *grads):
200
+ all_gradients = torch.stack(grads)
201
+ torch.distributed.all_reduce(all_gradients)
202
+ return all_gradients[torch.distributed.get_rank()]
203
+
204
+
205
+ def all_gather_with_grad(tensors):
206
+ """
207
+ Performs all_gather operation on the provided tensors.
208
+ Graph remains connected for backward grad computation.
209
+ """
210
+ # Queue the gathered tensors
211
+ world_size = torch.distributed.get_world_size()
212
+ # There is no need for reduction in the single-proc case
213
+ if world_size == 1:
214
+ return tensors
215
+
216
+ # tensor_all = GatherLayer.apply(tensors)
217
+ tensor_all = GatherLayer.apply(tensors)
218
+
219
+ return torch.cat(tensor_all, dim=0)
220
+
221
+
222
+ @torch.no_grad()
223
+ def concat_all_gather(tensor):
224
+ """
225
+ Performs all_gather operation on the provided tensors.
226
+ *** Warning ***: torch.distributed.all_gather has no gradient.
227
+ """
228
+ # if use distributed training
229
+ if not is_dist_avail_and_initialized():
230
+ return tensor
231
+
232
+ tensors_gather = [
233
+ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
234
+ ]
235
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
236
+
237
+ output = torch.cat(tensors_gather, dim=0)
238
+ return output
239
+
240
+
241
+ def tile(x, dim, n_tile):
242
+ init_dim = x.size(dim)
243
+ repeat_idx = [1] * x.dim()
244
+ repeat_idx[dim] = n_tile
245
+ x = x.repeat(*(repeat_idx))
246
+ order_index = torch.LongTensor(
247
+ np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
248
+ )
249
+ return torch.index_select(x, dim, order_index.to(x.device))
base_processor.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from omegaconf import OmegaConf
9
+
10
+
11
+ class BaseProcessor:
12
+ def __init__(self):
13
+ self.transform = lambda x: x
14
+ return
15
+
16
+ def __call__(self, item):
17
+ return self.transform(item)
18
+
19
+ @classmethod
20
+ def from_config(cls, cfg=None):
21
+ return cls()
22
+
23
+ def build(self, **kwargs):
24
+ cfg = OmegaConf.create(kwargs)
25
+
26
+ return self.from_config(cfg)
blip2.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import contextlib
8
+ import logging
9
+ import os
10
+ import time
11
+ import datetime
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.distributed as dist
16
+ import torch.nn.functional as F
17
+
18
+ from .dist_utils import download_cached_file,get_world_size,get_rank,is_dist_avail_and_initialized
19
+ from .utils import is_url
20
+ from .logger import MetricLogger
21
+ from .base_model import BaseModel
22
+ from .Qformer import BertConfig, BertLMHeadModel
23
+ from .eva_vit import create_eva_vit_g
24
+ from transformers import BertTokenizer
25
+
26
+
27
+ class Blip2Base(BaseModel):
28
+ @classmethod
29
+ def init_tokenizer(cls):
30
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
31
+ tokenizer.add_special_tokens({"bos_token": "[DEC]"})
32
+ return tokenizer
33
+
34
+ def maybe_autocast(self, dtype=torch.float16):
35
+ # if on cpu, don't use autocast
36
+ # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
37
+ enable_autocast = self.device != torch.device("cpu")
38
+
39
+ if enable_autocast:
40
+ return torch.cuda.amp.autocast(dtype=dtype)
41
+ else:
42
+ return contextlib.nullcontext()
43
+
44
+ @classmethod
45
+ def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
46
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
47
+ encoder_config.encoder_width = vision_width
48
+ # insert cross-attention layer every other block
49
+ encoder_config.add_cross_attention = True
50
+ encoder_config.cross_attention_freq = cross_attention_freq
51
+ encoder_config.query_length = num_query_token
52
+ Qformer = BertLMHeadModel(config=encoder_config)
53
+ query_tokens = nn.Parameter(
54
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
55
+ )
56
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
57
+ return Qformer, query_tokens
58
+
59
+ @classmethod
60
+ def init_vision_encoder(
61
+ cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
62
+ ):
63
+ assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
64
+ visual_encoder = create_eva_vit_g(
65
+ img_size, drop_path_rate, use_grad_checkpoint, precision
66
+ )
67
+
68
+ ln_vision = LayerNorm(visual_encoder.num_features)
69
+ return visual_encoder, ln_vision
70
+
71
+ def load_from_pretrained(self, url_or_filename):
72
+ if is_url(url_or_filename):
73
+ cached_file = download_cached_file(
74
+ url_or_filename, check_hash=False, progress=True
75
+ )
76
+ checkpoint = torch.load(cached_file, map_location="cpu")
77
+ elif os.path.isfile(url_or_filename):
78
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
79
+ else:
80
+ raise RuntimeError("checkpoint url or path is invalid")
81
+
82
+ state_dict = checkpoint["model"]
83
+
84
+ msg = self.load_state_dict(state_dict, strict=False)
85
+
86
+ # logging.info("Missing keys {}".format(msg.missing_keys))
87
+ logging.info("load checkpoint from %s" % url_or_filename)
88
+
89
+ return msg
90
+
91
+
92
+ def disabled_train(self, mode=True):
93
+ """Overwrite model.train with this function to make sure train/eval mode
94
+ does not change anymore."""
95
+ return self
96
+
97
+
98
+ class LayerNorm(nn.LayerNorm):
99
+ """Subclass torch's LayerNorm to handle fp16."""
100
+
101
+ def forward(self, x: torch.Tensor):
102
+ orig_type = x.dtype
103
+ ret = super().forward(x.type(torch.float32))
104
+ return ret.type(orig_type)
105
+
106
+
107
+ def compute_sim_matrix(model, data_loader, **kwargs):
108
+ k_test = kwargs.pop("k_test")
109
+
110
+ metric_logger = MetricLogger(delimiter=" ")
111
+ header = "Evaluation:"
112
+
113
+ logging.info("Computing features for evaluation...")
114
+ start_time = time.time()
115
+
116
+ texts = data_loader.dataset.text
117
+ num_text = len(texts)
118
+ text_bs = 256
119
+ text_ids = []
120
+ text_embeds = []
121
+ text_atts = []
122
+ for i in range(0, num_text, text_bs):
123
+ text = texts[i : min(num_text, i + text_bs)]
124
+ text_input = model.tokenizer(
125
+ text,
126
+ padding="max_length",
127
+ truncation=True,
128
+ max_length=35,
129
+ return_tensors="pt",
130
+ ).to(model.device)
131
+ text_feat = model.forward_text(text_input)
132
+ text_embed = F.normalize(model.text_proj(text_feat))
133
+ text_embeds.append(text_embed)
134
+ text_ids.append(text_input.input_ids)
135
+ text_atts.append(text_input.attention_mask)
136
+
137
+ text_embeds = torch.cat(text_embeds, dim=0)
138
+ text_ids = torch.cat(text_ids, dim=0)
139
+ text_atts = torch.cat(text_atts, dim=0)
140
+
141
+ vit_feats = []
142
+ image_embeds = []
143
+ for samples in data_loader:
144
+ image = samples["image"]
145
+
146
+ image = image.to(model.device)
147
+ image_feat, vit_feat = model.forward_image(image)
148
+ image_embed = model.vision_proj(image_feat)
149
+ image_embed = F.normalize(image_embed, dim=-1)
150
+
151
+ vit_feats.append(vit_feat.cpu())
152
+ image_embeds.append(image_embed)
153
+
154
+ vit_feats = torch.cat(vit_feats, dim=0)
155
+ image_embeds = torch.cat(image_embeds, dim=0)
156
+
157
+ sims_matrix = []
158
+ for image_embed in image_embeds:
159
+ sim_q2t = image_embed @ text_embeds.t()
160
+ sim_i2t, _ = sim_q2t.max(0)
161
+ sims_matrix.append(sim_i2t)
162
+ sims_matrix = torch.stack(sims_matrix, dim=0)
163
+
164
+ score_matrix_i2t = torch.full(
165
+ (len(data_loader.dataset.image), len(texts)), -100.0
166
+ ).to(model.device)
167
+
168
+ num_tasks = get_world_size()
169
+ rank = get_rank()
170
+ step = sims_matrix.size(0) // num_tasks + 1
171
+ start = rank * step
172
+ end = min(sims_matrix.size(0), start + step)
173
+
174
+ for i, sims in enumerate(
175
+ metric_logger.log_every(sims_matrix[start:end], 50, header)
176
+ ):
177
+ topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
178
+ image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
179
+ score = model.compute_itm(
180
+ image_inputs=image_inputs,
181
+ text_ids=text_ids[topk_idx],
182
+ text_atts=text_atts[topk_idx],
183
+ ).float()
184
+ score_matrix_i2t[start + i, topk_idx] = score + topk_sim
185
+
186
+ sims_matrix = sims_matrix.t()
187
+ score_matrix_t2i = torch.full(
188
+ (len(texts), len(data_loader.dataset.image)), -100.0
189
+ ).to(model.device)
190
+
191
+ step = sims_matrix.size(0) // num_tasks + 1
192
+ start = rank * step
193
+ end = min(sims_matrix.size(0), start + step)
194
+
195
+ for i, sims in enumerate(
196
+ metric_logger.log_every(sims_matrix[start:end], 50, header)
197
+ ):
198
+ topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
199
+ image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
200
+ score = model.compute_itm(
201
+ image_inputs=image_inputs,
202
+ text_ids=text_ids[start + i].repeat(k_test, 1),
203
+ text_atts=text_atts[start + i].repeat(k_test, 1),
204
+ ).float()
205
+ score_matrix_t2i[start + i, topk_idx] = score + topk_sim
206
+
207
+ if is_dist_avail_and_initialized():
208
+ dist.barrier()
209
+ torch.distributed.all_reduce(
210
+ score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
211
+ )
212
+ torch.distributed.all_reduce(
213
+ score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
214
+ )
215
+
216
+ total_time = time.time() - start_time
217
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
218
+ logging.info("Evaluation time {}".format(total_time_str))
219
+
220
+ return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
blip2_outputs.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Optional
10
+
11
+ import torch
12
+ from transformers.modeling_outputs import (
13
+ ModelOutput,
14
+ BaseModelOutputWithPoolingAndCrossAttentions,
15
+ CausalLMOutputWithCrossAttentions,
16
+ )
17
+
18
+
19
+ @dataclass
20
+ class BlipSimilarity(ModelOutput):
21
+ sim_i2t: torch.FloatTensor = None
22
+ sim_t2i: torch.FloatTensor = None
23
+
24
+ sim_i2t_m: Optional[torch.FloatTensor] = None
25
+ sim_t2i_m: Optional[torch.FloatTensor] = None
26
+
27
+ sim_i2t_targets: Optional[torch.FloatTensor] = None
28
+ sim_t2i_targets: Optional[torch.FloatTensor] = None
29
+
30
+
31
+ @dataclass
32
+ class BlipIntermediateOutput(ModelOutput):
33
+ """
34
+ Data class for intermediate outputs of BLIP models.
35
+
36
+ image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim).
37
+ text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim).
38
+
39
+ image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim).
40
+ text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim).
41
+
42
+ encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder.
43
+ encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs.
44
+
45
+ decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder.
46
+ decoder_labels (torch.LongTensor): labels for the captioning loss.
47
+
48
+ itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2).
49
+ itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,)
50
+
51
+ """
52
+
53
+ # uni-modal features
54
+ image_embeds: torch.FloatTensor = None
55
+ text_embeds: Optional[torch.FloatTensor] = None
56
+
57
+ image_embeds_m: Optional[torch.FloatTensor] = None
58
+ text_embeds_m: Optional[torch.FloatTensor] = None
59
+
60
+ # intermediate outputs of multimodal encoder
61
+ encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
62
+ encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
63
+
64
+ itm_logits: Optional[torch.FloatTensor] = None
65
+ itm_labels: Optional[torch.LongTensor] = None
66
+
67
+ # intermediate outputs of multimodal decoder
68
+ decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None
69
+ decoder_labels: Optional[torch.LongTensor] = None
70
+
71
+
72
+ @dataclass
73
+ class BlipOutput(ModelOutput):
74
+ # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.
75
+ sims: Optional[BlipSimilarity] = None
76
+
77
+ intermediate_output: BlipIntermediateOutput = None
78
+
79
+ loss: Optional[torch.FloatTensor] = None
80
+
81
+ loss_itc: Optional[torch.FloatTensor] = None
82
+
83
+ loss_itm: Optional[torch.FloatTensor] = None
84
+
85
+ loss_lm: Optional[torch.FloatTensor] = None
86
+
87
+
88
+ @dataclass
89
+ class BlipOutputFeatures(ModelOutput):
90
+ """
91
+ Data class of features from BlipFeatureExtractor.
92
+
93
+ Args:
94
+ image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional
95
+ image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional
96
+ text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional
97
+ text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional
98
+
99
+ The first embedding or feature is for the [CLS] token.
100
+
101
+ Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.
102
+ """
103
+
104
+ image_embeds: Optional[torch.FloatTensor] = None
105
+ image_embeds_proj: Optional[torch.FloatTensor] = None
106
+
107
+ text_embeds: Optional[torch.FloatTensor] = None
108
+ text_embeds_proj: Optional[torch.FloatTensor] = None
109
+
110
+ multimodal_embeds: Optional[torch.FloatTensor] = None
blip_processors.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import re
9
+
10
+ from .registry import registry
11
+ from .base_processor import BaseProcessor
12
+ from .randaugment import RandomAugment
13
+ from omegaconf import OmegaConf
14
+ from torchvision import transforms
15
+ from torchvision.transforms.functional import InterpolationMode
16
+
17
+
18
+ class BlipImageBaseProcessor(BaseProcessor):
19
+ def __init__(self, mean=None, std=None):
20
+ if mean is None:
21
+ mean = (0.48145466, 0.4578275, 0.40821073)
22
+ if std is None:
23
+ std = (0.26862954, 0.26130258, 0.27577711)
24
+
25
+
26
+ segment_mean = (0.485, 0.456, 0.406)
27
+ segment_std = (0.229, 0.224, 0.225)
28
+
29
+ self.normalize = transforms.Normalize(segment_mean, segment_std)
30
+
31
+
32
+ @registry.register_processor("blip_caption")
33
+ class BlipCaptionProcessor(BaseProcessor):
34
+ def __init__(self, prompt="", max_words=50):
35
+ self.prompt = prompt
36
+ self.max_words = max_words
37
+
38
+ def __call__(self, caption):
39
+ caption = self.prompt + self.pre_caption(caption)
40
+
41
+ return caption
42
+
43
+ @classmethod
44
+ def from_config(cls, cfg=None):
45
+ if cfg is None:
46
+ cfg = OmegaConf.create()
47
+
48
+ prompt = cfg.get("prompt", "")
49
+ max_words = cfg.get("max_words", 50)
50
+
51
+ return cls(prompt=prompt, max_words=max_words)
52
+
53
+ def pre_caption(self, caption):
54
+ caption = re.sub(
55
+ r"([.!\"()*#:;~])",
56
+ " ",
57
+ caption.lower(),
58
+ )
59
+ caption = re.sub(
60
+ r"\s{2,}",
61
+ " ",
62
+ caption,
63
+ )
64
+ caption = caption.rstrip("\n")
65
+ caption = caption.strip(" ")
66
+
67
+ # truncate caption
68
+ caption_words = caption.split(" ")
69
+ if len(caption_words) > self.max_words:
70
+ caption = " ".join(caption_words[: self.max_words])
71
+
72
+ return caption
73
+
74
+
75
+ @registry.register_processor("blip2_image_train")
76
+ class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
77
+ def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
78
+ super().__init__(mean=mean, std=std)
79
+
80
+ # self.transform = transforms.Compose(
81
+ # [
82
+ # transforms.RandomResizedCrop(
83
+ # image_size,
84
+ # scale=(min_scale, max_scale),
85
+ # interpolation=InterpolationMode.BICUBIC,
86
+ # ),
87
+ # transforms.ToTensor(),
88
+ # self.normalize,
89
+ # ]
90
+ # )
91
+ self.transform = transforms.Compose([
92
+ transforms.Resize(
93
+ (image_size, image_size), interpolation=InterpolationMode.BICUBIC
94
+ ),
95
+ transforms.ToTensor(),
96
+ self.normalize,
97
+ ]
98
+ )
99
+
100
+ # ### segment anything
101
+ # '''
102
+ # x = (x - self.pixel_mean) / self.pixel_std
103
+
104
+ # # Pad
105
+ # h, w = x.shape[-2:]
106
+ # padh = self.image_encoder.img_size - h
107
+ # padw = self.image_encoder.img_size - w
108
+ # x = F.pad(x, (0, padw, 0, padh))
109
+ # '''
110
+
111
+ def __call__(self, item):
112
+ return self.transform(item)
113
+
114
+ @classmethod
115
+ def from_config(cls, cfg=None):
116
+ if cfg is None:
117
+ cfg = OmegaConf.create()
118
+
119
+ image_size = cfg.get("image_size", 224)
120
+
121
+ mean = cfg.get("mean", None)
122
+ std = cfg.get("std", None)
123
+
124
+ min_scale = cfg.get("min_scale", 0.5)
125
+ max_scale = cfg.get("max_scale", 1.0)
126
+
127
+ return cls(
128
+ image_size=image_size,
129
+ mean=mean,
130
+ std=std,
131
+ min_scale=min_scale,
132
+ max_scale=max_scale,
133
+ )
134
+
135
+
136
+ @registry.register_processor("blip2_image_eval")
137
+ class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
138
+ def __init__(self, image_size=224, mean=None, std=None):
139
+ super().__init__(mean=mean, std=std)
140
+
141
+ self.transform = transforms.Compose(
142
+ [
143
+ transforms.Resize(
144
+ (image_size, image_size), interpolation=InterpolationMode.BICUBIC
145
+ ),
146
+ transforms.ToTensor(),
147
+ self.normalize,
148
+ ]
149
+ )
150
+
151
+ def __call__(self, item):
152
+ return self.transform(item)
153
+
154
+ @classmethod
155
+ def from_config(cls, cfg=None):
156
+ if cfg is None:
157
+ cfg = OmegaConf.create()
158
+
159
+ image_size = cfg.get("image_size", 224)
160
+
161
+ mean = cfg.get("mean", None)
162
+ std = cfg.get("std", None)
163
+
164
+ return cls(image_size=image_size, mean=mean, std=std)
clip_vision_encoder.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+
6
+
7
+ class CLIPVisionEncoder(nn.Module):
8
+ def __init__(self, encoder_name="openai/clip-vit-large-patch14", delay_load=False):
9
+ super().__init__()
10
+
11
+ self.is_loaded = False
12
+
13
+ self.vision_encoder_name = encoder_name
14
+ # self.select_layer = args.mm_vision_select_layer
15
+ # self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
16
+ self.select_layer = -1
17
+ self.select_feature = "patch"
18
+ if not delay_load:
19
+ self.load_model()
20
+ else:
21
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_encoder_name)
22
+
23
+ def load_model(self):
24
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_encoder_name)
25
+ self.vision_encoder = CLIPVisionModel.from_pretrained(self.vision_encoder_name)
26
+ self.vision_encoder.requires_grad_(False)
27
+
28
+ self.is_loaded = True
29
+
30
+ def feature_select(self, image_forward_outs):
31
+ image_features = image_forward_outs.hidden_states[self.select_layer]
32
+ if self.select_feature == 'patch':
33
+ image_features = image_features[:, :]
34
+ elif self.select_feature == 'cls_patch':
35
+ image_features = image_features
36
+ else:
37
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
38
+ return image_features
39
+
40
+ @torch.no_grad()
41
+ def forward(self, images):
42
+ if type(images) is list:
43
+ image_features = []
44
+ for image in images:
45
+ image_forward_out = self.vision_encoder(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
46
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
47
+ image_features.append(image_feature)
48
+ else:
49
+ image_forward_outs = self.vision_encoder(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
50
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
51
+ # print("image feature shape", image_features.shape)
52
+ # print(type(image_forward_outs))
53
+ # print(type(image_forward_outs.shape))
54
+ # image_features = image_forward_outs.to(images.dtype)
55
+
56
+ return image_features
57
+
58
+ @property
59
+ def dummy_feature(self):
60
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
61
+
62
+ @property
63
+ def dtype(self):
64
+ return self.vision_encoder.dtype
65
+
66
+ @property
67
+ def device(self):
68
+ return self.vision_encoder.device
69
+
70
+ @property
71
+ def config(self):
72
+ if self.is_loaded:
73
+ return self.vision_encoder.config
74
+ else:
75
+ return self.cfg_only
76
+
77
+ @property
78
+ def hidden_size(self):
79
+ return self.config.hidden_size
80
+
81
+ @property
82
+ def num_patches(self):
83
+ return (self.config.image_size // self.config.patch_size) ** 2
config.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import json
10
+ from typing import Dict
11
+
12
+ from omegaconf import OmegaConf
13
+ from .registry import registry
14
+
15
+
16
+ class Config:
17
+ def __init__(self, args):
18
+ self.config = {}
19
+
20
+ self.args = args
21
+
22
+ # Register the config and configuration for setup
23
+ registry.register("configuration", self)
24
+
25
+ user_config = self._build_opt_list(self.args.options)
26
+
27
+ config = OmegaConf.load(self.args.cfg_path)
28
+
29
+ runner_config = self.build_runner_config(config)
30
+ model_config = self.build_model_config(config, **user_config)
31
+ dataset_config = self.build_dataset_config(config)
32
+
33
+ # Validate the user-provided runner configuration
34
+ # model and dataset configuration are supposed to be validated by the respective classes
35
+ # [TODO] validate the model/dataset configuration
36
+ # self._validate_runner_config(runner_config)
37
+
38
+ # Override the default configuration with user options.
39
+ self.config = OmegaConf.merge(
40
+ runner_config, model_config, dataset_config, user_config
41
+ )
42
+
43
+ def _validate_runner_config(self, runner_config):
44
+ """
45
+ This method validates the configuration, such that
46
+ 1) all the user specified options are valid;
47
+ 2) no type mismatches between the user specified options and the config.
48
+ """
49
+ runner_config_validator = create_runner_config_validator()
50
+ runner_config_validator.validate(runner_config)
51
+
52
+ def _build_opt_list(self, opts):
53
+ opts_dot_list = self._convert_to_dot_list(opts)
54
+ return OmegaConf.from_dotlist(opts_dot_list)
55
+
56
+ @staticmethod
57
+ def build_model_config(config, **kwargs):
58
+ model = config.get("model", None)
59
+ assert model is not None, "Missing model configuration file."
60
+
61
+ model_cls = registry.get_model_class(model.arch)
62
+ assert model_cls is not None, f"Model '{model.arch}' has not been registered."
63
+
64
+ model_type = kwargs.get("model.model_type", None)
65
+ if not model_type:
66
+ model_type = model.get("model_type", None)
67
+ # else use the model type selected by user.
68
+
69
+ assert model_type is not None, "Missing model_type."
70
+
71
+ print("--------------")
72
+ print("model arch",model.arch)
73
+ print("model cls",model_cls)
74
+
75
+ model_config_path = model_cls.PRETRAINED_MODEL_CONFIG_DICT[model_type]
76
+
77
+ model_config = OmegaConf.create()
78
+ # hierarchy override, customized config > default config
79
+ model_config = OmegaConf.merge(
80
+ model_config,
81
+ OmegaConf.load(model_config_path),
82
+ {"model": config["model"]},
83
+ )
84
+
85
+ return model_config
86
+
87
+ @staticmethod
88
+ def build_runner_config(config):
89
+ return {"run": config.run}
90
+
91
+ @staticmethod
92
+ def build_dataset_config(config):
93
+ datasets = config.get("datasets", None)
94
+ if datasets is None:
95
+ raise KeyError(
96
+ "Expecting 'datasets' as the root key for dataset configuration."
97
+ )
98
+
99
+ dataset_config = OmegaConf.create()
100
+
101
+ for dataset_name in datasets:
102
+
103
+ print("dataset name", dataset_name)
104
+ builder_cls = registry.get_builder_class(dataset_name)
105
+
106
+ dataset_config_type = datasets[dataset_name].get("type", "default")
107
+ dataset_config_path = builder_cls.default_config_path(
108
+ type=dataset_config_type
109
+ )
110
+
111
+ # hierarchy override, customized config > default config
112
+ dataset_config = OmegaConf.merge(
113
+ dataset_config,
114
+ OmegaConf.load(dataset_config_path),
115
+ {"datasets": {dataset_name: config["datasets"][dataset_name]}},
116
+ )
117
+
118
+ return dataset_config
119
+
120
+ def _convert_to_dot_list(self, opts):
121
+ if opts is None:
122
+ opts = []
123
+
124
+ if len(opts) == 0:
125
+ return opts
126
+
127
+ has_equal = opts[0].find("=") != -1
128
+
129
+ if has_equal:
130
+ return opts
131
+
132
+ return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
133
+
134
+ def get_config(self):
135
+ return self.config
136
+
137
+ @property
138
+ def run_cfg(self):
139
+ return self.config.run
140
+
141
+ @property
142
+ def datasets_cfg(self):
143
+ return self.config.datasets
144
+
145
+ @property
146
+ def model_cfg(self):
147
+ return self.config.model
148
+
149
+ def pretty_print(self):
150
+ logging.info("\n===== Running Parameters =====")
151
+ logging.info(self._convert_node_to_json(self.config.run))
152
+
153
+ logging.info("\n====== Dataset Attributes ======")
154
+ datasets = self.config.datasets
155
+
156
+ for dataset in datasets:
157
+ if dataset in self.config.datasets:
158
+ logging.info(f"\n======== {dataset} =======")
159
+ dataset_config = self.config.datasets[dataset]
160
+ logging.info(self._convert_node_to_json(dataset_config))
161
+ else:
162
+ logging.warning(f"No dataset named '{dataset}' in config. Skipping")
163
+
164
+ logging.info(f"\n====== Model Attributes ======")
165
+ logging.info(self._convert_node_to_json(self.config.model))
166
+
167
+ def _convert_node_to_json(self, node):
168
+ container = OmegaConf.to_container(node, resolve=True)
169
+ return json.dumps(container, indent=4, sort_keys=True)
170
+
171
+ def to_dict(self):
172
+ return OmegaConf.to_container(self.config)
173
+
174
+
175
+ def node_to_dict(node):
176
+ return OmegaConf.to_container(node)
177
+
178
+
179
+ class ConfigValidator:
180
+ """
181
+ This is a preliminary implementation to centralize and validate the configuration.
182
+ May be altered in the future.
183
+
184
+ A helper class to validate configurations from yaml file.
185
+
186
+ This serves the following purposes:
187
+ 1. Ensure all the options in the yaml are defined, raise error if not.
188
+ 2. when type mismatches are found, the validator will raise an error.
189
+ 3. a central place to store and display helpful messages for supported configurations.
190
+
191
+ """
192
+
193
+ class _Argument:
194
+ def __init__(self, name, choices=None, type=None, help=None):
195
+ self.name = name
196
+ self.val = None
197
+ self.choices = choices
198
+ self.type = type
199
+ self.help = help
200
+
201
+ def __str__(self):
202
+ s = f"{self.name}={self.val}"
203
+ if self.type is not None:
204
+ s += f", ({self.type})"
205
+ if self.choices is not None:
206
+ s += f", choices: {self.choices}"
207
+ if self.help is not None:
208
+ s += f", ({self.help})"
209
+ return s
210
+
211
+ def __init__(self, description):
212
+ self.description = description
213
+
214
+ self.arguments = dict()
215
+
216
+ self.parsed_args = None
217
+
218
+ def __getitem__(self, key):
219
+ assert self.parsed_args is not None, "No arguments parsed yet."
220
+
221
+ return self.parsed_args[key]
222
+
223
+ def __str__(self) -> str:
224
+ return self.format_help()
225
+
226
+ def add_argument(self, *args, **kwargs):
227
+ """
228
+ Assume the first argument is the name of the argument.
229
+ """
230
+ self.arguments[args[0]] = self._Argument(*args, **kwargs)
231
+
232
+ def validate(self, config=None):
233
+ """
234
+ Convert yaml config (dict-like) to list, required by argparse.
235
+ """
236
+ for k, v in config.items():
237
+ assert (
238
+ k in self.arguments
239
+ ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
240
+
241
+ if self.arguments[k].type is not None:
242
+ try:
243
+ self.arguments[k].val = self.arguments[k].type(v)
244
+ except ValueError:
245
+ raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
246
+
247
+ if self.arguments[k].choices is not None:
248
+ assert (
249
+ v in self.arguments[k].choices
250
+ ), f"""{k} must be one of {self.arguments[k].choices}."""
251
+
252
+ return config
253
+
254
+ def format_arguments(self):
255
+ return str([f"{k}" for k in sorted(self.arguments.keys())])
256
+
257
+ def format_help(self):
258
+ # description + key-value pair string for each argument
259
+ help_msg = str(self.description)
260
+ return help_msg + ", available arguments: " + self.format_arguments()
261
+
262
+ def print_help(self):
263
+ # display help message
264
+ print(self.format_help())
265
+
266
+
267
+ def create_runner_config_validator():
268
+ validator = ConfigValidator(description="Runner configurations")
269
+
270
+ validator.add_argument(
271
+ "runner",
272
+ type=str,
273
+ choices=["runner_base", "runner_iter"],
274
+ help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
275
+ runner runs based on iters. Default: runner_base""",
276
+ )
277
+ # add argumetns for training dataset ratios
278
+ validator.add_argument(
279
+ "train_dataset_ratios",
280
+ type=Dict[str, float],
281
+ help="""Ratios of training dataset. This is used in iteration-based runner.
282
+ Do not support for epoch-based runner because how to define an epoch becomes tricky.
283
+ Default: None""",
284
+ )
285
+ validator.add_argument(
286
+ "max_iters",
287
+ type=float,
288
+ help="Maximum number of iterations to run.",
289
+ )
290
+ validator.add_argument(
291
+ "max_epoch",
292
+ type=int,
293
+ help="Maximum number of epochs to run.",
294
+ )
295
+ # add arguments for iters_per_inner_epoch
296
+ validator.add_argument(
297
+ "iters_per_inner_epoch",
298
+ type=float,
299
+ help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
300
+ )
301
+ lr_scheds_choices = registry.list_lr_schedulers()
302
+ validator.add_argument(
303
+ "lr_sched",
304
+ type=str,
305
+ choices=lr_scheds_choices,
306
+ help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
307
+ )
308
+ task_choices = registry.list_tasks()
309
+ validator.add_argument(
310
+ "task",
311
+ type=str,
312
+ choices=task_choices,
313
+ help="Task to use, from {}".format(task_choices),
314
+ )
315
+ # add arguments for init_lr
316
+ validator.add_argument(
317
+ "init_lr",
318
+ type=float,
319
+ help="Initial learning rate. This will be the learning rate after warmup and before decay.",
320
+ )
321
+ # add arguments for min_lr
322
+ validator.add_argument(
323
+ "min_lr",
324
+ type=float,
325
+ help="Minimum learning rate (after decay).",
326
+ )
327
+ # add arguments for warmup_lr
328
+ validator.add_argument(
329
+ "warmup_lr",
330
+ type=float,
331
+ help="Starting learning rate for warmup.",
332
+ )
333
+ # add arguments for learning rate decay rate
334
+ validator.add_argument(
335
+ "lr_decay_rate",
336
+ type=float,
337
+ help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
338
+ )
339
+ # add arguments for weight decay
340
+ validator.add_argument(
341
+ "weight_decay",
342
+ type=float,
343
+ help="Weight decay rate.",
344
+ )
345
+ # add arguments for training batch size
346
+ validator.add_argument(
347
+ "batch_size_train",
348
+ type=int,
349
+ help="Training batch size.",
350
+ )
351
+ # add arguments for evaluation batch size
352
+ validator.add_argument(
353
+ "batch_size_eval",
354
+ type=int,
355
+ help="Evaluation batch size, including validation and testing.",
356
+ )
357
+ # add arguments for number of workers for data loading
358
+ validator.add_argument(
359
+ "num_workers",
360
+ help="Number of workers for data loading.",
361
+ )
362
+ # add arguments for warm up steps
363
+ validator.add_argument(
364
+ "warmup_steps",
365
+ type=int,
366
+ help="Number of warmup steps. Required if a warmup schedule is used.",
367
+ )
368
+ # add arguments for random seed
369
+ validator.add_argument(
370
+ "seed",
371
+ type=int,
372
+ help="Random seed.",
373
+ )
374
+ # add arguments for output directory
375
+ validator.add_argument(
376
+ "output_dir",
377
+ type=str,
378
+ help="Output directory to save checkpoints and logs.",
379
+ )
380
+ # add arguments for whether only use evaluation
381
+ validator.add_argument(
382
+ "evaluate",
383
+ help="Whether to only evaluate the model. If true, training will not be performed.",
384
+ )
385
+ # add arguments for splits used for training, e.g. ["train", "val"]
386
+ validator.add_argument(
387
+ "train_splits",
388
+ type=list,
389
+ help="Splits to use for training.",
390
+ )
391
+ # add arguments for splits used for validation, e.g. ["val"]
392
+ validator.add_argument(
393
+ "valid_splits",
394
+ type=list,
395
+ help="Splits to use for validation. If not provided, will skip the validation.",
396
+ )
397
+ # add arguments for splits used for testing, e.g. ["test"]
398
+ validator.add_argument(
399
+ "test_splits",
400
+ type=list,
401
+ help="Splits to use for testing. If not provided, will skip the testing.",
402
+ )
403
+ # add arguments for accumulating gradient for iterations
404
+ validator.add_argument(
405
+ "accum_grad_iters",
406
+ type=int,
407
+ help="Number of iterations to accumulate gradient for.",
408
+ )
409
+
410
+ # ====== distributed training ======
411
+ validator.add_argument(
412
+ "device",
413
+ type=str,
414
+ choices=["cpu", "cuda"],
415
+ help="Device to use. Support 'cuda' or 'cpu' as for now.",
416
+ )
417
+ validator.add_argument(
418
+ "world_size",
419
+ type=int,
420
+ help="Number of processes participating in the job.",
421
+ )
422
+ validator.add_argument("dist_url", type=str)
423
+ validator.add_argument("distributed", type=bool)
424
+ # add arguments to opt using distributed sampler during evaluation or not
425
+ validator.add_argument(
426
+ "use_dist_eval_sampler",
427
+ type=bool,
428
+ help="Whether to use distributed sampler during evaluation or not.",
429
+ )
430
+
431
+ # ====== task specific ======
432
+ # generation task specific arguments
433
+ # add arguments for maximal length of text output
434
+ validator.add_argument(
435
+ "max_len",
436
+ type=int,
437
+ help="Maximal length of text output.",
438
+ )
439
+ # add arguments for minimal length of text output
440
+ validator.add_argument(
441
+ "min_len",
442
+ type=int,
443
+ help="Minimal length of text output.",
444
+ )
445
+ # add arguments number of beams
446
+ validator.add_argument(
447
+ "num_beams",
448
+ type=int,
449
+ help="Number of beams used for beam search.",
450
+ )
451
+
452
+ # vqa task specific arguments
453
+ # add arguments for number of answer candidates
454
+ validator.add_argument(
455
+ "num_ans_candidates",
456
+ type=int,
457
+ help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
458
+ )
459
+ # add arguments for inference method
460
+ validator.add_argument(
461
+ "inference_method",
462
+ type=str,
463
+ choices=["genearte", "rank"],
464
+ help="""Inference method to use for question answering. If rank, requires a answer list.""",
465
+ )
466
+
467
+ # ====== model specific ======
468
+ validator.add_argument(
469
+ "k_test",
470
+ type=int,
471
+ help="Number of top k most similar samples from ITC/VTC selection to be tested.",
472
+ )
473
+
474
+ return validator
conversation.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ from PIL import Image
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
7
+ from transformers import StoppingCriteria, StoppingCriteriaList
8
+
9
+ import dataclasses
10
+ from enum import auto, Enum
11
+ from typing import List, Tuple, Any
12
+
13
+ from .registry import registry
14
+
15
+
16
+ class SeparatorStyle(Enum):
17
+ """Different separator style."""
18
+ SINGLE = auto()
19
+ TWO = auto()
20
+
21
+
22
+ @dataclasses.dataclass
23
+ class Conversation:
24
+ """A class that keeps all conversation history."""
25
+ system: str
26
+ roles: List[str]
27
+ messages: List[List[str]]
28
+ offset: int
29
+ # system_img: List[Image.Image] = []
30
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
31
+ sep: str = "<s>"
32
+ sep2: str = "</s>"
33
+
34
+ skip_next: bool = False
35
+ conv_id: Any = None
36
+
37
+ def get_prompt(self):
38
+ if self.sep_style == SeparatorStyle.SINGLE:
39
+ # ret = self.system + self.sep
40
+ ret = self.system +"<s>"
41
+ for role, message in self.messages:
42
+ if message:
43
+ # ret += role + ": " + message + self.sep
44
+ ret+= role + message
45
+ # ret+= role + message
46
+ else:
47
+ # ret += role + ":"
48
+ # ret += self.sep2 + role
49
+ ret += role
50
+ return ret
51
+ elif self.sep_style == SeparatorStyle.TWO:
52
+ seps = [self.sep, self.sep2]
53
+ # ret = self.system + seps[0]
54
+ ret = self.system+"<s>"
55
+ for i, (role, message) in enumerate(self.messages):
56
+ if message:
57
+ # ret += role + ": " + message + seps[i % 2]
58
+ ret += role+message+seps[i%2]
59
+ else:
60
+ # ret += role + ":"
61
+ ret += role
62
+ return ret
63
+ else:
64
+ raise ValueError(f"Invalid style: {self.sep_style}")
65
+
66
+ def append_message(self, role, message):
67
+ self.messages.append([role, message])
68
+
69
+ def to_gradio_chatbot(self):
70
+ ret = []
71
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
72
+ if i % 2 == 0:
73
+ ret.append([msg, None])
74
+ else:
75
+ ret[-1][-1] = msg
76
+ return ret
77
+
78
+ def copy(self):
79
+ return Conversation(
80
+ system=self.system,
81
+ # system_img=self.system_img,
82
+ roles=self.roles,
83
+ messages=[[x, y] for x, y in self.messages],
84
+ offset=self.offset,
85
+ sep_style=self.sep_style,
86
+ sep=self.sep,
87
+ sep2=self.sep2,
88
+ conv_id=self.conv_id)
89
+
90
+ def dict(self):
91
+ return {
92
+ "system": self.system,
93
+ # "system_img": self.system_img,
94
+ "roles": self.roles,
95
+ "messages": self.messages,
96
+ "offset": self.offset,
97
+ "sep": self.sep,
98
+ "sep2": self.sep2,
99
+ "conv_id": self.conv_id,
100
+ }
101
+
102
+
103
+ class StoppingCriteriaSub(StoppingCriteria):
104
+
105
+ def __init__(self, stops=[], encounters=1):
106
+ super().__init__()
107
+ self.stops = stops
108
+
109
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
110
+ for stop in self.stops:
111
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
112
+ return True
113
+
114
+ return False
115
+
116
+
117
+ CONV_VISION = Conversation(
118
+ # system="Give the following image: <Img>ImageContent</Img>. "
119
+ # "You will be able to see the image once I provide it to you. Please answer my questions.",
120
+ system = "",
121
+ roles = (r"[INST] ",r" [/INST]"),
122
+ messages=[],
123
+ offset=2,
124
+ sep_style=SeparatorStyle.SINGLE,
125
+ sep="<s>",
126
+ )
127
+
128
+
129
+ class Chat:
130
+ def __init__(self, model, vis_processor, device='cuda:0'):
131
+ self.device = device
132
+ self.model = model
133
+ self.vis_processor = vis_processor
134
+
135
+ self.conv = CONV_VISION.copy()
136
+ self.img_list = []
137
+ self.raw_answers = []
138
+
139
+ stop_words_ids = [torch.tensor([2]).to(self.device)]
140
+ self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
141
+
142
+ def reset(self):
143
+ self.conv.messages = []
144
+ self.img_list = []
145
+ # self.img_list = [img for img in self.conv.system_img]
146
+ self.raw_answers = []
147
+
148
+ def ask(self, text, conv):
149
+ if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
150
+ and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
151
+ conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
152
+ else:
153
+ conv.append_message(conv.roles[0], text)
154
+
155
+ def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
156
+ repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
157
+ conv.append_message(conv.roles[1], None)
158
+ embs = self.get_context_emb(conv, img_list)
159
+
160
+ current_max_len = embs.shape[1] + max_new_tokens
161
+ if current_max_len - max_length > 0:
162
+ print('Warning: The number of tokens in current conversation exceeds the max length. '
163
+ 'The model will not see the contexts outside the range.')
164
+ begin_idx = max(0, current_max_len - max_length)
165
+
166
+ embs = embs[:, begin_idx:]
167
+
168
+ outputs = self.model.llama_model.generate(
169
+ inputs_embeds=embs,
170
+ max_new_tokens=max_new_tokens,
171
+ stopping_criteria=self.stopping_criteria,
172
+ num_beams=num_beams,
173
+ min_length=min_length,
174
+ top_p=top_p,
175
+ repetition_penalty=repetition_penalty,
176
+ length_penalty=length_penalty,
177
+ temperature=temperature,
178
+ do_sample=False,
179
+ )
180
+ output_token = outputs[0]
181
+ if output_token[0] == 0:
182
+ output_token = output_token[1:]
183
+ output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
184
+ self.raw_answers.append(output_text)
185
+ output_text = output_text.split('</s>')[0] # remove the stop sign '###'
186
+ output_text = output_text.replace("<s>", "")
187
+ output_text = output_text.split(r'[/INST]')[-1].strip()
188
+ self.conv.messages[-1][1] = output_text
189
+ return output_text, output_token.cpu().numpy()
190
+
191
+ def upload_img(self, image):
192
+ if isinstance(image, str): # is a image path
193
+ raw_image = Image.open(image).convert('RGB')
194
+ image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
195
+ elif isinstance(image, Image.Image):
196
+ raw_image = image
197
+ image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
198
+ elif isinstance(image, torch.Tensor):
199
+ if len(image.shape) == 3:
200
+ image = image.unsqueeze(0)
201
+ image = image.to(self.device)
202
+
203
+ image_emb, _ = self.model.encode_img(image)
204
+ self.img_list.append(image_emb)
205
+ self.conv.append_message(self.conv.roles[0], "<Img><ImageHere></Img>")
206
+ msg = "Received."
207
+ # self.conv.append_message(self.conv.roles[1], msg)
208
+ return msg
209
+
210
+ def get_context_emb(self, conv, img_list):
211
+ prompt = conv.get_prompt()
212
+ prompt_segs = prompt.split('<ImageHere>')
213
+ assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
214
+ seg_tokens = [
215
+ self.model.llama_tokenizer(
216
+ seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
217
+ # only add bos to the first seg
218
+ for i, seg in enumerate(prompt_segs)
219
+ ]
220
+
221
+ seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens]
222
+ mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
223
+ mixed_embs = torch.cat(mixed_embs, dim=1)
224
+ return mixed_embs
dist_utils.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import functools
10
+ import os
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import timm.models.hub as timm_hub
15
+
16
+
17
+ def setup_for_distributed(is_master):
18
+ """
19
+ This function disables printing when not in master process
20
+ """
21
+ import builtins as __builtin__
22
+
23
+ builtin_print = __builtin__.print
24
+
25
+ def print(*args, **kwargs):
26
+ force = kwargs.pop("force", False)
27
+ if is_master or force:
28
+ builtin_print(*args, **kwargs)
29
+
30
+ __builtin__.print = print
31
+
32
+
33
+ def is_dist_avail_and_initialized():
34
+ if not dist.is_available():
35
+ return False
36
+ if not dist.is_initialized():
37
+ return False
38
+ return True
39
+
40
+
41
+ def get_world_size():
42
+ if not is_dist_avail_and_initialized():
43
+ return 1
44
+ return dist.get_world_size()
45
+
46
+
47
+ def get_rank():
48
+ if not is_dist_avail_and_initialized():
49
+ return 0
50
+ return dist.get_rank()
51
+
52
+
53
+ def is_main_process():
54
+ return get_rank() == 0
55
+
56
+
57
+ def init_distributed_mode(args):
58
+ if args.distributed is False:
59
+ print("Not using distributed mode")
60
+ args.rank = 0
61
+ return
62
+
63
+ if 'LOCAL_RANK' not in os.environ:
64
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
65
+
66
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
67
+ args.rank = int(os.environ["RANK"])
68
+ args.world_size = int(os.environ["WORLD_SIZE"])
69
+ args.gpu = int(os.environ["LOCAL_RANK"])
70
+ elif "SLURM_PROCID" in os.environ:
71
+ args.rank = int(os.environ["SLURM_PROCID"])
72
+ args.gpu = args.rank % torch.cuda.device_count()
73
+ else:
74
+ print("Not using distributed mode")
75
+ args.distributed = False
76
+ args.rank = 0
77
+ return
78
+
79
+ args.distributed = True
80
+
81
+ torch.cuda.set_device(args.gpu)
82
+ args.dist_backend = "nccl"
83
+ print(
84
+ "| distributed init (rank {}, world {}): {}".format(
85
+ args.rank, args.world_size, args.dist_url
86
+ ),
87
+ flush=True,
88
+ )
89
+ torch.distributed.init_process_group(
90
+ backend=args.dist_backend,
91
+ init_method=args.dist_url,
92
+ world_size=args.world_size,
93
+ rank=args.rank,
94
+ timeout=datetime.timedelta(
95
+ days=365
96
+ ), # allow auto-downloading and de-compressing
97
+ )
98
+ torch.distributed.barrier()
99
+ setup_for_distributed(args.rank == 0)
100
+
101
+
102
+ def get_dist_info():
103
+ if torch.__version__ < "1.0":
104
+ initialized = dist._initialized
105
+ else:
106
+ initialized = dist.is_initialized()
107
+ if initialized:
108
+ rank = dist.get_rank()
109
+ world_size = dist.get_world_size()
110
+ else: # non-distributed training
111
+ rank = 0
112
+ world_size = 1
113
+ return rank, world_size
114
+
115
+
116
+ def main_process(func):
117
+ @functools.wraps(func)
118
+ def wrapper(*args, **kwargs):
119
+ rank, _ = get_dist_info()
120
+ if rank == 0:
121
+ return func(*args, **kwargs)
122
+
123
+ return wrapper
124
+
125
+
126
+ def download_cached_file(url, check_hash=True, progress=False):
127
+ """
128
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
129
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
130
+ """
131
+
132
+ def get_cached_file_path():
133
+ # a hack to sync the file path across processes
134
+ parts = torch.hub.urlparse(url)
135
+ filename = os.path.basename(parts.path)
136
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
137
+
138
+ return cached_file
139
+
140
+ if is_main_process():
141
+ timm_hub.download_cached_file(url, check_hash, progress)
142
+
143
+ if is_dist_avail_and_initialized():
144
+ dist.barrier()
145
+
146
+ return get_cached_file_path()
eva_vit.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on EVA, BEIT, timm and DeiT code bases
2
+ # https://github.com/baaivision/EVA
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
4
+ # https://github.com/microsoft/unilm/tree/master/beit
5
+ # https://github.com/facebookresearch/deit/
6
+ # https://github.com/facebookresearch/dino
7
+ # --------------------------------------------------------'
8
+ import math
9
+ from functools import partial
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint as checkpoint
15
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16
+ from timm.models.registry import register_model
17
+
18
+ from .dist_utils import download_cached_file
19
+
20
+ def _cfg(url='', **kwargs):
21
+ return {
22
+ 'url': url,
23
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
24
+ 'crop_pct': .9, 'interpolation': 'bicubic',
25
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
26
+ **kwargs
27
+ }
28
+
29
+
30
+ class DropPath(nn.Module):
31
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
32
+ """
33
+ def __init__(self, drop_prob=None):
34
+ super(DropPath, self).__init__()
35
+ self.drop_prob = drop_prob
36
+
37
+ def forward(self, x):
38
+ return drop_path(x, self.drop_prob, self.training)
39
+
40
+ def extra_repr(self) -> str:
41
+ return 'p={}'.format(self.drop_prob)
42
+
43
+
44
+ class Mlp(nn.Module):
45
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
46
+ super().__init__()
47
+ out_features = out_features or in_features
48
+ hidden_features = hidden_features or in_features
49
+ self.fc1 = nn.Linear(in_features, hidden_features)
50
+ self.act = act_layer()
51
+ self.fc2 = nn.Linear(hidden_features, out_features)
52
+ self.drop = nn.Dropout(drop)
53
+
54
+ def forward(self, x):
55
+ x = self.fc1(x)
56
+ x = self.act(x)
57
+ # x = self.drop(x)
58
+ # commit this for the orignal BERT implement
59
+ x = self.fc2(x)
60
+ x = self.drop(x)
61
+ return x
62
+
63
+
64
+ class Attention(nn.Module):
65
+ def __init__(
66
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
67
+ proj_drop=0., window_size=None, attn_head_dim=None):
68
+ super().__init__()
69
+ self.num_heads = num_heads
70
+ head_dim = dim // num_heads
71
+ if attn_head_dim is not None:
72
+ head_dim = attn_head_dim
73
+ all_head_dim = head_dim * self.num_heads
74
+ self.scale = qk_scale or head_dim ** -0.5
75
+
76
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
77
+ if qkv_bias:
78
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
79
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
80
+ else:
81
+ self.q_bias = None
82
+ self.v_bias = None
83
+
84
+ if window_size:
85
+ self.window_size = window_size
86
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
87
+ self.relative_position_bias_table = nn.Parameter(
88
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
89
+ # cls to token & token 2 cls & cls to cls
90
+
91
+ # get pair-wise relative position index for each token inside the window
92
+ coords_h = torch.arange(window_size[0])
93
+ coords_w = torch.arange(window_size[1])
94
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
95
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
96
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
97
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
98
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
99
+ relative_coords[:, :, 1] += window_size[1] - 1
100
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
101
+ relative_position_index = \
102
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
103
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
104
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
105
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
106
+ relative_position_index[0, 0] = self.num_relative_distance - 1
107
+
108
+ self.register_buffer("relative_position_index", relative_position_index)
109
+ else:
110
+ self.window_size = None
111
+ self.relative_position_bias_table = None
112
+ self.relative_position_index = None
113
+
114
+ self.attn_drop = nn.Dropout(attn_drop)
115
+ self.proj = nn.Linear(all_head_dim, dim)
116
+ self.proj_drop = nn.Dropout(proj_drop)
117
+
118
+ def forward(self, x, rel_pos_bias=None):
119
+ B, N, C = x.shape
120
+ qkv_bias = None
121
+ if self.q_bias is not None:
122
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
123
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
124
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
125
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
126
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
127
+
128
+ q = q * self.scale
129
+ attn = (q @ k.transpose(-2, -1))
130
+
131
+ if self.relative_position_bias_table is not None:
132
+ relative_position_bias = \
133
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
134
+ self.window_size[0] * self.window_size[1] + 1,
135
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
136
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
137
+ attn = attn + relative_position_bias.unsqueeze(0)
138
+
139
+ if rel_pos_bias is not None:
140
+ attn = attn + rel_pos_bias
141
+
142
+ attn = attn.softmax(dim=-1)
143
+ attn = self.attn_drop(attn)
144
+
145
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
146
+ x = self.proj(x)
147
+ x = self.proj_drop(x)
148
+ return x
149
+
150
+
151
+ class Block(nn.Module):
152
+
153
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
154
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
155
+ window_size=None, attn_head_dim=None):
156
+ super().__init__()
157
+ self.norm1 = norm_layer(dim)
158
+ self.attn = Attention(
159
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
160
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
161
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
162
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
163
+ self.norm2 = norm_layer(dim)
164
+ mlp_hidden_dim = int(dim * mlp_ratio)
165
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
166
+
167
+ if init_values is not None and init_values > 0:
168
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
169
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
170
+ else:
171
+ self.gamma_1, self.gamma_2 = None, None
172
+
173
+ def forward(self, x, rel_pos_bias=None):
174
+ if self.gamma_1 is None:
175
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
176
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
177
+ else:
178
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
179
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
180
+ return x
181
+
182
+
183
+ class PatchEmbed(nn.Module):
184
+ """ Image to Patch Embedding
185
+ """
186
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
187
+ super().__init__()
188
+ img_size = to_2tuple(img_size)
189
+ patch_size = to_2tuple(patch_size)
190
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
191
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
192
+ self.img_size = img_size
193
+ self.patch_size = patch_size
194
+ self.num_patches = num_patches
195
+
196
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
197
+
198
+ def forward(self, x, **kwargs):
199
+ B, C, H, W = x.shape
200
+ # FIXME look at relaxing size constraints
201
+ assert H == self.img_size[0] and W == self.img_size[1], \
202
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
203
+ x = self.proj(x).flatten(2).transpose(1, 2)
204
+ return x
205
+
206
+
207
+ class RelativePositionBias(nn.Module):
208
+
209
+ def __init__(self, window_size, num_heads):
210
+ super().__init__()
211
+ self.window_size = window_size
212
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
213
+ self.relative_position_bias_table = nn.Parameter(
214
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
215
+ # cls to token & token 2 cls & cls to cls
216
+
217
+ # get pair-wise relative position index for each token inside the window
218
+ coords_h = torch.arange(window_size[0])
219
+ coords_w = torch.arange(window_size[1])
220
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
221
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
222
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
223
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
224
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
225
+ relative_coords[:, :, 1] += window_size[1] - 1
226
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
227
+ relative_position_index = \
228
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
229
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
230
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
231
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
232
+ relative_position_index[0, 0] = self.num_relative_distance - 1
233
+
234
+ self.register_buffer("relative_position_index", relative_position_index)
235
+
236
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
237
+
238
+ def forward(self):
239
+ relative_position_bias = \
240
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
241
+ self.window_size[0] * self.window_size[1] + 1,
242
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
243
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
244
+
245
+
246
+ class VisionTransformer(nn.Module):
247
+ """ Vision Transformer with support for patch or hybrid CNN input stage
248
+ """
249
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
250
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
251
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
252
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
253
+ use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
254
+ super().__init__()
255
+ self.image_size = img_size
256
+ self.num_classes = num_classes
257
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
258
+
259
+ self.patch_embed = PatchEmbed(
260
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
261
+ num_patches = self.patch_embed.num_patches
262
+
263
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
264
+ if use_abs_pos_emb:
265
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
266
+ else:
267
+ self.pos_embed = None
268
+ self.pos_drop = nn.Dropout(p=drop_rate)
269
+
270
+ if use_shared_rel_pos_bias:
271
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
272
+ else:
273
+ self.rel_pos_bias = None
274
+ self.use_checkpoint = use_checkpoint
275
+
276
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
277
+ self.use_rel_pos_bias = use_rel_pos_bias
278
+ self.blocks = nn.ModuleList([
279
+ Block(
280
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
281
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
282
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
283
+ for i in range(depth)])
284
+ # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
285
+ # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
286
+ # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
287
+
288
+ if self.pos_embed is not None:
289
+ trunc_normal_(self.pos_embed, std=.02)
290
+ trunc_normal_(self.cls_token, std=.02)
291
+ # trunc_normal_(self.mask_token, std=.02)
292
+ # if isinstance(self.head, nn.Linear):
293
+ # trunc_normal_(self.head.weight, std=.02)
294
+ self.apply(self._init_weights)
295
+ self.fix_init_weight()
296
+ # if isinstance(self.head, nn.Linear):
297
+ # self.head.weight.data.mul_(init_scale)
298
+ # self.head.bias.data.mul_(init_scale)
299
+
300
+ def fix_init_weight(self):
301
+ def rescale(param, layer_id):
302
+ param.div_(math.sqrt(2.0 * layer_id))
303
+
304
+ for layer_id, layer in enumerate(self.blocks):
305
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
306
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
307
+
308
+ def _init_weights(self, m):
309
+ if isinstance(m, nn.Linear):
310
+ trunc_normal_(m.weight, std=.02)
311
+ if isinstance(m, nn.Linear) and m.bias is not None:
312
+ nn.init.constant_(m.bias, 0)
313
+ elif isinstance(m, nn.LayerNorm):
314
+ nn.init.constant_(m.bias, 0)
315
+ nn.init.constant_(m.weight, 1.0)
316
+
317
+ def get_classifier(self):
318
+ return self.head
319
+
320
+ def reset_classifier(self, num_classes, global_pool=''):
321
+ self.num_classes = num_classes
322
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
323
+
324
+ def forward_features(self, x):
325
+ x = self.patch_embed(x)
326
+ batch_size, seq_len, _ = x.size()
327
+
328
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
329
+ x = torch.cat((cls_tokens, x), dim=1)
330
+ if self.pos_embed is not None:
331
+ x = x + self.pos_embed
332
+ x = self.pos_drop(x)
333
+
334
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
335
+ for blk in self.blocks:
336
+ if self.use_checkpoint:
337
+ x = checkpoint.checkpoint(blk, x, rel_pos_bias)
338
+ else:
339
+ x = blk(x, rel_pos_bias)
340
+ return x
341
+ # x = self.norm(x)
342
+
343
+ # if self.fc_norm is not None:
344
+ # t = x[:, 1:, :]
345
+ # return self.fc_norm(t.mean(1))
346
+ # else:
347
+ # return x[:, 0]
348
+
349
+ def forward(self, x):
350
+ x = self.forward_features(x)
351
+ # x = self.head(x)
352
+ return x
353
+
354
+ def get_intermediate_layers(self, x):
355
+ x = self.patch_embed(x)
356
+ batch_size, seq_len, _ = x.size()
357
+
358
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
359
+ x = torch.cat((cls_tokens, x), dim=1)
360
+ if self.pos_embed is not None:
361
+ x = x + self.pos_embed
362
+ x = self.pos_drop(x)
363
+
364
+ features = []
365
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
366
+ for blk in self.blocks:
367
+ x = blk(x, rel_pos_bias)
368
+ features.append(x)
369
+
370
+ return features
371
+
372
+
373
+ def interpolate_pos_embed(model, checkpoint_model):
374
+ if 'pos_embed' in checkpoint_model:
375
+ pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
376
+ embedding_size = pos_embed_checkpoint.shape[-1]
377
+ num_patches = model.patch_embed.num_patches
378
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
379
+ # height (== width) for the checkpoint position embedding
380
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
381
+ # height (== width) for the new position embedding
382
+ new_size = int(num_patches ** 0.5)
383
+ # class_token and dist_token are kept unchanged
384
+ if orig_size != new_size:
385
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
386
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
387
+ # only the position tokens are interpolated
388
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
389
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
390
+ pos_tokens = torch.nn.functional.interpolate(
391
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
392
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
393
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
394
+ checkpoint_model['pos_embed'] = new_pos_embed
395
+
396
+
397
+ def convert_weights_to_fp16(model: nn.Module):
398
+ """Convert applicable model parameters to fp16"""
399
+
400
+ def _convert_weights_to_fp16(l):
401
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
402
+ l.weight.data = l.weight.data.half()
403
+ if l.bias is not None:
404
+ l.bias.data = l.bias.data.half()
405
+
406
+ # if isinstance(l, (nn.MultiheadAttention, Attention)):
407
+ # for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
408
+ # tensor = getattr(l, attr)
409
+ # if tensor is not None:
410
+ # tensor.data = tensor.data.half()
411
+
412
+ model.apply(_convert_weights_to_fp16)
413
+
414
+
415
+ def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"):
416
+ model = VisionTransformer(
417
+ img_size=img_size,
418
+ patch_size=14,
419
+ use_mean_pooling=False,
420
+ embed_dim=1408,
421
+ depth=39,
422
+ # depth = 37,
423
+ num_heads=1408//88,
424
+ mlp_ratio=4.3637,
425
+ qkv_bias=True,
426
+ drop_path_rate=drop_path_rate,
427
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
428
+ use_checkpoint=use_checkpoint,
429
+ )
430
+ url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
431
+ cached_file = download_cached_file(
432
+ url, check_hash=False, progress=True
433
+ )
434
+ state_dict = torch.load(cached_file, map_location="cpu")
435
+ interpolate_pos_embed(model,state_dict)
436
+
437
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
438
+ # print(incompatible_keys)
439
+
440
+ if precision == "fp16":
441
+ # model.to("cuda")
442
+ convert_weights_to_fp16(model)
443
+ return model
gradcam.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from matplotlib import pyplot as plt
3
+ from scipy.ndimage import filters
4
+ from skimage import transform as skimage_transform
5
+
6
+
7
+ def getAttMap(img, attMap, blur=True, overlap=True):
8
+ attMap -= attMap.min()
9
+ if attMap.max() > 0:
10
+ attMap /= attMap.max()
11
+ attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
12
+ if blur:
13
+ attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
14
+ attMap -= attMap.min()
15
+ attMap /= attMap.max()
16
+ cmap = plt.get_cmap("jet")
17
+ attMapV = cmap(attMap)
18
+ attMapV = np.delete(attMapV, 3, 2)
19
+ if overlap:
20
+ attMap = (
21
+ 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
22
+ + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
23
+ )
24
+ return attMap
logger.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import logging
10
+ import time
11
+ from collections import defaultdict, deque
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+
16
+ from .dist_utils import is_dist_avail_and_initialized,is_main_process
17
+
18
+ class SmoothedValue(object):
19
+ """Track a series of values and provide access to smoothed values over a
20
+ window or the global series average.
21
+ """
22
+
23
+ def __init__(self, window_size=20, fmt=None):
24
+ if fmt is None:
25
+ fmt = "{median:.4f} ({global_avg:.4f})"
26
+ self.deque = deque(maxlen=window_size)
27
+ self.total = 0.0
28
+ self.count = 0
29
+ self.fmt = fmt
30
+
31
+ def update(self, value, n=1):
32
+ self.deque.append(value)
33
+ self.count += n
34
+ self.total += value * n
35
+
36
+ def synchronize_between_processes(self):
37
+ """
38
+ Warning: does not synchronize the deque!
39
+ """
40
+ if not is_dist_avail_and_initialized():
41
+ return
42
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
43
+ dist.barrier()
44
+ dist.all_reduce(t)
45
+ t = t.tolist()
46
+ self.count = int(t[0])
47
+ self.total = t[1]
48
+
49
+ @property
50
+ def median(self):
51
+ d = torch.tensor(list(self.deque))
52
+ return d.median().item()
53
+
54
+ @property
55
+ def avg(self):
56
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
57
+ return d.mean().item()
58
+
59
+ @property
60
+ def global_avg(self):
61
+ return self.total / self.count
62
+
63
+ @property
64
+ def max(self):
65
+ return max(self.deque)
66
+
67
+ @property
68
+ def value(self):
69
+ return self.deque[-1]
70
+
71
+ def __str__(self):
72
+ return self.fmt.format(
73
+ median=self.median,
74
+ avg=self.avg,
75
+ global_avg=self.global_avg,
76
+ max=self.max,
77
+ value=self.value,
78
+ )
79
+
80
+
81
+ class MetricLogger(object):
82
+ def __init__(self, delimiter="\t"):
83
+ self.meters = defaultdict(SmoothedValue)
84
+ self.delimiter = delimiter
85
+
86
+ def update(self, **kwargs):
87
+ for k, v in kwargs.items():
88
+ if isinstance(v, torch.Tensor):
89
+ v = v.item()
90
+ assert isinstance(v, (float, int))
91
+ self.meters[k].update(v)
92
+
93
+ def __getattr__(self, attr):
94
+ if attr in self.meters:
95
+ return self.meters[attr]
96
+ if attr in self.__dict__:
97
+ return self.__dict__[attr]
98
+ raise AttributeError(
99
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
100
+ )
101
+
102
+ def __str__(self):
103
+ loss_str = []
104
+ for name, meter in self.meters.items():
105
+ loss_str.append("{}: {}".format(name, str(meter)))
106
+ return self.delimiter.join(loss_str)
107
+
108
+ def global_avg(self):
109
+ loss_str = []
110
+ for name, meter in self.meters.items():
111
+ loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
112
+ return self.delimiter.join(loss_str)
113
+
114
+ def synchronize_between_processes(self):
115
+ for meter in self.meters.values():
116
+ meter.synchronize_between_processes()
117
+
118
+ def add_meter(self, name, meter):
119
+ self.meters[name] = meter
120
+
121
+ def log_every(self, iterable, print_freq, header=None):
122
+ i = 0
123
+ if not header:
124
+ header = ""
125
+ start_time = time.time()
126
+ end = time.time()
127
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
128
+ data_time = SmoothedValue(fmt="{avg:.4f}")
129
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
130
+ log_msg = [
131
+ header,
132
+ "[{0" + space_fmt + "}/{1}]",
133
+ "eta: {eta}",
134
+ "{meters}",
135
+ "time: {time}",
136
+ "data: {data}",
137
+ ]
138
+ if torch.cuda.is_available():
139
+ log_msg.append("max mem: {memory:.0f}")
140
+ log_msg = self.delimiter.join(log_msg)
141
+ MB = 1024.0 * 1024.0
142
+ for obj in iterable:
143
+ data_time.update(time.time() - end)
144
+ yield obj
145
+ iter_time.update(time.time() - end)
146
+ if i % print_freq == 0 or i == len(iterable) - 1:
147
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
148
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
149
+ if torch.cuda.is_available():
150
+ print(
151
+ log_msg.format(
152
+ i,
153
+ len(iterable),
154
+ eta=eta_string,
155
+ meters=str(self),
156
+ time=str(iter_time),
157
+ data=str(data_time),
158
+ memory=torch.cuda.max_memory_allocated() / MB,
159
+ )
160
+ )
161
+ else:
162
+ print(
163
+ log_msg.format(
164
+ i,
165
+ len(iterable),
166
+ eta=eta_string,
167
+ meters=str(self),
168
+ time=str(iter_time),
169
+ data=str(data_time),
170
+ )
171
+ )
172
+ i += 1
173
+ end = time.time()
174
+ total_time = time.time() - start_time
175
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
176
+ print(
177
+ "{} Total time: {} ({:.4f} s / it)".format(
178
+ header, total_time_str, total_time / len(iterable)
179
+ )
180
+ )
181
+
182
+
183
+ class AttrDict(dict):
184
+ def __init__(self, *args, **kwargs):
185
+ super(AttrDict, self).__init__(*args, **kwargs)
186
+ self.__dict__ = self
187
+
188
+
189
+ def setup_logger():
190
+ logging.basicConfig(
191
+ level=logging.INFO if is_main_process() else logging.WARN,
192
+ format="%(asctime)s [%(levelname)s] %(message)s",
193
+ handlers=[logging.StreamHandler()],
194
+ )
mini_gpt4_llama_v2.py CHANGED
@@ -16,9 +16,10 @@ import torch
16
  from torch.cuda.amp import autocast as autocast
17
  import torch.nn as nn
18
 
19
- from minigpt4_video.registry import registry
20
- from minigpt4_video.blip2 import Blip2Base, disabled_train
21
- from minigpt4_video.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
 
22
  from transformers import LlamaTokenizer
23
  from transformers import BitsAndBytesConfig
24
  from transformers import AutoConfig, AutoTokenizer
@@ -34,7 +35,7 @@ import numpy as np
34
  import os
35
  from transformers import PretrainedConfig
36
  from transformers import PreTrainedModel
37
- from minigpt4_video.conversation import CONV_VISION
38
  import cv2
39
  def extract_audio(video_path, audio_path):
40
  video_clip = mp.VideoFileClip(video_path)
@@ -102,14 +103,15 @@ class MiniGPT4_Video(Blip2Base, PreTrainedModel):
102
  Blip2Base.__init__(self)
103
 
104
  vis_processor_cfg = {"name": "blip2_image_train","image_size": 224}
105
- self.vis_processor = registry.get_processor_class(vis_processor_cfg["name"]).from_config(vis_processor_cfg)
 
106
  self.CONV_VISION = CONV_VISION
107
  if "Mistral" in self.llama_model:
108
- from minigpt4_video.modeling_mistral import MistralForCausalLM as llm_model
109
  print("Mistral model")
110
  self.model_type = "Mistral"
111
  else:
112
- from minigpt4_video.modeling_llama_v2 import LlamaForCausalLM as llm_model
113
  print("Llama model")
114
  self.model_type = "Llama"
115
  self.tokenizer = self.init_tokenizer()
@@ -643,8 +645,7 @@ class MiniGPT4_Video(Blip2Base, PreTrainedModel):
643
  temperature=temperature,
644
  repetition_penalty=repetition_penalty,
645
  # stopping_criteria=stopping_criteria,
646
- # use_fastv=False,
647
- use_cache=True,
648
  )
649
 
650
  answers = []
 
16
  from torch.cuda.amp import autocast as autocast
17
  import torch.nn as nn
18
 
19
+ from .registry import registry
20
+ from .blip_processors import *
21
+ from .blip2 import Blip2Base, disabled_train
22
+ from .conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
23
  from transformers import LlamaTokenizer
24
  from transformers import BitsAndBytesConfig
25
  from transformers import AutoConfig, AutoTokenizer
 
35
  import os
36
  from transformers import PretrainedConfig
37
  from transformers import PreTrainedModel
38
+ from .conversation import CONV_VISION
39
  import cv2
40
  def extract_audio(video_path, audio_path):
41
  video_clip = mp.VideoFileClip(video_path)
 
103
  Blip2Base.__init__(self)
104
 
105
  vis_processor_cfg = {"name": "blip2_image_train","image_size": 224}
106
+ self.vis_processor = registry.get_processor_class(vis_processor_cfg["name"])
107
+ self.vis_processor = self.vis_processor.from_config(vis_processor_cfg)
108
  self.CONV_VISION = CONV_VISION
109
  if "Mistral" in self.llama_model:
110
+ from .modeling_mistral import MistralForCausalLM as llm_model
111
  print("Mistral model")
112
  self.model_type = "Mistral"
113
  else:
114
+ from .modeling_llama_v2 import LlamaForCausalLM as llm_model
115
  print("Llama model")
116
  self.model_type = "Llama"
117
  self.tokenizer = self.init_tokenizer()
 
645
  temperature=temperature,
646
  repetition_penalty=repetition_penalty,
647
  # stopping_criteria=stopping_criteria,
648
+ use_fastv=False,
 
649
  )
650
 
651
  answers = []
modeling_llama_v2.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch.nn import CrossEntropyLoss
7
+
8
+ from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
9
+ from transformers.modeling_outputs import CausalLMOutputWithPast
10
+ from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
11
+ from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig
12
+
13
+ class LlamaForCausalLM(LlamaForCausalLMOrig):
14
+
15
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
16
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
17
+ def forward(
18
+ self,
19
+ input_ids: torch.LongTensor = None,
20
+ attention_mask: Optional[torch.Tensor] = None,
21
+ position_ids: Optional[torch.LongTensor] = None,
22
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
23
+ inputs_embeds: Optional[torch.FloatTensor] = None,
24
+ labels: Optional[torch.LongTensor] = None,
25
+ use_cache: Optional[bool] = None,
26
+ output_attentions: Optional[bool] = None,
27
+ output_hidden_states: Optional[bool] = None,
28
+ return_dict: Optional[bool] = None,
29
+ cache_position: Optional[torch.LongTensor] = None,
30
+ reduction: Optional[str] = "mean",
31
+ use_fastv: Optional[bool] = False,
32
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
33
+ r"""
34
+ Args:
35
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
36
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
37
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
38
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
39
+
40
+ Returns:
41
+
42
+ Example:
43
+
44
+ ```python
45
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
46
+
47
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
48
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
49
+
50
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
51
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
52
+
53
+ >>> # Generate
54
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
55
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
56
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
57
+ ```"""
58
+
59
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
60
+ output_hidden_states = (
61
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
62
+ )
63
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
64
+
65
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
66
+ if use_fastv :
67
+ fastv_config = {
68
+ "use_fastv": True,
69
+ "fastv_k": 3,
70
+ "fastv_r": 0.75,
71
+ "image_token_start_index": 5,
72
+ "image_token_length": 576
73
+ }
74
+ print(f"Using fastv :{fastv_config}")
75
+ outputs = self.model.fastv_forward(
76
+ input_ids=input_ids,
77
+ attention_mask=attention_mask,
78
+ position_ids=position_ids,
79
+ past_key_values=past_key_values,
80
+ inputs_embeds=inputs_embeds,
81
+ use_cache=use_cache,
82
+ output_attentions=output_attentions,
83
+ output_hidden_states=output_hidden_states,
84
+ return_dict=return_dict,
85
+ fastv_config=fastv_config,
86
+ cache_position=cache_position,
87
+ )
88
+ else:
89
+ outputs = self.model(
90
+ input_ids=input_ids,
91
+ attention_mask=attention_mask,
92
+ position_ids=position_ids,
93
+ past_key_values=past_key_values,
94
+ inputs_embeds=inputs_embeds,
95
+ use_cache=use_cache,
96
+ output_attentions=output_attentions,
97
+ output_hidden_states=output_hidden_states,
98
+ return_dict=return_dict,
99
+ # cache_position=cache_position,
100
+ )
101
+
102
+ hidden_states = outputs[0]
103
+ if self.config.pretraining_tp > 1:
104
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
105
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
106
+ logits = torch.cat(logits, dim=-1)
107
+ else:
108
+ logits = self.lm_head(hidden_states)
109
+ logits = logits.float()
110
+
111
+ loss = None
112
+ if labels is not None:
113
+ # Shift so that tokens < n predict n
114
+ shift_logits = logits[..., :-1, :].contiguous()
115
+ shift_labels = labels[..., 1:].contiguous()
116
+ # Flatten the tokens
117
+ loss_fct = CrossEntropyLoss(reduction=reduction)
118
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
119
+ shift_labels = shift_labels.view(-1)
120
+ # Enable model parallelism
121
+ shift_labels = shift_labels.to(shift_logits.device)
122
+ loss = loss_fct(shift_logits, shift_labels)
123
+ if reduction == "none":
124
+ loss = loss.view(logits.size(0), -1).mean(1)
125
+
126
+ if not return_dict:
127
+ output = (logits,) + outputs[1:]
128
+ return (loss,) + output if loss is not None else output
129
+
130
+ return CausalLMOutputWithPast(
131
+ loss=loss,
132
+ logits=logits,
133
+ past_key_values=outputs.past_key_values,
134
+ hidden_states=outputs.hidden_states,
135
+ attentions=outputs.attentions,
136
+ )
modeling_mistral.py ADDED
@@ -0,0 +1,1388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch Mistral model."""
21
+ import inspect
22
+ import math
23
+ import warnings
24
+ from typing import List, Optional, Tuple, Union
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torch.utils.checkpoint
29
+ from torch import nn
30
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
+
32
+ from transformers.activations import ACT2FN
33
+ from transformers.cache_utils import Cache, DynamicCache
34
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
35
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
36
+ from transformers.modeling_utils import PreTrainedModel
37
+ from transformers.utils import (
38
+ add_start_docstrings,
39
+ add_start_docstrings_to_model_forward,
40
+ is_flash_attn_2_available,
41
+ is_flash_attn_greater_or_equal_2_10,
42
+ logging,
43
+ replace_return_docstrings,
44
+ )
45
+ from transformers.models.mistral.configuration_mistral import MistralConfig
46
+
47
+
48
+ if is_flash_attn_2_available():
49
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
50
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
51
+
52
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
53
+
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+ _CONFIG_FOR_DOC = "MistralConfig"
58
+
59
+
60
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
61
+ def _get_unpad_data(attention_mask):
62
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
63
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
64
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
65
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
66
+ return (
67
+ indices,
68
+ cu_seqlens,
69
+ max_seqlen_in_batch,
70
+ )
71
+
72
+
73
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
74
+ class MistralRMSNorm(nn.Module):
75
+ def __init__(self, hidden_size, eps=1e-6):
76
+ """
77
+ MistralRMSNorm is equivalent to T5LayerNorm
78
+ """
79
+ super().__init__()
80
+ self.weight = nn.Parameter(torch.ones(hidden_size))
81
+ self.variance_epsilon = eps
82
+
83
+ def forward(self, hidden_states):
84
+ input_dtype = hidden_states.dtype
85
+ hidden_states = hidden_states.to(torch.float32)
86
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
87
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
88
+ return self.weight * hidden_states.to(input_dtype)
89
+
90
+
91
+ # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
92
+ # TODO @Arthur no longer copied from LLama after static cache
93
+ class MistralRotaryEmbedding(nn.Module):
94
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
95
+ super().__init__()
96
+
97
+ self.dim = dim
98
+ self.max_position_embeddings = max_position_embeddings
99
+ self.base = base
100
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
101
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
102
+
103
+ # Build here to make `torch.jit.trace` work.
104
+ self._set_cos_sin_cache(
105
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
106
+ )
107
+
108
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
109
+ self.max_seq_len_cached = seq_len
110
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
111
+
112
+ freqs = torch.outer(t, self.inv_freq)
113
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
114
+ emb = torch.cat((freqs, freqs), dim=-1)
115
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
116
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
117
+
118
+ def forward(self, x, seq_len=None):
119
+ # x: [bs, num_attention_heads, seq_len, head_size]
120
+ if seq_len > self.max_seq_len_cached:
121
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
122
+
123
+ return (
124
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
125
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
126
+ )
127
+
128
+
129
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
130
+ def rotate_half(x):
131
+ """Rotates half the hidden dims of the input."""
132
+ x1 = x[..., : x.shape[-1] // 2]
133
+ x2 = x[..., x.shape[-1] // 2 :]
134
+ return torch.cat((-x2, x1), dim=-1)
135
+
136
+
137
+ # copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
138
+ # TODO @Arthur no longer copied from LLama after static cache
139
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
140
+ """Applies Rotary Position Embedding to the query and key tensors.
141
+
142
+ Args:
143
+ q (`torch.Tensor`): The query tensor.
144
+ k (`torch.Tensor`): The key tensor.
145
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
146
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
147
+ position_ids (`torch.Tensor`):
148
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
149
+ used to pass offsetted position ids when working with a KV-cache.
150
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
151
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
152
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
153
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
154
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
155
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
156
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
157
+ Returns:
158
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
159
+ """
160
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
161
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
162
+ q_embed = (q * cos) + (rotate_half(q) * sin)
163
+ k_embed = (k * cos) + (rotate_half(k) * sin)
164
+ return q_embed, k_embed
165
+
166
+
167
+ class MistralMLP(nn.Module):
168
+ def __init__(self, config):
169
+ super().__init__()
170
+ self.config = config
171
+ self.hidden_size = config.hidden_size
172
+ self.intermediate_size = config.intermediate_size
173
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
174
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
175
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
176
+ self.act_fn = ACT2FN[config.hidden_act]
177
+
178
+ def forward(self, x):
179
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
180
+
181
+
182
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
183
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
184
+ """
185
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
186
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
187
+ """
188
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
189
+ if n_rep == 1:
190
+ return hidden_states
191
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
192
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
193
+
194
+
195
+ class MistralAttention(nn.Module):
196
+ """
197
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
198
+ and "Generating Long Sequences with Sparse Transformers".
199
+ """
200
+
201
+ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None):
202
+ super().__init__()
203
+ self.config = config
204
+ self.layer_idx = layer_idx
205
+ if layer_idx is None:
206
+ logger.warning_once(
207
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
208
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
209
+ "when creating this class."
210
+ )
211
+
212
+ self.hidden_size = config.hidden_size
213
+ self.num_heads = config.num_attention_heads
214
+ self.head_dim = self.hidden_size // self.num_heads
215
+ self.num_key_value_heads = config.num_key_value_heads
216
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
217
+ self.max_position_embeddings = config.max_position_embeddings
218
+ self.rope_theta = config.rope_theta
219
+ self.is_causal = True
220
+ self.attention_dropout = config.attention_dropout
221
+
222
+ if (self.head_dim * self.num_heads) != self.hidden_size:
223
+ raise ValueError(
224
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
225
+ f" and `num_heads`: {self.num_heads})."
226
+ )
227
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
228
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
229
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
230
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
231
+
232
+ self.rotary_emb = MistralRotaryEmbedding(
233
+ self.head_dim,
234
+ max_position_embeddings=self.max_position_embeddings,
235
+ base=self.rope_theta,
236
+ )
237
+
238
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
239
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
240
+
241
+ def forward(
242
+ self,
243
+ hidden_states: torch.Tensor,
244
+ attention_mask: Optional[torch.Tensor] = None,
245
+ position_ids: Optional[torch.LongTensor] = None,
246
+ past_key_value: Optional[Cache] = None,
247
+ output_attentions: bool = False,
248
+ use_cache: bool = False,
249
+ **kwargs,
250
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
251
+ if "padding_mask" in kwargs:
252
+ warnings.warn(
253
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
254
+ )
255
+ bsz, q_len, _ = hidden_states.size()
256
+
257
+ query_states = self.q_proj(hidden_states)
258
+ key_states = self.k_proj(hidden_states)
259
+ value_states = self.v_proj(hidden_states)
260
+
261
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
262
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
263
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
264
+
265
+ kv_seq_len = key_states.shape[-2]
266
+ if past_key_value is not None:
267
+ if self.layer_idx is None:
268
+ raise ValueError(
269
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
270
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
271
+ "with a layer index."
272
+ )
273
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
274
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
275
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
276
+
277
+ if past_key_value is not None:
278
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
279
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
280
+
281
+ # repeat k/v heads if n_kv_heads < n_heads
282
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
283
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
284
+
285
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
286
+
287
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
288
+ raise ValueError(
289
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
290
+ f" {attn_weights.size()}"
291
+ )
292
+
293
+ if attention_mask is not None:
294
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
295
+ raise ValueError(
296
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
297
+ )
298
+
299
+ attn_weights = attn_weights + attention_mask
300
+
301
+ # upcast attention to fp32
302
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
303
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
304
+ attn_output = torch.matmul(attn_weights, value_states)
305
+
306
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
307
+ raise ValueError(
308
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
309
+ f" {attn_output.size()}"
310
+ )
311
+
312
+ attn_output = attn_output.transpose(1, 2).contiguous()
313
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
314
+
315
+ attn_output = self.o_proj(attn_output)
316
+
317
+ if not output_attentions:
318
+ attn_weights = None
319
+
320
+ return attn_output, attn_weights, past_key_value
321
+
322
+
323
+ class MistralFlashAttention2(MistralAttention):
324
+ """
325
+ Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays
326
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
327
+ flash attention and deal with padding tokens in case the input contains any of them.
328
+ """
329
+
330
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
331
+ def __init__(self, *args, **kwargs):
332
+ super().__init__(*args, **kwargs)
333
+
334
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
335
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
336
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
337
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
338
+
339
+ def forward(
340
+ self,
341
+ hidden_states: torch.Tensor,
342
+ attention_mask: Optional[torch.Tensor] = None,
343
+ position_ids: Optional[torch.LongTensor] = None,
344
+ past_key_value: Optional[Cache] = None,
345
+ output_attentions: bool = False,
346
+ use_cache: bool = False,
347
+ **kwargs,
348
+ ):
349
+ if "padding_mask" in kwargs:
350
+ warnings.warn(
351
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
352
+ )
353
+
354
+ # overwrite attention_mask with padding_mask
355
+ attention_mask = kwargs.pop("padding_mask")
356
+ bsz, q_len, _ = hidden_states.size()
357
+
358
+ query_states = self.q_proj(hidden_states)
359
+ key_states = self.k_proj(hidden_states)
360
+ value_states = self.v_proj(hidden_states)
361
+
362
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
363
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
364
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
365
+
366
+ kv_seq_len = key_states.shape[-2]
367
+ if past_key_value is not None:
368
+ if self.layer_idx is None:
369
+ raise ValueError(
370
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
371
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
372
+ "with a layer index."
373
+ )
374
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
375
+
376
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
377
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
378
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
379
+
380
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
381
+
382
+ use_sliding_windows = (
383
+ _flash_supports_window_size
384
+ and getattr(self.config, "sliding_window", None) is not None
385
+ and kv_seq_len > self.config.sliding_window
386
+ )
387
+
388
+ if not _flash_supports_window_size:
389
+ logger.warning_once(
390
+ "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
391
+ " make sure to upgrade flash-attn library."
392
+ )
393
+
394
+ if past_key_value is not None:
395
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
396
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
397
+ if (
398
+ getattr(self.config, "sliding_window", None) is not None
399
+ and kv_seq_len > self.config.sliding_window
400
+ and cache_has_contents
401
+ ):
402
+ slicing_tokens = 1 - self.config.sliding_window
403
+
404
+ past_key = past_key_value[self.layer_idx][0]
405
+ past_value = past_key_value[self.layer_idx][1]
406
+
407
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
408
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
409
+
410
+ if past_key.shape[-2] != self.config.sliding_window - 1:
411
+ raise ValueError(
412
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
413
+ f" {past_key.shape}"
414
+ )
415
+
416
+ if attention_mask is not None:
417
+ attention_mask = attention_mask[:, slicing_tokens:]
418
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
419
+
420
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
421
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
422
+
423
+ # repeat k/v heads if n_kv_heads < n_heads
424
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
425
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
426
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
427
+
428
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
429
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
430
+ # cast them back in float16 just to be sure everything works as expected.
431
+ input_dtype = query_states.dtype
432
+ if input_dtype == torch.float32:
433
+ if torch.is_autocast_enabled():
434
+ target_dtype = torch.get_autocast_gpu_dtype()
435
+ # Handle the case where the model is quantized
436
+ elif hasattr(self.config, "_pre_quantization_dtype"):
437
+ target_dtype = self.config._pre_quantization_dtype
438
+ else:
439
+ target_dtype = self.q_proj.weight.dtype
440
+
441
+ logger.warning_once(
442
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
443
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
444
+ f" {target_dtype}."
445
+ )
446
+
447
+ query_states = query_states.to(target_dtype)
448
+ key_states = key_states.to(target_dtype)
449
+ value_states = value_states.to(target_dtype)
450
+
451
+ # Reashape to the expected shape for Flash Attention
452
+ query_states = query_states.transpose(1, 2)
453
+ key_states = key_states.transpose(1, 2)
454
+ value_states = value_states.transpose(1, 2)
455
+
456
+ attn_output = self._flash_attention_forward(
457
+ query_states,
458
+ key_states,
459
+ value_states,
460
+ attention_mask,
461
+ q_len,
462
+ dropout=dropout_rate,
463
+ use_sliding_windows=use_sliding_windows,
464
+ )
465
+
466
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
467
+ attn_output = self.o_proj(attn_output)
468
+
469
+ if not output_attentions:
470
+ attn_weights = None
471
+
472
+ return attn_output, attn_weights, past_key_value
473
+
474
+ def _flash_attention_forward(
475
+ self,
476
+ query_states,
477
+ key_states,
478
+ value_states,
479
+ attention_mask,
480
+ query_length,
481
+ dropout=0.0,
482
+ softmax_scale=None,
483
+ use_sliding_windows=False,
484
+ ):
485
+ """
486
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
487
+ first unpad the input, then computes the attention scores and pad the final attention scores.
488
+
489
+ Args:
490
+ query_states (`torch.Tensor`):
491
+ Input query states to be passed to Flash Attention API
492
+ key_states (`torch.Tensor`):
493
+ Input key states to be passed to Flash Attention API
494
+ value_states (`torch.Tensor`):
495
+ Input value states to be passed to Flash Attention API
496
+ attention_mask (`torch.Tensor`):
497
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
498
+ position of padding tokens and 1 for the position of non-padding tokens.
499
+ dropout (`int`, *optional*):
500
+ Attention dropout
501
+ softmax_scale (`float`, *optional*):
502
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
503
+ use_sliding_windows (`bool`, *optional*):
504
+ Whether to activate sliding window attention.
505
+ """
506
+ if not self._flash_attn_uses_top_left_mask:
507
+ causal = self.is_causal
508
+ else:
509
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
510
+ causal = self.is_causal and query_length != 1
511
+
512
+ # Contains at least one padding token in the sequence
513
+ if attention_mask is not None:
514
+ batch_size = query_states.shape[0]
515
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
516
+ query_states, key_states, value_states, attention_mask, query_length
517
+ )
518
+
519
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
520
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
521
+
522
+ if not use_sliding_windows:
523
+ attn_output_unpad = flash_attn_varlen_func(
524
+ query_states,
525
+ key_states,
526
+ value_states,
527
+ cu_seqlens_q=cu_seqlens_q,
528
+ cu_seqlens_k=cu_seqlens_k,
529
+ max_seqlen_q=max_seqlen_in_batch_q,
530
+ max_seqlen_k=max_seqlen_in_batch_k,
531
+ dropout_p=dropout,
532
+ softmax_scale=softmax_scale,
533
+ causal=causal,
534
+ )
535
+ else:
536
+ attn_output_unpad = flash_attn_varlen_func(
537
+ query_states,
538
+ key_states,
539
+ value_states,
540
+ cu_seqlens_q=cu_seqlens_q,
541
+ cu_seqlens_k=cu_seqlens_k,
542
+ max_seqlen_q=max_seqlen_in_batch_q,
543
+ max_seqlen_k=max_seqlen_in_batch_k,
544
+ dropout_p=dropout,
545
+ softmax_scale=softmax_scale,
546
+ causal=causal,
547
+ window_size=(self.config.sliding_window, self.config.sliding_window),
548
+ )
549
+
550
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
551
+ else:
552
+ if not use_sliding_windows:
553
+ attn_output = flash_attn_func(
554
+ query_states,
555
+ key_states,
556
+ value_states,
557
+ dropout,
558
+ softmax_scale=softmax_scale,
559
+ causal=causal,
560
+ )
561
+ else:
562
+ attn_output = flash_attn_func(
563
+ query_states,
564
+ key_states,
565
+ value_states,
566
+ dropout,
567
+ softmax_scale=softmax_scale,
568
+ causal=causal,
569
+ window_size=(self.config.sliding_window, self.config.sliding_window),
570
+ )
571
+
572
+ return attn_output
573
+
574
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
575
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
576
+
577
+ # On the first iteration we need to properly re-create the padding mask
578
+ # by slicing it on the proper place
579
+ if kv_seq_len != attention_mask.shape[-1]:
580
+ attention_mask_num_tokens = attention_mask.shape[-1]
581
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
582
+
583
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
584
+
585
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
586
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
587
+
588
+ if query_length == kv_seq_len:
589
+ query_layer = index_first_axis(
590
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
591
+ )
592
+ cu_seqlens_q = cu_seqlens_k
593
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
594
+ indices_q = indices_k
595
+ elif query_length == 1:
596
+ max_seqlen_in_batch_q = 1
597
+ cu_seqlens_q = torch.arange(
598
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
599
+ ) # There is a memcpy here, that is very bad.
600
+ indices_q = cu_seqlens_q[:-1]
601
+ query_layer = query_layer.squeeze(1)
602
+ else:
603
+ # The -q_len: slice assumes left padding.
604
+ attention_mask = attention_mask[:, -query_length:]
605
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
606
+
607
+ return (
608
+ query_layer,
609
+ key_layer,
610
+ value_layer,
611
+ indices_q,
612
+ (cu_seqlens_q, cu_seqlens_k),
613
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
614
+ )
615
+
616
+
617
+ # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
618
+ # TODO @Arthur no longer copied from LLama after static cache
619
+ class MistralSdpaAttention(MistralAttention):
620
+ """
621
+ Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
622
+ `MistralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
623
+ SDPA API.
624
+ """
625
+
626
+ # Adapted from MistralAttention.forward
627
+ def forward(
628
+ self,
629
+ hidden_states: torch.Tensor,
630
+ attention_mask: Optional[torch.Tensor] = None,
631
+ position_ids: Optional[torch.LongTensor] = None,
632
+ past_key_value: Optional[Cache] = None,
633
+ output_attentions: bool = False,
634
+ use_cache: bool = False,
635
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
636
+ if output_attentions:
637
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
638
+ logger.warning_once(
639
+ "MistralModel is using MistralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
640
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
641
+ )
642
+ return super().forward(
643
+ hidden_states=hidden_states,
644
+ attention_mask=attention_mask,
645
+ position_ids=position_ids,
646
+ past_key_value=past_key_value,
647
+ output_attentions=output_attentions,
648
+ use_cache=use_cache,
649
+ )
650
+
651
+ bsz, q_len, _ = hidden_states.size()
652
+
653
+ query_states = self.q_proj(hidden_states)
654
+ key_states = self.k_proj(hidden_states)
655
+ value_states = self.v_proj(hidden_states)
656
+
657
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
658
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
659
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
660
+
661
+ kv_seq_len = key_states.shape[-2]
662
+ if past_key_value is not None:
663
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
664
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
665
+
666
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
667
+
668
+ if past_key_value is not None:
669
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
670
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
671
+
672
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
673
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
674
+
675
+ if attention_mask is not None:
676
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
677
+ raise ValueError(
678
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
679
+ )
680
+
681
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
682
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
683
+ if query_states.device.type == "cuda" and attention_mask is not None:
684
+ query_states = query_states.contiguous()
685
+ key_states = key_states.contiguous()
686
+ value_states = value_states.contiguous()
687
+
688
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
689
+ query_states,
690
+ key_states,
691
+ value_states,
692
+ attn_mask=attention_mask,
693
+ dropout_p=self.attention_dropout if self.training else 0.0,
694
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
695
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
696
+ )
697
+
698
+ attn_output = attn_output.transpose(1, 2).contiguous()
699
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
700
+
701
+ attn_output = self.o_proj(attn_output)
702
+
703
+ return attn_output, None, past_key_value
704
+
705
+
706
+ MISTRAL_ATTENTION_CLASSES = {
707
+ "eager": MistralAttention,
708
+ "flash_attention_2": MistralFlashAttention2,
709
+ "sdpa": MistralSdpaAttention,
710
+ }
711
+
712
+
713
+ class MistralDecoderLayer(nn.Module):
714
+ def __init__(self, config: MistralConfig, layer_idx: int):
715
+ super().__init__()
716
+ self.hidden_size = config.hidden_size
717
+
718
+ self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
719
+
720
+ self.mlp = MistralMLP(config)
721
+ self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
722
+ self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
723
+
724
+ def forward(
725
+ self,
726
+ hidden_states: torch.Tensor,
727
+ attention_mask: Optional[torch.Tensor] = None,
728
+ position_ids: Optional[torch.LongTensor] = None,
729
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
730
+ output_attentions: Optional[bool] = False,
731
+ use_cache: Optional[bool] = False,
732
+ **kwargs,
733
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
734
+ if "padding_mask" in kwargs:
735
+ warnings.warn(
736
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
737
+ )
738
+ """
739
+ Args:
740
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
741
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
742
+ `(batch, sequence_length)` where padding elements are indicated by 0.
743
+ output_attentions (`bool`, *optional*):
744
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
745
+ returned tensors for more detail.
746
+ use_cache (`bool`, *optional*):
747
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
748
+ (see `past_key_values`).
749
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
750
+ """
751
+
752
+ residual = hidden_states
753
+
754
+ hidden_states = self.input_layernorm(hidden_states)
755
+
756
+ # Self Attention
757
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
758
+ hidden_states=hidden_states,
759
+ attention_mask=attention_mask,
760
+ position_ids=position_ids,
761
+ past_key_value=past_key_value,
762
+ output_attentions=output_attentions,
763
+ use_cache=use_cache,
764
+ )
765
+ hidden_states = residual + hidden_states
766
+
767
+ # Fully Connected
768
+ residual = hidden_states
769
+ hidden_states = self.post_attention_layernorm(hidden_states)
770
+ hidden_states = self.mlp(hidden_states)
771
+ hidden_states = residual + hidden_states
772
+
773
+ outputs = (hidden_states,)
774
+
775
+ if output_attentions:
776
+ outputs += (self_attn_weights,)
777
+
778
+ if use_cache:
779
+ outputs += (present_key_value,)
780
+
781
+ return outputs
782
+
783
+
784
+ MISTRAL_START_DOCSTRING = r"""
785
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
786
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
787
+ etc.)
788
+
789
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
790
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
791
+ and behavior.
792
+
793
+ Parameters:
794
+ config ([`MistralConfig`]):
795
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
796
+ load the weights associated with the model, only the configuration. Check out the
797
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
798
+ """
799
+
800
+
801
+ @add_start_docstrings(
802
+ "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
803
+ MISTRAL_START_DOCSTRING,
804
+ )
805
+ class MistralPreTrainedModel(PreTrainedModel):
806
+ config_class = MistralConfig
807
+ base_model_prefix = "model"
808
+ supports_gradient_checkpointing = True
809
+ _no_split_modules = ["MistralDecoderLayer"]
810
+ _skip_keys_device_placement = "past_key_values"
811
+ _supports_flash_attn_2 = True
812
+ _supports_sdpa = True
813
+ _supports_cache_class = True
814
+
815
+ def _init_weights(self, module):
816
+ std = self.config.initializer_range
817
+ if isinstance(module, nn.Linear):
818
+ module.weight.data.normal_(mean=0.0, std=std)
819
+ if module.bias is not None:
820
+ module.bias.data.zero_()
821
+ elif isinstance(module, nn.Embedding):
822
+ module.weight.data.normal_(mean=0.0, std=std)
823
+ if module.padding_idx is not None:
824
+ module.weight.data[module.padding_idx].zero_()
825
+
826
+
827
+ MISTRAL_INPUTS_DOCSTRING = r"""
828
+ Args:
829
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
830
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
831
+ it.
832
+
833
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
834
+ [`PreTrainedTokenizer.__call__`] for details.
835
+
836
+ [What are input IDs?](../glossary#input-ids)
837
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
838
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
839
+
840
+ - 1 for tokens that are **not masked**,
841
+ - 0 for tokens that are **masked**.
842
+
843
+ [What are attention masks?](../glossary#attention-mask)
844
+
845
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
846
+ [`PreTrainedTokenizer.__call__`] for details.
847
+
848
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
849
+ `past_key_values`).
850
+
851
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
852
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
853
+ information on the default strategy.
854
+
855
+ - 1 indicates the head is **not masked**,
856
+ - 0 indicates the head is **masked**.
857
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
858
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
859
+ config.n_positions - 1]`.
860
+
861
+ [What are position IDs?](../glossary#position-ids)
862
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
863
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
864
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
865
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
866
+
867
+ Two formats are allowed:
868
+ - a [`~cache_utils.Cache`] instance;
869
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
870
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
871
+ cache format.
872
+
873
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
874
+ legacy cache format will be returned.
875
+
876
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
877
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
878
+ of shape `(batch_size, sequence_length)`.
879
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
880
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
881
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
882
+ model's internal embedding lookup matrix.
883
+ use_cache (`bool`, *optional*):
884
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
885
+ `past_key_values`).
886
+ output_attentions (`bool`, *optional*):
887
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
888
+ tensors for more detail.
889
+ output_hidden_states (`bool`, *optional*):
890
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
891
+ more detail.
892
+ return_dict (`bool`, *optional*):
893
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
894
+ """
895
+
896
+
897
+ @add_start_docstrings(
898
+ "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
899
+ MISTRAL_START_DOCSTRING,
900
+ )
901
+ class MistralModel(MistralPreTrainedModel):
902
+ """
903
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`]
904
+
905
+ Args:
906
+ config: MistralConfig
907
+ """
908
+
909
+ def __init__(self, config: MistralConfig):
910
+ super().__init__(config)
911
+ self.padding_idx = config.pad_token_id
912
+ self.vocab_size = config.vocab_size
913
+
914
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
915
+ self.layers = nn.ModuleList(
916
+ [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
917
+ )
918
+ self._attn_implementation = config._attn_implementation
919
+ self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
920
+
921
+ self.gradient_checkpointing = False
922
+ # Initialize weights and apply final processing
923
+ self.post_init()
924
+
925
+ def get_input_embeddings(self):
926
+ return self.embed_tokens
927
+
928
+ def set_input_embeddings(self, value):
929
+ self.embed_tokens = value
930
+
931
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
932
+ def forward(
933
+ self,
934
+ input_ids: torch.LongTensor = None,
935
+ attention_mask: Optional[torch.Tensor] = None,
936
+ position_ids: Optional[torch.LongTensor] = None,
937
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
938
+ inputs_embeds: Optional[torch.FloatTensor] = None,
939
+ use_cache: Optional[bool] = None,
940
+ output_attentions: Optional[bool] = None,
941
+ output_hidden_states: Optional[bool] = None,
942
+ return_dict: Optional[bool] = None,
943
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
944
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
945
+ output_hidden_states = (
946
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
947
+ )
948
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
949
+
950
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
951
+
952
+ # retrieve input_ids and inputs_embeds
953
+ if input_ids is not None and inputs_embeds is not None:
954
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
955
+ elif input_ids is not None:
956
+ batch_size, seq_length = input_ids.shape
957
+ elif inputs_embeds is not None:
958
+ batch_size, seq_length, _ = inputs_embeds.shape
959
+ else:
960
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
961
+
962
+ if self.gradient_checkpointing and self.training:
963
+ if use_cache:
964
+ logger.warning_once(
965
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
966
+ )
967
+ use_cache = False
968
+
969
+ past_key_values_length = 0
970
+
971
+ if use_cache:
972
+ use_legacy_cache = not isinstance(past_key_values, Cache)
973
+ if use_legacy_cache:
974
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
975
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
976
+
977
+ if position_ids is None:
978
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
979
+ position_ids = torch.arange(
980
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
981
+ )
982
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
983
+ else:
984
+ position_ids = position_ids.view(-1, seq_length).long()
985
+
986
+ if inputs_embeds is None:
987
+ inputs_embeds = self.embed_tokens(input_ids)
988
+
989
+ if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
990
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
991
+ if is_padding_right:
992
+ raise ValueError(
993
+ "You are attempting to perform batched generation with padding_side='right'"
994
+ " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
995
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
996
+ )
997
+
998
+ if self._attn_implementation == "flash_attention_2":
999
+ # 2d mask is passed through the layers
1000
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1001
+ elif self._attn_implementation == "sdpa" and not output_attentions:
1002
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1003
+ # the manual implementation that requires a 4D causal mask in all cases.
1004
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1005
+ attention_mask,
1006
+ (batch_size, seq_length),
1007
+ inputs_embeds,
1008
+ past_key_values_length,
1009
+ )
1010
+ else:
1011
+ # 4d mask is passed through the layers
1012
+ attention_mask = _prepare_4d_causal_attention_mask(
1013
+ attention_mask,
1014
+ (batch_size, seq_length),
1015
+ inputs_embeds,
1016
+ past_key_values_length,
1017
+ sliding_window=self.config.sliding_window,
1018
+ )
1019
+
1020
+ hidden_states = inputs_embeds
1021
+
1022
+ # decoder layers
1023
+ all_hidden_states = () if output_hidden_states else None
1024
+ all_self_attns = () if output_attentions else None
1025
+ next_decoder_cache = None
1026
+
1027
+ for decoder_layer in self.layers:
1028
+ if output_hidden_states:
1029
+ all_hidden_states += (hidden_states,)
1030
+
1031
+ if self.gradient_checkpointing and self.training:
1032
+ layer_outputs = self._gradient_checkpointing_func(
1033
+ decoder_layer.__call__,
1034
+ hidden_states,
1035
+ attention_mask,
1036
+ position_ids,
1037
+ past_key_values,
1038
+ output_attentions,
1039
+ use_cache,
1040
+ )
1041
+ else:
1042
+ layer_outputs = decoder_layer(
1043
+ hidden_states,
1044
+ attention_mask=attention_mask,
1045
+ position_ids=position_ids,
1046
+ past_key_value=past_key_values,
1047
+ output_attentions=output_attentions,
1048
+ use_cache=use_cache,
1049
+ )
1050
+
1051
+ hidden_states = layer_outputs[0]
1052
+
1053
+ if use_cache:
1054
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1055
+
1056
+ if output_attentions:
1057
+ all_self_attns += (layer_outputs[1],)
1058
+
1059
+ hidden_states = self.norm(hidden_states)
1060
+
1061
+ # add hidden states from the last decoder layer
1062
+ if output_hidden_states:
1063
+ all_hidden_states += (hidden_states,)
1064
+
1065
+ next_cache = None
1066
+ if use_cache:
1067
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1068
+
1069
+ if not return_dict:
1070
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1071
+ return BaseModelOutputWithPast(
1072
+ last_hidden_state=hidden_states,
1073
+ past_key_values=next_cache,
1074
+ hidden_states=all_hidden_states,
1075
+ attentions=all_self_attns,
1076
+ )
1077
+
1078
+
1079
+ class MistralForCausalLM(MistralPreTrainedModel):
1080
+ _tied_weights_keys = ["lm_head.weight"]
1081
+
1082
+ def __init__(self, config):
1083
+ super().__init__(config)
1084
+ self.model = MistralModel(config)
1085
+ self.vocab_size = config.vocab_size
1086
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1087
+
1088
+ # Initialize weights and apply final processing
1089
+ self.post_init()
1090
+
1091
+ def get_input_embeddings(self):
1092
+ return self.model.embed_tokens
1093
+
1094
+ def set_input_embeddings(self, value):
1095
+ self.model.embed_tokens = value
1096
+
1097
+ def get_output_embeddings(self):
1098
+ return self.lm_head
1099
+
1100
+ def set_output_embeddings(self, new_embeddings):
1101
+ self.lm_head = new_embeddings
1102
+
1103
+ def set_decoder(self, decoder):
1104
+ self.model = decoder
1105
+
1106
+ def get_decoder(self):
1107
+ return self.model
1108
+
1109
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
1110
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1111
+ def forward(
1112
+ self,
1113
+ input_ids: torch.LongTensor = None,
1114
+ attention_mask: Optional[torch.Tensor] = None,
1115
+ position_ids: Optional[torch.LongTensor] = None,
1116
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1117
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1118
+ labels: Optional[torch.LongTensor] = None,
1119
+ use_cache: Optional[bool] = None,
1120
+ output_attentions: Optional[bool] = None,
1121
+ output_hidden_states: Optional[bool] = None,
1122
+ return_dict: Optional[bool] = None,
1123
+ reduction: Optional[str] = "mean",
1124
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1125
+ r"""
1126
+ Args:
1127
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1128
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1129
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1130
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1131
+
1132
+ Returns:
1133
+
1134
+ Example:
1135
+
1136
+ ```python
1137
+ >>> from transformers import AutoTokenizer, MistralForCausalLM
1138
+
1139
+ >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
1140
+ >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
1141
+
1142
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1143
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1144
+
1145
+ >>> # Generate
1146
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1147
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1148
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1149
+ ```"""
1150
+
1151
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1152
+ output_hidden_states = (
1153
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1154
+ )
1155
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1156
+
1157
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1158
+ outputs = self.model(
1159
+ input_ids=input_ids,
1160
+ attention_mask=attention_mask,
1161
+ position_ids=position_ids,
1162
+ past_key_values=past_key_values,
1163
+ inputs_embeds=inputs_embeds,
1164
+ use_cache=use_cache,
1165
+ output_attentions=output_attentions,
1166
+ output_hidden_states=output_hidden_states,
1167
+ return_dict=return_dict,
1168
+ )
1169
+
1170
+ hidden_states = outputs[0]
1171
+ logits = self.lm_head(hidden_states)
1172
+ logits = logits.float()
1173
+
1174
+ loss = None
1175
+ if labels is not None:
1176
+ # Shift so that tokens < n predict n
1177
+ shift_logits = logits[..., :-1, :].contiguous()
1178
+ shift_labels = labels[..., 1:].contiguous()
1179
+ # Flatten the tokens
1180
+ loss_fct = CrossEntropyLoss(reduction=reduction)
1181
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1182
+ shift_labels = shift_labels.view(-1)
1183
+ # Enable model parallelism
1184
+ shift_labels = shift_labels.to(shift_logits.device)
1185
+ loss = loss_fct(shift_logits, shift_labels)
1186
+ if reduction == "none":
1187
+ loss = loss.view(logits.size(0), -1).mean(1)
1188
+ if not return_dict:
1189
+ output = (logits,) + outputs[1:]
1190
+ return (loss,) + output if loss is not None else output
1191
+
1192
+ return CausalLMOutputWithPast(
1193
+ loss=loss,
1194
+ logits=logits,
1195
+ past_key_values=outputs.past_key_values,
1196
+ hidden_states=outputs.hidden_states,
1197
+ attentions=outputs.attentions,
1198
+ )
1199
+
1200
+ def prepare_inputs_for_generation(
1201
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1202
+ ):
1203
+ # Omit tokens covered by past_key_values
1204
+ if past_key_values is not None:
1205
+ if isinstance(past_key_values, Cache):
1206
+ cache_length = past_key_values.get_seq_length()
1207
+ past_length = past_key_values.seen_tokens
1208
+ max_cache_length = past_key_values.get_max_length()
1209
+ else:
1210
+ cache_length = past_length = past_key_values[0][0].shape[2]
1211
+ max_cache_length = None
1212
+
1213
+ # Keep only the unprocessed tokens:
1214
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1215
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1216
+ # input)
1217
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1218
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1219
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1220
+ # input_ids based on the past_length.
1221
+ elif past_length < input_ids.shape[1]:
1222
+ input_ids = input_ids[:, past_length:]
1223
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1224
+
1225
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1226
+ if (
1227
+ max_cache_length is not None
1228
+ and attention_mask is not None
1229
+ and cache_length + input_ids.shape[1] > max_cache_length
1230
+ ):
1231
+ attention_mask = attention_mask[:, -max_cache_length:]
1232
+
1233
+ position_ids = kwargs.get("position_ids", None)
1234
+ if attention_mask is not None and position_ids is None:
1235
+ # create position_ids on the fly for batch generation
1236
+ position_ids = attention_mask.long().cumsum(-1) - 1
1237
+ position_ids.masked_fill_(attention_mask == 0, 1)
1238
+ if past_key_values:
1239
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1240
+
1241
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1242
+ if inputs_embeds is not None and past_key_values is None:
1243
+ model_inputs = {"inputs_embeds": inputs_embeds}
1244
+ else:
1245
+ model_inputs = {"input_ids": input_ids}
1246
+
1247
+ model_inputs.update(
1248
+ {
1249
+ "position_ids": position_ids,
1250
+ "past_key_values": past_key_values,
1251
+ "use_cache": kwargs.get("use_cache"),
1252
+ "attention_mask": attention_mask,
1253
+ }
1254
+ )
1255
+ return model_inputs
1256
+
1257
+ @staticmethod
1258
+ def _reorder_cache(past_key_values, beam_idx):
1259
+ reordered_past = ()
1260
+ for layer_past in past_key_values:
1261
+ reordered_past += (
1262
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1263
+ )
1264
+ return reordered_past
1265
+
1266
+
1267
+ @add_start_docstrings(
1268
+ """
1269
+ The Mistral Model transformer with a sequence classification head on top (linear layer).
1270
+
1271
+ [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1272
+ (e.g. GPT-2) do.
1273
+
1274
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1275
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1276
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1277
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1278
+ each row of the batch).
1279
+ """,
1280
+ MISTRAL_START_DOCSTRING,
1281
+ )
1282
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL
1283
+ class MistralForSequenceClassification(MistralPreTrainedModel):
1284
+ def __init__(self, config):
1285
+ super().__init__(config)
1286
+ self.num_labels = config.num_labels
1287
+ self.model = MistralModel(config)
1288
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1289
+
1290
+ # Initialize weights and apply final processing
1291
+ self.post_init()
1292
+
1293
+ def get_input_embeddings(self):
1294
+ return self.model.embed_tokens
1295
+
1296
+ def set_input_embeddings(self, value):
1297
+ self.model.embed_tokens = value
1298
+
1299
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
1300
+ def forward(
1301
+ self,
1302
+ input_ids: torch.LongTensor = None,
1303
+ attention_mask: Optional[torch.Tensor] = None,
1304
+ position_ids: Optional[torch.LongTensor] = None,
1305
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1306
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1307
+ labels: Optional[torch.LongTensor] = None,
1308
+ use_cache: Optional[bool] = None,
1309
+ output_attentions: Optional[bool] = None,
1310
+ output_hidden_states: Optional[bool] = None,
1311
+ return_dict: Optional[bool] = None,
1312
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1313
+ r"""
1314
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1315
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1316
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1317
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1318
+ """
1319
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1320
+
1321
+ transformer_outputs = self.model(
1322
+ input_ids,
1323
+ attention_mask=attention_mask,
1324
+ position_ids=position_ids,
1325
+ past_key_values=past_key_values,
1326
+ inputs_embeds=inputs_embeds,
1327
+ use_cache=use_cache,
1328
+ output_attentions=output_attentions,
1329
+ output_hidden_states=output_hidden_states,
1330
+ return_dict=return_dict,
1331
+ )
1332
+ hidden_states = transformer_outputs[0]
1333
+ logits = self.score(hidden_states)
1334
+
1335
+ if input_ids is not None:
1336
+ batch_size = input_ids.shape[0]
1337
+ else:
1338
+ batch_size = inputs_embeds.shape[0]
1339
+
1340
+ if self.config.pad_token_id is None and batch_size != 1:
1341
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1342
+ if self.config.pad_token_id is None:
1343
+ sequence_lengths = -1
1344
+ else:
1345
+ if input_ids is not None:
1346
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1347
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1348
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1349
+ sequence_lengths = sequence_lengths.to(logits.device)
1350
+ else:
1351
+ sequence_lengths = -1
1352
+
1353
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1354
+
1355
+ loss = None
1356
+ if labels is not None:
1357
+ labels = labels.to(logits.device)
1358
+ if self.config.problem_type is None:
1359
+ if self.num_labels == 1:
1360
+ self.config.problem_type = "regression"
1361
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1362
+ self.config.problem_type = "single_label_classification"
1363
+ else:
1364
+ self.config.problem_type = "multi_label_classification"
1365
+
1366
+ if self.config.problem_type == "regression":
1367
+ loss_fct = MSELoss()
1368
+ if self.num_labels == 1:
1369
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1370
+ else:
1371
+ loss = loss_fct(pooled_logits, labels)
1372
+ elif self.config.problem_type == "single_label_classification":
1373
+ loss_fct = CrossEntropyLoss()
1374
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1375
+ elif self.config.problem_type == "multi_label_classification":
1376
+ loss_fct = BCEWithLogitsLoss()
1377
+ loss = loss_fct(pooled_logits, labels)
1378
+ if not return_dict:
1379
+ output = (pooled_logits,) + transformer_outputs[1:]
1380
+ return ((loss,) + output) if loss is not None else output
1381
+
1382
+ return SequenceClassifierOutputWithPast(
1383
+ loss=loss,
1384
+ logits=pooled_logits,
1385
+ past_key_values=transformer_outputs.past_key_values,
1386
+ hidden_states=transformer_outputs.hidden_states,
1387
+ attentions=transformer_outputs.attentions,
1388
+ )
optims.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import math
9
+
10
+ from .registry import registry
11
+
12
+
13
+ @registry.register_lr_scheduler("linear_warmup_step_lr")
14
+ class LinearWarmupStepLRScheduler:
15
+ def __init__(
16
+ self,
17
+ optimizer,
18
+ max_epoch,
19
+ min_lr,
20
+ init_lr,
21
+ decay_rate=1,
22
+ warmup_start_lr=-1,
23
+ warmup_steps=0,
24
+ **kwargs
25
+ ):
26
+ self.optimizer = optimizer
27
+
28
+ self.max_epoch = max_epoch
29
+ self.min_lr = min_lr
30
+
31
+ self.decay_rate = decay_rate
32
+
33
+ self.init_lr = init_lr
34
+ self.warmup_steps = warmup_steps
35
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
36
+
37
+ def step(self, cur_epoch, cur_step):
38
+ if cur_epoch == 0:
39
+ warmup_lr_schedule(
40
+ step=cur_step,
41
+ optimizer=self.optimizer,
42
+ max_step=self.warmup_steps,
43
+ init_lr=self.warmup_start_lr,
44
+ max_lr=self.init_lr,
45
+ )
46
+ else:
47
+ step_lr_schedule(
48
+ epoch=cur_epoch,
49
+ optimizer=self.optimizer,
50
+ init_lr=self.init_lr,
51
+ min_lr=self.min_lr,
52
+ decay_rate=self.decay_rate,
53
+ )
54
+
55
+
56
+ @registry.register_lr_scheduler("linear_warmup_cosine_lr")
57
+ class LinearWarmupCosineLRScheduler:
58
+ def __init__(
59
+ self,
60
+ optimizer,
61
+ max_epoch,
62
+ iters_per_epoch,
63
+ min_lr,
64
+ init_lr,
65
+ warmup_steps=0,
66
+ warmup_start_lr=-1,
67
+ **kwargs
68
+ ):
69
+ self.optimizer = optimizer
70
+
71
+ self.max_epoch = max_epoch
72
+ self.iters_per_epoch = iters_per_epoch
73
+ self.min_lr = min_lr
74
+
75
+ self.init_lr = init_lr
76
+ self.warmup_steps = warmup_steps
77
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
78
+
79
+ def step(self, cur_epoch, cur_step):
80
+ total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
81
+ if total_cur_step < self.warmup_steps:
82
+ warmup_lr_schedule(
83
+ step=total_cur_step,
84
+ optimizer=self.optimizer,
85
+ max_step=self.warmup_steps,
86
+ init_lr=self.warmup_start_lr,
87
+ max_lr=self.init_lr,
88
+ )
89
+ else:
90
+ cosine_lr_schedule(
91
+ epoch=total_cur_step,
92
+ optimizer=self.optimizer,
93
+ max_epoch=self.max_epoch * self.iters_per_epoch,
94
+ init_lr=self.init_lr,
95
+ min_lr=self.min_lr,
96
+ )
97
+
98
+
99
+ def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
100
+ """Decay the learning rate"""
101
+ lr = (init_lr - min_lr) * 0.5 * (
102
+ 1.0 + math.cos(math.pi * epoch / max_epoch)
103
+ ) + min_lr
104
+ for param_group in optimizer.param_groups:
105
+ param_group["lr"] = lr
106
+
107
+
108
+ def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
109
+ """Warmup the learning rate"""
110
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
111
+ for param_group in optimizer.param_groups:
112
+ param_group["lr"] = lr
113
+
114
+
115
+ def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
116
+ """Decay the learning rate"""
117
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
118
+ for param_group in optimizer.param_groups:
119
+ param_group["lr"] = lr
randaugment.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import cv2
9
+ import numpy as np
10
+
11
+ import torch
12
+
13
+
14
+ ## aug functions
15
+ def identity_func(img):
16
+ return img
17
+
18
+
19
+ def autocontrast_func(img, cutoff=0):
20
+ """
21
+ same output as PIL.ImageOps.autocontrast
22
+ """
23
+ n_bins = 256
24
+
25
+ def tune_channel(ch):
26
+ n = ch.size
27
+ cut = cutoff * n // 100
28
+ if cut == 0:
29
+ high, low = ch.max(), ch.min()
30
+ else:
31
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
32
+ low = np.argwhere(np.cumsum(hist) > cut)
33
+ low = 0 if low.shape[0] == 0 else low[0]
34
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
35
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
36
+ if high <= low:
37
+ table = np.arange(n_bins)
38
+ else:
39
+ scale = (n_bins - 1) / (high - low)
40
+ offset = -low * scale
41
+ table = np.arange(n_bins) * scale + offset
42
+ table[table < 0] = 0
43
+ table[table > n_bins - 1] = n_bins - 1
44
+ table = table.clip(0, 255).astype(np.uint8)
45
+ return table[ch]
46
+
47
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
48
+ out = cv2.merge(channels)
49
+ return out
50
+
51
+
52
+ def equalize_func(img):
53
+ """
54
+ same output as PIL.ImageOps.equalize
55
+ PIL's implementation is different from cv2.equalize
56
+ """
57
+ n_bins = 256
58
+
59
+ def tune_channel(ch):
60
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
61
+ non_zero_hist = hist[hist != 0].reshape(-1)
62
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
63
+ if step == 0:
64
+ return ch
65
+ n = np.empty_like(hist)
66
+ n[0] = step // 2
67
+ n[1:] = hist[:-1]
68
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
69
+ return table[ch]
70
+
71
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
72
+ out = cv2.merge(channels)
73
+ return out
74
+
75
+
76
+ def rotate_func(img, degree, fill=(0, 0, 0)):
77
+ """
78
+ like PIL, rotate by degree, not radians
79
+ """
80
+ H, W = img.shape[0], img.shape[1]
81
+ center = W / 2, H / 2
82
+ M = cv2.getRotationMatrix2D(center, degree, 1)
83
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
84
+ return out
85
+
86
+
87
+ def solarize_func(img, thresh=128):
88
+ """
89
+ same output as PIL.ImageOps.posterize
90
+ """
91
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
92
+ table = table.clip(0, 255).astype(np.uint8)
93
+ out = table[img]
94
+ return out
95
+
96
+
97
+ def color_func(img, factor):
98
+ """
99
+ same output as PIL.ImageEnhance.Color
100
+ """
101
+ ## implementation according to PIL definition, quite slow
102
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
103
+ # out = blend(degenerate, img, factor)
104
+ # M = (
105
+ # np.eye(3) * factor
106
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
107
+ # )[np.newaxis, np.newaxis, :]
108
+ M = np.float32(
109
+ [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
110
+ ) * factor + np.float32([[0.114], [0.587], [0.299]])
111
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
112
+ return out
113
+
114
+
115
+ def contrast_func(img, factor):
116
+ """
117
+ same output as PIL.ImageEnhance.Contrast
118
+ """
119
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
120
+ table = (
121
+ np.array([(el - mean) * factor + mean for el in range(256)])
122
+ .clip(0, 255)
123
+ .astype(np.uint8)
124
+ )
125
+ out = table[img]
126
+ return out
127
+
128
+
129
+ def brightness_func(img, factor):
130
+ """
131
+ same output as PIL.ImageEnhance.Contrast
132
+ """
133
+ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
134
+ out = table[img]
135
+ return out
136
+
137
+
138
+ def sharpness_func(img, factor):
139
+ """
140
+ The differences the this result and PIL are all on the 4 boundaries, the center
141
+ areas are same
142
+ """
143
+ kernel = np.ones((3, 3), dtype=np.float32)
144
+ kernel[1][1] = 5
145
+ kernel /= 13
146
+ degenerate = cv2.filter2D(img, -1, kernel)
147
+ if factor == 0.0:
148
+ out = degenerate
149
+ elif factor == 1.0:
150
+ out = img
151
+ else:
152
+ out = img.astype(np.float32)
153
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
154
+ out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
155
+ out = out.astype(np.uint8)
156
+ return out
157
+
158
+
159
+ def shear_x_func(img, factor, fill=(0, 0, 0)):
160
+ H, W = img.shape[0], img.shape[1]
161
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
162
+ out = cv2.warpAffine(
163
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
164
+ ).astype(np.uint8)
165
+ return out
166
+
167
+
168
+ def translate_x_func(img, offset, fill=(0, 0, 0)):
169
+ """
170
+ same output as PIL.Image.transform
171
+ """
172
+ H, W = img.shape[0], img.shape[1]
173
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
174
+ out = cv2.warpAffine(
175
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
176
+ ).astype(np.uint8)
177
+ return out
178
+
179
+
180
+ def translate_y_func(img, offset, fill=(0, 0, 0)):
181
+ """
182
+ same output as PIL.Image.transform
183
+ """
184
+ H, W = img.shape[0], img.shape[1]
185
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
186
+ out = cv2.warpAffine(
187
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
188
+ ).astype(np.uint8)
189
+ return out
190
+
191
+
192
+ def posterize_func(img, bits):
193
+ """
194
+ same output as PIL.ImageOps.posterize
195
+ """
196
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
197
+ return out
198
+
199
+
200
+ def shear_y_func(img, factor, fill=(0, 0, 0)):
201
+ H, W = img.shape[0], img.shape[1]
202
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
203
+ out = cv2.warpAffine(
204
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
205
+ ).astype(np.uint8)
206
+ return out
207
+
208
+
209
+ def cutout_func(img, pad_size, replace=(0, 0, 0)):
210
+ replace = np.array(replace, dtype=np.uint8)
211
+ H, W = img.shape[0], img.shape[1]
212
+ rh, rw = np.random.random(2)
213
+ pad_size = pad_size // 2
214
+ ch, cw = int(rh * H), int(rw * W)
215
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
216
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
217
+ out = img.copy()
218
+ out[x1:x2, y1:y2, :] = replace
219
+ return out
220
+
221
+
222
+ ### level to args
223
+ def enhance_level_to_args(MAX_LEVEL):
224
+ def level_to_args(level):
225
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
226
+
227
+ return level_to_args
228
+
229
+
230
+ def shear_level_to_args(MAX_LEVEL, replace_value):
231
+ def level_to_args(level):
232
+ level = (level / MAX_LEVEL) * 0.3
233
+ if np.random.random() > 0.5:
234
+ level = -level
235
+ return (level, replace_value)
236
+
237
+ return level_to_args
238
+
239
+
240
+ def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
241
+ def level_to_args(level):
242
+ level = (level / MAX_LEVEL) * float(translate_const)
243
+ if np.random.random() > 0.5:
244
+ level = -level
245
+ return (level, replace_value)
246
+
247
+ return level_to_args
248
+
249
+
250
+ def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
251
+ def level_to_args(level):
252
+ level = int((level / MAX_LEVEL) * cutout_const)
253
+ return (level, replace_value)
254
+
255
+ return level_to_args
256
+
257
+
258
+ def solarize_level_to_args(MAX_LEVEL):
259
+ def level_to_args(level):
260
+ level = int((level / MAX_LEVEL) * 256)
261
+ return (level,)
262
+
263
+ return level_to_args
264
+
265
+
266
+ def none_level_to_args(level):
267
+ return ()
268
+
269
+
270
+ def posterize_level_to_args(MAX_LEVEL):
271
+ def level_to_args(level):
272
+ level = int((level / MAX_LEVEL) * 4)
273
+ return (level,)
274
+
275
+ return level_to_args
276
+
277
+
278
+ def rotate_level_to_args(MAX_LEVEL, replace_value):
279
+ def level_to_args(level):
280
+ level = (level / MAX_LEVEL) * 30
281
+ if np.random.random() < 0.5:
282
+ level = -level
283
+ return (level, replace_value)
284
+
285
+ return level_to_args
286
+
287
+
288
+ func_dict = {
289
+ "Identity": identity_func,
290
+ "AutoContrast": autocontrast_func,
291
+ "Equalize": equalize_func,
292
+ "Rotate": rotate_func,
293
+ "Solarize": solarize_func,
294
+ "Color": color_func,
295
+ "Contrast": contrast_func,
296
+ "Brightness": brightness_func,
297
+ "Sharpness": sharpness_func,
298
+ "ShearX": shear_x_func,
299
+ "TranslateX": translate_x_func,
300
+ "TranslateY": translate_y_func,
301
+ "Posterize": posterize_func,
302
+ "ShearY": shear_y_func,
303
+ }
304
+
305
+ translate_const = 10
306
+ MAX_LEVEL = 10
307
+ replace_value = (128, 128, 128)
308
+ arg_dict = {
309
+ "Identity": none_level_to_args,
310
+ "AutoContrast": none_level_to_args,
311
+ "Equalize": none_level_to_args,
312
+ "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
313
+ "Solarize": solarize_level_to_args(MAX_LEVEL),
314
+ "Color": enhance_level_to_args(MAX_LEVEL),
315
+ "Contrast": enhance_level_to_args(MAX_LEVEL),
316
+ "Brightness": enhance_level_to_args(MAX_LEVEL),
317
+ "Sharpness": enhance_level_to_args(MAX_LEVEL),
318
+ "ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
319
+ "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
320
+ "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
321
+ "Posterize": posterize_level_to_args(MAX_LEVEL),
322
+ "ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
323
+ }
324
+
325
+
326
+ class RandomAugment(object):
327
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
328
+ self.N = N
329
+ self.M = M
330
+ self.isPIL = isPIL
331
+ if augs:
332
+ self.augs = augs
333
+ else:
334
+ self.augs = list(arg_dict.keys())
335
+
336
+ def get_random_ops(self):
337
+ sampled_ops = np.random.choice(self.augs, self.N)
338
+ return [(op, 0.5, self.M) for op in sampled_ops]
339
+
340
+ def __call__(self, img):
341
+ if self.isPIL:
342
+ img = np.array(img)
343
+ ops = self.get_random_ops()
344
+ for name, prob, level in ops:
345
+ if np.random.random() > prob:
346
+ continue
347
+ args = arg_dict[name](level)
348
+ img = func_dict[name](img, *args)
349
+ return img
350
+
351
+
352
+ class VideoRandomAugment(object):
353
+ def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
354
+ self.N = N
355
+ self.M = M
356
+ self.p = p
357
+ self.tensor_in_tensor_out = tensor_in_tensor_out
358
+ if augs:
359
+ self.augs = augs
360
+ else:
361
+ self.augs = list(arg_dict.keys())
362
+
363
+ def get_random_ops(self):
364
+ sampled_ops = np.random.choice(self.augs, self.N, replace=False)
365
+ return [(op, self.M) for op in sampled_ops]
366
+
367
+ def __call__(self, frames):
368
+ assert (
369
+ frames.shape[-1] == 3
370
+ ), "Expecting last dimension for 3-channels RGB (b, h, w, c)."
371
+
372
+ if self.tensor_in_tensor_out:
373
+ frames = frames.numpy().astype(np.uint8)
374
+
375
+ num_frames = frames.shape[0]
376
+
377
+ ops = num_frames * [self.get_random_ops()]
378
+ apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]
379
+
380
+ frames = torch.stack(
381
+ list(map(self._aug, frames, ops, apply_or_not)), dim=0
382
+ ).float()
383
+
384
+ return frames
385
+
386
+ def _aug(self, img, ops, apply_or_not):
387
+ for i, (name, level) in enumerate(ops):
388
+ if not apply_or_not[i]:
389
+ continue
390
+ args = arg_dict[name](level)
391
+ img = func_dict[name](img, *args)
392
+ return torch.from_numpy(img)
393
+
394
+
395
+ if __name__ == "__main__":
396
+ a = RandomAugment()
397
+ img = np.random.randn(32, 32, 3)
398
+ a(img)
registry.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+
9
+ class Registry:
10
+ mapping = {
11
+ "builder_name_mapping": {},
12
+ "task_name_mapping": {},
13
+ "processor_name_mapping": {},
14
+ "model_name_mapping": {},
15
+ "lr_scheduler_name_mapping": {},
16
+ "runner_name_mapping": {},
17
+ "state": {},
18
+ "paths": {},
19
+ }
20
+
21
+
22
+ @classmethod
23
+ def register_model(cls, name):
24
+ r"""Register a task to registry with key 'name'
25
+
26
+ Args:
27
+ name: Key with which the task will be registered.
28
+
29
+ """
30
+
31
+ def wrap(model_cls):
32
+ from .base_model import BaseModel
33
+
34
+ assert issubclass(
35
+ model_cls, BaseModel
36
+ ), "All models must inherit BaseModel class"
37
+
38
+ if name in cls.mapping["model_name_mapping"]:
39
+ raise KeyError(
40
+ "Name '{}' already registered for {}.".format(
41
+ name, cls.mapping["model_name_mapping"][name]
42
+ )
43
+ )
44
+ cls.mapping["model_name_mapping"][name] = model_cls
45
+ return model_cls
46
+
47
+ return wrap
48
+
49
+ @classmethod
50
+ def register_processor(cls, name):
51
+ r"""Register a processor to registry with key 'name'
52
+
53
+ Args:
54
+ name: Key with which the task will be registered.
55
+
56
+ Usage:
57
+
58
+ from .registry import registry
59
+ """
60
+
61
+ def wrap(processor_cls):
62
+ from .base_processor import BaseProcessor
63
+
64
+ # assert issubclass(
65
+ # processor_cls, BaseProcessor
66
+ # ), "All processors must inherit BaseProcessor class"
67
+ # if name in cls.mapping["processor_name_mapping"]:
68
+ # raise KeyError(
69
+ # "Name '{}' already registered for {}.".format(
70
+ # name, cls.mapping["processor_name_mapping"][name]
71
+ # )
72
+ # )
73
+ cls.mapping["processor_name_mapping"][name] = processor_cls
74
+ return processor_cls
75
+
76
+ return wrap
77
+
78
+ @classmethod
79
+ def register_lr_scheduler(cls, name):
80
+ r"""Register a model to registry with key 'name'
81
+
82
+ Args:
83
+ name: Key with which the task will be registered.
84
+
85
+ Usage:
86
+
87
+ from .registry import registry
88
+ """
89
+
90
+ def wrap(lr_sched_cls):
91
+ if name in cls.mapping["lr_scheduler_name_mapping"]:
92
+ raise KeyError(
93
+ "Name '{}' already registered for {}.".format(
94
+ name, cls.mapping["lr_scheduler_name_mapping"][name]
95
+ )
96
+ )
97
+ cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
98
+ return lr_sched_cls
99
+
100
+ return wrap
101
+
102
+ @classmethod
103
+ def register_runner(cls, name):
104
+ r"""Register a model to registry with key 'name'
105
+
106
+ Args:
107
+ name: Key with which the task will be registered.
108
+
109
+ Usage:
110
+
111
+ .common.registry import registry
112
+ """
113
+
114
+ def wrap(runner_cls):
115
+ if name in cls.mapping["runner_name_mapping"]:
116
+ raise KeyError(
117
+ "Name '{}' already registered for {}.".format(
118
+ name, cls.mapping["runner_name_mapping"][name]
119
+ )
120
+ )
121
+ cls.mapping["runner_name_mapping"][name] = runner_cls
122
+ return runner_cls
123
+
124
+ return wrap
125
+
126
+ @classmethod
127
+ def register_path(cls, name, path):
128
+ r"""Register a path to registry with key 'name'
129
+
130
+ Args:
131
+ name: Key with which the path will be registered.
132
+
133
+ Usage:
134
+
135
+ from .registry import registry
136
+ """
137
+ assert isinstance(path, str), "All path must be str."
138
+ if name in cls.mapping["paths"]:
139
+ raise KeyError("Name '{}' already registered.".format(name))
140
+ cls.mapping["paths"][name] = path
141
+
142
+ @classmethod
143
+ def register(cls, name, obj):
144
+ r"""Register an item to registry with key 'name'
145
+
146
+ Args:
147
+ name: Key with which the item will be registered.
148
+
149
+ Usage::
150
+
151
+ from .registry import registry
152
+
153
+ registry.register("config", {})
154
+ """
155
+ path = name.split(".")
156
+ current = cls.mapping["state"]
157
+
158
+ for part in path[:-1]:
159
+ if part not in current:
160
+ current[part] = {}
161
+ current = current[part]
162
+
163
+ current[path[-1]] = obj
164
+
165
+ # @classmethod
166
+ # def get_trainer_class(cls, name):
167
+ # return cls.mapping["trainer_name_mapping"].get(name, None)
168
+
169
+ @classmethod
170
+ def get_builder_class(cls, name):
171
+ return cls.mapping["builder_name_mapping"].get(name, None)
172
+
173
+ @classmethod
174
+ def get_model_class(cls, name):
175
+ return cls.mapping["model_name_mapping"].get(name, None)
176
+
177
+ @classmethod
178
+ def get_task_class(cls, name):
179
+ return cls.mapping["task_name_mapping"].get(name, None)
180
+
181
+ @classmethod
182
+ def get_processor_class(cls, name):
183
+ return cls.mapping["processor_name_mapping"].get(name, None)
184
+
185
+ @classmethod
186
+ def get_lr_scheduler_class(cls, name):
187
+ return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
188
+
189
+ @classmethod
190
+ def get_runner_class(cls, name):
191
+ return cls.mapping["runner_name_mapping"].get(name, None)
192
+
193
+ @classmethod
194
+ def list_runners(cls):
195
+ return sorted(cls.mapping["runner_name_mapping"].keys())
196
+
197
+ @classmethod
198
+ def list_models(cls):
199
+ return sorted(cls.mapping["model_name_mapping"].keys())
200
+
201
+ @classmethod
202
+ def list_tasks(cls):
203
+ return sorted(cls.mapping["task_name_mapping"].keys())
204
+
205
+ @classmethod
206
+ def list_processors(cls):
207
+ return sorted(cls.mapping["processor_name_mapping"].keys())
208
+
209
+ @classmethod
210
+ def list_lr_schedulers(cls):
211
+ return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
212
+
213
+ @classmethod
214
+ def list_datasets(cls):
215
+ return sorted(cls.mapping["builder_name_mapping"].keys())
216
+
217
+ @classmethod
218
+ def get_path(cls, name):
219
+ return cls.mapping["paths"].get(name, None)
220
+
221
+ @classmethod
222
+ def get(cls, name, default=None, no_warning=False):
223
+ r"""Get an item from registry with key 'name'
224
+
225
+ Args:
226
+ name (string): Key whose value needs to be retrieved.
227
+ default: If passed and key is not in registry, default value will
228
+ be returned with a warning. Default: None
229
+ no_warning (bool): If passed as True, warning when key doesn't exist
230
+ will not be generated. Useful for MMF's
231
+ internal operations. Default: False
232
+ """
233
+ original_name = name
234
+ name = name.split(".")
235
+ value = cls.mapping["state"]
236
+ for subname in name:
237
+ value = value.get(subname, default)
238
+ if value is default:
239
+ break
240
+
241
+ if (
242
+ "writer" in cls.mapping["state"]
243
+ and value == default
244
+ and no_warning is False
245
+ ):
246
+ cls.mapping["state"]["writer"].warning(
247
+ "Key {} is not present in registry, returning default value "
248
+ "of {}".format(original_name, default)
249
+ )
250
+ return value
251
+
252
+ @classmethod
253
+ def unregister(cls, name):
254
+ r"""Remove an item from registry with key 'name'
255
+
256
+ Args:
257
+ name: Key which needs to be removed.
258
+ Usage::
259
+
260
+ from registry import registry
261
+
262
+ config = registry.unregister("config")
263
+ """
264
+ return cls.mapping["state"].pop(name, None)
265
+
266
+
267
+ registry = Registry()
utils.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import io
9
+ import json
10
+ import logging
11
+ import os
12
+ import pickle
13
+ import re
14
+ import shutil
15
+ import urllib
16
+ import urllib.error
17
+ import urllib.request
18
+ from typing import Optional
19
+ from urllib.parse import urlparse
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+ import yaml
24
+ from iopath.common.download import download
25
+ from iopath.common.file_io import file_lock, g_pathmgr
26
+ from .registry import registry
27
+ from torch.utils.model_zoo import tqdm
28
+ from torchvision.datasets.utils import (
29
+ check_integrity,
30
+ download_file_from_google_drive,
31
+ extract_archive,
32
+ )
33
+
34
+
35
+ def now():
36
+ from datetime import datetime
37
+
38
+ return datetime.now().strftime("%Y%m%d%H%M")
39
+
40
+
41
+ def is_url(url_or_filename):
42
+ parsed = urlparse(url_or_filename)
43
+ return parsed.scheme in ("http", "https")
44
+
45
+
46
+ def get_cache_path(rel_path):
47
+ return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
48
+
49
+
50
+ def get_abs_path(rel_path):
51
+ return os.path.join(registry.get_path("library_root"), rel_path)
52
+
53
+
54
+ def load_json(filename):
55
+ with open(filename, "r") as f:
56
+ return json.load(f)
57
+
58
+
59
+ # The following are adapted from torchvision and vissl
60
+ # torchvision: https://github.com/pytorch/vision
61
+ # vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
62
+
63
+
64
+ def makedir(dir_path):
65
+ """
66
+ Create the directory if it does not exist.
67
+ """
68
+ is_success = False
69
+ try:
70
+ if not g_pathmgr.exists(dir_path):
71
+ g_pathmgr.mkdirs(dir_path)
72
+ is_success = True
73
+ except BaseException:
74
+ print(f"Error creating directory: {dir_path}")
75
+ return is_success
76
+
77
+
78
+ def get_redirected_url(url: str):
79
+ """
80
+ Given a URL, returns the URL it redirects to or the
81
+ original URL in case of no indirection
82
+ """
83
+ import requests
84
+
85
+ with requests.Session() as session:
86
+ with session.get(url, stream=True, allow_redirects=True) as response:
87
+ if response.history:
88
+ return response.url
89
+ else:
90
+ return url
91
+
92
+
93
+ def to_google_drive_download_url(view_url: str) -> str:
94
+ """
95
+ Utility function to transform a view URL of google drive
96
+ to a download URL for google drive
97
+ Example input:
98
+ https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
99
+ Example output:
100
+ https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
101
+ """
102
+ splits = view_url.split("/")
103
+ assert splits[-1] == "view"
104
+ file_id = splits[-2]
105
+ return f"https://drive.google.com/uc?export=download&id={file_id}"
106
+
107
+
108
+ def download_google_drive_url(url: str, output_path: str, output_file_name: str):
109
+ """
110
+ Download a file from google drive
111
+ Downloading an URL from google drive requires confirmation when
112
+ the file of the size is too big (google drive notifies that
113
+ anti-viral checks cannot be performed on such files)
114
+ """
115
+ import requests
116
+
117
+ with requests.Session() as session:
118
+
119
+ # First get the confirmation token and append it to the URL
120
+ with session.get(url, stream=True, allow_redirects=True) as response:
121
+ for k, v in response.cookies.items():
122
+ if k.startswith("download_warning"):
123
+ url = url + "&confirm=" + v
124
+
125
+ # Then download the content of the file
126
+ with session.get(url, stream=True, verify=True) as response:
127
+ makedir(output_path)
128
+ path = os.path.join(output_path, output_file_name)
129
+ total_size = int(response.headers.get("Content-length", 0))
130
+ with open(path, "wb") as file:
131
+ from tqdm import tqdm
132
+
133
+ with tqdm(total=total_size) as progress_bar:
134
+ for block in response.iter_content(
135
+ chunk_size=io.DEFAULT_BUFFER_SIZE
136
+ ):
137
+ file.write(block)
138
+ progress_bar.update(len(block))
139
+
140
+
141
+ def _get_google_drive_file_id(url: str) -> Optional[str]:
142
+ parts = urlparse(url)
143
+
144
+ if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
145
+ return None
146
+
147
+ match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
148
+ if match is None:
149
+ return None
150
+
151
+ return match.group("id")
152
+
153
+
154
+ def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
155
+ with open(filename, "wb") as fh:
156
+ with urllib.request.urlopen(
157
+ urllib.request.Request(url, headers={"User-Agent": "vissl"})
158
+ ) as response:
159
+ with tqdm(total=response.length) as pbar:
160
+ for chunk in iter(lambda: response.read(chunk_size), ""):
161
+ if not chunk:
162
+ break
163
+ pbar.update(chunk_size)
164
+ fh.write(chunk)
165
+
166
+
167
+ def download_url(
168
+ url: str,
169
+ root: str,
170
+ filename: Optional[str] = None,
171
+ md5: Optional[str] = None,
172
+ ) -> None:
173
+ """Download a file from a url and place it in root.
174
+ Args:
175
+ url (str): URL to download file from
176
+ root (str): Directory to place downloaded file in
177
+ filename (str, optional): Name to save the file under.
178
+ If None, use the basename of the URL.
179
+ md5 (str, optional): MD5 checksum of the download. If None, do not check
180
+ """
181
+ root = os.path.expanduser(root)
182
+ if not filename:
183
+ filename = os.path.basename(url)
184
+ fpath = os.path.join(root, filename)
185
+
186
+ makedir(root)
187
+
188
+ # check if file is already present locally
189
+ if check_integrity(fpath, md5):
190
+ print("Using downloaded and verified file: " + fpath)
191
+ return
192
+
193
+ # expand redirect chain if needed
194
+ url = get_redirected_url(url)
195
+
196
+ # check if file is located on Google Drive
197
+ file_id = _get_google_drive_file_id(url)
198
+ if file_id is not None:
199
+ return download_file_from_google_drive(file_id, root, filename, md5)
200
+
201
+ # download the file
202
+ try:
203
+ print("Downloading " + url + " to " + fpath)
204
+ _urlretrieve(url, fpath)
205
+ except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
206
+ if url[:5] == "https":
207
+ url = url.replace("https:", "http:")
208
+ print(
209
+ "Failed download. Trying https -> http instead."
210
+ " Downloading " + url + " to " + fpath
211
+ )
212
+ _urlretrieve(url, fpath)
213
+ else:
214
+ raise e
215
+
216
+ # check integrity of downloaded file
217
+ if not check_integrity(fpath, md5):
218
+ raise RuntimeError("File not found or corrupted.")
219
+
220
+
221
+ def download_and_extract_archive(
222
+ url: str,
223
+ download_root: str,
224
+ extract_root: Optional[str] = None,
225
+ filename: Optional[str] = None,
226
+ md5: Optional[str] = None,
227
+ remove_finished: bool = False,
228
+ ) -> None:
229
+ download_root = os.path.expanduser(download_root)
230
+ if extract_root is None:
231
+ extract_root = download_root
232
+ if not filename:
233
+ filename = os.path.basename(url)
234
+
235
+ download_url(url, download_root, filename, md5)
236
+
237
+ archive = os.path.join(download_root, filename)
238
+ print("Extracting {} to {}".format(archive, extract_root))
239
+ extract_archive(archive, extract_root, remove_finished)
240
+
241
+
242
+ def cache_url(url: str, cache_dir: str) -> str:
243
+ """
244
+ This implementation downloads the remote resource and caches it locally.
245
+ The resource will only be downloaded if not previously requested.
246
+ """
247
+ parsed_url = urlparse(url)
248
+ dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
249
+ makedir(dirname)
250
+ filename = url.split("/")[-1]
251
+ cached = os.path.join(dirname, filename)
252
+ with file_lock(cached):
253
+ if not os.path.isfile(cached):
254
+ logging.info(f"Downloading {url} to {cached} ...")
255
+ cached = download(url, dirname, filename=filename)
256
+ logging.info(f"URL {url} cached in {cached}")
257
+ return cached
258
+
259
+
260
+ # TODO (prigoyal): convert this into RAII-style API
261
+ def create_file_symlink(file1, file2):
262
+ """
263
+ Simply create the symlinks for a given file1 to file2.
264
+ Useful during model checkpointing to symlinks to the
265
+ latest successful checkpoint.
266
+ """
267
+ try:
268
+ if g_pathmgr.exists(file2):
269
+ g_pathmgr.rm(file2)
270
+ g_pathmgr.symlink(file1, file2)
271
+ except Exception as e:
272
+ logging.info(f"Could NOT create symlink. Error: {e}")
273
+
274
+
275
+ def save_file(data, filename, append_to_json=True, verbose=True):
276
+ """
277
+ Common i/o utility to handle saving data to various file formats.
278
+ Supported:
279
+ .pkl, .pickle, .npy, .json
280
+ Specifically for .json, users have the option to either append (default)
281
+ or rewrite by passing in Boolean value to append_to_json.
282
+ """
283
+ if verbose:
284
+ logging.info(f"Saving data to file: {filename}")
285
+ file_ext = os.path.splitext(filename)[1]
286
+ if file_ext in [".pkl", ".pickle"]:
287
+ with g_pathmgr.open(filename, "wb") as fopen:
288
+ pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
289
+ elif file_ext == ".npy":
290
+ with g_pathmgr.open(filename, "wb") as fopen:
291
+ np.save(fopen, data)
292
+ elif file_ext == ".json":
293
+ if append_to_json:
294
+ with g_pathmgr.open(filename, "a") as fopen:
295
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
296
+ fopen.flush()
297
+ else:
298
+ with g_pathmgr.open(filename, "w") as fopen:
299
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
300
+ fopen.flush()
301
+ elif file_ext == ".yaml":
302
+ with g_pathmgr.open(filename, "w") as fopen:
303
+ dump = yaml.dump(data)
304
+ fopen.write(dump)
305
+ fopen.flush()
306
+ else:
307
+ raise Exception(f"Saving {file_ext} is not supported yet")
308
+
309
+ if verbose:
310
+ logging.info(f"Saved data to file: {filename}")
311
+
312
+
313
+ def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
314
+ """
315
+ Common i/o utility to handle loading data from various file formats.
316
+ Supported:
317
+ .pkl, .pickle, .npy, .json
318
+ For the npy files, we support reading the files in mmap_mode.
319
+ If the mmap_mode of reading is not successful, we load data without the
320
+ mmap_mode.
321
+ """
322
+ if verbose:
323
+ logging.info(f"Loading data from file: {filename}")
324
+
325
+ file_ext = os.path.splitext(filename)[1]
326
+ if file_ext == ".txt":
327
+ with g_pathmgr.open(filename, "r") as fopen:
328
+ data = fopen.readlines()
329
+ elif file_ext in [".pkl", ".pickle"]:
330
+ with g_pathmgr.open(filename, "rb") as fopen:
331
+ data = pickle.load(fopen, encoding="latin1")
332
+ elif file_ext == ".npy":
333
+ if mmap_mode:
334
+ try:
335
+ with g_pathmgr.open(filename, "rb") as fopen:
336
+ data = np.load(
337
+ fopen,
338
+ allow_pickle=allow_pickle,
339
+ encoding="latin1",
340
+ mmap_mode=mmap_mode,
341
+ )
342
+ except ValueError as e:
343
+ logging.info(
344
+ f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
345
+ )
346
+ data = np.load(
347
+ filename,
348
+ allow_pickle=allow_pickle,
349
+ encoding="latin1",
350
+ mmap_mode=mmap_mode,
351
+ )
352
+ logging.info("Successfully loaded without g_pathmgr")
353
+ except Exception:
354
+ logging.info("Could not mmap without g_pathmgr. Trying without mmap")
355
+ with g_pathmgr.open(filename, "rb") as fopen:
356
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
357
+ else:
358
+ with g_pathmgr.open(filename, "rb") as fopen:
359
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
360
+ elif file_ext == ".json":
361
+ with g_pathmgr.open(filename, "r") as fopen:
362
+ data = json.load(fopen)
363
+ elif file_ext == ".yaml":
364
+ with g_pathmgr.open(filename, "r") as fopen:
365
+ data = yaml.load(fopen, Loader=yaml.FullLoader)
366
+ elif file_ext == ".csv":
367
+ with g_pathmgr.open(filename, "r") as fopen:
368
+ data = pd.read_csv(fopen)
369
+ else:
370
+ raise Exception(f"Reading from {file_ext} is not supported yet")
371
+ return data
372
+
373
+
374
+ def abspath(resource_path: str):
375
+ """
376
+ Make a path absolute, but take into account prefixes like
377
+ "http://" or "manifold://"
378
+ """
379
+ regex = re.compile(r"^\w+://")
380
+ if regex.match(resource_path) is None:
381
+ return os.path.abspath(resource_path)
382
+ else:
383
+ return resource_path
384
+
385
+
386
+ def makedir(dir_path):
387
+ """
388
+ Create the directory if it does not exist.
389
+ """
390
+ is_success = False
391
+ try:
392
+ if not g_pathmgr.exists(dir_path):
393
+ g_pathmgr.mkdirs(dir_path)
394
+ is_success = True
395
+ except BaseException:
396
+ logging.info(f"Error creating directory: {dir_path}")
397
+ return is_success
398
+
399
+
400
+ def is_url(input_url):
401
+ """
402
+ Check if an input string is a url. look for http(s):// and ignoring the case
403
+ """
404
+ is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
405
+ return is_url
406
+
407
+
408
+ def cleanup_dir(dir):
409
+ """
410
+ Utility for deleting a directory. Useful for cleaning the storage space
411
+ that contains various training artifacts like checkpoints, data etc.
412
+ """
413
+ if os.path.exists(dir):
414
+ logging.info(f"Deleting directory: {dir}")
415
+ shutil.rmtree(dir)
416
+ logging.info(f"Deleted contents of directory: {dir}")
417
+
418
+
419
+ def get_file_size(filename):
420
+ """
421
+ Given a file, get the size of file in MB
422
+ """
423
+ size_in_mb = os.path.getsize(filename) / float(1024**2)
424
+ return size_in_mb
425
+
426
+ from typing import Dict, List, Protocol, Tuple
427
+
428
+ import torch
429
+ from torch.func import functional_call
430
+
431
+ from vllm.multimodal import BatchedTensors
432
+ from vllm.utils import is_pin_memory_available
433
+
434
+
435
+ def merge_vision_embeddings(input_ids: torch.Tensor,
436
+ inputs_embeds: torch.Tensor,
437
+ vision_embeddings: BatchedTensors,
438
+ image_token_id: int) -> torch.Tensor:
439
+ """
440
+ Merge `vision_embeddings` into `inputs_embeds` by overwriting the positions
441
+ in `inputs_embeds` corresponding to placeholder image tokens in `input_ids`.
442
+
443
+ Note:
444
+ This updates `inputs_embeds` in place.
445
+ """
446
+ mask = (input_ids == image_token_id)
447
+ num_expected_tokens = mask.sum()
448
+
449
+ if isinstance(vision_embeddings, torch.Tensor):
450
+ batch_size, batch_tokens, *_, embed_dim = vision_embeddings.shape
451
+ total_tokens = batch_size * batch_tokens
452
+ if num_expected_tokens != total_tokens:
453
+ expr = f"{batch_size} x {batch_tokens}"
454
+ raise ValueError(
455
+ f"Attempted to assign {expr} = {total_tokens} "
456
+ f"image tokens to {num_expected_tokens} placeholders")
457
+
458
+ inputs_embeds[mask] = vision_embeddings.view(total_tokens, embed_dim)
459
+ else:
460
+ size_per_batch = [t.shape[0] for t in vision_embeddings]
461
+ total_tokens = sum(size_per_batch)
462
+ if num_expected_tokens != total_tokens:
463
+ expr = ' + '.join(map(str, size_per_batch))
464
+ raise ValueError(
465
+ f"Attempted to assign {expr} = {total_tokens} "
466
+ f"image tokens to {num_expected_tokens} placeholders")
467
+
468
+ inputs_embeds[mask] = torch.cat(vision_embeddings)
469
+
470
+ return inputs_embeds