idefics2-8b-init?
#5
by
giobin
- opened
My compliments for the great work you have been doing with idefics2 (and IDEFICS before it)! Is it possible to have the checkpoint of idefics2 even before the pretraining phase (before idefics2-8b-base)? that would help people trying to "reproduce" at least part of the training. Basically i am asking for the initialization code or weights of the idefics2 modality projection layers. That would be great!
Thanks!
giobin
changed discussion status to
closed
giobin
changed discussion status to
open
Hi @giobin , here is our code for the initialization of the modules
def _init_weights(self, module):
def init_a_linear(module, mean=0.0, std=self.config.initializer_range):
with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)):
module.weight.data.normal_(mean=mean, std=std)
if module.bias is not None:
with ContextManagers(deepspeed_gathered_parameters_context_manager(module.bias, modify=True)):
module.bias.data.zero_()
if isinstance(module, MLP):
for sub_module_name, sub_module in module.named_modules():
if isinstance(sub_module, nn.Linear):
factor = 1.0
if "down_proj" in sub_module_name:
factor = 2.0
init_a_linear(sub_module, std=(0.4 / (self.config.hidden_size * factor)) ** 0.5)
if isinstance(module, PerceiverResampler):
with ContextManagers(deepspeed_gathered_parameters_context_manager(module.latents, modify=True)):
module.latents.data.normal_(mean=0.0, std=(1.0 / self.config.hidden_size) ** 0.5)
for sub_module_name, sub_module in module.named_modules():
if isinstance(sub_module, nn.Linear):
factor = 1.0
if "o_proj" in sub_module_name:
factor = 2.0 * self.config.perceiver_config.resampler_depth
init_a_linear(sub_module, std=(0.4 / (self.config.hidden_size * factor)) ** 0.5)
elif isinstance(module, nn.Embedding):
with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)):
module.weight.data.normal_(mean=0.0, std=(1.0 / self.config.hidden_size) ** 0.5)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, DecoupledLinear):
if hasattr(module, "additional_fc"):
init_a_linear(module.additional_fc, std=(1.0 / (module.additional_fc.in_features)) ** 0.5)
thank you Hugo, nice!
giobin
changed discussion status to
closed