Source code for flask_unchained.unchained

import functools
import inspect
import itertools
import jinja2
import markupsafe
import networkx as nx

from flask import Flask, current_app
from py_meta_utils import _missing
from typing import *

from ._compat import QUART_ENABLED, LocalProxy
from .constants import (DEV, PROD, STAGING, TEST,
                        _DI_AUTOMATICALLY_HANDLED, _INJECT_CLS_ATTRS)
from .di import _ensure_service_name, _get_injected_value, injectable, _inject_cls_attrs
from .exceptions import ServiceUsageError
from .utils import AttrDict


class DeferredBundleBlueprintFunctions:
    """
    The public interface for replacing Blueprints with Bundles. Must be accessed
    via the :class:`flask_unchained.Unchained` extension instance, eg::

       from flask_unchained import Bundle, unchained

       class Foobar(Bundle):
           name = 'the_bundle_name'

       unchained.the_bundle_name.before_request()  # or any other public method on this class
    """

    def __init__(self):
        self._deferred_functions = []

    def _defer(self, fn):
        self._deferred_functions.append(fn)

    def before_request(self, fn=None):
        """
        Like :meth:`flask.Blueprint.before_request` but for a bundle. This function
        is only executed before each request that is handled by a view function
        of that bundle.
        """
        if fn is None:
            return self.before_request

        self._defer(lambda bp: bp.before_request(fn))
        return fn

    def after_request(self, fn=None):
        """
        Like :meth:`flask.Blueprint.after_request` but for a bundle. This function
        is only executed after each request that is handled by a function of
        that bundle.
        """
        if fn is None:
            return self.after_request

        self._defer(lambda bp: bp.after_request(fn))
        return fn

    def teardown_request(self, fn=None):
        """
        Like :meth:`flask.Blueprint.teardown_request` but for a bundle. This
        function is only executed when tearing down requests handled by a
        function of that bundle.  Teardown request functions are executed
        when the request context is popped, even when no actual request was
        performed.
        """
        if fn is None:
            return self.teardown_request

        self._defer(lambda bp: bp.teardown_request(fn))
        return fn

    def context_processor(self, fn=None):
        """
        Like :meth:`flask.Blueprint.context_processor` but for a bundle. This
        function is only executed for requests handled by a bundle.
        """
        if fn is None:
            return self.context_processor

        self._defer(lambda bp: bp.context_processor(fn))
        return fn

    def url_defaults(self, fn=None):
        """
        Callback function for URL defaults for this bundle. It's called
        with the endpoint and values and should update the values passed
        in place.
        """
        if fn is None:
            return self.url_defaults

        self._defer(lambda bp: bp.url_defaults(fn))
        return fn

    def url_value_preprocessor(self, fn=None):
        """
        Registers a function as URL value preprocessor for this
        bundle. It's called before the view functions are called and
        can modify the url values provided.
        """
        if fn is None:
            return self.url_value_preprocessor

        self._defer(lambda bp: bp.url_value_preprocessor(fn))
        return fn

    def errorhandler(self, code_or_exception):
        """
        Registers an error handler that becomes active for this bundle
        only.  Please be aware that routing does not happen local to a
        bundle so an error handler for 404 usually is not handled by
        a bundle unless it is caused inside a view function.  Another
        special case is the 500 internal server error which is always looked
        up from the application.

        Otherwise works as the :meth:`flask.Blueprint.errorhandler` decorator.
        """
        def decorator(fn):
            self._defer(lambda bp: bp.register_error_handler(code_or_exception, fn))
            return fn
        return decorator

    if QUART_ENABLED:
        def before_websocket(self, fn=None):
            if fn is None:
                return self.before_websocket

            self._defer(lambda bp: bp.before_websocket(fn))
            return fn

        def after_websocket(self, fn=None):
            if fn is None:
                return self.after_websocket

            self._defer(lambda bp: bp.after_websocket(fn))
            return fn

        def teardown_websocket(self, fn=None):
            if fn is None:
                return self.teardown_websocket

            self._defer(lambda bp: bp.teardown_websocket(fn))
            return fn

    def __call__(self, *args, **kwargs):
        raise NotImplementedError("Probably you're trying to call a method at import "
                                  "time on the Unchained extension that doesn't exist?")


class _DeferredBundleBlueprintFunctionsStore:
    """
    An intermediary store that lives on :class:`~flask_unchained.Unchained` to
    return an instance of :class:`~flask_unchained.DeferredBundleBlueprintFunctions`
    for each bundle name lookup on us (implements a dict-like readonly interface).
    """
    def __init__(self):
        self._bundles = {}

    def __getitem__(self, bundle_name):
        if bundle_name not in self._bundles:
            self._bundles[bundle_name] = DeferredBundleBlueprintFunctions()
        return self._bundles[bundle_name]


[docs]class Unchained: """ The ``Unchained`` extension. Responsible for initializing the app by loading all the things from bundles, keeping references to all of the various discovered bundles and things inside them, and for doing dependency injection. To get access to the ``unchained`` extension instance:: from flask_unchained import unchained Also acts as a replacement for some of the public API of :class:`flask.Flask`. (The part that allows registering url rules, functions to run for handling errors, functions to run during the normal request response cycle, and methods for setting up the Jinja templating environment.) """ def __init__(self, env: Optional[Union[DEV, PROD, STAGING, TEST]] = None): self.bundles = AttrDict() self._deferred_bundle_functions = _DeferredBundleBlueprintFunctionsStore() self.babel_bundle = None self.env = env self.extensions = AttrDict() self.services = AttrDict() self._app = None self._app_bundle_cls = None self._deferred_functions = [] self._initialized = False self._models_initialized = False self._services_initialized = False self._services_registry = {} self._shell_ctx = {} def init_app(self, app: Flask, bundles: Optional[List] = None, # FIXME Optional[List[Bundle]] on 3.7+ unchained_config: Optional[Dict[str, Any]] = None, ) -> None: # deferred import to prevent circular dependency from .hooks.run_hooks_hook import RunHooksHook self.env = app.env or self.env app.extensions['unchained'] = self app.unchained = self self._app = app bundles = bundles or [] for bundle in bundles: bundle._deferred_functions = \ self._deferred_bundle_functions[bundle.name]._deferred_functions self.bundles = AttrDict({b.name: b for b in bundles}) self.babel_bundle = self.bundles.get('babel_bundle', None) self._shell_ctx = {b.name: b for b in bundles} self._shell_ctx['unchained'] = self app.shell_context_processor(lambda: self._shell_ctx) for deferred in self._deferred_functions: deferred(app) run_hooks_hook = RunHooksHook(self) run_hooks_hook.run_hook(app, bundles, unchained_config) self._initialized = True
[docs] def get_local_proxy(self, name): """ Returns a :class:`~werkzeug.local.LocalProxy` to the extension or service with ``name`` as registered with the current app. """ def get_extension_or_service_by_name(): value = _get_injected_value(current_app.unchained, name, throw=False) if value is _missing: raise KeyError(f'No extension or service was found with the name {name}.') return value return LocalProxy(get_extension_or_service_by_name)
[docs] def service(self, name: str = None): """ Decorator to mark something as a service. """ if self._services_initialized: from warnings import warn warn('Services have already been initialized. Please register ' f'{name} sooner.') return lambda x: x def wrapper(service): self.register_service(name, service) return service return wrapper
[docs] def register_service(self, name: str, service: Any): """ Method to register a service. """ if not isinstance(service, type): if hasattr(service, '__class__'): _ensure_service_name(service.__class__, name) self.services[name] = service return if self._services_initialized: from warnings import warn warn('Services have already been initialized. Please register ' f'{name} sooner.') return self._services_registry[_ensure_service_name(service, name)] = service
[docs] def inject(self, *args): """ Decorator to mark a class, method, or function as needing dependencies injected. Example usage:: from flask_unchained import unchained, injectable # automatically figure out which params to inject @unchained.inject() def my_function(not_injected, some_service: SomeService = injectable): # do stuff # or declare injectables explicitly (makes the ``injectable`` default optional) @unchained.inject('some_service') def my_function(not_injected, some_service: SomeService): # do stuff # use it on a class to set up class attributes injection (and the constructor) @unchained.inject() class MyClass: some_service: SomeService = injectable def __init__(self, another_service: AnotherService = injectable): self.another_service = another_service """ used_without_parenthesis = len(args) and callable(args[0]) has_explicit_args = len(args) and all(isinstance(x, str) for x in args) def wrapper(fn): cls = None if isinstance(fn, type): cls = fn fn = cls.__init__ # check if the fn/class has already been wrapped with inject if hasattr(fn, '__signature__'): if not cls: return fn if not hasattr(cls, '__signature__'): # this happens when both the class and its __init__ method # were decorated with @inject. which would be silly, but, # it should still work regardless cls.__signature__ = fn.__signature__ if cls and hasattr(cls, '__di_name__'): return cls sig = inspect.signature(fn) # create a new function wrapping the original to inject params @functools.wraps(fn) def dependency_injector(*fn_args, **fn_kwargs): # figure out which params we need to inject (we don't want to # interfere with any params the user has passed manually) bound_args = sig.bind_partial(*fn_args, **fn_kwargs) required = set(sig.parameters.keys()) have = set(bound_args.arguments.keys()) need = required - have to_inject = need & (set(args) if has_explicit_args else {k for k, v in sig.parameters.items() if isinstance(v.default, str) and v.default == injectable}) # try to inject needed params from extensions or services for param_name in to_inject: try: fn_kwargs[param_name] = _get_injected_value( unchained_ext=self, param_name=param_name, requested_by=dependency_injector.__di_name__, ) except: continue # check to make sure we're not missing anything required # this check must live here so that it works when services get used # outside of flask unchained bound_args = sig.bind_partial(*fn_args, **fn_kwargs) bound_args.apply_defaults() for k, v in bound_args.arguments.items(): if isinstance(v, str) and v == injectable: di_name = dependency_injector.__di_name__ is_constructor = ('.' not in di_name and di_name != di_name.lower()) action = 'initialized' if is_constructor else 'called' raise ServiceUsageError( f'{di_name} was {action} without the {k} parameter. ' f'Please supply it manually, or make sure it gets injected.' ) if cls and not getattr(cls, _DI_AUTOMATICALLY_HANDLED, False): cls_attrs_to_inject = getattr(cls, _INJECT_CLS_ATTRS, []) for attr, value in vars(cls).items(): if (isinstance(value, str) and value == injectable and attr not in cls_attrs_to_inject): cls_attrs_to_inject.append(attr) if has_explicit_args: cls_attrs_to_inject = list(set(cls_attrs_to_inject) & set(args)) if cls_attrs_to_inject: setattr(cls, _INJECT_CLS_ATTRS, cls_attrs_to_inject) _inject_cls_attrs()(cls) return fn(*bound_args.args, **bound_args.kwargs) dependency_injector.__signature__ = sig dependency_injector.__di_name__ = getattr(fn, '__di_name__', fn.__name__) if cls: cls.__init__ = dependency_injector cls.__signature__ = sig return cls return dependency_injector if used_without_parenthesis: return wrapper(args[0]) return wrapper
def _init_services(self): dag = nx.DiGraph() for name, service in self._services_registry.items(): if not callable(service): self.services[name] = service continue dag.add_node(name) for param_name in itertools.chain.from_iterable([ inspect.signature(service).parameters, getattr(service, _INJECT_CLS_ATTRS) ]): if (param_name in self.services or param_name in self.extensions or param_name in self._services_registry): dag.add_edge(name, param_name) try: instantiation_order = reversed(list(nx.topological_sort(dag))) except nx.NetworkXUnfeasible: msg = 'Circular dependency detected between services' problem_graph = ', '.join(f'{a} -> {b}' for a, b in nx.find_cycle(dag)) raise Exception(f'{msg}: {problem_graph}') for name in instantiation_order: if name in self.services or name in self.extensions: continue service = self._services_registry[name] params = {n: self.extensions.get(n, self.services.get(n)) for n in dag.successors(name) if n not in getattr(service, _INJECT_CLS_ATTRS) and (n in self.extensions or n in self.services)} if 'config' in inspect.signature(service).parameters: params['config'] = self._app.config if not isinstance(service, type): self.services[name] = functools.partial(service, **params) else: try: self.services[name] = service(**params) except TypeError as e: # FIXME this exception is too generic, need to better parse # its string repr (eg, got unexpected keyword argument) missing = str(e).rsplit(': ')[-1] requester = f'{service.__module__}.{service.__name__}' raise Exception(f'No service found with the name {missing} ' f'(required by {requester})') self._services_initialized = True def __getattr__(self, name: str): """ Implemented to allow accessing bundles by their name as attributes on the ``unchained`` extension instance. *Before* the app has been initialized (ie at import time), we don't actually know what bundles the user has configured, and therefore we need to make a compromise: *any* unrecognized attribute access before the app has been initialized is assumed to be a valid bundle name, and so we return a :class:`~flask_unchained.bundle.DeferredBundleBlueprintFunctions` instance that allows registering deferred functions (as a replacement for that part of the public Blueprint API). *After* the app has been initialized, we know what bundles the user configured, and can therefore return the correct bundle instance (or raise ``AttributeError`` if an invalid attribute was requested). """ if name in self.bundles: return self.bundles[name] elif not self._initialized: return self._deferred_bundle_functions[name] raise AttributeError(name) def _defer(self, fn): if self._initialized: from warnings import warn warn('The app has already been initialized. ' f'Please register {fn.__name__} sooner.') return self._deferred_functions.append(fn)
[docs] def add_url_rule(self, rule, endpoint=None, view_func=None, **options): """ Register a new url rule. Acts the same as :meth:`flask.Flask.add_url_rule`. """ self._defer(lambda app: app.add_url_rule(rule, endpoint=endpoint, view_func=view_func, **options))
[docs] def before_request(self, fn=None): """ Registers a function to run before each request. For example, this can be used to open a database connection, or to load the logged in user from the session. The function will be called without any arguments. If it returns a non-None value, the value is handled as if it was the return value from the view, and further request handling is stopped. """ if fn is None: return self.before_request self._defer(lambda app: app.before_request(fn)) return fn
[docs] def before_first_request(self, fn=None): """ Registers a function to be run before the first request to this instance of the application. The function will be called without any arguments and its return value is ignored. """ if fn is None: return self.before_first_request self._defer(lambda app: app.before_first_request(fn)) return fn
[docs] def after_request(self, fn=None): """ Register a function to be run after each request. Your function must take one parameter, an instance of :attr:`response_class` and return a new response object or the same (see :meth:`process_response`). As of Flask 0.7 this function might not be executed at the end of the request in case an unhandled exception occurred. """ if fn is None: return self.after_request self._defer(lambda app: app.after_request(fn)) return fn
[docs] def teardown_request(self, fn=None): """ Register a function to be run at the end of each request, regardless of whether there was an exception or not. These functions are executed when the request context is popped, even if not an actual request was performed. Example:: ctx = app.test_request_context() ctx.push() ... ctx.pop() When ``ctx.pop()`` is executed in the above example, the teardown functions are called just before the request context moves from the stack of active contexts. This becomes relevant if you are using such constructs in tests. Generally teardown functions must take every necessary step to avoid that they will fail. If they do execute code that might fail they will have to surround the execution of these code by try/except statements and log occurring errors. When a teardown function was called because of an exception it will be passed an error object. The return values of teardown functions are ignored. .. admonition:: Debug Note In debug mode Flask will not tear down a request on an exception immediately. Instead it will keep it alive so that the interactive debugger can still access it. This behavior can be controlled by the ``PRESERVE_CONTEXT_ON_EXCEPTION`` configuration variable. """ if fn is None: return self.teardown_request self._defer(lambda app: app.teardown_request(fn)) return fn
[docs] def teardown_appcontext(self, fn=None): """ Registers a function to be called when the application context ends. These functions are typically also called when the request context is popped. Example:: ctx = app.app_context() ctx.push() ... ctx.pop() When ``ctx.pop()`` is executed in the above example, the teardown functions are called just before the app context moves from the stack of active contexts. This becomes relevant if you are using such constructs in tests. Since a request context typically also manages an application context it would also be called when you pop a request context. When a teardown function was called because of an unhandled exception it will be passed an error object. If an :meth:`errorhandler` is registered, it will handle the exception and the teardown will not receive it. The return values of teardown functions are ignored. """ if fn is None: return self.teardown_appcontext self._defer(lambda app: app.teardown_appcontext(fn)) return fn
[docs] def context_processor(self, fn=None): """ Registers a template context processor function. """ if fn is None: return self.context_processor self._defer(lambda app: app.context_processor(fn)) return fn
[docs] def shell_context_processor(self, fn=None): """ Registers a shell context processor function. """ if fn is None: return self.shell_context_processor self._defer(lambda app: app.shell_context_processor(fn)) return fn
[docs] def url_value_preprocessor(self, fn=None): """ Register a URL value preprocessor function for all view functions in the application. These functions will be called before the :meth:`before_request` functions. The function can modify the values captured from the matched url before they are passed to the view. For example, this can be used to pop a common language code value and place it in ``g`` rather than pass it to every view. The function is passed the endpoint name and values dict. The return value is ignored. """ if fn is None: return self.url_value_preprocessor self._defer(lambda app: app.url_value_preprocessor(fn)) return fn
[docs] def url_defaults(self, fn=None): """ Callback function for URL defaults for all view functions of the application. It's called with the endpoint and values and should update the values passed in place. """ if fn is None: return self.url_defaults self._defer(lambda app: app.url_defaults(fn)) return fn
[docs] def errorhandler(self, code_or_exception): """ Register a function to handle errors by code or exception class. A decorator that is used to register a function given an error code. Example:: @app.errorhandler(404) def page_not_found(error): return 'This page does not exist', 404 You can also register handlers for arbitrary exceptions:: @app.errorhandler(DatabaseError) def special_exception_handler(error): return 'Database connection failed', 500 :param code_or_exception: the code as integer for the handler, or an arbitrary exception """ def decorator(fn): self._defer(lambda app: app.register_error_handler(code_or_exception, fn)) return fn return decorator
[docs] def template_filter(self, arg: Optional[Callable] = None, *, name: Optional[str] = None, pass_context: bool = False, inject: Optional[Union[bool, Iterable[str]]] = None, safe: bool = False, ) -> Callable: """ Decorator to mark a function as a Jinja template filter. :param name: The name of the filter, if different from the function name. :param pass_context: Whether or not to pass the template context into the filter. If ``True``, the first argument must be the context. :param inject: Whether or not this filter needs any dependencies injected. :param safe: Whether or not to mark the output of this filter as html-safe. """ def wrapper(fn): fn = _inject(fn, inject) if safe: fn = _make_safe(fn) if pass_context: fn = jinja2.contextfilter(fn) self._defer(lambda app: app.add_template_filter(fn, name=name)) return fn if callable(arg): return wrapper(arg) return wrapper
[docs] def template_global(self, arg: Optional[Callable] = None, *, name: Optional[str] = None, pass_context: bool = False, inject: Optional[Union[bool, Iterable[str]]] = None, safe: bool = False, ) -> Callable: """ Decorator to mark a function as a Jinja template global (tag). :param name: The name of the tag, if different from the function name. :param pass_context: Whether or not to pass the template context into the tag. If ``True``, the first argument must be the context. :param inject: Whether or not this tag needs any dependencies injected. :param safe: Whether or not to mark the output of this tag as html-safe. """ def wrapper(fn): fn = _inject(fn, inject) if safe: fn = _make_safe(fn) if pass_context: fn = jinja2.contextfunction(fn) self._defer(lambda app: app.add_template_global(fn, name=name)) return fn if callable(arg): return wrapper(arg) return wrapper
[docs] def template_tag(self, arg: Optional[Callable] = None, *, name: Optional[str] = None, pass_context: bool = False, inject: Optional[Union[bool, Iterable[str]]] = None, safe: bool = False, ) -> Callable: """ Alias for :meth:`template_global`. :param name: The name of the tag, if different from the function name. :param pass_context: Whether or not to pass the template context into the tag. If ``True``, the first argument must be the context. :param inject: Whether or not this tag needs any dependencies injected. :param safe: Whether or not to mark the output of this tag as html-safe. """ return self.template_global(arg, name=name, pass_context=pass_context, inject=inject, safe=safe)
[docs] def template_test(self, arg: Optional[Callable] = None, *, name: Optional[str] = None, inject: Optional[Union[bool, Iterable[str]]] = None, safe: bool = False, ) -> Callable: """ Decorator to mark a function as a Jinja template test. :param name: The name of the test, if different from the function name. :param inject: Whether or not this test needs any dependencies injected. :param safe: Whether or not to mark the output of this test as html-safe. """ def wrapper(fn): fn = _inject(fn, inject) if safe: fn = _make_safe(fn) self._defer(lambda app: app.add_template_test(fn, name=name)) return fn if callable(arg): return wrapper(arg) return wrapper
def _reset(self): """ This method is for use by tests only! """ self.bundles = AttrDict() self._deferred_bundle_functions = _DeferredBundleBlueprintFunctionsStore() self.babel_bundle = None self.env = None self.extensions = AttrDict() self.services = AttrDict() self._deferred_functions = [] self._initialized = False self._models_initialized = False self._services_initialized = False self._services_registry = {} self._shell_ctx = {} if QUART_ENABLED: def add_websocket( self, path: str, endpoint: Optional[str] = None, view_func: Optional[Callable] = None, defaults: Optional[dict] = None, host: Optional[str] = None, subdomain: Optional[str] = None, *, strict_slashes: Optional[bool] = None, ): self._defer(lambda app: app.add_websocket( path, endpoint=endpoint, view_func=view_func, defaults=defaults, host=host, subdomain=subdomain, strict_slashes=strict_slashes, )) def before_serving(self, fn=None): if fn is None: return self.before_serving self._defer(lambda app: app.before_serving(fn)) return fn def after_serving(self, fn=None): if fn is None: return self.after_serving self._defer(lambda app: app.after_serving(fn)) return fn def before_websocket(self, fn=None): """Add a before websocket function. This is designed to be used as a decorator, if used to decorate a synchronous function, the function will be wrapped in :func:`~quart.utils.run_sync` and run in a thread executor (with the wrapped function returned). An example usage, .. code-block:: python @app.before_websocket async def func(): ... :param fn: The before websocket function itself. """ if fn is None: return self.before_websocket self._defer(lambda app: app.before_websocket(fn)) return fn def after_websocket(self, fn=None): if fn is None: return self.after_websocket self._defer(lambda app: app.after_websocket(fn)) return fn def teardown_websocket(self, fn=None): if fn is None: return self.teardown_websocket self._defer(lambda app: app.teardown_websocket(fn)) return fn
def _inject(fn, inject_args): if not inject_args: return fn inject_args = inject_args if isinstance(inject_args, Iterable) else [] return unchained.inject(*inject_args)(fn) def _make_safe(fn): @functools.wraps(fn) def safe_fn(*args, **kwargs): return markupsafe.Markup(fn(*args, **kwargs)) return safe_fn unchained = Unchained() __all__ = [ 'unchained', 'Unchained', ]