Update model_class.py
Browse files- model_class.py +3 -0
model_class.py
CHANGED
@@ -26,6 +26,7 @@ class WhisperForAudioCaptioning(transformers.WhisperForConditionalGeneration):
|
|
26 |
output_hidden_states: Optional[bool] = None,
|
27 |
return_dict: Optional[bool] = None,
|
28 |
forced_ac_decoder_ids: Optional[torch.LongTensor] = None, # added to be ignored when passed from trainer
|
|
|
29 |
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
30 |
return super().forward(
|
31 |
input_features=input_features,
|
@@ -43,6 +44,7 @@ class WhisperForAudioCaptioning(transformers.WhisperForConditionalGeneration):
|
|
43 |
output_attentions=output_attentions,
|
44 |
output_hidden_states=output_hidden_states,
|
45 |
return_dict=return_dict,
|
|
|
46 |
)
|
47 |
|
48 |
# copy-pasted and adapted from transformers.WhisperForConditionalGeneration.generate
|
@@ -156,3 +158,4 @@ class WhisperForAudioCaptioning(transformers.WhisperForConditionalGeneration):
|
|
156 |
decoder_input_ids=decoder_input_ids,
|
157 |
**kwargs,
|
158 |
)
|
|
|
|
26 |
output_hidden_states: Optional[bool] = None,
|
27 |
return_dict: Optional[bool] = None,
|
28 |
forced_ac_decoder_ids: Optional[torch.LongTensor] = None, # added to be ignored when passed from trainer
|
29 |
+
decoder_position_ids: Optional[torch.LongTensor] = None,
|
30 |
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
31 |
return super().forward(
|
32 |
input_features=input_features,
|
|
|
44 |
output_attentions=output_attentions,
|
45 |
output_hidden_states=output_hidden_states,
|
46 |
return_dict=return_dict,
|
47 |
+
decoder_position_ids=decoder_position_ids,
|
48 |
)
|
49 |
|
50 |
# copy-pasted and adapted from transformers.WhisperForConditionalGeneration.generate
|
|
|
158 |
decoder_input_ids=decoder_input_ids,
|
159 |
**kwargs,
|
160 |
)
|
161 |
+
|