model-inference / app.py
htigenai's picture
Update app.py
3d12ac9 verified
raw
history blame
4.3 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import logging
import sys
import gc
import time
from contextlib import contextmanager
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
@contextmanager
def timer(description: str):
start = time.time()
yield
elapsed = time.time() - start
logger.info(f"{description}: {elapsed:.2f} seconds")
def log_system_info():
"""Log system information for debugging"""
logger.info(f"Python version: {sys.version}")
logger.info(f"PyTorch version: {torch.__version__}")
logger.info(f"Device: CPU")
print("Starting application...")
log_system_info()
try:
print("Loading model and tokenizer...")
model_id = "htigenai/finetune_test" # Replace with your chosen model ID
with timer("Loading tokenizer"):
tokenizer = AutoTokenizer.from_pretrained(
model_id,
use_fast=True, # Use fast tokenizer for better performance
cache_dir='./cache'
)
tokenizer.pad_token = tokenizer.eos_token
logger.info("Tokenizer loaded successfully")
with timer("Loading model"):
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map={"": "cpu"},
cache_dir='./cache'
)
model.eval()
logger.info("Model loaded successfully")
def generate_text(prompt, max_tokens=100, temperature=0.7):
"""Generate text based on the input prompt."""
try:
logger.info(f"Starting generation for prompt: {prompt[:50]}...")
with timer("Tokenization"):
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=256
).to("cpu") # Ensure inputs are on CPU
with timer("Generation"):
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=max_tokens,
temperature=temperature,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1,
)
with timer("Decoding"):
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info("Text generation completed successfully")
# Clean up
with timer("Cleanup"):
gc.collect()
return generated_text
except Exception as e:
logger.error(f"Error during generation: {str(e)}")
return f"Error during generation: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(
lines=3,
placeholder="Enter your prompt here...",
label="Input Prompt"
),
gr.Slider(
minimum=20,
maximum=200,
value=100,
step=10,
label="Max Tokens"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperature"
)
],
outputs=gr.Textbox(
label="Generated Response",
lines=10
),
title="Text Generation Demo",
description="Enter a prompt to generate text.",
examples=[
["What are your thoughts about cats?", 50, 0.7],
["Write a short story about a magical forest", 60, 0.8],
["Explain quantum computing to a 5-year-old", 40, 0.5],
]
)
iface.launch()
except Exception as e:
logger.error(f"Application startup failed: {str(e)}")
raise