File size: 5,071 Bytes
624f97e
 
 
 
bf48682
35eafc3
624f97e
bf48682
624f97e
 
 
bf48682
35eafc3
 
bf48682
35eafc3
bf48682
 
 
624f97e
35eafc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
624f97e
 
35eafc3
 
 
 
624f97e
35eafc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf48682
35eafc3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from __future__ import annotations

from collections import OrderedDict
import gc
import logging
import threading
import time
from typing import TYPE_CHECKING

from faster_whisper import WhisperModel

if TYPE_CHECKING:
    from collections.abc import Callable

    from faster_whisper_server.config import (
        WhisperConfig,
    )

logger = logging.getLogger(__name__)

# TODO: enable concurrent model downloads


class SelfDisposingWhisperModel:
    def __init__(
        self,
        model_id: str,
        whisper_config: WhisperConfig,
        *,
        on_unload: Callable[[str], None] | None = None,
    ) -> None:
        self.model_id = model_id
        self.whisper_config = whisper_config
        self.on_unload = on_unload

        self.ref_count: int = 0
        self.rlock = threading.RLock()
        self.expire_timer: threading.Timer | None = None
        self.whisper: WhisperModel | None = None

    def unload(self) -> None:
        with self.rlock:
            if self.whisper is None:
                raise ValueError(f"Model {self.model_id} is not loaded. {self.ref_count=}")
            if self.ref_count > 0:
                raise ValueError(f"Model {self.model_id} is still in use. {self.ref_count=}")
            if self.expire_timer:
                self.expire_timer.cancel()
            self.whisper = None
            # WARN: ~300 MB of memory will still be held by the model. See https://github.com/SYSTRAN/faster-whisper/issues/992
            gc.collect()
            logger.info(f"Model {self.model_id} unloaded")
            if self.on_unload is not None:
                self.on_unload(self.model_id)

    def _load(self) -> None:
        with self.rlock:
            assert self.whisper is None
            logger.debug(f"Loading model {self.model_id}")
            start = time.perf_counter()
            self.whisper = WhisperModel(
                self.model_id,
                device=self.whisper_config.inference_device,
                device_index=self.whisper_config.device_index,
                compute_type=self.whisper_config.compute_type,
                cpu_threads=self.whisper_config.cpu_threads,
                num_workers=self.whisper_config.num_workers,
            )
            logger.info(f"Model {self.model_id} loaded in {time.perf_counter() - start:.2f}s")

    def _increment_ref(self) -> None:
        with self.rlock:
            self.ref_count += 1
            if self.expire_timer:
                logger.debug(f"Model was set to expire in {self.expire_timer.interval}s, cancelling")
                self.expire_timer.cancel()
            logger.debug(f"Incremented ref count for {self.model_id}, {self.ref_count=}")

    def _decrement_ref(self) -> None:
        with self.rlock:
            self.ref_count -= 1
            logger.debug(f"Decremented ref count for {self.model_id}, {self.ref_count=}")
            if self.ref_count <= 0:
                if self.whisper_config.ttl > 0:
                    logger.info(f"Model {self.model_id} is idle, scheduling offload in {self.whisper_config.ttl}s")
                    self.expire_timer = threading.Timer(self.whisper_config.ttl, self.unload)
                    self.expire_timer.start()
                elif self.whisper_config.ttl == 0:
                    logger.info(f"Model {self.model_id} is idle, unloading immediately")
                    self.unload()
                else:
                    logger.info(f"Model {self.model_id} is idle, not unloading")

    def __enter__(self) -> WhisperModel:
        with self.rlock:
            if self.whisper is None:
                self._load()
            self._increment_ref()
            assert self.whisper is not None
            return self.whisper

    def __exit__(self, *_args) -> None:  # noqa: ANN002
        self._decrement_ref()


class ModelManager:
    def __init__(self, whisper_config: WhisperConfig) -> None:
        self.whisper_config = whisper_config
        self.loaded_models: OrderedDict[str, SelfDisposingWhisperModel] = OrderedDict()
        self._lock = threading.Lock()

    def _handle_model_unload(self, model_name: str) -> None:
        with self._lock:
            if model_name in self.loaded_models:
                del self.loaded_models[model_name]

    def unload_model(self, model_name: str) -> None:
        with self._lock:
            model = self.loaded_models.get(model_name)
            if model is None:
                raise KeyError(f"Model {model_name} not found")
            self.loaded_models[model_name].unload()

    def load_model(self, model_name: str) -> SelfDisposingWhisperModel:
        with self._lock:
            if model_name in self.loaded_models:
                logger.debug(f"{model_name} model already loaded")
                return self.loaded_models[model_name]
            self.loaded_models[model_name] = SelfDisposingWhisperModel(
                model_name,
                self.whisper_config,
                on_unload=self._handle_model_unload,
            )
            return self.loaded_models[model_name]