sanbo commited on
Commit
bf8b09b
·
1 Parent(s): e767741

update sth. at 2025-02-03 21:03:19

Browse files
Files changed (1) hide show
  1. app.py +38 -17
app.py CHANGED
@@ -4,16 +4,35 @@ import torch
4
  import gradio as gr
5
  from fastapi import FastAPI, HTTPException
6
  from fastapi.middleware.cors import CORSMiddleware
7
- from pydantic import BaseModel
8
- from typing import List, Dict
9
  from functools import lru_cache
10
- import numpy as np
11
  from threading import Lock
12
  import uvicorn
13
 
14
  class EmbeddingRequest(BaseModel):
15
- input: str
16
- model: str = "jinaai/jina-embeddings-v3"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  class EmbeddingResponse(BaseModel):
19
  status: str
@@ -21,7 +40,7 @@ class EmbeddingResponse(BaseModel):
21
 
22
  class EmbeddingService:
23
  def __init__(self):
24
- self.model_name = "jinaai/jina-embeddings-v3"
25
  self.max_length = 512
26
  self.device = torch.device("cpu")
27
  self.model = None
@@ -41,23 +60,22 @@ class EmbeddingService:
41
  try:
42
  from transformers import AutoTokenizer, AutoModel
43
  self.tokenizer = AutoTokenizer.from_pretrained(
44
- self.model_name,
45
  trust_remote_code=True
46
  )
47
  self.model = AutoModel.from_pretrained(
48
- self.model_name,
49
  trust_remote_code=True
50
  ).to(self.device)
51
  self.model.eval()
52
  torch.set_grad_enabled(False)
53
- self.logger.info(f"模型加载成功,使用设备: {self.device}")
54
  except Exception as e:
55
  self.logger.error(f"模型初始化失败: {str(e)}")
56
  raise
57
 
58
  @lru_cache(maxsize=1000)
59
  def get_embedding(self, text: str) -> List[float]:
60
- """同步生成嵌入向量,带缓存"""
61
  with self.lock:
62
  try:
63
  inputs = self.tokenizer(
@@ -85,7 +103,8 @@ app.add_middleware(
85
  allow_methods=["*"],
86
  allow_headers=["*"],
87
  )
88
-
 
89
  @app.post("/api/embed", response_model=EmbeddingResponse)
90
  @app.post("/v1/embeddings", response_model=EmbeddingResponse)
91
  @app.post("/generate_embeddings", response_model=EmbeddingResponse)
@@ -95,11 +114,10 @@ app.add_middleware(
95
  @app.post("/hf/v1/chat/completions", response_model=EmbeddingResponse)
96
  async def generate_embeddings(request: EmbeddingRequest):
97
  try:
98
- # 使用run_in_executor避免事件循环问题
99
  embedding = await asyncio.get_running_loop().run_in_executor(
100
  None,
101
  embedding_service.get_embedding,
102
- request.input
103
  )
104
  return EmbeddingResponse(
105
  status="success",
@@ -112,7 +130,7 @@ async def generate_embeddings(request: EmbeddingRequest):
112
  async def root():
113
  return {
114
  "status": "active",
115
- "model": embedding_service.model_name,
116
  "device": str(embedding_service.device)
117
  }
118
 
@@ -134,8 +152,11 @@ iface = gr.Interface(
134
  inputs=gr.Textbox(lines=3, label="输入文本"),
135
  outputs=gr.JSON(label="嵌入向量结果"),
136
  title="Jina Embeddings V3",
137
- description="使用jina-embeddings-v3模型生成文本嵌入向量",
138
- examples=[["这是一个测试句子。"]]
 
 
 
139
  )
140
 
141
  @app.on_event("startup")
@@ -145,4 +166,4 @@ async def startup_event():
145
  if __name__ == "__main__":
146
  asyncio.run(embedding_service.initialize())
147
  gr.mount_gradio_app(app, iface, path="/ui")
148
- uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)
 
4
  import gradio as gr
5
  from fastapi import FastAPI, HTTPException
6
  from fastapi.middleware.cors import CORSMiddleware
7
+ from pydantic import BaseModel, Field, root_validator
8
+ from typing import List, Dict, Optional
9
  from functools import lru_cache
 
10
  from threading import Lock
11
  import uvicorn
12
 
13
  class EmbeddingRequest(BaseModel):
14
+ # 强制锁定模型参数
15
+ model: str = Field(
16
+ default="jinaai/jina-embeddings-v3",
17
+ description="此参数仅用于API兼容,实际模型固定为jinaai/jina-embeddings-v3",
18
+ frozen=True # 禁止修改
19
+ )
20
+ # 支持三种输入字段
21
+ inputs: Optional[str] = Field(None, description="输入文本(兼容HuggingFace格式)")
22
+ input: Optional[str] = Field(None, description="输入文本(兼容OpenAI格式)")
23
+ prompt: Optional[str] = Field(None, description="输入文本(兼容Ollama格式)")
24
+
25
+ # 自动合并输入字段
26
+ @root_validator(pre=True)
27
+ def merge_input_fields(cls, values):
28
+ input_fields = ["inputs", "input", "prompt"]
29
+ for field in input_fields:
30
+ if values.get(field):
31
+ values["inputs"] = values[field]
32
+ break
33
+ else:
34
+ raise ValueError("必须提供 inputs/input/prompt 任一字段")
35
+ return values
36
 
37
  class EmbeddingResponse(BaseModel):
38
  status: str
 
40
 
41
  class EmbeddingService:
42
  def __init__(self):
43
+ self._true_model_name = "jinaai/jina-embeddings-v3" # 硬编码模型名称
44
  self.max_length = 512
45
  self.device = torch.device("cpu")
46
  self.model = None
 
60
  try:
61
  from transformers import AutoTokenizer, AutoModel
62
  self.tokenizer = AutoTokenizer.from_pretrained(
63
+ self._true_model_name,
64
  trust_remote_code=True
65
  )
66
  self.model = AutoModel.from_pretrained(
67
+ self._true_model_name,
68
  trust_remote_code=True
69
  ).to(self.device)
70
  self.model.eval()
71
  torch.set_grad_enabled(False)
72
+ self.logger.info(f"强制加载模型: {self._true_model_name}")
73
  except Exception as e:
74
  self.logger.error(f"模型初始化失败: {str(e)}")
75
  raise
76
 
77
  @lru_cache(maxsize=1000)
78
  def get_embedding(self, text: str) -> List[float]:
 
79
  with self.lock:
80
  try:
81
  inputs = self.tokenizer(
 
103
  allow_methods=["*"],
104
  allow_headers=["*"],
105
  )
106
+ @app.post("/embed", response_model=EmbeddingResponse)
107
+ @app.post("/api/embeddings", response_model=EmbeddingResponse)
108
  @app.post("/api/embed", response_model=EmbeddingResponse)
109
  @app.post("/v1/embeddings", response_model=EmbeddingResponse)
110
  @app.post("/generate_embeddings", response_model=EmbeddingResponse)
 
114
  @app.post("/hf/v1/chat/completions", response_model=EmbeddingResponse)
115
  async def generate_embeddings(request: EmbeddingRequest):
116
  try:
 
117
  embedding = await asyncio.get_running_loop().run_in_executor(
118
  None,
119
  embedding_service.get_embedding,
120
+ request.inputs # 使用合并后的输入字段
121
  )
122
  return EmbeddingResponse(
123
  status="success",
 
130
  async def root():
131
  return {
132
  "status": "active",
133
+ "true_model": embedding_service._true_model_name,
134
  "device": str(embedding_service.device)
135
  }
136
 
 
152
  inputs=gr.Textbox(lines=3, label="输入文本"),
153
  outputs=gr.JSON(label="嵌入向量结果"),
154
  title="Jina Embeddings V3",
155
+ description="强制使用jinaai/jina-embeddings-v3模型(无视请求中的model参数)",
156
+ examples=[[
157
+ "Represent this sentence for searching relevant passages: "
158
+ "The sky is blue because of Rayleigh scattering"
159
+ ]]
160
  )
161
 
162
  @app.on_event("startup")
 
166
  if __name__ == "__main__":
167
  asyncio.run(embedding_service.initialize())
168
  gr.mount_gradio_app(app, iface, path="/ui")
169
+ uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)