Update modeling_indictrans.py
Browse files- modeling_indictrans.py +1 -1
modeling_indictrans.py
CHANGED
@@ -1213,7 +1213,7 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
|
|
1213 |
# move labels to the correct device to enable PP
|
1214 |
labels = labels.to(lm_logits.device)
|
1215 |
loss_fct = nn.CrossEntropyLoss()
|
1216 |
-
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.
|
1217 |
|
1218 |
if not return_dict:
|
1219 |
output = (lm_logits,) + outputs[1:]
|
|
|
1213 |
# move labels to the correct device to enable PP
|
1214 |
labels = labels.to(lm_logits.device)
|
1215 |
loss_fct = nn.CrossEntropyLoss()
|
1216 |
+
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.decoder_vocab_size), labels.view(-1))
|
1217 |
|
1218 |
if not return_dict:
|
1219 |
output = (lm_logits,) + outputs[1:]
|