# This should eventually be cleaned up and become public, but for right now I'm just
# implementing enough to test DTLS.

# TODO:
# - user-defined routers
# - TCP
# - UDP broadcast

import trio
import attr
import ipaddress
from collections import deque
import errno
import os
from typing import Union, List, Optional
import enum
from contextlib import contextmanager

from trio._util import Final, NoPublicConstructor

IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]


def _family_for(ip: IPAddress) -> int:
    if isinstance(ip, ipaddress.IPv4Address):
        return trio.socket.AF_INET
    elif isinstance(ip, ipaddress.IPv6Address):
        return trio.socket.AF_INET6
    assert False  # pragma: no cover


def _wildcard_ip_for(family: int) -> IPAddress:
    if family == trio.socket.AF_INET:
        return ipaddress.ip_address("0.0.0.0")
    elif family == trio.socket.AF_INET6:
        return ipaddress.ip_address("::")
    else:
        assert False


def _localhost_ip_for(family: int) -> IPAddress:
    if family == trio.socket.AF_INET:
        return ipaddress.ip_address("127.0.0.1")
    elif family == trio.socket.AF_INET6:
        return ipaddress.ip_address("::1")
    else:
        assert False


def _fake_err(code):
    raise OSError(code, os.strerror(code))


def _scatter(data, buffers):
    written = 0
    for buf in buffers:
        next_piece = data[written : written + len(buf)]
        with memoryview(buf) as mbuf:
            mbuf[: len(next_piece)] = next_piece
        written += len(next_piece)
        if written == len(data):
            break
    return written


@attr.frozen
class UDPEndpoint:
    ip: IPAddress
    port: int

    def as_python_sockaddr(self):
        sockaddr = (self.ip.compressed, self.port)
        if isinstance(self.ip, ipaddress.IPv6Address):
            sockaddr += (0, 0)
        return sockaddr

    @classmethod
    def from_python_sockaddr(cls, sockaddr):
        ip, port = sockaddr[:2]
        return cls(ip=ipaddress.ip_address(ip), port=port)


@attr.frozen
class UDPBinding:
    local: UDPEndpoint


@attr.frozen
class UDPPacket:
    source: UDPEndpoint
    destination: UDPEndpoint
    payload: bytes = attr.ib(repr=lambda p: p.hex())

    def reply(self, payload):
        return UDPPacket(
            source=self.destination, destination=self.source, payload=payload
        )


@attr.frozen
class FakeSocketFactory(trio.abc.SocketFactory):
    fake_net: "FakeNet"

    def socket(self, family: int, type: int, proto: int) -> "FakeSocket":
        return FakeSocket._create(self.fake_net, family, type, proto)


@attr.frozen
class FakeHostnameResolver(trio.abc.HostnameResolver):
    fake_net: "FakeNet"

    async def getaddrinfo(
        self, host: str, port: Union[int, str], family=0, type=0, proto=0, flags=0
    ):
        raise NotImplementedError("FakeNet doesn't do fake DNS yet")

    async def getnameinfo(self, sockaddr, flags: int):
        raise NotImplementedError("FakeNet doesn't do fake DNS yet")


class FakeNet(metaclass=Final):
    def __init__(self):
        # When we need to pick an arbitrary unique ip address/port, use these:
        self._auto_ipv4_iter = ipaddress.IPv4Network("1.0.0.0/8").hosts()
        self._auto_ipv4_iter = ipaddress.IPv6Network("1::/16").hosts()
        self._auto_port_iter = iter(range(50000, 65535))

        self._bound: Dict[UDPBinding, FakeSocket] = {}

        self.route_packet = None

    def _bind(self, binding: UDPBinding, socket: "FakeSocket") -> None:
        if binding in self._bound:
            _fake_err(errno.EADDRINUSE)
        self._bound[binding] = socket

    def enable(self) -> None:
        trio.socket.set_custom_socket_factory(FakeSocketFactory(self))
        trio.socket.set_custom_hostname_resolver(FakeHostnameResolver(self))

    def send_packet(self, packet) -> None:
        if self.route_packet is None:
            self.deliver_packet(packet)
        else:
            self.route_packet(packet)

    def deliver_packet(self, packet) -> None:
        binding = UDPBinding(local=packet.destination)
        if binding in self._bound:
            self._bound[binding]._deliver_packet(packet)
        else:
            # No valid destination, so drop it
            pass


class FakeSocket(trio.socket.SocketType, metaclass=NoPublicConstructor):
    def __init__(self, fake_net: FakeNet, family: int, type: int, proto: int):
        self._fake_net = fake_net

        if not family:
            family = trio.socket.AF_INET
        if not type:
            type = trio.socket.SOCK_STREAM

        if family not in (trio.socket.AF_INET, trio.socket.AF_INET6):
            raise NotImplementedError(f"FakeNet doesn't (yet) support family={family}")
        if type != trio.socket.SOCK_DGRAM:
            raise NotImplementedError(f"FakeNet doesn't (yet) support type={type}")

        self.family = family
        self.type = type
        self.proto = proto

        self._closed = False

        self._packet_sender, self._packet_receiver = trio.open_memory_channel(
            float("inf")
        )

        # This is the source-of-truth for what port etc. this socket is bound to
        self._binding: Optional[UDPBinding] = None

    def _check_closed(self):
        if self._closed:
            _fake_err(errno.EBADF)

    def close(self):
        # breakpoint()
        if self._closed:
            return
        self._closed = True
        if self._binding is not None:
            del self._fake_net._bound[self._binding]
        self._packet_receiver.close()

    async def _resolve_address_nocp(self, address, *, local):
        return await trio._socket._resolve_address_nocp(
            self.type,
            self.family,
            self.proto,
            address=address,
            ipv6_v6only=False,
            local=local,
        )

    def _deliver_packet(self, packet: UDPPacket):
        try:
            self._packet_sender.send_nowait(packet)
        except trio.BrokenResourceError:
            # sending to a closed socket -- UDP packets get dropped
            pass

    ################################################################
    # Actual IO operation implementations
    ################################################################

    async def bind(self, addr):
        self._check_closed()
        if self._binding is not None:
            _fake_error(errno.EINVAL)
        await trio.lowlevel.checkpoint()
        ip_str, port = await self._resolve_address_nocp(addr, local=True)
        ip = ipaddress.ip_address(ip_str)
        assert _family_for(ip) == self.family
        # We convert binds to INET_ANY into binds to localhost
        if ip == ipaddress.ip_address("0.0.0.0"):
            ip = ipaddress.ip_address("127.0.0.1")
        elif ip == ipaddress.ip_address("::"):
            ip = ipaddress.ip_address("::1")
        if port == 0:
            port = next(self._fake_net._auto_port_iter)
        binding = UDPBinding(local=UDPEndpoint(ip, port))
        self._fake_net._bind(binding, self)
        self._binding = binding

    async def connect(self, peer):
        raise NotImplementedError("FakeNet does not (yet) support connected sockets")

    async def sendmsg(self, *args):
        self._check_closed()
        ancdata = []
        flags = 0
        address = None
        if len(args) == 1:
            (buffers,) = args
        elif len(args) == 2:
            buffers, address = args
        elif len(args) == 3:
            buffers, flags, address = args
        elif len(args) == 4:
            buffers, ancdata, flags, address = args
        else:
            raise TypeError("wrong number of arguments")

        await trio.lowlevel.checkpoint()

        if address is not None:
            address = await self._resolve_address_nocp(address, local=False)
        if ancdata:
            raise NotImplementedError("FakeNet doesn't support ancillary data")
        if flags:
            raise NotImplementedError(f"FakeNet send flags must be 0, not {flags}")

        if address is None:
            _fake_err(errno.ENOTCONN)

        destination = UDPEndpoint.from_python_sockaddr(address)

        if self._binding is None:
            await self.bind((_wildcard_ip_for(self.family).compressed, 0))

        payload = b"".join(buffers)

        packet = UDPPacket(
            source=self._binding.local,
            destination=destination,
            payload=payload,
        )

        self._fake_net.send_packet(packet)

        return len(payload)

    async def recvmsg_into(self, buffers, ancbufsize=0, flags=0):
        if ancbufsize != 0:
            raise NotImplementedError("FakeNet doesn't support ancillary data")
        if flags != 0:
            raise NotImplementedError("FakeNet doesn't support any recv flags")

        self._check_closed()

        ancdata = []
        msg_flags = 0

        packet = await self._packet_receiver.receive()
        address = packet.source.as_python_sockaddr()
        written = _scatter(packet.payload, buffers)
        if written < len(packet.payload):
            msg_flags |= trio.socket.MSG_TRUNC
        return written, ancdata, msg_flags, address

    ################################################################
    # Simple state query stuff
    ################################################################

    def getsockname(self):
        self._check_closed()
        if self._binding is not None:
            return self._binding.local.as_python_sockaddr()
        elif self.family == trio.socket.AF_INET:
            return ("0.0.0.0", 0)
        else:
            assert self.family == trio.socket.AF_INET6
            return ("::", 0)

    def getpeername(self):
        self._check_closed()
        if self._binding is not None:
            if self._binding.remote is not None:
                return self._binding.remote.as_python_sockaddr()
        _fake_err(errno.ENOTCONN)

    def getsockopt(self, level, item):
        self._check_closed()
        raise OSError(f"FakeNet doesn't implement getsockopt({level}, {item})")

    def setsockopt(self, level, item, value):
        self._check_closed()

        if (level, item) == (trio.socket.IPPROTO_IPV6, trio.socket.IPV6_V6ONLY):
            if not value:
                raise NotImplementedError("FakeNet always has IPV6_V6ONLY=True")

        raise OSError(f"FakeNet doesn't implement setsockopt({level}, {item}, ...)")

    ################################################################
    # Various boilerplate and trivial stubs
    ################################################################

    def __enter__(self):
        return self

    def __exit__(self, *exc_info):
        self.close()

    async def send(self, data, flags=0):
        return await self.sendto(data, flags, None)

    async def sendto(self, *args):
        if len(args) == 2:
            data, address = args
            flags = 0
        elif len(args) == 3:
            data, flags, address = args
        else:
            raise TypeError("wrong number of arguments")
        return await self.sendmsg([data], [], flags, address)

    async def recv(self, bufsize, flags=0):
        data, address = await self.recvfrom(bufsize, flags)
        return data

    async def recv_into(self, buf, nbytes=0, flags=0):
        got_bytes, address = await self.recvfrom_into(buf, nbytes, flags)
        return got_bytes

    async def recvfrom(self, bufsize, flags=0):
        data, ancdata, msg_flags, address = await self.recvmsg(bufsize, flags)
        return data, address

    async def recvfrom_into(self, buf, nbytes=0, flags=0):
        if nbytes != 0 and nbytes != len(buf):
            raise NotImplementedError("partial recvfrom_into")
        got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into(
            [buf], 0, flags
        )
        return got_nbytes, address

    async def recvmsg(self, bufsize, ancbufsize=0, flags=0):
        buf = bytearray(bufsize)
        got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into(
            [buf], ancbufsize, flags
        )
        return (bytes(buf[:got_nbytes]), ancdata, msg_flags, address)

    def fileno(self):
        raise NotImplementedError("can't get fileno() for FakeNet sockets")

    def detach(self):
        raise NotImplementedError("can't detach() a FakeNet socket")

    def get_inheritable(self):
        return False

    def set_inheritable(self, inheritable):
        if inheritable:
            raise NotImplementedError("FakeNet can't make inheritable sockets")

    def share(self, process_id):
        raise NotImplementedError("FakeNet can't share sockets")
