import gradio as gr import subprocess import os import shutil import tempfile import torch import logging import numpy as np import re from concurrent.futures import ThreadPoolExecutor from functools import lru_cache # 로깅 설정 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('yue_generation.log'), logging.StreamHandler() ] ) def optimize_gpu_settings(): if torch.cuda.is_available(): # GPU 메모리 관리 최적화 torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True torch.backends.cudnn.enabled = True torch.backends.cudnn.deterministic = False # L40S에 최적화된 메모리 설정 torch.cuda.empty_cache() torch.cuda.set_device(0) # CUDA 스트림 최적화 torch.cuda.Stream(0) # 메모리 할당 최적화 os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512' logging.info(f"Using GPU: {torch.cuda.get_device_name(0)}") logging.info(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") # L40S 특화 설정 if 'L40S' in torch.cuda.get_device_name(0): torch.cuda.set_per_process_memory_fraction(0.95) def analyze_lyrics(lyrics, repeat_chorus=2): lines = [line.strip() for line in lyrics.split('\n') if line.strip()] sections = { 'verse': 0, 'chorus': 0, 'bridge': 0, 'total_lines': len(lines) } current_section = None section_lines = { 'verse': [], 'chorus': [], 'bridge': [] } last_section = None # 마지막 섹션 태그 찾기 for i, line in enumerate(lines): if '[verse]' in line.lower() or '[chorus]' in line.lower() or '[bridge]' in line.lower(): last_section = i for i, line in enumerate(lines): lower_line = line.lower() # 섹션 태그 처리 if '[verse]' in lower_line: if current_section: # 이전 섹션의 라인들 저장 section_lines[current_section].extend(lines[last_section_start:i]) current_section = 'verse' sections['verse'] += 1 last_section_start = i + 1 continue elif '[chorus]' in lower_line: if current_section: section_lines[current_section].extend(lines[last_section_start:i]) current_section = 'chorus' sections['chorus'] += 1 last_section_start = i + 1 continue elif '[bridge]' in lower_line: if current_section: section_lines[current_section].extend(lines[last_section_start:i]) current_section = 'bridge' sections['bridge'] += 1 last_section_start = i + 1 continue # 마지막 섹션의 라인들 추가 if current_section and last_section_start < len(lines): section_lines[current_section].extend(lines[last_section_start:]) # 코러스 반복 처리 if sections['chorus'] > 0 and repeat_chorus > 1: original_chorus = section_lines['chorus'][:] for _ in range(repeat_chorus - 1): section_lines['chorus'].extend(original_chorus) # 섹션별 라인 수 확인 로깅 logging.info(f"Section line counts - Verse: {len(section_lines['verse'])}, " f"Chorus: {len(section_lines['chorus'])}, " f"Bridge: {len(section_lines['bridge'])}") return sections, (sections['verse'] + sections['chorus'] + sections['bridge']), len(lines), section_lines def calculate_generation_params(lyrics): sections, total_sections, total_lines, section_lines = analyze_lyrics(lyrics) # 기본 시간 계산 (초 단위) time_per_line = { 'verse': 4, # verse는 한 줄당 4초 'chorus': 6, # chorus는 한 줄당 6초 'bridge': 5 # bridge는 한 줄당 5초 } # 각 섹션별 예상 시간 계산 (마지막 섹션 포함) section_durations = {} for section_type in ['verse', 'chorus', 'bridge']: lines_count = len(section_lines[section_type]) section_durations[section_type] = lines_count * time_per_line[section_type] # 전체 시간 계산 (여유 시간 추가) total_duration = sum(duration for duration in section_durations.values()) total_duration = max(60, int(total_duration * 1.2)) # 20% 여유 시간 추가 # 토큰 계산 (마지막 섹션을 위한 추가 토큰) base_tokens = 3000 tokens_per_line = 200 extra_tokens = 1000 # 마지막 섹션을 위한 추가 토큰 total_tokens = base_tokens + (total_lines * tokens_per_line) + extra_tokens # 세그먼트 수 계산 (마지막 섹션을 위한 추가 세그먼트) if sections['chorus'] > 0: num_segments = 4 # 코러스가 있는 경우 4개 세그먼트 else: num_segments = 3 # 코러스가 없는 경우 3개 세그먼트 # 토큰 수 제한 (더 큰 제한) max_tokens = min(12000, total_tokens) # 최대 토큰 수 증가 return { 'max_tokens': max_tokens, 'num_segments': num_segments, 'sections': sections, 'section_lines': section_lines, 'estimated_duration': total_duration, 'section_durations': section_durations, 'has_chorus': sections['chorus'] > 0 } def detect_and_select_model(text): if re.search(r'[\u3131-\u318E\uAC00-\uD7A3]', text): return "m-a-p/YuE-s1-7B-anneal-jp-kr-cot" elif re.search(r'[\u4e00-\u9fff]', text): return "m-a-p/YuE-s1-7B-anneal-zh-cot" elif re.search(r'[\u3040-\u309F\u30A0-\u30FF]', text): return "m-a-p/YuE-s1-7B-anneal-jp-kr-cot" else: return "m-a-p/YuE-s1-7B-anneal-en-cot" def install_flash_attn(): try: if not torch.cuda.is_available(): logging.warning("GPU not available, skipping flash-attn installation") return False cuda_version = torch.version.cuda if cuda_version is None: logging.warning("CUDA not available, skipping flash-attn installation") return False logging.info(f"Detected CUDA version: {cuda_version}") try: import flash_attn logging.info("flash-attn already installed") return True except ImportError: logging.info("Installing flash-attn...") subprocess.run( ["pip", "install", "flash-attn", "--no-build-isolation"], check=True, capture_output=True ) logging.info("flash-attn installed successfully!") return True except Exception as e: logging.warning(f"Failed to install flash-attn: {e}") return False def initialize_system(): optimize_gpu_settings() with ThreadPoolExecutor(max_workers=4) as executor: futures = [] futures.append(executor.submit(install_flash_attn)) from huggingface_hub import snapshot_download folder_path = './inference/xcodec_mini_infer' os.makedirs(folder_path, exist_ok=True) logging.info(f"Created folder at: {folder_path}") futures.append(executor.submit( snapshot_download, repo_id="m-a-p/xcodec_mini_infer", local_dir="./inference/xcodec_mini_infer", resume_download=True )) for future in futures: future.result() try: os.chdir("./inference") logging.info(f"Working directory changed to: {os.getcwd()}") except FileNotFoundError as e: logging.error(f"Directory error: {e}") raise @lru_cache(maxsize=100) def get_cached_file_path(content_hash, prefix): return create_temp_file(content_hash, prefix) def empty_output_folder(output_dir): try: shutil.rmtree(output_dir) os.makedirs(output_dir) logging.info(f"Output folder cleaned: {output_dir}") except Exception as e: logging.error(f"Error cleaning output folder: {e}") raise def create_temp_file(content, prefix, suffix=".txt"): temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=prefix, suffix=suffix) content = content.strip() + "\n\n" content = content.replace("\r\n", "\n").replace("\r", "\n") temp_file.write(content) temp_file.close() logging.debug(f"Temporary file created: {temp_file.name}") return temp_file.name def get_last_mp3_file(output_dir): mp3_files = [f for f in os.listdir(output_dir) if f.endswith('.mp3')] if not mp3_files: logging.warning("No MP3 files found") return None mp3_files_with_path = [os.path.join(output_dir, f) for f in mp3_files] mp3_files_with_path.sort(key=os.path.getmtime, reverse=True) return mp3_files_with_path[0] def get_audio_duration(file_path): try: import librosa duration = librosa.get_duration(path=file_path) return duration except Exception as e: logging.error(f"Failed to get audio duration: {e}") return None def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens): genre_txt_path = None lyrics_txt_path = None try: model_path, config, params = optimize_model_selection(lyrics_txt_content, genre_txt_content) logging.info(f"Selected model: {model_path}") logging.info(f"Lyrics analysis: {params}") has_chorus = params['sections']['chorus'] > 0 estimated_duration = params.get('estimated_duration', 90) # 세그먼트 및 토큰 수 설정 if has_chorus: actual_max_tokens = min(12000, int(config['max_tokens'] * 1.3)) # 30% 더 많은 토큰 actual_num_segments = min(5, params['num_segments'] + 2) # 추가 세그먼트 else: actual_max_tokens = min(10000, int(config['max_tokens'] * 1.2)) actual_num_segments = min(4, params['num_segments'] + 1) logging.info(f"Estimated duration: {estimated_duration} seconds") logging.info(f"Has chorus sections: {has_chorus}") logging.info(f"Using segments: {actual_num_segments}, tokens: {actual_max_tokens}") genre_txt_path = create_temp_file(genre_txt_content, prefix="genre_") lyrics_txt_path = create_temp_file(lyrics_txt_content, prefix="lyrics_") output_dir = "./output" os.makedirs(output_dir, exist_ok=True) empty_output_folder(output_dir) # 수정된 command - 지원되지 않는 인수 제거 command = [ "python", "infer.py", "--stage1_model", model_path, "--stage2_model", "m-a-p/YuE-s2-1B-general", "--genre_txt", genre_txt_path, "--lyrics_txt", lyrics_txt_path, "--run_n_segments", str(actual_num_segments), "--stage2_batch_size", "16", "--output_dir", output_dir, "--cuda_idx", "0", "--max_new_tokens", str(actual_max_tokens), "--disable_offload_model" # GPU 메모리 최적화를 위해 추가 ] env = os.environ.copy() if torch.cuda.is_available(): env.update({ "CUDA_VISIBLE_DEVICES": "0", "CUDA_HOME": "/usr/local/cuda", "PATH": f"/usr/local/cuda/bin:{env.get('PATH', '')}", "LD_LIBRARY_PATH": f"/usr/local/cuda/lib64:{env.get('LD_LIBRARY_PATH', '')}", "PYTORCH_CUDA_ALLOC_CONF": "max_split_size_mb:512", "CUDA_LAUNCH_BLOCKING": "0" }) # transformers 캐시 마이그레이션 처리 try: from transformers.utils import move_cache move_cache() except Exception as e: logging.warning(f"Cache migration warning (non-critical): {e}") process = subprocess.run( command, env=env, check=False, capture_output=True, text=True ) logging.info(f"Command output: {process.stdout}") if process.stderr: logging.error(f"Command error: {process.stderr}") if process.returncode != 0: logging.error(f"Command failed with return code: {process.returncode}") logging.error(f"Command: {' '.join(command)}") raise RuntimeError(f"Inference failed: {process.stderr}") last_mp3 = get_last_mp3_file(output_dir) if last_mp3: try: duration = get_audio_duration(last_mp3) logging.info(f"Generated audio file: {last_mp3}") if duration: logging.info(f"Audio duration: {duration:.2f} seconds") logging.info(f"Expected duration: {estimated_duration} seconds") if duration < estimated_duration * 0.8: logging.warning(f"Generated audio is shorter than expected: {duration:.2f}s < {estimated_duration:.2f}s") except Exception as e: logging.warning(f"Failed to get audio duration: {e}") return last_mp3 else: logging.warning("No output audio file generated") return None except Exception as e: logging.error(f"Inference error: {e}") raise finally: for path in [genre_txt_path, lyrics_txt_path]: if path and os.path.exists(path): try: os.remove(path) logging.debug(f"Removed temporary file: {path}") except Exception as e: logging.warning(f"Failed to remove temporary file {path}: {e}") def optimize_model_selection(lyrics, genre): model_path = detect_and_select_model(lyrics) params = calculate_generation_params(lyrics) has_chorus = params['sections']['chorus'] > 0 tokens_per_segment = params['max_tokens'] // params['num_segments'] model_config = { "m-a-p/YuE-s1-7B-anneal-en-cot": { "max_tokens": params['max_tokens'], "temperature": 0.8, "batch_size": 16, "num_segments": params['num_segments'], "estimated_duration": params['estimated_duration'] }, "m-a-p/YuE-s1-7B-anneal-jp-kr-cot": { "max_tokens": params['max_tokens'], "temperature": 0.7, "batch_size": 16, "num_segments": params['num_segments'], "estimated_duration": params['estimated_duration'] }, "m-a-p/YuE-s1-7B-anneal-zh-cot": { "max_tokens": params['max_tokens'], "temperature": 0.7, "batch_size": 16, "num_segments": params['num_segments'], "estimated_duration": params['estimated_duration'] } } if has_chorus: for config in model_config.values(): config['max_tokens'] = int(config['max_tokens'] * 1.5) return model_path, model_config[model_path], params def main(): with gr.Blocks() as demo: with gr.Column(): gr.Markdown("# Open SUNO: Full-Song Generation (Multi-Language Support)") with gr.Row(): with gr.Column(): genre_txt = gr.Textbox( label="Genre", placeholder="Enter music genre and style descriptions..." ) lyrics_txt = gr.Textbox( label="Lyrics (Supports English, Korean, Japanese, Chinese)", placeholder="Enter song lyrics with [verse], [chorus], [bridge] tags...", lines=10 ) with gr.Column(): num_segments = gr.Number( label="Number of Song Segments (Auto-adjusted based on lyrics)", value=2, minimum=1, maximum=4, step=1, interactive=False ) max_new_tokens = gr.Slider( label="Max New Tokens (Auto-adjusted based on lyrics)", minimum=500, maximum=32000, step=500, value=4000, interactive=False ) with gr.Row(): duration_info = gr.Label(label="Estimated Duration") sections_info = gr.Label(label="Section Information") submit_btn = gr.Button("Generate Music", variant="primary") music_out = gr.Audio(label="Generated Audio") gr.Examples( examples=[ [ "female blues airy vocal bright vocal piano sad romantic guitar jazz", """[verse] In the quiet of the evening, shadows start to fall Whispers of the night wind echo through the hall Lost within the silence, I hear your gentle voice Guiding me back homeward, making my heart rejoice [chorus] Don't let this moment fade, hold me close tonight With you here beside me, everything's alright Can't imagine life alone, don't want to let you go Stay with me forever, let our love just flow [verse] In the quiet of the evening, shadows start to fall Whispers of the night wind echo through the hall Lost within the silence, I hear your gentle voice Guiding me back homeward, making my heart rejoice [chorus] Don't let this moment fade, hold me close tonight With you here beside me, everything's alright Can't imagine life alone, don't want to let you go Stay with me forever, let our love just flow""" ], [ "K-pop bright energetic synth dance electronic", """[verse] 언젠가 마주한 눈빛 속에서 [chorus] 다시 한 번 내게 말해줘 [verse] 어두운 밤을 지날 때마다 [chorus] 다시 한 번 내게 말해줘 """ ] ], inputs=[genre_txt, lyrics_txt] ) initialize_system() def update_info(lyrics): if not lyrics: return "No lyrics entered", "No sections detected" params = calculate_generation_params(lyrics) duration = params['estimated_duration'] sections = params['sections'] return ( f"Estimated duration: {duration:.1f} seconds", f"Verses: {sections['verse']}, Chorus: {sections['chorus']} (Expected full length including chorus)" ) lyrics_txt.change( fn=update_info, inputs=[lyrics_txt], outputs=[duration_info, sections_info] ) submit_btn.click( fn=infer, inputs=[genre_txt, lyrics_txt, num_segments, max_new_tokens], outputs=[music_out] ) return demo if __name__ == "__main__": demo = main() demo.queue(max_size=20).launch( server_name="0.0.0.0", server_port=7860, share=True, show_api=True, show_error=True, max_threads=8 )