Spaces:
Running
Running
import asyncio | |
import logging | |
import torch | |
import gradio as gr | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel, Field, root_validator | |
from typing import List, Dict, Optional | |
from functools import lru_cache | |
from threading import Lock | |
import uvicorn | |
class EmbeddingRequest(BaseModel): | |
# 强制锁定模型参数 | |
model: str = Field( | |
default="jinaai/jina-embeddings-v3", | |
description="此参数仅用于API兼容,实际模型固定为jinaai/jina-embeddings-v3", | |
frozen=True # 禁止修改 | |
) | |
# 支持三种输入字段 | |
inputs: Optional[str] = Field(None, description="输入文本(兼容HuggingFace格式)") | |
input: Optional[str] = Field(None, description="输入文本(兼容OpenAI格式)") | |
prompt: Optional[str] = Field(None, description="输入文本(兼容Ollama格式)") | |
# 自动合并输入字段 | |
def merge_input_fields(cls, values): | |
input_fields = ["inputs", "input", "prompt"] | |
for field in input_fields: | |
if values.get(field): | |
values["inputs"] = values[field] | |
break | |
else: | |
raise ValueError("必须提供 inputs/input/prompt 任一字段") | |
return values | |
class EmbeddingResponse(BaseModel): | |
status: str | |
embeddings: List[List[float]] | |
class EmbeddingService: | |
def __init__(self): | |
self._true_model_name = "jinaai/jina-embeddings-v3" # 硬编码模型名称 | |
self.max_length = 512 | |
self.device = torch.device("cpu") | |
self.model = None | |
self.tokenizer = None | |
self.lock = Lock() | |
self.setup_logging() | |
torch.set_num_threads(4) # CPU优化 | |
def setup_logging(self): | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
self.logger = logging.getLogger(__name__) | |
async def initialize(self): | |
try: | |
from transformers import AutoTokenizer, AutoModel | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self._true_model_name, | |
trust_remote_code=True | |
) | |
self.model = AutoModel.from_pretrained( | |
self._true_model_name, | |
trust_remote_code=True | |
).to(self.device) | |
self.model.eval() | |
torch.set_grad_enabled(False) | |
self.logger.info(f"强制加载模型: {self._true_model_name}") | |
except Exception as e: | |
self.logger.error(f"模型初始化失败: {str(e)}") | |
raise | |
def get_embedding(self, text: str) -> List[float]: | |
with self.lock: | |
try: | |
inputs = self.tokenizer( | |
text, | |
return_tensors="pt", | |
truncation=True, | |
max_length=self.max_length, | |
padding=True | |
) | |
with torch.no_grad(): | |
outputs = self.model(**inputs).last_hidden_state.mean(dim=1) | |
return outputs.numpy().tolist()[0] | |
except Exception as e: | |
self.logger.error(f"生成嵌入向量失败: {str(e)}") | |
raise | |
embedding_service = EmbeddingService() | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def generate_embeddings(request: EmbeddingRequest): | |
try: | |
embedding = await asyncio.get_running_loop().run_in_executor( | |
None, | |
embedding_service.get_embedding, | |
request.inputs # 使用合并后的输入字段 | |
) | |
return EmbeddingResponse( | |
status="success", | |
embeddings=[embedding] | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def root(): | |
return { | |
"status": "active", | |
"true_model": embedding_service._true_model_name, | |
"device": str(embedding_service.device) | |
} | |
def gradio_interface(text: str) -> Dict: | |
try: | |
embedding = embedding_service.get_embedding(text) | |
return { | |
"status": "success", | |
"embeddings": [embedding] | |
} | |
except Exception as e: | |
return { | |
"status": "error", | |
"message": str(e) | |
} | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=gr.Textbox(lines=3, label="输入文本"), | |
outputs=gr.JSON(label="嵌入向量结果"), | |
title="Jina Embeddings V3", | |
description="强制使用jinaai/jina-embeddings-v3模型(无视请求中的model参数)", | |
examples=[[ | |
"Represent this sentence for searching relevant passages: " | |
"The sky is blue because of Rayleigh scattering" | |
]] | |
) | |
async def startup_event(): | |
await embedding_service.initialize() | |
if __name__ == "__main__": | |
asyncio.run(embedding_service.initialize()) | |
gr.mount_gradio_app(app, iface, path="/ui") | |
uvicorn.run(app, host="0.0.0.0", port=7860, workers=1) |