Spaces:
Runtime error
Runtime error
from keras.layers import Conv2D, Activation, Input, Concatenate, LeakyReLU, Lambda, AveragePooling2D, UpSampling2D, Convolution2D, BatchNormalization, Conv2DTranspose, Add | |
from keras.models import Model | |
from InstanceNorm import InstanceNormalization | |
def make_standard_UNET(channels,outs): | |
def relu(x): | |
return Activation('relu')(x) | |
def concat(x): | |
return Concatenate()(x) | |
c0 = Convolution2D(filters=32, kernel_size=3, strides=1, padding='same', name='c0') | |
c1 = Convolution2D(filters=64, kernel_size=4, strides=2, padding='same', name='c1') | |
c2 = Convolution2D(filters=64, kernel_size=3, strides=1, padding='same', name='c2') | |
c3 = Convolution2D(filters=128, kernel_size=4, strides=2, padding='same', name='c3') | |
c4 = Convolution2D(filters=128, kernel_size=3, strides=1, padding='same', name='c4') | |
c5 = Convolution2D(filters=256, kernel_size=4, strides=2, padding='same', name='c5') | |
c6 = Convolution2D(filters=256, kernel_size=3, strides=1, padding='same', name='c6') | |
c7 = Convolution2D(filters=512, kernel_size=4, strides=2, padding='same', name='c7') | |
c8 = Convolution2D(filters=512, kernel_size=3, strides=1, padding='same', name='c8') | |
bnc0 = BatchNormalization(axis=3, name='bnc0') | |
bnc1 = BatchNormalization(axis=3, name='bnc1') | |
bnc2 = BatchNormalization(axis=3, name='bnc2') | |
bnc3 = BatchNormalization(axis=3, name='bnc3') | |
bnc4 = BatchNormalization(axis=3, name='bnc4') | |
bnc5 = BatchNormalization(axis=3, name='bnc5') | |
bnc6 = BatchNormalization(axis=3, name='bnc6') | |
bnc7 = BatchNormalization(axis=3, name='bnc7') | |
bnc8 = BatchNormalization(axis=3, name='bnc8') | |
dc8 = Conv2DTranspose(filters=512, kernel_size=4, strides=2, padding='same', name='dc8_') | |
dc7 = Convolution2D(filters=256, kernel_size=3, strides=1, padding='same', name='dc7') | |
dc6 = Conv2DTranspose(filters=256, kernel_size=4, strides=2, padding='same', name='dc6_') | |
dc5 = Convolution2D(filters=128, kernel_size=3, strides=1, padding='same', name='dc5') | |
dc4 = Conv2DTranspose(filters=128, kernel_size=4, strides=2, padding='same', name='dc4_') | |
dc3 = Convolution2D(filters=64, kernel_size=3, strides=1, padding='same', name='dc3') | |
dc2 = Conv2DTranspose(filters=64, kernel_size=4, strides=2, padding='same', name='dc2_') | |
dc1 = Convolution2D(filters=32, kernel_size=3, strides=1, padding='same', name='dc1') | |
dc0 = Convolution2D(filters=outs, kernel_size=3, strides=1, padding='same', name='dc0') | |
bnd1 = BatchNormalization(axis=3, name='bnd1') | |
bnd2 = BatchNormalization(axis=3, name='bnd2') | |
bnd3 = BatchNormalization(axis=3, name='bnd3') | |
bnd4 = BatchNormalization(axis=3, name='bnd4') | |
bnd5 = BatchNormalization(axis=3, name='bnd5') | |
bnd6 = BatchNormalization(axis=3, name='bnd6') | |
bnd7 = BatchNormalization(axis=3, name='bnd7') | |
bnd8 = BatchNormalization(axis=3, name='bnd8') | |
x = Input(shape=(128, 128, channels)) | |
e0 = relu(bnc0(c0(x), training = False)) | |
e1 = relu(bnc1(c1(e0), training = False)) | |
e2 = relu(bnc2(c2(e1), training = False)) | |
e3 = relu(bnc3(c3(e2), training = False)) | |
e4 = relu(bnc4(c4(e3), training = False)) | |
e5 = relu(bnc5(c5(e4), training = False)) | |
e6 = relu(bnc6(c6(e5), training = False)) | |
e7 = relu(bnc7(c7(e6), training = False)) | |
e8 = relu(bnc8(c8(e7), training = False)) | |
d8 = relu(bnd8(dc8(concat([e7, e8])), training = False)) | |
d7 = relu(bnd7(dc7(d8), training = False)) | |
d6 = relu(bnd6(dc6(concat([e6, d7])), training = False)) | |
d5 = relu(bnd5(dc5(d6), training = False)) | |
d4 = relu(bnd4(dc4(concat([e4, d5])), training = False)) | |
d3 = relu(bnd3(dc3(d4), training = False)) | |
d2 = relu(bnd2(dc2(concat([e2, d3])), training = False)) | |
d1 = relu(bnd1(dc1(d2), training = False)) | |
d0 = dc0(concat([e0, d1])) | |
model = Model(inputs=x,outputs=d0) | |
return model | |
def make_diff_net(): | |
def conv(x, filters, name): | |
return Conv2D(filters=filters, strides=(1, 1), kernel_size=(3, 3), padding='same', name=name)(x) | |
def relu(x): | |
return Activation('relu')(x) | |
def lrelu(x): | |
return LeakyReLU(alpha=0.1)(x) | |
def r_block(x, filters, name=None): | |
return relu(conv(relu(conv(x, filters, None if name is None else name + '_c1')), filters, | |
None if name is None else name + '_c2')) | |
def cat(a, b): | |
return Concatenate()([UpSampling2D((2, 2))(a), b]) | |
def dog(x): | |
down = AveragePooling2D((2, 2))(x) | |
up = UpSampling2D((2, 2))(down) | |
diff = Lambda(lambda p: p[0] - p[1])([x, up]) | |
return down, diff | |
ip = Input(shape=(512, 512, 3)) | |
c512 = r_block(ip, 16, 'c512') | |
c256, l512 = dog(c512) | |
c256 = r_block(c256, 32, 'c256') | |
c128, l256 = dog(c256) | |
c128 = r_block(c128, 64, 'c128') | |
c64, l128 = dog(c128) | |
c64 = r_block(c64, 128, 'c64') | |
c32, l64 = dog(c64) | |
c32 = r_block(c32, 256, 'c32') | |
c16, l32 = dog(c32) | |
c16 = r_block(c16, 512, 'c16') | |
d32 = cat(c16, l32) | |
d32 = r_block(d32, 256, 'd32') | |
d64 = cat(d32, l64) | |
d64 = r_block(d64, 128, 'd64') | |
d128 = cat(d64, l128) | |
d128 = r_block(d128, 64, 'd128') | |
d256 = cat(d128, l256) | |
d256 = r_block(d256, 32, 'd256') | |
d512 = cat(d256, l512) | |
d512 = r_block(d512, 16, 'd512') | |
op = conv(d512, 1, 'op') | |
return Model(inputs=ip, outputs=op) | |
def make_wnet256(): | |
def conv(x, filters): | |
return Conv2D(filters=filters, strides=(1, 1), kernel_size=(3, 3), padding='same')(x) | |
def relu(x): | |
return Activation('relu')(x) | |
def lrelu(x): | |
return LeakyReLU(alpha=0.1)(x) | |
def r_block(x, filters): | |
return relu(conv(relu(conv(x, filters)), filters)) | |
def res_block(x, filters): | |
return relu(Add()([x, conv(relu(conv(x, filters)), filters)])) | |
def cat(a, b): | |
return Concatenate()([UpSampling2D((2, 2))(a), b]) | |
def dog(x): | |
down = AveragePooling2D((2, 2))(x) | |
up = UpSampling2D((2, 2))(down) | |
diff = Lambda(lambda p: p[0] - p[1])([x, up]) | |
return down, diff | |
ip_sketch = Input(shape=(256, 256, 1)) | |
ip_color = Input(shape=(256, 256, 3)) | |
c256 = r_block(ip_sketch, 32) | |
c128, l256 = dog(c256) | |
c128 = r_block(c128, 64) | |
c64, l128 = dog(c128) | |
c64 = r_block(c64, 128) | |
c32, l64 = dog(c64) | |
c32 = r_block(Concatenate()([c32, AveragePooling2D((8, 8))(ip_color)]), 256) | |
c32 = res_block(c32, 256) | |
c32 = res_block(c32, 256) | |
c32 = res_block(c32, 256) | |
c32 = res_block(c32, 256) | |
c32 = res_block(c32, 256) | |
c32 = res_block(c32, 256) | |
c32 = res_block(c32, 256) | |
c32 = res_block(c32, 256) | |
c32 = res_block(c32, 256) | |
d64 = cat(c32, l64) | |
d64 = r_block(d64, 128) | |
d128 = cat(d64, l128) | |
d128 = r_block(d128, 64) | |
d256 = cat(d128, l256) | |
d256 = r_block(d256, 32) | |
op = conv(d256, 3) | |
return Model(inputs=[ip_sketch, ip_color], outputs=op) | |
def make_unet512(): | |
def conv(x, filters, strides=(1, 1), kernel_size=(3, 3)): | |
return Conv2D(filters=filters, strides=strides, kernel_size=kernel_size, padding='same')(x) | |
def donv(x, filters, strides=(2, 2), kernel_size=(4, 4)): | |
return Conv2DTranspose(filters=filters, strides=strides, kernel_size=kernel_size, padding='same')(x) | |
def relu(x): | |
return Activation('relu')(x) | |
def sigmoid(x): | |
return Activation('sigmoid')(x) | |
def norm(x): | |
return InstanceNormalization(axis=3)(x) | |
def cat(a, b): | |
return Concatenate()([a, b]) | |
def res(x, filters): | |
c1 = relu(norm(conv(x, filters // 2))) | |
c2 = norm(conv(c1, filters)) | |
ad = Add()([x, c2]) | |
return relu(ad) | |
ip = Input(shape=(512, 512, 3)) | |
c512 = relu(norm(conv(ip, 16, strides=(1, 1), kernel_size=(3, 3)))) | |
c256 = relu(norm(conv(c512, 32, strides=(2, 2), kernel_size=(4, 4)))) | |
c128 = relu(norm(conv(c256, 64, strides=(2, 2), kernel_size=(4, 4)))) | |
c128 = res(c128, 64) | |
c64 = relu(norm(conv(c128, 128, strides=(2, 2), kernel_size=(4, 4)))) | |
c64 = res(c64, 128) | |
c64 = res(c64, 128) | |
c32 = relu(norm(conv(c64, 256, strides=(2, 2), kernel_size=(4, 4)))) | |
c32 = res(c32, 256) | |
c32 = res(c32, 256) | |
c32 = res(c32, 256) | |
c32 = res(c32, 256) | |
c32 = res(c32, 256) | |
c32 = res(c32, 256) | |
c32 = res(c32, 256) | |
c32 = res(c32, 256) | |
c16 = relu(norm(conv(c32, 512, strides=(2, 2), kernel_size=(4, 4)))) | |
c16 = res(c16, 512) | |
c16 = res(c16, 512) | |
c16 = res(c16, 512) | |
c16 = res(c16, 512) | |
c16 = res(c16, 512) | |
c16 = res(c16, 512) | |
c16 = res(c16, 512) | |
c16 = res(c16, 512) | |
c8 = relu(norm(conv(c16, 1024, strides=(2, 2), kernel_size=(4, 4)))) | |
c8 = res(c8, 1024) | |
c8 = res(c8, 1024) | |
c8 = res(c8, 1024) | |
c8 = res(c8, 1024) | |
e16 = relu(norm(donv(c8, 512, strides=(2, 2), kernel_size=(4, 4)))) | |
e16 = cat(e16, c16) | |
e16 = relu(norm(conv(e16, 512, strides=(1, 1), kernel_size=(3, 3)))) | |
e32 = relu(norm(donv(e16, 256, strides=(2, 2), kernel_size=(4, 4)))) | |
e32 = cat(e32, c32) | |
e32 = relu(norm(conv(e32, 256, strides=(1, 1), kernel_size=(3, 3)))) | |
e64 = relu(norm(donv(e32, 128, strides=(2, 2), kernel_size=(4, 4)))) | |
e64 = cat(e64, c64) | |
e64 = relu(norm(conv(e64, 128, strides=(1, 1), kernel_size=(3, 3)))) | |
e128 = relu(norm(donv(e64, 64, strides=(2, 2), kernel_size=(4, 4)))) | |
e128 = cat(e128, c128) | |
e128 = relu(norm(conv(e128, 64, strides=(1, 1), kernel_size=(3, 3)))) | |
e256 = relu(norm(donv(e128, 32, strides=(2, 2), kernel_size=(4, 4)))) | |
e256 = cat(e256, c256) | |
e256 = relu(norm(conv(e256, 32, strides=(1, 1), kernel_size=(3, 3)))) | |
e512 = relu(norm(donv(e256, 16, strides=(2, 2), kernel_size=(4, 4)))) | |
e512 = cat(e512, c512) | |
e512 = relu(norm(conv(e512, 16, strides=(1, 1), kernel_size=(3, 3)))) | |
ot = sigmoid(conv(e512, 1)) | |
return Model(inputs=ip, outputs=ot) | |