Jackmin108
commited on
Commit
•
65e9690
1
Parent(s):
4ee2970
fix: device
Browse filesSigned-off-by: Meow <[email protected]>
- embedding.py +1 -1
- mha.py +2 -2
- mlp.py +2 -2
- modeling_xlm_roberta.py +1 -1
embedding.py
CHANGED
@@ -51,7 +51,7 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
51 |
unique_tasks = torch.unique(adapter_mask).tolist()
|
52 |
embedding_dtype = next(self.word_embeddings.parameters()).dtype
|
53 |
embeddings = torch.empty(*input_ids.shape, self.word_embeddings.embedding_dim,
|
54 |
-
dtype=embedding_dtype
|
55 |
for task_id in unique_tasks:
|
56 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
57 |
task_input_ids = input_ids[task_indices]
|
|
|
51 |
unique_tasks = torch.unique(adapter_mask).tolist()
|
52 |
embedding_dtype = next(self.word_embeddings.parameters()).dtype
|
53 |
embeddings = torch.empty(*input_ids.shape, self.word_embeddings.embedding_dim,
|
54 |
+
dtype=embedding_dtype, device=input_ids.device)
|
55 |
for task_id in unique_tasks:
|
56 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
57 |
task_input_ids = input_ids[task_indices]
|
mha.py
CHANGED
@@ -650,7 +650,7 @@ class MHA(nn.Module):
|
|
650 |
unique_tasks = torch.unique(cu_adapter_mask).tolist()
|
651 |
qkv_dtype = next(self.Wqkv.parameters()).dtype
|
652 |
qkv = torch.empty(x.shape[0], self.Wqkv.out_features,
|
653 |
-
dtype=qkv_dtype
|
654 |
for task_id in unique_tasks:
|
655 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
656 |
task_tensor = x[task_indices]
|
@@ -755,7 +755,7 @@ class MHA(nn.Module):
|
|
755 |
unique_tasks = torch.unique(cu_adapter_mask).tolist()
|
756 |
out_dtype = next(self.out_proj.parameters()).dtype
|
757 |
out = torch.empty(inp.shape[0], self.out_proj.out_features,
|
758 |
-
dtype=out_dtype
|
759 |
for task_id in unique_tasks:
|
760 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
761 |
task_tensor = inp[task_indices]
|
|
|
650 |
unique_tasks = torch.unique(cu_adapter_mask).tolist()
|
651 |
qkv_dtype = next(self.Wqkv.parameters()).dtype
|
652 |
qkv = torch.empty(x.shape[0], self.Wqkv.out_features,
|
653 |
+
dtype=qkv_dtype, device=x.device)
|
654 |
for task_id in unique_tasks:
|
655 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
656 |
task_tensor = x[task_indices]
|
|
|
755 |
unique_tasks = torch.unique(cu_adapter_mask).tolist()
|
756 |
out_dtype = next(self.out_proj.parameters()).dtype
|
757 |
out = torch.empty(inp.shape[0], self.out_proj.out_features,
|
758 |
+
dtype=out_dtype, device=inp.device)
|
759 |
for task_id in unique_tasks:
|
760 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
761 |
task_tensor = inp[task_indices]
|
mlp.py
CHANGED
@@ -52,7 +52,7 @@ class Mlp(nn.Module):
|
|
52 |
unique_tasks = torch.unique(cu_adapter_mask).tolist()
|
53 |
fc1_dtype = next(self.fc1.parameters()).dtype
|
54 |
y = torch.empty(x.shape[0], self.fc1.out_features,
|
55 |
-
dtype=fc1_dtype
|
56 |
for task_id in unique_tasks:
|
57 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
58 |
task_tensor = x[task_indices]
|
@@ -67,7 +67,7 @@ class Mlp(nn.Module):
|
|
67 |
unique_tasks = torch.unique(cu_adapter_mask).tolist()
|
68 |
fc2_dtype = next(self.fc2.parameters()).dtype
|
69 |
out = torch.empty(y.shape[0], self.fc2.out_features,
|
70 |
-
dtype=fc2_dtype
|
71 |
for task_id in unique_tasks:
|
72 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
73 |
task_tensor = y[task_indices]
|
|
|
52 |
unique_tasks = torch.unique(cu_adapter_mask).tolist()
|
53 |
fc1_dtype = next(self.fc1.parameters()).dtype
|
54 |
y = torch.empty(x.shape[0], self.fc1.out_features,
|
55 |
+
dtype=fc1_dtype, device=x.device)
|
56 |
for task_id in unique_tasks:
|
57 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
58 |
task_tensor = x[task_indices]
|
|
|
67 |
unique_tasks = torch.unique(cu_adapter_mask).tolist()
|
68 |
fc2_dtype = next(self.fc2.parameters()).dtype
|
69 |
out = torch.empty(y.shape[0], self.fc2.out_features,
|
70 |
+
dtype=fc2_dtype, device=y.device)
|
71 |
for task_id in unique_tasks:
|
72 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
73 |
task_tensor = y[task_indices]
|
modeling_xlm_roberta.py
CHANGED
@@ -317,7 +317,7 @@ class XLMRobertaPooler(nn.Module):
|
|
317 |
unique_tasks = torch.unique(adapter_mask).tolist()
|
318 |
pool_dtype = next(self.dense.parameters()).dtype
|
319 |
pooled_output = torch.empty(first_token_tensor.shape[0], self.dense.out_features,
|
320 |
-
dtype=pool_dtype
|
321 |
for task_id in unique_tasks:
|
322 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
323 |
task_first_token_tensor = first_token_tensor[task_indices]
|
|
|
317 |
unique_tasks = torch.unique(adapter_mask).tolist()
|
318 |
pool_dtype = next(self.dense.parameters()).dtype
|
319 |
pooled_output = torch.empty(first_token_tensor.shape[0], self.dense.out_features,
|
320 |
+
dtype=pool_dtype, device=first_token_tensor.device)
|
321 |
for task_id in unique_tasks:
|
322 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
323 |
task_first_token_tensor = first_token_tensor[task_indices]
|