#!/usr/bin/env python3
# coding: utf-8
from __future__ import print_function, unicode_literals

import io
import os
import pprint
import shutil
import tarfile
import tempfile
import time
import unittest
import zipfile

from copyparty.authsrv import AuthSrv
from copyparty.httpcli import HttpCli
from tests import util as tu
from tests.util import Cfg, eprint, pfind2ls


def hdr(query):
    h = "GET /{} HTTP/1.1\r\nCookie: cppwd=o\r\nConnection: close\r\n\r\n"
    return h.format(query).encode("utf-8")


class TestHttpCli(unittest.TestCase):
    def setUp(self):
        self.td = tu.get_ramdisk()
        self.maxDiff = 99999

    def tearDown(self):
        os.chdir(tempfile.gettempdir())
        shutil.rmtree(self.td)

    def test(self):
        test_tar = True
        test_zip = True

        td = os.path.join(self.td, "vfs")
        os.mkdir(td)
        os.chdir(td)

        # "perm+user"; r/w/a (a=rw) for user a/o/x (a=all)
        self.dtypes = ["ra", "ro", "rx", "wa", "wo", "wx", "aa", "ao", "ax"]
        self.can_read = ["ra", "ro", "aa", "ao"]
        self.can_write = ["wa", "wo", "aa", "ao"]
        self.fn = "g{:x}g".format(int(time.time() * 3))

        tctr = 0
        allfiles = []
        allvols = []
        for top in self.dtypes:
            allvols.append(top)
            allfiles.append("/".join([top, self.fn]))
            for s1 in self.dtypes:
                p = "/".join([top, s1])
                allvols.append(p)
                allfiles.append(p + "/" + self.fn)
                allfiles.append(p + "/n/" + self.fn)
                for s2 in self.dtypes:
                    p = "/".join([top, s1, "n", s2])
                    os.makedirs(p)
                    allvols.append(p)
                    allfiles.append(p + "/" + self.fn)

        for fp in allfiles:
            with open(fp, "w") as f:
                f.write("ok {}\n".format(fp))

        for top in self.dtypes:
            vcfg = []
            for vol in allvols:
                if not vol.startswith(top):
                    continue

                mode = vol[-2].replace("a", "rw")
                usr = vol[-1]
                if usr == "a":
                    usr = ""

                if "/" not in vol:
                    vol += "/"

                top, sub = vol.split("/", 1)
                vcfg.append("{0}/{1}:{1}:{2},{3}".format(top, sub, mode, usr))

            pprint.pprint(vcfg)

            self.args = Cfg(v=vcfg, a=["o:o", "x:x"])
            self.asrv = AuthSrv(self.args, self.log)
            self.conn = tu.VHttpConn(self.args, self.asrv, self.log, b"")
            vfiles = [x for x in allfiles if x.startswith(top)]
            for fp in vfiles:
                tctr += 1
                rok, wok = self.can_rw(fp)
                furl = fp.split("/", 1)[1]
                durl = furl.rsplit("/", 1)[0] if "/" in furl else ""

                # file download
                h, ret = self.curl(furl)
                res = "ok " + fp in ret
                print("[{}] {} {} = {}".format(fp, rok, wok, res))
                if rok != res:
                    eprint("\033[33m{}\n# {}\033[0m".format(ret, furl))
                    self.fail()

                # file browser: html
                h, ret = self.curl(durl)
                res = "'{}'".format(self.fn) in ret
                print(res)
                if rok != res:
                    eprint("\033[33m{}\n# {}\033[0m".format(ret, durl))
                    self.fail()

                # file browser: json
                url = durl + "?ls"
                h, ret = self.curl(url)
                res = '"{}"'.format(self.fn) in ret
                print(res)
                if rok != res:
                    eprint("\033[33m{}\n# {}\033[0m".format(ret, url))
                    self.fail()

                # expected files in archives
                if rok:
                    zs = top + "/" + durl
                    ref = [x for x in vfiles if self.in_dive(zs, x)]
                    ref.sort()
                else:
                    ref = []

                h, b = self.propfind(durl, 1)
                fns = [x for x in pfind2ls(b) if not x.endswith("/")]
                if ref:
                    self.assertIn("<D:propstat>", b)
                elif not rok and not wok:
                    self.assertListEqual([], fns)
                else:
                    self.assertIn("<D:multistatus", b)

                h, b = self.propfind(durl, 0)
                fns = [x for x in pfind2ls(b) if not x.endswith("/")]
                if ref:
                    self.assertIn("<D:propstat>", b)
                elif not rok:
                    self.assertListEqual([], fns)
                else:
                    self.assertIn("<D:multistatus", b)

                if test_tar:
                    url = durl + "?tar"
                    h, b = self.curl(url, True)
                    try:
                        tar = tarfile.open(fileobj=io.BytesIO(b), mode="r|").getnames()
                    except:
                        if "HTTP/1.1 403 Forbidden" not in h and b != b"\nJ2EOT":
                            eprint("bad tar?", url, h, b)
                            raise
                        tar = []
                    tar = [x.split("/", 1)[1] for x in tar]
                    tar = ["/".join([y for y in [top, durl, x] if y]) for x in tar]
                    tar = [[x] + self.can_rw(x) for x in tar]
                    tar_ok = [x[0] for x in tar if x[1]]
                    tar_ng = [x[0] for x in tar if not x[1]]
                    tar_ok.sort()
                    self.assertEqual(ref, tar_ok)
                    self.assertEqual([], tar_ng)

                if test_zip:
                    url = durl + "?zip"
                    h, b = self.curl(url, True)
                    try:
                        with zipfile.ZipFile(io.BytesIO(b), "r") as zf:
                            zfi = zf.infolist()
                    except:
                        if "HTTP/1.1 403 Forbidden" not in h and b != b"\nJ2EOT":
                            eprint("bad zip?", url, h, b)
                            raise
                        zfi = []
                    zfn = [x.filename.split("/", 1)[1] for x in zfi]
                    zfn = ["/".join([y for y in [top, durl, x] if y]) for x in zfn]
                    zfn = [[x] + self.can_rw(x) for x in zfn]
                    zf_ok = [x[0] for x in zfn if x[1]]
                    zf_ng = [x[0] for x in zfn if not x[1]]
                    zf_ok.sort()
                    self.assertEqual(ref, zf_ok)
                    self.assertEqual([], zf_ng)

                # stash
                h, ret = self.put(durl)
                res = h.startswith("HTTP/1.1 201 ")
                self.assertEqual(res, wok)
                if wok:
                    vp = h.split("\nLocation: http://a:1/")[1].split("\r")[0]
                    vn, rem = self.asrv.vfs.get(vp, "*", False, False)
                    ap = os.path.join(vn.realpath, rem)
                    os.unlink(ap)

            self.conn.shutdown()

    def can_rw(self, fp):
        # lowest non-neutral folder declares permissions
        expect = fp.split("/")[:-1]
        for x in reversed(expect):
            if x != "n":
                expect = x
                break

        return [expect in self.can_read, expect in self.can_write]

    def in_dive(self, top, fp):
        # archiver bails at first inaccessible subvolume
        top = top.strip("/").split("/")
        fp = fp.split("/")
        for f1, f2 in zip(top, fp):
            if f1 != f2:
                return False

        for f in fp[len(top) :]:
            if f == self.fn:
                return True
            if f not in self.can_read and f != "n":
                return False

        return True

    def put(self, url):
        buf = "PUT /{0} HTTP/1.1\r\nCookie: cppwd=o\r\nConnection: close\r\nContent-Length: {1}\r\n\r\nok {0}\n"
        buf = buf.format(url, len(url) + 4).encode("utf-8")
        print("PUT -->", buf)
        conn = self.conn.setbuf(buf)
        HttpCli(conn).run()
        ret = conn.s._reply.decode("utf-8").split("\r\n\r\n", 1)
        print("PUT <--", ret)
        return ret

    def curl(self, url, binary=False):
        conn = self.conn.setbuf(hdr(url))
        HttpCli(conn).run()
        if binary:
            h, b = conn.s._reply.split(b"\r\n\r\n", 1)
            return [h.decode("utf-8"), b]

        return conn.s._reply.decode("utf-8").split("\r\n\r\n", 1)

    def propfind(self, url, depth=1):
        zs = "PROPFIND /%s HTTP/1.1\r\nDepth: %d\r\nPW: o\r\nConnection: close\r\n\r\n"
        buf = zs % (url, depth)
        conn = self.conn.setbuf(buf.encode("utf-8"))
        HttpCli(conn).run()
        return conn.s._reply.decode("utf-8").split("\r\n\r\n", 1)

    def log(self, src, msg, c=0):
        print(msg)
