File size: 4,130 Bytes
dc4f25f
3e15f14
 
 
 
 
 
 
 
 
 
 
 
 
 
dc4f25f
3e15f14
 
 
dc4f25f
3e15f14
79f1f8d
dc4f25f
79f1f8d
 
 
 
3e15f14
dc4f25f
79f1f8d
3e15f14
79f1f8d
3e15f14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc4f25f
3e15f14
 
 
 
79f1f8d
 
 
5aa421e
79f1f8d
dc4f25f
79f1f8d
 
dc4f25f
79f1f8d
dc4f25f
79f1f8d
 
 
 
3e15f14
 
 
 
 
 
 
 
79f1f8d
3e15f14
dc4f25f
3e15f14
79f1f8d
3e15f14
dc4f25f
3e15f14
 
 
 
 
 
 
 
 
79f1f8d
 
3e15f14
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from collections.abc import Generator
import os

import gradio as gr
import httpx
from httpx_sse import connect_sse

from faster_whisper_server.config import Config, Task

TRANSCRIPTION_ENDPOINT = "/v1/audio/transcriptions"
TRANSLATION_ENDPOINT = "/v1/audio/translations"


def create_gradio_demo(config: Config) -> gr.Blocks:
    host = os.getenv("UVICORN_HOST", "0.0.0.0")
    port = int(os.getenv("UVICORN_PORT", "8000"))
    # NOTE: worth looking into generated clients
    http_client = httpx.Client(base_url=f"http://{host}:{port}", timeout=None)

    def handler(file_path: str, model: str, task: Task, temperature: float, stream: bool) -> Generator[str, None, None]:
        if stream:
            previous_transcription = ""
            for transcription in transcribe_audio_streaming(file_path, task, temperature, model):
                previous_transcription += transcription
                yield previous_transcription
        else:
            yield transcribe_audio(file_path, task, temperature, model)

    def transcribe_audio(file_path: str, task: Task, temperature: float, model: str) -> str:
        if task == Task.TRANSCRIBE:
            endpoint = TRANSCRIPTION_ENDPOINT
        elif task == Task.TRANSLATE:
            endpoint = TRANSLATION_ENDPOINT

        with open(file_path, "rb") as file:
            response = http_client.post(
                endpoint,
                files={"file": file},
                data={
                    "model": model,
                    "response_format": "text",
                    "temperature": temperature,
                },
            )

        response.raise_for_status()
        return response.text

    def transcribe_audio_streaming(
        file_path: str, task: Task, temperature: float, model: str
    ) -> Generator[str, None, None]:
        with open(file_path, "rb") as file:
            kwargs = {
                "files": {"file": file},
                "data": {
                    "response_format": "text",
                    "temperature": temperature,
                    "model": model,
                    "stream": True,
                },
            }
            endpoint = TRANSCRIPTION_ENDPOINT if task == Task.TRANSCRIBE else TRANSLATION_ENDPOINT
            with connect_sse(http_client, "POST", endpoint, **kwargs) as event_source:
                for event in event_source.iter_sse():
                    yield event.data

    def update_model_dropdown() -> gr.Dropdown:
        res = http_client.get("/v1/models")
        res_data = res.json()
        models: list[str] = [model["id"] for model in res_data["data"]]
        assert config.whisper.model in models
        recommended_models = {model for model in models if model.startswith("Systran")}
        other_models = [model for model in models if model not in recommended_models]
        models = list(recommended_models) + other_models
        return gr.Dropdown(
            # no idea why it's complaining
            choices=models,  # pyright: ignore[reportArgumentType]
            label="Model",
            value=config.whisper.model,
        )

    model_dropdown = gr.Dropdown(
        choices=[config.whisper.model],
        label="Model",
        value=config.whisper.model,
    )
    task_dropdown = gr.Dropdown(
        choices=[task.value for task in Task],
        label="Task",
        value=Task.TRANSCRIBE,
    )
    temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Temperature", value=0.0)
    stream_checkbox = gr.Checkbox(label="Stream", value=True)
    with gr.Interface(
        title="Whisper Playground",
        description="""Consider supporting the project by starring the <a href="https://github.com/fedirz/faster-whisper-server">repository on GitHub</a>.""",  # noqa: E501
        inputs=[
            gr.Audio(type="filepath"),
            model_dropdown,
            task_dropdown,
            temperature_slider,
            stream_checkbox,
        ],
        fn=handler,
        outputs="text",
    ) as demo:
        demo.load(update_model_dropdown, inputs=None, outputs=model_dropdown)
    return demo