Jackmin108
commited on
Commit
•
814cbbb
1
Parent(s):
65e9690
some fixes and suggestions
Browse filesSigned-off-by: Meow <[email protected]>
- embedding.py +2 -2
- mha.py +6 -3
- mlp.py +2 -2
- modeling_lora.py +5 -3
- modeling_xlm_roberta.py +2 -1
embedding.py
CHANGED
@@ -48,7 +48,7 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
48 |
"""
|
49 |
batch_size, seqlen = input_ids.shape
|
50 |
if adapter_mask is not None:
|
51 |
-
unique_tasks = torch.unique(adapter_mask)
|
52 |
embedding_dtype = next(self.word_embeddings.parameters()).dtype
|
53 |
embeddings = torch.empty(*input_ids.shape, self.word_embeddings.embedding_dim,
|
54 |
dtype=embedding_dtype, device=input_ids.device)
|
@@ -71,7 +71,7 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
71 |
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
|
72 |
|
73 |
if adapter_mask is not None:
|
74 |
-
unique_tasks = torch.unique(adapter_mask)
|
75 |
for task_id in unique_tasks:
|
76 |
task_token_type_embeddings = self.token_type_embeddings(token_type_ids, task_id=task_id)
|
77 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
|
|
48 |
"""
|
49 |
batch_size, seqlen = input_ids.shape
|
50 |
if adapter_mask is not None:
|
51 |
+
unique_tasks = torch.unique(adapter_mask)
|
52 |
embedding_dtype = next(self.word_embeddings.parameters()).dtype
|
53 |
embeddings = torch.empty(*input_ids.shape, self.word_embeddings.embedding_dim,
|
54 |
dtype=embedding_dtype, device=input_ids.device)
|
|
|
71 |
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
|
72 |
|
73 |
if adapter_mask is not None:
|
74 |
+
unique_tasks = torch.unique(adapter_mask)
|
75 |
for task_id in unique_tasks:
|
76 |
task_token_type_embeddings = self.token_type_embeddings(token_type_ids, task_id=task_id)
|
77 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
mha.py
CHANGED
@@ -647,7 +647,7 @@ class MHA(nn.Module):
|
|
647 |
assert x_kv is None and mixer_subset is None
|
648 |
|
649 |
if cu_adapter_mask is not None:
|
650 |
-
unique_tasks = torch.unique(cu_adapter_mask)
|
651 |
qkv_dtype = next(self.Wqkv.parameters()).dtype
|
652 |
qkv = torch.empty(x.shape[0], self.Wqkv.out_features,
|
653 |
dtype=qkv_dtype, device=x.device)
|
@@ -663,7 +663,10 @@ class MHA(nn.Module):
|
|
663 |
if not self.return_residual:
|
664 |
qkv = self.Wqkv(x)
|
665 |
else:
|
666 |
-
|
|
|
|
|
|
|
667 |
|
668 |
if self.dwconv:
|
669 |
qkv = rearrange(
|
@@ -752,7 +755,7 @@ class MHA(nn.Module):
|
|
752 |
|
753 |
inp = rearrange(context, "... h d -> ... (h d)")
|
754 |
if cu_adapter_mask is not None:
|
755 |
-
unique_tasks = torch.unique(cu_adapter_mask)
|
756 |
out_dtype = next(self.out_proj.parameters()).dtype
|
757 |
out = torch.empty(inp.shape[0], self.out_proj.out_features,
|
758 |
dtype=out_dtype, device=inp.device)
|
|
|
647 |
assert x_kv is None and mixer_subset is None
|
648 |
|
649 |
if cu_adapter_mask is not None:
|
650 |
+
unique_tasks = torch.unique(cu_adapter_mask)
|
651 |
qkv_dtype = next(self.Wqkv.parameters()).dtype
|
652 |
qkv = torch.empty(x.shape[0], self.Wqkv.out_features,
|
653 |
dtype=qkv_dtype, device=x.device)
|
|
|
663 |
if not self.return_residual:
|
664 |
qkv = self.Wqkv(x)
|
665 |
else:
|
666 |
+
if hasattr(self.Wqkv, 'parametrizations'):
|
667 |
+
qkv, x = self.Wqkv(x, residual=True)
|
668 |
+
else:
|
669 |
+
qkv, x = self.Wqkv(x)
|
670 |
|
671 |
if self.dwconv:
|
672 |
qkv = rearrange(
|
|
|
755 |
|
756 |
inp = rearrange(context, "... h d -> ... (h d)")
|
757 |
if cu_adapter_mask is not None:
|
758 |
+
unique_tasks = torch.unique(cu_adapter_mask)
|
759 |
out_dtype = next(self.out_proj.parameters()).dtype
|
760 |
out = torch.empty(inp.shape[0], self.out_proj.out_features,
|
761 |
dtype=out_dtype, device=inp.device)
|
mlp.py
CHANGED
@@ -49,7 +49,7 @@ class Mlp(nn.Module):
|
|
49 |
|
50 |
def forward(self, x, cu_adapter_mask=None):
|
51 |
if cu_adapter_mask is not None:
|
52 |
-
unique_tasks = torch.unique(cu_adapter_mask)
|
53 |
fc1_dtype = next(self.fc1.parameters()).dtype
|
54 |
y = torch.empty(x.shape[0], self.fc1.out_features,
|
55 |
dtype=fc1_dtype, device=x.device)
|
@@ -64,7 +64,7 @@ class Mlp(nn.Module):
|
|
64 |
y = self.activation(y)
|
65 |
|
66 |
if cu_adapter_mask is not None:
|
67 |
-
unique_tasks = torch.unique(cu_adapter_mask)
|
68 |
fc2_dtype = next(self.fc2.parameters()).dtype
|
69 |
out = torch.empty(y.shape[0], self.fc2.out_features,
|
70 |
dtype=fc2_dtype, device=y.device)
|
|
|
49 |
|
50 |
def forward(self, x, cu_adapter_mask=None):
|
51 |
if cu_adapter_mask is not None:
|
52 |
+
unique_tasks = torch.unique(cu_adapter_mask)
|
53 |
fc1_dtype = next(self.fc1.parameters()).dtype
|
54 |
y = torch.empty(x.shape[0], self.fc1.out_features,
|
55 |
dtype=fc1_dtype, device=x.device)
|
|
|
64 |
y = self.activation(y)
|
65 |
|
66 |
if cu_adapter_mask is not None:
|
67 |
+
unique_tasks = torch.unique(cu_adapter_mask)
|
68 |
fc2_dtype = next(self.fc2.parameters()).dtype
|
69 |
out = torch.empty(y.shape[0], self.fc2.out_features,
|
70 |
dtype=fc2_dtype, device=y.device)
|
modeling_lora.py
CHANGED
@@ -355,7 +355,9 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
355 |
f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
|
356 |
f"Alternatively, don't pass the `task_type` argument to disable LoRA."
|
357 |
)
|
358 |
-
|
359 |
-
|
360 |
-
|
|
|
|
|
361 |
return self.roberta.encode(sentences, *args, adapter_mask=adapter_mask, **kwargs)
|
|
|
355 |
f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
|
356 |
f"Alternatively, don't pass the `task_type` argument to disable LoRA."
|
357 |
)
|
358 |
+
adapter_mask = None
|
359 |
+
if task_type:
|
360 |
+
task_id = self._adaptation_map[task_type]
|
361 |
+
num_examples = 1 if isinstance(sentences, str) else len(sentences)
|
362 |
+
adapter_mask = torch.full((num_examples,), task_id, dtype=torch.int32, device=self.device)
|
363 |
return self.roberta.encode(sentences, *args, adapter_mask=adapter_mask, **kwargs)
|
modeling_xlm_roberta.py
CHANGED
@@ -314,7 +314,7 @@ class XLMRobertaPooler(nn.Module):
|
|
314 |
# to the first token.
|
315 |
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
316 |
if adapter_mask is not None:
|
317 |
-
unique_tasks = torch.unique(adapter_mask)
|
318 |
pool_dtype = next(self.dense.parameters()).dtype
|
319 |
pooled_output = torch.empty(first_token_tensor.shape[0], self.dense.out_features,
|
320 |
dtype=pool_dtype, device=first_token_tensor.device)
|
@@ -465,6 +465,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
465 |
normalize_embeddings: bool = False,
|
466 |
truncate_dim: Optional[int] = None,
|
467 |
adapter_mask: Optional[torch.Tensor] = None,
|
|
|
468 |
**tokenizer_kwargs,
|
469 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
470 |
"""
|
|
|
314 |
# to the first token.
|
315 |
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
316 |
if adapter_mask is not None:
|
317 |
+
unique_tasks = torch.unique(adapter_mask)
|
318 |
pool_dtype = next(self.dense.parameters()).dtype
|
319 |
pooled_output = torch.empty(first_token_tensor.shape[0], self.dense.out_features,
|
320 |
dtype=pool_dtype, device=first_token_tensor.device)
|
|
|
465 |
normalize_embeddings: bool = False,
|
466 |
truncate_dim: Optional[int] = None,
|
467 |
adapter_mask: Optional[torch.Tensor] = None,
|
468 |
+
task_type: Optional[str] = None,
|
469 |
**tokenizer_kwargs,
|
470 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
471 |
"""
|