MOFA-Video_Traj / models /cmp /utils /common_utils.py
myniu
init
b4d6fd1
raw
history blame
4.42 kB
import os
import logging
import numpy as np
import spaces
import torch
from torch.nn import init
def init_weights(net, init_type='normal', init_gain=0.02):
"""Initialize network weights.
Parameters:
net (network) -- network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
work better for some applications. Feel free to try yourself.
"""
def init_func(m): # define the initialization function
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
init.normal_(m.weight.data, 1.0, init_gain)
init.constant_(m.bias.data, 0.0)
net.apply(init_func) # apply the initialization function <init_func>
def create_logger(name, log_file, level=logging.INFO):
l = logging.getLogger(name)
formatter = logging.Formatter('[%(asctime)s] %(message)s')
fh = logging.FileHandler(log_file)
fh.setFormatter(formatter)
sh = logging.StreamHandler()
sh.setFormatter(formatter)
l.setLevel(level)
l.addHandler(fh)
l.addHandler(sh)
return l
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, length=0):
self.length = length
self.reset()
def reset(self):
if self.length > 0:
self.history = []
else:
self.count = 0
self.sum = 0.0
self.val = 0.0
self.avg = 0.0
def update(self, val):
if self.length > 0:
self.history.append(val)
if len(self.history) > self.length:
del self.history[0]
self.val = self.history[-1]
self.avg = np.mean(self.history)
else:
self.val = val
self.sum += val
self.count += 1
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdims=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def load_state(path, model, optimizer=None):
def map_func(storage, location):
return storage.cuda()
if os.path.isfile(path):
print("=> loading checkpoint '{}'".format(path))
# checkpoint = torch.load(path, map_location=map_func)
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['state_dict'], strict=False)
ckpt_keys = set(checkpoint['state_dict'].keys())
own_keys = set(model.state_dict().keys())
missing_keys = own_keys - ckpt_keys
# print(ckpt_keys)
# print(own_keys)
for k in missing_keys:
print('caution: missing keys from checkpoint {}: {}'.format(path, k))
last_iter = checkpoint['step']
if optimizer != None:
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> also loaded optimizer from checkpoint '{}' (iter {})"
.format(path, last_iter))
return last_iter
else:
print("=> no checkpoint found at '{}'".format(path))