Upload folder using huggingface_hub
Browse files- Dockerfile +45 -0
- server.py +140 -0
Dockerfile
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04
|
2 |
+
|
3 |
+
# Set environment variables
|
4 |
+
ENV PYTHONUNBUFFERED=1 \
|
5 |
+
DEBIAN_FRONTEND=noninteractive \
|
6 |
+
CUDA_HOME=/usr/local/cuda \
|
7 |
+
PATH=/usr/local/cuda/bin:$PATH \
|
8 |
+
LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH \
|
9 |
+
NVIDIA_VISIBLE_DEVICES=all \
|
10 |
+
NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
11 |
+
|
12 |
+
# Install system dependencies
|
13 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
14 |
+
python3 \
|
15 |
+
python3-pip \
|
16 |
+
python3-dev \
|
17 |
+
build-essential \
|
18 |
+
ffmpeg \
|
19 |
+
libsndfile1 \
|
20 |
+
curl \
|
21 |
+
&& rm -rf /var/lib/apt/lists/*
|
22 |
+
|
23 |
+
# Upgrade pip and install build tools
|
24 |
+
RUN python3 -m pip install --upgrade pip setuptools wheel
|
25 |
+
|
26 |
+
WORKDIR /app
|
27 |
+
|
28 |
+
# Copy requirements first for better caching
|
29 |
+
COPY requirements.txt .
|
30 |
+
|
31 |
+
# Install PyTorch with CUDA support
|
32 |
+
RUN pip3 install --no-cache-dir torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121
|
33 |
+
|
34 |
+
# Install other requirements
|
35 |
+
RUN pip3 install --no-cache-dir -r requirements.txt
|
36 |
+
|
37 |
+
# Copy the rest of the application
|
38 |
+
COPY . .
|
39 |
+
|
40 |
+
# Create models directory
|
41 |
+
RUN mkdir -p /app/models
|
42 |
+
|
43 |
+
EXPOSE 8000
|
44 |
+
|
45 |
+
CMD ["python3", "server.py"]
|
server.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from pydantic import BaseModel
|
5 |
+
import base64
|
6 |
+
import io
|
7 |
+
import os
|
8 |
+
import logging
|
9 |
+
from pathlib import Path
|
10 |
+
from inference import InferenceRecipe
|
11 |
+
from fastapi.middleware.cors import CORSMiddleware
|
12 |
+
|
13 |
+
logging.basicConfig(level=logging.INFO)
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
app = FastAPI()
|
17 |
+
|
18 |
+
# Add CORS middleware
|
19 |
+
app.add_middleware(
|
20 |
+
CORSMiddleware,
|
21 |
+
allow_origins=["*"],
|
22 |
+
allow_credentials=True,
|
23 |
+
allow_methods=["*"],
|
24 |
+
allow_headers=["*"],
|
25 |
+
)
|
26 |
+
|
27 |
+
class AudioRequest(BaseModel):
|
28 |
+
audio_data: str # base64 encoded audio data
|
29 |
+
sample_rate: int
|
30 |
+
|
31 |
+
class AudioResponse(BaseModel):
|
32 |
+
audio_data: str # base64 encoded audio data
|
33 |
+
text: str = ""
|
34 |
+
|
35 |
+
# Model initialization status
|
36 |
+
INITIALIZATION_STATUS = {
|
37 |
+
"model_loaded": False,
|
38 |
+
"error": None
|
39 |
+
}
|
40 |
+
|
41 |
+
# Global model instance
|
42 |
+
model = None
|
43 |
+
|
44 |
+
def initialize_model():
|
45 |
+
"""Initialize the model with correct path resolution"""
|
46 |
+
global model, INITIALIZATION_STATUS
|
47 |
+
try:
|
48 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
49 |
+
logger.info(f"Initializing model on device: {device}")
|
50 |
+
|
51 |
+
# Use absolute path for model loading
|
52 |
+
model_path = os.path.abspath(os.path.join(os.getcwd(), 'models'))
|
53 |
+
logger.info(f"Loading models from: {model_path}")
|
54 |
+
|
55 |
+
if not os.path.exists(model_path):
|
56 |
+
raise RuntimeError(f"Model path {model_path} does not exist")
|
57 |
+
|
58 |
+
# Log available model files for debugging
|
59 |
+
model_files = os.listdir(model_path)
|
60 |
+
logger.info(f"Available model files: {model_files}")
|
61 |
+
|
62 |
+
model = InferenceRecipe(model_path, device=device)
|
63 |
+
INITIALIZATION_STATUS["model_loaded"] = True
|
64 |
+
logger.info("Model initialized successfully")
|
65 |
+
return True
|
66 |
+
except Exception as e:
|
67 |
+
INITIALIZATION_STATUS["error"] = str(e)
|
68 |
+
logger.error(f"Failed to initialize model: {e}")
|
69 |
+
return False
|
70 |
+
|
71 |
+
@app.on_event("startup")
|
72 |
+
async def startup_event():
|
73 |
+
"""Initialize model on startup"""
|
74 |
+
initialize_model()
|
75 |
+
|
76 |
+
@app.get("/api/v1/health")
|
77 |
+
def health_check():
|
78 |
+
"""Health check endpoint"""
|
79 |
+
status = {
|
80 |
+
"status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing",
|
81 |
+
"initialization_status": INITIALIZATION_STATUS
|
82 |
+
}
|
83 |
+
|
84 |
+
if model is not None:
|
85 |
+
status.update({
|
86 |
+
"device": str(model.device),
|
87 |
+
"model_path": str(model.model_path),
|
88 |
+
"mimi_loaded": model.mimi is not None,
|
89 |
+
"text_tokenizer_loaded": model.text_tokenizer is not None,
|
90 |
+
"lm_loaded": model.lm_gen is not None
|
91 |
+
})
|
92 |
+
|
93 |
+
return status
|
94 |
+
|
95 |
+
@app.post("/api/v1/inference")
|
96 |
+
async def inference(request: AudioRequest) -> AudioResponse:
|
97 |
+
"""Run inference with enhanced error handling and logging"""
|
98 |
+
if not INITIALIZATION_STATUS["model_loaded"]:
|
99 |
+
raise HTTPException(
|
100 |
+
status_code=503,
|
101 |
+
detail=f"Model not ready. Status: {INITIALIZATION_STATUS}"
|
102 |
+
)
|
103 |
+
|
104 |
+
try:
|
105 |
+
# Log input validation
|
106 |
+
logger.info(f"Received inference request with sample rate: {request.sample_rate}")
|
107 |
+
|
108 |
+
# Decode audio
|
109 |
+
audio_bytes = base64.b64decode(request.audio_data)
|
110 |
+
audio_array = np.load(io.BytesIO(audio_bytes))
|
111 |
+
logger.info(f"Decoded audio array shape: {audio_array.shape}, dtype: {audio_array.dtype}")
|
112 |
+
|
113 |
+
# Validate input format
|
114 |
+
if len(audio_array.shape) != 2:
|
115 |
+
raise ValueError(f"Expected 2D audio array [C,T], got shape {audio_array.shape}")
|
116 |
+
|
117 |
+
# Run inference
|
118 |
+
result = model.inference(audio_array, request.sample_rate)
|
119 |
+
logger.info(f"Inference complete. Output shape: {result['audio'].shape}")
|
120 |
+
|
121 |
+
# Encode output
|
122 |
+
buffer = io.BytesIO()
|
123 |
+
np.save(buffer, result['audio'])
|
124 |
+
audio_b64 = base64.b64encode(buffer.getvalue()).decode()
|
125 |
+
|
126 |
+
return AudioResponse(
|
127 |
+
audio_data=audio_b64,
|
128 |
+
text=result.get("text", "")
|
129 |
+
)
|
130 |
+
|
131 |
+
except Exception as e:
|
132 |
+
logger.error(f"Inference failed: {str(e)}", exc_info=True)
|
133 |
+
raise HTTPException(
|
134 |
+
status_code=500,
|
135 |
+
detail=str(e)
|
136 |
+
)
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
import uvicorn
|
140 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|