LKCell / models /utils /tools.py
xiazhi1
initial commit
aea73e2
raw
history blame
927 Bytes
# -*- coding: utf-8 -*-
# Helper functions for models
#
# @ Fabian Hörst, [email protected]
# Institute for Artifical Intelligence in Medicine,
# University Medicine Essen
from torch import nn
def reset_weights(model: nn.Module) -> None:
"""Reset the parameters of the model to avaid weight leakage
Args:
model (nn.Module): PyTorch Model
"""
for layer in model.children():
if hasattr(layer, "reset_parameters"):
layer.reset_parameters()
def initialize_weights(module: nn.Module) -> None:
"""Initialize Module weights according to xavier
Args:
module (nn.Module): Model
"""
for m in module.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)