Spaces:
Paused
Paused
| # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license | |
| """trio async I/O library query support""" | |
| import socket | |
| import trio | |
| import trio.socket # type: ignore | |
| import dns._asyncbackend | |
| import dns._features | |
| import dns.exception | |
| import dns.inet | |
| if not dns._features.have("trio"): | |
| raise ImportError("trio not found or too old") | |
| def _maybe_timeout(timeout): | |
| if timeout is not None: | |
| return trio.move_on_after(timeout) | |
| else: | |
| return dns._asyncbackend.NullContext() | |
| # for brevity | |
| _lltuple = dns.inet.low_level_address_tuple | |
| # pylint: disable=redefined-outer-name | |
| class DatagramSocket(dns._asyncbackend.DatagramSocket): | |
| def __init__(self, socket): | |
| super().__init__(socket.family) | |
| self.socket = socket | |
| async def sendto(self, what, destination, timeout): | |
| with _maybe_timeout(timeout): | |
| return await self.socket.sendto(what, destination) | |
| raise dns.exception.Timeout( | |
| timeout=timeout | |
| ) # pragma: no cover lgtm[py/unreachable-statement] | |
| async def recvfrom(self, size, timeout): | |
| with _maybe_timeout(timeout): | |
| return await self.socket.recvfrom(size) | |
| raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement] | |
| async def close(self): | |
| self.socket.close() | |
| async def getpeername(self): | |
| return self.socket.getpeername() | |
| async def getsockname(self): | |
| return self.socket.getsockname() | |
| async def getpeercert(self, timeout): | |
| raise NotImplementedError | |
| class StreamSocket(dns._asyncbackend.StreamSocket): | |
| def __init__(self, family, stream, tls=False): | |
| self.family = family | |
| self.stream = stream | |
| self.tls = tls | |
| async def sendall(self, what, timeout): | |
| with _maybe_timeout(timeout): | |
| return await self.stream.send_all(what) | |
| raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement] | |
| async def recv(self, size, timeout): | |
| with _maybe_timeout(timeout): | |
| return await self.stream.receive_some(size) | |
| raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement] | |
| async def close(self): | |
| await self.stream.aclose() | |
| async def getpeername(self): | |
| if self.tls: | |
| return self.stream.transport_stream.socket.getpeername() | |
| else: | |
| return self.stream.socket.getpeername() | |
| async def getsockname(self): | |
| if self.tls: | |
| return self.stream.transport_stream.socket.getsockname() | |
| else: | |
| return self.stream.socket.getsockname() | |
| async def getpeercert(self, timeout): | |
| if self.tls: | |
| with _maybe_timeout(timeout): | |
| await self.stream.do_handshake() | |
| return self.stream.getpeercert() | |
| else: | |
| raise NotImplementedError | |
| if dns._features.have("doh"): | |
| import httpcore | |
| import httpcore._backends.trio | |
| import httpx | |
| _CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend | |
| _CoreTrioStream = httpcore._backends.trio.TrioStream | |
| from dns.query import _compute_times, _expiration_for_this_attempt, _remaining | |
| class _NetworkBackend(_CoreAsyncNetworkBackend): | |
| def __init__(self, resolver, local_port, bootstrap_address, family): | |
| super().__init__() | |
| self._local_port = local_port | |
| self._resolver = resolver | |
| self._bootstrap_address = bootstrap_address | |
| self._family = family | |
| async def connect_tcp( | |
| self, host, port, timeout, local_address, socket_options=None | |
| ): # pylint: disable=signature-differs | |
| addresses = [] | |
| _, expiration = _compute_times(timeout) | |
| if dns.inet.is_address(host): | |
| addresses.append(host) | |
| elif self._bootstrap_address is not None: | |
| addresses.append(self._bootstrap_address) | |
| else: | |
| timeout = _remaining(expiration) | |
| family = self._family | |
| if local_address: | |
| family = dns.inet.af_for_address(local_address) | |
| answers = await self._resolver.resolve_name( | |
| host, family=family, lifetime=timeout | |
| ) | |
| addresses = answers.addresses() | |
| for address in addresses: | |
| try: | |
| af = dns.inet.af_for_address(address) | |
| if local_address is not None or self._local_port != 0: | |
| source = (local_address, self._local_port) | |
| else: | |
| source = None | |
| destination = (address, port) | |
| attempt_expiration = _expiration_for_this_attempt(2.0, expiration) | |
| timeout = _remaining(attempt_expiration) | |
| sock = await Backend().make_socket( | |
| af, socket.SOCK_STREAM, 0, source, destination, timeout | |
| ) | |
| return _CoreTrioStream(sock.stream) | |
| except Exception: | |
| continue | |
| raise httpcore.ConnectError | |
| async def connect_unix_socket( | |
| self, path, timeout, socket_options=None | |
| ): # pylint: disable=signature-differs | |
| raise NotImplementedError | |
| async def sleep(self, seconds): # pylint: disable=signature-differs | |
| await trio.sleep(seconds) | |
| class _HTTPTransport(httpx.AsyncHTTPTransport): | |
| def __init__( | |
| self, | |
| *args, | |
| local_port=0, | |
| bootstrap_address=None, | |
| resolver=None, | |
| family=socket.AF_UNSPEC, | |
| **kwargs, | |
| ): | |
| if resolver is None: | |
| # pylint: disable=import-outside-toplevel,redefined-outer-name | |
| import dns.asyncresolver | |
| resolver = dns.asyncresolver.Resolver() | |
| super().__init__(*args, **kwargs) | |
| self._pool._network_backend = _NetworkBackend( | |
| resolver, local_port, bootstrap_address, family | |
| ) | |
| else: | |
| _HTTPTransport = dns._asyncbackend.NullTransport # type: ignore | |
| class Backend(dns._asyncbackend.Backend): | |
| def name(self): | |
| return "trio" | |
| async def make_socket( | |
| self, | |
| af, | |
| socktype, | |
| proto=0, | |
| source=None, | |
| destination=None, | |
| timeout=None, | |
| ssl_context=None, | |
| server_hostname=None, | |
| ): | |
| s = trio.socket.socket(af, socktype, proto) | |
| stream = None | |
| try: | |
| if source: | |
| await s.bind(_lltuple(source, af)) | |
| if socktype == socket.SOCK_STREAM: | |
| connected = False | |
| with _maybe_timeout(timeout): | |
| await s.connect(_lltuple(destination, af)) | |
| connected = True | |
| if not connected: | |
| raise dns.exception.Timeout( | |
| timeout=timeout | |
| ) # lgtm[py/unreachable-statement] | |
| except Exception: # pragma: no cover | |
| s.close() | |
| raise | |
| if socktype == socket.SOCK_DGRAM: | |
| return DatagramSocket(s) | |
| elif socktype == socket.SOCK_STREAM: | |
| stream = trio.SocketStream(s) | |
| tls = False | |
| if ssl_context: | |
| tls = True | |
| try: | |
| stream = trio.SSLStream( | |
| stream, ssl_context, server_hostname=server_hostname | |
| ) | |
| except Exception: # pragma: no cover | |
| await stream.aclose() | |
| raise | |
| return StreamSocket(af, stream, tls) | |
| raise NotImplementedError( | |
| "unsupported socket " + f"type {socktype}" | |
| ) # pragma: no cover | |
| async def sleep(self, interval): | |
| await trio.sleep(interval) | |
| def get_transport_class(self): | |
| return _HTTPTransport | |
| async def wait_for(self, awaitable, timeout): | |
| with _maybe_timeout(timeout): | |
| return await awaitable | |
| raise dns.exception.Timeout( | |
| timeout=timeout | |
| ) # pragma: no cover lgtm[py/unreachable-statement] | |