Spaces:
Runtime error
Runtime error
| import time | |
| import os | |
| import argparse | |
| import cv2 | |
| import numpy as np | |
| import tensorflow as tf | |
| import neuralgym as ng | |
| from inpaint_model import InpaintCAModel | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| '--flist', default='', type=str, | |
| help='The filenames of image to be processed: input, mask, output.') | |
| parser.add_argument( | |
| '--image_height', default=-1, type=int, | |
| help='The height of images should be defined, otherwise batch mode is not' | |
| ' supported.') | |
| parser.add_argument( | |
| '--image_width', default=-1, type=int, | |
| help='The width of images should be defined, otherwise batch mode is not' | |
| ' supported.') | |
| parser.add_argument( | |
| '--checkpoint_dir', default='', type=str, | |
| help='The directory of tensorflow checkpoint.') | |
| if __name__ == "__main__": | |
| ng.get_gpus(1) | |
| # os.environ['CUDA_VISIBLE_DEVICES'] ='' | |
| args = parser.parse_args() | |
| sess_config = tf.ConfigProto() | |
| sess_config.gpu_options.allow_growth = True | |
| sess = tf.Session(config=sess_config) | |
| model = InpaintCAModel() | |
| input_image_ph = tf.placeholder( | |
| tf.float32, shape=(1, args.image_height, args.image_width*3, 3)) | |
| output = model.build_server_graph(input_image_ph) | |
| output = (output + 1.) * 127.5 | |
| output = tf.reverse(output, [-1]) | |
| output = tf.saturate_cast(output, tf.uint8) | |
| vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) | |
| assign_ops = [] | |
| for var in vars_list: | |
| vname = var.name | |
| from_name = vname | |
| var_value = tf.contrib.framework.load_variable( | |
| args.checkpoint_dir, from_name) | |
| assign_ops.append(tf.assign(var, var_value)) | |
| sess.run(assign_ops) | |
| print('Model loaded.') | |
| with open(args.flist, 'r') as f: | |
| lines = f.read().splitlines() | |
| t = time.time() | |
| for line in lines: | |
| # for i in range(100): | |
| image, mask, out = line.split() | |
| base = os.path.basename(mask) | |
| guidance = cv2.imread(image[:-4] + '_edge.jpg') | |
| image = cv2.imread(image) | |
| mask = cv2.imread(mask) | |
| image = cv2.resize(image, (args.image_width, args.image_height)) | |
| guidance = cv2.resize(guidance, (args.image_width, args.image_height)) | |
| mask = cv2.resize(mask, (args.image_width, args.image_height)) | |
| # cv2.imwrite(out, image*(1-mask/255.) + mask) | |
| # # continue | |
| # image = np.zeros((128, 256, 3)) | |
| # mask = np.zeros((128, 256, 3)) | |
| assert image.shape == mask.shape | |
| h, w, _ = image.shape | |
| grid = 4 | |
| image = image[:h//grid*grid, :w//grid*grid, :] | |
| mask = mask[:h//grid*grid, :w//grid*grid, :] | |
| guidance = guidance[:h//grid*grid, :w//grid*grid, :] | |
| print('Shape of image: {}'.format(image.shape)) | |
| image = np.expand_dims(image, 0) | |
| guidance = np.expand_dims(guidance, 0) | |
| mask = np.expand_dims(mask, 0) | |
| input_image = np.concatenate([image, guidance, mask], axis=2) | |
| # load pretrained model | |
| result = sess.run(output, feed_dict={input_image_ph: input_image}) | |
| print('Processed: {}'.format(out)) | |
| cv2.imwrite(out, result[0][:, :, ::-1]) | |
| print('Time total: {}'.format(time.time() - t)) | |