Spaces:
Runtime error
Runtime error
| # Copyright 2016 The TensorFlow Authors All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| import os | |
| import numpy as np | |
| import logging | |
| import src.utils as utils | |
| import datasets.nav_env_config as nec | |
| from datasets import factory | |
| def adjust_args_for_mode(args, mode): | |
| if mode == 'train': | |
| args.control.train = True | |
| elif mode == 'val1': | |
| # Same settings as for training, to make sure nothing wonky is happening | |
| # there. | |
| args.control.test = True | |
| args.control.test_mode = 'val' | |
| args.navtask.task_params.batch_size = 32 | |
| elif mode == 'val2': | |
| # No data augmentation, not sampling but taking the argmax action, not | |
| # sampling from the ground truth at all. | |
| args.control.test = True | |
| args.arch.action_sample_type = 'argmax' | |
| args.arch.sample_gt_prob_type = 'zero' | |
| args.navtask.task_params.data_augment = \ | |
| utils.Foo(lr_flip=0, delta_angle=0, delta_xy=0, relight=False, | |
| relight_fast=False, structured=False) | |
| args.control.test_mode = 'val' | |
| args.navtask.task_params.batch_size = 32 | |
| elif mode == 'bench': | |
| # Actually testing the agent in settings that are kept same between | |
| # different runs. | |
| args.navtask.task_params.batch_size = 16 | |
| args.control.test = True | |
| args.arch.action_sample_type = 'argmax' | |
| args.arch.sample_gt_prob_type = 'zero' | |
| args.navtask.task_params.data_augment = \ | |
| utils.Foo(lr_flip=0, delta_angle=0, delta_xy=0, relight=False, | |
| relight_fast=False, structured=False) | |
| args.summary.test_iters = 250 | |
| args.control.only_eval_when_done = True | |
| args.control.reset_rng_seed = True | |
| args.control.test_mode = 'test' | |
| else: | |
| logging.fatal('Unknown mode: %s.', mode) | |
| assert(False) | |
| return args | |
| def get_solver_vars(solver_str): | |
| if solver_str == '': vals = []; | |
| else: vals = solver_str.split('_') | |
| ks = ['clip', 'dlw', 'long', 'typ', 'isdk', 'adam_eps', 'init_lr']; | |
| ks = ks[:len(vals)] | |
| # Gradient clipping or not. | |
| if len(vals) == 0: ks.append('clip'); vals.append('noclip'); | |
| # data loss weight. | |
| if len(vals) == 1: ks.append('dlw'); vals.append('dlw20') | |
| # how long to train for. | |
| if len(vals) == 2: ks.append('long'); vals.append('nolong') | |
| # Adam | |
| if len(vals) == 3: ks.append('typ'); vals.append('adam2') | |
| # reg loss wt | |
| if len(vals) == 4: ks.append('rlw'); vals.append('rlw1') | |
| # isd_k | |
| if len(vals) == 5: ks.append('isdk'); vals.append('isdk415') # 415, inflexion at 2.5k. | |
| # adam eps | |
| if len(vals) == 6: ks.append('adam_eps'); vals.append('aeps1en8') | |
| # init lr | |
| if len(vals) == 7: ks.append('init_lr'); vals.append('lr1en3') | |
| assert(len(vals) == 8) | |
| vars = utils.Foo() | |
| for k, v in zip(ks, vals): | |
| setattr(vars, k, v) | |
| logging.error('solver_vars: %s', vars) | |
| return vars | |
| def process_solver_str(solver_str): | |
| solver = utils.Foo( | |
| seed=0, learning_rate_decay=None, clip_gradient_norm=None, max_steps=None, | |
| initial_learning_rate=None, momentum=None, steps_per_decay=None, | |
| logdir=None, sync=False, adjust_lr_sync=True, wt_decay=0.0001, | |
| data_loss_wt=None, reg_loss_wt=None, freeze_conv=True, num_workers=1, | |
| task=0, ps_tasks=0, master='local', typ=None, momentum2=None, | |
| adam_eps=None) | |
| # Clobber with overrides from solver str. | |
| solver_vars = get_solver_vars(solver_str) | |
| solver.data_loss_wt = float(solver_vars.dlw[3:].replace('x', '.')) | |
| solver.adam_eps = float(solver_vars.adam_eps[4:].replace('x', '.').replace('n', '-')) | |
| solver.initial_learning_rate = float(solver_vars.init_lr[2:].replace('x', '.').replace('n', '-')) | |
| solver.reg_loss_wt = float(solver_vars.rlw[3:].replace('x', '.')) | |
| solver.isd_k = float(solver_vars.isdk[4:].replace('x', '.')) | |
| long = solver_vars.long | |
| if long == 'long': | |
| solver.steps_per_decay = 40000 | |
| solver.max_steps = 120000 | |
| elif long == 'long2': | |
| solver.steps_per_decay = 80000 | |
| solver.max_steps = 120000 | |
| elif long == 'nolong' or long == 'nol': | |
| solver.steps_per_decay = 20000 | |
| solver.max_steps = 60000 | |
| else: | |
| logging.fatal('solver_vars.long should be long, long2, nolong or nol.') | |
| assert(False) | |
| clip = solver_vars.clip | |
| if clip == 'noclip' or clip == 'nocl': | |
| solver.clip_gradient_norm = 0 | |
| elif clip[:4] == 'clip': | |
| solver.clip_gradient_norm = float(clip[4:].replace('x', '.')) | |
| else: | |
| logging.fatal('Unknown solver_vars.clip: %s', clip) | |
| assert(False) | |
| typ = solver_vars.typ | |
| if typ == 'adam': | |
| solver.typ = 'adam' | |
| solver.momentum = 0.9 | |
| solver.momentum2 = 0.999 | |
| solver.learning_rate_decay = 1.0 | |
| elif typ == 'adam2': | |
| solver.typ = 'adam' | |
| solver.momentum = 0.9 | |
| solver.momentum2 = 0.999 | |
| solver.learning_rate_decay = 0.1 | |
| elif typ == 'sgd': | |
| solver.typ = 'sgd' | |
| solver.momentum = 0.99 | |
| solver.momentum2 = None | |
| solver.learning_rate_decay = 0.1 | |
| else: | |
| logging.fatal('Unknown solver_vars.typ: %s', typ) | |
| assert(False) | |
| logging.error('solver: %s', solver) | |
| return solver | |
| def get_navtask_vars(navtask_str): | |
| if navtask_str == '': vals = [] | |
| else: vals = navtask_str.split('_') | |
| ks_all = ['dataset_name', 'modality', 'task', 'history', 'max_dist', | |
| 'num_steps', 'step_size', 'n_ori', 'aux_views', 'data_aug'] | |
| ks = ks_all[:len(vals)] | |
| # All data or not. | |
| if len(vals) == 0: ks.append('dataset_name'); vals.append('sbpd') | |
| # modality | |
| if len(vals) == 1: ks.append('modality'); vals.append('rgb') | |
| # semantic task? | |
| if len(vals) == 2: ks.append('task'); vals.append('r2r') | |
| # number of history frames. | |
| if len(vals) == 3: ks.append('history'); vals.append('h0') | |
| # max steps | |
| if len(vals) == 4: ks.append('max_dist'); vals.append('32') | |
| # num steps | |
| if len(vals) == 5: ks.append('num_steps'); vals.append('40') | |
| # step size | |
| if len(vals) == 6: ks.append('step_size'); vals.append('8') | |
| # n_ori | |
| if len(vals) == 7: ks.append('n_ori'); vals.append('4') | |
| # Auxiliary views. | |
| if len(vals) == 8: ks.append('aux_views'); vals.append('nv0') | |
| # Normal data augmentation as opposed to structured data augmentation (if set | |
| # to straug. | |
| if len(vals) == 9: ks.append('data_aug'); vals.append('straug') | |
| assert(len(vals) == 10) | |
| for i in range(len(ks)): | |
| assert(ks[i] == ks_all[i]) | |
| vars = utils.Foo() | |
| for k, v in zip(ks, vals): | |
| setattr(vars, k, v) | |
| logging.error('navtask_vars: %s', vals) | |
| return vars | |
| def process_navtask_str(navtask_str): | |
| navtask = nec.nav_env_base_config() | |
| # Clobber with overrides from strings. | |
| navtask_vars = get_navtask_vars(navtask_str) | |
| navtask.task_params.n_ori = int(navtask_vars.n_ori) | |
| navtask.task_params.max_dist = int(navtask_vars.max_dist) | |
| navtask.task_params.num_steps = int(navtask_vars.num_steps) | |
| navtask.task_params.step_size = int(navtask_vars.step_size) | |
| navtask.task_params.data_augment.delta_xy = int(navtask_vars.step_size)/2. | |
| n_aux_views_each = int(navtask_vars.aux_views[2]) | |
| aux_delta_thetas = np.concatenate((np.arange(n_aux_views_each) + 1, | |
| -1 -np.arange(n_aux_views_each))) | |
| aux_delta_thetas = aux_delta_thetas*np.deg2rad(navtask.camera_param.fov) | |
| navtask.task_params.aux_delta_thetas = aux_delta_thetas | |
| if navtask_vars.data_aug == 'aug': | |
| navtask.task_params.data_augment.structured = False | |
| elif navtask_vars.data_aug == 'straug': | |
| navtask.task_params.data_augment.structured = True | |
| else: | |
| logging.fatal('Unknown navtask_vars.data_aug %s.', navtask_vars.data_aug) | |
| assert(False) | |
| navtask.task_params.num_history_frames = int(navtask_vars.history[1:]) | |
| navtask.task_params.n_views = 1+navtask.task_params.num_history_frames | |
| navtask.task_params.goal_channels = int(navtask_vars.n_ori) | |
| if navtask_vars.task == 'hard': | |
| navtask.task_params.type = 'rng_rejection_sampling_many' | |
| navtask.task_params.rejection_sampling_M = 2000 | |
| navtask.task_params.min_dist = 10 | |
| elif navtask_vars.task == 'r2r': | |
| navtask.task_params.type = 'room_to_room_many' | |
| elif navtask_vars.task == 'ST': | |
| # Semantic task at hand. | |
| navtask.task_params.goal_channels = \ | |
| len(navtask.task_params.semantic_task.class_map_names) | |
| navtask.task_params.rel_goal_loc_dim = \ | |
| len(navtask.task_params.semantic_task.class_map_names) | |
| navtask.task_params.type = 'to_nearest_obj_acc' | |
| else: | |
| logging.fatal('navtask_vars.task: should be hard or r2r, ST') | |
| assert(False) | |
| if navtask_vars.modality == 'rgb': | |
| navtask.camera_param.modalities = ['rgb'] | |
| navtask.camera_param.img_channels = 3 | |
| elif navtask_vars.modality == 'd': | |
| navtask.camera_param.modalities = ['depth'] | |
| navtask.camera_param.img_channels = 2 | |
| navtask.task_params.img_height = navtask.camera_param.height | |
| navtask.task_params.img_width = navtask.camera_param.width | |
| navtask.task_params.modalities = navtask.camera_param.modalities | |
| navtask.task_params.img_channels = navtask.camera_param.img_channels | |
| navtask.task_params.img_fov = navtask.camera_param.fov | |
| navtask.dataset = factory.get_dataset(navtask_vars.dataset_name) | |
| return navtask | |