Spaces:
Runtime error
Runtime error
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| """Objects for storing configuration and passing config into binaries. | |
| Config class stores settings and hyperparameters for models, data, and anything | |
| else that may be specific to a particular run. | |
| """ | |
| import ast | |
| import itertools | |
| from six.moves import xrange | |
| class Config(dict): | |
| """Stores model configuration, hyperparameters, or dataset parameters.""" | |
| def __getattr__(self, attr): | |
| return self[attr] | |
| def __setattr__(self, attr, value): | |
| self[attr] = value | |
| def pretty_str(self, new_lines=True, indent=2, final_indent=0): | |
| prefix = (' ' * indent) if new_lines else '' | |
| final_prefix = (' ' * final_indent) if new_lines else '' | |
| kv = ['%s%s=%s' % (prefix, k, | |
| (repr(v) if not isinstance(v, Config) | |
| else v.pretty_str(new_lines=new_lines, | |
| indent=indent+2, | |
| final_indent=indent))) | |
| for k, v in self.items()] | |
| if new_lines: | |
| return 'Config(\n%s\n%s)' % (',\n'.join(kv), final_prefix) | |
| else: | |
| return 'Config(%s)' % ', '.join(kv) | |
| def _update_iterator(self, *args, **kwargs): | |
| """Convert mixed input into an iterator over (key, value) tuples. | |
| Follows the dict.update call signature. | |
| Args: | |
| *args: (Optional) Pass a dict or iterable of (key, value) 2-tuples as | |
| an unnamed argument. Only one unnamed argument allowed. | |
| **kwargs: (Optional) Pass (key, value) pairs as named arguments, where the | |
| argument name is the key and the argument value is the value. | |
| Returns: | |
| An iterator over (key, value) tuples given in the input. | |
| Raises: | |
| TypeError: If more than one unnamed argument is given. | |
| """ | |
| if len(args) > 1: | |
| raise TypeError('Expected at most 1 unnamed arguments, got %d' | |
| % len(args)) | |
| obj = args[0] if args else dict() | |
| if isinstance(obj, dict): | |
| return itertools.chain(obj.items(), kwargs.items()) | |
| # Assume obj is an iterable of 2-tuples. | |
| return itertools.chain(obj, kwargs.items()) | |
| def make_default(self, keys=None): | |
| """Convert OneOf objects into their default configs. | |
| Recursively calls into Config objects. | |
| Args: | |
| keys: Iterable of key names to check. If None, all keys in self will be | |
| used. | |
| """ | |
| if keys is None: | |
| keys = self.keys() | |
| for k in keys: | |
| # Replace OneOf with its default value. | |
| if isinstance(self[k], OneOf): | |
| self[k] = self[k].default() | |
| # Recursively call into all Config objects, even those that came from | |
| # OneOf objects in the previous code line (for nested OneOf objects). | |
| if isinstance(self[k], Config): | |
| self[k].make_default() | |
| def update(self, *args, **kwargs): | |
| """Same as dict.update except nested Config objects are updated. | |
| Args: | |
| *args: (Optional) Pass a dict or list of (key, value) 2-tuples as unnamed | |
| argument. | |
| **kwargs: (Optional) Pass (key, value) pairs as named arguments, where the | |
| argument name is the key and the argument value is the value. | |
| """ | |
| key_set = set(self.keys()) | |
| for k, v in self._update_iterator(*args, **kwargs): | |
| if k in key_set: | |
| key_set.remove(k) # This key is updated so exclude from make_default. | |
| if k in self and isinstance(self[k], Config) and isinstance(v, dict): | |
| self[k].update(v) | |
| elif k in self and isinstance(self[k], OneOf) and isinstance(v, dict): | |
| # Replace OneOf with the chosen config. | |
| self[k] = self[k].update(v) | |
| else: | |
| self[k] = v | |
| self.make_default(key_set) | |
| def strict_update(self, *args, **kwargs): | |
| """Same as Config.update except keys and types are not allowed to change. | |
| If a given key is not already in this instance, an exception is raised. If a | |
| given value does not have the same type as the existing value for the same | |
| key, an exception is raised. Use this method to catch config mistakes. | |
| Args: | |
| *args: (Optional) Pass a dict or list of (key, value) 2-tuples as unnamed | |
| argument. | |
| **kwargs: (Optional) Pass (key, value) pairs as named arguments, where the | |
| argument name is the key and the argument value is the value. | |
| Raises: | |
| TypeError: If more than one unnamed argument is given. | |
| TypeError: If new value type does not match existing type. | |
| KeyError: If a given key is not already defined in this instance. | |
| """ | |
| key_set = set(self.keys()) | |
| for k, v in self._update_iterator(*args, **kwargs): | |
| if k in self: | |
| key_set.remove(k) # This key is updated so exclude from make_default. | |
| if isinstance(self[k], Config): | |
| if not isinstance(v, dict): | |
| raise TypeError('dict required for Config value, got %s' % type(v)) | |
| self[k].strict_update(v) | |
| elif isinstance(self[k], OneOf): | |
| if not isinstance(v, dict): | |
| raise TypeError('dict required for OneOf value, got %s' % type(v)) | |
| # Replace OneOf with the chosen config. | |
| self[k] = self[k].strict_update(v) | |
| else: | |
| if not isinstance(v, type(self[k])): | |
| raise TypeError('Expecting type %s for key %s, got type %s' | |
| % (type(self[k]), k, type(v))) | |
| self[k] = v | |
| else: | |
| raise KeyError( | |
| 'Key %s does not exist. New key creation not allowed in ' | |
| 'strict_update.' % k) | |
| self.make_default(key_set) | |
| def from_str(config_str): | |
| """Inverse of Config.__str__.""" | |
| parsed = ast.literal_eval(config_str) | |
| assert isinstance(parsed, dict) | |
| def _make_config(dictionary): | |
| for k, v in dictionary.items(): | |
| if isinstance(v, dict): | |
| dictionary[k] = _make_config(v) | |
| return Config(**dictionary) | |
| return _make_config(parsed) | |
| def parse(key_val_string): | |
| """Parse hyperparameter string into Config object. | |
| Format is 'key=val,key=val,...' | |
| Values can be any python literal, or another Config object encoded as | |
| 'c(key=val,key=val,...)'. | |
| c(...) expressions can be arbitrarily nested. | |
| Example: | |
| 'a=1,b=3e-5,c=[1,2,3],d="hello world",e={"a":1,"b":2},f=c(x=1,y=[10,20])' | |
| Args: | |
| key_val_string: The hyperparameter string. | |
| Returns: | |
| Config object parsed from the input string. | |
| """ | |
| if not key_val_string.strip(): | |
| return Config() | |
| def _pair_to_kv(pair): | |
| split_index = pair.find('=') | |
| key, val = pair[:split_index].strip(), pair[split_index+1:].strip() | |
| if val.startswith('c(') and val.endswith(')'): | |
| val = Config.parse(val[2:-1]) | |
| else: | |
| val = ast.literal_eval(val) | |
| return key, val | |
| return Config(**dict([_pair_to_kv(pair) | |
| for pair in _comma_iterator(key_val_string)])) | |
| class OneOf(object): | |
| """Stores branching config. | |
| In some cases there may be options which each have their own set of config | |
| params. For example, if specifying config for an environment, each environment | |
| can have custom config options. OneOf is a way to organize branching config. | |
| Usage example: | |
| one_of = OneOf( | |
| [Config(a=1, b=2), | |
| Config(a=2, c='hello'), | |
| Config(a=3, d=10, e=-10)], | |
| a=1) | |
| config = one_of.strict_update(Config(a=3, d=20)) | |
| config == {'a': 3, 'd': 20, 'e': -10} | |
| """ | |
| def __init__(self, choices, **kwargs): | |
| """Constructor. | |
| Usage: OneOf([Config(...), Config(...), ...], attribute=default_value) | |
| Args: | |
| choices: An iterable of Config objects. When update/strict_update is | |
| called on this OneOf, one of these Config will be selected. | |
| **kwargs: Give exactly one config attribute to branch on. The value of | |
| this attribute during update/strict_update will determine which | |
| Config is used. | |
| Raises: | |
| ValueError: If kwargs does not contain exactly one entry. Should give one | |
| named argument which is used as the attribute to condition on. | |
| """ | |
| if len(kwargs) != 1: | |
| raise ValueError( | |
| 'Incorrect usage. Must give exactly one named argument. The argument ' | |
| 'name is the config attribute to condition on, and the argument ' | |
| 'value is the default choice. Got %d named arguments.' % len(kwargs)) | |
| key, default_value = kwargs.items()[0] | |
| self.key = key | |
| self.default_value = default_value | |
| # Make sure each choice is a Config object. | |
| for config in choices: | |
| if not isinstance(config, Config): | |
| raise TypeError('choices must be a list of Config objects. Got %s.' | |
| % type(config)) | |
| # Map value for key to the config with that value. | |
| self.value_map = {config[key]: config for config in choices} | |
| self.default_config = self.value_map[self.default_value] | |
| # Make sure there are no duplicate values. | |
| if len(self.value_map) != len(choices): | |
| raise ValueError('Multiple choices given for the same value of %s.' % key) | |
| # Check that the default value is valid. | |
| if self.default_value not in self.value_map: | |
| raise ValueError( | |
| 'Default value is not an available choice. Got %s=%s. Choices are %s.' | |
| % (key, self.default_value, self.value_map.keys())) | |
| def default(self): | |
| return self.default_config | |
| def update(self, other): | |
| """Choose a config and update it. | |
| If `other` is a Config, one of the config choices is selected and updated. | |
| Otherwise `other` is returned. | |
| Args: | |
| other: Will update chosen config with this value by calling `update` on | |
| the config. | |
| Returns: | |
| The chosen config after updating it, or `other` if no config could be | |
| selected. | |
| """ | |
| if not isinstance(other, Config): | |
| return other | |
| if self.key not in other or other[self.key] not in self.value_map: | |
| return other | |
| target = self.value_map[other[self.key]] | |
| target.update(other) | |
| return target | |
| def strict_update(self, config): | |
| """Choose a config and update it. | |
| `config` must be a Config object. `config` must have the key used to select | |
| among the config choices, and that key must have a value which one of the | |
| config choices has. | |
| Args: | |
| config: A Config object. the chosen config will be update by calling | |
| `strict_update`. | |
| Returns: | |
| The chosen config after updating it. | |
| Raises: | |
| TypeError: If `config` is not a Config instance. | |
| ValueError: If `config` does not have the branching key in its key set. | |
| ValueError: If the value of the config's branching key is not one of the | |
| valid choices. | |
| """ | |
| if not isinstance(config, Config): | |
| raise TypeError('Expecting Config instance, got %s.' % type(config)) | |
| if self.key not in config: | |
| raise ValueError( | |
| 'Branching key %s required but not found in %s' % (self.key, config)) | |
| if config[self.key] not in self.value_map: | |
| raise ValueError( | |
| 'Value %s for key %s is not a possible choice. Choices are %s.' | |
| % (config[self.key], self.key, self.value_map.keys())) | |
| target = self.value_map[config[self.key]] | |
| target.strict_update(config) | |
| return target | |
| def _next_comma(string, start_index): | |
| """Finds the position of the next comma not used in a literal collection.""" | |
| paren_count = 0 | |
| for i in xrange(start_index, len(string)): | |
| c = string[i] | |
| if c == '(' or c == '[' or c == '{': | |
| paren_count += 1 | |
| elif c == ')' or c == ']' or c == '}': | |
| paren_count -= 1 | |
| if paren_count == 0 and c == ',': | |
| return i | |
| return -1 | |
| def _comma_iterator(string): | |
| index = 0 | |
| while 1: | |
| next_index = _next_comma(string, index) | |
| if next_index == -1: | |
| yield string[index:] | |
| return | |
| yield string[index:next_index] | |
| index = next_index + 1 | |