"""Shared Streamable HTTP + bearer middleware for Python MCP servers.

This is the Python parallel to ``streamable_transport.ts``. It wires an
``mcp.server.Server`` to Starlette via the SDK's Streamable HTTP ASGI
app, fronted by a bearer-token middleware, with an unauthenticated
``/health`` probe.

Routes:
  GET  /health         -> 200 {"ok": true, "service": ...} UNAUTHENTICATED
  ANY  /mcp            -> Streamable HTTP MCP transport (bearer required)
  GET  /sse            -> legacy SSE stream (bearer required, only if legacy_sse)
  POST /messages       -> legacy SSE POST (bearer required, only if legacy_sse)

Phase A scope: TypeScript ringcentral-admin and docstrange are the
only Python consumers. Phase B will migrate them as part of the
rollout after the AdvancedMD pilot.
"""

from __future__ import annotations

import contextlib
import logging
import time
from dataclasses import dataclass
from typing import Any, AsyncIterator, Optional

import uvicorn
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import Mount, Route

from .auth import require_bearer

logger = logging.getLogger("mcp_shared")


# Rate-limited bearer-failure logging.
#
# Logs the first failure in each window immediately so operators see an
# attack as it begins, then suppresses subsequent failures for the
# window and flushes a summary when the next failure crosses the
# window boundary. Avoids the bug where a short burst stopping abruptly
# would never log at all.
#
# Never logs the offered token, path, or any header values.
class _AuthFailureCounter:
    def __init__(self) -> None:
        self._window_start = 0.0
        self._suppressed = 0
        self._window_s = 60.0

    def record(self) -> None:
        now = time.monotonic()
        if now - self._window_start >= self._window_s:
            if self._suppressed > 0:
                logger.warning(
                    "mcp-shared: %d additional bearer auth failure(s) in previous window",
                    self._suppressed,
                )
            logger.warning("mcp-shared: bearer auth failure")
            self._window_start = now
            self._suppressed = 0
        else:
            self._suppressed += 1


class _PathRewriteMiddleware:
    """Internal path rewrite for the MCP mount.

    Starlette's ``Mount(path, app=...)`` requires a trailing slash on the
    incoming request (``/mcp/`` not ``/mcp``). The TypeScript transport
    accepts both forms, and the existing client config + Tailscale
    funnel point at ``/mcp``. Rather than emit a 307 (which some MCP
    clients do not follow on POSTs), we rewrite the path before routing
    when it matches the configured MCP base.
    """

    def __init__(self, app: Any, base_path: str) -> None:
        self.app = app
        self._base = base_path.rstrip("/")
        self._base_slash = self._base + "/"

    async def __call__(self, scope: Any, receive: Any, send: Any) -> None:
        if scope.get("type") == "http":
            path = scope.get("path", "")
            if path == self._base:
                scope = dict(scope)
                scope["path"] = self._base_slash
                raw_path = scope.get("raw_path")
                if isinstance(raw_path, (bytes, bytearray)):
                    scope["raw_path"] = bytes(raw_path) + b"/"
        await self.app(scope, receive, send)


class BearerAuthMiddleware(BaseHTTPMiddleware):
    """Gate every request except the unauthenticated allow-list."""

    def __init__(
        self,
        app: Any,
        *,
        token: str,
        allow_paths: tuple[str, ...] = ("/health",),
        counter: Optional[_AuthFailureCounter] = None,
    ) -> None:
        super().__init__(app)
        self._token = token
        self._allow_paths = allow_paths
        self._counter = counter or _AuthFailureCounter()

    async def dispatch(self, request: Request, call_next):  # type: ignore[override]
        if request.url.path in self._allow_paths:
            return await call_next(request)
        err = require_bearer(request.headers, self._token)
        if err is not None:
            self._counter.record()
            return err
        return await call_next(request)


@dataclass
class ServeMcpResult:
    app: Starlette
    config: uvicorn.Config
    server: uvicorn.Server


def _health_route_factory(service_name: str):
    async def health(_req: Request) -> Response:
        return JSONResponse({"ok": True, "service": service_name})

    return health


def serve_mcp_over_http(
    server: Any,
    *,
    port: int,
    token: str,
    path: str = "/mcp",
    legacy_sse: bool = False,
    host: str = "0.0.0.0",
    run: bool = True,
) -> ServeMcpResult:
    """Boot a Streamable HTTP listener for an ``mcp.server.Server``.

    Parameters
    ----------
    server:
        An ``mcp.server.Server`` instance with tools already registered.
    port:
        TCP port to bind.
    token:
        Bearer token required for all routes except ``/health``.
    path:
        MCP route path. Defaults to ``/mcp``.
    legacy_sse:
        If True, also mount the SDK's legacy SSE app at ``/sse`` +
        ``/messages``. Gated by the same bearer.
    host:
        Bind host. Defaults to ``0.0.0.0``.
    run:
        If True (default), start uvicorn synchronously. If False,
        return the configured app/server for the caller to run.

    Returns
    -------
    ServeMcpResult with the Starlette app, uvicorn config, and server.
    """
    # Import lazily: this lets the auth module be imported and unit
    # tested without the full mcp SDK installed.
    #
    # We require mcp>=1.6 -- that's the version that ships the
    # ``streamable_http_manager`` module, ``handle_request`` ASGI entry
    # point, and ``session_manager.run()`` async context manager that
    # this transport depends on. Consumers must pin accordingly in
    # their pyproject. We do not attempt a fallback for older SDKs:
    # the API surface differs in three places (module path, method
    # name, lifecycle) and a partial fallback would only defer the
    # crash to request time.
    from mcp.server.streamable_http_manager import (  # type: ignore[import-not-found]
        StreamableHTTPSessionManager,
    )

    service_name = getattr(server, "name", None) or "mcp"

    # The SDK ships a session manager whose handle_request is an ASGI
    # callable. It requires a long-lived task group started via run() as
    # an async context manager, so we wire that into the Starlette
    # lifespan.
    session_manager = StreamableHTTPSessionManager(
        app=server,
        event_store=None,
        json_response=False,
        stateless=False,
    )

    @contextlib.asynccontextmanager
    async def _lifespan(_app: Starlette) -> AsyncIterator[None]:
        async with session_manager.run():
            yield

    # The TypeScript transport (and existing client config) speaks
    # plain ``/mcp`` -- no trailing slash. Starlette's ``Mount`` would
    # answer ``/mcp`` with a 307 to ``/mcp/``, which some MCP clients
    # do not follow on POSTs. ``_PathRewriteMiddleware`` (installed
    # below) rewrites the path in the ASGI scope before routing, so
    # both ``/mcp`` and ``/mcp/`` reach the session manager without a
    # redirect.
    routes = [
        Route("/health", endpoint=_health_route_factory(service_name)),
        Mount(path, app=session_manager.handle_request),
    ]

    if legacy_sse:
        # The legacy SSE app lives under a different module in older
        # SDK versions; defer the import + mount so we don't hard-fail
        # when running without it.
        try:
            from mcp.server.sse import SseServerTransport  # type: ignore[import-not-found]

            sse = SseServerTransport("/messages")

            async def handle_sse(request: Request) -> Response:
                async with sse.connect_sse(
                    request.scope, request.receive, request._send  # noqa: SLF001
                ) as streams:
                    await server.run(
                        streams[0], streams[1], server.create_initialization_options()
                    )
                return Response()

            routes.extend(
                [
                    Route("/sse", endpoint=handle_sse),
                    Mount("/messages", app=sse.handle_post_message),
                ]
            )
        except Exception as exc:  # pragma: no cover - depends on SDK version
            logger.warning(
                "mcp-shared: legacy_sse requested but unavailable: %s", exc
            )

    counter = _AuthFailureCounter()
    middleware = [
        Middleware(_PathRewriteMiddleware, base_path=path),
        Middleware(
            BearerAuthMiddleware,
            token=token,
            allow_paths=("/health",),
            counter=counter,
        ),
    ]

    app = Starlette(
        routes=routes,
        middleware=middleware,
        lifespan=_lifespan,
    )

    logger.info(
        "mcp-shared: listening on :%d path %s (service '%s'%s)",
        port,
        path,
        service_name,
        ", legacy SSE on /sse + /messages" if legacy_sse else "",
    )

    config = uvicorn.Config(app, host=host, port=port, log_level="warning")
    uv_server = uvicorn.Server(config)
    if run:
        uv_server.run()
    return ServeMcpResult(app=app, config=config, server=uv_server)
