模型并行出错并给出修改方案

#54
Files changed (1) hide show
  1. modeling_chatglm.py +1 -1
modeling_chatglm.py CHANGED
@@ -952,7 +952,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
952
 
953
  # Shift so that tokens < n predict n
954
  shift_logits = lm_logits[..., :-1, :].contiguous()
955
- shift_labels = labels[..., 1:].contiguous()
956
  # Flatten the tokens
957
  loss_fct = CrossEntropyLoss(ignore_index=-100)
958
  loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
 
952
 
953
  # Shift so that tokens < n predict n
954
  shift_logits = lm_logits[..., :-1, :].contiguous()
955
+ shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
956
  # Flatten the tokens
957
  loss_fct = CrossEntropyLoss(ignore_index=-100)
958
  loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))