Update modeling_time_moe.py
Browse files- modeling_time_moe.py +3 -6
modeling_time_moe.py
CHANGED
@@ -25,6 +25,7 @@ try:
|
|
25 |
except:
|
26 |
pass
|
27 |
|
|
|
28 |
def _get_unpad_data(attention_mask):
|
29 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
30 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
@@ -66,7 +67,7 @@ def load_balancing_loss_func(
|
|
66 |
The auxiliary loss.
|
67 |
"""
|
68 |
if gate_logits is None or not isinstance(gate_logits, (tuple, list)) or gate_logits[0] is None:
|
69 |
-
return
|
70 |
|
71 |
compute_device = gate_logits[0].device
|
72 |
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
|
@@ -293,7 +294,7 @@ class TimeMoeSparseExpertsLayer(nn.Module):
|
|
293 |
""" """
|
294 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
295 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
296 |
-
# router_logits
|
297 |
router_logits = self.gate(hidden_states)
|
298 |
|
299 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
@@ -764,8 +765,6 @@ class TimeMoeModel(TimeMoePreTrainedModel):
|
|
764 |
|
765 |
def __init__(self, config: TimeMoeConfig):
|
766 |
super().__init__(config)
|
767 |
-
# self.padding_idx = config.pad_token_id
|
768 |
-
|
769 |
self.embed_layer = TimeMoeInputEmbedding(config)
|
770 |
self.layers = nn.ModuleList(
|
771 |
[TimeMoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
@@ -1096,12 +1095,10 @@ class TimeMoeForPrediction(TimeMoePreTrainedModel, TSGenerationMixin):
|
|
1096 |
shift_labels = labels
|
1097 |
|
1098 |
# Calculate loss with mask
|
1099 |
-
# losses = self.loss_function(shift_predictions.to(torch.float32), shift_labels.to(torch.float32))
|
1100 |
losses = self.loss_function(shift_predictions, shift_labels)
|
1101 |
|
1102 |
if loss_masks is not None:
|
1103 |
losses = losses * loss_masks
|
1104 |
-
|
1105 |
loss = losses.sum() / loss_masks.sum()
|
1106 |
else:
|
1107 |
loss = torch.mean(losses)
|
|
|
25 |
except:
|
26 |
pass
|
27 |
|
28 |
+
|
29 |
def _get_unpad_data(attention_mask):
|
30 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
31 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
|
|
67 |
The auxiliary loss.
|
68 |
"""
|
69 |
if gate_logits is None or not isinstance(gate_logits, (tuple, list)) or gate_logits[0] is None:
|
70 |
+
return 0.0
|
71 |
|
72 |
compute_device = gate_logits[0].device
|
73 |
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
|
|
|
294 |
""" """
|
295 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
296 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
297 |
+
# router_logits -> (batch * sequence_length, n_experts)
|
298 |
router_logits = self.gate(hidden_states)
|
299 |
|
300 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
|
|
765 |
|
766 |
def __init__(self, config: TimeMoeConfig):
|
767 |
super().__init__(config)
|
|
|
|
|
768 |
self.embed_layer = TimeMoeInputEmbedding(config)
|
769 |
self.layers = nn.ModuleList(
|
770 |
[TimeMoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
|
1095 |
shift_labels = labels
|
1096 |
|
1097 |
# Calculate loss with mask
|
|
|
1098 |
losses = self.loss_function(shift_predictions, shift_labels)
|
1099 |
|
1100 |
if loss_masks is not None:
|
1101 |
losses = losses * loss_masks
|
|
|
1102 |
loss = losses.sum() / loss_masks.sum()
|
1103 |
else:
|
1104 |
loss = torch.mean(losses)
|