|
def NAME_TO_WIDTH(name): |
|
map = { |
|
'mn04': 0.4, |
|
'mn05': 0.5, |
|
'mn10': 1.0, |
|
'mn20': 2.0, |
|
'mn30': 3.0, |
|
'mn40': 4.0 |
|
} |
|
try: |
|
w = map[name[:4]] |
|
except: |
|
w = 1.0 |
|
|
|
return w |
|
|
|
|
|
import csv |
|
|
|
|
|
with open('efficientat/metadata/class_labels_indices.csv', 'r') as f: |
|
reader = csv.reader(f, delimiter=',') |
|
lines = list(reader) |
|
|
|
labels = [] |
|
ids = [] |
|
for i1 in range(1, len(lines)): |
|
id = lines[i1][1] |
|
label = lines[i1][2] |
|
ids.append(id) |
|
labels.append(label) |
|
|
|
classes_num = len(labels) |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
def exp_warmup_linear_down(warmup, rampdown_length, start_rampdown, last_value): |
|
rampup = exp_rampup(warmup) |
|
rampdown = linear_rampdown(rampdown_length, start_rampdown, last_value) |
|
def wrapper(epoch): |
|
return rampup(epoch) * rampdown(epoch) |
|
return wrapper |
|
|
|
|
|
def exp_rampup(rampup_length): |
|
"""Exponential rampup from https://arxiv.org/abs/1610.02242""" |
|
def wrapper(epoch): |
|
if epoch < rampup_length: |
|
epoch = np.clip(epoch, 0.5, rampup_length) |
|
phase = 1.0 - epoch / rampup_length |
|
return float(np.exp(-5.0 * phase * phase)) |
|
else: |
|
return 1.0 |
|
return wrapper |
|
|
|
|
|
def linear_rampdown(rampdown_length, start=0, last_value=0): |
|
def wrapper(epoch): |
|
if epoch <= start: |
|
return 1. |
|
elif epoch - start < rampdown_length: |
|
return last_value + (1. - last_value) * (rampdown_length - epoch + start) / rampdown_length |
|
else: |
|
return last_value |
|
return wrapper |
|
|
|
|
|
import torch |
|
|
|
|
|
def mixup(size, alpha): |
|
rn_indices = torch.randperm(size) |
|
lambd = np.random.beta(alpha, alpha, size).astype(np.float32) |
|
lambd = np.concatenate([lambd[:, None], 1 - lambd[:, None]], 1).max(1) |
|
lam = torch.FloatTensor(lambd) |
|
return rn_indices, lam |
|
|
|
|
|
from torch.distributions.beta import Beta |
|
|
|
|
|
def mixstyle(x, p=0.4, alpha=0.4, eps=1e-6, mix_labels=False): |
|
if np.random.rand() > p: |
|
return x |
|
batch_size = x.size(0) |
|
|
|
|
|
f_mu = x.mean(dim=[1, 3], keepdim=True) |
|
f_var = x.var(dim=[1, 3], keepdim=True) |
|
|
|
f_sig = (f_var + eps).sqrt() |
|
f_mu, f_sig = f_mu.detach(), f_sig.detach() |
|
x_normed = (x - f_mu) / f_sig |
|
lmda = Beta(alpha, alpha).sample((batch_size, 1, 1, 1)).to(x.device) |
|
perm = torch.randperm(batch_size).to(x.device) |
|
f_mu_perm, f_sig_perm = f_mu[perm], f_sig[perm] |
|
mu_mix = f_mu * lmda + f_mu_perm * (1 - lmda) |
|
sig_mix = f_sig * lmda + f_sig_perm * (1 - lmda) |
|
x = x_normed * sig_mix + mu_mix |
|
if mix_labels: |
|
return x, perm, lmda |
|
return x |
|
|