Update ts_generation_mixin.py
Browse files- ts_generation_mixin.py +7 -4
ts_generation_mixin.py
CHANGED
@@ -28,6 +28,8 @@ class TSGenerationMixin(GenerationMixin):
|
|
28 |
streamer: Optional["BaseStreamer"] = None,
|
29 |
**model_kwargs,
|
30 |
) -> Union[GenerateNonBeamOutput, torch.Tensor]:
|
|
|
|
|
31 |
if len(input_ids.shape) == 2:
|
32 |
batch_size, cur_len = input_ids.shape
|
33 |
else:
|
@@ -169,6 +171,7 @@ class TSGenerationMixin(GenerationMixin):
|
|
169 |
if streamer is not None:
|
170 |
streamer.end()
|
171 |
|
|
|
172 |
if return_dict_in_generate:
|
173 |
if self.config.is_encoder_decoder:
|
174 |
return GenerateEncoderDecoderOutput(
|
@@ -192,7 +195,7 @@ class TSGenerationMixin(GenerationMixin):
|
|
192 |
past_key_values=model_kwargs.get("past_key_values"),
|
193 |
)
|
194 |
else:
|
195 |
-
return input_ids
|
196 |
|
197 |
def _update_model_kwargs_for_generation(
|
198 |
self,
|
@@ -226,12 +229,12 @@ class TSGenerationMixin(GenerationMixin):
|
|
226 |
if "decoder_attention_mask" in model_kwargs:
|
227 |
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
228 |
model_kwargs["decoder_attention_mask"] = torch.cat(
|
229 |
-
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0],
|
230 |
dim=-1,
|
231 |
)
|
232 |
|
233 |
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
|
234 |
-
|
235 |
-
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
|
236 |
|
237 |
return model_kwargs
|
|
|
28 |
streamer: Optional["BaseStreamer"] = None,
|
29 |
**model_kwargs,
|
30 |
) -> Union[GenerateNonBeamOutput, torch.Tensor]:
|
31 |
+
input_ids_origin_device = input_ids.device
|
32 |
+
input_ids = input_ids.to(self.device)
|
33 |
if len(input_ids.shape) == 2:
|
34 |
batch_size, cur_len = input_ids.shape
|
35 |
else:
|
|
|
171 |
if streamer is not None:
|
172 |
streamer.end()
|
173 |
|
174 |
+
input_ids.squeeze_(dim=-1).to(input_ids_origin_device)
|
175 |
if return_dict_in_generate:
|
176 |
if self.config.is_encoder_decoder:
|
177 |
return GenerateEncoderDecoderOutput(
|
|
|
195 |
past_key_values=model_kwargs.get("past_key_values"),
|
196 |
)
|
197 |
else:
|
198 |
+
return input_ids
|
199 |
|
200 |
def _update_model_kwargs_for_generation(
|
201 |
self,
|
|
|
229 |
if "decoder_attention_mask" in model_kwargs:
|
230 |
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
231 |
model_kwargs["decoder_attention_mask"] = torch.cat(
|
232 |
+
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], horizon_length))],
|
233 |
dim=-1,
|
234 |
)
|
235 |
|
236 |
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
|
237 |
+
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + horizon_length
|
238 |
+
# model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
|
239 |
|
240 |
return model_kwargs
|