File size: 2,352 Bytes
7ab6a95 f849f21 9d138e8 502bbe0 a2a2ab4 7ab6a95 a2a2ab4 502bbe0 7ab6a95 e7873d1 a2a2ab4 7ab6a95 a2a2ab4 7ab6a95 a2a2ab4 7ab6a95 a2a2ab4 7ab6a95 2131cc5 a2a2ab4 e7873d1 624bf57 9d138e8 a2a2ab4 e7873d1 a2a2ab4 e7873d1 2131cc5 fe115a7 a2a2ab4 2131cc5 a2a2ab4 9d138e8 7ab6a95 9d138e8 e7873d1 9d138e8 497e57e e7873d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import os
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
# Configurar caché y gestión de memoria
os.environ["TRANSFORMERS_CACHE"] = "/root/.cache/huggingface/"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# Nombre del modelo
model_name = "BSC-LT/ALIA-40b"
# Cargar modelo desde caché si es posible
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=os.getenv("TRANSFORMERS_CACHE"), local_files_only=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
cache_dir=os.getenv("TRANSFORMERS_CACHE"),
local_files_only=True,
device_map="auto",
offload_folder="offload_cache",
torch_dtype=torch.bfloat16
)
print("Modelo cargado desde caché.")
except Exception as e:
print("El modelo no se encontró en caché. Descargando...")
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=os.getenv("TRANSFORMERS_CACHE"))
model = AutoModelForCausalLM.from_pretrained(
model_name,
cache_dir=os.getenv("TRANSFORMERS_CACHE"),
device_map="auto",
offload_folder="offload_cache",
torch_dtype=torch.bfloat16
)
tokenizer.save_pretrained("/root/model_storage/")
model.save_pretrained("/root/model_storage/")
print("Modelo guardado en caché para futuras cargas.")
# Mostrar en qué dispositivo está el modelo
print(f"Modelo cargado en: {next(model.parameters()).device}")
def generar_texto(entrada):
torch.cuda.empty_cache() # Liberar caché antes de inferencia
input_ids = tokenizer(entrada, return_tensors="pt").input_ids.to("cuda")
output = model.generate(
input_ids,
max_length=100,
temperature=0.1,
top_p=0.95,
repetition_penalty=1.2,
do_sample=True
)
return tokenizer.decode(output[0], skip_special_tokens=True)
# Crear la interfaz de Gradio
interfaz = gr.Interface(
fn=generar_texto,
inputs=gr.Textbox(lines=2, placeholder="Escribe tu prompt aquí...", interactive=True),
outputs=gr.Textbox(interactive=True),
title="Generador de Texto con ALIA-40b",
description="Este modelo genera texto utilizando ALIA-40b, un modelo LLM entrenado por BSC-LT."
)
if __name__ == "__main__":
interfaz.launch(share=True, server_port=7860)
|