File size: 2,854 Bytes
313814b
 
dc4f25f
313814b
39ee116
 
 
 
 
 
 
 
 
313814b
dc4f25f
 
 
 
 
313814b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from __future__ import annotations

from typing import TYPE_CHECKING

from faster_whisper_server.audio import Audio, AudioStream
from faster_whisper_server.config import config
from faster_whisper_server.core import (
    Transcription,
    Word,
    common_prefix,
    to_full_sentences,
)
from faster_whisper_server.logger import logger

if TYPE_CHECKING:
    from collections.abc import AsyncGenerator

    from faster_whisper_server.asr import FasterWhisperASR


class LocalAgreement:
    def __init__(self) -> None:
        self.unconfirmed = Transcription()

    def merge(self, confirmed: Transcription, incoming: Transcription) -> list[Word]:
        # https://github.com/ufal/whisper_streaming/blob/main/whisper_online.py#L264
        incoming = incoming.after(confirmed.end - 0.1)
        prefix = common_prefix(incoming.words, self.unconfirmed.words)
        logger.debug(f"Confirmed: {confirmed.text}")
        logger.debug(f"Unconfirmed: {self.unconfirmed.text}")
        logger.debug(f"Incoming: {incoming.text}")

        if len(incoming.words) > len(prefix):
            self.unconfirmed = Transcription(incoming.words[len(prefix) :])
        else:
            self.unconfirmed = Transcription()

        return prefix

    @classmethod
    def prompt(cls, confirmed: Transcription) -> str | None:
        sentences = to_full_sentences(confirmed.words)
        if len(sentences) == 0:
            return None
        return sentences[-1].text

    # TODO: better name
    @classmethod
    def needs_audio_after(cls, confirmed: Transcription) -> float:
        full_sentences = to_full_sentences(confirmed.words)
        return full_sentences[-1].end if len(full_sentences) > 0 else 0.0


def needs_audio_after(confirmed: Transcription) -> float:
    full_sentences = to_full_sentences(confirmed.words)
    return full_sentences[-1].end if len(full_sentences) > 0 else 0.0


def prompt(confirmed: Transcription) -> str | None:
    sentences = to_full_sentences(confirmed.words)
    if len(sentences) == 0:
        return None
    return sentences[-1].text


async def audio_transcriber(
    asr: FasterWhisperASR,
    audio_stream: AudioStream,
) -> AsyncGenerator[Transcription, None]:
    local_agreement = LocalAgreement()
    full_audio = Audio()
    confirmed = Transcription()
    async for chunk in audio_stream.chunks(config.min_duration):
        full_audio.extend(chunk)
        audio = full_audio.after(needs_audio_after(confirmed))
        transcription, _ = await asr.transcribe(audio, prompt(confirmed))
        new_words = local_agreement.merge(confirmed, transcription)
        if len(new_words) > 0:
            confirmed.extend(new_words)
            yield confirmed
    logger.debug("Flushing...")
    confirmed.extend(local_agreement.unconfirmed.words)
    yield confirmed
    logger.info("Audio transcriber finished")