jupyterjazz commited on
Commit
c55e591
1 Parent(s): b27fa55

refactor: truncation fn

Browse files

Signed-off-by: jupyterjazz <[email protected]>

Files changed (1) hide show
  1. modeling_xlm_roberta.py +14 -9
modeling_xlm_roberta.py CHANGED
@@ -579,15 +579,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
579
  all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
580
 
581
  if truncate_dim:
582
- if not self.config.matryoshka_dimensions:
583
- logger.warning(
584
- 'Matryoshka embeddings are not supported, so dimension truncation will not be performed.'
585
- )
586
- elif truncate_dim in self.config.matryoshka_dimensions:
587
- all_embeddings = [tensor[:truncate_dim] for tensor in all_embeddings]
588
- else:
589
- raise ValueError(f'The provided `truncate_dim` value of {truncate_dim} is not supported. '
590
- f'Supported dimensions are {self.config.matryoshka_dimensions}.')
591
 
592
  if convert_to_tensor:
593
  all_embeddings = torch.stack(all_embeddings)
@@ -600,6 +592,19 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
600
  self.train(is_training)
601
  return all_embeddings
602
 
 
 
 
 
 
 
 
 
 
 
 
 
 
603
  def mean_pooling(
604
  self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
605
  ):
 
579
  all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
580
 
581
  if truncate_dim:
582
+ all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
 
 
 
 
 
 
 
 
583
 
584
  if convert_to_tensor:
585
  all_embeddings = torch.stack(all_embeddings)
 
592
  self.train(is_training)
593
  return all_embeddings
594
 
595
+
596
+ def truncate_embeddings(self, embeddings, truncate_dim):
597
+ if not self.config.matryoshka_dimensions:
598
+ logger.warning(
599
+ 'Matryoshka embeddings are not supported, so dimension truncation will not be performed.'
600
+ )
601
+ return embeddings
602
+ elif truncate_dim in self.config.matryoshka_dimensions:
603
+ return [tensor[:truncate_dim] for tensor in embeddings]
604
+ else:
605
+ raise ValueError(f'The provided `truncate_dim` value of {truncate_dim} is not supported. '
606
+ f'Supported dimensions are {self.config.matryoshka_dimensions}.')
607
+
608
  def mean_pooling(
609
  self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
610
  ):