Fedir Zadniprovskyi commited on
Commit
14908c1
·
1 Parent(s): d9a361c

extract segments to response logic

Browse files
Files changed (1) hide show
  1. faster_whisper_server/main.py +48 -62
faster_whisper_server/main.py CHANGED
@@ -4,7 +4,7 @@ import asyncio
4
  import time
5
  from contextlib import asynccontextmanager
6
  from io import BytesIO
7
- from typing import Annotated, Generator, Literal, OrderedDict
8
 
9
  import huggingface_hub
10
  from fastapi import (
@@ -21,6 +21,7 @@ from fastapi import (
21
  from fastapi.responses import StreamingResponse
22
  from fastapi.websockets import WebSocketState
23
  from faster_whisper import WhisperModel
 
24
  from faster_whisper.vad import VadOptions, get_speech_timestamps
25
  from huggingface_hub.hf_api import ModelInfo
26
  from pydantic import AfterValidator
@@ -132,10 +133,48 @@ def get_model(
132
  )
133
 
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def format_as_sse(data: str) -> str:
136
  return f"data: {data}\n\n"
137
 
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  def handle_default_openai_model(model_name: str) -> str:
140
  """This exists because some callers may not be able override the default("whisper-1") model name.
141
  For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623.
@@ -168,7 +207,6 @@ def translate_file(
168
  | TranscriptionVerboseJsonResponse
169
  | StreamingResponse
170
  ):
171
- start = time.perf_counter()
172
  whisper = load_model(model)
173
  segments, transcription_info = whisper.transcribe(
174
  file.file,
@@ -178,36 +216,12 @@ def translate_file(
178
  vad_filter=True,
179
  )
180
 
181
- if not stream:
182
- segments = list(segments)
183
- logger.info(
184
- f"Translated {transcription_info.duration}({transcription_info.duration_after_vad}) seconds of audio in {time.perf_counter() - start:.2f} seconds"
185
  )
186
- if response_format == ResponseFormat.TEXT:
187
- return utils.segments_text(segments)
188
- elif response_format == ResponseFormat.JSON:
189
- return TranscriptionJsonResponse.from_segments(segments)
190
- elif response_format == ResponseFormat.VERBOSE_JSON:
191
- return TranscriptionVerboseJsonResponse.from_segments(
192
- segments, transcription_info
193
- )
194
  else:
195
-
196
- def segment_responses() -> Generator[str, None, None]:
197
- for segment in segments:
198
- if response_format == ResponseFormat.TEXT:
199
- data = segment.text
200
- elif response_format == ResponseFormat.JSON:
201
- data = TranscriptionJsonResponse.from_segments(
202
- [segment]
203
- ).model_dump_json()
204
- elif response_format == ResponseFormat.VERBOSE_JSON:
205
- data = TranscriptionVerboseJsonResponse.from_segment(
206
- segment, transcription_info
207
- ).model_dump_json()
208
- yield format_as_sse(data)
209
-
210
- return StreamingResponse(segment_responses(), media_type="text/event-stream")
211
 
212
 
213
  # https://platform.openai.com/docs/api-reference/audio/createTranscription
@@ -234,7 +248,6 @@ def transcribe_file(
234
  | TranscriptionVerboseJsonResponse
235
  | StreamingResponse
236
  ):
237
- start = time.perf_counter()
238
  whisper = load_model(model)
239
  segments, transcription_info = whisper.transcribe(
240
  file.file,
@@ -246,39 +259,12 @@ def transcribe_file(
246
  vad_filter=True,
247
  )
248
 
249
- if not stream:
250
- segments = list(segments)
251
- logger.info(
252
- f"Transcribed {transcription_info.duration}({transcription_info.duration_after_vad}) seconds of audio in {time.perf_counter() - start:.2f} seconds"
253
  )
254
- if response_format == ResponseFormat.TEXT:
255
- return utils.segments_text(segments)
256
- elif response_format == ResponseFormat.JSON:
257
- return TranscriptionJsonResponse.from_segments(segments)
258
- elif response_format == ResponseFormat.VERBOSE_JSON:
259
- return TranscriptionVerboseJsonResponse.from_segments(
260
- segments, transcription_info
261
- )
262
  else:
263
-
264
- def segment_responses() -> Generator[str, None, None]:
265
- for segment in segments:
266
- logger.info(
267
- f"Transcribed {segment.end - segment.start} seconds of audio in {time.perf_counter() - start:.2f} seconds"
268
- )
269
- if response_format == ResponseFormat.TEXT:
270
- data = segment.text
271
- elif response_format == ResponseFormat.JSON:
272
- data = TranscriptionJsonResponse.from_segments(
273
- [segment]
274
- ).model_dump_json()
275
- elif response_format == ResponseFormat.VERBOSE_JSON:
276
- data = TranscriptionVerboseJsonResponse.from_segment(
277
- segment, transcription_info
278
- ).model_dump_json()
279
- yield format_as_sse(data)
280
-
281
- return StreamingResponse(segment_responses(), media_type="text/event-stream")
282
 
283
 
284
  async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
 
4
  import time
5
  from contextlib import asynccontextmanager
6
  from io import BytesIO
7
+ from typing import Annotated, Generator, Iterable, Literal, OrderedDict
8
 
9
  import huggingface_hub
10
  from fastapi import (
 
21
  from fastapi.responses import StreamingResponse
22
  from fastapi.websockets import WebSocketState
23
  from faster_whisper import WhisperModel
24
+ from faster_whisper.transcribe import Segment, TranscriptionInfo
25
  from faster_whisper.vad import VadOptions, get_speech_timestamps
26
  from huggingface_hub.hf_api import ModelInfo
27
  from pydantic import AfterValidator
 
133
  )
134
 
135
 
136
+ def segments_to_response(
137
+ segments: Iterable[Segment],
138
+ transcription_info: TranscriptionInfo,
139
+ response_format: ResponseFormat,
140
+ ) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse:
141
+ segments = list(segments)
142
+ if response_format == ResponseFormat.TEXT:
143
+ return utils.segments_text(segments)
144
+ elif response_format == ResponseFormat.JSON:
145
+ return TranscriptionJsonResponse.from_segments(segments)
146
+ elif response_format == ResponseFormat.VERBOSE_JSON:
147
+ return TranscriptionVerboseJsonResponse.from_segments(
148
+ segments, transcription_info
149
+ )
150
+
151
+
152
  def format_as_sse(data: str) -> str:
153
  return f"data: {data}\n\n"
154
 
155
 
156
+ def segments_to_streaming_response(
157
+ segments: Iterable[Segment],
158
+ transcription_info: TranscriptionInfo,
159
+ response_format: ResponseFormat,
160
+ ) -> StreamingResponse:
161
+ def segment_responses() -> Generator[str, None, None]:
162
+ for segment in segments:
163
+ if response_format == ResponseFormat.TEXT:
164
+ data = segment.text
165
+ elif response_format == ResponseFormat.JSON:
166
+ data = TranscriptionJsonResponse.from_segments(
167
+ [segment]
168
+ ).model_dump_json()
169
+ elif response_format == ResponseFormat.VERBOSE_JSON:
170
+ data = TranscriptionVerboseJsonResponse.from_segment(
171
+ segment, transcription_info
172
+ ).model_dump_json()
173
+ yield format_as_sse(data)
174
+
175
+ return StreamingResponse(segment_responses(), media_type="text/event-stream")
176
+
177
+
178
  def handle_default_openai_model(model_name: str) -> str:
179
  """This exists because some callers may not be able override the default("whisper-1") model name.
180
  For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623.
 
207
  | TranscriptionVerboseJsonResponse
208
  | StreamingResponse
209
  ):
 
210
  whisper = load_model(model)
211
  segments, transcription_info = whisper.transcribe(
212
  file.file,
 
216
  vad_filter=True,
217
  )
218
 
219
+ if stream:
220
+ return segments_to_streaming_response(
221
+ segments, transcription_info, response_format
 
222
  )
 
 
 
 
 
 
 
 
223
  else:
224
+ return segments_to_response(segments, transcription_info, response_format)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
 
227
  # https://platform.openai.com/docs/api-reference/audio/createTranscription
 
248
  | TranscriptionVerboseJsonResponse
249
  | StreamingResponse
250
  ):
 
251
  whisper = load_model(model)
252
  segments, transcription_info = whisper.transcribe(
253
  file.file,
 
259
  vad_filter=True,
260
  )
261
 
262
+ if stream:
263
+ return segments_to_streaming_response(
264
+ segments, transcription_info, response_format
 
265
  )
 
 
 
 
 
 
 
 
266
  else:
267
+ return segments_to_response(segments, transcription_info, response_format)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
 
270
  async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None: