# coding: utf-8
from __future__ import print_function, unicode_literals

import socket
import time

import ipaddress
from ipaddress import (
    IPv4Address,
    IPv4Network,
    IPv6Address,
    IPv6Network,
    ip_address,
    ip_network,
)

from .__init__ import MACOS, TYPE_CHECKING
from .util import Daemon, Netdev, find_prefix, min_ex, spack

if TYPE_CHECKING:
    from .svchub import SvcHub

if not hasattr(socket, "IPPROTO_IPV6"):
    setattr(socket, "IPPROTO_IPV6", 41)


class NoIPs(Exception):
    pass


class MC_Sck(object):
    """there is one socket for each server ip"""

    def __init__(
        self,
        sck ,
        nd ,
        grp ,
        ip ,
        net  ,
    ):
        self.sck = sck
        self.idx = nd.idx
        self.name = nd.name
        self.grp = grp
        self.mreq = b""
        self.ip = ip
        self.net = net
        self.ips = {ip: net}
        self.v6 = ":" in ip
        self.have4 = ":" not in ip
        self.have6 = ":" in ip


class MCast(object):
    def __init__(
        self,
        hub ,
        Srv ,
        on ,
        off ,
        mc_grp_4 ,
        mc_grp_6 ,
        port ,
        vinit ,
    )  :
        """disable ipv%d by setting mc_grp_%d empty"""
        self.hub = hub
        self.Srv = Srv
        self.args = hub.args
        self.asrv = hub.asrv
        self.log_func = hub.log
        self.on = on
        self.off = off
        self.grp4 = mc_grp_4
        self.grp6 = mc_grp_6
        self.port = port
        self.vinit = vinit

        self.srv   = {}  # listening sockets
        self.sips  = set()  # all listening ips (including failed attempts)
        self.ll_ok  = set()  # fallback linklocal IPv4 and IPv6 addresses
        self.b2srv   = {}  # binary-ip -> server socket
        self.b4  = []  # sorted list of binary-ips
        self.b6  = []  # sorted list of binary-ips
        self.cscache   = {}  # client ip -> server cache

        self.running = True

    def log(self, msg , c   = 0)  :
        self.log_func("multicast", msg, c)

    def create_servers(self)  :
        bound  = []
        netdevs = self.hub.tcpsrv.netdevs
        ips = [x[0] for x in self.hub.tcpsrv.bound]

        if "::" in ips:
            ips = [x for x in ips if x != "::"] + list(
                [x.split("/")[0] for x in netdevs if ":" in x]
            )
            ips.append("0.0.0.0")

        if "0.0.0.0" in ips:
            ips = [x for x in ips if x != "0.0.0.0"] + list(
                [x.split("/")[0] for x in netdevs if ":" not in x]
            )

        ips = [x for x in ips if x not in ("::1", "127.0.0.1")]
        ips = find_prefix(ips, list(netdevs))

        on = self.on[:]
        off = self.off[:]
        for lst in (on, off):
            for av in list(lst):
                try:
                    arg_net = ip_network(av, False)
                except:
                    arg_net = None

                for sk, sv in netdevs.items():
                    if arg_net:
                        net_ip = ip_address(sk.split("/")[0])
                        if net_ip in arg_net and sk not in lst:
                            lst.append(sk)

                    if (av == str(sv.idx) or av == sv.name) and sk not in lst:
                        lst.append(sk)

        if on:
            ips = [x for x in ips if x in on]
        elif off:
            ips = [x for x in ips if x not in off]

        if not self.grp4:
            ips = [x for x in ips if ":" in x]

        if not self.grp6:
            ips = [x for x in ips if ":" not in x]

        ips = list(set(ips))
        all_selected = ips[:]

        # discard non-linklocal ipv6
        ips = [x for x in ips if ":" not in x or x.startswith("fe80")]

        if not ips:
            raise NoIPs()

        for ip in ips:
            v6 = ":" in ip
            netdev = netdevs[ip]
            if not netdev.idx:
                t = "using INADDR_ANY for ip [{}], netdev [{}]"
                if not self.srv and ip not in ["::", "0.0.0.0"]:
                    self.log(t.format(ip, netdev), 3)

            ipv = socket.AF_INET6 if v6 else socket.AF_INET
            sck = socket.socket(ipv, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
            sck.settimeout(None)
            sck.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            try:
                # safe for this purpose; https://lwn.net/Articles/853637/
                sck.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
            except:
                pass

            # most ipv6 clients expect multicast on linklocal ip only;
            # add a/aaaa records for the other nic IPs
            other_ips  = set()
            if v6:
                for nd in netdevs.values():
                    if nd.idx == netdev.idx and nd.ip in all_selected and ":" in nd.ip:
                        other_ips.add(nd.ip)

            net = ipaddress.ip_network(ip, False)
            ip = ip.split("/")[0]
            srv = self.Srv(sck, netdev, self.grp6 if ":" in ip else self.grp4, ip, net)
            for oth_ip in other_ips:
                srv.ips[oth_ip.split("/")[0]] = ipaddress.ip_network(oth_ip, False)

            # gvfs breaks if a linklocal ip appears in a dns reply
            ll = {k: v for k, v in srv.ips.items() if k.startswith(("169.254", "fe80"))}
            rt = {k: v for k, v in srv.ips.items() if k not in ll}

            if self.args.ll or not rt:
                self.ll_ok.update(list(ll))

            if not self.args.ll:
                srv.ips = rt or ll

            if not srv.ips:
                self.log("no IPs on {}; skipping [{}]".format(netdev, ip), 3)
                continue

            try:
                self.setup_socket(srv)
                self.srv[sck] = srv
                bound.append(ip)
            except:
                t = "announce failed on {} [{}]:\n{}"
                self.log(t.format(netdev, ip, min_ex()), 3)
                sck.close()

        if self.args.zm_msub:
            for s1 in self.srv.values():
                for s2 in self.srv.values():
                    if s1.idx != s2.idx:
                        continue

                    if s1.ip not in s2.ips:
                        s2.ips[s1.ip] = s1.net

        if self.args.zm_mnic:
            for s1 in self.srv.values():
                for s2 in self.srv.values():
                    for ip1, net1 in list(s1.ips.items()):
                        for ip2, net2 in list(s2.ips.items()):
                            if net1 == net2 and ip1 != ip2:
                                s1.ips[ip2] = net2

        self.sips = set([x.split("/")[0] for x in all_selected])
        for srv in self.srv.values():
            assert srv.ip in self.sips

        Daemon(self.hopper, "mc-hop")
        return bound

    def setup_socket(self, srv )  :
        sck = srv.sck
        if srv.v6:
            if self.vinit:
                zsl = list(srv.ips.keys())
                self.log("v6({}) idx({}) {}".format(srv.ip, srv.idx, zsl), 6)

            for ip in srv.ips:
                bip = socket.inet_pton(socket.AF_INET6, ip)
                self.b2srv[bip] = srv
                self.b6.append(bip)

            grp = self.grp6 if srv.idx else ""
            try:
                if MACOS:
                    raise Exception()

                sck.bind((grp, self.port, 0, srv.idx))
            except:
                sck.bind(("", self.port, 0, srv.idx))

            bgrp = socket.inet_pton(socket.AF_INET6, self.grp6)
            dev = spack(b"@I", srv.idx)
            srv.mreq = bgrp + dev
            if srv.idx != socket.INADDR_ANY:
                sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, dev)

            try:
                sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 255)
                sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, 1)
            except:
                # macos
                t = "failed to set IPv6 TTL/LOOP; announcements may not survive multiple switches/routers"
                self.log(t, 3)
        else:
            if self.vinit:
                self.log("v4({}) idx({})".format(srv.ip, srv.idx), 6)

            bip = socket.inet_aton(srv.ip)
            self.b2srv[bip] = srv
            self.b4.append(bip)

            grp = self.grp4 if srv.idx else ""
            try:
                if MACOS:
                    raise Exception()

                sck.bind((grp, self.port))
            except:
                sck.bind(("", self.port))

            bgrp = socket.inet_aton(self.grp4)
            dev = (
                spack(b"=I", socket.INADDR_ANY)
                if srv.idx == socket.INADDR_ANY
                else socket.inet_aton(srv.ip)
            )
            srv.mreq = bgrp + dev
            if srv.idx != socket.INADDR_ANY:
                sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, dev)

            try:
                sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 255)
                sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, 1)
            except:
                # probably can't happen but dontcare if it does
                t = "failed to set IPv4 TTL/LOOP; announcements may not survive multiple switches/routers"
                self.log(t, 3)

        if self.hop(srv, False):
            self.log("igmp was already joined?? chilling for a sec", 3)
            time.sleep(1.2)

        self.hop(srv, True)
        self.b4.sort(reverse=True)
        self.b6.sort(reverse=True)

    def hop(self, srv , on )  :
        """rejoin to keepalive on routers/switches without igmp-snooping"""
        sck = srv.sck
        req = srv.mreq
        if ":" in srv.ip:
            if not on:
                try:
                    sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_LEAVE_GROUP, req)
                    return True
                except:
                    return False
            else:
                sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, req)
        else:
            if not on:
                try:
                    sck.setsockopt(socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP, req)
                    return True
                except:
                    return False
            else:
                # t = "joining {} from ip {} idx {} with mreq {}"
                # self.log(t.format(srv.grp, srv.ip, srv.idx, repr(srv.mreq)), 6)
                sck.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, req)

        return True

    def hopper(self):
        while self.args.mc_hop and self.running:
            time.sleep(self.args.mc_hop)
            if not self.running:
                return

            for srv in self.srv.values():
                self.hop(srv, False)

            # linux does leaves/joins twice with 0.2~1.05s spacing
            time.sleep(1.2)
            if not self.running:
                return

            for srv in self.srv.values():
                self.hop(srv, True)

    def map_client(self, cip )  :
        try:
            return self.cscache[cip]
        except:
            pass

        ret  = None
        v6 = ":" in cip
        ci = IPv6Address(cip) if v6 else IPv4Address(cip)
        for x in self.b6 if v6 else self.b4:
            srv = self.b2srv[x]
            if any([x for x in srv.ips.values() if ci in x]):
                ret = srv
                break

        if not ret and cip in ("127.0.0.1", "::1"):
            # just give it something
            ret = list(self.srv.values())[0]

        if not ret and cip.startswith("169.254"):
            # idk how to map LL IPv4 msgs to nics;
            # just pick one and hope for the best
            lls = (
                x
                for x in self.srv.values()
                if next((y for y in x.ips if y in self.ll_ok), None)
            )
            ret = next(lls, None)

        if ret:
            t = "new client on {} ({}): {}"
            self.log(t.format(ret.name, ret.net, cip), 6)
        else:
            t = "could not map client {} to known subnet; maybe forwarded from another network?"
            self.log(t.format(cip), 3)

        if len(self.cscache) > 9000:
            self.cscache = {}

        self.cscache[cip] = ret
        return ret
