Update ts_generation_mixin.py
Browse files- ts_generation_mixin.py +7 -4
ts_generation_mixin.py
CHANGED
@@ -13,7 +13,7 @@ class TSGenerationMixin(GenerationMixin):
|
|
13 |
|
14 |
def _greedy_search(
|
15 |
self,
|
16 |
-
input_ids: torch.
|
17 |
logits_processor: Optional[LogitsProcessorList] = None,
|
18 |
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
19 |
max_length: Optional[int] = None,
|
@@ -27,7 +27,11 @@ class TSGenerationMixin(GenerationMixin):
|
|
27 |
synced_gpus: bool = False,
|
28 |
streamer: Optional["BaseStreamer"] = None,
|
29 |
**model_kwargs,
|
30 |
-
) -> Union[GenerateNonBeamOutput, torch.
|
|
|
|
|
|
|
|
|
31 |
# init values
|
32 |
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
33 |
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
@@ -82,7 +86,6 @@ class TSGenerationMixin(GenerationMixin):
|
|
82 |
)
|
83 |
|
84 |
# keep track of which sequences are already finished
|
85 |
-
batch_size, cur_len = input_ids.shape
|
86 |
if "inputs_embeds" in model_kwargs:
|
87 |
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
88 |
this_peer_finished = False
|
@@ -189,7 +192,7 @@ class TSGenerationMixin(GenerationMixin):
|
|
189 |
past_key_values=model_kwargs.get("past_key_values"),
|
190 |
)
|
191 |
else:
|
192 |
-
return input_ids
|
193 |
|
194 |
def _update_model_kwargs_for_generation(
|
195 |
self,
|
|
|
13 |
|
14 |
def _greedy_search(
|
15 |
self,
|
16 |
+
input_ids: torch.Tensor,
|
17 |
logits_processor: Optional[LogitsProcessorList] = None,
|
18 |
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
19 |
max_length: Optional[int] = None,
|
|
|
27 |
synced_gpus: bool = False,
|
28 |
streamer: Optional["BaseStreamer"] = None,
|
29 |
**model_kwargs,
|
30 |
+
) -> Union[GenerateNonBeamOutput, torch.Tensor]:
|
31 |
+
if len(input_ids.shape) == 2:
|
32 |
+
batch_size, cur_len = input_ids.shape
|
33 |
+
else:
|
34 |
+
raise ValueError('Input shape must be: [batch_size, seq_len]')
|
35 |
# init values
|
36 |
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
37 |
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
|
|
86 |
)
|
87 |
|
88 |
# keep track of which sequences are already finished
|
|
|
89 |
if "inputs_embeds" in model_kwargs:
|
90 |
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
91 |
this_peer_finished = False
|
|
|
192 |
past_key_values=model_kwargs.get("past_key_values"),
|
193 |
)
|
194 |
else:
|
195 |
+
return input_ids.squeeze(dim=-1)
|
196 |
|
197 |
def _update_model_kwargs_for_generation(
|
198 |
self,
|