idefics2-8b-init?

#5
by giobin - opened

My compliments for the great work you have been doing with idefics2 (and IDEFICS before it)! Is it possible to have the checkpoint of idefics2 even before the pretraining phase (before idefics2-8b-base)? that would help people trying to "reproduce" at least part of the training. Basically i am asking for the initialization code or weights of the idefics2 modality projection layers. That would be great!
Thanks!

giobin changed discussion status to closed
giobin changed discussion status to open

Hi @giobin , here is our code for the initialization of the modules

def _init_weights(self, module):
        def init_a_linear(module, mean=0.0, std=self.config.initializer_range):
            with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)):
                module.weight.data.normal_(mean=mean, std=std)
                if module.bias is not None:
                    with ContextManagers(deepspeed_gathered_parameters_context_manager(module.bias, modify=True)):
                        module.bias.data.zero_()

        if isinstance(module, MLP):
            for sub_module_name, sub_module in module.named_modules():
                if isinstance(sub_module, nn.Linear):
                    factor = 1.0
                    if "down_proj" in sub_module_name:
                        factor = 2.0
                    init_a_linear(sub_module, std=(0.4 / (self.config.hidden_size * factor)) ** 0.5)

        if isinstance(module, PerceiverResampler):
            with ContextManagers(deepspeed_gathered_parameters_context_manager(module.latents, modify=True)):
                module.latents.data.normal_(mean=0.0, std=(1.0 / self.config.hidden_size) ** 0.5)
            for sub_module_name, sub_module in module.named_modules():
                if isinstance(sub_module, nn.Linear):
                    factor = 1.0
                    if "o_proj" in sub_module_name:
                        factor = 2.0 * self.config.perceiver_config.resampler_depth
                    init_a_linear(sub_module, std=(0.4 / (self.config.hidden_size * factor)) ** 0.5)

        elif isinstance(module, nn.Embedding):
            with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)):
                module.weight.data.normal_(mean=0.0, std=(1.0 / self.config.hidden_size) ** 0.5)
                if module.padding_idx is not None:
                    module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, DecoupledLinear):
            if hasattr(module, "additional_fc"):
                init_a_linear(module.additional_fc, std=(1.0 / (module.additional_fc.in_features)) ** 0.5)

thank you Hugo, nice!

giobin changed discussion status to closed

Sign up or log in to comment