File size: 3,886 Bytes
76e2719
 
11a93c7
a0fd028
 
 
 
 
 
76e2719
f19162e
 
 
 
 
a0fd028
76e2719
a0fd028
11a93c7
f19162e
11a93c7
f19162e
11a93c7
a0fd028
f19162e
a0fd028
76e2719
11a93c7
a0fd028
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11a93c7
 
a0fd028
 
f19162e
 
 
a0fd028
 
 
 
 
 
 
 
 
11a93c7
a0fd028
 
f19162e
a0fd028
 
 
 
 
 
 
 
11a93c7
 
 
 
 
a0fd028
11a93c7
 
 
870ebd0
 
 
11a93c7
f19162e
 
11a93c7
a0fd028
 
 
 
 
 
 
 
 
 
11a93c7
 
f19162e
 
 
 
 
 
a0fd028
f19162e
a0fd028
 
98dcdff
f19162e
a0fd028
 
f19162e
11a93c7
 
f19162e
 
 
11a93c7
a0fd028
11a93c7
a0fd028
 
11a93c7
76e2719
a0fd028
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
import torch
from transformers import pipeline
import logging

# Set up logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)

ARTICLE_GENERATOR_MODEL = "gpt2"
SUMMARIZER_MODEL = "Falconsai/text_summarization"
TITLE_GENERATOR_MODEL = "czearing/article-title-generator"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
logging.info(f"Using device: {DEVICE}")

logging.info("Initializing models...")
text_generator = pipeline(
    "text-generation", model=ARTICLE_GENERATOR_MODEL, device=DEVICE
)
summarizer = pipeline("summarization", model=SUMMARIZER_MODEL, device=DEVICE)
title_generator = pipeline(
    "text2text-generation", model=TITLE_GENERATOR_MODEL, device=DEVICE
)
logging.info("Models initialized successfully")


def generate_article(query, max_new_tokens):
    logging.info(f"Generating article for query: {query}")
    article = text_generator(
        query,
        max_new_tokens=max_new_tokens,
        num_return_sequences=1,
    )[0]["generated_text"]
    logging.debug(f"Generated article: {article[:100]}...")
    return article


def generate_title(article):
    logging.info("Generating title")
    title = title_generator(article, num_return_sequences=1)[0]["generated_text"]
    logging.debug(f"Generated title: {title}")
    return title


def generate_summary(article):
    logging.info("Generating summary")
    summary = summarizer(
        article,
        do_sample=False,
    )[
        0
    ]["summary_text"]
    logging.debug(f"Generated summary: {summary}")
    return summary


def generate_blog_post(query, max_new_tokens):
    logging.info("Starting blog post generation")

    logging.info("Generating article")
    article = generate_article(query, max_new_tokens)

    logging.info("Generating title")
    title = generate_title(article)

    logging.info("Generating summary")
    summary = generate_summary(article)

    logging.info("Blog post generation completed")
    return title, summary, article


with gr.Blocks() as iface:
    gr.Markdown("# Blog Post Generator")
    gr.Markdown(
        "Enter a topic, and I'll generate a blog post with a title and summary!"
    )

    with gr.Row():
        input_prompt = gr.Textbox(
            label="Input Prompt", lines=2, placeholder="Enter your blog post topic..."
        )

    with gr.Row():
        generate_button = gr.Button("Generate Blog Post", size="sm")

    gr.Examples(
        examples=[
            "The future of artificial intelligence in healthcare",
            "Top 10 travel destinations for nature lovers",
            "How to start a successful online business in 2024",
            "The impact of climate change on global food security",
        ],
        inputs=input_prompt,
    )

    with gr.Row():
        with gr.Column(scale=2):
            with gr.Blocks() as title_block:
                gr.Markdown("## Title")
                title_output = gr.Textbox(label="Title")

            with gr.Blocks() as body_block:
                gr.Markdown("## Body")
                article_output = gr.Textbox(label="Article", lines=30)
                with gr.Accordion("Options", open=False):
                    max_new_tokens = gr.Slider(
                        minimum=20,
                        maximum=1000,
                        value=500,
                        step=10,
                        label="Max New Tokens",
                    )

        with gr.Column(scale=1):
            with gr.Blocks() as summary_block:
                gr.Markdown("## Summary")
                summary_output = gr.Textbox(label="Summary", lines=5)

    generate_button.click(
        generate_blog_post,
        inputs=[input_prompt, max_new_tokens],
        outputs=[title_output, summary_output, article_output],
    )

logging.info("Launching Gradio interface")
iface.queue().launch()