import WsWebSocket from "ws";
import { Buffer } from "node:buffer";
import type {
  AudioFrame,
  ToolInvocation,
  VoiceProviderEvents,
  VoiceProviderKind,
  VoiceProviderSession,
} from "./types.ts";

export interface WebSocketLike {
  readyState: number;
  send(data: string | ArrayBuffer | Uint8Array): void;
  close(code?: number, reason?: string): void;
  addEventListener(
    type: "open" | "message" | "error" | "close",
    listener: (event: unknown) => void,
  ): void;
}

export type WebSocketFactory = (
  url: string,
  protocols?: string | string[],
  options?: { headers?: Record<string, string> },
) => WebSocketLike;

export interface RealtimeWebSocketSessionOptions {
  provider: VoiceProviderKind;
  sessionId: string;
  url: string;
  headers: Record<string, string>;
  events: VoiceProviderEvents;
  makeSessionUpdate: () => Record<string, unknown>;
  makeAppendAudio: (frame: AudioFrame) => Record<string, unknown>;
  makeTextMessage: (text: string) => RealtimeClientEvent | RealtimeClientEvent[];
  makeToolResult: (toolCallId: string, result: unknown) => RealtimeClientEvent | RealtimeClientEvent[];
  makeInterrupt: () => Record<string, unknown>;
  parseMessage: (data: unknown) => ParsedRealtimeMessage[];
  webSocketFactory?: WebSocketFactory;
}

export type ParsedRealtimeMessage =
  | { type: "audio"; frame: AudioFrame }
  | { type: "transcript"; role: "caller" | "agent"; text: string; isFinal: boolean }
  | { type: "tool_call"; call: ToolInvocation }
  | { type: "error"; error: Error }
  | { type: "close"; reason?: string };

type RealtimeClientEvent = Record<string, unknown>;

function defaultWebSocketFactory(
  url: string,
  protocols?: string | string[],
  options?: { headers?: Record<string, string> },
): WebSocketLike {
  return new WsWebSocket(url, protocols ?? [], {
    headers: options?.headers,
  }) as unknown as WebSocketLike;
}

function encodeBase64(bytes: Uint8Array): string {
  return Buffer.from(bytes).toString("base64");
}

export function decodeBase64(value: string): Uint8Array {
  return new Uint8Array(Buffer.from(value, "base64"));
}

export function audioFrameToBase64(frame: AudioFrame): string {
  return encodeBase64(frame.data);
}

export class RealtimeWebSocketSession implements VoiceProviderSession {
  private readonly ws: WebSocketLike;
  private openPromise: Promise<void>;

  readonly provider: VoiceProviderKind;
  readonly sessionId: string;

  constructor(private readonly opts: RealtimeWebSocketSessionOptions) {
    this.provider = opts.provider;
    this.sessionId = opts.sessionId;
    const factory = opts.webSocketFactory ?? defaultWebSocketFactory;
    this.ws = factory(opts.url, undefined, { headers: opts.headers });
    this.openPromise = new Promise((resolve, reject) => {
      this.ws.addEventListener("open", () => {
        this.sendJson(this.opts.makeSessionUpdate());
        resolve();
      });
      this.ws.addEventListener("error", (event) => {
        const err = new Error(`websocket error: ${String(event)}`);
        void this.opts.events.onError?.(err);
        reject(err);
      });
      this.ws.addEventListener("close", () => {
        void this.opts.events.onClose?.();
      });
      this.ws.addEventListener("message", (event) => {
        const data = (event as { data?: unknown }).data;
        void this.handleMessage(data);
      });
    });
  }

  async sendAudio(frame: AudioFrame): Promise<void> {
    await this.openPromise;
    this.sendJson(this.opts.makeAppendAudio(frame));
  }

  async sendText(text: string): Promise<void> {
    await this.openPromise;
    this.sendEvents(this.opts.makeTextMessage(text));
  }

  async sendToolResult(toolCallId: string, result: unknown): Promise<void> {
    await this.openPromise;
    this.sendEvents(this.opts.makeToolResult(toolCallId, result));
  }

  async interrupt(): Promise<void> {
    await this.openPromise;
    this.sendJson(this.opts.makeInterrupt());
  }

  async close(reason = "voice session closed"): Promise<void> {
    this.ws.close(1000, reason);
  }

  private sendJson(value: Record<string, unknown>): void {
    this.ws.send(JSON.stringify(value));
  }

  private sendEvents(value: RealtimeClientEvent | RealtimeClientEvent[]): void {
    for (const event of Array.isArray(value) ? value : [value]) {
      this.sendJson(event);
    }
  }

  private async handleMessage(raw: unknown): Promise<void> {
    let parsed: unknown = raw;
    if (typeof raw === "string" || raw instanceof Uint8Array) {
      const text = typeof raw === "string" ? raw : new TextDecoder().decode(raw);
      parsed = JSON.parse(text);
    }
    for (const event of this.opts.parseMessage(parsed)) {
      switch (event.type) {
        case "audio":
          await this.opts.events.onAudio?.(event.frame);
          break;
        case "transcript":
          await this.opts.events.onTranscript?.({
            role: event.role,
            text: event.text,
            isFinal: event.isFinal,
          });
          break;
        case "tool_call":
          await this.opts.events.onToolCall?.(event.call);
          break;
        case "error":
          await this.opts.events.onError?.(event.error);
          break;
        case "close":
          await this.opts.events.onClose?.(event.reason);
          break;
      }
    }
  }
}
