import os import torch import gradio as gr from train_optimized import GPT, GPTConfig from huggingface_hub import hf_hub_download import json # Cache for model and tokenizer MODEL = None CHARS = None STOI = None ITOS = None def initialize(): global MODEL, CHARS, STOI, ITOS if MODEL is None: print("Loading model and tokenizer...") # Download model files from HF Hub config_path = hf_hub_download(repo_id="jatingocodeo/shakespeare-decoder", filename="config.json") model_path = hf_hub_download(repo_id="jatingocodeo/shakespeare-decoder", filename="pytorch_model.bin") # Load config with open(config_path, 'r') as f: config_dict = json.load(f) # Initialize model with config config = GPTConfig( vocab_size=config_dict['vocab_size'], n_layer=config_dict['n_layer'], n_head=config_dict['n_head'], n_embd=config_dict['n_embd'], block_size=config_dict['block_size'], dropout=config_dict['dropout'], bias=config_dict['bias'] ) model = GPT(config) # Load model weights state_dict = torch.load(model_path, map_location='cpu') model.load_state_dict(state_dict) model.eval() MODEL = model # Initialize tokenizer # Download input.txt from the repository try: input_path = hf_hub_download(repo_id="jatingocodeo/shakespeare-decoder", filename="input.txt") with open(input_path, 'r', encoding='utf-8') as f: text = f.read() except: # Fallback to Shakespeare text if input.txt is not in the repo text = """ First Citizen: Before we proceed any further, hear me speak. All: Speak, speak. First Citizen: You are all resolved rather to die than to famish? """ CHARS = sorted(list(set(text))) STOI = {ch:i for i,ch in enumerate(CHARS)} ITOS = {i:ch for i,ch in enumerate(CHARS)} print("Model and tokenizer loaded successfully!") def generate_text( prompt, max_new_tokens=100, temperature=0.8, top_k=50 ): # Initialize if not already done if MODEL is None: initialize() # Encode the prompt encode = lambda s: [STOI[c] for c in s] decode = lambda l: ''.join([ITOS[i] for i in l]) try: # Convert prompt to tensor x = torch.tensor(encode(prompt), dtype=torch.long)[None,...] # Generate with torch.no_grad(): y = MODEL.generate(x, max_new_tokens, temperature, top_k)[0] # Decode and return generated_text = decode(y.tolist()) return generated_text except KeyError: return "Error: The prompt contains characters that are not in the training data. Please use only standard English characters." except Exception as e: return f"Error generating text: {str(e)}" # Initialize on startup initialize() # Create Gradio interface demo = gr.Interface( fn=generate_text, inputs=[ gr.Textbox( label="Prompt", placeholder="Enter your prompt here...", lines=5 ), gr.Slider( label="Max New Tokens", minimum=10, maximum=500, value=100, step=10 ), gr.Slider( label="Temperature", minimum=0.1, maximum=2.0, value=0.8, step=0.1 ), gr.Slider( label="Top-k", minimum=1, maximum=100, value=50, step=1 ), ], outputs=gr.Textbox(label="Generated Text", lines=10), title="Shakespeare GPT", description=""" This is a GPT model trained on Shakespeare's text. Enter a prompt and the model will continue it in Shakespeare's style. Parameters: - Temperature: Higher values make the output more random, lower values make it more deterministic - Top-k: Number of highest probability tokens to consider at each step - Max New Tokens: Maximum number of tokens to generate """, examples=[ ["To be, or not to be,", 100, 0.8, 50], ["Friends, Romans, countrymen,", 150, 0.7, 40], ["Now is the winter of", 200, 0.9, 30], ] ) if __name__ == "__main__": demo.launch()