Fedir Zadniprovskyi commited on
Commit
2c38ce0
·
1 Parent(s): 20b7748

feat: api route to download model

Browse files
Files changed (1) hide show
  1. src/faster_whisper_server/main.py +13 -0
src/faster_whisper_server/main.py CHANGED
@@ -25,8 +25,10 @@ from fastapi.websockets import WebSocketState
25
  from faster_whisper import WhisperModel
26
  from faster_whisper.vad import VadOptions, get_speech_timestamps
27
  import huggingface_hub
 
28
  from pydantic import AfterValidator
29
 
 
30
  from faster_whisper_server.asr import FasterWhisperASR
31
  from faster_whisper_server.audio import AudioStream, audio_samples_from_file
32
  from faster_whisper_server.config import (
@@ -108,6 +110,17 @@ def health() -> Response:
108
  return Response(status_code=200, content="OK")
109
 
110
 
 
 
 
 
 
 
 
 
 
 
 
111
  @app.get("/api/ps", tags=["experimental"], summary="Get a list of loaded models.")
112
  def get_running_models() -> dict[str, list[str]]:
113
  return {"models": list(loaded_models.keys())}
 
25
  from faster_whisper import WhisperModel
26
  from faster_whisper.vad import VadOptions, get_speech_timestamps
27
  import huggingface_hub
28
+ from huggingface_hub.hf_api import RepositoryNotFoundError
29
  from pydantic import AfterValidator
30
 
31
+ from faster_whisper_server import hf_utils
32
  from faster_whisper_server.asr import FasterWhisperASR
33
  from faster_whisper_server.audio import AudioStream, audio_samples_from_file
34
  from faster_whisper_server.config import (
 
110
  return Response(status_code=200, content="OK")
111
 
112
 
113
+ @app.post("/api/pull/{model_name:path}", tags=["experimental"], summary="Download a model from Hugging Face.")
114
+ def pull_model(model_name: str) -> Response:
115
+ if hf_utils.does_local_model_exist(model_name):
116
+ return Response(status_code=200, content="Model already exists")
117
+ try:
118
+ huggingface_hub.snapshot_download(model_name, repo_type="model")
119
+ except RepositoryNotFoundError as e:
120
+ return Response(status_code=404, content=str(e))
121
+ return Response(status_code=201, content="Model downloaded")
122
+
123
+
124
  @app.get("/api/ps", tags=["experimental"], summary="Get a list of loaded models.")
125
  def get_running_models() -> dict[str, list[str]]:
126
  return {"models": list(loaded_models.keys())}