tezuesh commited on
Commit
1f42983
·
verified ·
1 Parent(s): 4768721

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. Dockerfile +45 -0
  2. server.py +140 -0
Dockerfile ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04
2
+
3
+ # Set environment variables
4
+ ENV PYTHONUNBUFFERED=1 \
5
+ DEBIAN_FRONTEND=noninteractive \
6
+ CUDA_HOME=/usr/local/cuda \
7
+ PATH=/usr/local/cuda/bin:$PATH \
8
+ LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH \
9
+ NVIDIA_VISIBLE_DEVICES=all \
10
+ NVIDIA_DRIVER_CAPABILITIES=compute,utility
11
+
12
+ # Install system dependencies
13
+ RUN apt-get update && apt-get install -y --no-install-recommends \
14
+ python3 \
15
+ python3-pip \
16
+ python3-dev \
17
+ build-essential \
18
+ ffmpeg \
19
+ libsndfile1 \
20
+ curl \
21
+ && rm -rf /var/lib/apt/lists/*
22
+
23
+ # Upgrade pip and install build tools
24
+ RUN python3 -m pip install --upgrade pip setuptools wheel
25
+
26
+ WORKDIR /app
27
+
28
+ # Copy requirements first for better caching
29
+ COPY requirements.txt .
30
+
31
+ # Install PyTorch with CUDA support
32
+ RUN pip3 install --no-cache-dir torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121
33
+
34
+ # Install other requirements
35
+ RUN pip3 install --no-cache-dir -r requirements.txt
36
+
37
+ # Copy the rest of the application
38
+ COPY . .
39
+
40
+ # Create models directory
41
+ RUN mkdir -p /app/models
42
+
43
+ EXPOSE 8000
44
+
45
+ CMD ["python3", "server.py"]
server.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ import numpy as np
3
+ import torch
4
+ from pydantic import BaseModel
5
+ import base64
6
+ import io
7
+ import os
8
+ import logging
9
+ from pathlib import Path
10
+ from inference import InferenceRecipe
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ app = FastAPI()
17
+
18
+ # Add CORS middleware
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"],
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+
27
+ class AudioRequest(BaseModel):
28
+ audio_data: str # base64 encoded audio data
29
+ sample_rate: int
30
+
31
+ class AudioResponse(BaseModel):
32
+ audio_data: str # base64 encoded audio data
33
+ text: str = ""
34
+
35
+ # Model initialization status
36
+ INITIALIZATION_STATUS = {
37
+ "model_loaded": False,
38
+ "error": None
39
+ }
40
+
41
+ # Global model instance
42
+ model = None
43
+
44
+ def initialize_model():
45
+ """Initialize the model with correct path resolution"""
46
+ global model, INITIALIZATION_STATUS
47
+ try:
48
+ device = "cuda" if torch.cuda.is_available() else "cpu"
49
+ logger.info(f"Initializing model on device: {device}")
50
+
51
+ # Use absolute path for model loading
52
+ model_path = os.path.abspath(os.path.join(os.getcwd(), 'models'))
53
+ logger.info(f"Loading models from: {model_path}")
54
+
55
+ if not os.path.exists(model_path):
56
+ raise RuntimeError(f"Model path {model_path} does not exist")
57
+
58
+ # Log available model files for debugging
59
+ model_files = os.listdir(model_path)
60
+ logger.info(f"Available model files: {model_files}")
61
+
62
+ model = InferenceRecipe(model_path, device=device)
63
+ INITIALIZATION_STATUS["model_loaded"] = True
64
+ logger.info("Model initialized successfully")
65
+ return True
66
+ except Exception as e:
67
+ INITIALIZATION_STATUS["error"] = str(e)
68
+ logger.error(f"Failed to initialize model: {e}")
69
+ return False
70
+
71
+ @app.on_event("startup")
72
+ async def startup_event():
73
+ """Initialize model on startup"""
74
+ initialize_model()
75
+
76
+ @app.get("/api/v1/health")
77
+ def health_check():
78
+ """Health check endpoint"""
79
+ status = {
80
+ "status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing",
81
+ "initialization_status": INITIALIZATION_STATUS
82
+ }
83
+
84
+ if model is not None:
85
+ status.update({
86
+ "device": str(model.device),
87
+ "model_path": str(model.model_path),
88
+ "mimi_loaded": model.mimi is not None,
89
+ "text_tokenizer_loaded": model.text_tokenizer is not None,
90
+ "lm_loaded": model.lm_gen is not None
91
+ })
92
+
93
+ return status
94
+
95
+ @app.post("/api/v1/inference")
96
+ async def inference(request: AudioRequest) -> AudioResponse:
97
+ """Run inference with enhanced error handling and logging"""
98
+ if not INITIALIZATION_STATUS["model_loaded"]:
99
+ raise HTTPException(
100
+ status_code=503,
101
+ detail=f"Model not ready. Status: {INITIALIZATION_STATUS}"
102
+ )
103
+
104
+ try:
105
+ # Log input validation
106
+ logger.info(f"Received inference request with sample rate: {request.sample_rate}")
107
+
108
+ # Decode audio
109
+ audio_bytes = base64.b64decode(request.audio_data)
110
+ audio_array = np.load(io.BytesIO(audio_bytes))
111
+ logger.info(f"Decoded audio array shape: {audio_array.shape}, dtype: {audio_array.dtype}")
112
+
113
+ # Validate input format
114
+ if len(audio_array.shape) != 2:
115
+ raise ValueError(f"Expected 2D audio array [C,T], got shape {audio_array.shape}")
116
+
117
+ # Run inference
118
+ result = model.inference(audio_array, request.sample_rate)
119
+ logger.info(f"Inference complete. Output shape: {result['audio'].shape}")
120
+
121
+ # Encode output
122
+ buffer = io.BytesIO()
123
+ np.save(buffer, result['audio'])
124
+ audio_b64 = base64.b64encode(buffer.getvalue()).decode()
125
+
126
+ return AudioResponse(
127
+ audio_data=audio_b64,
128
+ text=result.get("text", "")
129
+ )
130
+
131
+ except Exception as e:
132
+ logger.error(f"Inference failed: {str(e)}", exc_info=True)
133
+ raise HTTPException(
134
+ status_code=500,
135
+ detail=str(e)
136
+ )
137
+
138
+ if __name__ == "__main__":
139
+ import uvicorn
140
+ uvicorn.run(app, host="0.0.0.0", port=8000)