Spaces:
Paused
Paused
ssalb
commited on
Commit
·
f75feb2
1
Parent(s):
d2e0b39
Update space with latest code and dependencies on Fri Jan 3 18:00:11 UTC 2025
Browse files- LICENSE +1 -1
- app.py +11 -7
- requirements.txt +2 -2
- story_beam_search/beam_search.py +44 -21
- story_beam_search/scoring.py +172 -93
- story_beam_search/stories_generator.py +2 -2
LICENSE
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
MIT License
|
2 |
|
3 |
-
Copyright (c)
|
4 |
|
5 |
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
of this software and associated documentation files (the "Software"), to deal
|
|
|
1 |
MIT License
|
2 |
|
3 |
+
Copyright (c) 2025 Salvador Salazar
|
4 |
|
5 |
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
of this software and associated documentation files (the "Software"), to deal
|
app.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
-
from
|
3 |
-
from pydantic import BaseModel, Field, constr
|
4 |
from story_beam_search.stories_generator import StoryGenerationSystem
|
5 |
-
from typing import Tuple, List
|
6 |
|
7 |
genre_choices = [
|
8 |
"children",
|
@@ -29,10 +27,10 @@ def create_story_generation_interface() -> gr.Interface:
|
|
29 |
# Initialize the story generation system
|
30 |
system = StoryGenerationSystem()
|
31 |
system.initialize()
|
32 |
-
|
33 |
def generate_stories(
|
34 |
prompt: str, genre: str, num_stories: int, temperature: float, max_length: int
|
35 |
-
) ->
|
36 |
"""
|
37 |
Generate and evaluate stories based on user input.
|
38 |
Returns a tuple of (detailed_scores, story_texts).
|
@@ -96,7 +94,7 @@ def create_story_generation_interface() -> gr.Interface:
|
|
96 |
)
|
97 |
|
98 |
max_length_input = gr.Slider(
|
99 |
-
minimum=
|
100 |
)
|
101 |
|
102 |
# Output components
|
@@ -122,7 +120,13 @@ def create_story_generation_interface() -> gr.Interface:
|
|
122 |
fluency, and genre alignment.
|
123 |
""",
|
124 |
examples=[
|
125 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
[
|
127 |
"The detective knelt beside the bloodstained carpet, her gaze sharp as she traced the faint outline of a shoeprint.",
|
128 |
"mystery",
|
|
|
1 |
import gradio as gr
|
2 |
+
from pydantic import BaseModel, Field
|
|
|
3 |
from story_beam_search.stories_generator import StoryGenerationSystem
|
|
|
4 |
|
5 |
genre_choices = [
|
6 |
"children",
|
|
|
27 |
# Initialize the story generation system
|
28 |
system = StoryGenerationSystem()
|
29 |
system.initialize()
|
30 |
+
|
31 |
def generate_stories(
|
32 |
prompt: str, genre: str, num_stories: int, temperature: float, max_length: int
|
33 |
+
) -> tuple[str, list[str]]:
|
34 |
"""
|
35 |
Generate and evaluate stories based on user input.
|
36 |
Returns a tuple of (detailed_scores, story_texts).
|
|
|
94 |
)
|
95 |
|
96 |
max_length_input = gr.Slider(
|
97 |
+
minimum=40, maximum=200, value=60, step=20, label="Maximum Length"
|
98 |
)
|
99 |
|
100 |
# Output components
|
|
|
120 |
fluency, and genre alignment.
|
121 |
""",
|
122 |
examples=[
|
123 |
+
[
|
124 |
+
"Once upon a time in a magical forest, the trees whispered secrets, and moonlight revealed hidden paths to a realm where time stood still.",
|
125 |
+
"fantasy",
|
126 |
+
3,
|
127 |
+
1.8,
|
128 |
+
150,
|
129 |
+
],
|
130 |
[
|
131 |
"The detective knelt beside the bloodstained carpet, her gaze sharp as she traced the faint outline of a shoeprint.",
|
132 |
"mystery",
|
requirements.txt
CHANGED
@@ -47,7 +47,7 @@ ruff==0.8.5 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
|
47 |
safehttpx==0.1.6 ; python_full_version == "3.10.13"
|
48 |
safetensors==0.5.0 ; python_full_version == "3.10.13"
|
49 |
scikit-learn==1.6.0 ; python_full_version == "3.10.13"
|
50 |
-
scipy==1.
|
51 |
semantic-version==2.10.0 ; python_full_version == "3.10.13"
|
52 |
shellingham==1.5.4 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
53 |
six==1.17.0 ; python_full_version == "3.10.13"
|
@@ -59,7 +59,7 @@ tokenizers==0.21.0 ; python_full_version == "3.10.13"
|
|
59 |
tomlkit==0.13.2 ; python_full_version == "3.10.13"
|
60 |
torch==2.4.0 ; python_full_version == "3.10.13"
|
61 |
tqdm==4.67.1 ; python_full_version == "3.10.13"
|
62 |
-
transformers @ git+https://github.com/huggingface/transformers.git@
|
63 |
triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
|
64 |
typer==0.15.1 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
65 |
typing-extensions==4.12.2 ; python_full_version == "3.10.13"
|
|
|
47 |
safehttpx==0.1.6 ; python_full_version == "3.10.13"
|
48 |
safetensors==0.5.0 ; python_full_version == "3.10.13"
|
49 |
scikit-learn==1.6.0 ; python_full_version == "3.10.13"
|
50 |
+
scipy==1.15.0 ; python_full_version == "3.10.13"
|
51 |
semantic-version==2.10.0 ; python_full_version == "3.10.13"
|
52 |
shellingham==1.5.4 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
53 |
six==1.17.0 ; python_full_version == "3.10.13"
|
|
|
59 |
tomlkit==0.13.2 ; python_full_version == "3.10.13"
|
60 |
torch==2.4.0 ; python_full_version == "3.10.13"
|
61 |
tqdm==4.67.1 ; python_full_version == "3.10.13"
|
62 |
+
transformers @ git+https://github.com/huggingface/transformers.git@e5fd865ebae062b7cf03a81b8c6affeb39f30bec ; python_full_version == "3.10.13"
|
63 |
triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
|
64 |
typer==0.15.1 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
65 |
typing-extensions==4.12.2 ; python_full_version == "3.10.13"
|
story_beam_search/beam_search.py
CHANGED
@@ -36,31 +36,27 @@ class BeamSearchGenerator:
|
|
36 |
self, prompt: str, genre: str, evaluator: StoryScorer
|
37 |
) -> list[str]:
|
38 |
"""
|
39 |
-
Generate story continuations using
|
40 |
"""
|
41 |
-
|
42 |
-
# Adding some instructions to the prompt. These are removed in the end
|
43 |
instructions = (
|
44 |
f"Continue the following story in the {genre} genre, "
|
45 |
"ensuring coherence with the tone, characters, and narrative established so far:\n"
|
46 |
)
|
47 |
instructions_len = len(instructions)
|
48 |
|
49 |
-
stories = self.
|
50 |
ranked_stories = evaluator.evaluate_multiple(
|
51 |
[story[instructions_len:] for story in stories]
|
52 |
)
|
53 |
-
|
54 |
stories = [story for story, _ in ranked_stories[: self.config.num_beams]]
|
55 |
|
56 |
if stories:
|
57 |
for _ in range(self.config.num_iterations):
|
58 |
-
|
59 |
-
for story in stories
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
all_stories.extend(continuations)
|
64 |
ranked_stories = evaluator.evaluate_multiple(
|
65 |
[story[instructions_len:] for story in all_stories]
|
66 |
)
|
@@ -70,23 +66,50 @@ class BeamSearchGenerator:
|
|
70 |
|
71 |
return stories
|
72 |
|
73 |
-
def
|
74 |
"""
|
75 |
-
Generate multiple continuations for
|
76 |
"""
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
-
|
82 |
-
|
|
|
83 |
)
|
84 |
|
|
|
85 |
with torch.no_grad():
|
86 |
outputs = self.model.generate(
|
87 |
-
input_ids=
|
88 |
-
attention_mask=
|
89 |
-
max_length=
|
90 |
num_beams=self.config.num_beams,
|
91 |
num_return_sequences=self.config.num_return_sequences,
|
92 |
early_stopping=True,
|
|
|
36 |
self, prompt: str, genre: str, evaluator: StoryScorer
|
37 |
) -> list[str]:
|
38 |
"""
|
39 |
+
Generate story continuations using parallel beam search iterations.
|
40 |
"""
|
|
|
|
|
41 |
instructions = (
|
42 |
f"Continue the following story in the {genre} genre, "
|
43 |
"ensuring coherence with the tone, characters, and narrative established so far:\n"
|
44 |
)
|
45 |
instructions_len = len(instructions)
|
46 |
|
47 |
+
stories = self._generate_batch([instructions + prompt])
|
48 |
ranked_stories = evaluator.evaluate_multiple(
|
49 |
[story[instructions_len:] for story in stories]
|
50 |
)
|
|
|
51 |
stories = [story for story, _ in ranked_stories[: self.config.num_beams]]
|
52 |
|
53 |
if stories:
|
54 |
for _ in range(self.config.num_iterations):
|
55 |
+
# Prepare all prompts for batch processing
|
56 |
+
all_prompts = [instructions + story for story in stories]
|
57 |
+
# Generate all continuations in one batch
|
58 |
+
all_stories = self._generate_batch(all_prompts)
|
59 |
+
|
|
|
60 |
ranked_stories = evaluator.evaluate_multiple(
|
61 |
[story[instructions_len:] for story in all_stories]
|
62 |
)
|
|
|
66 |
|
67 |
return stories
|
68 |
|
69 |
+
def _generate_batch(self, prompts: list[str]) -> list[str]:
|
70 |
"""
|
71 |
+
Generate multiple continuations for multiple prompts in a single batch.
|
72 |
"""
|
73 |
+
# Tokenize all prompts
|
74 |
+
tokenized = [self.tokenizer(prompt, return_tensors="pt") for prompt in prompts]
|
75 |
+
|
76 |
+
# Pad input_ids and attention_masks to same length
|
77 |
+
max_length = max(inputs["input_ids"].size(1) for inputs in tokenized)
|
78 |
+
padded_input_ids = []
|
79 |
+
padded_attention_masks = []
|
80 |
+
|
81 |
+
for inputs in tokenized:
|
82 |
+
input_ids = inputs["input_ids"][0]
|
83 |
+
attention_mask = inputs["attention_mask"][0]
|
84 |
+
|
85 |
+
# Pad to max_length
|
86 |
+
padding_length = max_length - input_ids.size(0)
|
87 |
+
if padding_length > 0:
|
88 |
+
input_ids = torch.cat(
|
89 |
+
[input_ids, torch.zeros(padding_length, dtype=torch.long)]
|
90 |
+
)
|
91 |
+
attention_mask = torch.cat(
|
92 |
+
[attention_mask, torch.zeros(padding_length, dtype=torch.long)]
|
93 |
+
)
|
94 |
+
|
95 |
+
padded_input_ids.append(input_ids)
|
96 |
+
padded_attention_masks.append(attention_mask)
|
97 |
+
|
98 |
+
# Stack into batches
|
99 |
+
input_ids_batch = torch.stack(padded_input_ids).to(self.device)
|
100 |
+
attention_mask_batch = torch.stack(padded_attention_masks).to(self.device)
|
101 |
|
102 |
+
# Calculate continuation length
|
103 |
+
continuation_length = (
|
104 |
+
max_length + self.config.max_length // self.config.num_iterations
|
105 |
)
|
106 |
|
107 |
+
# Generate all continuations in one pass
|
108 |
with torch.no_grad():
|
109 |
outputs = self.model.generate(
|
110 |
+
input_ids=input_ids_batch,
|
111 |
+
attention_mask=attention_mask_batch,
|
112 |
+
max_length=continuation_length,
|
113 |
num_beams=self.config.num_beams,
|
114 |
num_return_sequences=self.config.num_return_sequences,
|
115 |
early_stopping=True,
|
story_beam_search/scoring.py
CHANGED
@@ -9,7 +9,7 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|
9 |
class StoryScorer(Protocol):
|
10 |
"""Protocol defining the interface for story scoring components."""
|
11 |
|
12 |
-
def score(self,
|
13 |
"""Return a score between 0 and 1."""
|
14 |
...
|
15 |
|
@@ -28,40 +28,72 @@ class CoherenceScorer(StoryScorer):
|
|
28 |
model: PreTrainedModel,
|
29 |
tokenizer: PreTrainedTokenizer,
|
30 |
device: torch.device,
|
31 |
-
|
32 |
):
|
33 |
self.model = model
|
34 |
self.tokenizer = tokenizer
|
35 |
self.device = device
|
36 |
-
self.
|
37 |
|
38 |
-
def score(self,
|
39 |
"""Calculate coherence score based on sentences cosine similarity."""
|
|
|
40 |
|
41 |
-
|
|
|
|
|
|
|
42 |
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
-
# Generate embeddings for each sentence
|
46 |
-
for sentence in sentences:
|
47 |
-
inputs = self.tokenizer(sentence, return_tensors="pt").to(self.device)
|
48 |
with torch.no_grad():
|
49 |
-
outputs = self.model(**inputs)
|
50 |
-
|
51 |
-
|
52 |
-
embeddings
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
-
|
61 |
-
avg_coherence = (
|
62 |
-
sum(coherence_scores) / len(coherence_scores) if coherence_scores else 0.0
|
63 |
-
)
|
64 |
-
return avg_coherence
|
65 |
|
66 |
|
67 |
class FluencyScorer(StoryScorer):
|
@@ -70,61 +102,119 @@ class FluencyScorer(StoryScorer):
|
|
70 |
model: PreTrainedModel,
|
71 |
tokenizer: PreTrainedTokenizer,
|
72 |
device: torch.device,
|
|
|
73 |
):
|
74 |
self.model = model
|
75 |
self.tokenizer = tokenizer
|
76 |
self.device = device
|
77 |
-
|
78 |
-
|
79 |
-
#
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
mask_token_id = self.tokenizer.mask_token_id
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
88 |
|
89 |
-
|
90 |
-
for i in range(1, input_ids.size(1) - 1):
|
91 |
-
masked_input_ids = input_ids.clone()
|
92 |
-
masked_input_ids[0, i] = mask_token_id
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
-
|
99 |
-
token_probability = logits[0, i].softmax(dim=-1)[original_token_id].item()
|
100 |
-
fluency_scores.append(token_probability)
|
101 |
|
102 |
-
|
103 |
-
sum(fluency_scores) / len(fluency_scores) if fluency_scores else 0.0
|
104 |
-
)
|
105 |
-
return avg_fluency
|
106 |
|
107 |
|
108 |
class GenreAlignmentScorer(StoryScorer):
|
109 |
-
def __init__(self, pipeline: Pipeline, genre: str):
|
110 |
self.pipeline = pipeline
|
111 |
self.genre = genre
|
|
|
112 |
|
113 |
-
def score(self,
|
114 |
if not self.genre:
|
115 |
-
return 0.5
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
)
|
124 |
-
|
125 |
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
|
130 |
class StoryEvaluator:
|
@@ -140,43 +230,32 @@ class StoryEvaluator:
|
|
140 |
self.genre_scorer = genre_scorer
|
141 |
self.weights = weights
|
142 |
|
143 |
-
def evaluate(self, story: str, max_scores: list[float]) -> CombinedScore:
|
144 |
-
coherence = self.coherence_scorer.score(story)
|
145 |
-
fluency = self.fluency_scorer.score(story)
|
146 |
-
genre_alignment = self.genre_scorer.score(story)
|
147 |
-
|
148 |
-
max_scores[0] = np.max([max_scores[0], coherence])
|
149 |
-
max_scores[1] = np.max([max_scores[1], fluency])
|
150 |
-
max_scores[2] = np.max([max_scores[2], genre_alignment])
|
151 |
-
|
152 |
-
return CombinedScore(
|
153 |
-
coherence=coherence,
|
154 |
-
fluency=fluency,
|
155 |
-
genre_alignment=genre_alignment,
|
156 |
-
total=0,
|
157 |
-
)
|
158 |
-
|
159 |
def evaluate_multiple(self, stories: list[str]) -> list[tuple[str, CombinedScore]]:
|
160 |
-
"""Evaluate multiple stories and return them sorted by total score."""
|
161 |
-
|
162 |
-
#
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
|
|
|
|
|
|
|
|
177 |
)
|
178 |
scores.total = np.dot(
|
179 |
[scores.coherence, scores.fluency, scores.genre_alignment], self.weights
|
180 |
)
|
|
|
181 |
|
182 |
return sorted(scored_stories, key=lambda x: x[1].total, reverse=True)
|
|
|
9 |
class StoryScorer(Protocol):
|
10 |
"""Protocol defining the interface for story scoring components."""
|
11 |
|
12 |
+
def score(self, stories: list[str]) -> float:
|
13 |
"""Return a score between 0 and 1."""
|
14 |
...
|
15 |
|
|
|
28 |
model: PreTrainedModel,
|
29 |
tokenizer: PreTrainedTokenizer,
|
30 |
device: torch.device,
|
31 |
+
batch_size: int = 32,
|
32 |
):
|
33 |
self.model = model
|
34 |
self.tokenizer = tokenizer
|
35 |
self.device = device
|
36 |
+
self.batch_size = batch_size
|
37 |
|
38 |
+
def score(self, stories: list[str]) -> list[float]:
|
39 |
"""Calculate coherence score based on sentences cosine similarity."""
|
40 |
+
all_scores = []
|
41 |
|
42 |
+
# Split stories into sentences for coherence scoring
|
43 |
+
sentences_list = [
|
44 |
+
[s.strip() for s in story.split(".") if s.strip()] for story in stories
|
45 |
+
]
|
46 |
|
47 |
+
# Collect all sentence pairs that need embedding
|
48 |
+
all_sentence_pairs = []
|
49 |
+
story_boundaries = [] # Track where each story's sentences end
|
50 |
+
current_position = 0
|
51 |
+
|
52 |
+
for sentences in sentences_list:
|
53 |
+
pairs_count = len(sentences) - 1
|
54 |
+
all_sentence_pairs.extend(zip(sentences[:-1], sentences[1:]))
|
55 |
+
story_boundaries.append(current_position + pairs_count)
|
56 |
+
current_position += pairs_count
|
57 |
+
|
58 |
+
# Process sentence pairs in batches
|
59 |
+
all_embeddings = []
|
60 |
+
for i in range(0, len(all_sentence_pairs), self.batch_size):
|
61 |
+
batch_pairs = all_sentence_pairs[i : i + self.batch_size]
|
62 |
+
# Flatten pairs for batch processing
|
63 |
+
batch_sentences = [sent for pair in batch_pairs for sent in pair]
|
64 |
+
|
65 |
+
# Tokenize batch
|
66 |
+
inputs = self.tokenizer(
|
67 |
+
batch_sentences, padding=True, truncation=True, return_tensors="pt"
|
68 |
+
).to(self.device)
|
69 |
|
|
|
|
|
|
|
70 |
with torch.no_grad():
|
71 |
+
outputs = self.model(**inputs, output_hidden_states=True)
|
72 |
+
batch_embeddings = outputs.hidden_states[-1][
|
73 |
+
:, 0, :
|
74 |
+
] # Get CLS token embeddings
|
75 |
+
all_embeddings.extend(batch_embeddings.cpu().numpy())
|
76 |
+
|
77 |
+
# Calculate coherence scores for each story
|
78 |
+
current_idx = 0
|
79 |
+
for boundary in story_boundaries:
|
80 |
+
story_pairs_count = boundary - current_idx
|
81 |
+
story_scores = []
|
82 |
+
|
83 |
+
for i in range(story_pairs_count):
|
84 |
+
idx = current_idx + i
|
85 |
+
first_emb = all_embeddings[idx * 2].reshape(1, -1)
|
86 |
+
second_emb = all_embeddings[idx * 2 + 1].reshape(1, -1)
|
87 |
+
sim = cosine_similarity(first_emb, second_emb)[0][0]
|
88 |
+
story_scores.append(sim)
|
89 |
+
|
90 |
+
avg_coherence = (
|
91 |
+
sum(story_scores) / len(story_scores) if story_scores else 0.0
|
92 |
+
)
|
93 |
+
all_scores.append(avg_coherence)
|
94 |
+
current_idx = boundary
|
95 |
|
96 |
+
return all_scores
|
|
|
|
|
|
|
|
|
97 |
|
98 |
|
99 |
class FluencyScorer(StoryScorer):
|
|
|
102 |
model: PreTrainedModel,
|
103 |
tokenizer: PreTrainedTokenizer,
|
104 |
device: torch.device,
|
105 |
+
batch_size: int = 32,
|
106 |
):
|
107 |
self.model = model
|
108 |
self.tokenizer = tokenizer
|
109 |
self.device = device
|
110 |
+
self.batch_size = batch_size
|
111 |
+
|
112 |
+
# Set up padding token if it doesn't exist
|
113 |
+
if self.tokenizer.pad_token is None:
|
114 |
+
if self.tokenizer.eos_token is not None:
|
115 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
116 |
+
else:
|
117 |
+
# Add padding token to tokenizer only
|
118 |
+
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
119 |
+
|
120 |
+
# Set up mask token if it doesn't exist
|
121 |
+
if self.tokenizer.mask_token is None:
|
122 |
+
self.tokenizer.add_special_tokens({"mask_token": "[MASK]"})
|
123 |
+
|
124 |
+
def score(self, stories: list[str]) -> list[float]:
|
125 |
+
all_scores = []
|
126 |
mask_token_id = self.tokenizer.mask_token_id
|
127 |
|
128 |
+
# Process stories in batches
|
129 |
+
for i in range(0, len(stories), self.batch_size):
|
130 |
+
batch_stories = stories[i : i + self.batch_size]
|
131 |
+
batch_inputs = self.tokenizer(
|
132 |
+
batch_stories, padding=True, truncation=True, return_tensors="pt"
|
133 |
+
).to(self.device)
|
134 |
|
135 |
+
batch_scores = []
|
|
|
|
|
|
|
136 |
|
137 |
+
# For each story in the batch
|
138 |
+
for j in range(len(batch_stories)):
|
139 |
+
story_scores = []
|
140 |
+
input_ids = batch_inputs.input_ids[j : j + 1] # Keep batch dimension
|
141 |
+
attention_mask = batch_inputs.attention_mask[
|
142 |
+
j : j + 1
|
143 |
+
] # Get attention mask
|
144 |
+
|
145 |
+
# Only process tokens that aren't padding
|
146 |
+
valid_tokens = attention_mask[0].sum().item()
|
147 |
+
|
148 |
+
# For each token in the story (excluding padding)
|
149 |
+
for k in range(1, valid_tokens - 1):
|
150 |
+
masked_input_ids = input_ids.clone()
|
151 |
+
masked_input_ids[0, k] = mask_token_id
|
152 |
+
|
153 |
+
with torch.no_grad():
|
154 |
+
outputs = self.model(
|
155 |
+
input_ids=masked_input_ids, attention_mask=attention_mask
|
156 |
+
)
|
157 |
+
logits = outputs.logits
|
158 |
+
|
159 |
+
original_token_id = input_ids[0, k]
|
160 |
+
token_probability = (
|
161 |
+
logits[0, k].softmax(dim=-1)[original_token_id].item()
|
162 |
+
)
|
163 |
+
story_scores.append(token_probability)
|
164 |
+
|
165 |
+
avg_fluency = (
|
166 |
+
sum(story_scores) / len(story_scores) if story_scores else 0.0
|
167 |
+
)
|
168 |
+
batch_scores.append(avg_fluency)
|
169 |
|
170 |
+
all_scores.extend(batch_scores)
|
|
|
|
|
171 |
|
172 |
+
return all_scores
|
|
|
|
|
|
|
173 |
|
174 |
|
175 |
class GenreAlignmentScorer(StoryScorer):
|
176 |
+
def __init__(self, pipeline: Pipeline, genre: str, batch_size: int = 32):
|
177 |
self.pipeline = pipeline
|
178 |
self.genre = genre
|
179 |
+
self.batch_size = batch_size
|
180 |
|
181 |
+
def score(self, stories: list[str]) -> list[float]:
|
182 |
if not self.genre:
|
183 |
+
return [0.5] * len(stories)
|
184 |
+
|
185 |
+
all_scores = []
|
186 |
+
# Split all stories into sentences
|
187 |
+
all_sentences = []
|
188 |
+
story_boundaries = []
|
189 |
+
current_position = 0
|
190 |
+
|
191 |
+
for story in stories:
|
192 |
+
sentences = [s.strip() for s in story.split(".") if s.strip()]
|
193 |
+
all_sentences.extend(sentences)
|
194 |
+
story_boundaries.append(current_position + len(sentences))
|
195 |
+
current_position += len(sentences)
|
196 |
+
|
197 |
+
# Process sentences in batches
|
198 |
+
all_sentence_scores = []
|
199 |
+
for i in range(0, len(all_sentences), self.batch_size):
|
200 |
+
batch_sentences = all_sentences[i : i + self.batch_size]
|
201 |
+
results = self.pipeline(
|
202 |
+
batch_sentences,
|
203 |
+
candidate_labels=[self.genre],
|
204 |
+
multi_label=True,
|
205 |
+
batch_size=self.batch_size,
|
206 |
)
|
207 |
+
all_sentence_scores.extend([r["scores"][0] for r in results])
|
208 |
|
209 |
+
# Calculate average score for each story
|
210 |
+
current_idx = 0
|
211 |
+
for boundary in story_boundaries:
|
212 |
+
story_scores = all_sentence_scores[current_idx:boundary]
|
213 |
+
avg_score = sum(story_scores) / len(story_scores) if story_scores else 0.0
|
214 |
+
all_scores.append(avg_score)
|
215 |
+
current_idx = boundary
|
216 |
+
|
217 |
+
return all_scores
|
218 |
|
219 |
|
220 |
class StoryEvaluator:
|
|
|
230 |
self.genre_scorer = genre_scorer
|
231 |
self.weights = weights
|
232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
def evaluate_multiple(self, stories: list[str]) -> list[tuple[str, CombinedScore]]:
|
234 |
+
"""Evaluate multiple stories in batches and return them sorted by total score."""
|
235 |
+
|
236 |
+
# Get all scores in parallel using batch processing
|
237 |
+
coherence_scores = self.coherence_scorer.score(stories)
|
238 |
+
fluency_scores = self.fluency_scorer.score(stories)
|
239 |
+
genre_scores = self.genre_scorer.score(stories)
|
240 |
+
|
241 |
+
# Find max scores for normalization
|
242 |
+
max_scores = [max(coherence_scores), max(fluency_scores), max(genre_scores)]
|
243 |
+
|
244 |
+
# Create scored stories
|
245 |
+
scored_stories = []
|
246 |
+
for i, story in enumerate(stories):
|
247 |
+
scores = CombinedScore(
|
248 |
+
coherence=(
|
249 |
+
coherence_scores[i] / max_scores[0] if max_scores[0] != 0 else 0
|
250 |
+
),
|
251 |
+
fluency=fluency_scores[i] / max_scores[1] if max_scores[1] != 0 else 0,
|
252 |
+
genre_alignment=(
|
253 |
+
genre_scores[i] / max_scores[2] if max_scores[2] != 0 else 0
|
254 |
+
),
|
255 |
)
|
256 |
scores.total = np.dot(
|
257 |
[scores.coherence, scores.fluency, scores.genre_alignment], self.weights
|
258 |
)
|
259 |
+
scored_stories.append((story, scores))
|
260 |
|
261 |
return sorted(scored_stories, key=lambda x: x[1].total, reverse=True)
|
story_beam_search/stories_generator.py
CHANGED
@@ -160,10 +160,10 @@ class StoryGenerationSystem:
|
|
160 |
prompt_segments = re.split(r"[^a-zA-Z0-9 ]+", prompt)
|
161 |
prompt_segments = list(set(prompt_segments))
|
162 |
|
163 |
-
storyness_score = self.storyness.score(prompt)
|
164 |
for segment in prompt_segments:
|
165 |
if segment.strip():
|
166 |
-
injection_score = self.injection_guard.score(segment)
|
167 |
if storyness_score < 0.2 or injection_score > 0.2:
|
168 |
print("Potential prompt injection detected.")
|
169 |
print(f"storyness_score: {storyness_score}")
|
|
|
160 |
prompt_segments = re.split(r"[^a-zA-Z0-9 ]+", prompt)
|
161 |
prompt_segments = list(set(prompt_segments))
|
162 |
|
163 |
+
storyness_score = self.storyness.score([prompt])[0]
|
164 |
for segment in prompt_segments:
|
165 |
if segment.strip():
|
166 |
+
injection_score = self.injection_guard.score([segment])[0]
|
167 |
if storyness_score < 0.2 or injection_score > 0.2:
|
168 |
print("Potential prompt injection detected.")
|
169 |
print(f"storyness_score: {storyness_score}")
|