Model checkpoint loading throws an error
#5
by
alvations
- opened
With unbabel-comet==2.2.2
, when loading the checkpoint with:
import os
from huggingface_hub import snapshot_download
from comet.models.multitask.xcomet_metric import XCOMETMetric
#model_path = snapshot_download(repo_id="Unbabel/XCOMET-XL", cache_dir=os.path.abspath(os.path.dirname('.')))
model_checkpoint_path = f"{model_path}/checkpoints/model.ckpt"
xcometxl = XCOMETMetric.load_from_checkpoint(model_checkpoint_path)
It's giving an error when loading the state_dict:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-9-103976254713> in <cell line: 7>()
5
6 model_checkpoint_path = f"{model_path}/checkpoints/model.ckpt"
----> 7 xcometxl = XCOMETMetric.load_from_checkpoint(model_checkpoint_path)
4 frames
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict, assign)
2187
2188 if len(error_msgs) > 0:
-> 2189 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
2190 self.__class__.__name__, "\n\t".join(error_msgs)))
2191 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for XCOMETMetric:
Unexpected key(s) in state_dict: "encoder.model.embeddings.position_ids".
Is there something else that needs to be initialized for the model checkpoint to load properly?