Source code for flask_unchained.bundles.api.model_serializer

from types import FunctionType
from typing import *

from flask_unchained import unchained
from flask_unchained.di import _set_up_class_dependency_injection
from flask_unchained.string_utils import title_case
from py_meta_utils import McsArgs
from speaklater import _LazyString

try:
    from flask_marshmallow.sqla import (
        SQLAlchemyAutoSchema as BaseModelSerializer,
        SQLAlchemyAutoSchemaOpts as BaseModelSerializerOptionsClass)
    from marshmallow.fields import Field
    from marshmallow.class_registry import _registry
    from marshmallow.exceptions import ValidationError as MarshmallowValidationError
    from marshmallow_sqlalchemy.convert import ModelConverter as BaseModelConverter
    from marshmallow_sqlalchemy.schema import (
        SQLAlchemyAutoSchemaMeta as BaseModelSerializerMetaclass)
    from sqlalchemy.orm import SynonymProperty
except ImportError:
    _registry = {}
    from py_meta_utils import OptionalClass as BaseModelSerializer
    from py_meta_utils import OptionalClass as BaseModelSerializerOptionsClass
    from py_meta_utils import OptionalClass as MarshmallowValidationError
    from py_meta_utils import OptionalClass as BaseModelConverter
    from py_meta_utils import OptionalMetaclass as BaseModelSerializerMetaclass

from .config import Config


class ModelConverter(BaseModelConverter):
    def fields_for_model(self,
                         model,
                         *,
                         include_fk=False,
                         include_relationships=False,
                         fields=None,
                         exclude=None,
                         base_fields=None,
                         dict_cls=dict):
        """
        Overridden to correctly name hybrid_property fields, eg given::

            class User(db.Model):
                _password = db.Column('password', db.String)

                @db.hybrid_property
                def password(self):
                    return self._password

                @password.setter
                def password(self, password):
                    self._password = hash_password(password)

        In this case upstream marshmallow_sqlalchemy uses '_password' for the
        field name, but we use 'password', as would be expected because it's
        the attribute name used for the public interface of the Model. In order
        for this logic to work, the column name must be specified and it must be
        the same as the hybrid property name. Otherwise we just fallback to the
        upstream naming convention.
        """
        # this prevents an error when building the docs
        if not hasattr(model, '__mapper__'):
            return

        result = dict_cls()
        base_fields = base_fields or {}
        for prop in model.__mapper__.iterate_properties:
            key = self._get_field_name(prop)
            if self._should_exclude_field(prop, fields=fields, exclude=exclude):
                result[key] = None
                continue
            if isinstance(prop, SynonymProperty):
                continue
            if hasattr(prop, 'columns'):
                if not include_fk:
                    # Only skip a column if there is no overridden column
                    # which does not have a Foreign Key.
                    for column in prop.columns:
                        if not column.foreign_keys:
                            break
                    else:
                        continue

                col_name = prop.columns[0].name
                if key != col_name and hasattr(model, col_name):
                    key = col_name

            if not include_relationships and hasattr(prop, "direction"):
                continue

            field = base_fields.get(key) or self.property2field(prop)
            if field:
                result[key] = field

        return result

    def property2field(self, prop, instance=True, field_class=None, **kwargs):
        """
        Overridden to mark non-nullable model columns as required (unless it's the
        primary key, because there's no way to tell if we're generating fields
        for a create or an update).
        """
        field = super().property2field(prop, instance=instance, field_class=field_class, **kwargs)
        # when a column is not nullable, mark the field as required
        if hasattr(prop, 'columns'):
            col = prop.columns[0]
            if not col.primary_key and not col.nullable:
                field.required = True
        return field


class _ModelDescriptor:
    def __get__(self, instance, owner):
        # make sure to always return the correct mapped model class
        if not unchained._models_initialized or not instance._model:
            return instance._model
        return unchained.sqlalchemy_bundle.models[instance._model.__name__]

    def __set__(self, instance, value):
        instance._model = value


class _ModelSerializerMetaMetaclass(type):
    model = _ModelDescriptor()


class _ModelSerializerMeta(metaclass=_ModelSerializerMetaMetaclass):
    pass


class ModelSerializerMetaclass(BaseModelSerializerMetaclass):
    def __new__(mcs, name, bases, clsdict):
        mcs_args = McsArgs(mcs, name, bases, clsdict)
        _set_up_class_dependency_injection(mcs_args)
        if mcs_args.is_abstract:
            return super().__new__(*mcs_args)

        meta = mcs_args.getattr('Meta', None)
        model_missing = False
        try:
            if meta.model is None:
                model_missing = True
        except AttributeError:
            model_missing = True

        if model_missing:
            raise AttributeError(f'{name} is missing the ``class Meta`` model attribute')

        model = meta.model
        try:
            model = unchained.sqlalchemy_bundle.models[meta.model.__name__]
        except AttributeError as e:
            # this happens when attempting to generate documentation and the
            # sqlalchemy bundle hasn't been loaded
            safe_error = "'DeferredBundleBlueprintFunctions' object has no attribute 'models'"
            if safe_error not in str(e):
                raise e
        except KeyError:
            pass

        meta_dict = dict(meta.__dict__)

        additional_fields = meta_dict.pop('additional', None)
        if additional_fields:
            fields = [name for name, field in clsdict.items()
                      if isinstance(field, Field)]
            meta_dict['fields'] = fields + list(additional_fields)

        meta_dict.pop('model', None)
        clsdict['Meta'] = type('Meta', (_ModelSerializerMeta,), meta_dict)
        clsdict['Meta'].model = model

        return super().__new__(*mcs_args)

    def __init__(cls, name, bases, attrs):
        if name and name in _registry:
            for existing_cls in _registry[name]:
                fullname = f'{existing_cls.__module__}.{existing_cls.__name__}'
                _registry.pop(fullname, None)
            fullname = f'{cls.__module__}.{cls.__name__}'
            _registry[name] = _registry[fullname] = [cls]
        super().__init__(name, bases, attrs)

    @classmethod
    def get_declared_fields(mcs, klass, cls_fields, inherited_fields, dict_cls):
        # overridden to fix building the docs
        try:
            return super().get_declared_fields(klass, cls_fields, inherited_fields, dict_cls)
        except TypeError:
            pass


class ModelSerializerOptionsClass(BaseModelSerializerOptionsClass):
    """
    Sets the default ``model_converter`` to :class:`_ModelConverter`.
    """
    def __init__(self, meta, **kwargs):
        self._model = None
        self.dump_key_fn = getattr(meta, 'dump_key_fn', Config.DUMP_KEY_FN)
        self.load_key_fn = getattr(meta, 'load_key_fn', Config.LOAD_KEY_FN)

        # override the upstream default values for load_instance and model_converter
        meta.load_instance = getattr(meta, 'load_instance', True)
        meta.model_converter = getattr(meta, 'model_converter', ModelConverter)
        super().__init__(meta, **kwargs)

    @property
    def model(self):
        # make sure to always return the correct mapped model class
        if not unchained._models_initialized or not self._model:
            return self._model
        return unchained.sqlalchemy_bundle.models[self._model.__name__]

    @model.setter
    def model(self, model):
        self._model = model


def maybe_convert_keys(data: Any,
                       key_fn: Optional[FunctionType] = None,
                       fields: Tuple[str] = (),
                       many: bool = False,
                       ) -> Any:
    if not key_fn or not fields:
        return data

    if many:
        return [maybe_convert_keys(el, key_fn, fields, many=False) for el in data]
    elif isinstance(data, dict):
        rv = data.copy()
        for k, v in data.items():
            new_k = key_fn(k)
            if k not in fields and new_k not in fields:
                continue
            if k != new_k:
                rv.pop(k)
            rv[new_k] = v
        return rv
    return data


[docs]class ModelSerializer(BaseModelSerializer, metaclass=ModelSerializerMetaclass): """ Base class for SQLAlchemy model serializers. This is pretty much a stock :class:`flask_marshmallow.sqla.ModelSchema`, except: - dependency injection is set up automatically on ModelSerializer - when loading to update an existing instance, validate the primary keys are the same - automatically make fields named ``slug``, ``model.Meta.created_at``, and ``model.Meta.updated_at`` dump-only For example:: from flask_unchained.bundles.api import ModelSerializer from flask_unchained.bundles.security.models import Role class RoleSerializer(ModelSerializer): class Meta: model = Role Is roughly equivalent to:: from marshmallow import Schema, fields class RoleSerializer(Schema): id = fields.Integer(dump_only=True) name = fields.String() description = fields.String() created_at = fields.DateTime(dump_only=True) updated_at = fields.DateTime(dump_only=True) """ __abstract__ = True OPTIONS_CLASS = ModelSerializerOptionsClass opts: ModelSerializerOptionsClass = None # set by the metaclass
[docs] def is_create(self): """ Check if we're creating a new object. Note that this context flag must be set from the outside, ie when the class gets instantiated. """ return self.context.get('is_create', False)
[docs] def load( self, data: Mapping, *, many: bool = None, partial: Union[bool, Sequence[str], Set[str]] = None, unknown: str = None, **kwargs, ): """Deserialize a dict to an object defined by this ModelSerializer's fields. A :exc:`ValidationError <marshmallow.exceptions.ValidationError>` is raised if invalid data is passed. :param data: The data to deserialize. :param many: Whether to deserialize `data` as a collection. If `None`, the value for `self.many` is used. :param partial: Whether to ignore missing fields and not require any fields declared. Propagates down to ``Nested`` fields as well. If its value is an iterable, only missing fields listed in that iterable will be ignored. Use dot delimiters to specify nested fields. :param unknown: Whether to exclude, include, or raise an error for unknown fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`. If `None`, the value for `self.unknown` is used. :return: Deserialized data """ # when data is None, which happens when a POST request was made with an # empty body, convert it to an empty dict. makes validation errors work # as expected data = data or {} # maybe convert all keys in data with the configured fn data = maybe_convert_keys( data, self.opts.load_key_fn, self.opts.fields or set(self.declared_fields.keys()), many=many, ) try: return super().load(data, many=many, partial=partial, unknown=unknown, **kwargs) except MarshmallowValidationError as e: e.messages = maybe_convert_keys( e.messages, self.opts.dump_key_fn, self.opts.fields or set(self.declared_fields.keys()), many=False, ) raise e
[docs] def dump(self, obj, *, many: bool = None): """Serialize an object to native Python data types according to this ModelSerializer's fields. :param obj: The object to serialize. :param many: Whether to serialize `obj` as a collection. If `None`, the value for `self.many` is used. :return: A dict of serialized data :rtype: dict """ data = super().dump(obj, many=many) # maybe convert all keys in data with the configured fn return maybe_convert_keys( data, self.opts.dump_key_fn, fields=self.opts.fields or set(self.declared_fields.keys()), many=many, )
[docs] def handle_error(self, error: MarshmallowValidationError, data: Any, # skipcq: PYL-W0613 (unused arg) **kwargs ) -> None: """ Customize the error messages for required/not-null validators with dynamically generated field names. This is definitely a little hacky (it mutates state, uses hardcoded strings), but unsure how better to do it """ required_messages = {'Missing data for required field.', 'Field may not be null.'} for field_name in error.normalized_messages(): for i, msg in enumerate(error.messages[field_name]): if isinstance(msg, _LazyString): msg = str(msg) if msg in required_messages: label = title_case(field_name) error.messages[field_name][i] = f'{label} is required.'
def _init_fields(self): """ Overridden to: - automatically validate ids (primary keys) are the same when updating objects. - automatically convert slug, created_at, and updated_at to dump-only fields """ super()._init_fields() read_only_fields = {field for field in { self.Meta.model.Meta.pk, 'slug', self.Meta.model.Meta.created_at, self.Meta.model.Meta.updated_at, } if field is not None} for name in read_only_fields: if name in self.fields: field = self.fields[name] field.dump_only = True self.dump_fields[name] = field self.load_fields.pop(name, None)
__all__ = [ 'ModelConverter', 'ModelSerializer', 'ModelSerializerMetaclass', 'ModelSerializerOptionsClass', ]