File size: 4,238 Bytes
dc4f25f
3e15f14
 
 
 
 
01b8eeb
3e15f14
 
 
 
 
 
 
 
 
dc4f25f
3e15f14
 
01b8eeb
3e15f14
dc4f25f
3e15f14
79f1f8d
dc4f25f
79f1f8d
 
 
 
3e15f14
dc4f25f
79f1f8d
3e15f14
79f1f8d
3e15f14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc4f25f
3e15f14
 
 
 
79f1f8d
01b8eeb
 
 
 
 
 
dc4f25f
79f1f8d
01b8eeb
79f1f8d
 
 
 
3e15f14
 
 
 
 
 
 
 
79f1f8d
3e15f14
dc4f25f
3e15f14
79f1f8d
3e15f14
dc4f25f
3e15f14
 
 
 
 
 
 
 
 
79f1f8d
 
3e15f14
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from collections.abc import Generator
import os

import gradio as gr
import httpx
from httpx_sse import connect_sse
from openai import OpenAI

from faster_whisper_server.config import Config, Task

TRANSCRIPTION_ENDPOINT = "/v1/audio/transcriptions"
TRANSLATION_ENDPOINT = "/v1/audio/translations"


def create_gradio_demo(config: Config) -> gr.Blocks:
    host = os.getenv("UVICORN_HOST", "0.0.0.0")
    port = int(os.getenv("UVICORN_PORT", "8000"))
    # NOTE: worth looking into generated clients
    http_client = httpx.Client(base_url=f"http://{host}:{port}", timeout=None)
    openai_client = OpenAI(base_url=f"http://{host}:{port}/v1", api_key="cant-be-empty")

    def handler(file_path: str, model: str, task: Task, temperature: float, stream: bool) -> Generator[str, None, None]:
        if stream:
            previous_transcription = ""
            for transcription in transcribe_audio_streaming(file_path, task, temperature, model):
                previous_transcription += transcription
                yield previous_transcription
        else:
            yield transcribe_audio(file_path, task, temperature, model)

    def transcribe_audio(file_path: str, task: Task, temperature: float, model: str) -> str:
        if task == Task.TRANSCRIBE:
            endpoint = TRANSCRIPTION_ENDPOINT
        elif task == Task.TRANSLATE:
            endpoint = TRANSLATION_ENDPOINT

        with open(file_path, "rb") as file:
            response = http_client.post(
                endpoint,
                files={"file": file},
                data={
                    "model": model,
                    "response_format": "text",
                    "temperature": temperature,
                },
            )

        response.raise_for_status()
        return response.text

    def transcribe_audio_streaming(
        file_path: str, task: Task, temperature: float, model: str
    ) -> Generator[str, None, None]:
        with open(file_path, "rb") as file:
            kwargs = {
                "files": {"file": file},
                "data": {
                    "response_format": "text",
                    "temperature": temperature,
                    "model": model,
                    "stream": True,
                },
            }
            endpoint = TRANSCRIPTION_ENDPOINT if task == Task.TRANSCRIBE else TRANSLATION_ENDPOINT
            with connect_sse(http_client, "POST", endpoint, **kwargs) as event_source:
                for event in event_source.iter_sse():
                    yield event.data

    def update_model_dropdown() -> gr.Dropdown:
        models = openai_client.models.list().data
        model_names: list[str] = [model.id for model in models]
        assert config.whisper.model in model_names
        recommended_models = {model for model in model_names if model.startswith("Systran")}
        other_models = [model for model in model_names if model not in recommended_models]
        model_names = list(recommended_models) + other_models
        return gr.Dropdown(
            # no idea why it's complaining
            choices=model_names,  # pyright: ignore[reportArgumentType]
            label="Model",
            value=config.whisper.model,
        )

    model_dropdown = gr.Dropdown(
        choices=[config.whisper.model],
        label="Model",
        value=config.whisper.model,
    )
    task_dropdown = gr.Dropdown(
        choices=[task.value for task in Task],
        label="Task",
        value=Task.TRANSCRIBE,
    )
    temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Temperature", value=0.0)
    stream_checkbox = gr.Checkbox(label="Stream", value=True)
    with gr.Interface(
        title="Whisper Playground",
        description="""Consider supporting the project by starring the <a href="https://github.com/fedirz/faster-whisper-server">repository on GitHub</a>.""",  # noqa: E501
        inputs=[
            gr.Audio(type="filepath"),
            model_dropdown,
            task_dropdown,
            temperature_slider,
            stream_checkbox,
        ],
        fn=handler,
        outputs="text",
    ) as demo:
        demo.load(update_model_dropdown, inputs=None, outputs=model_dropdown)
    return demo