Upload modeling_kosmos2.py
Browse files- modeling_kosmos2.py +1 -1
modeling_kosmos2.py
CHANGED
@@ -1429,7 +1429,7 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel):
|
|
1429 |
batch_size, seq_len = input_ids.size()
|
1430 |
mask_len = img_attn_mask.size()[-1]
|
1431 |
img_attn_mask = torch.cat(
|
1432 |
-
(img_attn_mask, torch.zeros(size=(batch_size, seq_len - mask_len), dtype=torch.bool)), dim=1
|
1433 |
)
|
1434 |
|
1435 |
return {
|
|
|
1429 |
batch_size, seq_len = input_ids.size()
|
1430 |
mask_len = img_attn_mask.size()[-1]
|
1431 |
img_attn_mask = torch.cat(
|
1432 |
+
(img_attn_mask, torch.zeros(size=(batch_size, seq_len - mask_len), dtype=torch.bool, device=input_ids.device)), dim=1
|
1433 |
)
|
1434 |
|
1435 |
return {
|