Spaces:
Runtime error
Runtime error
| import abc | |
| import dataclasses | |
| import functools | |
| import inspect | |
| import sys | |
| from dataclasses import Field, fields | |
| from typing import Any, Callable, Dict, Optional, Tuple, Union, Type, get_type_hints | |
| from enum import Enum | |
| from marshmallow.exceptions import ValidationError # type: ignore | |
| from dataclasses_json.utils import CatchAllVar | |
| KnownParameters = Dict[str, Any] | |
| UnknownParameters = Dict[str, Any] | |
| class _UndefinedParameterAction(abc.ABC): | |
| def handle_from_dict(cls, kvs: Dict[Any, Any]) -> Dict[str, Any]: | |
| """ | |
| Return the parameters to initialize the class with. | |
| """ | |
| pass | |
| def handle_to_dict(obj, kvs: Dict[Any, Any]) -> Dict[Any, Any]: | |
| """ | |
| Return the parameters that will be written to the output dict | |
| """ | |
| return kvs | |
| def handle_dump(obj) -> Dict[Any, Any]: | |
| """ | |
| Return the parameters that will be added to the schema dump. | |
| """ | |
| return {} | |
| def create_init(obj) -> Callable: | |
| return obj.__init__ | |
| def _separate_defined_undefined_kvs(cls, kvs: Dict) -> \ | |
| Tuple[KnownParameters, UnknownParameters]: | |
| """ | |
| Returns a 2 dictionaries: defined and undefined parameters | |
| """ | |
| class_fields = fields(cls) | |
| field_names = [field.name for field in class_fields] | |
| unknown_given_parameters = {k: v for k, v in kvs.items() if | |
| k not in field_names} | |
| known_given_parameters = {k: v for k, v in kvs.items() if | |
| k in field_names} | |
| return known_given_parameters, unknown_given_parameters | |
| class _RaiseUndefinedParameters(_UndefinedParameterAction): | |
| """ | |
| This action raises UndefinedParameterError if it encounters an undefined | |
| parameter during initialization. | |
| """ | |
| def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]: | |
| known, unknown = \ | |
| _UndefinedParameterAction._separate_defined_undefined_kvs( | |
| cls=cls, kvs=kvs) | |
| if len(unknown) > 0: | |
| raise UndefinedParameterError( | |
| f"Received undefined initialization arguments {unknown}") | |
| return known | |
| CatchAll = Optional[CatchAllVar] | |
| class _IgnoreUndefinedParameters(_UndefinedParameterAction): | |
| """ | |
| This action does nothing when it encounters undefined parameters. | |
| The undefined parameters can not be retrieved after the class has been | |
| created. | |
| """ | |
| def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]: | |
| known_given_parameters, _ = \ | |
| _UndefinedParameterAction._separate_defined_undefined_kvs( | |
| cls=cls, kvs=kvs) | |
| return known_given_parameters | |
| def create_init(obj) -> Callable: | |
| original_init = obj.__init__ | |
| init_signature = inspect.signature(original_init) | |
| def _ignore_init(self, *args, **kwargs): | |
| known_kwargs, _ = \ | |
| _CatchAllUndefinedParameters._separate_defined_undefined_kvs( | |
| obj, kwargs) | |
| num_params_takeable = len( | |
| init_signature.parameters) - 1 # don't count self | |
| num_args_takeable = num_params_takeable - len(known_kwargs) | |
| args = args[:num_args_takeable] | |
| bound_parameters = init_signature.bind_partial(self, *args, | |
| **known_kwargs) | |
| bound_parameters.apply_defaults() | |
| arguments = bound_parameters.arguments | |
| arguments.pop("self", None) | |
| final_parameters = \ | |
| _IgnoreUndefinedParameters.handle_from_dict(obj, arguments) | |
| original_init(self, **final_parameters) | |
| return _ignore_init | |
| class _CatchAllUndefinedParameters(_UndefinedParameterAction): | |
| """ | |
| This class allows to add a field of type utils.CatchAll which acts as a | |
| dictionary into which all | |
| undefined parameters will be written. | |
| These parameters are not affected by LetterCase. | |
| If no undefined parameters are given, this dictionary will be empty. | |
| """ | |
| class _SentinelNoDefault: | |
| pass | |
| def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]: | |
| known, unknown = _UndefinedParameterAction \ | |
| ._separate_defined_undefined_kvs(cls=cls, kvs=kvs) | |
| catch_all_field = _CatchAllUndefinedParameters._get_catch_all_field( | |
| cls=cls) | |
| if catch_all_field.name in known: | |
| already_parsed = isinstance(known[catch_all_field.name], dict) | |
| default_value = _CatchAllUndefinedParameters._get_default( | |
| catch_all_field=catch_all_field) | |
| received_default = default_value == known[catch_all_field.name] | |
| value_to_write: Any | |
| if received_default and len(unknown) == 0: | |
| value_to_write = default_value | |
| elif received_default and len(unknown) > 0: | |
| value_to_write = unknown | |
| elif already_parsed: | |
| # Did not receive default | |
| value_to_write = known[catch_all_field.name] | |
| if len(unknown) > 0: | |
| value_to_write.update(unknown) | |
| else: | |
| error_message = f"Received input field with " \ | |
| f"same name as catch-all field: " \ | |
| f"'{catch_all_field.name}': " \ | |
| f"'{known[catch_all_field.name]}'" | |
| raise UndefinedParameterError(error_message) | |
| else: | |
| value_to_write = unknown | |
| known[catch_all_field.name] = value_to_write | |
| return known | |
| def _get_default(catch_all_field: Field) -> Any: | |
| # access to the default factory currently causes | |
| # a false-positive mypy error (16. Dec 2019): | |
| # https://github.com/python/mypy/issues/6910 | |
| # noinspection PyProtectedMember | |
| has_default = not isinstance(catch_all_field.default, | |
| dataclasses._MISSING_TYPE) | |
| # noinspection PyProtectedMember | |
| has_default_factory = not isinstance(catch_all_field.default_factory, | |
| # type: ignore | |
| dataclasses._MISSING_TYPE) | |
| # TODO: black this for proper formatting | |
| default_value: Union[ | |
| Type[_CatchAllUndefinedParameters._SentinelNoDefault], Any] = _CatchAllUndefinedParameters\ | |
| ._SentinelNoDefault | |
| if has_default: | |
| default_value = catch_all_field.default | |
| elif has_default_factory: | |
| # This might be unwanted if the default factory constructs | |
| # something expensive, | |
| # because we have to construct it again just for this test | |
| default_value = catch_all_field.default_factory() # type: ignore | |
| return default_value | |
| def handle_to_dict(obj, kvs: Dict[Any, Any]) -> Dict[Any, Any]: | |
| catch_all_field = \ | |
| _CatchAllUndefinedParameters._get_catch_all_field(obj.__class__) | |
| undefined_parameters = kvs.pop(catch_all_field.name) | |
| if isinstance(undefined_parameters, dict): | |
| kvs.update( | |
| undefined_parameters) # If desired handle letter case here | |
| return kvs | |
| def handle_dump(obj) -> Dict[Any, Any]: | |
| catch_all_field = _CatchAllUndefinedParameters._get_catch_all_field( | |
| cls=obj) | |
| return getattr(obj, catch_all_field.name) | |
| def create_init(obj) -> Callable: | |
| original_init = obj.__init__ | |
| init_signature = inspect.signature(original_init) | |
| def _catch_all_init(self, *args, **kwargs): | |
| known_kwargs, unknown_kwargs = \ | |
| _CatchAllUndefinedParameters._separate_defined_undefined_kvs( | |
| obj, kwargs) | |
| num_params_takeable = len( | |
| init_signature.parameters) - 1 # don't count self | |
| if _CatchAllUndefinedParameters._get_catch_all_field( | |
| obj).name not in known_kwargs: | |
| num_params_takeable -= 1 | |
| num_args_takeable = num_params_takeable - len(known_kwargs) | |
| args, unknown_args = args[:num_args_takeable], args[ | |
| num_args_takeable:] | |
| bound_parameters = init_signature.bind_partial(self, *args, | |
| **known_kwargs) | |
| unknown_args = {f"_UNKNOWN{i}": v for i, v in | |
| enumerate(unknown_args)} | |
| arguments = bound_parameters.arguments | |
| arguments.update(unknown_args) | |
| arguments.update(unknown_kwargs) | |
| arguments.pop("self", None) | |
| final_parameters = _CatchAllUndefinedParameters.handle_from_dict( | |
| obj, arguments) | |
| original_init(self, **final_parameters) | |
| return _catch_all_init | |
| def _get_catch_all_field(cls) -> Field: | |
| cls_globals = vars(sys.modules[cls.__module__]) | |
| types = get_type_hints(cls, globalns=cls_globals) | |
| catch_all_fields = list( | |
| filter(lambda f: types[f.name] == Optional[CatchAllVar], fields(cls))) | |
| number_of_catch_all_fields = len(catch_all_fields) | |
| if number_of_catch_all_fields == 0: | |
| raise UndefinedParameterError( | |
| "No field of type dataclasses_json.CatchAll defined") | |
| elif number_of_catch_all_fields > 1: | |
| raise UndefinedParameterError( | |
| f"Multiple catch-all fields supplied: " | |
| f"{number_of_catch_all_fields}.") | |
| else: | |
| return catch_all_fields[0] | |
| class Undefined(Enum): | |
| """ | |
| Choose the behavior what happens when an undefined parameter is encountered | |
| during class initialization. | |
| """ | |
| INCLUDE = _CatchAllUndefinedParameters | |
| RAISE = _RaiseUndefinedParameters | |
| EXCLUDE = _IgnoreUndefinedParameters | |
| class UndefinedParameterError(ValidationError): | |
| """ | |
| Raised when something has gone wrong handling undefined parameters. | |
| """ | |
| pass | |