Spaces:
Build error
Build error
| from __future__ import division | |
| from __future__ import print_function | |
| import os.path as osp | |
| import numpy as np | |
| from easydict import EasyDict as edict | |
| __C = edict() | |
| cfg = __C | |
| # Dataset name: flowers, birds | |
| __C.DATASET_NAME = "birds" | |
| __C.CONFIG_NAME = "" | |
| __C.DATA_DIR = "" | |
| __C.GPU_ID = 0 | |
| __C.CUDA = True | |
| __C.WORKERS = 6 | |
| __C.RNN_TYPE = "LSTM" # 'GRU' | |
| __C.B_VALIDATION = False | |
| __C.TREE = edict() | |
| __C.TREE.BRANCH_NUM = 3 | |
| __C.TREE.BASE_SIZE = 64 | |
| # Training options | |
| __C.TRAIN = edict() | |
| __C.TRAIN.BATCH_SIZE = 64 | |
| __C.TRAIN.MAX_EPOCH = 600 | |
| __C.TRAIN.SNAPSHOT_INTERVAL = 2000 | |
| __C.TRAIN.DISCRIMINATOR_LR = 2e-4 | |
| __C.TRAIN.GENERATOR_LR = 2e-4 | |
| __C.TRAIN.ENCODER_LR = 2e-4 | |
| __C.TRAIN.RNN_GRAD_CLIP = 0.25 | |
| __C.TRAIN.FLAG = True | |
| __C.TRAIN.NET_E = "" | |
| __C.TRAIN.NET_G = "" | |
| __C.TRAIN.B_NET_D = True | |
| __C.TRAIN.SMOOTH = edict() | |
| __C.TRAIN.SMOOTH.GAMMA1 = 5.0 | |
| __C.TRAIN.SMOOTH.GAMMA3 = 10.0 | |
| __C.TRAIN.SMOOTH.GAMMA2 = 5.0 | |
| __C.TRAIN.SMOOTH.LAMBDA = 1.0 | |
| # Modal options | |
| __C.GAN = edict() | |
| __C.GAN.DF_DIM = 64 | |
| __C.GAN.GF_DIM = 128 | |
| __C.GAN.Z_DIM = 100 | |
| __C.GAN.CONDITION_DIM = 100 | |
| __C.GAN.R_NUM = 2 | |
| __C.GAN.B_ATTENTION = True | |
| __C.GAN.B_DCGAN = False | |
| __C.TEXT = edict() | |
| __C.TEXT.CAPTIONS_PER_IMAGE = 10 | |
| __C.TEXT.EMBEDDING_DIM = 256 | |
| __C.TEXT.WORDS_NUM = 18 | |
| def _merge_a_into_b(a, b): | |
| """Merge config dictionary a into config dictionary b, clobbering the | |
| options in b whenever they are also specified in a. | |
| """ | |
| if type(a) is not edict: | |
| return | |
| for k, v in a.items(): | |
| # a must specify keys that are in b | |
| if k not in b: | |
| raise KeyError("{} is not a valid config key".format(k)) | |
| # the types must match, too | |
| old_type = type(b[k]) | |
| if old_type is not type(v): | |
| if isinstance(b[k], np.ndarray): | |
| v = np.array(v, dtype=b[k].dtype) | |
| else: | |
| raise ValueError( | |
| ("Type mismatch ({} vs. {}) " "for config key: {}").format( | |
| type(b[k]), type(v), k | |
| ) | |
| ) | |
| # recursively merge dicts | |
| if type(v) is edict: | |
| try: | |
| _merge_a_into_b(a[k], b[k]) | |
| except: | |
| print("Error under config key: {}".format(k)) | |
| raise | |
| else: | |
| b[k] = v | |
| def cfg_from_file(filename): | |
| """Load a config file and merge it into the default options.""" | |
| import yaml | |
| with open(filename, "r") as f: | |
| yaml_cfg = edict(yaml.load(f)) | |
| _merge_a_into_b(yaml_cfg, __C) | |