from fastapi import FastAPI, HTTPException import numpy as np import torch from pydantic import BaseModel import base64 import io import os import logging from pathlib import Path from inference import InferenceRecipe from fastapi.middleware.cors import CORSMiddleware logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class AudioRequest(BaseModel): audio_data: str sample_rate: int class AudioResponse(BaseModel): audio_data: str text: str = "" # Model initialization status INITIALIZATION_STATUS = { "model_loaded": False, "error": None } # Global model instance model = None def initialize_model(): """Initialize the model with correct path resolution""" global model, INITIALIZATION_STATUS try: device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Initializing model on device: {device}") # Critical: Use absolute path for model loading model_path = os.path.abspath(os.path.join('/app/src', 'models')) logger.info(f"Loading models from: {model_path}") if not os.path.exists(model_path): raise RuntimeError(f"Model path {model_path} does not exist") # Log available model files for debugging model_files = os.listdir(model_path) logger.info(f"Available model files: {model_files}") model = InferenceRecipe(model_path, device=device) INITIALIZATION_STATUS["model_loaded"] = True logger.info("Model initialized successfully") return True except Exception as e: INITIALIZATION_STATUS["error"] = str(e) logger.error(f"Failed to initialize model: {e}") return False @app.on_event("startup") async def startup_event(): """Initialize model on startup""" initialize_model() @app.get("/api/v1/health") def health_check(): """Health check endpoint""" status = { "status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing", "initialization_status": INITIALIZATION_STATUS } if model is not None: status.update({ "device": str(model.device), "model_path": str(model.model_path), "mimi_loaded": model.mimi is not None, "tokenizer_loaded": model.text_tokenizer is not None, "lm_loaded": model.lm_gen is not None }) return status # @app.post("/api/v1/inference") # async def inference(request: AudioRequest) -> AudioResponse: # """Run inference on audio input""" # if not INITIALIZATION_STATUS["model_loaded"]: # raise HTTPException( # status_code=503, # detail=f"Model not ready. Status: {INITIALIZATION_STATUS}" # ) # try: # # Decode audio from base64 # audio_bytes = base64.b64decode(request.audio_data) # audio_array = np.load(io.BytesIO(audio_bytes)) # # Run inference # result = model.inference(audio_array, request.sample_rate) # # Encode output audio # buffer = io.BytesIO() # np.save(buffer, result['audio']) # audio_b64 = base64.b64encode(buffer.getvalue()).decode() # return AudioResponse( # audio_data=audio_b64, # text=result.get("text", "") # ) # except Exception as e: # logger.error(f"Inference failed: {str(e)}") # raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/v1/inference") async def inference(request: AudioRequest) -> AudioResponse: """Run inference with enhanced error handling and logging""" if not INITIALIZATION_STATUS["model_loaded"]: raise HTTPException( status_code=503, detail=f"Model not ready. Status: {INITIALIZATION_STATUS}" ) try: # Log input validation logger.info(f"Received inference request with sample rate: {request.sample_rate}") # Decode audio audio_bytes = base64.b64decode(request.audio_data) audio_array = np.load(io.BytesIO(audio_bytes)) logger.info(f"Decoded audio array shape: {audio_array.shape}, dtype: {audio_array.dtype}") # Validate input format if len(audio_array.shape) != 2: raise ValueError(f"Expected 2D audio array [C,T], got shape {audio_array.shape}") # Run inference result = model.inference(audio_array, request.sample_rate) logger.info(f"Inference complete. Output shape: {result['audio'].shape}") # Encode output buffer = io.BytesIO() np.save(buffer, result['audio']) audio_b64 = base64.b64encode(buffer.getvalue()).decode() return AudioResponse( audio_data=audio_b64, text=result.get("text", "") ) except Exception as e: logger.error(f"Inference failed: {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail=str(e) ) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)