Fedir Zadniprovskyi commited on
Commit
8c12cdc
·
1 Parent(s): 8900179

fix: streaming doesn't use sse #15

Browse files
examples/youtube/script.sh CHANGED
@@ -14,7 +14,7 @@ docker run --detach --gpus=all --publish 8000:8000 --volume ~/.cache/huggingface
14
  youtube-dl --extract-audio --audio-format mp3 -o the-evolution-of-the-operating-system.mp3 'https://www.youtube.com/watch?v=1lG7lFLXBIs'
15
 
16
  # Make a request to the API to transcribe the audio. The response will be streamed to the terminal and saved to a file. The video is 30 minutes long, so it might take a while to transcribe, especially if you are running this on a CPU. `Systran/faster-distil-whisper-large-v3` takes ~30 seconds on Nvidia L4. `Systran/faster-whisper-tiny.en` takes ~1 minute on Ryzen 7 7700X. The .txt file in the example was transcribed using `Systran/faster-distil-whisper-large-v3`.
17
- curl -s http://localhost:8000/v1/audio/transcriptions -F "[email protected]" -F "stream=true" -F "language=en" -F "response_format=text" | tee the-evolution-of-the-operating-system.txt
18
 
19
  # Here I'm using `aichat` which is a CLI LLM client. You could use any other client that supports attaching/uploading files. https://github.com/sigoden/aichat
20
  aichat -m openai:gpt-4o -f the-evolution-of-the-operating-system.txt 'What companies are mentioned in the following Youtube video transcription? Responed with just a list of names'
 
14
  youtube-dl --extract-audio --audio-format mp3 -o the-evolution-of-the-operating-system.mp3 'https://www.youtube.com/watch?v=1lG7lFLXBIs'
15
 
16
  # Make a request to the API to transcribe the audio. The response will be streamed to the terminal and saved to a file. The video is 30 minutes long, so it might take a while to transcribe, especially if you are running this on a CPU. `Systran/faster-distil-whisper-large-v3` takes ~30 seconds on Nvidia L4. `Systran/faster-whisper-tiny.en` takes ~1 minute on Ryzen 7 7700X. The .txt file in the example was transcribed using `Systran/faster-distil-whisper-large-v3`.
17
+ curl -s http://localhost:8000/v1/audio/transcriptions -F "[email protected]" -F "language=en" -F "response_format=text" | tee the-evolution-of-the-operating-system.txt
18
 
19
  # Here I'm using `aichat` which is a CLI LLM client. You could use any other client that supports attaching/uploading files. https://github.com/sigoden/aichat
20
  aichat -m openai:gpt-4o -f the-evolution-of-the-operating-system.txt 'What companies are mentioned in the following Youtube video transcription? Responed with just a list of names'
faster_whisper_server/main.py CHANGED
@@ -4,7 +4,7 @@ import asyncio
4
  import time
5
  from contextlib import asynccontextmanager
6
  from io import BytesIO
7
- from typing import Annotated, Literal, OrderedDict
8
 
9
  import huggingface_hub
10
  from fastapi import (
@@ -127,6 +127,10 @@ def get_model(model_name: str) -> ModelObject:
127
  )
128
 
129
 
 
 
 
 
130
  @app.post("/v1/audio/translations")
131
  def translate_file(
132
  file: Annotated[UploadFile, Form()],
@@ -146,19 +150,6 @@ def translate_file(
146
  vad_filter=True,
147
  )
148
 
149
- def segment_responses():
150
- for segment in segments:
151
- if response_format == ResponseFormat.TEXT:
152
- yield segment.text
153
- elif response_format == ResponseFormat.JSON:
154
- yield TranscriptionJsonResponse.from_segments(
155
- [segment]
156
- ).model_dump_json()
157
- elif response_format == ResponseFormat.VERBOSE_JSON:
158
- yield TranscriptionVerboseJsonResponse.from_segment(
159
- segment, transcription_info
160
- ).model_dump_json()
161
-
162
  if not stream:
163
  segments = list(segments)
164
  logger.info(
@@ -173,6 +164,21 @@ def translate_file(
173
  segments, transcription_info
174
  )
175
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  return StreamingResponse(segment_responses(), media_type="text/event-stream")
177
 
178
 
@@ -204,22 +210,6 @@ def transcribe_file(
204
  vad_filter=True,
205
  )
206
 
207
- def segment_responses():
208
- for segment in segments:
209
- logger.info(
210
- f"Transcribed {segment.end - segment.start} seconds of audio in {time.perf_counter() - start:.2f} seconds"
211
- )
212
- if response_format == ResponseFormat.TEXT:
213
- yield segment.text
214
- elif response_format == ResponseFormat.JSON:
215
- yield TranscriptionJsonResponse.from_segments(
216
- [segment]
217
- ).model_dump_json()
218
- elif response_format == ResponseFormat.VERBOSE_JSON:
219
- yield TranscriptionVerboseJsonResponse.from_segment(
220
- segment, transcription_info
221
- ).model_dump_json()
222
-
223
  if not stream:
224
  segments = list(segments)
225
  logger.info(
@@ -234,6 +224,24 @@ def transcribe_file(
234
  segments, transcription_info
235
  )
236
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  return StreamingResponse(segment_responses(), media_type="text/event-stream")
238
 
239
 
 
4
  import time
5
  from contextlib import asynccontextmanager
6
  from io import BytesIO
7
+ from typing import Annotated, Generator, Literal, OrderedDict
8
 
9
  import huggingface_hub
10
  from fastapi import (
 
127
  )
128
 
129
 
130
+ def format_as_sse(data: str) -> str:
131
+ return f"data: {data}\n\n"
132
+
133
+
134
  @app.post("/v1/audio/translations")
135
  def translate_file(
136
  file: Annotated[UploadFile, Form()],
 
150
  vad_filter=True,
151
  )
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  if not stream:
154
  segments = list(segments)
155
  logger.info(
 
164
  segments, transcription_info
165
  )
166
  else:
167
+
168
+ def segment_responses() -> Generator[str, None, None]:
169
+ for segment in segments:
170
+ if response_format == ResponseFormat.TEXT:
171
+ data = segment.text
172
+ elif response_format == ResponseFormat.JSON:
173
+ data = TranscriptionJsonResponse.from_segments(
174
+ [segment]
175
+ ).model_dump_json()
176
+ elif response_format == ResponseFormat.VERBOSE_JSON:
177
+ data = TranscriptionVerboseJsonResponse.from_segment(
178
+ segment, transcription_info
179
+ ).model_dump_json()
180
+ yield format_as_sse(data)
181
+
182
  return StreamingResponse(segment_responses(), media_type="text/event-stream")
183
 
184
 
 
210
  vad_filter=True,
211
  )
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  if not stream:
214
  segments = list(segments)
215
  logger.info(
 
224
  segments, transcription_info
225
  )
226
  else:
227
+
228
+ def segment_responses() -> Generator[str, None, None]:
229
+ for segment in segments:
230
+ logger.info(
231
+ f"Transcribed {segment.end - segment.start} seconds of audio in {time.perf_counter() - start:.2f} seconds"
232
+ )
233
+ if response_format == ResponseFormat.TEXT:
234
+ data = segment.text
235
+ elif response_format == ResponseFormat.JSON:
236
+ data = TranscriptionJsonResponse.from_segments(
237
+ [segment]
238
+ ).model_dump_json()
239
+ elif response_format == ResponseFormat.VERBOSE_JSON:
240
+ data = TranscriptionVerboseJsonResponse.from_segment(
241
+ segment, transcription_info
242
+ ).model_dump_json()
243
+ yield format_as_sse(data)
244
+
245
  return StreamingResponse(segment_responses(), media_type="text/event-stream")
246
 
247
 
tests/sse_test.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import Generator
4
+
5
+ import pytest
6
+ from fastapi.testclient import TestClient
7
+ from httpx_sse import connect_sse
8
+
9
+ from faster_whisper_server.main import app
10
+ from faster_whisper_server.server_models import (
11
+ TranscriptionJsonResponse,
12
+ TranscriptionVerboseJsonResponse,
13
+ )
14
+
15
+
16
+ @pytest.fixture()
17
+ def client() -> Generator[TestClient, None, None]:
18
+ with TestClient(app) as client:
19
+ yield client
20
+
21
+
22
+ FILE_PATHS = ["audio.wav"] # HACK
23
+ ENDPOINTS = [
24
+ "/v1/audio/transcriptions",
25
+ "/v1/audio/translations",
26
+ ]
27
+
28
+
29
+ parameters = [
30
+ (file_path, endpoint) for endpoint in ENDPOINTS for file_path in FILE_PATHS
31
+ ]
32
+
33
+
34
+ @pytest.mark.parametrize("file_path,endpoint", parameters)
35
+ def test_streaming_transcription_text(
36
+ client: TestClient, file_path: str, endpoint: str
37
+ ):
38
+ extension = os.path.splitext(file_path)[1]
39
+ with open(file_path, "rb") as f:
40
+ data = f.read()
41
+ kwargs = {
42
+ "files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
43
+ "data": {"response_format": "text", "stream": True},
44
+ }
45
+ with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
46
+ for event in event_source.iter_sse():
47
+ print(event)
48
+ assert (
49
+ len(event.data) > 1
50
+ ) # HACK: 1 because of the space character that's always prepended
51
+
52
+
53
+ @pytest.mark.parametrize("file_path,endpoint", parameters)
54
+ def test_streaming_transcription_json(
55
+ client: TestClient, file_path: str, endpoint: str
56
+ ):
57
+ extension = os.path.splitext(file_path)[1]
58
+ with open(file_path, "rb") as f:
59
+ data = f.read()
60
+ kwargs = {
61
+ "files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
62
+ "data": {"response_format": "json", "stream": True},
63
+ }
64
+ with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
65
+ for event in event_source.iter_sse():
66
+ TranscriptionJsonResponse(**json.loads(event.data))
67
+
68
+
69
+ @pytest.mark.parametrize("file_path,endpoint", parameters)
70
+ def test_streaming_transcription_verbose_json(
71
+ client: TestClient, file_path: str, endpoint: str
72
+ ):
73
+ extension = os.path.splitext(file_path)[1]
74
+ with open(file_path, "rb") as f:
75
+ data = f.read()
76
+ kwargs = {
77
+ "files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
78
+ "data": {"response_format": "verbose_json", "stream": True},
79
+ }
80
+ with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
81
+ for event in event_source.iter_sse():
82
+ TranscriptionVerboseJsonResponse(**json.loads(event.data))