File size: 4,542 Bytes
3b5aca7
f75feb2
bfb4432
3b5aca7
bfb4432
 
 
 
 
 
 
 
 
 
 
3b5aca7
d2e0b39
bfb4432
 
 
 
 
 
 
 
 
 
 
 
f75feb2
bfb4432
 
f75feb2
bfb4432
 
 
 
 
 
 
d2e0b39
 
 
 
 
bfb4432
 
 
 
 
 
 
 
d2e0b39
 
 
bfb4432
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f75feb2
bfb4432
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f75feb2
 
 
 
 
 
 
bfb4432
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
from pydantic import BaseModel, Field
from story_beam_search.stories_generator import StoryGenerationSystem

genre_choices = [
    "children",
    "mystery",
    "adventure",
    "sci-fi",
    "fantasy",
    "romance",
    "comedy",
    "drama",
    "horror",
]


class InputModel(BaseModel):
    prompt: str
    genre: str
    num_stories: int = Field(3, ge=2, le=7)
    temperature: float = Field(2.5, ge=0.7, le=3.5)
    max_length: int = Field(60, ge=30, le=200)


def create_story_generation_interface() -> gr.Interface:
    # Initialize the story generation system
    system = StoryGenerationSystem()
    system.initialize()

    def generate_stories(
        prompt: str, genre: str, num_stories: int, temperature: float, max_length: int
    ) -> tuple[str, list[str]]:
        """
        Generate and evaluate stories based on user input.
        Returns a tuple of (detailed_scores, story_texts).
        """

        # Validate inputs.Gradio seems to validate chioces but not the range of the values
        input_values = InputModel(
            prompt=prompt,
            genre=genre,
            num_stories=num_stories,
            temperature=temperature,
            max_length=max_length,
        )

        # Update beam search config with user parameters
        system.beam_search.config.temperature = input_values.temperature
        system.beam_search.config.max_length = input_values.max_length

        # Generate and evaluate stories
        ranked_stories = system.generate_and_evaluate(
            input_values.prompt,
            input_values.genre,
            num_stories=input_values.num_stories,
        )

        # Format detailed scores
        detailed_scores = ""
        story_texts = []

        for i, (story, scores) in enumerate(ranked_stories, 1):
            detailed_scores += f"Story {i}:\n"
            detailed_scores += f"Total Score: {scores.total:.3f}\n"
            detailed_scores += f"Coherence: {scores.coherence:.3f}\n"
            detailed_scores += f"Fluency: {scores.fluency:.3f}\n"
            detailed_scores += f"Genre Alignment: {scores.genre_alignment:.3f}\n"
            detailed_scores += "-" * 50 + "\n"

            story_texts.append(f"Story {i}:\n{story}\n")

        return detailed_scores, "\n".join(story_texts)

    # Define interface components
    prompt_input = gr.Textbox(
        label="Story Prompt",
        placeholder="Enter the beginning of your story...",
        lines=3,
    )

    genre_input = gr.Dropdown(
        choices=genre_choices,
        label="Genre",
        value="fantasy",
    )

    num_stories_input = gr.Slider(
        minimum=2, maximum=7, value=3, step=1, label="Number of Stories to Generate"
    )

    temperature_input = gr.Slider(
        minimum=0.7, maximum=3.5, value=2.5, step=0.1, label="Temperature (Creativity)"
    )

    max_length_input = gr.Slider(
        minimum=40, maximum=200, value=60, step=20, label="Maximum Length"
    )

    # Output components
    scores_output = gr.Textbox(label="Detailed Scores", lines=10, interactive=False)

    stories_output = gr.Textbox(label="Generated Stories", lines=15, interactive=False)

    # Create the interface
    interface = gr.Interface(
        fn=generate_stories,
        inputs=[
            prompt_input,
            genre_input,
            num_stories_input,
            temperature_input,
            max_length_input,
        ],
        outputs=[scores_output, stories_output],
        title="AI Story Generator",
        description="""
        Generate creative stories using AI! Enter a prompt and choose your preferences.
        The system will generate multiple stories and evaluate them based on coherence,
        fluency, and genre alignment.
        """,
        examples=[
            [
                "Once upon a time in a magical forest, the trees whispered secrets, and moonlight revealed hidden paths to a realm where time stood still.",
                "fantasy",
                3,
                1.8,
                150,
            ],
            [
                "The detective knelt beside the bloodstained carpet, her gaze sharp as she traced the faint outline of a shoeprint.",
                "mystery",
                3,
                2.7,
                200,
            ],
        ],
        theme=gr.themes.Soft(),
    )

    return interface


if __name__ == "__main__":
    # Create and launch the interface
    interface = create_story_generation_interface()
    interface.launch(show_error=True)