File size: 4,011 Bytes
8c12cdc
42343e0
8c12cdc
04d664a
 
 
dc4f25f
323aa51
 
 
8c12cdc
23a3cae
 
 
 
 
8c12cdc
 
 
 
 
 
 
dc4f25f
8c12cdc
 
23a3cae
dc4f25f
04d664a
42343e0
04d664a
 
8c12cdc
 
 
 
04d664a
 
8c12cdc
dc4f25f
8c12cdc
 
23a3cae
dc4f25f
04d664a
42343e0
04d664a
 
8c12cdc
 
 
 
04d664a
 
ec4d8ae
8c12cdc
 
23a3cae
dc4f25f
04d664a
42343e0
04d664a
 
8c12cdc
 
 
 
04d664a
 
ec4d8ae
323aa51
 
23a3cae
04d664a
 
 
323aa51
 
 
 
04d664a
323aa51
 
 
 
 
 
 
 
 
23a3cae
04d664a
 
 
323aa51
 
 
 
04d664a
323aa51
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import json
from pathlib import Path

import anyio
from httpx import AsyncClient
from httpx_sse import aconnect_sse
import pytest
import srt
import webvtt
import webvtt.vtt

from faster_whisper_server.api_models import (
    CreateTranscriptionResponseJson,
    CreateTranscriptionResponseVerboseJson,
)

FILE_PATHS = ["audio.wav"]  # HACK
ENDPOINTS = [
    "/v1/audio/transcriptions",
    "/v1/audio/translations",
]


parameters = [(file_path, endpoint) for endpoint in ENDPOINTS for file_path in FILE_PATHS]


@pytest.mark.asyncio
@pytest.mark.parametrize(("file_path", "endpoint"), parameters)
async def test_streaming_transcription_text(aclient: AsyncClient, file_path: str, endpoint: str) -> None:
    extension = Path(file_path).suffix[1:]
    async with await anyio.open_file(file_path, "rb") as f:
        data = await f.read()
    kwargs = {
        "files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
        "data": {"response_format": "text", "stream": True},
    }
    async with aconnect_sse(aclient, "POST", endpoint, **kwargs) as event_source:
        async for event in event_source.aiter_sse():
            print(event)
            assert len(event.data) > 1  # HACK: 1 because of the space character that's always prepended


@pytest.mark.asyncio
@pytest.mark.parametrize(("file_path", "endpoint"), parameters)
async def test_streaming_transcription_json(aclient: AsyncClient, file_path: str, endpoint: str) -> None:
    extension = Path(file_path).suffix[1:]
    async with await anyio.open_file(file_path, "rb") as f:
        data = await f.read()
    kwargs = {
        "files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
        "data": {"response_format": "json", "stream": True},
    }
    async with aconnect_sse(aclient, "POST", endpoint, **kwargs) as event_source:
        async for event in event_source.aiter_sse():
            CreateTranscriptionResponseJson(**json.loads(event.data))


@pytest.mark.asyncio
@pytest.mark.parametrize(("file_path", "endpoint"), parameters)
async def test_streaming_transcription_verbose_json(aclient: AsyncClient, file_path: str, endpoint: str) -> None:
    extension = Path(file_path).suffix[1:]
    async with await anyio.open_file(file_path, "rb") as f:
        data = await f.read()
    kwargs = {
        "files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
        "data": {"response_format": "verbose_json", "stream": True},
    }
    async with aconnect_sse(aclient, "POST", endpoint, **kwargs) as event_source:
        async for event in event_source.aiter_sse():
            CreateTranscriptionResponseVerboseJson(**json.loads(event.data))


@pytest.mark.asyncio
async def test_transcription_vtt(aclient: AsyncClient) -> None:
    async with await anyio.open_file("audio.wav", "rb") as f:
        data = await f.read()
    kwargs = {
        "files": {"file": ("audio.wav", data, "audio/wav")},
        "data": {"response_format": "vtt", "stream": False},
    }
    response = await aclient.post("/v1/audio/transcriptions", **kwargs)
    assert response.status_code == 200
    assert response.headers["content-type"] == "text/vtt; charset=utf-8"
    text = response.text
    webvtt.from_string(text)
    text = text.replace("WEBVTT", "YO")
    with pytest.raises(webvtt.vtt.MalformedFileError):
        webvtt.from_string(text)


@pytest.mark.asyncio
async def test_transcription_srt(aclient: AsyncClient) -> None:
    async with await anyio.open_file("audio.wav", "rb") as f:
        data = await f.read()
    kwargs = {
        "files": {"file": ("audio.wav", data, "audio/wav")},
        "data": {"response_format": "srt", "stream": False},
    }
    response = await aclient.post("/v1/audio/transcriptions", **kwargs)
    assert response.status_code == 200
    assert "text/plain" in response.headers["content-type"]

    text = response.text
    list(srt.parse(text))
    text = text.replace("1", "YO")
    with pytest.raises(srt.SRTParseError):
        list(srt.parse(text))