Spaces:
Sleeping
Sleeping
File size: 4,523 Bytes
7fdb8e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
from functools import cached_property
from pathlib import Path
from typing import Optional, ClassVar
from threading import Lock
import numpy as np
from loguru import logger
from numpy.typing import NDArray
from sentence_transformers.SentenceTransformer import SentenceTransformer
from transformers import AutoTokenizer
from rag_demo.settings import settings
class SingletonMeta(type):
"""
This is a thread-safe implementation of Singleton.
"""
_instances: ClassVar = {}
_lock: Lock = Lock()
"""
We now have a lock object that will be used to synchronize threads during
first access to the Singleton.
"""
def __call__(cls, *args, **kwargs):
"""
Possible changes to the value of the `__init__` argument do not affect
the returned instance.
"""
# Now, imagine that the program has just been launched. Since there's no
# Singleton instance yet, multiple threads can simultaneously pass the
# previous conditional and reach this point almost at the same time. The
# first of them will acquire lock and will proceed further, while the
# rest will wait here.
with cls._lock:
# The first thread to acquire the lock, reaches this conditional,
# goes inside and creates the Singleton instance. Once it leaves the
# lock block, a thread that might have been waiting for the lock
# release may then enter this section. But since the Singleton field
# is already initialized, the thread won't create a new object.
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
return cls._instances[cls]
class EmbeddingModelSingleton(metaclass=SingletonMeta):
"""
A singleton class that provides a pre-trained transformer model for generating embeddings of input text.
"""
def __init__(
self,
model_id: str = settings.TEXT_EMBEDDING_MODEL_ID,
device: str = settings.RAG_MODEL_DEVICE,
cache_dir: Optional[Path] = None,
) -> None:
self._model_id = model_id
self._device = device
self._model = SentenceTransformer(
self._model_id,
device=self._device,
cache_folder=str(cache_dir) if cache_dir else None,
)
self._model.eval()
@property
def model_id(self) -> str:
"""
Returns the identifier of the pre-trained transformer model to use.
Returns:
str: The identifier of the pre-trained transformer model to use.
"""
return self._model_id
@cached_property
def embedding_size(self) -> int:
"""
Returns the size of the embeddings generated by the pre-trained transformer model.
Returns:
int: The size of the embeddings generated by the pre-trained transformer model.
"""
dummy_embedding = self._model.encode("")
return dummy_embedding.shape[0]
@property
def max_input_length(self) -> int:
"""
Returns the maximum length of input text to tokenize.
Returns:
int: The maximum length of input text to tokenize.
"""
return self._model.max_seq_length
@property
def tokenizer(self) -> AutoTokenizer:
"""
Returns the tokenizer used to tokenize input text.
Returns:
AutoTokenizer: The tokenizer used to tokenize input text.
"""
return self._model.tokenizer
def __call__(
self, input_text: str | list[str], to_list: bool = True
) -> NDArray[np.float32] | list[float] | list[list[float]]:
"""
Generates embeddings for the input text using the pre-trained transformer model.
Args:
input_text (str): The input text to generate embeddings for.
to_list (bool): Whether to return the embeddings as a list or numpy array. Defaults to True.
Returns:
Union[np.ndarray, list]: The embeddings generated for the input text.
"""
try:
embeddings = self._model.encode(input_text)
except Exception:
logger.error(
f"Error generating embeddings for {self._model_id=} and {input_text=}"
)
return [] if to_list else np.array([])
if to_list:
embeddings = embeddings.tolist()
return embeddings
|