davinwang commited on
Commit
5bc5aff
1 Parent(s): f3822a7

compatible with DirectML/ROCm

Browse files

Tensor.new is a deprecated constructor and does not support PrivateUse1 in pytorch 1.13.1/2.0.0, use torch.ones() instead. Please refer to https://github.com/microsoft/DirectML/issues/400 and https://github.com/pytorch/pytorch/issues/95734 and https://huggingface.co/THUDM/chatglm2-6b/discussions/71 for more detail. This should also fix the ROCm compatibility in this file.

Files changed (1) hide show
  1. modeling_chatglm.py +2 -1
modeling_chatglm.py CHANGED
@@ -16,6 +16,7 @@ from transformers.modeling_outputs import (
16
  BaseModelOutputWithPast,
17
  CausalLMOutputWithPast,
18
  )
 
19
  from transformers.modeling_utils import PreTrainedModel
20
  from transformers.utils import logging
21
  from transformers.generation.logits_process import LogitsProcessor
@@ -1138,7 +1139,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1138
  )
1139
  logits_warper = self._get_logits_warper(generation_config)
1140
 
1141
- unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1142
  scores = None
1143
  while True:
1144
  model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
 
16
  BaseModelOutputWithPast,
17
  CausalLMOutputWithPast,
18
  )
19
+
20
  from transformers.modeling_utils import PreTrainedModel
21
  from transformers.utils import logging
22
  from transformers.generation.logits_process import LogitsProcessor
 
1139
  )
1140
  logits_warper = self._get_logits_warper(generation_config)
1141
 
1142
+ unfinished_sequences = torch.ones(input_ids.shape[0], device=input_ids.device, dtype=input_ids.dtype)
1143
  scores = None
1144
  while True:
1145
  model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)