|
from fastapi import FastAPI, HTTPException |
|
import numpy as np |
|
import torch |
|
from pydantic import BaseModel |
|
import base64 |
|
import io |
|
import os |
|
import logging |
|
from pathlib import Path |
|
from inference import InferenceRecipe |
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
class AudioRequest(BaseModel): |
|
audio_data: str |
|
sample_rate: int |
|
|
|
class AudioResponse(BaseModel): |
|
audio_data: str |
|
text: str = "" |
|
|
|
|
|
INITIALIZATION_STATUS = { |
|
"model_loaded": False, |
|
"error": None |
|
} |
|
|
|
|
|
model = None |
|
|
|
|
|
def initialize_model(): |
|
"""Initialize the model with correct path resolution""" |
|
global model, INITIALIZATION_STATUS |
|
try: |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
logger.info(f"Initializing model on device: {device}") |
|
|
|
|
|
model_path = os.path.abspath(os.path.join('/app/src', 'models')) |
|
logger.info(f"Loading models from: {model_path}") |
|
|
|
if not os.path.exists(model_path): |
|
raise RuntimeError(f"Model path {model_path} does not exist") |
|
|
|
|
|
model_files = os.listdir(model_path) |
|
logger.info(f"Available model files: {model_files}") |
|
|
|
model = InferenceRecipe(model_path, device=device) |
|
INITIALIZATION_STATUS["model_loaded"] = True |
|
logger.info("Model initialized successfully") |
|
return True |
|
except Exception as e: |
|
INITIALIZATION_STATUS["error"] = str(e) |
|
logger.error(f"Failed to initialize model: {e}") |
|
return False |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
"""Initialize model on startup""" |
|
initialize_model() |
|
|
|
@app.get("/api/v1/health") |
|
def health_check(): |
|
"""Health check endpoint""" |
|
status = { |
|
"status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing", |
|
"initialization_status": INITIALIZATION_STATUS |
|
} |
|
|
|
if model is not None: |
|
status.update({ |
|
"device": str(model.device), |
|
"model_path": str(model.model_path), |
|
"mimi_loaded": model.mimi is not None, |
|
"tokenizer_loaded": model.text_tokenizer is not None, |
|
"lm_loaded": model.lm_gen is not None |
|
}) |
|
|
|
return status |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/v1/inference") |
|
async def inference(request: AudioRequest) -> AudioResponse: |
|
"""Run inference with enhanced error handling and logging""" |
|
if not INITIALIZATION_STATUS["model_loaded"]: |
|
raise HTTPException( |
|
status_code=503, |
|
detail=f"Model not ready. Status: {INITIALIZATION_STATUS}" |
|
) |
|
|
|
try: |
|
|
|
logger.info(f"Received inference request with sample rate: {request.sample_rate}") |
|
|
|
|
|
audio_bytes = base64.b64decode(request.audio_data) |
|
audio_array = np.load(io.BytesIO(audio_bytes)) |
|
logger.info(f"Decoded audio array shape: {audio_array.shape}, dtype: {audio_array.dtype}") |
|
|
|
|
|
if len(audio_array.shape) != 2: |
|
raise ValueError(f"Expected 2D audio array [C,T], got shape {audio_array.shape}") |
|
|
|
|
|
result = model.inference(audio_array, request.sample_rate) |
|
logger.info(f"Inference complete. Output shape: {result['audio'].shape}") |
|
|
|
|
|
buffer = io.BytesIO() |
|
np.save(buffer, result['audio']) |
|
audio_b64 = base64.b64encode(buffer.getvalue()).decode() |
|
|
|
return AudioResponse( |
|
audio_data=audio_b64, |
|
text=result.get("text", "") |
|
) |
|
|
|
except Exception as e: |
|
logger.error(f"Inference failed: {str(e)}", exc_info=True) |
|
raise HTTPException( |
|
status_code=500, |
|
detail=str(e) |
|
) |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |