File size: 2,352 Bytes
7ab6a95
f849f21
9d138e8
 
502bbe0
a2a2ab4
7ab6a95
a2a2ab4
502bbe0
7ab6a95
 
e7873d1
a2a2ab4
7ab6a95
 
 
 
 
 
 
a2a2ab4
 
7ab6a95
 
 
 
 
 
 
 
 
a2a2ab4
 
7ab6a95
a2a2ab4
 
7ab6a95
2131cc5
a2a2ab4
e7873d1
624bf57
9d138e8
a2a2ab4
e7873d1
a2a2ab4
e7873d1
2131cc5
 
fe115a7
 
 
 
a2a2ab4
2131cc5
 
a2a2ab4
9d138e8
7ab6a95
9d138e8
 
e7873d1
 
9d138e8
 
497e57e
 
 
e7873d1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM

# Configurar caché y gestión de memoria
os.environ["TRANSFORMERS_CACHE"] = "/root/.cache/huggingface/"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Nombre del modelo
model_name = "BSC-LT/ALIA-40b"

# Cargar modelo desde caché si es posible
try:
    tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=os.getenv("TRANSFORMERS_CACHE"), local_files_only=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        cache_dir=os.getenv("TRANSFORMERS_CACHE"),
        local_files_only=True,
        device_map="auto",
        offload_folder="offload_cache",
        torch_dtype=torch.bfloat16
    )
    print("Modelo cargado desde caché.")
except Exception as e:
    print("El modelo no se encontró en caché. Descargando...")
    tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=os.getenv("TRANSFORMERS_CACHE"))
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        cache_dir=os.getenv("TRANSFORMERS_CACHE"),
        device_map="auto",
        offload_folder="offload_cache",
        torch_dtype=torch.bfloat16
    )
    tokenizer.save_pretrained("/root/model_storage/")
    model.save_pretrained("/root/model_storage/")
    print("Modelo guardado en caché para futuras cargas.")

# Mostrar en qué dispositivo está el modelo
print(f"Modelo cargado en: {next(model.parameters()).device}")

def generar_texto(entrada):
    torch.cuda.empty_cache()  # Liberar caché antes de inferencia

    input_ids = tokenizer(entrada, return_tensors="pt").input_ids.to("cuda")

    output = model.generate(
        input_ids,
        max_length=100,
        temperature=0.1,
        top_p=0.95,
        repetition_penalty=1.2,
        do_sample=True
    )

    return tokenizer.decode(output[0], skip_special_tokens=True)

# Crear la interfaz de Gradio
interfaz = gr.Interface(
    fn=generar_texto,
    inputs=gr.Textbox(lines=2, placeholder="Escribe tu prompt aquí...", interactive=True),
    outputs=gr.Textbox(interactive=True),
    title="Generador de Texto con ALIA-40b",
    description="Este modelo genera texto utilizando ALIA-40b, un modelo LLM entrenado por BSC-LT."
)

if __name__ == "__main__":
    interfaz.launch(share=True, server_port=7860)