Source code for distributed.lock

import asyncio
from collections import defaultdict, deque
import logging
import uuid

from .client import _get_global_client
from .utils import log_errors, TimeoutError
from .worker import get_worker

logger = logging.getLogger(__name__)


class LockExtension:
    """ An extension for the scheduler to manage Locks

    This adds the following routes to the scheduler

    *  lock_acquire
    *  lock_release
    """

    def __init__(self, scheduler):
        self.scheduler = scheduler
        self.events = defaultdict(deque)
        self.ids = dict()

        self.scheduler.handlers.update(
            {"lock_acquire": self.acquire, "lock_release": self.release}
        )

        self.scheduler.extensions["locks"] = self

    async def acquire(self, stream=None, name=None, id=None, timeout=None):
        with log_errors():
            if isinstance(name, list):
                name = tuple(name)
            if name not in self.ids:
                result = True
            else:
                while name in self.ids:
                    event = asyncio.Event()
                    self.events[name].append(event)
                    future = event.wait()
                    if timeout is not None:
                        future = asyncio.wait_for(future, timeout)
                    try:
                        await future
                    except TimeoutError:
                        result = False
                        break
                    else:
                        result = True
                    finally:
                        event2 = self.events[name].popleft()
                        assert event is event2
            if result:
                assert name not in self.ids
                self.ids[name] = id
            return result

    def release(self, stream=None, name=None, id=None):
        with log_errors():
            if isinstance(name, list):
                name = tuple(name)
            if self.ids.get(name) != id:
                raise ValueError("This lock has not yet been acquired")
            del self.ids[name]
            if self.events[name]:
                self.scheduler.loop.add_callback(self.events[name][0].set)
            else:
                del self.events[name]


[docs]class Lock: """ Distributed Centralized Lock Parameters ---------- name: string (optional) Name of the lock to acquire. Choosing the same name allows two disconnected processes to coordinate a lock. If not given, a random name will be generated. client: Client (optional) Client to use for communication with the scheduler. If not given, the default global client will be used. Examples -------- >>> lock = Lock('x') # doctest: +SKIP >>> lock.acquire(timeout=1) # doctest: +SKIP >>> # do things with protected resource >>> lock.release() # doctest: +SKIP """ def __init__(self, name=None, client=None): self.client = client or _get_global_client() or get_worker().client self.name = name or "lock-" + uuid.uuid4().hex self.id = uuid.uuid4().hex self._locked = False
[docs] def acquire(self, blocking=True, timeout=None): """ Acquire the lock Parameters ---------- blocking : bool, optional If false, don't wait on the lock in the scheduler at all. timeout : number, optional Seconds to wait on the lock in the scheduler. This does not include local coroutine time, network transfer time, etc.. It is forbidden to specify a timeout when blocking is false. Examples -------- >>> lock = Lock('x') # doctest: +SKIP >>> lock.acquire(timeout=1) # doctest: +SKIP Returns ------- True or False whether or not it sucessfully acquired the lock """ if not blocking: if timeout is not None: raise ValueError("can't specify a timeout for a non-blocking call") timeout = 0 result = self.client.sync( self.client.scheduler.lock_acquire, name=self.name, id=self.id, timeout=timeout, ) self._locked = True return result
[docs] def release(self): """ Release the lock if already acquired """ if not self.locked(): raise ValueError("Lock is not yet acquired") result = self.client.sync( self.client.scheduler.lock_release, name=self.name, id=self.id ) self._locked = False return result
def locked(self): return self._locked def __enter__(self): self.acquire() return self def __exit__(self, *args, **kwargs): self.release() async def __aenter__(self): await self.acquire() return self async def __aexit__(self, *args, **kwargs): await self.release() def __reduce__(self): return (Lock, (self.name,))