yuanzhoulvpi
commited on
Commit
•
81c9eae
1
Parent(s):
8eb45c8
Update modeling_chatglm.py
Browse files在进行模型并行的时候,如果不加这行的代码,会报错,建议加上。
- modeling_chatglm.py +1 -1
modeling_chatglm.py
CHANGED
@@ -952,7 +952,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
952 |
|
953 |
# Shift so that tokens < n predict n
|
954 |
shift_logits = lm_logits[..., :-1, :].contiguous()
|
955 |
-
shift_labels = labels[..., 1:].contiguous()
|
956 |
# Flatten the tokens
|
957 |
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
958 |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
|
|
952 |
|
953 |
# Shift so that tokens < n predict n
|
954 |
shift_logits = lm_logits[..., :-1, :].contiguous()
|
955 |
+
shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
|
956 |
# Flatten the tokens
|
957 |
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
958 |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|