Fedir Zadniprovskyi commited on
Commit
f3632d1
·
1 Parent(s): 608e57c

chore: handle "whisper-1" model name

Browse files
Files changed (1) hide show
  1. faster_whisper_server/main.py +23 -6
faster_whisper_server/main.py CHANGED
@@ -11,6 +11,7 @@ from fastapi import (
11
  FastAPI,
12
  Form,
13
  HTTPException,
 
14
  Query,
15
  Response,
16
  UploadFile,
@@ -22,6 +23,7 @@ from fastapi.websockets import WebSocketState
22
  from faster_whisper import WhisperModel
23
  from faster_whisper.vad import VadOptions, get_speech_timestamps
24
  from huggingface_hub.hf_api import ModelInfo
 
25
 
26
  from faster_whisper_server import utils
27
  from faster_whisper_server.asr import FasterWhisperASR
@@ -85,7 +87,7 @@ def health() -> Response:
85
  return Response(status_code=200, content="OK")
86
 
87
 
88
- @app.get("/v1/models", response_model=list[ModelObject])
89
  def get_models() -> list[ModelObject]:
90
  models = huggingface_hub.list_models(library="ctranslate2")
91
  models = [
@@ -101,8 +103,8 @@ def get_models() -> list[ModelObject]:
101
  return models
102
 
103
 
104
- @app.get("/v1/models/{model_name:path}", response_model=ModelObject)
105
- def get_model(model_name: str) -> ModelObject:
106
  models = list(
107
  huggingface_hub.list_models(model_name=model_name, library="ctranslate2")
108
  )
@@ -131,10 +133,25 @@ def format_as_sse(data: str) -> str:
131
  return f"data: {data}\n\n"
132
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  @app.post("/v1/audio/translations")
135
  def translate_file(
136
  file: Annotated[UploadFile, Form()],
137
- model: Annotated[str, Form()] = config.whisper.model,
138
  prompt: Annotated[str | None, Form()] = None,
139
  response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
140
  temperature: Annotated[float, Form()] = 0.0,
@@ -187,7 +204,7 @@ def translate_file(
187
  @app.post("/v1/audio/transcriptions")
188
  def transcribe_file(
189
  file: Annotated[UploadFile, Form()],
190
- model: Annotated[str, Form()] = config.whisper.model,
191
  language: Annotated[Language | None, Form()] = config.default_language,
192
  prompt: Annotated[str | None, Form()] = None,
193
  response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
@@ -289,7 +306,7 @@ async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
289
  @app.websocket("/v1/audio/transcriptions")
290
  async def transcribe_stream(
291
  ws: WebSocket,
292
- model: Annotated[str, Query()] = config.whisper.model,
293
  language: Annotated[Language | None, Query()] = config.default_language,
294
  response_format: Annotated[
295
  ResponseFormat, Query()
 
11
  FastAPI,
12
  Form,
13
  HTTPException,
14
+ Path,
15
  Query,
16
  Response,
17
  UploadFile,
 
23
  from faster_whisper import WhisperModel
24
  from faster_whisper.vad import VadOptions, get_speech_timestamps
25
  from huggingface_hub.hf_api import ModelInfo
26
+ from pydantic import AfterValidator
27
 
28
  from faster_whisper_server import utils
29
  from faster_whisper_server.asr import FasterWhisperASR
 
87
  return Response(status_code=200, content="OK")
88
 
89
 
90
+ @app.get("/v1/models")
91
  def get_models() -> list[ModelObject]:
92
  models = huggingface_hub.list_models(library="ctranslate2")
93
  models = [
 
103
  return models
104
 
105
 
106
+ @app.get("/v1/models/{model_name:path}")
107
+ def get_model(model_name: Annotated[str, Path()]) -> ModelObject:
108
  models = list(
109
  huggingface_hub.list_models(model_name=model_name, library="ctranslate2")
110
  )
 
133
  return f"data: {data}\n\n"
134
 
135
 
136
+ def handle_default_openai_model(model_name: str) -> str:
137
+ """This exists because some callers may not be able override the default("whisper-1") model name.
138
+ For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623.
139
+ """
140
+ if model_name == "whisper-1":
141
+ logger.info(
142
+ f"{model_name} is not a valid model name. Using {config.whisper.model} instead."
143
+ )
144
+ return config.whisper.model
145
+ return model_name
146
+
147
+
148
+ ModelName = Annotated[str, AfterValidator(handle_default_openai_model)]
149
+
150
+
151
  @app.post("/v1/audio/translations")
152
  def translate_file(
153
  file: Annotated[UploadFile, Form()],
154
+ model: Annotated[ModelName, Form()] = config.whisper.model,
155
  prompt: Annotated[str | None, Form()] = None,
156
  response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
157
  temperature: Annotated[float, Form()] = 0.0,
 
204
  @app.post("/v1/audio/transcriptions")
205
  def transcribe_file(
206
  file: Annotated[UploadFile, Form()],
207
+ model: Annotated[ModelName, Form()] = config.whisper.model,
208
  language: Annotated[Language | None, Form()] = config.default_language,
209
  prompt: Annotated[str | None, Form()] = None,
210
  response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
 
306
  @app.websocket("/v1/audio/transcriptions")
307
  async def transcribe_stream(
308
  ws: WebSocket,
309
+ model: Annotated[ModelName, Query()] = config.whisper.model,
310
  language: Annotated[Language | None, Query()] = config.default_language,
311
  response_format: Annotated[
312
  ResponseFormat, Query()