Fedir Zadniprovskyi commited on
Commit
04d664a
·
1 Parent(s): ec4d8ae

test: switch to async http client

Browse files
pyproject.toml CHANGED
@@ -30,6 +30,7 @@ dev = [
30
  "basedpyright==1.13.0",
31
  "pytest-xdist==3.6.1",
32
  "pytest-asyncio>=0.24.0",
 
33
  ]
34
 
35
  [build-system]
 
30
  "basedpyright==1.13.0",
31
  "pytest-xdist==3.6.1",
32
  "pytest-asyncio>=0.24.0",
33
+ "anyio>=4.4.0",
34
  ]
35
 
36
  [build-system]
tests/api_timestamp_granularities_test.py CHANGED
@@ -1,5 +1,7 @@
1
  """See `tests/openai_timestamp_granularities_test.py` to understand how OpenAI handles `response_type` and `timestamp_granularities`.""" # noqa: E501
2
 
 
 
3
  from faster_whisper_server.api_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities
4
  from openai import AsyncOpenAI
5
  import pytest
@@ -11,10 +13,10 @@ async def test_api_json_response_format_and_timestamp_granularities_combinations
11
  openai_client: AsyncOpenAI,
12
  timestamp_granularities: TimestampGranularities,
13
  ) -> None:
14
- audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230
15
 
16
  await openai_client.audio.transcriptions.create(
17
- file=audio_file, model="whisper-1", response_format="json", timestamp_granularities=timestamp_granularities
18
  )
19
 
20
 
@@ -24,10 +26,10 @@ async def test_api_verbose_json_response_format_and_timestamp_granularities_comb
24
  openai_client: AsyncOpenAI,
25
  timestamp_granularities: TimestampGranularities,
26
  ) -> None:
27
- audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230
28
 
29
  transcription = await openai_client.audio.transcriptions.create(
30
- file=audio_file,
31
  model="whisper-1",
32
  response_format="verbose_json",
33
  timestamp_granularities=timestamp_granularities,
 
1
  """See `tests/openai_timestamp_granularities_test.py` to understand how OpenAI handles `response_type` and `timestamp_granularities`.""" # noqa: E501
2
 
3
+ from pathlib import Path
4
+
5
  from faster_whisper_server.api_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities
6
  from openai import AsyncOpenAI
7
  import pytest
 
13
  openai_client: AsyncOpenAI,
14
  timestamp_granularities: TimestampGranularities,
15
  ) -> None:
16
+ file_path = Path("audio.wav")
17
 
18
  await openai_client.audio.transcriptions.create(
19
+ file=file_path, model="whisper-1", response_format="json", timestamp_granularities=timestamp_granularities
20
  )
21
 
22
 
 
26
  openai_client: AsyncOpenAI,
27
  timestamp_granularities: TimestampGranularities,
28
  ) -> None:
29
+ file_path = Path("audio.wav")
30
 
31
  transcription = await openai_client.audio.transcriptions.create(
32
+ file=file_path,
33
  model="whisper-1",
34
  response_format="verbose_json",
35
  timestamp_granularities=timestamp_granularities,
tests/conftest.py CHANGED
@@ -18,6 +18,7 @@ def pytest_configure() -> None:
18
  logger.disabled = True
19
 
20
 
 
21
  @pytest.fixture()
22
  def client() -> Generator[TestClient, None, None]:
23
  os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
 
18
  logger.disabled = True
19
 
20
 
21
+ # NOTE: not being used. Keeping just in case
22
  @pytest.fixture()
23
  def client() -> Generator[TestClient, None, None]:
24
  os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
tests/openai_timestamp_granularities_test.py CHANGED
@@ -1,5 +1,7 @@
1
  """OpenAI's handling of `response_format` and `timestamp_granularities` is a bit confusing and inconsistent. This test module exists to capture the OpenAI API's behavior with respect to these parameters.""" # noqa: E501
2
 
 
 
3
  from faster_whisper_server.api_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities
4
  from openai import AsyncOpenAI, BadRequestError
5
  import pytest
@@ -12,19 +14,18 @@ async def test_openai_json_response_format_and_timestamp_granularities_combinati
12
  actual_openai_client: AsyncOpenAI,
13
  timestamp_granularities: TimestampGranularities,
14
  ) -> None:
15
- audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230
16
-
17
  if "word" in timestamp_granularities:
18
  with pytest.raises(BadRequestError):
19
  await actual_openai_client.audio.transcriptions.create(
20
- file=audio_file,
21
  model="whisper-1",
22
  response_format="json",
23
  timestamp_granularities=timestamp_granularities,
24
  )
25
  else:
26
  await actual_openai_client.audio.transcriptions.create(
27
- file=audio_file, model="whisper-1", response_format="json", timestamp_granularities=timestamp_granularities
28
  )
29
 
30
 
@@ -35,10 +36,10 @@ async def test_openai_verbose_json_response_format_and_timestamp_granularities_c
35
  actual_openai_client: AsyncOpenAI,
36
  timestamp_granularities: TimestampGranularities,
37
  ) -> None:
38
- audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230
39
 
40
  transcription = await actual_openai_client.audio.transcriptions.create(
41
- file=audio_file,
42
  model="whisper-1",
43
  response_format="verbose_json",
44
  timestamp_granularities=timestamp_granularities,
 
1
  """OpenAI's handling of `response_format` and `timestamp_granularities` is a bit confusing and inconsistent. This test module exists to capture the OpenAI API's behavior with respect to these parameters.""" # noqa: E501
2
 
3
+ from pathlib import Path
4
+
5
  from faster_whisper_server.api_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities
6
  from openai import AsyncOpenAI, BadRequestError
7
  import pytest
 
14
  actual_openai_client: AsyncOpenAI,
15
  timestamp_granularities: TimestampGranularities,
16
  ) -> None:
17
+ file_path = Path("audio.wav")
 
18
  if "word" in timestamp_granularities:
19
  with pytest.raises(BadRequestError):
20
  await actual_openai_client.audio.transcriptions.create(
21
+ file=file_path,
22
  model="whisper-1",
23
  response_format="json",
24
  timestamp_granularities=timestamp_granularities,
25
  )
26
  else:
27
  await actual_openai_client.audio.transcriptions.create(
28
+ file=file_path, model="whisper-1", response_format="json", timestamp_granularities=timestamp_granularities
29
  )
30
 
31
 
 
36
  actual_openai_client: AsyncOpenAI,
37
  timestamp_granularities: TimestampGranularities,
38
  ) -> None:
39
+ file_path = Path("audio.wav")
40
 
41
  transcription = await actual_openai_client.audio.transcriptions.create(
42
+ file=file_path,
43
  model="whisper-1",
44
  response_format="verbose_json",
45
  timestamp_granularities=timestamp_granularities,
tests/sse_test.py CHANGED
@@ -1,12 +1,13 @@
1
  import json
2
  import os
3
 
4
- from fastapi.testclient import TestClient
5
  from faster_whisper_server.api_models import (
6
  CreateTranscriptionResponseJson,
7
  CreateTranscriptionResponseVerboseJson,
8
  )
9
- from httpx_sse import connect_sse
 
10
  import pytest
11
  import srt
12
  import webvtt
@@ -22,57 +23,61 @@ ENDPOINTS = [
22
  parameters = [(file_path, endpoint) for endpoint in ENDPOINTS for file_path in FILE_PATHS]
23
 
24
 
 
25
  @pytest.mark.parametrize(("file_path", "endpoint"), parameters)
26
- def test_streaming_transcription_text(client: TestClient, file_path: str, endpoint: str) -> None:
27
  extension = os.path.splitext(file_path)[1]
28
- with open(file_path, "rb") as f:
29
- data = f.read()
30
  kwargs = {
31
  "files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
32
  "data": {"response_format": "text", "stream": True},
33
  }
34
- with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
35
- for event in event_source.iter_sse():
36
  print(event)
37
  assert len(event.data) > 1 # HACK: 1 because of the space character that's always prepended
38
 
39
 
 
40
  @pytest.mark.parametrize(("file_path", "endpoint"), parameters)
41
- def test_streaming_transcription_json(client: TestClient, file_path: str, endpoint: str) -> None:
42
  extension = os.path.splitext(file_path)[1]
43
- with open(file_path, "rb") as f:
44
- data = f.read()
45
  kwargs = {
46
  "files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
47
  "data": {"response_format": "json", "stream": True},
48
  }
49
- with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
50
- for event in event_source.iter_sse():
51
  CreateTranscriptionResponseJson(**json.loads(event.data))
52
 
53
 
 
54
  @pytest.mark.parametrize(("file_path", "endpoint"), parameters)
55
- def test_streaming_transcription_verbose_json(client: TestClient, file_path: str, endpoint: str) -> None:
56
  extension = os.path.splitext(file_path)[1]
57
- with open(file_path, "rb") as f:
58
- data = f.read()
59
  kwargs = {
60
  "files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
61
  "data": {"response_format": "verbose_json", "stream": True},
62
  }
63
- with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
64
- for event in event_source.iter_sse():
65
  CreateTranscriptionResponseVerboseJson(**json.loads(event.data))
66
 
67
 
68
- def test_transcription_vtt(client: TestClient) -> None:
69
- with open("audio.wav", "rb") as f:
70
- data = f.read()
 
71
  kwargs = {
72
  "files": {"file": ("audio.wav", data, "audio/wav")},
73
  "data": {"response_format": "vtt", "stream": False},
74
  }
75
- response = client.post("/v1/audio/transcriptions", **kwargs)
76
  assert response.status_code == 200
77
  assert response.headers["content-type"] == "text/vtt; charset=utf-8"
78
  text = response.text
@@ -82,14 +87,15 @@ def test_transcription_vtt(client: TestClient) -> None:
82
  webvtt.from_string(text)
83
 
84
 
85
- def test_transcription_srt(client: TestClient) -> None:
86
- with open("audio.wav", "rb") as f:
87
- data = f.read()
 
88
  kwargs = {
89
  "files": {"file": ("audio.wav", data, "audio/wav")},
90
  "data": {"response_format": "srt", "stream": False},
91
  }
92
- response = client.post("/v1/audio/transcriptions", **kwargs)
93
  assert response.status_code == 200
94
  assert "text/plain" in response.headers["content-type"]
95
 
 
1
  import json
2
  import os
3
 
4
+ import anyio
5
  from faster_whisper_server.api_models import (
6
  CreateTranscriptionResponseJson,
7
  CreateTranscriptionResponseVerboseJson,
8
  )
9
+ from httpx import AsyncClient
10
+ from httpx_sse import aconnect_sse
11
  import pytest
12
  import srt
13
  import webvtt
 
23
  parameters = [(file_path, endpoint) for endpoint in ENDPOINTS for file_path in FILE_PATHS]
24
 
25
 
26
+ @pytest.mark.asyncio()
27
  @pytest.mark.parametrize(("file_path", "endpoint"), parameters)
28
+ async def test_streaming_transcription_text(aclient: AsyncClient, file_path: str, endpoint: str) -> None:
29
  extension = os.path.splitext(file_path)[1]
30
+ async with await anyio.open_file(file_path, "rb") as f:
31
+ data = await f.read()
32
  kwargs = {
33
  "files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
34
  "data": {"response_format": "text", "stream": True},
35
  }
36
+ async with aconnect_sse(aclient, "POST", endpoint, **kwargs) as event_source:
37
+ async for event in event_source.aiter_sse():
38
  print(event)
39
  assert len(event.data) > 1 # HACK: 1 because of the space character that's always prepended
40
 
41
 
42
+ @pytest.mark.asyncio()
43
  @pytest.mark.parametrize(("file_path", "endpoint"), parameters)
44
+ async def test_streaming_transcription_json(aclient: AsyncClient, file_path: str, endpoint: str) -> None:
45
  extension = os.path.splitext(file_path)[1]
46
+ async with await anyio.open_file(file_path, "rb") as f:
47
+ data = await f.read()
48
  kwargs = {
49
  "files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
50
  "data": {"response_format": "json", "stream": True},
51
  }
52
+ async with aconnect_sse(aclient, "POST", endpoint, **kwargs) as event_source:
53
+ async for event in event_source.aiter_sse():
54
  CreateTranscriptionResponseJson(**json.loads(event.data))
55
 
56
 
57
+ @pytest.mark.asyncio()
58
  @pytest.mark.parametrize(("file_path", "endpoint"), parameters)
59
+ async def test_streaming_transcription_verbose_json(aclient: AsyncClient, file_path: str, endpoint: str) -> None:
60
  extension = os.path.splitext(file_path)[1]
61
+ async with await anyio.open_file(file_path, "rb") as f:
62
+ data = await f.read()
63
  kwargs = {
64
  "files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
65
  "data": {"response_format": "verbose_json", "stream": True},
66
  }
67
+ async with aconnect_sse(aclient, "POST", endpoint, **kwargs) as event_source:
68
+ async for event in event_source.aiter_sse():
69
  CreateTranscriptionResponseVerboseJson(**json.loads(event.data))
70
 
71
 
72
+ @pytest.mark.asyncio()
73
+ async def test_transcription_vtt(aclient: AsyncClient) -> None:
74
+ async with await anyio.open_file("audio.wav", "rb") as f:
75
+ data = await f.read()
76
  kwargs = {
77
  "files": {"file": ("audio.wav", data, "audio/wav")},
78
  "data": {"response_format": "vtt", "stream": False},
79
  }
80
+ response = await aclient.post("/v1/audio/transcriptions", **kwargs)
81
  assert response.status_code == 200
82
  assert response.headers["content-type"] == "text/vtt; charset=utf-8"
83
  text = response.text
 
87
  webvtt.from_string(text)
88
 
89
 
90
+ @pytest.mark.asyncio()
91
+ async def test_transcription_srt(aclient: AsyncClient) -> None:
92
+ async with await anyio.open_file("audio.wav", "rb") as f:
93
+ data = await f.read()
94
  kwargs = {
95
  "files": {"file": ("audio.wav", data, "audio/wav")},
96
  "data": {"response_format": "srt", "stream": False},
97
  }
98
+ response = await aclient.post("/v1/audio/transcriptions", **kwargs)
99
  assert response.status_code == 200
100
  assert "text/plain" in response.headers["content-type"]
101
 
uv.lock CHANGED
@@ -295,6 +295,7 @@ client = [
295
  { name = "keyboard" },
296
  ]
297
  dev = [
 
298
  { name = "basedpyright" },
299
  { name = "pytest" },
300
  { name = "pytest-asyncio" },
@@ -306,6 +307,7 @@ dev = [
306
 
307
  [package.metadata]
308
  requires-dist = [
 
309
  { name = "basedpyright", marker = "extra == 'dev'", specifier = "==1.13.0" },
310
  { name = "fastapi", specifier = "==0.112.4" },
311
  { name = "faster-whisper", specifier = "==1.0.3" },
 
295
  { name = "keyboard" },
296
  ]
297
  dev = [
298
+ { name = "anyio" },
299
  { name = "basedpyright" },
300
  { name = "pytest" },
301
  { name = "pytest-asyncio" },
 
307
 
308
  [package.metadata]
309
  requires-dist = [
310
+ { name = "anyio", marker = "extra == 'dev'", specifier = ">=4.4.0" },
311
  { name = "basedpyright", marker = "extra == 'dev'", specifier = "==1.13.0" },
312
  { name = "fastapi", specifier = "==0.112.4" },
313
  { name = "faster-whisper", specifier = "==1.0.3" },