sanbo
update sth. at 2025-02-03 21:03:19
bf8b09b
import asyncio
import logging
import torch
import gradio as gr
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, root_validator
from typing import List, Dict, Optional
from functools import lru_cache
from threading import Lock
import uvicorn
class EmbeddingRequest(BaseModel):
# 强制锁定模型参数
model: str = Field(
default="jinaai/jina-embeddings-v3",
description="此参数仅用于API兼容,实际模型固定为jinaai/jina-embeddings-v3",
frozen=True # 禁止修改
)
# 支持三种输入字段
inputs: Optional[str] = Field(None, description="输入文本(兼容HuggingFace格式)")
input: Optional[str] = Field(None, description="输入文本(兼容OpenAI格式)")
prompt: Optional[str] = Field(None, description="输入文本(兼容Ollama格式)")
# 自动合并输入字段
@root_validator(pre=True)
def merge_input_fields(cls, values):
input_fields = ["inputs", "input", "prompt"]
for field in input_fields:
if values.get(field):
values["inputs"] = values[field]
break
else:
raise ValueError("必须提供 inputs/input/prompt 任一字段")
return values
class EmbeddingResponse(BaseModel):
status: str
embeddings: List[List[float]]
class EmbeddingService:
def __init__(self):
self._true_model_name = "jinaai/jina-embeddings-v3" # 硬编码模型名称
self.max_length = 512
self.device = torch.device("cpu")
self.model = None
self.tokenizer = None
self.lock = Lock()
self.setup_logging()
torch.set_num_threads(4) # CPU优化
def setup_logging(self):
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(__name__)
async def initialize(self):
try:
from transformers import AutoTokenizer, AutoModel
self.tokenizer = AutoTokenizer.from_pretrained(
self._true_model_name,
trust_remote_code=True
)
self.model = AutoModel.from_pretrained(
self._true_model_name,
trust_remote_code=True
).to(self.device)
self.model.eval()
torch.set_grad_enabled(False)
self.logger.info(f"强制加载模型: {self._true_model_name}")
except Exception as e:
self.logger.error(f"模型初始化失败: {str(e)}")
raise
@lru_cache(maxsize=1000)
def get_embedding(self, text: str) -> List[float]:
with self.lock:
try:
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=self.max_length,
padding=True
)
with torch.no_grad():
outputs = self.model(**inputs).last_hidden_state.mean(dim=1)
return outputs.numpy().tolist()[0]
except Exception as e:
self.logger.error(f"生成嵌入向量失败: {str(e)}")
raise
embedding_service = EmbeddingService()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/embed", response_model=EmbeddingResponse)
@app.post("/api/embeddings", response_model=EmbeddingResponse)
@app.post("/api/embed", response_model=EmbeddingResponse)
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
@app.post("/generate_embeddings", response_model=EmbeddingResponse)
@app.post("/api/v1/embeddings", response_model=EmbeddingResponse)
@app.post("/hf/v1/embeddings", response_model=EmbeddingResponse)
@app.post("/api/v1/chat/completions", response_model=EmbeddingResponse)
@app.post("/hf/v1/chat/completions", response_model=EmbeddingResponse)
async def generate_embeddings(request: EmbeddingRequest):
try:
embedding = await asyncio.get_running_loop().run_in_executor(
None,
embedding_service.get_embedding,
request.inputs # 使用合并后的输入字段
)
return EmbeddingResponse(
status="success",
embeddings=[embedding]
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
return {
"status": "active",
"true_model": embedding_service._true_model_name,
"device": str(embedding_service.device)
}
def gradio_interface(text: str) -> Dict:
try:
embedding = embedding_service.get_embedding(text)
return {
"status": "success",
"embeddings": [embedding]
}
except Exception as e:
return {
"status": "error",
"message": str(e)
}
iface = gr.Interface(
fn=gradio_interface,
inputs=gr.Textbox(lines=3, label="输入文本"),
outputs=gr.JSON(label="嵌入向量结果"),
title="Jina Embeddings V3",
description="强制使用jinaai/jina-embeddings-v3模型(无视请求中的model参数)",
examples=[[
"Represent this sentence for searching relevant passages: "
"The sky is blue because of Rayleigh scattering"
]]
)
@app.on_event("startup")
async def startup_event():
await embedding_service.initialize()
if __name__ == "__main__":
asyncio.run(embedding_service.initialize())
gr.mount_gradio_app(app, iface, path="/ui")
uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)