|
import os |
|
import cv2 |
|
import numpy as np |
|
import tensorflow as tf |
|
import wbc.network as network |
|
import wbc.guided_filter as guided_filter |
|
from tqdm import tqdm |
|
|
|
|
|
def resize_crop(image): |
|
h, w, c = np.shape(image) |
|
if min(h, w) > 720: |
|
if h > w: |
|
h, w = int(720 * h / w), 720 |
|
else: |
|
h, w = 720, int(720 * w / h) |
|
image = cv2.resize(image, (w, h), |
|
interpolation=cv2.INTER_AREA) |
|
h, w = (h // 8) * 8, (w // 8) * 8 |
|
image = image[:h, :w, :] |
|
return image |
|
|
|
|
|
def cartoonize(load_folder, save_folder, model_path): |
|
print(model_path) |
|
input_photo = tf.placeholder(tf.float32, [1, None, None, 3]) |
|
network_out = network.unet_generator(input_photo) |
|
final_out = guided_filter.guided_filter(input_photo, network_out, r=1, eps=5e-3) |
|
|
|
all_vars = tf.trainable_variables() |
|
gene_vars = [var for var in all_vars if 'generator' in var.name] |
|
saver = tf.train.Saver(var_list=gene_vars) |
|
|
|
config = tf.ConfigProto() |
|
config.gpu_options.allow_growth = True |
|
sess = tf.Session(config=config) |
|
|
|
sess.run(tf.global_variables_initializer()) |
|
saver.restore(sess, tf.train.latest_checkpoint(model_path)) |
|
name_list = os.listdir(load_folder) |
|
for name in tqdm(name_list): |
|
try: |
|
load_path = os.path.join(load_folder, name) |
|
save_path = os.path.join(save_folder, name) |
|
image = cv2.imread(load_path) |
|
image = resize_crop(image) |
|
batch_image = image.astype(np.float32) / 127.5 - 1 |
|
batch_image = np.expand_dims(batch_image, axis=0) |
|
output = sess.run(final_out, feed_dict={input_photo: batch_image}) |
|
output = (np.squeeze(output) + 1) * 127.5 |
|
output = np.clip(output, 0, 255).astype(np.uint8) |
|
cv2.imwrite(save_path, output) |
|
except: |
|
print('cartoonize {} failed'.format(load_path)) |
|
|
|
|
|
class Cartoonize: |
|
def __init__(self, model_path): |
|
print(model_path) |
|
self.input_photo = tf.placeholder(tf.float32, [1, None, None, 3]) |
|
network_out = network.unet_generator(self.input_photo) |
|
self.final_out = guided_filter.guided_filter(self.input_photo, network_out, r=1, eps=5e-3) |
|
|
|
all_vars = tf.trainable_variables() |
|
gene_vars = [var for var in all_vars if 'generator' in var.name] |
|
saver = tf.train.Saver(var_list=gene_vars) |
|
|
|
config = tf.ConfigProto() |
|
config.gpu_options.allow_growth = True |
|
self.sess = tf.Session(config=config) |
|
|
|
self.sess.run(tf.global_variables_initializer()) |
|
saver.restore(self.sess, tf.train.latest_checkpoint(model_path)) |
|
|
|
def run(self, load_folder, save_folder): |
|
name_list = os.listdir(load_folder) |
|
for name in tqdm(name_list): |
|
try: |
|
load_path = os.path.join(load_folder, name) |
|
save_path = os.path.join(save_folder, name) |
|
image = cv2.imread(load_path) |
|
image = resize_crop(image) |
|
batch_image = image.astype(np.float32) / 127.5 - 1 |
|
batch_image = np.expand_dims(batch_image, axis=0) |
|
output = self.sess.run(self.final_out, feed_dict={self.input_photo: batch_image}) |
|
output = (np.squeeze(output) + 1) * 127.5 |
|
output = np.clip(output, 0, 255).astype(np.uint8) |
|
cv2.imwrite(save_path, output) |
|
except: |
|
print('cartoonize {} failed'.format(load_path)) |
|
|
|
def run_sigle(self, load_path, save_path): |
|
try: |
|
image = cv2.imread(load_path) |
|
image = resize_crop(image) |
|
batch_image = image.astype(np.float32) / 127.5 - 1 |
|
batch_image = np.expand_dims(batch_image, axis=0) |
|
output = self.sess.run(self.final_out, feed_dict={self.input_photo: batch_image}) |
|
output = (np.squeeze(output) + 1) * 127.5 |
|
output = np.clip(output, 0, 255).astype(np.uint8) |
|
cv2.imwrite(save_path, output) |
|
except: |
|
print('cartoonize {} failed'.format(load_path)) |
|
|
|
|
|
if __name__ == '__main__': |
|
model_path = 'saved_models' |
|
load_folder = 'test_images' |
|
save_folder = 'cartoonized_images' |
|
if not os.path.exists(save_folder): |
|
os.mkdir(save_folder) |
|
cartoonize(load_folder, save_folder, model_path) |
|
|