Update ts_generation_mixin.py
Browse files- ts_generation_mixin.py +3 -3
ts_generation_mixin.py
CHANGED
@@ -226,12 +226,12 @@ class TSGenerationMixin(GenerationMixin):
|
|
226 |
if "decoder_attention_mask" in model_kwargs:
|
227 |
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
228 |
model_kwargs["decoder_attention_mask"] = torch.cat(
|
229 |
-
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0],
|
230 |
dim=-1,
|
231 |
)
|
232 |
|
233 |
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
|
234 |
-
|
235 |
-
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
|
236 |
|
237 |
return model_kwargs
|
|
|
226 |
if "decoder_attention_mask" in model_kwargs:
|
227 |
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
228 |
model_kwargs["decoder_attention_mask"] = torch.cat(
|
229 |
+
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], horizon_length))],
|
230 |
dim=-1,
|
231 |
)
|
232 |
|
233 |
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
|
234 |
+
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + horizon_length
|
235 |
+
# model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
|
236 |
|
237 |
return model_kwargs
|