import tensorflow as tf | |
from transformers.modeling_tf_utils import unpack_inputs | |
from transformers.modeling_tf_utils import TFPreTrainedModel | |
from .configuration_my_model import MyModelConfig | |
class TFMyModelPretrainedModel(TFPreTrainedModel): | |
config_class = MyModelConfig | |
class TFMyModel(TFMyModelPretrainedModel): | |
def __init__(self, config: MyModelConfig): | |
super().__init__(config) | |
self.config = config | |
self.n_layers = config.n_layers | |
self.hidden_dim = config.hidden_dim | |
self.linear = tf.keras.layers.Dense(units=config.n_layers) | |
def dummy_inputs(self): | |
hidden = tf.zeros(shape=(1, self.config.hidden_dim)) | |
dummy_inputs = {"hidden": hidden} | |
return dummy_inputs | |
def call( | |
self, | |
hidden, | |
output_attentions=False, | |
output_hidden_states=False, | |
return_dict=False, | |
): | |
breakpoint() | |
self.linear(hidden) | |