jatingocodeo's picture
Create app.py
7b5b9e7 verified
raw
history blame
4.54 kB
import os
import torch
import gradio as gr
from train_optimized import GPT, GPTConfig
from huggingface_hub import hf_hub_download
import json
# Cache for model and tokenizer
MODEL = None
CHARS = None
STOI = None
ITOS = None
def initialize():
global MODEL, CHARS, STOI, ITOS
if MODEL is None:
print("Loading model and tokenizer...")
# Download model files from HF Hub
config_path = hf_hub_download(repo_id="jatingocodeo/shakespeare-decoder", filename="config.json")
model_path = hf_hub_download(repo_id="jatingocodeo/shakespeare-decoder", filename="pytorch_model.bin")
# Load config
with open(config_path, 'r') as f:
config_dict = json.load(f)
# Initialize model with config
config = GPTConfig(
vocab_size=config_dict['vocab_size'],
n_layer=config_dict['n_layer'],
n_head=config_dict['n_head'],
n_embd=config_dict['n_embd'],
block_size=config_dict['block_size'],
dropout=config_dict['dropout'],
bias=config_dict['bias']
)
model = GPT(config)
# Load model weights
state_dict = torch.load(model_path, map_location='cpu')
model.load_state_dict(state_dict)
model.eval()
MODEL = model
# Initialize tokenizer
# Download input.txt from the repository
try:
input_path = hf_hub_download(repo_id="jatingocodeo/shakespeare-decoder", filename="input.txt")
with open(input_path, 'r', encoding='utf-8') as f:
text = f.read()
except:
# Fallback to Shakespeare text if input.txt is not in the repo
text = """
First Citizen:
Before we proceed any further, hear me speak.
All:
Speak, speak.
First Citizen:
You are all resolved rather to die than to famish?
"""
CHARS = sorted(list(set(text)))
STOI = {ch:i for i,ch in enumerate(CHARS)}
ITOS = {i:ch for i,ch in enumerate(CHARS)}
print("Model and tokenizer loaded successfully!")
def generate_text(
prompt,
max_new_tokens=100,
temperature=0.8,
top_k=50
):
# Initialize if not already done
if MODEL is None:
initialize()
# Encode the prompt
encode = lambda s: [STOI[c] for c in s]
decode = lambda l: ''.join([ITOS[i] for i in l])
try:
# Convert prompt to tensor
x = torch.tensor(encode(prompt), dtype=torch.long)[None,...]
# Generate
with torch.no_grad():
y = MODEL.generate(x, max_new_tokens, temperature, top_k)[0]
# Decode and return
generated_text = decode(y.tolist())
return generated_text
except KeyError:
return "Error: The prompt contains characters that are not in the training data. Please use only standard English characters."
except Exception as e:
return f"Error generating text: {str(e)}"
# Initialize on startup
initialize()
# Create Gradio interface
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here...",
lines=5
),
gr.Slider(
label="Max New Tokens",
minimum=10,
maximum=500,
value=100,
step=10
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=2.0,
value=0.8,
step=0.1
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=100,
value=50,
step=1
),
],
outputs=gr.Textbox(label="Generated Text", lines=10),
title="Shakespeare GPT",
description="""
This is a GPT model trained on Shakespeare's text. Enter a prompt and the model will continue it in Shakespeare's style.
Parameters:
- Temperature: Higher values make the output more random, lower values make it more deterministic
- Top-k: Number of highest probability tokens to consider at each step
- Max New Tokens: Maximum number of tokens to generate
""",
examples=[
["To be, or not to be,", 100, 0.8, 50],
["Friends, Romans, countrymen,", 150, 0.7, 40],
["Now is the winter of", 200, 0.9, 30],
]
)
if __name__ == "__main__":
demo.launch()