from collections.abc import Generator import json import os import time from fastapi.testclient import TestClient import pytest from starlette.testclient import WebSocketTestSession from faster_whisper_server.config import BYTES_PER_SECOND from faster_whisper_server.server_models import TranscriptionVerboseJsonResponse SIMILARITY_THRESHOLD = 0.97 AUDIO_FILES_LIMIT = 5 AUDIO_FILE_DIR = "tests/data" TRANSCRIBE_ENDPOINT = "/v1/audio/transcriptions?response_format=verbose_json" @pytest.fixture() def ws(client: TestClient) -> Generator[WebSocketTestSession, None, None]: with client.websocket_connect(TRANSCRIBE_ENDPOINT) as ws: yield ws def get_audio_file_paths() -> list[str]: file_paths: list[str] = [] directory = "tests/data" for filename in sorted(os.listdir(directory)[:AUDIO_FILES_LIMIT]): file_paths.append(os.path.join(directory, filename)) # noqa: PERF401 return file_paths file_paths = get_audio_file_paths() def stream_audio_data(ws: WebSocketTestSession, data: bytes, *, chunk_size: int = 4000, speed: float = 1.0) -> None: for i in range(0, len(data), chunk_size): ws.send_bytes(data[i : i + chunk_size]) delay = len(data[i : i + chunk_size]) / BYTES_PER_SECOND / speed time.sleep(delay) def transcribe_audio_data(client: TestClient, data: bytes) -> TranscriptionVerboseJsonResponse: response = client.post( TRANSCRIBE_ENDPOINT, files={"file": ("audio.raw", data, "audio/raw")}, ) data = json.loads(response.json()) # TODO: figure this out return TranscriptionVerboseJsonResponse(**data) # pyright: ignore[reportCallIssue] # @pytest.mark.parametrize("file_path", file_paths) # def test_ws_audio_transcriptions( # client: TestClient, ws: WebSocketTestSession, file_path: str # ): # with open(file_path, "rb") as file: # data = file.read() # # streaming_transcription: TranscriptionVerboseJsonResponse = None # type: ignore # noqa: PGH003 # thread = threading.Thread( # target=stream_audio_data, args=(ws, data), kwargs={"speed": 4.0} # ) # thread.start() # while True: # try: # streaming_transcription = TranscriptionVerboseJsonResponse( # **ws.receive_json() # ) # except WebSocketDisconnect: # break # file_transcription = transcribe_audio_data(client, data) # s = SequenceMatcher( # lambda x: x == " ", file_transcription.text, streaming_transcription.text # ) # assert ( # s.ratio() > SIMILARITY_THRESHOLD # ), f"\nExpected: {file_transcription.text}\nReceived: {streaming_transcription.text}"