import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import torch import logging import sys import gc import time from contextlib import contextmanager # Set up logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)] ) logger = logging.getLogger(__name__) @contextmanager def timer(description: str): start = time.time() yield elapsed = time.time() - start logger.info(f"{description}: {elapsed:.2f} seconds") def log_system_info(): """Log system information for debugging""" logger.info(f"Python version: {sys.version}") logger.info(f"PyTorch version: {torch.__version__}") logger.info(f"Device: CPU") print("Starting application...") log_system_info() try: print("Loading model and tokenizer...") model_id = "htigenai/finetune_test" # Replace with your chosen model ID with timer("Loading tokenizer"): tokenizer = AutoTokenizer.from_pretrained( model_id, use_fast=True, # Use fast tokenizer for better performance cache_dir='./cache' ) tokenizer.pad_token = tokenizer.eos_token logger.info("Tokenizer loaded successfully") with timer("Loading model"): model = AutoModelForCausalLM.from_pretrained( model_id, device_map={"": "cpu"}, cache_dir='./cache' ) model.eval() logger.info("Model loaded successfully") def generate_text(prompt, max_tokens=100, temperature=0.7): """Generate text based on the input prompt.""" try: logger.info(f"Starting generation for prompt: {prompt[:50]}...") with timer("Tokenization"): inputs = tokenizer( prompt, return_tensors="pt", padding=True, truncation=True, max_length=256 ).to("cpu") # Ensure inputs are on CPU with timer("Generation"): with torch.no_grad(): outputs = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=max_tokens, temperature=temperature, top_p=0.95, do_sample=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, repetition_penalty=1.1, ) with timer("Decoding"): generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) logger.info("Text generation completed successfully") # Clean up with timer("Cleanup"): gc.collect() return generated_text except Exception as e: logger.error(f"Error during generation: {str(e)}") return f"Error during generation: {str(e)}" # Create Gradio interface iface = gr.Interface( fn=generate_text, inputs=[ gr.Textbox( lines=3, placeholder="Enter your prompt here...", label="Input Prompt" ), gr.Slider( minimum=20, maximum=200, value=100, step=10, label="Max Tokens" ), gr.Slider( minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature" ) ], outputs=gr.Textbox( label="Generated Response", lines=10 ), title="Text Generation Demo", description="Enter a prompt to generate text.", examples=[ ["What are your thoughts about cats?", 50, 0.7], ["Write a short story about a magical forest", 60, 0.8], ["Explain quantum computing to a 5-year-old", 40, 0.5], ] ) iface.launch() except Exception as e: logger.error(f"Application startup failed: {str(e)}") raise