"""Bearer token middleware for Python MCP servers.

Returns a Starlette JSONResponse(401) if the Authorization header is
missing or does not match the expected token. Returns None on success.

- Case-insensitive Authorization header lookup.
- "Bearer " prefix is matched case-insensitively, exactly one space.
- Constant-time comparison via secrets.compare_digest.
- Empty / whitespace-only tokens always fail.
- Token value is never logged.
"""

from __future__ import annotations

import secrets
from typing import Mapping, Optional

from starlette.responses import JSONResponse


_UNAUTHORIZED_BODY = {
    "error": "unauthorized",
    "message": "missing or invalid bearer token",
}


def _unauthorized() -> JSONResponse:
    return JSONResponse(
        _UNAUTHORIZED_BODY,
        status_code=401,
        headers={"www-authenticate": 'Bearer realm="mcp"'},
    )


def _get_header(headers: Mapping[str, str], name: str) -> Optional[str]:
    """Case-insensitive header lookup over a plain dict-like mapping.

    Starlette's Headers object is already case-insensitive; this helper
    keeps the function usable with either a Starlette Headers instance
    or a plain dict.
    """
    if hasattr(headers, "get") and not isinstance(headers, dict):
        # Starlette Headers: case-insensitive .get() already.
        return headers.get(name)  # type: ignore[no-any-return]
    target = name.lower()
    for k, v in headers.items():
        if k.lower() == target:
            return v
    return None


def require_bearer(
    headers: Mapping[str, str],
    expected_token: str,
) -> Optional[JSONResponse]:
    """Validate the Authorization header against the expected bearer.

    Returns None on success; a 401 JSONResponse on failure.
    """
    if not expected_token or not expected_token.strip():
        return _unauthorized()

    auth = _get_header(headers, "authorization")
    if not auth:
        return _unauthorized()

    if len(auth) < 7:
        return _unauthorized()

    prefix = auth[:7]
    if prefix.lower() != "bearer ":
        return _unauthorized()

    token = auth[7:]
    if not token or not token.strip():
        return _unauthorized()

    # secrets.compare_digest requires equal-length inputs to provide
    # its constant-time guarantee; it returns False fast on length
    # mismatch but does not leak the actual token via timing.
    if not secrets.compare_digest(token, expected_token):
        return _unauthorized()

    return None
