import os # Keras Backend 설정: JAX를 가장 먼저 설정해야 합니다. os.environ["KERAS_BACKEND"] = "jax" # 메모리 fragmentation 방지 os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00" import keras_nlp import keras from jax import config import time import gradio as gr # JAX 환경설정, float32로 설정하여 Gemma 모델의 정확도 향상 config.update("jax_default_matmul_precision", "float32") # 모델 ID 및 LoRA 설정 model_id = "gemma2_instruct_2b_en" # 기본 모델 ID lora_rank = 4 # 허깅페이스 모델 로드 경로 설정 (모델이 로컬에 있다면 로컬 경로도 사용 가능) model_path = "hf://본인_허깅페이스_아이디/모델_저장소_이름" # 예: "hf://my_user/my_gemma_model" token_limit = 128 # 토큰 제한 설정 # 모델 파일 이름 (허깅페이스 스페이스 루트 디렉토리에 있다고 가정) model_file = "model.safetensors" config_file = "config.json" # 글로벌 시간 추적 변수 tick_start = 0 def tick(): """시간 측정 시작.""" global tick_start tick_start = time.time() def tock(): """시간 측정 종료 및 출력.""" print(f"총 소요 시간: {time.time() - tick_start:.2f}s") def load_model_with_explicit_path(model_path, model_file, config_file): """모델을 명시적인 파일 경로를 사용하여 로드합니다.""" try: # 모델이 허깅페이스에 있을 경우 if model_path.startswith("hf://"): gemma_lm_loaded = keras.saving.load_model(model_path) return gemma_lm_loaded # 모델이 로컬에 있을 경우 else: # model_file, config_file 경로 합치기 full_model_path = os.path.join(os.getcwd(), model_file) full_config_path = os.path.join(os.getcwd(), config_file) # load_model 함수 사용 (load_model 함수는 config파일을 인식하지 못함) gemma_lm_loaded = keras.saving.load_model(full_model_path) return gemma_lm_loaded except Exception as e: print(f"모델 로드 실패: {e}") return None # 허깅페이스 모델 로드 (수정됨) gemma_lm_loaded = load_model_with_explicit_path(model_path, model_file, config_file) # 텍스트 생성 함수 (수정됨) def generate_text(prompt): tick() # 시간 측정 시작 # 원본 모델의 토크나이저를 사용하여 토큰화 input_text = f"user\n{prompt}\nmodel\n" output = gemma_lm_loaded.generate(input_text, max_length=token_limit) # 모델 추론 tock() # 시간 측정 종료 및 출력 return output # Gradio 인터페이스 정의 iface = gr.Interface( fn=generate_text, inputs=gr.Textbox(label="사용자 입력"), outputs=gr.Textbox(label="챗봇 응답"), title="Gemma2 챗봇", description="Gemma2 기반의 한국어 대화 모델입니다." ) # Gradio 인터페이스 실행 iface.launch()