Commit
•
ab85772
1
Parent(s):
7c4a80c
alibi (#19)
Browse files- feat: support alibi (b4903887bd152b045811fb78a5edc369a8db7cb5)
Co-authored-by: Jack Min Ong <[email protected]>
- embedding.py +1 -1
- modeling_xlm_roberta.py +2 -1
embedding.py
CHANGED
@@ -50,7 +50,7 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
50 |
embeddings = self.word_embeddings(input_ids)
|
51 |
if self.max_position_embeddings > 0:
|
52 |
if position_ids is None:
|
53 |
-
position_ids =create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
|
54 |
# position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
|
55 |
position_embeddings = self.position_embeddings(position_ids)
|
56 |
embeddings = embeddings + position_embeddings
|
|
|
50 |
embeddings = self.word_embeddings(input_ids)
|
51 |
if self.max_position_embeddings > 0:
|
52 |
if position_ids is None:
|
53 |
+
position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
|
54 |
# position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
|
55 |
position_embeddings = self.position_embeddings(position_ids)
|
56 |
embeddings = embeddings + position_embeddings
|
modeling_xlm_roberta.py
CHANGED
@@ -109,6 +109,7 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
|
109 |
fused_bias_fc=fused_bias_fc,
|
110 |
use_flash_attn=use_flash_attn,
|
111 |
return_residual=return_residual,
|
|
|
112 |
**rotary_kwargs,
|
113 |
)
|
114 |
return mixer_cls
|
@@ -429,7 +430,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
429 |
self.embeddings = XLMRobertaEmbeddings(
|
430 |
config.hidden_size,
|
431 |
config.vocab_size,
|
432 |
-
config.max_position_embeddings,
|
433 |
config.type_vocab_size,
|
434 |
padding_idx=config.pad_token_id,
|
435 |
)
|
|
|
109 |
fused_bias_fc=fused_bias_fc,
|
110 |
use_flash_attn=use_flash_attn,
|
111 |
return_residual=return_residual,
|
112 |
+
use_alibi=config.position_embedding_type == 'alibi',
|
113 |
**rotary_kwargs,
|
114 |
)
|
115 |
return mixer_cls
|
|
|
430 |
self.embeddings = XLMRobertaEmbeddings(
|
431 |
config.hidden_size,
|
432 |
config.vocab_size,
|
433 |
+
config.max_position_embeddings if config.position_embedding_type == 'absolute' else -1,
|
434 |
config.type_vocab_size,
|
435 |
padding_idx=config.pad_token_id,
|
436 |
)
|