Spaces:
Runtime error
Runtime error
File size: 6,337 Bytes
e6bfa26 e9d4572 e6bfa26 e9d4572 e6bfa26 e9d4572 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
from keras.engine.base_layer import Layer
from keras.engine.input_spec import InputSpec
from keras import initializers, regularizers, constraints
from keras import backend as K
from keras.saving.object_registration import get_custom_objects
import tensorflow as tf
class InstanceNormalization(Layer):
"""Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016).
Normalize the activations of the previous layer at each step,
i.e. applies a transformation that maintains the mean activation
close to 0 and the activation standard deviation close to 1.
# Arguments
axis: Integer, the axis that should be normalized
(typically the features axis).
For instance, after a `Conv2D` layer with
`data_format="channels_first"`,
set `axis=1` in `InstanceNormalization`.
Setting `axis=None` will normalize all values in each instance of the batch.
Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors.
epsilon: Small float added to variance to avoid dividing by zero.
center: If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: If True, multiply by `gamma`.
If False, `gamma` is not used.
When the next layer is linear (also e.g. `nn.relu`),
this can be disabled since the scaling
will be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
beta_constraint: Optional constraint for the beta weight.
gamma_constraint: Optional constraint for the gamma weight.
# Input shape
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
# Output shape
Same shape as input.
# References
- [Layer Normalization](https://arxiv.org/abs/1607.06450)
- [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022)
"""
def __init__(self,
axis=None,
epsilon=1e-3,
center=True,
scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
**kwargs):
super(InstanceNormalization, self).__init__(**kwargs)
self.supports_masking = True
self.axis = axis
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)
def build(self, input_shape):
ndim = len(input_shape)
if self.axis == 0:
raise ValueError('Axis cannot be zero')
if (self.axis is not None) and (ndim == 2):
raise ValueError('Cannot specify axis for rank 1 tensor')
self.input_spec = InputSpec(ndim=ndim)
if self.axis is None:
shape = (1,)
else:
shape = (input_shape[self.axis],)
if self.scale:
self.gamma = self.add_weight(shape=shape,
name='gamma',
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
else:
self.gamma = None
if self.center:
self.beta = self.add_weight(shape=shape,
name='beta',
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint)
else:
self.beta = None
self.built = True
def call(self, inputs, training=None):
input_shape = K.int_shape(inputs)
reduction_axes = list(range(0, len(input_shape)))
if (self.axis is not None):
del reduction_axes[self.axis]
del reduction_axes[0]
mean, var = tf.nn.moments(inputs, reduction_axes, keepdims=True)
stddev = tf.sqrt(var) + self.epsilon
normed = (inputs - mean) / stddev
broadcast_shape = [1] * len(input_shape)
if self.axis is not None:
broadcast_shape[self.axis] = input_shape[self.axis]
if self.scale:
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
normed = normed * broadcast_gamma
if self.center:
broadcast_beta = K.reshape(self.beta, broadcast_shape)
normed = normed + broadcast_beta
return normed
def get_config(self):
config = {
'axis': self.axis,
'epsilon': self.epsilon,
'center': self.center,
'scale': self.scale,
'beta_initializer': initializers.serialize(self.beta_initializer),
'gamma_initializer': initializers.serialize(self.gamma_initializer),
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
'beta_constraint': constraints.serialize(self.beta_constraint),
'gamma_constraint': constraints.serialize(self.gamma_constraint)
}
base_config = super(InstanceNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
get_custom_objects().update({'InstanceNormalization': InstanceNormalization})
|