Spaces:
Paused
Paused
| # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license | |
| from typing import Dict | |
| import dns.exception | |
| # pylint: disable=unused-import | |
| from dns._asyncbackend import ( # noqa: F401 lgtm[py/unused-import] | |
| Backend, | |
| DatagramSocket, | |
| Socket, | |
| StreamSocket, | |
| ) | |
| # pylint: enable=unused-import | |
| _default_backend = None | |
| _backends: Dict[str, Backend] = {} | |
| # Allow sniffio import to be disabled for testing purposes | |
| _no_sniffio = False | |
| class AsyncLibraryNotFoundError(dns.exception.DNSException): | |
| pass | |
| def get_backend(name: str) -> Backend: | |
| """Get the specified asynchronous backend. | |
| *name*, a ``str``, the name of the backend. Currently the "trio" | |
| and "asyncio" backends are available. | |
| Raises NotImplementedError if an unknown backend name is specified. | |
| """ | |
| # pylint: disable=import-outside-toplevel,redefined-outer-name | |
| backend = _backends.get(name) | |
| if backend: | |
| return backend | |
| if name == "trio": | |
| import dns._trio_backend | |
| backend = dns._trio_backend.Backend() | |
| elif name == "asyncio": | |
| import dns._asyncio_backend | |
| backend = dns._asyncio_backend.Backend() | |
| else: | |
| raise NotImplementedError(f"unimplemented async backend {name}") | |
| _backends[name] = backend | |
| return backend | |
| def sniff() -> str: | |
| """Attempt to determine the in-use asynchronous I/O library by using | |
| the ``sniffio`` module if it is available. | |
| Returns the name of the library, or raises AsyncLibraryNotFoundError | |
| if the library cannot be determined. | |
| """ | |
| # pylint: disable=import-outside-toplevel | |
| try: | |
| if _no_sniffio: | |
| raise ImportError | |
| import sniffio | |
| try: | |
| return sniffio.current_async_library() | |
| except sniffio.AsyncLibraryNotFoundError: | |
| raise AsyncLibraryNotFoundError("sniffio cannot determine async library") | |
| except ImportError: | |
| import asyncio | |
| try: | |
| asyncio.get_running_loop() | |
| return "asyncio" | |
| except RuntimeError: | |
| raise AsyncLibraryNotFoundError("no async library detected") | |
| def get_default_backend() -> Backend: | |
| """Get the default backend, initializing it if necessary.""" | |
| if _default_backend: | |
| return _default_backend | |
| return set_default_backend(sniff()) | |
| def set_default_backend(name: str) -> Backend: | |
| """Set the default backend. | |
| It's not normally necessary to call this method, as | |
| ``get_default_backend()`` will initialize the backend | |
| appropriately in many cases. If ``sniffio`` is not installed, or | |
| in testing situations, this function allows the backend to be set | |
| explicitly. | |
| """ | |
| global _default_backend | |
| _default_backend = get_backend(name) | |
| return _default_backend | |