from abc import ABC, abstractmethod, abstractproperty
import asyncio
import logging
import weakref
import dask
from ..metrics import time
from ..utils import parse_timedelta, ignoring, TimeoutError
from . import registry
from .addressing import parse_address
logger = logging.getLogger(__name__)
class CommClosedError(IOError):
pass
class FatalCommClosedError(CommClosedError):
pass
[docs]class Comm(ABC):
"""
A message-oriented communication object, representing an established
communication channel. There should be only one reader and one
writer at a time: to manage current communications, even with a
single peer, you must create distinct ``Comm`` objects.
Messages are arbitrary Python objects. Concrete implementations
of this class can implement different serialization mechanisms
depending on the underlying transport's characteristics.
"""
_instances = weakref.WeakSet()
def __init__(self):
self._instances.add(self)
self.name = None
# XXX add set_close_callback()?
[docs] @abstractmethod
def read(self, deserializers=None):
"""
Read and return a message (a Python object).
This method is a coroutine.
Parameters
----------
deserializers : Optional[Dict[str, Tuple[Callable, Callable, bool]]]
An optional dict appropriate for distributed.protocol.deserialize.
See :ref:`serialization` for more.
"""
[docs] @abstractmethod
def write(self, msg, serializers=None, on_error=None):
"""
Write a message (a Python object).
This method is a coroutine.
Parameters
----------
msg :
on_error : Optional[str]
The behavior when serialization fails. See
``distributed.protocol.core.dumps`` for valid values.
"""
[docs] @abstractmethod
def close(self):
"""
Close the communication cleanly. This will attempt to flush
outgoing buffers before actually closing the underlying transport.
This method is a coroutine.
"""
[docs] @abstractmethod
def abort(self):
"""
Close the communication immediately and abruptly.
Useful in destructors or generators' ``finally`` blocks.
"""
[docs] @abstractmethod
def closed(self):
"""
Return whether the stream is closed.
"""
@abstractproperty
def local_address(self):
"""
The local address. For logging and debugging purposes only.
"""
@abstractproperty
def peer_address(self):
"""
The peer's address. For logging and debugging purposes only.
"""
@property
def extra_info(self):
"""
Return backend-specific information about the communication,
as a dict. Typically, this is information which is initialized
when the communication is established and doesn't vary afterwards.
"""
return {}
def __repr__(self):
clsname = self.__class__.__name__
if self.closed():
return "<closed %s>" % (clsname,)
else:
return "<%s %s local=%s remote=%s>" % (
clsname,
self.name or "",
self.local_address,
self.peer_address,
)
[docs]class Listener(ABC):
[docs] @abstractmethod
async def start(self):
"""
Start listening for incoming connections.
"""
[docs] @abstractmethod
def stop(self):
"""
Stop listening. This does not shutdown already established
communications, but prevents accepting new ones.
"""
@abstractproperty
def listen_address(self):
"""
The listening address as a URI string.
"""
@abstractproperty
def contact_address(self):
"""
An address this listener can be contacted on. This can be
different from `listen_address` if the latter is some wildcard
address such as 'tcp://0.0.0.0:123'.
"""
async def __aenter__(self):
await self.start()
return self
async def __aexit__(self, *exc):
self.stop()
class Connector(ABC):
@abstractmethod
def connect(self, address, deserialize=True):
"""
Connect to the given address and return a Comm object.
This function is a coroutine. It may raise EnvironmentError
if the other endpoint is unreachable or unavailable. It
may raise ValueError if the address is malformed.
"""
[docs]async def connect(addr, timeout=None, deserialize=True, connection_args=None):
"""
Connect to the given address (a URI such as ``tcp://127.0.0.1:1234``)
and yield a ``Comm`` object. If the connection attempt fails, it is
retried until the *timeout* is expired.
"""
if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")
timeout = parse_timedelta(timeout, default="seconds")
scheme, loc = parse_address(addr)
backend = registry.get_backend(scheme)
connector = backend.get_connector()
comm = None
start = time()
deadline = start + timeout
error = None
def _raise(error):
error = error or "connect() didn't finish in time"
msg = "Timed out trying to connect to %r after %s s: %s" % (
addr,
timeout,
error,
)
raise IOError(msg)
# This starts a thread
while True:
try:
while deadline - time() > 0:
future = connector.connect(
loc, deserialize=deserialize, **(connection_args or {})
)
with ignoring(TimeoutError):
comm = await asyncio.wait_for(
future, timeout=min(deadline - time(), 1)
)
break
if not comm:
_raise(error)
except FatalCommClosedError:
raise
except EnvironmentError as e:
error = str(e)
if time() < deadline:
await asyncio.sleep(0.01)
logger.debug("sleeping on connect")
else:
_raise(error)
else:
break
return comm
[docs]def listen(addr, handle_comm, deserialize=True, connection_args=None):
"""
Create a listener object with the given parameters. When its ``start()``
method is called, the listener will listen on the given address
(a URI such as ``tcp://0.0.0.0``) and call *handle_comm* with a
``Comm`` object for each incoming connection.
*handle_comm* can be a regular function or a coroutine.
"""
try:
scheme, loc = parse_address(addr, strict=True)
except ValueError:
if connection_args and connection_args.get("ssl_context"):
addr = "tls://" + addr
else:
addr = "tcp://" + addr
scheme, loc = parse_address(addr, strict=True)
backend = registry.get_backend(scheme)
return backend.get_listener(
loc, handle_comm, deserialize, **(connection_args or {})
)