Spaces:
Runtime error
Runtime error
| # flake8: noqa | |
| import typing | |
| import warnings | |
| import sys | |
| from copy import deepcopy | |
| from dataclasses import MISSING, is_dataclass, fields as dc_fields | |
| from datetime import datetime | |
| from decimal import Decimal | |
| from uuid import UUID | |
| from enum import Enum | |
| from typing_inspect import is_union_type # type: ignore | |
| from marshmallow import fields, Schema, post_load # type: ignore | |
| from marshmallow.exceptions import ValidationError # type: ignore | |
| from dataclasses_json.core import (_is_supported_generic, _decode_dataclass, | |
| _ExtendedEncoder, _user_overrides_or_exts) | |
| from dataclasses_json.utils import (_is_collection, _is_optional, | |
| _issubclass_safe, _timestamp_to_dt_aware, | |
| _is_new_type, _get_type_origin, | |
| _handle_undefined_parameters_safe, | |
| CatchAllVar) | |
| class _TimestampField(fields.Field): | |
| def _serialize(self, value, attr, obj, **kwargs): | |
| if value is not None: | |
| return value.timestamp() | |
| else: | |
| if not self.required: | |
| return None | |
| else: | |
| raise ValidationError(self.default_error_messages["required"]) | |
| def _deserialize(self, value, attr, data, **kwargs): | |
| if value is not None: | |
| return _timestamp_to_dt_aware(value) | |
| else: | |
| if not self.required: | |
| return None | |
| else: | |
| raise ValidationError(self.default_error_messages["required"]) | |
| class _IsoField(fields.Field): | |
| def _serialize(self, value, attr, obj, **kwargs): | |
| if value is not None: | |
| return value.isoformat() | |
| else: | |
| if not self.required: | |
| return None | |
| else: | |
| raise ValidationError(self.default_error_messages["required"]) | |
| def _deserialize(self, value, attr, data, **kwargs): | |
| if value is not None: | |
| return datetime.fromisoformat(value) | |
| else: | |
| if not self.required: | |
| return None | |
| else: | |
| raise ValidationError(self.default_error_messages["required"]) | |
| class _UnionField(fields.Field): | |
| def __init__(self, desc, cls, field, *args, **kwargs): | |
| self.desc = desc | |
| self.cls = cls | |
| self.field = field | |
| super().__init__(*args, **kwargs) | |
| def _serialize(self, value, attr, obj, **kwargs): | |
| if self.allow_none and value is None: | |
| return None | |
| for type_, schema_ in self.desc.items(): | |
| if _issubclass_safe(type(value), type_): | |
| if is_dataclass(value): | |
| res = schema_._serialize(value, attr, obj, **kwargs) | |
| res['__type'] = str(type_.__name__) | |
| return res | |
| break | |
| elif isinstance(value, _get_type_origin(type_)): | |
| return schema_._serialize(value, attr, obj, **kwargs) | |
| else: | |
| warnings.warn( | |
| f'The type "{type(value).__name__}" (value: "{value}") ' | |
| f'is not in the list of possible types of typing.Union ' | |
| f'(dataclass: {self.cls.__name__}, field: {self.field.name}). ' | |
| f'Value cannot be serialized properly.') | |
| return super()._serialize(value, attr, obj, **kwargs) | |
| def _deserialize(self, value, attr, data, **kwargs): | |
| tmp_value = deepcopy(value) | |
| if isinstance(tmp_value, dict) and '__type' in tmp_value: | |
| dc_name = tmp_value['__type'] | |
| for type_, schema_ in self.desc.items(): | |
| if is_dataclass(type_) and type_.__name__ == dc_name: | |
| del tmp_value['__type'] | |
| return schema_._deserialize(tmp_value, attr, data, **kwargs) | |
| elif isinstance(tmp_value, dict): | |
| warnings.warn( | |
| f'Attempting to deserialize "dict" (value: "{tmp_value}) ' | |
| f'that does not have a "__type" type specifier field into' | |
| f'(dataclass: {self.cls.__name__}, field: {self.field.name}).' | |
| f'Deserialization may fail, or deserialization to wrong type may occur.' | |
| ) | |
| return super()._deserialize(tmp_value, attr, data, **kwargs) | |
| else: | |
| for type_, schema_ in self.desc.items(): | |
| if isinstance(tmp_value, _get_type_origin(type_)): | |
| return schema_._deserialize(tmp_value, attr, data, **kwargs) | |
| else: | |
| warnings.warn( | |
| f'The type "{type(tmp_value).__name__}" (value: "{tmp_value}") ' | |
| f'is not in the list of possible types of typing.Union ' | |
| f'(dataclass: {self.cls.__name__}, field: {self.field.name}). ' | |
| f'Value cannot be deserialized properly.') | |
| return super()._deserialize(tmp_value, attr, data, **kwargs) | |
| class _TupleVarLen(fields.List): | |
| """ | |
| variable-length homogeneous tuples | |
| """ | |
| def _deserialize(self, value, attr, data, **kwargs): | |
| optional_list = super()._deserialize(value, attr, data, **kwargs) | |
| return None if optional_list is None else tuple(optional_list) | |
| TYPES = { | |
| typing.Mapping: fields.Mapping, | |
| typing.MutableMapping: fields.Mapping, | |
| typing.List: fields.List, | |
| typing.Dict: fields.Dict, | |
| typing.Tuple: fields.Tuple, | |
| typing.Callable: fields.Function, | |
| typing.Any: fields.Raw, | |
| dict: fields.Dict, | |
| list: fields.List, | |
| tuple: fields.Tuple, | |
| str: fields.Str, | |
| int: fields.Int, | |
| float: fields.Float, | |
| bool: fields.Bool, | |
| datetime: _TimestampField, | |
| UUID: fields.UUID, | |
| Decimal: fields.Decimal, | |
| CatchAllVar: fields.Dict, | |
| } | |
| A = typing.TypeVar('A') | |
| JsonData = typing.Union[str, bytes, bytearray] | |
| TEncoded = typing.Dict[str, typing.Any] | |
| TOneOrMulti = typing.Union[typing.List[A], A] | |
| TOneOrMultiEncoded = typing.Union[typing.List[TEncoded], TEncoded] | |
| if sys.version_info >= (3, 7) or typing.TYPE_CHECKING: | |
| class SchemaF(Schema, typing.Generic[A]): | |
| """Lift Schema into a type constructor""" | |
| def __init__(self, *args, **kwargs): | |
| """ | |
| Raises exception because this class should not be inherited. | |
| This class is helper only. | |
| """ | |
| super().__init__(*args, **kwargs) | |
| raise NotImplementedError() | |
| def dump(self, obj: typing.List[A], many: typing.Optional[bool] = None) -> typing.List[TEncoded]: # type: ignore | |
| # mm has the wrong return type annotation (dict) so we can ignore the mypy error | |
| pass | |
| def dump(self, obj: A, many: typing.Optional[bool] = None) -> TEncoded: | |
| pass | |
| def dump(self, obj: TOneOrMulti, # type: ignore | |
| many: typing.Optional[bool] = None) -> TOneOrMultiEncoded: | |
| pass | |
| def dumps(self, obj: typing.List[A], many: typing.Optional[bool] = None, *args, | |
| **kwargs) -> str: | |
| pass | |
| def dumps(self, obj: A, many: typing.Optional[bool] = None, *args, **kwargs) -> str: | |
| pass | |
| def dumps(self, obj: TOneOrMulti, many: typing.Optional[bool] = None, *args, # type: ignore | |
| **kwargs) -> str: | |
| pass | |
| # type: ignore | |
| def load(self, data: typing.List[TEncoded], | |
| many: bool = True, partial: typing.Optional[bool] = None, | |
| unknown: typing.Optional[str] = None) -> \ | |
| typing.List[A]: | |
| # ignore the mypy error of the decorator because mm does not define lists as an allowed input type | |
| pass | |
| def load(self, data: TEncoded, | |
| many: None = None, partial: typing.Optional[bool] = None, | |
| unknown: typing.Optional[str] = None) -> A: | |
| pass | |
| def load(self, data: TOneOrMultiEncoded, | |
| many: typing.Optional[bool] = None, partial: typing.Optional[bool] = None, | |
| unknown: typing.Optional[str] = None) -> TOneOrMulti: | |
| pass | |
| # type: ignore | |
| def loads(self, json_data: JsonData, # type: ignore | |
| many: typing.Optional[bool] = True, partial: typing.Optional[bool] = None, unknown: typing.Optional[str] = None, | |
| **kwargs) -> typing.List[A]: | |
| # ignore the mypy error of the decorator because mm does not define bytes as correct input data | |
| # mm has the wrong return type annotation (dict) so we can ignore the mypy error | |
| # for the return type overlap | |
| pass | |
| def loads(self, json_data: JsonData, | |
| many: typing.Optional[bool] = None, partial: typing.Optional[bool] = None, unknown: typing.Optional[str] = None, | |
| **kwargs) -> TOneOrMulti: | |
| pass | |
| SchemaType = SchemaF[A] | |
| else: | |
| SchemaType = Schema | |
| def build_type(type_, options, mixin, field, cls): | |
| def inner(type_, options): | |
| while True: | |
| if not _is_new_type(type_): | |
| break | |
| type_ = type_.__supertype__ | |
| if is_dataclass(type_): | |
| if _issubclass_safe(type_, mixin): | |
| options['field_many'] = bool( | |
| _is_supported_generic(field.type) and _is_collection( | |
| field.type)) | |
| return fields.Nested(type_.schema(), **options) | |
| else: | |
| warnings.warn(f"Nested dataclass field {field.name} of type " | |
| f"{field.type} detected in " | |
| f"{cls.__name__} that is not an instance of " | |
| f"dataclass_json. Did you mean to recursively " | |
| f"serialize this field? If so, make sure to " | |
| f"augment {type_} with either the " | |
| f"`dataclass_json` decorator or mixin.") | |
| return fields.Field(**options) | |
| origin = getattr(type_, '__origin__', type_) | |
| args = [inner(a, {}) for a in getattr(type_, '__args__', []) if | |
| a is not type(None)] | |
| if type_ == Ellipsis: | |
| return type_ | |
| if _is_optional(type_): | |
| options["allow_none"] = True | |
| if origin is tuple: | |
| if len(args) == 2 and args[1] == Ellipsis: | |
| return _TupleVarLen(args[0], **options) | |
| else: | |
| return fields.Tuple(args, **options) | |
| if origin in TYPES: | |
| return TYPES[origin](*args, **options) | |
| if _issubclass_safe(origin, Enum): | |
| return fields.Enum(enum=origin, by_value=True, *args, **options) | |
| if is_union_type(type_): | |
| union_types = [a for a in getattr(type_, '__args__', []) if | |
| a is not type(None)] | |
| union_desc = dict(zip(union_types, args)) | |
| return _UnionField(union_desc, cls, field, **options) | |
| warnings.warn( | |
| f"Unknown type {type_} at {cls.__name__}.{field.name}: {field.type} " | |
| f"It's advised to pass the correct marshmallow type to `mm_field`.") | |
| return fields.Field(**options) | |
| return inner(type_, options) | |
| def schema(cls, mixin, infer_missing): | |
| schema = {} | |
| overrides = _user_overrides_or_exts(cls) | |
| # TODO check the undefined parameters and add the proper schema action | |
| # https://marshmallow.readthedocs.io/en/stable/quickstart.html | |
| for field in dc_fields(cls): | |
| metadata = overrides[field.name] | |
| if metadata.mm_field is not None: | |
| schema[field.name] = metadata.mm_field | |
| else: | |
| type_ = field.type | |
| options: typing.Dict[str, typing.Any] = {} | |
| missing_key = 'missing' if infer_missing else 'default' | |
| if field.default is not MISSING: | |
| options[missing_key] = field.default | |
| elif field.default_factory is not MISSING: | |
| options[missing_key] = field.default_factory() | |
| else: | |
| options['required'] = True | |
| if options.get(missing_key, ...) is None: | |
| options['allow_none'] = True | |
| if _is_optional(type_): | |
| options.setdefault(missing_key, None) | |
| options['allow_none'] = True | |
| if len(type_.__args__) == 2: | |
| # Union[str, int, None] is optional too, but it has more than 1 typed field. | |
| type_ = [tp for tp in type_.__args__ if tp is not type(None)][0] | |
| if metadata.letter_case is not None: | |
| options['data_key'] = metadata.letter_case(field.name) | |
| t = build_type(type_, options, mixin, field, cls) | |
| if field.metadata.get('dataclasses_json', {}).get('decoder'): | |
| # If the field defines a custom decoder, it should completely replace the Marshmallow field's conversion | |
| # logic. | |
| # From Marshmallow's documentation for the _deserialize method: | |
| # "Deserialize value. Concrete :class:`Field` classes should implement this method. " | |
| # This is the method that Field implementations override to perform the actual deserialization logic. | |
| # In this case we specifically override this method instead of `deserialize` to minimize potential | |
| # side effects, and only cancel the actual value deserialization. | |
| t._deserialize = lambda v, *_a, **_kw: v | |
| # if type(t) is not fields.Field: # If we use `isinstance` we would return nothing. | |
| if field.type != typing.Optional[CatchAllVar]: | |
| schema[field.name] = t | |
| return schema | |
| def build_schema(cls: typing.Type[A], | |
| mixin, | |
| infer_missing, | |
| partial) -> typing.Type["SchemaType[A]"]: | |
| Meta = type('Meta', | |
| (), | |
| {'fields': tuple(field.name for field in dc_fields(cls) # type: ignore | |
| if | |
| field.name != 'dataclass_json_config' and field.type != | |
| typing.Optional[CatchAllVar]), | |
| # TODO #180 | |
| # 'render_module': global_config.json_module | |
| }) | |
| def make_instance(self, kvs, **kwargs): | |
| return _decode_dataclass(cls, kvs, partial) | |
| def dumps(self, *args, **kwargs): | |
| if 'cls' not in kwargs: | |
| kwargs['cls'] = _ExtendedEncoder | |
| return Schema.dumps(self, *args, **kwargs) | |
| def dump(self, obj, *, many=None): | |
| many = self.many if many is None else bool(many) | |
| dumped = Schema.dump(self, obj, many=many) | |
| # TODO This is hacky, but the other option I can think of is to generate a different schema | |
| # depending on dump and load, which is even more hacky | |
| # The only problem is the catch-all field, we can't statically create a schema for it, | |
| # so we just update the dumped dict | |
| if many: | |
| for i, _obj in enumerate(obj): | |
| dumped[i].update( | |
| _handle_undefined_parameters_safe(cls=_obj, kvs={}, | |
| usage="dump")) | |
| else: | |
| dumped.update(_handle_undefined_parameters_safe(cls=obj, kvs={}, | |
| usage="dump")) | |
| return dumped | |
| schema_ = schema(cls, mixin, infer_missing) | |
| DataClassSchema: typing.Type["SchemaType[A]"] = type( | |
| f'{cls.__name__.capitalize()}Schema', | |
| (Schema,), | |
| {'Meta': Meta, | |
| f'make_{cls.__name__.lower()}': make_instance, | |
| 'dumps': dumps, | |
| 'dump': dump, | |
| **schema_}) | |
| return DataClassSchema | |