HighCWu's picture
update tf version
e6bfa26
from keras.layers import Conv2D, Activation, Input, Concatenate, LeakyReLU, Lambda, AveragePooling2D, UpSampling2D, Convolution2D, BatchNormalization, Conv2DTranspose, Add
from keras.models import Model
from InstanceNorm import InstanceNormalization
def make_standard_UNET(channels,outs):
def relu(x):
return Activation('relu')(x)
def concat(x):
return Concatenate()(x)
c0 = Convolution2D(filters=32, kernel_size=3, strides=1, padding='same', name='c0')
c1 = Convolution2D(filters=64, kernel_size=4, strides=2, padding='same', name='c1')
c2 = Convolution2D(filters=64, kernel_size=3, strides=1, padding='same', name='c2')
c3 = Convolution2D(filters=128, kernel_size=4, strides=2, padding='same', name='c3')
c4 = Convolution2D(filters=128, kernel_size=3, strides=1, padding='same', name='c4')
c5 = Convolution2D(filters=256, kernel_size=4, strides=2, padding='same', name='c5')
c6 = Convolution2D(filters=256, kernel_size=3, strides=1, padding='same', name='c6')
c7 = Convolution2D(filters=512, kernel_size=4, strides=2, padding='same', name='c7')
c8 = Convolution2D(filters=512, kernel_size=3, strides=1, padding='same', name='c8')
bnc0 = BatchNormalization(axis=3, name='bnc0')
bnc1 = BatchNormalization(axis=3, name='bnc1')
bnc2 = BatchNormalization(axis=3, name='bnc2')
bnc3 = BatchNormalization(axis=3, name='bnc3')
bnc4 = BatchNormalization(axis=3, name='bnc4')
bnc5 = BatchNormalization(axis=3, name='bnc5')
bnc6 = BatchNormalization(axis=3, name='bnc6')
bnc7 = BatchNormalization(axis=3, name='bnc7')
bnc8 = BatchNormalization(axis=3, name='bnc8')
dc8 = Conv2DTranspose(filters=512, kernel_size=4, strides=2, padding='same', name='dc8_')
dc7 = Convolution2D(filters=256, kernel_size=3, strides=1, padding='same', name='dc7')
dc6 = Conv2DTranspose(filters=256, kernel_size=4, strides=2, padding='same', name='dc6_')
dc5 = Convolution2D(filters=128, kernel_size=3, strides=1, padding='same', name='dc5')
dc4 = Conv2DTranspose(filters=128, kernel_size=4, strides=2, padding='same', name='dc4_')
dc3 = Convolution2D(filters=64, kernel_size=3, strides=1, padding='same', name='dc3')
dc2 = Conv2DTranspose(filters=64, kernel_size=4, strides=2, padding='same', name='dc2_')
dc1 = Convolution2D(filters=32, kernel_size=3, strides=1, padding='same', name='dc1')
dc0 = Convolution2D(filters=outs, kernel_size=3, strides=1, padding='same', name='dc0')
bnd1 = BatchNormalization(axis=3, name='bnd1')
bnd2 = BatchNormalization(axis=3, name='bnd2')
bnd3 = BatchNormalization(axis=3, name='bnd3')
bnd4 = BatchNormalization(axis=3, name='bnd4')
bnd5 = BatchNormalization(axis=3, name='bnd5')
bnd6 = BatchNormalization(axis=3, name='bnd6')
bnd7 = BatchNormalization(axis=3, name='bnd7')
bnd8 = BatchNormalization(axis=3, name='bnd8')
x = Input(shape=(128, 128, channels))
e0 = relu(bnc0(c0(x), training = False))
e1 = relu(bnc1(c1(e0), training = False))
e2 = relu(bnc2(c2(e1), training = False))
e3 = relu(bnc3(c3(e2), training = False))
e4 = relu(bnc4(c4(e3), training = False))
e5 = relu(bnc5(c5(e4), training = False))
e6 = relu(bnc6(c6(e5), training = False))
e7 = relu(bnc7(c7(e6), training = False))
e8 = relu(bnc8(c8(e7), training = False))
d8 = relu(bnd8(dc8(concat([e7, e8])), training = False))
d7 = relu(bnd7(dc7(d8), training = False))
d6 = relu(bnd6(dc6(concat([e6, d7])), training = False))
d5 = relu(bnd5(dc5(d6), training = False))
d4 = relu(bnd4(dc4(concat([e4, d5])), training = False))
d3 = relu(bnd3(dc3(d4), training = False))
d2 = relu(bnd2(dc2(concat([e2, d3])), training = False))
d1 = relu(bnd1(dc1(d2), training = False))
d0 = dc0(concat([e0, d1]))
model = Model(inputs=x,outputs=d0)
return model
def make_diff_net():
def conv(x, filters, name):
return Conv2D(filters=filters, strides=(1, 1), kernel_size=(3, 3), padding='same', name=name)(x)
def relu(x):
return Activation('relu')(x)
def lrelu(x):
return LeakyReLU(alpha=0.1)(x)
def r_block(x, filters, name=None):
return relu(conv(relu(conv(x, filters, None if name is None else name + '_c1')), filters,
None if name is None else name + '_c2'))
def cat(a, b):
return Concatenate()([UpSampling2D((2, 2))(a), b])
def dog(x):
down = AveragePooling2D((2, 2))(x)
up = UpSampling2D((2, 2))(down)
diff = Lambda(lambda p: p[0] - p[1])([x, up])
return down, diff
ip = Input(shape=(512, 512, 3))
c512 = r_block(ip, 16, 'c512')
c256, l512 = dog(c512)
c256 = r_block(c256, 32, 'c256')
c128, l256 = dog(c256)
c128 = r_block(c128, 64, 'c128')
c64, l128 = dog(c128)
c64 = r_block(c64, 128, 'c64')
c32, l64 = dog(c64)
c32 = r_block(c32, 256, 'c32')
c16, l32 = dog(c32)
c16 = r_block(c16, 512, 'c16')
d32 = cat(c16, l32)
d32 = r_block(d32, 256, 'd32')
d64 = cat(d32, l64)
d64 = r_block(d64, 128, 'd64')
d128 = cat(d64, l128)
d128 = r_block(d128, 64, 'd128')
d256 = cat(d128, l256)
d256 = r_block(d256, 32, 'd256')
d512 = cat(d256, l512)
d512 = r_block(d512, 16, 'd512')
op = conv(d512, 1, 'op')
return Model(inputs=ip, outputs=op)
def make_wnet256():
def conv(x, filters):
return Conv2D(filters=filters, strides=(1, 1), kernel_size=(3, 3), padding='same')(x)
def relu(x):
return Activation('relu')(x)
def lrelu(x):
return LeakyReLU(alpha=0.1)(x)
def r_block(x, filters):
return relu(conv(relu(conv(x, filters)), filters))
def res_block(x, filters):
return relu(Add()([x, conv(relu(conv(x, filters)), filters)]))
def cat(a, b):
return Concatenate()([UpSampling2D((2, 2))(a), b])
def dog(x):
down = AveragePooling2D((2, 2))(x)
up = UpSampling2D((2, 2))(down)
diff = Lambda(lambda p: p[0] - p[1])([x, up])
return down, diff
ip_sketch = Input(shape=(256, 256, 1))
ip_color = Input(shape=(256, 256, 3))
c256 = r_block(ip_sketch, 32)
c128, l256 = dog(c256)
c128 = r_block(c128, 64)
c64, l128 = dog(c128)
c64 = r_block(c64, 128)
c32, l64 = dog(c64)
c32 = r_block(Concatenate()([c32, AveragePooling2D((8, 8))(ip_color)]), 256)
c32 = res_block(c32, 256)
c32 = res_block(c32, 256)
c32 = res_block(c32, 256)
c32 = res_block(c32, 256)
c32 = res_block(c32, 256)
c32 = res_block(c32, 256)
c32 = res_block(c32, 256)
c32 = res_block(c32, 256)
c32 = res_block(c32, 256)
d64 = cat(c32, l64)
d64 = r_block(d64, 128)
d128 = cat(d64, l128)
d128 = r_block(d128, 64)
d256 = cat(d128, l256)
d256 = r_block(d256, 32)
op = conv(d256, 3)
return Model(inputs=[ip_sketch, ip_color], outputs=op)
def make_unet512():
def conv(x, filters, strides=(1, 1), kernel_size=(3, 3)):
return Conv2D(filters=filters, strides=strides, kernel_size=kernel_size, padding='same')(x)
def donv(x, filters, strides=(2, 2), kernel_size=(4, 4)):
return Conv2DTranspose(filters=filters, strides=strides, kernel_size=kernel_size, padding='same')(x)
def relu(x):
return Activation('relu')(x)
def sigmoid(x):
return Activation('sigmoid')(x)
def norm(x):
return InstanceNormalization(axis=3)(x)
def cat(a, b):
return Concatenate()([a, b])
def res(x, filters):
c1 = relu(norm(conv(x, filters // 2)))
c2 = norm(conv(c1, filters))
ad = Add()([x, c2])
return relu(ad)
ip = Input(shape=(512, 512, 3))
c512 = relu(norm(conv(ip, 16, strides=(1, 1), kernel_size=(3, 3))))
c256 = relu(norm(conv(c512, 32, strides=(2, 2), kernel_size=(4, 4))))
c128 = relu(norm(conv(c256, 64, strides=(2, 2), kernel_size=(4, 4))))
c128 = res(c128, 64)
c64 = relu(norm(conv(c128, 128, strides=(2, 2), kernel_size=(4, 4))))
c64 = res(c64, 128)
c64 = res(c64, 128)
c32 = relu(norm(conv(c64, 256, strides=(2, 2), kernel_size=(4, 4))))
c32 = res(c32, 256)
c32 = res(c32, 256)
c32 = res(c32, 256)
c32 = res(c32, 256)
c32 = res(c32, 256)
c32 = res(c32, 256)
c32 = res(c32, 256)
c32 = res(c32, 256)
c16 = relu(norm(conv(c32, 512, strides=(2, 2), kernel_size=(4, 4))))
c16 = res(c16, 512)
c16 = res(c16, 512)
c16 = res(c16, 512)
c16 = res(c16, 512)
c16 = res(c16, 512)
c16 = res(c16, 512)
c16 = res(c16, 512)
c16 = res(c16, 512)
c8 = relu(norm(conv(c16, 1024, strides=(2, 2), kernel_size=(4, 4))))
c8 = res(c8, 1024)
c8 = res(c8, 1024)
c8 = res(c8, 1024)
c8 = res(c8, 1024)
e16 = relu(norm(donv(c8, 512, strides=(2, 2), kernel_size=(4, 4))))
e16 = cat(e16, c16)
e16 = relu(norm(conv(e16, 512, strides=(1, 1), kernel_size=(3, 3))))
e32 = relu(norm(donv(e16, 256, strides=(2, 2), kernel_size=(4, 4))))
e32 = cat(e32, c32)
e32 = relu(norm(conv(e32, 256, strides=(1, 1), kernel_size=(3, 3))))
e64 = relu(norm(donv(e32, 128, strides=(2, 2), kernel_size=(4, 4))))
e64 = cat(e64, c64)
e64 = relu(norm(conv(e64, 128, strides=(1, 1), kernel_size=(3, 3))))
e128 = relu(norm(donv(e64, 64, strides=(2, 2), kernel_size=(4, 4))))
e128 = cat(e128, c128)
e128 = relu(norm(conv(e128, 64, strides=(1, 1), kernel_size=(3, 3))))
e256 = relu(norm(donv(e128, 32, strides=(2, 2), kernel_size=(4, 4))))
e256 = cat(e256, c256)
e256 = relu(norm(conv(e256, 32, strides=(1, 1), kernel_size=(3, 3))))
e512 = relu(norm(donv(e256, 16, strides=(2, 2), kernel_size=(4, 4))))
e512 = cat(e512, c512)
e512 = relu(norm(conv(e512, 16, strides=(1, 1), kernel_size=(3, 3))))
ot = sigmoid(conv(e512, 1))
return Model(inputs=ip, outputs=ot)