from __future__ import annotations

from dataclasses import dataclass
import json
import mimetypes
import os
from pathlib import Path
import ssl
import threading
import time
from typing import Any, Callable

from http.server import BaseHTTPRequestHandler
from socketserver import ThreadingMixIn
from http.server import HTTPServer

from backend.chatapp import db as db_mod
from backend.chatapp.api import register_routes
from backend.chatapp.backup import start_backup_loop
from backend.chatapp.config import Config, load_config
from backend.chatapp.http import HttpError, Request, Response, Router, json_response, parse_request_target
from backend.chatapp.state import AppState
from backend.chatapp.util import now_ts, safe_mkdir


class _TLSContextReloader:
    def __init__(self, cert_path: Path, key_path: Path) -> None:
        self._cert_path = cert_path
        self._key_path = key_path
        self._lock = threading.Lock()
        self._context: ssl.SSLContext | None = None
        self._fingerprint: tuple[int, int, int, int] | None = None

    def get(self) -> ssl.SSLContext | None:
        with self._lock:
            return self._context

    def load_if_changed(self) -> bool:
        fingerprint = self._compute_fingerprint()
        with self._lock:
            if fingerprint == self._fingerprint and self._context is not None:
                return False
        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
        ctx.minimum_version = ssl.TLSVersion.TLSv1_2
        ctx.load_cert_chain(certfile=str(self._cert_path), keyfile=str(self._key_path))
        with self._lock:
            self._context = ctx
            self._fingerprint = fingerprint
        return True

    def _compute_fingerprint(self) -> tuple[int, int, int, int]:
        cert_stat = self._cert_path.stat()
        key_stat = self._key_path.stat()
        return (
            int(cert_stat.st_mtime),
            int(cert_stat.st_size),
            int(key_stat.st_mtime),
            int(key_stat.st_size),
        )


class _FlexibleHTTPServer(ThreadingMixIn, HTTPServer):
    daemon_threads = True

    def __init__(
        self,
        server_address: tuple[str, int],
        RequestHandlerClass: type[BaseHTTPRequestHandler],
        *,
        ssl_context_getter: Callable[[], ssl.SSLContext | None],
    ) -> None:
        self._ssl_context_getter = ssl_context_getter
        super().__init__(server_address, RequestHandlerClass)

    def get_request(self) -> tuple[Any, Any]:
        sock, addr = super().get_request()
        ctx = self._ssl_context_getter()
        if ctx is None:
            return sock, addr
        try:
            tls_sock = ctx.wrap_socket(sock, server_side=True)
        except Exception:  # noqa: BLE001
            sock.close()
            raise
        return tls_sock, addr


@dataclass(frozen=True)
class _Runtime:
    state: AppState
    router: Router


def _cors_headers() -> dict[str, str]:
    return {
        "access-control-allow-origin": "*",
        "access-control-allow-headers": "authorization,content-type,x-chunk-sha256",
        "access-control-allow-methods": "GET,POST,PUT,PATCH,DELETE,OPTIONS",
        "access-control-max-age": "600",
    }


def _discover_web_root() -> Path | None:
    override = os.getenv("CHATAPP_WEB_ROOT", "").strip()
    if override:
        root = Path(override).expanduser()
        if root.exists() and root.is_dir():
            return root

    packaged = Path(__file__).parent / "static" / "web"
    if packaged.exists() and packaged.is_dir():
        return packaged

    repo_dist = Path(__file__).resolve().parents[2] / "web" / "dist"
    if repo_dist.exists() and repo_dist.is_dir():
        return repo_dist

    return None


class _Handler(BaseHTTPRequestHandler):
    server: _FlexibleHTTPServer  # type: ignore[assignment]

    def do_OPTIONS(self) -> None:  # noqa: N802
        headers = _cors_headers()
        headers["content-length"] = "0"
        self._send(Response(status=204, headers=headers, body=b""))

    def do_GET(self) -> None:  # noqa: N802
        self._dispatch()

    def do_POST(self) -> None:  # noqa: N802
        self._dispatch()

    def do_PUT(self) -> None:  # noqa: N802
        self._dispatch()

    def do_PATCH(self) -> None:  # noqa: N802
        self._dispatch()

    def do_DELETE(self) -> None:  # noqa: N802
        self._dispatch()

    def log_message(self, format: str, *args: Any) -> None:  # noqa: A002
        # Avoid noisy default logging; API uses DB logs.
        return

    def _dispatch(self) -> None:
        runtime: _Runtime = self.server.runtime  # type: ignore[attr-defined]

        path, query = parse_request_target(self.path)

        # Built-in admin UI static files.
        if path.startswith("/admin/") or path == "/admin":
            self._serve_admin_static(runtime.state, path)
            return
        # Optional built web client (Docker image copies it into static/web).
        if self.command == "GET" and self._maybe_serve_web_static(path):
            return

        match = runtime.router.match(self.command, path)
        if match is None:
            self._send(
                json_response(
                    404,
                    {"error": {"code": "not_found", "message": "Not found"}},
                    headers=_cors_headers(),
                )
            )
            return
        handler, path_params = match

        headers = {k.lower(): v for k, v in self.headers.items()}
        content_length = int(headers.get("content-length", "-1"))
        req = Request(
            method=self.command,
            raw_path=self.path,
            path=path,
            query=query,
            headers=headers,
            client_ip=self.client_address[0],
            rfile=self.rfile,
            content_length=content_length,
            path_params=path_params,
        )
        try:
            resp = handler(req)
        except HttpError as exc:
            resp = json_response(
                exc.status,
                {"error": {"code": exc.code, "message": exc.message}},
                headers=_cors_headers(),
            )
        except Exception as exc:  # noqa: BLE001
            resp = json_response(
                500,
                {"error": {"code": "internal_error", "message": "Internal server error"}},
                headers=_cors_headers(),
            )
            _log_internal_error(runtime.state, req, exc)
        else:
            resp.headers.update(_cors_headers())

        self._send(resp)

    def _send(self, resp: Response) -> None:
        self.send_response(resp.status)
        for k, v in resp.headers.items():
            self.send_header(k, v)
        self.end_headers()
        if resp.file_path is not None:
            with open(resp.file_path, "rb") as f:
                while True:
                    chunk = f.read(1024 * 256)
                    if not chunk:
                        break
                    self.wfile.write(chunk)
            return
        self.wfile.write(resp.body)

    def _serve_admin_static(self, state: AppState, path: str) -> None:
        admin_root = Path(__file__).parent / "static" / "admin"
        if path in ("/admin", "/admin/"):
            rel = "index.html"
        else:
            rel = path.removeprefix("/admin/").lstrip("/")
            if rel == "":
                rel = "index.html"
        file_path = (admin_root / rel).resolve()
        if admin_root not in file_path.parents and file_path != admin_root:
            self._send(
                json_response(
                    404,
                    {"error": {"code": "not_found", "message": "Not found"}},
                    headers=_cors_headers(),
                )
            )
            return
        if not file_path.exists() or not file_path.is_file():
            self._send(
                json_response(
                    404,
                    {"error": {"code": "not_found", "message": "Not found"}},
                    headers=_cors_headers(),
                )
            )
            return

        content = file_path.read_bytes()
        ctype, _ = mimetypes.guess_type(str(file_path))
        headers = {
            "content-type": ctype or "application/octet-stream",
            "content-length": str(len(content)),
            "cache-control": "no-store",
        }
        headers.update(_cors_headers())
        self._send(Response(status=200, headers=headers, body=content))

    def _maybe_serve_web_static(self, path: str) -> bool:
        web_root = _discover_web_root()
        if web_root is None:
            return False

        if path in ("/", "/index.html"):
            rel = "index.html"
        else:
            rel = path.lstrip("/")

        file_path = (web_root / rel).resolve()
        if web_root not in file_path.parents and file_path != web_root:
            return False

        if not file_path.exists() or not file_path.is_file():
            if path.startswith("/api/") or path.startswith("/admin"):
                return False
            file_path = (web_root / "index.html").resolve()
            if not file_path.exists() or not file_path.is_file():
                return False

        content = file_path.read_bytes()
        ctype, _ = mimetypes.guess_type(str(file_path))
        cache = "no-store"
        if path.startswith("/assets/"):
            cache = "public, max-age=31536000, immutable"
        headers = {
            "content-type": ctype or "application/octet-stream",
            "content-length": str(len(content)),
            "cache-control": cache,
        }
        headers.update(_cors_headers())
        self._send(Response(status=200, headers=headers, body=content))
        return True


def _log_internal_error(state: AppState, req: Request, exc: Exception) -> None:
    try:
        conn = db_mod.connect(state.db_path)
        try:
            now = now_ts()
            conn.execute(
                "INSERT INTO logs (id, level, event, details_json, created_at) VALUES (?,?,?,?,?)",
                (
                    _new_id(),
                    "ERROR",
                    "internal_error",
                    json.dumps(
                        {
                            "path": req.path,
                            "method": req.method,
                            "message": str(exc),
                        }
                    ),
                    now,
                ),
            )
        finally:
            conn.close()
    except Exception:  # noqa: BLE001
        return


def _new_id() -> str:
    import uuid

    return uuid.uuid4().hex


def _start_cert_reload_loop(config: Config, reloader: _TLSContextReloader) -> threading.Thread:
    def loop() -> None:
        while True:
            try:
                reloader.load_if_changed()
            except Exception:  # noqa: BLE001
                pass
            time.sleep(max(5, int(config.cert_reload_seconds)))

    t = threading.Thread(target=loop, name="cert-reloader", daemon=True)
    t.start()
    return t


def run() -> None:
    config = load_config()
    safe_mkdir(config.data_dir)
    safe_mkdir(config.storage_dir)
    safe_mkdir(config.backups_dir)

    db_mod.init_db(config.db_path)

    state = AppState(
        config=config,
        db_path=config.db_path,
        storage_dir=config.storage_dir,
        backups_dir=config.backups_dir,
    )

    router = Router()
    register_routes(router, state)

    start_backup_loop(state)

    reloader: _TLSContextReloader | None = None
    if config.cert_dir is not None:
        cert_path = config.cert_dir / config.cert_file
        key_path = config.cert_dir / config.key_file
        reloader = _TLSContextReloader(cert_path=cert_path, key_path=key_path)
        try:
            reloader.load_if_changed()
        except Exception:  # noqa: BLE001
            reloader = None
        else:
            _start_cert_reload_loop(config, reloader)

    def get_ctx() -> ssl.SSLContext | None:
        if reloader is None:
            return None
        return reloader.get()

    httpd = _FlexibleHTTPServer((config.bind, int(config.port)), _Handler, ssl_context_getter=get_ctx)
    httpd.runtime = _Runtime(state=state, router=router)  # type: ignore[attr-defined]

    scheme = "https" if reloader is not None else "http"
    print(f"Chat backend listening on {scheme}://{config.bind}:{config.port}")
    try:
        httpd.serve_forever(poll_interval=0.5)
    except KeyboardInterrupt:
        pass
    finally:
        httpd.server_close()
