|
import gradio as gr |
|
import torch |
|
from transformers import pipeline |
|
import logging |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" |
|
) |
|
|
|
ARTICLE_GENERATOR_MODEL = "gpt2" |
|
SUMMARIZER_MODEL = "Falconsai/text_summarization" |
|
TITLE_GENERATOR_MODEL = "czearing/article-title-generator" |
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
logging.info(f"Using device: {DEVICE}") |
|
|
|
logging.info("Initializing models...") |
|
text_generator = pipeline( |
|
"text-generation", model=ARTICLE_GENERATOR_MODEL, device=DEVICE |
|
) |
|
summarizer = pipeline("summarization", model=SUMMARIZER_MODEL, device=DEVICE) |
|
title_generator = pipeline( |
|
"text2text-generation", model=TITLE_GENERATOR_MODEL, device=DEVICE |
|
) |
|
logging.info("Models initialized successfully") |
|
|
|
|
|
def generate_article(query, max_new_tokens): |
|
logging.info(f"Generating article for query: {query}") |
|
article = text_generator( |
|
query, |
|
max_new_tokens=max_new_tokens, |
|
num_return_sequences=1, |
|
)[0]["generated_text"] |
|
logging.debug(f"Generated article: {article[:100]}...") |
|
return article |
|
|
|
|
|
def generate_title(article): |
|
logging.info("Generating title") |
|
title = title_generator(article, num_return_sequences=1)[0]["generated_text"] |
|
logging.debug(f"Generated title: {title}") |
|
return title |
|
|
|
|
|
def generate_summary(article): |
|
logging.info("Generating summary") |
|
summary = summarizer( |
|
article, |
|
do_sample=False, |
|
)[ |
|
0 |
|
]["summary_text"] |
|
logging.debug(f"Generated summary: {summary}") |
|
return summary |
|
|
|
|
|
def generate_blog_post(query, max_new_tokens): |
|
logging.info("Starting blog post generation") |
|
|
|
logging.info("Generating article") |
|
article = generate_article(query, max_new_tokens) |
|
|
|
logging.info("Generating title") |
|
title = generate_title(article) |
|
|
|
logging.info("Generating summary") |
|
summary = generate_summary(article) |
|
|
|
logging.info("Blog post generation completed") |
|
return title, summary, article |
|
|
|
|
|
with gr.Blocks() as iface: |
|
gr.Markdown("# Blog Post Generator") |
|
gr.Markdown( |
|
"Enter a topic, and I'll generate a blog post with a title and summary!" |
|
) |
|
|
|
with gr.Row(): |
|
input_prompt = gr.Textbox( |
|
label="Input Prompt", lines=2, placeholder="Enter your blog post topic..." |
|
) |
|
|
|
with gr.Row(): |
|
generate_button = gr.Button("Generate Blog Post", size="sm") |
|
|
|
gr.Examples( |
|
examples=[ |
|
"The future of artificial intelligence in healthcare", |
|
"Top 10 travel destinations for nature lovers", |
|
"How to start a successful online business in 2024", |
|
"The impact of climate change on global food security", |
|
], |
|
inputs=input_prompt, |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
with gr.Blocks() as title_block: |
|
gr.Markdown("## Title") |
|
title_output = gr.Textbox(label="Title") |
|
|
|
with gr.Blocks() as body_block: |
|
gr.Markdown("## Body") |
|
article_output = gr.Textbox(label="Article", lines=30) |
|
with gr.Accordion("Options", open=False): |
|
max_new_tokens = gr.Slider( |
|
minimum=20, |
|
maximum=1000, |
|
value=500, |
|
step=10, |
|
label="Max New Tokens", |
|
) |
|
|
|
with gr.Column(scale=1): |
|
with gr.Blocks() as summary_block: |
|
gr.Markdown("## Summary") |
|
summary_output = gr.Textbox(label="Summary", lines=5) |
|
|
|
generate_button.click( |
|
generate_blog_post, |
|
inputs=[input_prompt, max_new_tokens], |
|
outputs=[title_output, summary_output, article_output], |
|
) |
|
|
|
logging.info("Launching Gradio interface") |
|
iface.queue().launch() |
|
|