Fedir Zadniprovskyi commited on
Commit
4bdd7f2
·
1 Parent(s): db7bf9a

feat: further improve openai compatabilit + refactor

Browse files
Files changed (4) hide show
  1. speaches/asr.py +3 -11
  2. speaches/main.py +60 -31
  3. speaches/server_models.py +68 -4
  4. speaches/utils.py +14 -0
speaches/asr.py CHANGED
@@ -3,28 +3,20 @@ import time
3
  from typing import Iterable
4
 
5
  from faster_whisper import transcribe
6
- from pydantic import BaseModel
7
 
8
  from speaches.audio import Audio
9
- from speaches.config import Language
10
  from speaches.core import Transcription, Word
11
  from speaches.logger import logger
12
 
13
 
14
- class TranscribeOpts(BaseModel):
15
- language: Language | None
16
- vad_filter: bool
17
- condition_on_previous_text: bool
18
-
19
-
20
  class FasterWhisperASR:
21
  def __init__(
22
  self,
23
  whisper: transcribe.WhisperModel,
24
- transcribe_opts: TranscribeOpts,
25
  ) -> None:
26
  self.whisper = whisper
27
- self.transcribe_opts = transcribe_opts
28
 
29
  def _transcribe(
30
  self,
@@ -36,7 +28,7 @@ class FasterWhisperASR:
36
  audio.data,
37
  initial_prompt=prompt,
38
  word_timestamps=True,
39
- **self.transcribe_opts.model_dump(),
40
  )
41
  words = words_from_whisper_segments(segments)
42
  for word in words:
 
3
  from typing import Iterable
4
 
5
  from faster_whisper import transcribe
 
6
 
7
  from speaches.audio import Audio
 
8
  from speaches.core import Transcription, Word
9
  from speaches.logger import logger
10
 
11
 
 
 
 
 
 
 
12
  class FasterWhisperASR:
13
  def __init__(
14
  self,
15
  whisper: transcribe.WhisperModel,
16
+ **kwargs,
17
  ) -> None:
18
  self.whisper = whisper
19
+ self.transcribe_opts = kwargs
20
 
21
  def _transcribe(
22
  self,
 
28
  audio.data,
29
  initial_prompt=prompt,
30
  word_timestamps=True,
31
+ **self.transcribe_opts,
32
  )
33
  words = words_from_whisper_segments(segments)
34
  for word in words:
speaches/main.py CHANGED
@@ -5,17 +5,18 @@ import logging
5
  import time
6
  from contextlib import asynccontextmanager
7
  from io import BytesIO
8
- from typing import Annotated
9
 
10
- from fastapi import (Depends, FastAPI, Response, UploadFile, WebSocket,
11
  WebSocketDisconnect)
12
  from fastapi.websockets import WebSocketState
13
  from faster_whisper import WhisperModel
14
  from faster_whisper.vad import VadOptions, get_speech_timestamps
15
 
16
- from speaches.asr import FasterWhisperASR, TranscribeOpts
 
17
  from speaches.audio import AudioStream, audio_samples_from_file
18
- from speaches.config import SAMPLES_PER_SECOND, Language, config
19
  from speaches.core import Transcription
20
  from speaches.logger import logger
21
  from speaches.server_models import (ResponseFormat, TranscriptionJsonResponse,
@@ -48,32 +49,40 @@ def health() -> Response:
48
  return Response(status_code=200, content="Everything is peachy!")
49
 
50
 
51
- async def transcription_parameters(
52
- language: Language = Language.EN,
53
- vad_filter: bool = True,
54
- condition_on_previous_text: bool = False,
55
- ) -> TranscribeOpts:
56
- return TranscribeOpts(
57
- language=language,
58
- vad_filter=vad_filter,
59
- condition_on_previous_text=condition_on_previous_text,
60
- )
61
-
62
-
63
- TranscribeParams = Annotated[TranscribeOpts, Depends(transcription_parameters)]
64
-
65
-
66
  @app.post("/v1/audio/transcriptions")
67
  async def transcribe_file(
68
- file: UploadFile,
69
- transcription_opts: TranscribeParams,
70
- response_format: ResponseFormat = ResponseFormat.JSON,
71
- ) -> str:
72
- asr = FasterWhisperASR(whisper, transcription_opts)
73
- audio_samples = audio_samples_from_file(file.file)
74
- audio = AudioStream(audio_samples)
75
- transcription, _ = await asr.transcribe(audio)
76
- return format_transcription(transcription, response_format)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
 
79
  async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
@@ -135,11 +144,31 @@ def format_transcription(
135
  @app.websocket("/v1/audio/transcriptions")
136
  async def transcribe_stream(
137
  ws: WebSocket,
138
- transcription_opts: TranscribeParams,
139
- response_format: ResponseFormat = ResponseFormat.JSON,
 
 
 
 
 
 
 
 
 
 
140
  ) -> None:
 
 
 
141
  await ws.accept()
142
- asr = FasterWhisperASR(whisper, transcription_opts)
 
 
 
 
 
 
 
143
  audio_stream = AudioStream()
144
  async with asyncio.TaskGroup() as tg:
145
  tg.create_task(audio_receiver(ws, audio_stream))
 
5
  import time
6
  from contextlib import asynccontextmanager
7
  from io import BytesIO
8
+ from typing import Annotated, Literal
9
 
10
+ from fastapi import (FastAPI, Form, Query, Response, UploadFile, WebSocket,
11
  WebSocketDisconnect)
12
  from fastapi.websockets import WebSocketState
13
  from faster_whisper import WhisperModel
14
  from faster_whisper.vad import VadOptions, get_speech_timestamps
15
 
16
+ from speaches import utils
17
+ from speaches.asr import FasterWhisperASR
18
  from speaches.audio import AudioStream, audio_samples_from_file
19
+ from speaches.config import SAMPLES_PER_SECOND, Language, Model, config
20
  from speaches.core import Transcription
21
  from speaches.logger import logger
22
  from speaches.server_models import (ResponseFormat, TranscriptionJsonResponse,
 
49
  return Response(status_code=200, content="Everything is peachy!")
50
 
51
 
52
+ # https://platform.openai.com/docs/api-reference/audio/createTranscription
53
+ # https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  @app.post("/v1/audio/transcriptions")
55
  async def transcribe_file(
56
+ file: Annotated[UploadFile, Form()],
57
+ model: Annotated[Model, Form()] = config.whisper.model,
58
+ language: Annotated[Language | None, Form()] = None,
59
+ prompt: Annotated[str | None, Form()] = None,
60
+ response_format: Annotated[ResponseFormat, Form()] = ResponseFormat.JSON,
61
+ temperature: Annotated[float, Form()] = 0.0,
62
+ timestamp_granularities: Annotated[
63
+ list[Literal["segments"] | Literal["words"]],
64
+ Form(alias="timestamp_granularities[]"),
65
+ ] = ["segments"],
66
+ ):
67
+ assert (
68
+ model == config.whisper.model
69
+ ), "Specifying a model that is different from the default is not supported yet."
70
+ segments, transcription_info = whisper.transcribe(
71
+ file.file,
72
+ language=language,
73
+ initial_prompt=prompt,
74
+ word_timestamps="words" in timestamp_granularities,
75
+ temperature=temperature,
76
+ )
77
+ segments = list(segments)
78
+ if response_format == ResponseFormat.TEXT:
79
+ return utils.segments_text(segments)
80
+ elif response_format == ResponseFormat.JSON:
81
+ return TranscriptionJsonResponse.from_segments(segments)
82
+ elif response_format == ResponseFormat.VERBOSE_JSON:
83
+ return TranscriptionVerboseJsonResponse.from_segments(
84
+ segments, transcription_info
85
+ )
86
 
87
 
88
  async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
 
144
  @app.websocket("/v1/audio/transcriptions")
145
  async def transcribe_stream(
146
  ws: WebSocket,
147
+ model: Annotated[Model, Query()] = config.whisper.model,
148
+ language: Annotated[Language | None, Query()] = None,
149
+ prompt: Annotated[str | None, Query()] = None,
150
+ response_format: Annotated[ResponseFormat, Query()] = ResponseFormat.JSON,
151
+ temperature: Annotated[float, Query()] = 0.0,
152
+ timestamp_granularities: Annotated[
153
+ list[Literal["segments"] | Literal["words"]],
154
+ Query(
155
+ alias="timestamp_granularities[]",
156
+ description="No-op. Ignored. Only for compatibility.",
157
+ ),
158
+ ] = ["segments", "words"],
159
  ) -> None:
160
+ assert (
161
+ model == config.whisper.model
162
+ ), "Specifying a model that is different from the default is not supported yet."
163
  await ws.accept()
164
+ transcribe_opts = {
165
+ "language": language,
166
+ "initial_prompt": prompt,
167
+ "temperature": temperature,
168
+ "vad_filter": True,
169
+ "condition_on_previous_text": False,
170
+ }
171
+ asr = FasterWhisperASR(whisper, **transcribe_opts)
172
  audio_stream = AudioStream()
173
  async with asyncio.TaskGroup() as tg:
174
  tg.create_task(audio_receiver(ws, audio_stream))
speaches/server_models.py CHANGED
@@ -2,9 +2,10 @@ from __future__ import annotations
2
 
3
  import enum
4
 
5
- from faster_whisper.transcribe import Segment, Word
6
  from pydantic import BaseModel
7
 
 
8
  from speaches.core import Transcription
9
 
10
 
@@ -21,6 +22,10 @@ class ResponseFormat(enum.StrEnum):
21
  class TranscriptionJsonResponse(BaseModel):
22
  text: str
23
 
 
 
 
 
24
  @classmethod
25
  def from_transcription(
26
  cls, transcription: Transcription
@@ -28,14 +33,73 @@ class TranscriptionJsonResponse(BaseModel):
28
  return cls(text=transcription.text)
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  # https://platform.openai.com/docs/api-reference/audio/verbose-json-object
32
  class TranscriptionVerboseJsonResponse(BaseModel):
33
  task: str = "transcribe"
34
  language: str
35
  duration: float
36
  text: str
37
- words: list[Word]
38
- segments: list[Segment]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  @classmethod
41
  def from_transcription(
@@ -46,7 +110,7 @@ class TranscriptionVerboseJsonResponse(BaseModel):
46
  duration=transcription.duration,
47
  text=transcription.text,
48
  words=[
49
- Word(
50
  start=word.start,
51
  end=word.end,
52
  word=word.text,
 
2
 
3
  import enum
4
 
5
+ from faster_whisper.transcribe import Segment, TranscriptionInfo, Word
6
  from pydantic import BaseModel
7
 
8
+ from speaches import utils
9
  from speaches.core import Transcription
10
 
11
 
 
22
  class TranscriptionJsonResponse(BaseModel):
23
  text: str
24
 
25
+ @classmethod
26
+ def from_segments(cls, segments: list[Segment]) -> TranscriptionJsonResponse:
27
+ return cls(text=utils.segments_text(segments))
28
+
29
  @classmethod
30
  def from_transcription(
31
  cls, transcription: Transcription
 
33
  return cls(text=transcription.text)
34
 
35
 
36
+ class WordObject(BaseModel):
37
+ start: float
38
+ end: float
39
+ word: str
40
+ probability: float
41
+
42
+ @classmethod
43
+ def from_word(cls, word: Word) -> WordObject:
44
+ return cls(
45
+ start=word.start,
46
+ end=word.end,
47
+ word=word.word,
48
+ probability=word.probability,
49
+ )
50
+
51
+
52
+ class SegmentObject(BaseModel):
53
+ id: int
54
+ seek: int
55
+ start: float
56
+ end: float
57
+ text: str
58
+ tokens: list[int]
59
+ temperature: float
60
+ avg_logprob: float
61
+ compression_ratio: float
62
+ no_speech_prob: float
63
+
64
+ @classmethod
65
+ def from_segment(cls, segment: Segment) -> SegmentObject:
66
+ return cls(
67
+ id=segment.id,
68
+ seek=segment.seek,
69
+ start=segment.start,
70
+ end=segment.end,
71
+ text=segment.text,
72
+ tokens=segment.tokens,
73
+ temperature=segment.temperature,
74
+ avg_logprob=segment.avg_logprob,
75
+ compression_ratio=segment.compression_ratio,
76
+ no_speech_prob=segment.no_speech_prob,
77
+ )
78
+
79
+
80
  # https://platform.openai.com/docs/api-reference/audio/verbose-json-object
81
  class TranscriptionVerboseJsonResponse(BaseModel):
82
  task: str = "transcribe"
83
  language: str
84
  duration: float
85
  text: str
86
+ words: list[WordObject]
87
+ segments: list[SegmentObject]
88
+
89
+ @classmethod
90
+ def from_segments(
91
+ cls, segments: list[Segment], transcription_info: TranscriptionInfo
92
+ ) -> TranscriptionVerboseJsonResponse:
93
+ return cls(
94
+ language=transcription_info.language,
95
+ duration=transcription_info.duration,
96
+ text=utils.segments_text(segments),
97
+ segments=[SegmentObject.from_segment(segment) for segment in segments],
98
+ words=[
99
+ WordObject.from_word(word)
100
+ for word in utils.words_from_segments(segments)
101
+ ],
102
+ )
103
 
104
  @classmethod
105
  def from_transcription(
 
110
  duration=transcription.duration,
111
  text=transcription.text,
112
  words=[
113
+ WordObject(
114
  start=word.start,
115
  end=word.end,
116
  word=word.text,
speaches/utils.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from faster_whisper.transcribe import Segment, Word
2
+
3
+
4
+ def segments_text(segments: list[Segment]) -> str:
5
+ return "".join(segment.text for segment in segments).strip()
6
+
7
+
8
+ def words_from_segments(segments: list[Segment]) -> list[Word]:
9
+ words = []
10
+ for segment in segments:
11
+ if segment.words is None:
12
+ continue
13
+ words.extend(segment.words)
14
+ return words