alia / app.py
repd79's picture
Update app.py
fe115a7 verified
import os
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
# Configurar caché y gestión de memoria
os.environ["TRANSFORMERS_CACHE"] = "/root/.cache/huggingface/"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# Nombre del modelo
model_name = "BSC-LT/ALIA-40b"
# Cargar modelo desde caché si es posible
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=os.getenv("TRANSFORMERS_CACHE"), local_files_only=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
cache_dir=os.getenv("TRANSFORMERS_CACHE"),
local_files_only=True,
device_map="auto",
offload_folder="offload_cache",
torch_dtype=torch.bfloat16
)
print("Modelo cargado desde caché.")
except Exception as e:
print("El modelo no se encontró en caché. Descargando...")
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=os.getenv("TRANSFORMERS_CACHE"))
model = AutoModelForCausalLM.from_pretrained(
model_name,
cache_dir=os.getenv("TRANSFORMERS_CACHE"),
device_map="auto",
offload_folder="offload_cache",
torch_dtype=torch.bfloat16
)
tokenizer.save_pretrained("/root/model_storage/")
model.save_pretrained("/root/model_storage/")
print("Modelo guardado en caché para futuras cargas.")
# Mostrar en qué dispositivo está el modelo
print(f"Modelo cargado en: {next(model.parameters()).device}")
def generar_texto(entrada):
torch.cuda.empty_cache() # Liberar caché antes de inferencia
input_ids = tokenizer(entrada, return_tensors="pt").input_ids.to("cuda")
output = model.generate(
input_ids,
max_length=100,
temperature=0.1,
top_p=0.95,
repetition_penalty=1.2,
do_sample=True
)
return tokenizer.decode(output[0], skip_special_tokens=True)
# Crear la interfaz de Gradio
interfaz = gr.Interface(
fn=generar_texto,
inputs=gr.Textbox(lines=2, placeholder="Escribe tu prompt aquí...", interactive=True),
outputs=gr.Textbox(interactive=True),
title="Generador de Texto con ALIA-40b",
description="Este modelo genera texto utilizando ALIA-40b, un modelo LLM entrenado por BSC-LT."
)
if __name__ == "__main__":
interfaz.launch(share=True, server_port=7860)