Maple728 commited on
Commit
83dec66
1 Parent(s): 740b1e0

Update ts_generation_mixin.py

Browse files
Files changed (1) hide show
  1. ts_generation_mixin.py +7 -4
ts_generation_mixin.py CHANGED
@@ -13,7 +13,7 @@ class TSGenerationMixin(GenerationMixin):
13
 
14
  def _greedy_search(
15
  self,
16
- input_ids: torch.LongTensor,
17
  logits_processor: Optional[LogitsProcessorList] = None,
18
  stopping_criteria: Optional[StoppingCriteriaList] = None,
19
  max_length: Optional[int] = None,
@@ -27,7 +27,11 @@ class TSGenerationMixin(GenerationMixin):
27
  synced_gpus: bool = False,
28
  streamer: Optional["BaseStreamer"] = None,
29
  **model_kwargs,
30
- ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
 
 
 
 
31
  # init values
32
  logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
33
  stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
@@ -82,7 +86,6 @@ class TSGenerationMixin(GenerationMixin):
82
  )
83
 
84
  # keep track of which sequences are already finished
85
- batch_size, cur_len = input_ids.shape
86
  if "inputs_embeds" in model_kwargs:
87
  cur_len = model_kwargs["inputs_embeds"].shape[1]
88
  this_peer_finished = False
@@ -189,7 +192,7 @@ class TSGenerationMixin(GenerationMixin):
189
  past_key_values=model_kwargs.get("past_key_values"),
190
  )
191
  else:
192
- return input_ids
193
 
194
  def _update_model_kwargs_for_generation(
195
  self,
 
13
 
14
  def _greedy_search(
15
  self,
16
+ input_ids: torch.Tensor,
17
  logits_processor: Optional[LogitsProcessorList] = None,
18
  stopping_criteria: Optional[StoppingCriteriaList] = None,
19
  max_length: Optional[int] = None,
 
27
  synced_gpus: bool = False,
28
  streamer: Optional["BaseStreamer"] = None,
29
  **model_kwargs,
30
+ ) -> Union[GenerateNonBeamOutput, torch.Tensor]:
31
+ if len(input_ids.shape) == 2:
32
+ batch_size, cur_len = input_ids.shape
33
+ else:
34
+ raise ValueError('Input shape must be: [batch_size, seq_len]')
35
  # init values
36
  logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
37
  stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
 
86
  )
87
 
88
  # keep track of which sequences are already finished
 
89
  if "inputs_embeds" in model_kwargs:
90
  cur_len = model_kwargs["inputs_embeds"].shape[1]
91
  this_peer_finished = False
 
192
  past_key_values=model_kwargs.get("past_key_values"),
193
  )
194
  else:
195
+ return input_ids.squeeze(dim=-1)
196
 
197
  def _update_model_kwargs_for_generation(
198
  self,