File size: 2,808 Bytes
313814b
 
 
 
 
 
 
 
 
 
 
 
39ee116
 
 
313814b
 
a9ee91b
 
 
313814b
 
 
 
 
 
 
 
a9ee91b
 
 
 
 
 
313814b
 
 
a9ee91b
 
313814b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0feed8
313814b
a9ee91b
313814b
 
 
d0feed8
313814b
 
 
a9ee91b
 
 
313814b
 
a9ee91b
d0feed8
a9ee91b
 
 
 
 
 
d0feed8
 
 
a9ee91b
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import json
import os
import threading
import time
from difflib import SequenceMatcher
from typing import Generator

import pytest
from fastapi import WebSocketDisconnect
from fastapi.testclient import TestClient
from starlette.testclient import WebSocketTestSession

from faster_whisper_server.config import BYTES_PER_SECOND
from faster_whisper_server.main import app
from faster_whisper_server.server_models import TranscriptionVerboseJsonResponse

SIMILARITY_THRESHOLD = 0.97
AUDIO_FILES_LIMIT = 5
AUDIO_FILE_DIR = "tests/data"
TRANSCRIBE_ENDPOINT = "/v1/audio/transcriptions?response_format=verbose_json"


@pytest.fixture()
def client() -> Generator[TestClient, None, None]:
    with TestClient(app) as client:
        yield client


@pytest.fixture()
def ws(client: TestClient) -> Generator[WebSocketTestSession, None, None]:
    with client.websocket_connect(TRANSCRIBE_ENDPOINT) as ws:
        yield ws


def get_audio_file_paths():
    file_paths = []
    directory = "tests/data"
    for filename in sorted(os.listdir(directory)[:AUDIO_FILES_LIMIT]):
        file_paths.append(os.path.join(directory, filename))
    return file_paths


file_paths = get_audio_file_paths()


def stream_audio_data(
    ws: WebSocketTestSession, data: bytes, *, chunk_size: int = 4000, speed: float = 1.0
):
    for i in range(0, len(data), chunk_size):
        ws.send_bytes(data[i : i + chunk_size])
        delay = len(data[i : i + chunk_size]) / BYTES_PER_SECOND / speed
        time.sleep(delay)


def transcribe_audio_data(
    client: TestClient, data: bytes
) -> TranscriptionVerboseJsonResponse:
    response = client.post(
        TRANSCRIBE_ENDPOINT,
        files={"file": ("audio.raw", data, "audio/raw")},
    )
    data = json.loads(response.json())  # TODO: figure this out
    return TranscriptionVerboseJsonResponse(**data)  # type: ignore


@pytest.mark.parametrize("file_path", file_paths)
def test_ws_audio_transcriptions(
    client: TestClient, ws: WebSocketTestSession, file_path: str
):
    with open(file_path, "rb") as file:
        data = file.read()

    streaming_transcription: TranscriptionVerboseJsonResponse = None  # type: ignore
    thread = threading.Thread(
        target=stream_audio_data, args=(ws, data), kwargs={"speed": 4.0}
    )
    thread.start()
    while True:
        try:
            streaming_transcription = TranscriptionVerboseJsonResponse(
                **ws.receive_json()
            )
        except WebSocketDisconnect:
            break
    file_transcription = transcribe_audio_data(client, data)
    s = SequenceMatcher(
        lambda x: x == " ", file_transcription.text, streaming_transcription.text
    )
    assert (
        s.ratio() > SIMILARITY_THRESHOLD
    ), f"\nExpected: {file_transcription.text}\nReceived: {streaming_transcription.text}"