_set_gradient_checkpointing() got an unexpected keyword argument 'enable'
#3
by
ehartford
- opened
I have worked around this by modifying modeling_qwen.py as follows:
def _set_gradient_checkpointing(self, enable: bool = False, gradient_checkpointing_func: Callable = None):
is_gradient_checkpointing_set = False
if isinstance(self, QWenModel):
self.gradient_checkpointing = enable
self._gradient_checkpointing_func = gradient_checkpointing_func
is_gradient_checkpointing_set = True
for module in self.modules():
if isinstance(module, QWenModel):
module.gradient_checkpointing = enable
module._gradient_checkpointing_func = gradient_checkpointing_func
is_gradient_checkpointing_set = True
if not is_gradient_checkpointing_set:
raise ValueError(f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute 'gradient_checkpointing' to modules of the model that uses checkpointing.")
@ehartford
Hello!
I am not creator of this model,
But I solved this problem, so I want to share my solution.
My solution is check the your transformers
module version, such that pip install transformers==4.34.0
Thank you!
That's not a solution when you are using software that requires latest transformers