File size: 927 Bytes
aea73e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# -*- coding: utf-8 -*-
# Helper functions for models
#
# @ Fabian Hörst, [email protected]
# Institute for Artifical Intelligence in Medicine,
# University Medicine Essen

from torch import nn


def reset_weights(model: nn.Module) -> None:
    """Reset the parameters of the model to avaid weight leakage

    Args:
        model (nn.Module): PyTorch Model
    """
    for layer in model.children():
        if hasattr(layer, "reset_parameters"):
            layer.reset_parameters()


def initialize_weights(module: nn.Module) -> None:
    """Initialize Module weights according to xavier

    Args:
        module (nn.Module): Model
    """
    for m in module.modules():
        if isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            m.bias.data.zero_()

        elif isinstance(m, nn.BatchNorm1d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)