jupyterjazz
commited on
Commit
•
943cec2
1
Parent(s):
c55e591
feat: truncation option during init
Browse filesSigned-off-by: jupyterjazz <[email protected]>
- configuration_xlm_roberta.py +2 -0
- modeling_xlm_roberta.py +1 -0
configuration_xlm_roberta.py
CHANGED
@@ -32,6 +32,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
32 |
torch_dtype=None,
|
33 |
emb_pooler=None,
|
34 |
matryoshka_dimensions=None,
|
|
|
35 |
**kwargs,
|
36 |
):
|
37 |
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
@@ -61,6 +62,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
61 |
self.use_flash_attn = use_flash_attn
|
62 |
self.emb_pooler = emb_pooler
|
63 |
self.matryoshka_dimensions = matryoshka_dimensions
|
|
|
64 |
if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
|
65 |
self.torch_dtype = getattr(torch, torch_dtype)
|
66 |
else:
|
|
|
32 |
torch_dtype=None,
|
33 |
emb_pooler=None,
|
34 |
matryoshka_dimensions=None,
|
35 |
+
truncate_dim=None,
|
36 |
**kwargs,
|
37 |
):
|
38 |
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
|
|
62 |
self.use_flash_attn = use_flash_attn
|
63 |
self.emb_pooler = emb_pooler
|
64 |
self.matryoshka_dimensions = matryoshka_dimensions
|
65 |
+
self.truncate_dim = truncate_dim
|
66 |
if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
|
67 |
self.torch_dtype = getattr(torch, torch_dtype)
|
68 |
else:
|
modeling_xlm_roberta.py
CHANGED
@@ -578,6 +578,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
578 |
|
579 |
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
580 |
|
|
|
581 |
if truncate_dim:
|
582 |
all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
|
583 |
|
|
|
578 |
|
579 |
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
580 |
|
581 |
+
truncate_dim = truncate_dim or self.config.truncate_dim
|
582 |
if truncate_dim:
|
583 |
all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
|
584 |
|