模型并行出错并给出修改方案
#54
by
yuanzhoulvpi
- opened
- modeling_chatglm.py +1 -1
modeling_chatglm.py
CHANGED
@@ -952,7 +952,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
952 |
|
953 |
# Shift so that tokens < n predict n
|
954 |
shift_logits = lm_logits[..., :-1, :].contiguous()
|
955 |
-
shift_labels = labels[..., 1:].contiguous()
|
956 |
# Flatten the tokens
|
957 |
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
958 |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
|
|
952 |
|
953 |
# Shift so that tokens < n predict n
|
954 |
shift_logits = lm_logits[..., :-1, :].contiguous()
|
955 |
+
shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
|
956 |
# Flatten the tokens
|
957 |
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
958 |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|