Spaces:
Configuration error
Configuration error
Fedir Zadniprovskyi
commited on
Commit
·
aada575
1
Parent(s):
6127072
feat: support loading multiple models
Browse files- speaches/config.py +20 -18
- speaches/main.py +46 -48
speaches/config.py
CHANGED
@@ -163,39 +163,41 @@ class Language(enum.StrEnum):
|
|
163 |
|
164 |
|
165 |
class WhisperConfig(BaseModel):
|
166 |
-
model: Model = Field(default=Model.DISTIL_MEDIUM_EN)
|
167 |
-
inference_device: Device = Field(
|
168 |
-
|
169 |
-
) # ENV: WHISPER_INFERENCE_DEVICE
|
170 |
-
compute_type: Quantization = Field(
|
171 |
-
default=Quantization.DEFAULT
|
172 |
-
) # ENV: WHISPER_COMPUTE_TYPE
|
173 |
|
174 |
|
175 |
class Config(BaseSettings):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
model_config = SettingsConfigDict(env_nested_delimiter="_")
|
177 |
|
178 |
-
log_level: str = "info"
|
179 |
-
default_language: Language | None = None
|
180 |
-
default_response_format: ResponseFormat =
|
181 |
-
|
182 |
-
|
183 |
-
whisper: WhisperConfig = WhisperConfig() # ENV: WHISPER_*
|
184 |
"""
|
185 |
Max duration to for the next audio chunk before transcription is finilized and connection is closed.
|
186 |
"""
|
187 |
-
max_no_data_seconds: float = 1.0
|
188 |
-
min_duration: float = 1.0
|
189 |
-
word_timestamp_error_margin: float = 0.2
|
190 |
"""
|
191 |
Max allowed audio duration without any speech being detected before transcription is finilized and connection is closed.
|
192 |
"""
|
193 |
-
max_inactivity_seconds: float = 2.0
|
194 |
"""
|
195 |
Controls how many latest seconds of audio are being passed through VAD.
|
196 |
Should be greater than `max_inactivity_seconds`
|
197 |
"""
|
198 |
-
inactivity_window_seconds: float = 3.0
|
199 |
|
200 |
|
201 |
config = Config()
|
|
|
163 |
|
164 |
|
165 |
class WhisperConfig(BaseModel):
|
166 |
+
model: Model = Field(default=Model.DISTIL_MEDIUM_EN)
|
167 |
+
inference_device: Device = Field(default=Device.AUTO)
|
168 |
+
compute_type: Quantization = Field(default=Quantization.DEFAULT)
|
|
|
|
|
|
|
|
|
169 |
|
170 |
|
171 |
class Config(BaseSettings):
|
172 |
+
"""
|
173 |
+
Configuration for the application. Values can be set via environment variables.
|
174 |
+
Pydantic will automatically handle mapping uppercased environment variables to the corresponding fields.
|
175 |
+
To populate nested, the environment should be prefixed with the nested field name and an underscore. For example,
|
176 |
+
the environment variable `LOG_LEVEL` will be mapped to `log_level`, `WHISPER_MODEL` to `whisper.model`, etc.
|
177 |
+
"""
|
178 |
+
|
179 |
model_config = SettingsConfigDict(env_nested_delimiter="_")
|
180 |
|
181 |
+
log_level: str = "info"
|
182 |
+
default_language: Language | None = None
|
183 |
+
default_response_format: ResponseFormat = ResponseFormat.JSON
|
184 |
+
whisper: WhisperConfig = WhisperConfig()
|
185 |
+
max_models: int = 1
|
|
|
186 |
"""
|
187 |
Max duration to for the next audio chunk before transcription is finilized and connection is closed.
|
188 |
"""
|
189 |
+
max_no_data_seconds: float = 1.0
|
190 |
+
min_duration: float = 1.0
|
191 |
+
word_timestamp_error_margin: float = 0.2
|
192 |
"""
|
193 |
Max allowed audio duration without any speech being detected before transcription is finilized and connection is closed.
|
194 |
"""
|
195 |
+
max_inactivity_seconds: float = 2.0
|
196 |
"""
|
197 |
Controls how many latest seconds of audio are being passed through VAD.
|
198 |
Should be greater than `max_inactivity_seconds`
|
199 |
"""
|
200 |
+
inactivity_window_seconds: float = 3.0
|
201 |
|
202 |
|
203 |
config = Config()
|
speaches/main.py
CHANGED
@@ -1,11 +1,10 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import asyncio
|
4 |
-
import logging
|
5 |
import time
|
6 |
from contextlib import asynccontextmanager
|
7 |
from io import BytesIO
|
8 |
-
from typing import Annotated, Literal
|
9 |
|
10 |
from fastapi import (FastAPI, Form, Query, Response, UploadFile, WebSocket,
|
11 |
WebSocketDisconnect)
|
@@ -19,29 +18,45 @@ from speaches.asr import FasterWhisperASR
|
|
19 |
from speaches.audio import AudioStream, audio_samples_from_file
|
20 |
from speaches.config import (SAMPLES_PER_SECOND, Language, Model,
|
21 |
ResponseFormat, config)
|
22 |
-
from speaches.core import Transcription
|
23 |
from speaches.logger import logger
|
24 |
from speaches.server_models import (TranscriptionJsonResponse,
|
25 |
TranscriptionVerboseJsonResponse)
|
26 |
from speaches.transcriber import audio_transcriber
|
27 |
|
28 |
-
|
29 |
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
start = time.perf_counter()
|
36 |
whisper = WhisperModel(
|
37 |
-
|
38 |
device=config.whisper.inference_device,
|
39 |
compute_type=config.whisper.compute_type,
|
40 |
)
|
41 |
-
logger.
|
42 |
-
f"Loaded {
|
43 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
yield
|
|
|
|
|
|
|
45 |
|
46 |
|
47 |
app = FastAPI(lifespan=lifespan)
|
@@ -53,7 +68,7 @@ def health() -> Response:
|
|
53 |
|
54 |
|
55 |
@app.post("/v1/audio/translations")
|
56 |
-
|
57 |
file: Annotated[UploadFile, Form()],
|
58 |
model: Annotated[Model, Form()] = config.whisper.model,
|
59 |
prompt: Annotated[str | None, Form()] = None,
|
@@ -61,11 +76,8 @@ async def translate_file(
|
|
61 |
temperature: Annotated[float, Form()] = 0.0,
|
62 |
stream: Annotated[bool, Form()] = False,
|
63 |
):
|
64 |
-
if model != config.whisper.model:
|
65 |
-
logger.warning(
|
66 |
-
f"Specifying a model that is different from the default is not supported yet. Using {config.whisper.model}."
|
67 |
-
)
|
68 |
start = time.perf_counter()
|
|
|
69 |
segments, transcription_info = whisper.transcribe(
|
70 |
file.file,
|
71 |
task="translate",
|
@@ -107,7 +119,7 @@ async def translate_file(
|
|
107 |
# https://platform.openai.com/docs/api-reference/audio/createTranscription
|
108 |
# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915
|
109 |
@app.post("/v1/audio/transcriptions")
|
110 |
-
|
111 |
file: Annotated[UploadFile, Form()],
|
112 |
model: Annotated[Model, Form()] = config.whisper.model,
|
113 |
language: Annotated[Language | None, Form()] = config.default_language,
|
@@ -120,11 +132,8 @@ async def transcribe_file(
|
|
120 |
] = ["segments"],
|
121 |
stream: Annotated[bool, Form()] = False,
|
122 |
):
|
123 |
-
if model != config.whisper.model:
|
124 |
-
logger.warning(
|
125 |
-
f"Specifying a model that is different from the default is not supported yet. Using {config.whisper.model}."
|
126 |
-
)
|
127 |
start = time.perf_counter()
|
|
|
128 |
segments, transcription_info = whisper.transcribe(
|
129 |
file.file,
|
130 |
task="transcribe",
|
@@ -209,21 +218,6 @@ async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
|
|
209 |
audio_stream.close()
|
210 |
|
211 |
|
212 |
-
def format_transcription(
|
213 |
-
transcription: Transcription, response_format: ResponseFormat
|
214 |
-
) -> str:
|
215 |
-
if response_format == ResponseFormat.TEXT:
|
216 |
-
return transcription.text
|
217 |
-
elif response_format == ResponseFormat.JSON:
|
218 |
-
return TranscriptionJsonResponse.from_transcription(
|
219 |
-
transcription
|
220 |
-
).model_dump_json()
|
221 |
-
elif response_format == ResponseFormat.VERBOSE_JSON:
|
222 |
-
return TranscriptionVerboseJsonResponse.from_transcription(
|
223 |
-
transcription
|
224 |
-
).model_dump_json()
|
225 |
-
|
226 |
-
|
227 |
@app.websocket("/v1/audio/transcriptions")
|
228 |
async def transcribe_stream(
|
229 |
ws: WebSocket,
|
@@ -234,18 +228,7 @@ async def transcribe_stream(
|
|
234 |
ResponseFormat, Query()
|
235 |
] = config.default_response_format,
|
236 |
temperature: Annotated[float, Query()] = 0.0,
|
237 |
-
timestamp_granularities: Annotated[
|
238 |
-
list[Literal["segments"] | Literal["words"]],
|
239 |
-
Query(
|
240 |
-
alias="timestamp_granularities[]",
|
241 |
-
description="No-op. Ignored. Only for compatibility.",
|
242 |
-
),
|
243 |
-
] = ["segments", "words"],
|
244 |
) -> None:
|
245 |
-
if model != config.whisper.model:
|
246 |
-
logger.warning(
|
247 |
-
f"Specifying a model that is different from the default is not supported yet. Using {config.whisper.model}."
|
248 |
-
)
|
249 |
await ws.accept()
|
250 |
transcribe_opts = {
|
251 |
"language": language,
|
@@ -254,6 +237,7 @@ async def transcribe_stream(
|
|
254 |
"vad_filter": True,
|
255 |
"condition_on_previous_text": False,
|
256 |
}
|
|
|
257 |
asr = FasterWhisperASR(whisper, **transcribe_opts)
|
258 |
audio_stream = AudioStream()
|
259 |
async with asyncio.TaskGroup() as tg:
|
@@ -262,7 +246,21 @@ async def transcribe_stream(
|
|
262 |
logger.debug(f"Sending transcription: {transcription.text}")
|
263 |
if ws.client_state == WebSocketState.DISCONNECTED:
|
264 |
break
|
265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
|
267 |
if not ws.client_state == WebSocketState.DISCONNECTED:
|
268 |
logger.info("Closing the connection.")
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import asyncio
|
|
|
4 |
import time
|
5 |
from contextlib import asynccontextmanager
|
6 |
from io import BytesIO
|
7 |
+
from typing import Annotated, Literal, OrderedDict
|
8 |
|
9 |
from fastapi import (FastAPI, Form, Query, Response, UploadFile, WebSocket,
|
10 |
WebSocketDisconnect)
|
|
|
18 |
from speaches.audio import AudioStream, audio_samples_from_file
|
19 |
from speaches.config import (SAMPLES_PER_SECOND, Language, Model,
|
20 |
ResponseFormat, config)
|
|
|
21 |
from speaches.logger import logger
|
22 |
from speaches.server_models import (TranscriptionJsonResponse,
|
23 |
TranscriptionVerboseJsonResponse)
|
24 |
from speaches.transcriber import audio_transcriber
|
25 |
|
26 |
+
models: OrderedDict[Model, WhisperModel] = OrderedDict()
|
27 |
|
28 |
|
29 |
+
def load_model(model_name: Model) -> WhisperModel:
|
30 |
+
if model_name in models:
|
31 |
+
logger.debug(f"{model_name} model already loaded")
|
32 |
+
return models[model_name]
|
33 |
+
if len(models) >= config.max_models:
|
34 |
+
oldest_model_name = next(iter(models))
|
35 |
+
logger.info(
|
36 |
+
f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}"
|
37 |
+
)
|
38 |
+
del models[oldest_model_name]
|
39 |
+
logger.debug(f"Loading {model_name}")
|
40 |
start = time.perf_counter()
|
41 |
whisper = WhisperModel(
|
42 |
+
model_name,
|
43 |
device=config.whisper.inference_device,
|
44 |
compute_type=config.whisper.compute_type,
|
45 |
)
|
46 |
+
logger.info(
|
47 |
+
f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds"
|
48 |
)
|
49 |
+
models[model_name] = whisper
|
50 |
+
return whisper
|
51 |
+
|
52 |
+
|
53 |
+
@asynccontextmanager
|
54 |
+
async def lifespan(_: FastAPI):
|
55 |
+
load_model(config.whisper.model)
|
56 |
yield
|
57 |
+
for model in models.keys():
|
58 |
+
logger.info(f"Unloading {model}")
|
59 |
+
del models[model]
|
60 |
|
61 |
|
62 |
app = FastAPI(lifespan=lifespan)
|
|
|
68 |
|
69 |
|
70 |
@app.post("/v1/audio/translations")
|
71 |
+
def translate_file(
|
72 |
file: Annotated[UploadFile, Form()],
|
73 |
model: Annotated[Model, Form()] = config.whisper.model,
|
74 |
prompt: Annotated[str | None, Form()] = None,
|
|
|
76 |
temperature: Annotated[float, Form()] = 0.0,
|
77 |
stream: Annotated[bool, Form()] = False,
|
78 |
):
|
|
|
|
|
|
|
|
|
79 |
start = time.perf_counter()
|
80 |
+
whisper = load_model(model)
|
81 |
segments, transcription_info = whisper.transcribe(
|
82 |
file.file,
|
83 |
task="translate",
|
|
|
119 |
# https://platform.openai.com/docs/api-reference/audio/createTranscription
|
120 |
# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915
|
121 |
@app.post("/v1/audio/transcriptions")
|
122 |
+
def transcribe_file(
|
123 |
file: Annotated[UploadFile, Form()],
|
124 |
model: Annotated[Model, Form()] = config.whisper.model,
|
125 |
language: Annotated[Language | None, Form()] = config.default_language,
|
|
|
132 |
] = ["segments"],
|
133 |
stream: Annotated[bool, Form()] = False,
|
134 |
):
|
|
|
|
|
|
|
|
|
135 |
start = time.perf_counter()
|
136 |
+
whisper = load_model(model)
|
137 |
segments, transcription_info = whisper.transcribe(
|
138 |
file.file,
|
139 |
task="transcribe",
|
|
|
218 |
audio_stream.close()
|
219 |
|
220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
@app.websocket("/v1/audio/transcriptions")
|
222 |
async def transcribe_stream(
|
223 |
ws: WebSocket,
|
|
|
228 |
ResponseFormat, Query()
|
229 |
] = config.default_response_format,
|
230 |
temperature: Annotated[float, Query()] = 0.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
) -> None:
|
|
|
|
|
|
|
|
|
232 |
await ws.accept()
|
233 |
transcribe_opts = {
|
234 |
"language": language,
|
|
|
237 |
"vad_filter": True,
|
238 |
"condition_on_previous_text": False,
|
239 |
}
|
240 |
+
whisper = load_model(model)
|
241 |
asr = FasterWhisperASR(whisper, **transcribe_opts)
|
242 |
audio_stream = AudioStream()
|
243 |
async with asyncio.TaskGroup() as tg:
|
|
|
246 |
logger.debug(f"Sending transcription: {transcription.text}")
|
247 |
if ws.client_state == WebSocketState.DISCONNECTED:
|
248 |
break
|
249 |
+
|
250 |
+
if response_format == ResponseFormat.TEXT:
|
251 |
+
await ws.send_text(transcription.text)
|
252 |
+
elif response_format == ResponseFormat.JSON:
|
253 |
+
await ws.send_json(
|
254 |
+
TranscriptionJsonResponse.from_transcription(
|
255 |
+
transcription
|
256 |
+
).model_dump()
|
257 |
+
)
|
258 |
+
elif response_format == ResponseFormat.VERBOSE_JSON:
|
259 |
+
await ws.send_json(
|
260 |
+
TranscriptionVerboseJsonResponse.from_transcription(
|
261 |
+
transcription
|
262 |
+
).model_dump()
|
263 |
+
)
|
264 |
|
265 |
if not ws.client_state == WebSocketState.DISCONNECTED:
|
266 |
logger.info("Closing the connection.")
|