# coding: utf-8
from __future__ import print_function, unicode_literals

import hashlib
import math
import os
import re
import socket
import sys
import threading
import time

import queue

from .__init__ import ANYWIN, CORES, EXE, MACOS, PY2, TYPE_CHECKING, EnvParams, unicode

try:
    MNFE = ModuleNotFoundError
except:
    MNFE = ImportError

try:
    import jinja2
except MNFE:
    if EXE:
        raise

    print(
        """\033[1;31m
  you do not have jinja2 installed,\033[33m
  choose one of these:\033[0m
   * apt install python-jinja2
   * {} -m pip install --user jinja2
   * (try another python version, if you have one)
   * (try copyparty.sfx instead)
""".format(
            sys.executable
        )
    )
    sys.exit(1)
except SyntaxError:
    if EXE:
        raise

    print(
        """\033[1;31m
  your jinja2 version is incompatible with your python version;\033[33m
  please try to replace it with an older version:\033[0m
   * {} -m pip install --user jinja2==2.11.3
   * (try another python version, if you have one)
   * (try copyparty.sfx instead)
""".format(
            sys.executable
        )
    )
    sys.exit(1)

from .httpconn import HttpConn
from .metrics import Metrics
from .u2idx import U2idx
from .util import (
    E_SCK,
    FHC,
    CachedDict,
    Daemon,
    Garda,
    Magician,
    Netdev,
    NetMap,
    build_netmap,
    has_resource,
    ipnorm,
    load_ipu,
    load_resource,
    min_ex,
    shut_socket,
    spack,
    start_log_thrs,
    start_stackmon,
    ub64enc,
)

if TYPE_CHECKING:
    from .authsrv import VFS
    from .broker_util import BrokerCli
    from .ssdp import SSDPr

if PY2:
    range = xrange  # type: ignore

if not hasattr(socket, "AF_UNIX"):
    setattr(socket, "AF_UNIX", -9001)


def load_jinja2_resource(E , name ):
    with load_resource(E, "web/" + name, "r") as f:
        return f.read()


class HttpSrv(object):
    """
    handles incoming connections using HttpConn to process http,
    relying on MpSrv for performance (HttpSrv is just plain threads)
    """

    def __init__(self, broker , nid )  :
        self.broker = broker
        self.nid = nid
        self.args = broker.args
        self.E  = self.args.E
        self.log = broker.log
        self.asrv = broker.asrv

        # redefine in case of multiprocessing
        socket.setdefaulttimeout(120)

        self.t0 = time.time()
        nsuf = "-n{}-i{:x}".format(nid, os.getpid()) if nid else ""
        self.magician = Magician()
        self.nm = NetMap([], [])
        self.ssdp  = None
        self.gpwd = Garda(self.args.ban_pw)
        self.gpwc = Garda(self.args.ban_pwc)
        self.g404 = Garda(self.args.ban_404)
        self.g403 = Garda(self.args.ban_403)
        self.g422 = Garda(self.args.ban_422, False)
        self.gmal = Garda(self.args.ban_422)
        self.gurl = Garda(self.args.ban_url)
        self.bans   = {}
        self.aclose   = {}

        dli       = {}  # info
        dls    = {}  # state
        self.dli = self.tdli = dli
        self.dls = self.tdls = dls
        self.iiam = '<img src="%s.cpr/iiam.gif?cache=i" />' % (self.args.SRS,)

        self.bound   = set()
        self.name = "hsrv" + nsuf
        self.mutex = threading.Lock()
        self.u2mutex = threading.Lock()
        self.stopping = False

        self.tp_nthr = 0  # actual
        self.tp_ncli = 0  # fading
        self.tp_time = 0.0  # latest worker collect
        self.tp_q  = (
            None if self.args.no_htp else queue.LifoQueue()
        )
        self.t_periodic  = None

        self.u2fh = FHC()
        self.u2sc    = {}
        self.pipes = CachedDict(0.2)
        self.metrics = Metrics(self)
        self.nreq = 0
        self.nsus = 0
        self.nban = 0
        self.srvs  = []
        self.ncli = 0  # exact
        self.clients  = set()  # laggy
        self.nclimax = 0
        self.cb_ts = 0.0
        self.cb_v = ""

        self.u2idx_free   = {}
        self.u2idx_n = 0

        env = jinja2.Environment()
        env.loader = jinja2.FunctionLoader(lambda f: load_jinja2_resource(self.E, f))
        jn = [
            "browser",
            "browser2",
            "cf",
            "idp",
            "md",
            "mde",
            "msg",
            "rups",
            "shares",
            "splash",
            "svcs",
        ]
        self.j2 = {x: env.get_template(x + ".html") for x in jn}
        self.prism = has_resource(self.E, "web/deps/prism.js.gz")

        if self.args.ipu:
            self.ipu_iu, self.ipu_nm = load_ipu(self.log, self.args.ipu)
        else:
            self.ipu_iu = self.ipu_nm = None

        self.ipa_nm = build_netmap(self.args.ipa)
        self.xff_nm = build_netmap(self.args.xff_src)
        self.xff_lan = build_netmap("lan")

        self.mallow = "GET HEAD POST PUT DELETE OPTIONS".split()
        if not self.args.no_dav:
            zs = "PROPFIND PROPPATCH LOCK UNLOCK MKCOL COPY MOVE"
            self.mallow += zs.split()

        if self.args.zs:
            from .ssdp import SSDPr

            self.ssdp = SSDPr(broker)

        if self.tp_q:
            self.start_threads(4)

        if nid:
            self.tdli = {}
            self.tdls = {}

            if self.args.stackmon:
                start_stackmon(self.args.stackmon, nid)

            if self.args.log_thrs:
                start_log_thrs(self.log, self.args.log_thrs, nid)

        self.th_cfg   = {}
        Daemon(self.post_init, "hsrv-init2")

    def post_init(self)  :
        try:
            x = self.broker.ask("thumbsrv.getcfg")
            self.th_cfg = x.get()
        except:
            pass

    def set_netdevs(self, netdevs  )  :
        ips = set()
        for ip, _ in self.bound:
            ips.add(ip)

        self.nm = NetMap(list(ips), list(netdevs))

    def start_threads(self, n )  :
        self.tp_nthr += n
        if self.args.log_htp:
            self.log(self.name, "workers += {} = {}".format(n, self.tp_nthr), 6)

        for _ in range(n):
            Daemon(self.thr_poolw, self.name + "-poolw")

    def stop_threads(self, n )  :
        self.tp_nthr -= n
        if self.args.log_htp:
            self.log(self.name, "workers -= {} = {}".format(n, self.tp_nthr), 6)

        for _ in range(n):
            self.tp_q.put(None)

    def periodic(self)  :
        while True:
            time.sleep(2 if self.tp_ncli or self.ncli else 10)
            with self.u2mutex, self.mutex:
                self.u2fh.clean()
                if self.tp_q:
                    self.tp_ncli = max(self.ncli, self.tp_ncli - 2)
                    if self.tp_nthr > self.tp_ncli + 8:
                        self.stop_threads(4)

                if not self.ncli and not self.u2fh.cache and self.tp_nthr <= 8:
                    self.t_periodic = None
                    return

    def listen(self, sck , nlisteners )  :
        tcp = sck.family != socket.AF_UNIX

        if self.args.j != 1:
            # lost in the pickle; redefine
            if not ANYWIN or self.args.reuseaddr:
                sck.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

            if tcp:
                sck.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

            sck.settimeout(None)  # < does not inherit, ^ opts above do

        if tcp:
            ip, port = sck.getsockname()[:2]
        else:
            ip = re.sub(r"\.[0-9]+$", "", sck.getsockname().split("/")[-1])
            port = 0

        self.srvs.append(sck)
        self.bound.add((ip, port))
        self.nclimax = math.ceil(self.args.nc * 1.0 / nlisteners)
        Daemon(
            self.thr_listen,
            "httpsrv-n{}-listen-{}-{}".format(self.nid or "0", ip, port),
            (sck,),
        )

    def thr_listen(self, srv_sck )  :
        """listens on a shared tcp server"""
        fno = srv_sck.fileno()
        if srv_sck.family == socket.AF_UNIX:
            ip = re.sub(r"\.[0-9]+$", "", srv_sck.getsockname())
            msg = "subscribed @ %s  f%d p%d" % (ip, fno, os.getpid())
            ip = ip.split("/")[-1]
            port = 0
            tcp = False
        else:
            tcp = True
            ip, port = srv_sck.getsockname()[:2]
            hip = "[%s]" % (ip,) if ":" in ip else ip
            msg = "subscribed @ %s:%d  f%d p%d" % (hip, port, fno, os.getpid())

        self.log(self.name, msg)

        Daemon(self.broker.say, "sig-hsrv-up1", ("cb_httpsrv_up",))

        saddr = ("", 0)  # fwd-decl for `except TypeError as ex:`

        while not self.stopping:
            if self.args.log_conn:
                self.log(self.name, "|%sC-ncli" % ("-" * 1,), c="90")

            spins = 0
            while self.ncli >= self.nclimax:
                if not spins:
                    t = "at connection limit (global-option 'nc'); waiting"
                    self.log(self.name, t, 3)

                spins += 1
                time.sleep(0.1)
                if spins != 50 or not self.args.aclose:
                    continue

                ipfreq   = {}
                with self.mutex:
                    for c in self.clients:
                        ip = ipnorm(c.ip)
                        try:
                            ipfreq[ip] += 1
                        except:
                            ipfreq[ip] = 1

                ip, n = sorted(ipfreq.items(), key=lambda x: x[1], reverse=True)[0]
                if n < self.nclimax / 2:
                    continue

                self.aclose[ip] = int(time.time() + self.args.aclose * 60)
                nclose = 0
                nloris = 0
                nconn = 0
                with self.mutex:
                    for c in self.clients:
                        cip = ipnorm(c.ip)
                        if ip != cip:
                            continue

                        nconn += 1
                        try:
                            if (
                                c.nreq >= 1
                                or not c.cli
                                or c.cli.in_hdr_recv
                                or c.cli.keepalive
                            ):
                                Daemon(c.shutdown)
                                nclose += 1
                                if c.nreq <= 0 and (not c.cli or c.cli.in_hdr_recv):
                                    nloris += 1
                        except:
                            pass

                t = "{} downgraded to connection:close for {} min; dropped {}/{} connections"
                self.log(self.name, t.format(ip, self.args.aclose, nclose, nconn), 1)

                if nloris < nconn / 2:
                    continue

                t = "slowloris (idle-conn): {} banned for {} min"
                self.log(self.name, t.format(ip, self.args.loris, nclose), 1)
                self.bans[ip] = int(time.time() + self.args.loris * 60)

            if self.args.log_conn:
                self.log(self.name, "|%sC-acc1" % ("-" * 2,), c="90")

            try:
                sck, saddr = srv_sck.accept()
                if tcp:
                    cip = unicode(saddr[0])
                    if cip.startswith("::ffff:"):
                        cip = cip[7:]
                    addr = (cip, saddr[1])
                else:
                    addr = ("127.8.3.7", sck.fileno())
            except (OSError, socket.error) as ex:
                if self.stopping:
                    break

                self.log(self.name, "accept({}): {}".format(fno, ex), c=6)
                time.sleep(0.02)
                continue
            except TypeError as ex:
                # on macOS, accept() may return a None saddr if blocked by LittleSnitch;
                # unicode(saddr[0]) ==> TypeError: 'NoneType' object is not subscriptable
                if tcp and not saddr:
                    t = "accept(%s): failed to accept connection from client due to firewall or network issue"
                    self.log(self.name, t % (fno,), c=3)
                    try:
                        sck.close()  # type: ignore
                    except:
                        pass
                    time.sleep(0.02)
                    continue
                raise

            if self.args.log_conn:
                t = "|{}C-acc2 \033[0;36m{} \033[3{}m{}".format(
                    "-" * 3, ip, port % 8, port
                )
                self.log("%s %s" % addr, t, c="90")

            self.accept(sck, addr)

    def accept(self, sck , addr  )  :
        """takes an incoming tcp connection and creates a thread to handle it"""
        now = time.time()

        if now - (self.tp_time or now) > 300:
            t = "httpserver threadpool died: tpt {:.2f}, now {:.2f}, nthr {}, ncli {}"
            self.log(self.name, t.format(self.tp_time, now, self.tp_nthr, self.ncli), 1)
            self.tp_time = 0
            self.tp_q = None

        with self.mutex:
            self.ncli += 1
            if not self.t_periodic:
                name = "hsrv-pt"
                if self.nid:
                    name += "-%d" % (self.nid,)

                self.t_periodic = Daemon(self.periodic, name)

            if self.tp_q:
                self.tp_time = self.tp_time or now
                self.tp_ncli = max(self.tp_ncli, self.ncli)
                if self.tp_nthr < self.ncli + 4:
                    self.start_threads(8)

                self.tp_q.put((sck, addr))
                return

        if not self.args.no_htp:
            t = "looks like the httpserver threadpool died; please make an issue on github and tell me the story of how you pulled that off, thanks and dog bless\n"
            self.log(self.name, t, 1)

        Daemon(
            self.thr_client,
            "httpconn-%s-%d" % (addr[0].split(".", 2)[-1][-6:], addr[1]),
            (sck, addr),
        )

    def thr_poolw(self)  :
        while True:
            task = self.tp_q.get()
            if not task:
                break

            with self.mutex:
                self.tp_time = 0

            try:
                sck, addr = task
                me = threading.current_thread()
                me.name = "httpconn-%s-%d" % (addr[0].split(".", 2)[-1][-6:], addr[1])
                self.thr_client(sck, addr)
                me.name = self.name + "-poolw"
            except Exception as ex:
                if str(ex).startswith("client d/c "):
                    self.log(self.name, "thr_client: " + str(ex), 6)
                else:
                    self.log(self.name, "thr_client: " + min_ex(), 3)

    def shutdown(self)  :
        self.stopping = True
        for srv in self.srvs:
            try:
                srv.close()
            except:
                pass

        thrs = []
        clients = list(self.clients)
        for cli in clients:
            t = threading.Thread(target=cli.shutdown)
            thrs.append(t)
            t.start()

        if self.tp_q:
            self.stop_threads(self.tp_nthr)
            for _ in range(10):
                time.sleep(0.05)
                if self.tp_q.empty():
                    break

        for t in thrs:
            t.join()

        self.log(self.name, "ok bye")

    def thr_client(self, sck , addr  )  :
        """thread managing one tcp client"""
        cli = HttpConn(sck, addr, self)
        with self.mutex:
            self.clients.add(cli)

        # print("{}\n".format(len(self.clients)), end="")
        fno = sck.fileno()
        try:
            if self.args.log_conn:
                self.log("%s %s" % addr, "|%sC-crun" % ("-" * 4,), c="90")

            cli.run()

        except (OSError, socket.error) as ex:
            if ex.errno not in E_SCK:
                self.log(
                    "%s %s" % addr,
                    "run({}): {}".format(fno, ex),
                    c=6,
                )

        finally:
            sck = cli.s
            if self.args.log_conn:
                self.log("%s %s" % addr, "|%sC-cdone" % ("-" * 5,), c="90")

            try:
                fno = sck.fileno()
                shut_socket(cli.log, sck)
            except (OSError, socket.error) as ex:
                if not MACOS:
                    self.log(
                        "%s %s" % addr,
                        "shut({}): {}".format(fno, ex),
                        c="90",
                    )
                if ex.errno not in E_SCK:
                    raise
            finally:
                with self.mutex:
                    self.clients.remove(cli)
                    self.ncli -= 1

                if cli.u2idx:
                    self.put_u2idx(str(addr), cli.u2idx)

    def cachebuster(self)  :
        if time.time() - self.cb_ts < 1:
            return self.cb_v

        with self.mutex:
            if time.time() - self.cb_ts < 1:
                return self.cb_v

            v = self.E.t0
            try:
                with os.scandir(os.path.join(self.E.mod, "web")) as dh:
                    for fh in dh:
                        inf = fh.stat()
                        v = max(v, inf.st_mtime)
            except:
                pass

            # spack gives 4 lsb, take 3 lsb, get 4 ch
            self.cb_v = ub64enc(spack(b">L", int(v))[1:]).decode("ascii")
            self.cb_ts = time.time()
            return self.cb_v

    def get_u2idx(self, ident )  :
        utab = self.u2idx_free
        for _ in range(100):  # 5/0.05 = 5sec
            with self.mutex:
                if utab:
                    if ident in utab:
                        return utab.pop(ident)

                    return utab.pop(list(utab.keys())[0])

                if self.u2idx_n < CORES:
                    self.u2idx_n += 1
                    return U2idx(self)

            time.sleep(0.05)
            # not using conditional waits, on a hunch that
            # average performance will be faster like this
            # since most servers won't be fully saturated

        return None

    def put_u2idx(self, ident , u2idx )  :
        with self.mutex:
            while ident in self.u2idx_free:
                ident += "a"

            self.u2idx_free[ident] = u2idx

    def read_dls(
        self,
    ):  
                
    
        """
        mp-broker asking for local dl-info + dl-state;
        reduce overhead by sending just the vfs vpath
        """
        dli = {k: (a, b, c.vpath, d, e) for k, (a, b, c, d, e) in self.dli.items()}
        return (dli, self.dls)

    def write_dls(
        self,
        sdli      ,
        dls   ,
    )  :
        """
        mp-broker pushing total dl-info + dl-state;
        swap out the vfs vpath with the vfs node
        """
        dli       = {}
        for k, (a, b, c, d, e) in sdli.items():
            vn = self.asrv.vfs.all_nodes[c]
            dli[k] = (a, b, vn, d, e)

        self.tdli = dli
        self.tdls = dls
