import numpy as np
import torch
from typing import List, Union, Dict
from collections import defaultdict
from transformers import AutoTokenizer
from scipy.sparse import vstack, csr_matrix
import onnxruntime as ort

class ONNXInferenceModel:
    def __init__(
        self,
        model_path: str,
        tokenizer_path: str,
        max_length: int = 8192,
        use_fp16: bool = True,
        device: str = "cuda"
    ):
        self.max_length = max_length
        self.use_fp16 = use_fp16
        self.device = device

        
        providers = [('CUDAExecutionProvider', {
            'device_id': 0,
            'arena_extend_strategy': 'kSameAsRequested',
            'gpu_mem_limit': 5 * 1024 * 1024 * 1024, 
            'cudnn_conv_algo_search': 'EXHAUSTIVE',
            'do_copy_in_default_stream': True,
        })]

        so = ort.SessionOptions()
        so.enable_mem_pattern = True
        so.enable_mem_reuse = True
        so.add_session_config_entry("memory.enable_memory_arena_shrinkage", "cpu:0; gpu:0")
        so.add_session_config_entry('session.use_device_allocator_for_initializers', "1")
        so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
        so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

        self.session = ort.InferenceSession(model_path, providers=providers, sess_options=so)
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)

    def _process_token_weights(self, token_weights: np.ndarray, input_ids: list) -> csr_matrix:
        result = defaultdict(float)
        unused_tokens = set([
            self.tokenizer.cls_token_id,
            self.tokenizer.eos_token_id,
            self.tokenizer.pad_token_id,
            self.tokenizer.unk_token_id,
        ])
        for w, idx in zip(token_weights, input_ids):
            if idx not in unused_tokens and w > 0:
                result[idx] = max(result[idx], float(w))
        
        indices = list(result.keys())
        data = list(result.values())
        return csr_matrix((data, ([0] * len(indices), indices)), shape=(1, self.tokenizer.vocab_size), dtype=np.float64)

    def encode(self, sentences: Union[List[str], str], batch_size: int = 12, 
               return_dense: bool = True, return_sparse: bool = False) -> Dict[str, Union[List[np.ndarray], csr_matrix]]:
        if isinstance(sentences, str):
            sentences = [sentences]

        dense_embeddings = []
        sparse_embeddings = []

        for i in range(0, len(sentences), batch_size):
            batch = sentences[i:i+batch_size]
            inputs = self.tokenizer(
                batch,
                padding="longest",
                truncation=True,
                max_length=self.max_length,
                return_tensors="pt"  
            ).to(self.device)  

            
            ort_inputs = {k: v.cpu().numpy() for k, v in inputs.items()}
            ort_outputs = self.session.run(None, ort_inputs)

            if return_dense:
                batch_dense = ort_outputs[0] 
                if self.use_fp16:
                    batch_dense = batch_dense.astype(np.float16)
                dense_embeddings.extend(batch_dense)

            if return_sparse:
                sparse_vecs = ort_outputs[1]  
                for j, input_ids in enumerate(inputs["input_ids"].cpu().numpy()):
                    sparse_embeddings.append(self._process_token_weights(sparse_vecs[j], input_ids.tolist()))

        result = {}
        if return_dense:
            result["dense"] = dense_embeddings
        if return_sparse:
            result["sparse"] = vstack(sparse_embeddings)

        return result



tokenizer_path = tokenizer_path
model_path = model_path

onnx_model = ONNXInferenceModel(model_path=model_path, tokenizer_path=tokenizer_path, use_fp16=True)



sentences = ["Hi"']
embeddings = onnx_model.encode(sentences, return_dense=True, return_sparse=True)
print(embeddings)

{'dense': [array([-0.0251  ,  0.03464 , -0.04285 , ..., -0.02548 ,  0.004963,
       -0.034   ], dtype=float16)
, 'sparse': <1x250002 sparse matrix of type '<class 'numpy.float64'>'
    with 7 stored elements in Compressed Sparse Row format>}
Downloads last month
0
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and HF Inference API was unable to determine this model's library.

Model tree for hiauiarau/bge-m3-onnx-O4

Base model

BAAI/bge-m3
Quantized
(24)
this model