Spaces:
Runtime error
Runtime error
| # Copyright 2018 Google LLC | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # https://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================= | |
| """KeypointNet!! | |
| A reimplementation of 'Discovery of Latent 3D Keypoints via End-to-end | |
| Geometric Reasoning' keypoint network. Given a single 2D image of a known class, | |
| this network can predict a set of 3D keypoints that are consistent across | |
| viewing angles of the same object and across object instances. These keypoints | |
| and their detectors are discovered and learned automatically without | |
| keypoint location supervision. | |
| """ | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import math | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import os | |
| from scipy import misc | |
| import sys | |
| import tensorflow as tf | |
| import tensorflow.contrib.slim as slim | |
| import utils | |
| FLAGS = tf.app.flags.FLAGS | |
| tf.app.flags.DEFINE_boolean("predict", False, "Running inference if true") | |
| tf.app.flags.DEFINE_string( | |
| "input", | |
| "", | |
| "Input folder containing images") | |
| tf.app.flags.DEFINE_string("model_dir", None, "Estimator model_dir") | |
| tf.app.flags.DEFINE_string( | |
| "dset", | |
| "", | |
| "Path to the directory containing the dataset.") | |
| tf.app.flags.DEFINE_integer("steps", 200000, "Training steps") | |
| tf.app.flags.DEFINE_integer("batch_size", 8, "Size of mini-batch.") | |
| tf.app.flags.DEFINE_string( | |
| "hparams", "", | |
| "A comma-separated list of `name=value` hyperparameter values. This flag " | |
| "is used to override hyperparameter settings either when manually " | |
| "selecting hyperparameters or when using Vizier.") | |
| tf.app.flags.DEFINE_integer( | |
| "sync_replicas", -1, | |
| "If > 0, use SyncReplicasOptimizer and use this many replicas per sync.") | |
| # Fixed input size 128 x 128. | |
| vw = vh = 128 | |
| def create_input_fn(split, batch_size): | |
| """Returns input_fn for tf.estimator.Estimator. | |
| Reads tfrecords and construts input_fn for either training or eval. All | |
| tfrecords not in test.txt or dev.txt will be assigned to training set. | |
| Args: | |
| split: A string indicating the split. Can be either 'train' or 'validation'. | |
| batch_size: The batch size! | |
| Returns: | |
| input_fn for tf.estimator.Estimator. | |
| Raises: | |
| IOError: If test.txt or dev.txt are not found. | |
| """ | |
| if (not os.path.exists(os.path.join(FLAGS.dset, "test.txt")) or | |
| not os.path.exists(os.path.join(FLAGS.dset, "dev.txt"))): | |
| raise IOError("test.txt or dev.txt not found") | |
| with open(os.path.join(FLAGS.dset, "test.txt"), "r") as f: | |
| testset = [x.strip() for x in f.readlines()] | |
| with open(os.path.join(FLAGS.dset, "dev.txt"), "r") as f: | |
| validset = [x.strip() for x in f.readlines()] | |
| files = os.listdir(FLAGS.dset) | |
| filenames = [] | |
| for f in files: | |
| sp = os.path.splitext(f) | |
| if sp[1] != ".tfrecord" or sp[0] in testset: | |
| continue | |
| if ((split == "validation" and sp[0] in validset) or | |
| (split == "train" and sp[0] not in validset)): | |
| filenames.append(os.path.join(FLAGS.dset, f)) | |
| def input_fn(): | |
| """input_fn for tf.estimator.Estimator.""" | |
| def parser(serialized_example): | |
| """Parses a single tf.Example into image and label tensors.""" | |
| fs = tf.parse_single_example( | |
| serialized_example, | |
| features={ | |
| "img0": tf.FixedLenFeature([], tf.string), | |
| "img1": tf.FixedLenFeature([], tf.string), | |
| "mv0": tf.FixedLenFeature([16], tf.float32), | |
| "mvi0": tf.FixedLenFeature([16], tf.float32), | |
| "mv1": tf.FixedLenFeature([16], tf.float32), | |
| "mvi1": tf.FixedLenFeature([16], tf.float32), | |
| }) | |
| fs["img0"] = tf.div(tf.to_float(tf.image.decode_png(fs["img0"], 4)), 255) | |
| fs["img1"] = tf.div(tf.to_float(tf.image.decode_png(fs["img1"], 4)), 255) | |
| fs["img0"].set_shape([vh, vw, 4]) | |
| fs["img1"].set_shape([vh, vw, 4]) | |
| # fs["lr0"] = [fs["mv0"][0]] | |
| # fs["lr1"] = [fs["mv1"][0]] | |
| fs["lr0"] = tf.convert_to_tensor([fs["mv0"][0]]) | |
| fs["lr1"] = tf.convert_to_tensor([fs["mv1"][0]]) | |
| return fs | |
| np.random.shuffle(filenames) | |
| dataset = tf.data.TFRecordDataset(filenames) | |
| dataset = dataset.map(parser, num_parallel_calls=4) | |
| dataset = dataset.shuffle(400).repeat().batch(batch_size) | |
| dataset = dataset.prefetch(buffer_size=256) | |
| return dataset.make_one_shot_iterator().get_next(), None | |
| return input_fn | |
| class Transformer(object): | |
| """A utility for projecting 3D points to 2D coordinates and vice versa. | |
| 3D points are represented in 4D-homogeneous world coordinates. The pixel | |
| coordinates are represented in normalized device coordinates [-1, 1]. | |
| See https://learnopengl.com/Getting-started/Coordinate-Systems. | |
| """ | |
| def __get_matrix(self, lines): | |
| return np.array([[float(y) for y in x.strip().split(" ")] for x in lines]) | |
| def __read_projection_matrix(self, filename): | |
| if not os.path.exists(filename): | |
| filename = "/cns/vz-d/home/supasorn/datasets/cars/projection.txt" | |
| with open(filename, "r") as f: | |
| lines = f.readlines() | |
| return self.__get_matrix(lines) | |
| def __init__(self, w, h, dataset_dir): | |
| self.w = w | |
| self.h = h | |
| p = self.__read_projection_matrix(dataset_dir + "projection.txt") | |
| # transposed of inversed projection matrix. | |
| self.pinv_t = tf.constant([[1.0 / p[0, 0], 0, 0, | |
| 0], [0, 1.0 / p[1, 1], 0, 0], [0, 0, 1, 0], | |
| [0, 0, 0, 1]]) | |
| self.f = p[0, 0] | |
| def project(self, xyzw): | |
| """Projects homogeneous 3D coordinates to normalized device coordinates.""" | |
| z = xyzw[:, :, 2:3] + 1e-8 | |
| return tf.concat([-self.f * xyzw[:, :, :2] / z, z], axis=2) | |
| def unproject(self, xyz): | |
| """Unprojects normalized device coordinates with depth to 3D coordinates.""" | |
| z = xyz[:, :, 2:] | |
| xy = -xyz * z | |
| def batch_matmul(a, b): | |
| return tf.reshape( | |
| tf.matmul(tf.reshape(a, [-1, a.shape[2].value]), b), | |
| [-1, a.shape[1].value, a.shape[2].value]) | |
| return batch_matmul( | |
| tf.concat([xy[:, :, :2], z, tf.ones_like(z)], axis=2), self.pinv_t) | |
| def meshgrid(h): | |
| """Returns a meshgrid ranging from [-1, 1] in x, y axes.""" | |
| r = np.arange(0.5, h, 1) / (h / 2) - 1 | |
| ranx, rany = tf.meshgrid(r, -r) | |
| return tf.to_float(ranx), tf.to_float(rany) | |
| def estimate_rotation(xyz0, xyz1, pconf, noise): | |
| """Estimates the rotation between two sets of keypoints. | |
| The rotation is estimated by first subtracting mean from each set of keypoints | |
| and computing SVD of the covariance matrix. | |
| Args: | |
| xyz0: [batch, num_kp, 3] The first set of keypoints. | |
| xyz1: [batch, num_kp, 3] The second set of keypoints. | |
| pconf: [batch, num_kp] The weights used to compute the rotation estimate. | |
| noise: A number indicating the noise added to the keypoints. | |
| Returns: | |
| [batch, 3, 3] A batch of transposed 3 x 3 rotation matrices. | |
| """ | |
| xyz0 += tf.random_normal(tf.shape(xyz0), mean=0, stddev=noise) | |
| xyz1 += tf.random_normal(tf.shape(xyz1), mean=0, stddev=noise) | |
| pconf2 = tf.expand_dims(pconf, 2) | |
| cen0 = tf.reduce_sum(xyz0 * pconf2, 1, keepdims=True) | |
| cen1 = tf.reduce_sum(xyz1 * pconf2, 1, keepdims=True) | |
| x = xyz0 - cen0 | |
| y = xyz1 - cen1 | |
| cov = tf.matmul(tf.matmul(x, tf.matrix_diag(pconf), transpose_a=True), y) | |
| _, u, v = tf.svd(cov, full_matrices=True) | |
| d = tf.matrix_determinant(tf.matmul(v, u, transpose_b=True)) | |
| ud = tf.concat( | |
| [u[:, :, :-1], u[:, :, -1:] * tf.expand_dims(tf.expand_dims(d, 1), 1)], | |
| axis=2) | |
| return tf.matmul(ud, v, transpose_b=True) | |
| def relative_pose_loss(xyz0, xyz1, rot, pconf, noise): | |
| """Computes the relative pose loss (chordal, angular). | |
| Args: | |
| xyz0: [batch, num_kp, 3] The first set of keypoints. | |
| xyz1: [batch, num_kp, 3] The second set of keypoints. | |
| rot: [batch, 4, 4] The ground-truth rotation matrices. | |
| pconf: [batch, num_kp] The weights used to compute the rotation estimate. | |
| noise: A number indicating the noise added to the keypoints. | |
| Returns: | |
| A tuple (chordal loss, angular loss). | |
| """ | |
| r_transposed = estimate_rotation(xyz0, xyz1, pconf, noise) | |
| rotation = rot[:, :3, :3] | |
| frob_sqr = tf.reduce_sum(tf.square(r_transposed - rotation), axis=[1, 2]) | |
| frob = tf.sqrt(frob_sqr) | |
| return tf.reduce_mean(frob_sqr), \ | |
| 2.0 * tf.reduce_mean(tf.asin(tf.minimum(1.0, frob / (2 * math.sqrt(2))))) | |
| def separation_loss(xyz, delta): | |
| """Computes the separation loss. | |
| Args: | |
| xyz: [batch, num_kp, 3] Input keypoints. | |
| delta: A separation threshold. Incur 0 cost if the distance >= delta. | |
| Returns: | |
| The seperation loss. | |
| """ | |
| num_kp = tf.shape(xyz)[1] | |
| t1 = tf.tile(xyz, [1, num_kp, 1]) | |
| t2 = tf.reshape(tf.tile(xyz, [1, 1, num_kp]), tf.shape(t1)) | |
| diffsq = tf.square(t1 - t2) | |
| # -> [batch, num_kp ^ 2] | |
| lensqr = tf.reduce_sum(diffsq, axis=2) | |
| return (tf.reduce_sum(tf.maximum(-lensqr + delta, 0.0)) / tf.to_float( | |
| num_kp * FLAGS.batch_size * 2)) | |
| def consistency_loss(uv0, uv1, pconf): | |
| """Computes multi-view consistency loss between two sets of keypoints. | |
| Args: | |
| uv0: [batch, num_kp, 2] The first set of keypoint 2D coordinates. | |
| uv1: [batch, num_kp, 2] The second set of keypoint 2D coordinates. | |
| pconf: [batch, num_kp] The weights used to compute the rotation estimate. | |
| Returns: | |
| The consistency loss. | |
| """ | |
| # [batch, num_kp, 2] | |
| wd = tf.square(uv0 - uv1) * tf.expand_dims(pconf, 2) | |
| wd = tf.reduce_sum(wd, axis=[1, 2]) | |
| return tf.reduce_mean(wd) | |
| def variance_loss(probmap, ranx, rany, uv): | |
| """Computes the variance loss as part of Sillhouette consistency. | |
| Args: | |
| probmap: [batch, num_kp, h, w] The distribution map of keypoint locations. | |
| ranx: X-axis meshgrid. | |
| rany: Y-axis meshgrid. | |
| uv: [batch, num_kp, 2] Keypoint locations (in NDC). | |
| Returns: | |
| The variance loss. | |
| """ | |
| ran = tf.stack([ranx, rany], axis=2) | |
| sh = tf.shape(ran) | |
| # [batch, num_kp, vh, vw, 2] | |
| ran = tf.reshape(ran, [1, 1, sh[0], sh[1], 2]) | |
| sh = tf.shape(uv) | |
| uv = tf.reshape(uv, [sh[0], sh[1], 1, 1, 2]) | |
| diff = tf.reduce_sum(tf.square(uv - ran), axis=4) | |
| diff *= probmap | |
| return tf.reduce_mean(tf.reduce_sum(diff, axis=[2, 3])) | |
| def dilated_cnn(images, num_filters, is_training): | |
| """Constructs a base dilated convolutional network. | |
| Args: | |
| images: [batch, h, w, 3] Input RGB images. | |
| num_filters: The number of filters for all layers. | |
| is_training: True if this function is called during training. | |
| Returns: | |
| Output of this dilated CNN. | |
| """ | |
| net = images | |
| with slim.arg_scope( | |
| [slim.conv2d, slim.fully_connected], | |
| normalizer_fn=slim.batch_norm, | |
| activation_fn=lambda x: tf.nn.leaky_relu(x, alpha=0.1), | |
| normalizer_params={"is_training": is_training}): | |
| for i, r in enumerate([1, 1, 2, 4, 8, 16, 1, 2, 4, 8, 16, 1]): | |
| net = slim.conv2d(net, num_filters, [3, 3], rate=r, scope="dconv%d" % i) | |
| return net | |
| def orientation_network(images, num_filters, is_training): | |
| """Constructs a network that infers the orientation of an object. | |
| Args: | |
| images: [batch, h, w, 3] Input RGB images. | |
| num_filters: The number of filters for all layers. | |
| is_training: True if this function is called during training. | |
| Returns: | |
| Output of the orientation network. | |
| """ | |
| with tf.variable_scope("OrientationNetwork"): | |
| net = dilated_cnn(images, num_filters, is_training) | |
| modules = 2 | |
| prob = slim.conv2d(net, 2, [3, 3], rate=1, activation_fn=None) | |
| prob = tf.transpose(prob, [0, 3, 1, 2]) | |
| prob = tf.reshape(prob, [-1, modules, vh * vw]) | |
| prob = tf.nn.softmax(prob) | |
| ranx, rany = meshgrid(vh) | |
| prob = tf.reshape(prob, [-1, 2, vh, vw]) | |
| sx = tf.reduce_sum(prob * ranx, axis=[2, 3]) | |
| sy = tf.reduce_sum(prob * rany, axis=[2, 3]) # -> batch x modules | |
| out_xy = tf.reshape(tf.stack([sx, sy], -1), [-1, modules, 2]) | |
| return out_xy | |
| def keypoint_network(rgba, | |
| num_filters, | |
| num_kp, | |
| is_training, | |
| lr_gt=None, | |
| anneal=1): | |
| """Constructs our main keypoint network that predicts 3D keypoints. | |
| Args: | |
| rgba: [batch, h, w, 4] Input RGB images with alpha channel. | |
| num_filters: The number of filters for all layers. | |
| num_kp: The number of keypoints. | |
| is_training: True if this function is called during training. | |
| lr_gt: The groundtruth orientation flag used at the beginning of training. | |
| Then we linearly anneal in the prediction. | |
| anneal: A number between [0, 1] where 1 means using the ground-truth | |
| orientation and 0 means using our estimate. | |
| Returns: | |
| uv: [batch, num_kp, 2] 2D locations of keypoints. | |
| z: [batch, num_kp] The depth of keypoints. | |
| orient: [batch, 2, 2] Two 2D coordinates that correspond to [1, 0, 0] and | |
| [-1, 0, 0] in object space. | |
| sill: The Sillhouette loss. | |
| variance: The variance loss. | |
| prob_viz: A visualization of all predicted keypoints. | |
| prob_vizs: A list of visualizations of each keypoint. | |
| """ | |
| images = rgba[:, :, :, :3] | |
| # [batch, 1] | |
| orient = orientation_network(images, num_filters * 0.5, is_training) | |
| # [batch, 1] | |
| lr_estimated = tf.maximum(0.0, tf.sign(orient[:, 0, :1] - orient[:, 1, :1])) | |
| if lr_gt is None: | |
| lr = lr_estimated | |
| else: | |
| lr_gt = tf.maximum(0.0, tf.sign(lr_gt[:, :1])) | |
| lr = tf.round(lr_gt * anneal + lr_estimated * (1 - anneal)) | |
| lrtiled = tf.tile( | |
| tf.expand_dims(tf.expand_dims(lr, 1), 1), | |
| [1, images.shape[1], images.shape[2], 1]) | |
| images = tf.concat([images, lrtiled], axis=3) | |
| mask = rgba[:, :, :, 3] | |
| mask = tf.cast(tf.greater(mask, tf.zeros_like(mask)), dtype=tf.float32) | |
| net = dilated_cnn(images, num_filters, is_training) | |
| # The probability distribution map. | |
| prob = slim.conv2d( | |
| net, num_kp, [3, 3], rate=1, scope="conv_xy", activation_fn=None) | |
| # We added the fixed camera distance as a bias. | |
| z = -30 + slim.conv2d( | |
| net, num_kp, [3, 3], rate=1, scope="conv_z", activation_fn=None) | |
| prob = tf.transpose(prob, [0, 3, 1, 2]) | |
| z = tf.transpose(z, [0, 3, 1, 2]) | |
| prob = tf.reshape(prob, [-1, num_kp, vh * vw]) | |
| prob = tf.nn.softmax(prob, name="softmax") | |
| ranx, rany = meshgrid(vh) | |
| prob = tf.reshape(prob, [-1, num_kp, vh, vw]) | |
| # These are for visualizing the distribution maps. | |
| prob_viz = tf.expand_dims(tf.reduce_sum(prob, 1), 3) | |
| prob_vizs = [tf.expand_dims(prob[:, i, :, :], 3) for i in range(num_kp)] | |
| sx = tf.reduce_sum(prob * ranx, axis=[2, 3]) | |
| sy = tf.reduce_sum(prob * rany, axis=[2, 3]) # -> batch x num_kp | |
| # [batch, num_kp] | |
| sill = tf.reduce_sum(prob * tf.expand_dims(mask, 1), axis=[2, 3]) | |
| sill = tf.reduce_mean(-tf.log(sill + 1e-12)) | |
| z = tf.reduce_sum(prob * z, axis=[2, 3]) | |
| uv = tf.reshape(tf.stack([sx, sy], -1), [-1, num_kp, 2]) | |
| variance = variance_loss(prob, ranx, rany, uv) | |
| return uv, z, orient, sill, variance, prob_viz, prob_vizs | |
| def model_fn(features, labels, mode, hparams): | |
| """Returns model_fn for tf.estimator.Estimator.""" | |
| del labels | |
| is_training = (mode == tf.estimator.ModeKeys.TRAIN) | |
| t = Transformer(vw, vh, FLAGS.dset) | |
| def func1(x): | |
| return tf.transpose(tf.reshape(features[x], [-1, 4, 4]), [0, 2, 1]) | |
| mv = [func1("mv%d" % i) for i in range(2)] | |
| mvi = [func1("mvi%d" % i) for i in range(2)] | |
| uvz = [None] * 2 | |
| uvz_proj = [None] * 2 # uvz coordinates projected on to the other view. | |
| viz = [None] * 2 | |
| vizs = [None] * 2 | |
| loss_sill = 0 | |
| loss_variance = 0 | |
| loss_con = 0 | |
| loss_sep = 0 | |
| loss_lr = 0 | |
| for i in range(2): | |
| with tf.variable_scope("KeypointNetwork", reuse=i > 0): | |
| # anneal: 1 = using ground-truth, 0 = using our estimate orientation. | |
| anneal = tf.to_float(hparams.lr_anneal_end - tf.train.get_global_step()) | |
| anneal = tf.clip_by_value( | |
| anneal / (hparams.lr_anneal_end - hparams.lr_anneal_start), 0.0, 1.0) | |
| uv, z, orient, sill, variance, viz[i], vizs[i] = keypoint_network( | |
| features["img%d" % i], | |
| hparams.num_filters, | |
| hparams.num_kp, | |
| is_training, | |
| lr_gt=features["lr%d" % i], | |
| anneal=anneal) | |
| # x-positive/negative axes (dominant direction). | |
| xp_axis = tf.tile( | |
| tf.constant([[[1.0, 0, 0, 1], [-1.0, 0, 0, 1]]]), | |
| [tf.shape(orient)[0], 1, 1]) | |
| # [batch, 2, 4] = [batch, 2, 4] x [batch, 4, 4] | |
| xp = tf.matmul(xp_axis, mv[i]) | |
| # [batch, 2, 3] | |
| xp = t.project(xp) | |
| loss_lr += tf.losses.mean_squared_error(orient[:, :, :2], xp[:, :, :2]) | |
| loss_variance += variance | |
| loss_sill += sill | |
| uv = tf.reshape(uv, [-1, hparams.num_kp, 2]) | |
| z = tf.reshape(z, [-1, hparams.num_kp, 1]) | |
| # [batch, num_kp, 3] | |
| uvz[i] = tf.concat([uv, z], axis=2) | |
| world_coords = tf.matmul(t.unproject(uvz[i]), mvi[i]) | |
| # [batch, num_kp, 3] | |
| uvz_proj[i] = t.project(tf.matmul(world_coords, mv[1 - i])) | |
| pconf = tf.ones( | |
| [tf.shape(uv)[0], tf.shape(uv)[1]], dtype=tf.float32) / hparams.num_kp | |
| for i in range(2): | |
| loss_con += consistency_loss(uvz_proj[i][:, :, :2], uvz[1 - i][:, :, :2], | |
| pconf) | |
| loss_sep += separation_loss( | |
| t.unproject(uvz[i])[:, :, :3], hparams.sep_delta) | |
| chordal, angular = relative_pose_loss( | |
| t.unproject(uvz[0])[:, :, :3], | |
| t.unproject(uvz[1])[:, :, :3], tf.matmul(mvi[0], mv[1]), pconf, | |
| hparams.noise) | |
| loss = ( | |
| hparams.loss_pose * angular + | |
| hparams.loss_con * loss_con + | |
| hparams.loss_sep * loss_sep + | |
| hparams.loss_sill * loss_sill + | |
| hparams.loss_lr * loss_lr + | |
| hparams.loss_variance * loss_variance | |
| ) | |
| def touint8(img): | |
| return tf.cast(img * 255.0, tf.uint8) | |
| with tf.variable_scope("output"): | |
| tf.summary.image("0_img0", touint8(features["img0"][:, :, :, :3])) | |
| tf.summary.image("1_combined", viz[0]) | |
| for i in range(hparams.num_kp): | |
| tf.summary.image("2_f%02d" % i, vizs[0][i]) | |
| with tf.variable_scope("stats"): | |
| tf.summary.scalar("anneal", anneal) | |
| tf.summary.scalar("closs", loss_con) | |
| tf.summary.scalar("seploss", loss_sep) | |
| tf.summary.scalar("angular", angular) | |
| tf.summary.scalar("chordal", chordal) | |
| tf.summary.scalar("lrloss", loss_lr) | |
| tf.summary.scalar("sill", loss_sill) | |
| tf.summary.scalar("vloss", loss_variance) | |
| return { | |
| "loss": loss, | |
| "predictions": { | |
| "img0": features["img0"], | |
| "img1": features["img1"], | |
| "uvz0": uvz[0], | |
| "uvz1": uvz[1] | |
| }, | |
| "eval_metric_ops": { | |
| "closs": tf.metrics.mean(loss_con), | |
| "angular_loss": tf.metrics.mean(angular), | |
| "chordal_loss": tf.metrics.mean(chordal), | |
| } | |
| } | |
| def predict(input_folder, hparams): | |
| """Predicts keypoints on all images in input_folder.""" | |
| cols = plt.cm.get_cmap("rainbow")( | |
| np.linspace(0, 1.0, hparams.num_kp))[:, :4] | |
| img = tf.placeholder(tf.float32, shape=(1, 128, 128, 4)) | |
| with tf.variable_scope("KeypointNetwork"): | |
| ret = keypoint_network( | |
| img, hparams.num_filters, hparams.num_kp, False) | |
| uv = tf.reshape(ret[0], [-1, hparams.num_kp, 2]) | |
| z = tf.reshape(ret[1], [-1, hparams.num_kp, 1]) | |
| uvz = tf.concat([uv, z], axis=2) | |
| sess = tf.Session() | |
| saver = tf.train.Saver() | |
| ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir) | |
| print("loading model: ", ckpt.model_checkpoint_path) | |
| saver.restore(sess, ckpt.model_checkpoint_path) | |
| files = [x for x in os.listdir(input_folder) | |
| if x[-3:] in ["jpg", "png"]] | |
| output_folder = os.path.join(input_folder, "output") | |
| if not os.path.exists(output_folder): | |
| os.mkdir(output_folder) | |
| for f in files: | |
| orig = misc.imread(os.path.join(input_folder, f)).astype(float) / 255 | |
| if orig.shape[2] == 3: | |
| orig = np.concatenate((orig, np.ones_like(orig[:, :, :1])), axis=2) | |
| uv_ret = sess.run(uvz, feed_dict={img: np.expand_dims(orig, 0)}) | |
| utils.draw_ndc_points(orig, uv_ret.reshape(hparams.num_kp, 3), cols) | |
| misc.imsave(os.path.join(output_folder, f), orig) | |
| def _default_hparams(): | |
| """Returns default or overridden user-specified hyperparameters.""" | |
| hparams = tf.contrib.training.HParams( | |
| num_filters=64, # Number of filters. | |
| num_kp=10, # Numer of keypoints. | |
| loss_pose=0.2, # Pose Loss. | |
| loss_con=1.0, # Multiview consistency Loss. | |
| loss_sep=1.0, # Seperation Loss. | |
| loss_sill=1.0, # Sillhouette Loss. | |
| loss_lr=1.0, # Orientation Loss. | |
| loss_variance=0.5, # Variance Loss (part of Sillhouette loss). | |
| sep_delta=0.05, # Seperation threshold. | |
| noise=0.1, # Noise added during estimating rotation. | |
| learning_rate=1.0e-3, | |
| lr_anneal_start=30000, # When to anneal in the orientation prediction. | |
| lr_anneal_end=60000, # When to use the prediction completely. | |
| ) | |
| if FLAGS.hparams: | |
| hparams = hparams.parse(FLAGS.hparams) | |
| return hparams | |
| def main(argv): | |
| del argv | |
| hparams = _default_hparams() | |
| if FLAGS.predict: | |
| predict(FLAGS.input, hparams) | |
| else: | |
| utils.train_and_eval( | |
| model_dir=FLAGS.model_dir, | |
| model_fn=model_fn, | |
| input_fn=create_input_fn, | |
| hparams=hparams, | |
| steps=FLAGS.steps, | |
| batch_size=FLAGS.batch_size, | |
| save_checkpoints_secs=600, | |
| eval_throttle_secs=1800, | |
| eval_steps=5, | |
| sync_replicas=FLAGS.sync_replicas, | |
| ) | |
| if __name__ == "__main__": | |
| sys.excepthook = utils.colored_hook( | |
| os.path.dirname(os.path.realpath(__file__))) | |
| tf.app.run() | |