from transformers import PretrainedConfig | |
from typing import List | |
class MnistConfig(PretrainedConfig): | |
# since we have an image classification task | |
# we need to put a model type that is close to our task | |
# don't worry this will not affect our model | |
model_type = "MobileNetV1" | |
def __init__( | |
self, | |
conv1=10, | |
conv2=20, | |
**kwargs): | |
self.conv1 = conv1 | |
self.conv2 = conv2 | |
super().__init__(**kwargs) | |