import { describe, expect, test } from "bun:test";
import { OpenAiRealtimeProvider, parseOpenAiRealtimeMessage } from "./openai-realtime.ts";
import type { WebSocketFactory, WebSocketLike } from "./realtime-websocket.ts";

describe("OpenAI Realtime provider", () => {
  test("parses audio, transcript, and tool-call events", () => {
    const audio = parseOpenAiRealtimeMessage({
      type: "response.output_audio.delta",
      delta: Buffer.from("abc").toString("base64"),
    });
    expect(audio[0]!.type).toBe("audio");
    if (audio[0]!.type === "audio") {
      expect(audio[0].frame.format.sampleRateHz).toBe(16000);
    }

    const providerRateAudio = parseOpenAiRealtimeMessage(
      {
        type: "response.output_audio.delta",
        delta: Buffer.from("abc").toString("base64"),
      },
      { codec: "pcm16", sampleRateHz: 24000, channels: 1 },
    );
    expect(providerRateAudio[0]).toMatchObject({
      type: "audio",
      frame: { format: { codec: "pcm16", sampleRateHz: 24000, channels: 1 } },
    });

    const transcript = parseOpenAiRealtimeMessage({
      type: "conversation.item.input_audio_transcription.completed",
      transcript: "hello",
    });
    expect(transcript[0]).toEqual({
      type: "transcript",
      role: "caller",
      text: "hello",
      isFinal: true,
    });

    const tool = parseOpenAiRealtimeMessage({
      type: "response.function_call_arguments.done",
      call_id: "call-1",
      name: "transfer_to_staff",
      arguments: "{\"reason\":\"urgent\"}",
    });
    expect(tool[0]).toEqual({
      type: "tool_call",
      call: {
        id: "call-1",
        name: "transfer_to_staff",
        arguments: { reason: "urgent" },
      },
    });
  });

  test("connects with model URL, safety header, and session update", async () => {
    const sockets: FakeWebSocket[] = [];
    const factory: WebSocketFactory = (url, _protocols, options) => {
      const socket = new FakeWebSocket(url, options?.headers ?? {});
      sockets.push(socket);
      return socket;
    };
    const provider = new OpenAiRealtimeProvider({
      apiKey: "sk-test",
      model: "gpt-realtime",
      webSocketFactory: factory,
    });

    const session = await provider.connect({
      callId: "call-1",
      instructions: "Be safe.",
      inputFormat: { codec: "pcm16", sampleRateHz: 24000, channels: 1 },
      outputFormat: { codec: "pcm16", sampleRateHz: 24000, channels: 1 },
      tools: [],
      safetyIdentifier: "hashed-caller",
      events: {},
    });

    sockets[0]!.open();
    await session.sendText("hello");

    expect(sockets[0]!.url).toContain("model=gpt-realtime");
    expect(sockets[0]!.headers.Authorization).toBe("Bearer sk-test");
    expect(sockets[0]!.headers["OpenAI-Safety-Identifier"]).toBe("hashed-caller");
    expect(JSON.parse(sockets[0]!.sent[0]!).type).toBe("session.update");
    expect(JSON.parse(sockets[0]!.sent[0]!).session.type).toBe("realtime");
    expect(JSON.parse(sockets[0]!.sent[0]!).session.audio.input.format).toEqual({
      type: "audio/pcm",
      rate: 24000,
    });
    expect(JSON.parse(sockets[0]!.sent[0]!).session.audio.output.format).toEqual({
      type: "audio/pcm",
      rate: 24000,
    });
    expect(JSON.parse(sockets[0]!.sent[1]!).type).toBe("conversation.item.create");
    expect(JSON.parse(sockets[0]!.sent[2]!).type).toBe("response.create");
  });

  test("handles binary WebSocket message events from the server runtime", async () => {
    const transcripts: string[] = [];
    const sockets: FakeWebSocket[] = [];
    const provider = new OpenAiRealtimeProvider({
      apiKey: "sk-test",
      model: "gpt-realtime",
      webSocketFactory: (url, _protocols, options) => {
        const socket = new FakeWebSocket(url, options?.headers ?? {});
        sockets.push(socket);
        return socket;
      },
    });

    await provider.connect({
      callId: "call-1",
      instructions: "Be safe.",
      inputFormat: { codec: "pcm16", sampleRateHz: 24000, channels: 1 },
      outputFormat: { codec: "pcm16", sampleRateHz: 24000, channels: 1 },
      tools: [],
      events: {
        onTranscript: (event) => {
          transcripts.push(event.text);
        },
      },
    });
    sockets[0]!.open();
    sockets[0]!.message(
      new TextEncoder().encode(
        JSON.stringify({
          type: "conversation.item.input_audio_transcription.completed",
          transcript: "hello from binary",
        }),
      ),
    );

    expect(transcripts).toEqual(["hello from binary"]);
  });

  test("tags provider audio with the negotiated output sample rate", async () => {
    const audioRates: number[] = [];
    const sockets: FakeWebSocket[] = [];
    const provider = new OpenAiRealtimeProvider({
      apiKey: "sk-test",
      model: "gpt-realtime",
      webSocketFactory: (url, _protocols, options) => {
        const socket = new FakeWebSocket(url, options?.headers ?? {});
        sockets.push(socket);
        return socket;
      },
    });

    await provider.connect({
      callId: "call-1",
      instructions: "Be safe.",
      inputFormat: { codec: "pcm16", sampleRateHz: 24000, channels: 1 },
      outputFormat: { codec: "pcm16", sampleRateHz: 24000, channels: 1 },
      tools: [],
      events: {
        onAudio: (frame) => {
          audioRates.push(frame.format.sampleRateHz);
        },
      },
    });
    sockets[0]!.open();
    sockets[0]!.message(
      JSON.stringify({
        type: "response.output_audio.delta",
        delta: Buffer.from([1, 2, 3, 4]).toString("base64"),
      }),
    );

    expect(audioRates).toEqual([24000]);
  });

  test("can still emit the legacy session shape for xAI compatibility", async () => {
    const sockets: FakeWebSocket[] = [];
    const factory: WebSocketFactory = (url, _protocols, options) => {
      const socket = new FakeWebSocket(url, options?.headers ?? {});
      sockets.push(socket);
      return socket;
    };
    const provider = new OpenAiRealtimeProvider({
      apiKey: "xai-test",
      model: "grok-voice-think-fast-1.0",
      sessionShape: "legacy",
      webSocketFactory: factory,
    });

    await provider.connect({
      callId: "call-1",
      instructions: "Be safe.",
      inputFormat: { codec: "pcm16", sampleRateHz: 24000, channels: 1 },
      outputFormat: { codec: "pcm16", sampleRateHz: 24000, channels: 1 },
      tools: [],
      events: {},
    });
    sockets[0]!.open();

    const update = JSON.parse(sockets[0]!.sent[0]!);
    expect(update.session.input_audio_format).toBe("pcm16");
    expect(update.session.output_audio_format).toBe("pcm16");
  });
});

class FakeWebSocket implements WebSocketLike {
  readyState = 0;
  sent: string[] = [];
  private listeners = new Map<string, Array<(event: unknown) => void>>();

  constructor(
    readonly url: string,
    readonly headers: Record<string, string>,
  ) {}

  send(data: string | ArrayBuffer | Uint8Array): void {
    this.sent.push(String(data));
  }

  close(): void {
    this.readyState = 3;
  }

  addEventListener(type: "open" | "message" | "error" | "close", listener: (event: unknown) => void): void {
    const existing = this.listeners.get(type) ?? [];
    existing.push(listener);
    this.listeners.set(type, existing);
  }

  open(): void {
    this.readyState = 1;
    for (const listener of this.listeners.get("open") ?? []) listener({});
  }

  message(data: string | Uint8Array): void {
    for (const listener of this.listeners.get("message") ?? []) listener({ data });
  }
}
