File size: 7,412 Bytes
d2d8fab
7cc3853
 
bf48682
d2d8fab
 
7cc3853
d2d8fab
 
7cc3853
 
 
 
d2d8fab
bf48682
d2d8fab
 
 
 
 
 
7cc3853
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2d8fab
 
7cc3853
 
 
d2d8fab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7cc3853
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
from collections.abc import Generator
from functools import lru_cache
import json
import logging
from pathlib import Path
import typing
from typing import Any, Literal

import huggingface_hub
from huggingface_hub.constants import HF_HUB_CACHE
from pydantic import BaseModel

from faster_whisper_server.api_models import Model

logger = logging.getLogger(__name__)

LIBRARY_NAME = "ctranslate2"
TASK_NAME = "automatic-speech-recognition"


def does_local_model_exist(model_id: str) -> bool:
    return any(model_id == model.repo_id for model, _ in list_local_whisper_models())


def list_whisper_models() -> Generator[Model, None, None]:
    models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True)
    models = list(models)
    models.sort(key=lambda model: model.downloads or -1, reverse=True)
    for model in models:
        assert model.created_at is not None
        assert model.card_data is not None
        assert model.card_data.language is None or isinstance(model.card_data.language, str | list)
        if model.card_data.language is None:
            language = []
        elif isinstance(model.card_data.language, str):
            language = [model.card_data.language]
        else:
            language = model.card_data.language
        transformed_model = Model(
            id=model.id,
            created=int(model.created_at.timestamp()),
            object_="model",
            owned_by=model.id.split("/")[0],
            language=language,
        )
        yield transformed_model


def list_local_whisper_models() -> (
    Generator[tuple[huggingface_hub.CachedRepoInfo, huggingface_hub.ModelCardData], None, None]
):
    hf_cache = huggingface_hub.scan_cache_dir()
    hf_models = [repo for repo in list(hf_cache.repos) if repo.repo_type == "model"]
    for model in hf_models:
        revision = next(iter(model.revisions))
        cached_readme_file = next((f for f in revision.files if f.file_name == "README.md"), None)
        if cached_readme_file:
            readme_file_path = Path(cached_readme_file.file_path)
        else:
            # NOTE: the README.md doesn't get downloaded when `WhisperModel` is called
            logger.debug(f"Model {model.repo_id} does not have a README.md file. Downloading it.")
            readme_file_path = Path(huggingface_hub.hf_hub_download(model.repo_id, "README.md"))

        model_card = huggingface_hub.ModelCard.load(readme_file_path)
        model_card_data = typing.cast(huggingface_hub.ModelCardData, model_card.data)
        if (
            model_card_data.library_name == LIBRARY_NAME
            and model_card_data.tags is not None
            and TASK_NAME in model_card_data.tags
        ):
            yield model, model_card_data


def get_whisper_models() -> Generator[Model, None, None]:
    models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True)
    models = list(models)
    models.sort(key=lambda model: model.downloads or -1, reverse=True)
    for model in models:
        assert model.created_at is not None
        assert model.card_data is not None
        assert model.card_data.language is None or isinstance(model.card_data.language, str | list)
        if model.card_data.language is None:
            language = []
        elif isinstance(model.card_data.language, str):
            language = [model.card_data.language]
        else:
            language = model.card_data.language
        transformed_model = Model(
            id=model.id,
            created=int(model.created_at.timestamp()),
            object_="model",
            owned_by=model.id.split("/")[0],
            language=language,
        )
        yield transformed_model


class PiperModel(BaseModel):
    id: str
    object: Literal["model"] = "model"
    created: int
    owned_by: Literal["rhasspy"] = "rhasspy"
    path: Path
    config_path: Path


def get_model_path(model_id: str, *, cache_dir: str | Path | None = None) -> Path | None:
    if cache_dir is None:
        cache_dir = HF_HUB_CACHE

    cache_dir = Path(cache_dir).expanduser().resolve()
    if not cache_dir.exists():
        raise huggingface_hub.CacheNotFound(
            f"Cache directory not found: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable.",  # noqa: E501
            cache_dir=cache_dir,
        )

    if cache_dir.is_file():
        raise ValueError(
            f"Scan cache expects a directory but found a file: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable."  # noqa: E501
        )

    for repo_path in cache_dir.iterdir():
        if not repo_path.is_dir():
            continue
        if repo_path.name == ".locks":  # skip './.locks/' folder
            continue
        repo_type, repo_id = repo_path.name.split("--", maxsplit=1)
        repo_type = repo_type[:-1]  # "models" -> "model"
        repo_id = repo_id.replace("--", "/")  # google--fleurs -> "google/fleurs"
        if repo_type != "model":
            continue
        if model_id == repo_id:
            return repo_path

    return None


def list_model_files(
    model_id: str, glob_pattern: str = "**/*", *, cache_dir: str | Path | None = None
) -> Generator[Path, None, None]:
    repo_path = get_model_path(model_id, cache_dir=cache_dir)
    if repo_path is None:
        return None
    snapshots_path = repo_path / "snapshots"
    if not snapshots_path.exists():
        return None
    yield from list(snapshots_path.glob(glob_pattern))


def list_piper_models() -> Generator[PiperModel, None, None]:
    model_weights_files = list_model_files("rhasspy/piper-voices", glob_pattern="**/*.onnx")
    for model_weights_file in model_weights_files:
        model_config_file = model_weights_file.with_suffix(".json")
        yield PiperModel(
            id=model_weights_file.name,
            created=int(model_weights_file.stat().st_mtime),
            path=model_weights_file,
            config_path=model_config_file,
        )


# NOTE: It's debatable whether caching should be done here or by the caller. Should be revisited.


@lru_cache
def read_piper_voices_config() -> dict[str, Any]:
    voices_file = next(list_model_files("rhasspy/piper-voices", glob_pattern="**/voices.json"), None)
    if voices_file is None:
        raise FileNotFoundError("Could not find voices.json file")  # noqa: EM101
    return json.loads(voices_file.read_text())


@lru_cache
def get_piper_voice_model_file(voice: str) -> Path:
    model_file = next(list_model_files("rhasspy/piper-voices", glob_pattern=f"**/{voice}.onnx"), None)
    if model_file is None:
        raise FileNotFoundError(f"Could not find model file for '{voice}' voice")
    return model_file


class PiperVoiceConfigAudio(BaseModel):
    sample_rate: int
    quality: int


class PiperVoiceConfig(BaseModel):
    audio: PiperVoiceConfigAudio
    # NOTE: there are more fields in the config, but we don't care about them


@lru_cache
def read_piper_voice_config(voice: str) -> PiperVoiceConfig:
    model_config_file = next(list_model_files("rhasspy/piper-voices", glob_pattern=f"**/{voice}.onnx.json"), None)
    if model_config_file is None:
        raise FileNotFoundError(f"Could not find config file for '{voice}' voice")
    return PiperVoiceConfig.model_validate_json(model_config_file.read_text())