Maple728 commited on
Commit
3d1e09d
1 Parent(s): 83dec66

Update ts_generation_mixin.py

Browse files
Files changed (1) hide show
  1. ts_generation_mixin.py +7 -4
ts_generation_mixin.py CHANGED
@@ -28,6 +28,8 @@ class TSGenerationMixin(GenerationMixin):
28
  streamer: Optional["BaseStreamer"] = None,
29
  **model_kwargs,
30
  ) -> Union[GenerateNonBeamOutput, torch.Tensor]:
 
 
31
  if len(input_ids.shape) == 2:
32
  batch_size, cur_len = input_ids.shape
33
  else:
@@ -169,6 +171,7 @@ class TSGenerationMixin(GenerationMixin):
169
  if streamer is not None:
170
  streamer.end()
171
 
 
172
  if return_dict_in_generate:
173
  if self.config.is_encoder_decoder:
174
  return GenerateEncoderDecoderOutput(
@@ -192,7 +195,7 @@ class TSGenerationMixin(GenerationMixin):
192
  past_key_values=model_kwargs.get("past_key_values"),
193
  )
194
  else:
195
- return input_ids.squeeze(dim=-1)
196
 
197
  def _update_model_kwargs_for_generation(
198
  self,
@@ -226,12 +229,12 @@ class TSGenerationMixin(GenerationMixin):
226
  if "decoder_attention_mask" in model_kwargs:
227
  decoder_attention_mask = model_kwargs["decoder_attention_mask"]
228
  model_kwargs["decoder_attention_mask"] = torch.cat(
229
- [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
230
  dim=-1,
231
  )
232
 
233
  if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
234
- # model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + horizon_length
235
- model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
236
 
237
  return model_kwargs
 
28
  streamer: Optional["BaseStreamer"] = None,
29
  **model_kwargs,
30
  ) -> Union[GenerateNonBeamOutput, torch.Tensor]:
31
+ input_ids_origin_device = input_ids.device
32
+ input_ids = input_ids.to(self.device)
33
  if len(input_ids.shape) == 2:
34
  batch_size, cur_len = input_ids.shape
35
  else:
 
171
  if streamer is not None:
172
  streamer.end()
173
 
174
+ input_ids.squeeze_(dim=-1).to(input_ids_origin_device)
175
  if return_dict_in_generate:
176
  if self.config.is_encoder_decoder:
177
  return GenerateEncoderDecoderOutput(
 
195
  past_key_values=model_kwargs.get("past_key_values"),
196
  )
197
  else:
198
+ return input_ids
199
 
200
  def _update_model_kwargs_for_generation(
201
  self,
 
229
  if "decoder_attention_mask" in model_kwargs:
230
  decoder_attention_mask = model_kwargs["decoder_attention_mask"]
231
  model_kwargs["decoder_attention_mask"] = torch.cat(
232
+ [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], horizon_length))],
233
  dim=-1,
234
  )
235
 
236
  if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
237
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + horizon_length
238
+ # model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
239
 
240
  return model_kwargs