from fastapi import FastAPI, HTTPException import numpy as np import torch import base64 import io import os import logging from pathlib import Path from inference import InferenceRecipe from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel # Configure PyTorch behavior - only use supported configs torch._dynamo.config.suppress_errors = True # Disable optimizations via environment variables os.environ["TORCH_LOGS"] = "+dynamo" os.environ["TORCHDYNAMO_VERBOSE"] = "1" os.environ["TORCH_COMPILE_DEBUG"] = "1" os.environ["TORCHINDUCTOR_DISABLE_CUDAGRAPHS"] = "1" os.environ["TORCH_COMPILE"] = "0" # Disable torch.compile logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class AudioRequest(BaseModel): audio_data: str sample_rate: int class AudioResponse(BaseModel): audio_data: str text: str = "" # Model initialization status INITIALIZATION_STATUS = { "model_loaded": False, "error": None } # Global model instance model = None def initialize_model(): """Initialize the model with correct path resolution""" global model, INITIALIZATION_STATUS try: device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Initializing model on device: {device}") model_path = os.path.abspath(os.path.join('/app/src', 'models')) logger.info(f"Loading models from: {model_path}") if not os.path.exists(model_path): raise RuntimeError(f"Model path {model_path} does not exist") model_files = os.listdir(model_path) logger.info(f"Available model files: {model_files}") model = InferenceRecipe(model_path, device=device) INITIALIZATION_STATUS["model_loaded"] = True logger.info("Model initialized successfully") return True except Exception as e: INITIALIZATION_STATUS["error"] = str(e) logger.error(f"Failed to initialize model: {e}") return False @app.on_event("startup") async def startup_event(): """Initialize model on startup""" initialize_model() @app.get("/api/v1/health") def health_check(): """Health check endpoint""" status = { "status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing", "initialization_status": INITIALIZATION_STATUS } if model is not None: status.update({ "device": str(model.device), "model_path": str(model.model_path), "mimi_loaded": model.mimi is not None, "tokenizer_loaded": model.text_tokenizer is not None, "lm_loaded": model.lm_gen is not None }) return status @app.post("/api/v1/inference") async def inference(request: AudioRequest) -> AudioResponse: """Run inference with enhanced error handling and logging""" if not INITIALIZATION_STATUS["model_loaded"]: raise HTTPException( status_code=503, detail=f"Model not ready. Status: {INITIALIZATION_STATUS}" ) try: logger.info(f"Received inference request with sample rate: {request.sample_rate}") audio_bytes = base64.b64decode(request.audio_data) audio_array = np.load(io.BytesIO(audio_bytes)) logger.info(f"Decoded audio array shape: {audio_array.shape}, dtype: {audio_array.dtype}") if len(audio_array.shape) != 2: raise ValueError(f"Expected 2D audio array [C,T], got shape {audio_array.shape}") result = model.inference(audio_array, request.sample_rate) logger.info(f"Inference complete. Output shape: {result['audio'].shape}") buffer = io.BytesIO() np.save(buffer, result['audio']) audio_b64 = base64.b64encode(buffer.getvalue()).decode() return AudioResponse( audio_data=audio_b64, text=result.get("text", "") ) except Exception as e: logger.error(f"Inference failed: {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail=str(e) ) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)