from collections import defaultdict

import attr
from async_generator import asynccontextmanager

from .. import _core
from .. import _util
from .. import Event

if False:
    from typing import DefaultDict, Set


@attr.s(eq=False, hash=False)
class Sequencer(metaclass=_util.Final):
    """A convenience class for forcing code in different tasks to run in an
    explicit linear order.

    Instances of this class implement a ``__call__`` method which returns an
    async context manager. The idea is that you pass a sequence number to
    ``__call__`` to say where this block of code should go in the linear
    sequence. Block 0 starts immediately, and then block N doesn't start until
    block N-1 has finished.

    Example:
      An extremely elaborate way to print the numbers 0-5, in order::

         async def worker1(seq):
             async with seq(0):
                 print(0)
             async with seq(4):
                 print(4)

         async def worker2(seq):
             async with seq(2):
                 print(2)
             async with seq(5):
                 print(5)

         async def worker3(seq):
             async with seq(1):
                 print(1)
             async with seq(3):
                 print(3)

         async def main():
            seq = trio.testing.Sequencer()
            async with trio.open_nursery() as nursery:
                nursery.start_soon(worker1, seq)
                nursery.start_soon(worker2, seq)
                nursery.start_soon(worker3, seq)

    """

    _sequence_points = attr.ib(
        factory=lambda: defaultdict(Event), init=False
    )  # type: DefaultDict[int, Event]
    _claimed = attr.ib(factory=set, init=False)  # type: Set[int]
    _broken = attr.ib(default=False, init=False)

    @asynccontextmanager
    async def __call__(self, position: int):
        if position in self._claimed:
            raise RuntimeError("Attempted to re-use sequence point {}".format(position))
        if self._broken:
            raise RuntimeError("sequence broken!")
        self._claimed.add(position)
        if position != 0:
            try:
                await self._sequence_points[position].wait()
            except _core.Cancelled:
                self._broken = True
                for event in self._sequence_points.values():
                    event.set()
                raise RuntimeError("Sequencer wait cancelled -- sequence broken")
            else:
                if self._broken:
                    raise RuntimeError("sequence broken!")
        try:
            yield
        finally:
            self._sequence_points[position + 1].set()
