Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import logging | |
import numpy as np | |
import spaces | |
import torch | |
from torch.nn import init | |
def init_weights(net, init_type='normal', init_gain=0.02): | |
"""Initialize network weights. | |
Parameters: | |
net (network) -- network to be initialized | |
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal | |
init_gain (float) -- scaling factor for normal, xavier and orthogonal. | |
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might | |
work better for some applications. Feel free to try yourself. | |
""" | |
def init_func(m): # define the initialization function | |
classname = m.__class__.__name__ | |
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): | |
if init_type == 'normal': | |
init.normal_(m.weight.data, 0.0, init_gain) | |
elif init_type == 'xavier': | |
init.xavier_normal_(m.weight.data, gain=init_gain) | |
elif init_type == 'kaiming': | |
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') | |
elif init_type == 'orthogonal': | |
init.orthogonal_(m.weight.data, gain=init_gain) | |
else: | |
raise NotImplementedError('initialization method [%s] is not implemented' % init_type) | |
if hasattr(m, 'bias') and m.bias is not None: | |
init.constant_(m.bias.data, 0.0) | |
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. | |
init.normal_(m.weight.data, 1.0, init_gain) | |
init.constant_(m.bias.data, 0.0) | |
net.apply(init_func) # apply the initialization function <init_func> | |
def create_logger(name, log_file, level=logging.INFO): | |
l = logging.getLogger(name) | |
formatter = logging.Formatter('[%(asctime)s] %(message)s') | |
fh = logging.FileHandler(log_file) | |
fh.setFormatter(formatter) | |
sh = logging.StreamHandler() | |
sh.setFormatter(formatter) | |
l.setLevel(level) | |
l.addHandler(fh) | |
l.addHandler(sh) | |
return l | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self, length=0): | |
self.length = length | |
self.reset() | |
def reset(self): | |
if self.length > 0: | |
self.history = [] | |
else: | |
self.count = 0 | |
self.sum = 0.0 | |
self.val = 0.0 | |
self.avg = 0.0 | |
def update(self, val): | |
if self.length > 0: | |
self.history.append(val) | |
if len(self.history) > self.length: | |
del self.history[0] | |
self.val = self.history[-1] | |
self.avg = np.mean(self.history) | |
else: | |
self.val = val | |
self.sum += val | |
self.count += 1 | |
self.avg = self.sum / self.count | |
def accuracy(output, target, topk=(1,)): | |
"""Computes the precision@k for the specified values of k""" | |
maxk = max(topk) | |
batch_size = target.size(0) | |
_, pred = output.topk(maxk, 1, True, True) | |
pred = pred.t() | |
correct = pred.eq(target.view(1, -1).expand_as(pred)) | |
res = [] | |
for k in topk: | |
correct_k = correct[:k].view(-1).float().sum(0, keepdims=True) | |
res.append(correct_k.mul_(100.0 / batch_size)) | |
return res | |
def load_state(path, model, optimizer=None): | |
def map_func(storage, location): | |
return storage.cuda() | |
if os.path.isfile(path): | |
print("=> loading checkpoint '{}'".format(path)) | |
# checkpoint = torch.load(path, map_location=map_func) | |
checkpoint = torch.load(path) | |
model.load_state_dict(checkpoint['state_dict'], strict=False) | |
ckpt_keys = set(checkpoint['state_dict'].keys()) | |
own_keys = set(model.state_dict().keys()) | |
missing_keys = own_keys - ckpt_keys | |
# print(ckpt_keys) | |
# print(own_keys) | |
for k in missing_keys: | |
print('caution: missing keys from checkpoint {}: {}'.format(path, k)) | |
last_iter = checkpoint['step'] | |
if optimizer != None: | |
optimizer.load_state_dict(checkpoint['optimizer']) | |
print("=> also loaded optimizer from checkpoint '{}' (iter {})" | |
.format(path, last_iter)) | |
return last_iter | |
else: | |
print("=> no checkpoint found at '{}'".format(path)) | |