ssalb commited on
Commit
f75feb2
·
1 Parent(s): d2e0b39

Update space with latest code and dependencies on Fri Jan 3 18:00:11 UTC 2025

Browse files
LICENSE CHANGED
@@ -1,6 +1,6 @@
1
  MIT License
2
 
3
- Copyright (c) 2023 Salvador Salazar
4
 
5
  Permission is hereby granted, free of charge, to any person obtaining a copy
6
  of this software and associated documentation files (the "Software"), to deal
 
1
  MIT License
2
 
3
+ Copyright (c) 2025 Salvador Salazar
4
 
5
  Permission is hereby granted, free of charge, to any person obtaining a copy
6
  of this software and associated documentation files (the "Software"), to deal
app.py CHANGED
@@ -1,8 +1,6 @@
1
  import gradio as gr
2
- from typing import Literal
3
- from pydantic import BaseModel, Field, constr
4
  from story_beam_search.stories_generator import StoryGenerationSystem
5
- from typing import Tuple, List
6
 
7
  genre_choices = [
8
  "children",
@@ -29,10 +27,10 @@ def create_story_generation_interface() -> gr.Interface:
29
  # Initialize the story generation system
30
  system = StoryGenerationSystem()
31
  system.initialize()
32
-
33
  def generate_stories(
34
  prompt: str, genre: str, num_stories: int, temperature: float, max_length: int
35
- ) -> Tuple[str, List[str]]:
36
  """
37
  Generate and evaluate stories based on user input.
38
  Returns a tuple of (detailed_scores, story_texts).
@@ -96,7 +94,7 @@ def create_story_generation_interface() -> gr.Interface:
96
  )
97
 
98
  max_length_input = gr.Slider(
99
- minimum=30, maximum=200, value=60, step=30, label="Maximum Length"
100
  )
101
 
102
  # Output components
@@ -122,7 +120,13 @@ def create_story_generation_interface() -> gr.Interface:
122
  fluency, and genre alignment.
123
  """,
124
  examples=[
125
- ["Once upon a time in a magical forest,", "fantasy", 3, 1.8, 150],
 
 
 
 
 
 
126
  [
127
  "The detective knelt beside the bloodstained carpet, her gaze sharp as she traced the faint outline of a shoeprint.",
128
  "mystery",
 
1
  import gradio as gr
2
+ from pydantic import BaseModel, Field
 
3
  from story_beam_search.stories_generator import StoryGenerationSystem
 
4
 
5
  genre_choices = [
6
  "children",
 
27
  # Initialize the story generation system
28
  system = StoryGenerationSystem()
29
  system.initialize()
30
+
31
  def generate_stories(
32
  prompt: str, genre: str, num_stories: int, temperature: float, max_length: int
33
+ ) -> tuple[str, list[str]]:
34
  """
35
  Generate and evaluate stories based on user input.
36
  Returns a tuple of (detailed_scores, story_texts).
 
94
  )
95
 
96
  max_length_input = gr.Slider(
97
+ minimum=40, maximum=200, value=60, step=20, label="Maximum Length"
98
  )
99
 
100
  # Output components
 
120
  fluency, and genre alignment.
121
  """,
122
  examples=[
123
+ [
124
+ "Once upon a time in a magical forest, the trees whispered secrets, and moonlight revealed hidden paths to a realm where time stood still.",
125
+ "fantasy",
126
+ 3,
127
+ 1.8,
128
+ 150,
129
+ ],
130
  [
131
  "The detective knelt beside the bloodstained carpet, her gaze sharp as she traced the faint outline of a shoeprint.",
132
  "mystery",
requirements.txt CHANGED
@@ -47,7 +47,7 @@ ruff==0.8.5 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
47
  safehttpx==0.1.6 ; python_full_version == "3.10.13"
48
  safetensors==0.5.0 ; python_full_version == "3.10.13"
49
  scikit-learn==1.6.0 ; python_full_version == "3.10.13"
50
- scipy==1.14.1 ; python_full_version == "3.10.13"
51
  semantic-version==2.10.0 ; python_full_version == "3.10.13"
52
  shellingham==1.5.4 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
53
  six==1.17.0 ; python_full_version == "3.10.13"
@@ -59,7 +59,7 @@ tokenizers==0.21.0 ; python_full_version == "3.10.13"
59
  tomlkit==0.13.2 ; python_full_version == "3.10.13"
60
  torch==2.4.0 ; python_full_version == "3.10.13"
61
  tqdm==4.67.1 ; python_full_version == "3.10.13"
62
- transformers @ git+https://github.com/huggingface/transformers.git@42865860ec6dc135972d9555753cb7ee17f51fb4 ; python_full_version == "3.10.13"
63
  triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
64
  typer==0.15.1 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
65
  typing-extensions==4.12.2 ; python_full_version == "3.10.13"
 
47
  safehttpx==0.1.6 ; python_full_version == "3.10.13"
48
  safetensors==0.5.0 ; python_full_version == "3.10.13"
49
  scikit-learn==1.6.0 ; python_full_version == "3.10.13"
50
+ scipy==1.15.0 ; python_full_version == "3.10.13"
51
  semantic-version==2.10.0 ; python_full_version == "3.10.13"
52
  shellingham==1.5.4 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
53
  six==1.17.0 ; python_full_version == "3.10.13"
 
59
  tomlkit==0.13.2 ; python_full_version == "3.10.13"
60
  torch==2.4.0 ; python_full_version == "3.10.13"
61
  tqdm==4.67.1 ; python_full_version == "3.10.13"
62
+ transformers @ git+https://github.com/huggingface/transformers.git@e5fd865ebae062b7cf03a81b8c6affeb39f30bec ; python_full_version == "3.10.13"
63
  triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
64
  typer==0.15.1 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
65
  typing-extensions==4.12.2 ; python_full_version == "3.10.13"
story_beam_search/beam_search.py CHANGED
@@ -36,31 +36,27 @@ class BeamSearchGenerator:
36
  self, prompt: str, genre: str, evaluator: StoryScorer
37
  ) -> list[str]:
38
  """
39
- Generate story continuations using multiple iterations of beam search.
40
  """
41
-
42
- # Adding some instructions to the prompt. These are removed in the end
43
  instructions = (
44
  f"Continue the following story in the {genre} genre, "
45
  "ensuring coherence with the tone, characters, and narrative established so far:\n"
46
  )
47
  instructions_len = len(instructions)
48
 
49
- stories = self._generate_single_iteration(instructions + prompt)
50
  ranked_stories = evaluator.evaluate_multiple(
51
  [story[instructions_len:] for story in stories]
52
  )
53
-
54
  stories = [story for story, _ in ranked_stories[: self.config.num_beams]]
55
 
56
  if stories:
57
  for _ in range(self.config.num_iterations):
58
- all_stories = []
59
- for story in stories:
60
- continuations = self._generate_single_iteration(
61
- instructions + story
62
- )
63
- all_stories.extend(continuations)
64
  ranked_stories = evaluator.evaluate_multiple(
65
  [story[instructions_len:] for story in all_stories]
66
  )
@@ -70,23 +66,50 @@ class BeamSearchGenerator:
70
 
71
  return stories
72
 
73
- def _generate_single_iteration(self, prompt: str) -> list[str]:
74
  """
75
- Generate multiple continuations for a single iteration using beam search.
76
  """
77
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
78
- input_ids = inputs["input_ids"]
79
- attention_mask = inputs["attention_mask"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- self.config.continuation_length = (
82
- len(input_ids[0]) + self.config.max_length // self.config.num_iterations
 
83
  )
84
 
 
85
  with torch.no_grad():
86
  outputs = self.model.generate(
87
- input_ids=input_ids,
88
- attention_mask=attention_mask,
89
- max_length=self.config.continuation_length,
90
  num_beams=self.config.num_beams,
91
  num_return_sequences=self.config.num_return_sequences,
92
  early_stopping=True,
 
36
  self, prompt: str, genre: str, evaluator: StoryScorer
37
  ) -> list[str]:
38
  """
39
+ Generate story continuations using parallel beam search iterations.
40
  """
 
 
41
  instructions = (
42
  f"Continue the following story in the {genre} genre, "
43
  "ensuring coherence with the tone, characters, and narrative established so far:\n"
44
  )
45
  instructions_len = len(instructions)
46
 
47
+ stories = self._generate_batch([instructions + prompt])
48
  ranked_stories = evaluator.evaluate_multiple(
49
  [story[instructions_len:] for story in stories]
50
  )
 
51
  stories = [story for story, _ in ranked_stories[: self.config.num_beams]]
52
 
53
  if stories:
54
  for _ in range(self.config.num_iterations):
55
+ # Prepare all prompts for batch processing
56
+ all_prompts = [instructions + story for story in stories]
57
+ # Generate all continuations in one batch
58
+ all_stories = self._generate_batch(all_prompts)
59
+
 
60
  ranked_stories = evaluator.evaluate_multiple(
61
  [story[instructions_len:] for story in all_stories]
62
  )
 
66
 
67
  return stories
68
 
69
+ def _generate_batch(self, prompts: list[str]) -> list[str]:
70
  """
71
+ Generate multiple continuations for multiple prompts in a single batch.
72
  """
73
+ # Tokenize all prompts
74
+ tokenized = [self.tokenizer(prompt, return_tensors="pt") for prompt in prompts]
75
+
76
+ # Pad input_ids and attention_masks to same length
77
+ max_length = max(inputs["input_ids"].size(1) for inputs in tokenized)
78
+ padded_input_ids = []
79
+ padded_attention_masks = []
80
+
81
+ for inputs in tokenized:
82
+ input_ids = inputs["input_ids"][0]
83
+ attention_mask = inputs["attention_mask"][0]
84
+
85
+ # Pad to max_length
86
+ padding_length = max_length - input_ids.size(0)
87
+ if padding_length > 0:
88
+ input_ids = torch.cat(
89
+ [input_ids, torch.zeros(padding_length, dtype=torch.long)]
90
+ )
91
+ attention_mask = torch.cat(
92
+ [attention_mask, torch.zeros(padding_length, dtype=torch.long)]
93
+ )
94
+
95
+ padded_input_ids.append(input_ids)
96
+ padded_attention_masks.append(attention_mask)
97
+
98
+ # Stack into batches
99
+ input_ids_batch = torch.stack(padded_input_ids).to(self.device)
100
+ attention_mask_batch = torch.stack(padded_attention_masks).to(self.device)
101
 
102
+ # Calculate continuation length
103
+ continuation_length = (
104
+ max_length + self.config.max_length // self.config.num_iterations
105
  )
106
 
107
+ # Generate all continuations in one pass
108
  with torch.no_grad():
109
  outputs = self.model.generate(
110
+ input_ids=input_ids_batch,
111
+ attention_mask=attention_mask_batch,
112
+ max_length=continuation_length,
113
  num_beams=self.config.num_beams,
114
  num_return_sequences=self.config.num_return_sequences,
115
  early_stopping=True,
story_beam_search/scoring.py CHANGED
@@ -9,7 +9,7 @@ from sklearn.metrics.pairwise import cosine_similarity
9
  class StoryScorer(Protocol):
10
  """Protocol defining the interface for story scoring components."""
11
 
12
- def score(self, story: str) -> float:
13
  """Return a score between 0 and 1."""
14
  ...
15
 
@@ -28,40 +28,72 @@ class CoherenceScorer(StoryScorer):
28
  model: PreTrainedModel,
29
  tokenizer: PreTrainedTokenizer,
30
  device: torch.device,
31
- max_pairs: int = 3,
32
  ):
33
  self.model = model
34
  self.tokenizer = tokenizer
35
  self.device = device
36
- self.max_pairs = max_pairs
37
 
38
- def score(self, story: str) -> float:
39
  """Calculate coherence score based on sentences cosine similarity."""
 
40
 
41
- sentences = [s.strip() for s in story.split(".") if s.strip()]
 
 
 
42
 
43
- embeddings = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- # Generate embeddings for each sentence
46
- for sentence in sentences:
47
- inputs = self.tokenizer(sentence, return_tensors="pt").to(self.device)
48
  with torch.no_grad():
49
- outputs = self.model(**inputs)
50
- last_hidden_state = outputs.hidden_states[-1]
51
- emb = last_hidden_state[:, 0, :]
52
- embeddings.append(emb.cpu().numpy())
53
-
54
- # Calculate cosine similarity between adjacent embeddings
55
- coherence_scores = []
56
- for i in range(len(embeddings) - 1):
57
- sim = cosine_similarity(embeddings[i], embeddings[i + 1])[0][0]
58
- coherence_scores.append(sim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # Average coherence score
61
- avg_coherence = (
62
- sum(coherence_scores) / len(coherence_scores) if coherence_scores else 0.0
63
- )
64
- return avg_coherence
65
 
66
 
67
  class FluencyScorer(StoryScorer):
@@ -70,61 +102,119 @@ class FluencyScorer(StoryScorer):
70
  model: PreTrainedModel,
71
  tokenizer: PreTrainedTokenizer,
72
  device: torch.device,
 
73
  ):
74
  self.model = model
75
  self.tokenizer = tokenizer
76
  self.device = device
77
-
78
- def score(self, story: str) -> float:
79
- # Mask each token in the story and calculate the probability of the original token
80
- # Fluency is measured by the average probability of each token in the story
81
- inputs = self.tokenizer(story, return_tensors="pt").to(self.device)
82
- input_ids = inputs.input_ids
 
 
 
 
 
 
 
 
 
 
83
  mask_token_id = self.tokenizer.mask_token_id
84
 
85
- if mask_token_id is None:
86
- self.tokenizer.mask_token = "[MASK]"
87
- mask_token_id = self.tokenizer.encode(self.tokenizer.mask_token)[0]
 
 
 
88
 
89
- fluency_scores = []
90
- for i in range(1, input_ids.size(1) - 1):
91
- masked_input_ids = input_ids.clone()
92
- masked_input_ids[0, i] = mask_token_id
93
 
94
- with torch.no_grad():
95
- outputs = self.model(input_ids=masked_input_ids)
96
- logits = outputs.logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- original_token_id = input_ids[0, i]
99
- token_probability = logits[0, i].softmax(dim=-1)[original_token_id].item()
100
- fluency_scores.append(token_probability)
101
 
102
- avg_fluency = (
103
- sum(fluency_scores) / len(fluency_scores) if fluency_scores else 0.0
104
- )
105
- return avg_fluency
106
 
107
 
108
  class GenreAlignmentScorer(StoryScorer):
109
- def __init__(self, pipeline: Pipeline, genre: str):
110
  self.pipeline = pipeline
111
  self.genre = genre
 
112
 
113
- def score(self, story: str) -> float:
114
  if not self.genre:
115
- return 0.5
116
-
117
- # Evaluate by sentence to check whether the genre is maintained throughout
118
- sentences = [s.strip() for s in story.split(".") if s.strip()]
119
- results = []
120
- for sentence in sentences:
121
- result = self.pipeline(
122
- sentence, candidate_labels=[self.genre], multi_label=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  )
124
- results.append(result["scores"][0])
125
 
126
- avg_core = sum(results) / len(results) if results else 0.0
127
- return avg_core
 
 
 
 
 
 
 
128
 
129
 
130
  class StoryEvaluator:
@@ -140,43 +230,32 @@ class StoryEvaluator:
140
  self.genre_scorer = genre_scorer
141
  self.weights = weights
142
 
143
- def evaluate(self, story: str, max_scores: list[float]) -> CombinedScore:
144
- coherence = self.coherence_scorer.score(story)
145
- fluency = self.fluency_scorer.score(story)
146
- genre_alignment = self.genre_scorer.score(story)
147
-
148
- max_scores[0] = np.max([max_scores[0], coherence])
149
- max_scores[1] = np.max([max_scores[1], fluency])
150
- max_scores[2] = np.max([max_scores[2], genre_alignment])
151
-
152
- return CombinedScore(
153
- coherence=coherence,
154
- fluency=fluency,
155
- genre_alignment=genre_alignment,
156
- total=0,
157
- )
158
-
159
  def evaluate_multiple(self, stories: list[str]) -> list[tuple[str, CombinedScore]]:
160
- """Evaluate multiple stories and return them sorted by total score."""
161
-
162
- # Scores are normalized by the max scores on every evaluation
163
- # This is to ensure that the scores are comparable between each other, as they are originally on different scales
164
-
165
- # Reset max scores
166
- max_scores = [0.0, 0.0, 0.0]
167
-
168
- scored_stories = [
169
- (story, self.evaluate(story, max_scores)) for story in stories
170
- ]
171
-
172
- # Normalize scores
173
- for _, scores in scored_stories:
174
- scores.coherence, scores.fluency, scores.genre_alignment = np.divide(
175
- [scores.coherence, scores.fluency, scores.genre_alignment],
176
- max_scores,
 
 
 
 
177
  )
178
  scores.total = np.dot(
179
  [scores.coherence, scores.fluency, scores.genre_alignment], self.weights
180
  )
 
181
 
182
  return sorted(scored_stories, key=lambda x: x[1].total, reverse=True)
 
9
  class StoryScorer(Protocol):
10
  """Protocol defining the interface for story scoring components."""
11
 
12
+ def score(self, stories: list[str]) -> float:
13
  """Return a score between 0 and 1."""
14
  ...
15
 
 
28
  model: PreTrainedModel,
29
  tokenizer: PreTrainedTokenizer,
30
  device: torch.device,
31
+ batch_size: int = 32,
32
  ):
33
  self.model = model
34
  self.tokenizer = tokenizer
35
  self.device = device
36
+ self.batch_size = batch_size
37
 
38
+ def score(self, stories: list[str]) -> list[float]:
39
  """Calculate coherence score based on sentences cosine similarity."""
40
+ all_scores = []
41
 
42
+ # Split stories into sentences for coherence scoring
43
+ sentences_list = [
44
+ [s.strip() for s in story.split(".") if s.strip()] for story in stories
45
+ ]
46
 
47
+ # Collect all sentence pairs that need embedding
48
+ all_sentence_pairs = []
49
+ story_boundaries = [] # Track where each story's sentences end
50
+ current_position = 0
51
+
52
+ for sentences in sentences_list:
53
+ pairs_count = len(sentences) - 1
54
+ all_sentence_pairs.extend(zip(sentences[:-1], sentences[1:]))
55
+ story_boundaries.append(current_position + pairs_count)
56
+ current_position += pairs_count
57
+
58
+ # Process sentence pairs in batches
59
+ all_embeddings = []
60
+ for i in range(0, len(all_sentence_pairs), self.batch_size):
61
+ batch_pairs = all_sentence_pairs[i : i + self.batch_size]
62
+ # Flatten pairs for batch processing
63
+ batch_sentences = [sent for pair in batch_pairs for sent in pair]
64
+
65
+ # Tokenize batch
66
+ inputs = self.tokenizer(
67
+ batch_sentences, padding=True, truncation=True, return_tensors="pt"
68
+ ).to(self.device)
69
 
 
 
 
70
  with torch.no_grad():
71
+ outputs = self.model(**inputs, output_hidden_states=True)
72
+ batch_embeddings = outputs.hidden_states[-1][
73
+ :, 0, :
74
+ ] # Get CLS token embeddings
75
+ all_embeddings.extend(batch_embeddings.cpu().numpy())
76
+
77
+ # Calculate coherence scores for each story
78
+ current_idx = 0
79
+ for boundary in story_boundaries:
80
+ story_pairs_count = boundary - current_idx
81
+ story_scores = []
82
+
83
+ for i in range(story_pairs_count):
84
+ idx = current_idx + i
85
+ first_emb = all_embeddings[idx * 2].reshape(1, -1)
86
+ second_emb = all_embeddings[idx * 2 + 1].reshape(1, -1)
87
+ sim = cosine_similarity(first_emb, second_emb)[0][0]
88
+ story_scores.append(sim)
89
+
90
+ avg_coherence = (
91
+ sum(story_scores) / len(story_scores) if story_scores else 0.0
92
+ )
93
+ all_scores.append(avg_coherence)
94
+ current_idx = boundary
95
 
96
+ return all_scores
 
 
 
 
97
 
98
 
99
  class FluencyScorer(StoryScorer):
 
102
  model: PreTrainedModel,
103
  tokenizer: PreTrainedTokenizer,
104
  device: torch.device,
105
+ batch_size: int = 32,
106
  ):
107
  self.model = model
108
  self.tokenizer = tokenizer
109
  self.device = device
110
+ self.batch_size = batch_size
111
+
112
+ # Set up padding token if it doesn't exist
113
+ if self.tokenizer.pad_token is None:
114
+ if self.tokenizer.eos_token is not None:
115
+ self.tokenizer.pad_token = self.tokenizer.eos_token
116
+ else:
117
+ # Add padding token to tokenizer only
118
+ self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
119
+
120
+ # Set up mask token if it doesn't exist
121
+ if self.tokenizer.mask_token is None:
122
+ self.tokenizer.add_special_tokens({"mask_token": "[MASK]"})
123
+
124
+ def score(self, stories: list[str]) -> list[float]:
125
+ all_scores = []
126
  mask_token_id = self.tokenizer.mask_token_id
127
 
128
+ # Process stories in batches
129
+ for i in range(0, len(stories), self.batch_size):
130
+ batch_stories = stories[i : i + self.batch_size]
131
+ batch_inputs = self.tokenizer(
132
+ batch_stories, padding=True, truncation=True, return_tensors="pt"
133
+ ).to(self.device)
134
 
135
+ batch_scores = []
 
 
 
136
 
137
+ # For each story in the batch
138
+ for j in range(len(batch_stories)):
139
+ story_scores = []
140
+ input_ids = batch_inputs.input_ids[j : j + 1] # Keep batch dimension
141
+ attention_mask = batch_inputs.attention_mask[
142
+ j : j + 1
143
+ ] # Get attention mask
144
+
145
+ # Only process tokens that aren't padding
146
+ valid_tokens = attention_mask[0].sum().item()
147
+
148
+ # For each token in the story (excluding padding)
149
+ for k in range(1, valid_tokens - 1):
150
+ masked_input_ids = input_ids.clone()
151
+ masked_input_ids[0, k] = mask_token_id
152
+
153
+ with torch.no_grad():
154
+ outputs = self.model(
155
+ input_ids=masked_input_ids, attention_mask=attention_mask
156
+ )
157
+ logits = outputs.logits
158
+
159
+ original_token_id = input_ids[0, k]
160
+ token_probability = (
161
+ logits[0, k].softmax(dim=-1)[original_token_id].item()
162
+ )
163
+ story_scores.append(token_probability)
164
+
165
+ avg_fluency = (
166
+ sum(story_scores) / len(story_scores) if story_scores else 0.0
167
+ )
168
+ batch_scores.append(avg_fluency)
169
 
170
+ all_scores.extend(batch_scores)
 
 
171
 
172
+ return all_scores
 
 
 
173
 
174
 
175
  class GenreAlignmentScorer(StoryScorer):
176
+ def __init__(self, pipeline: Pipeline, genre: str, batch_size: int = 32):
177
  self.pipeline = pipeline
178
  self.genre = genre
179
+ self.batch_size = batch_size
180
 
181
+ def score(self, stories: list[str]) -> list[float]:
182
  if not self.genre:
183
+ return [0.5] * len(stories)
184
+
185
+ all_scores = []
186
+ # Split all stories into sentences
187
+ all_sentences = []
188
+ story_boundaries = []
189
+ current_position = 0
190
+
191
+ for story in stories:
192
+ sentences = [s.strip() for s in story.split(".") if s.strip()]
193
+ all_sentences.extend(sentences)
194
+ story_boundaries.append(current_position + len(sentences))
195
+ current_position += len(sentences)
196
+
197
+ # Process sentences in batches
198
+ all_sentence_scores = []
199
+ for i in range(0, len(all_sentences), self.batch_size):
200
+ batch_sentences = all_sentences[i : i + self.batch_size]
201
+ results = self.pipeline(
202
+ batch_sentences,
203
+ candidate_labels=[self.genre],
204
+ multi_label=True,
205
+ batch_size=self.batch_size,
206
  )
207
+ all_sentence_scores.extend([r["scores"][0] for r in results])
208
 
209
+ # Calculate average score for each story
210
+ current_idx = 0
211
+ for boundary in story_boundaries:
212
+ story_scores = all_sentence_scores[current_idx:boundary]
213
+ avg_score = sum(story_scores) / len(story_scores) if story_scores else 0.0
214
+ all_scores.append(avg_score)
215
+ current_idx = boundary
216
+
217
+ return all_scores
218
 
219
 
220
  class StoryEvaluator:
 
230
  self.genre_scorer = genre_scorer
231
  self.weights = weights
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  def evaluate_multiple(self, stories: list[str]) -> list[tuple[str, CombinedScore]]:
234
+ """Evaluate multiple stories in batches and return them sorted by total score."""
235
+
236
+ # Get all scores in parallel using batch processing
237
+ coherence_scores = self.coherence_scorer.score(stories)
238
+ fluency_scores = self.fluency_scorer.score(stories)
239
+ genre_scores = self.genre_scorer.score(stories)
240
+
241
+ # Find max scores for normalization
242
+ max_scores = [max(coherence_scores), max(fluency_scores), max(genre_scores)]
243
+
244
+ # Create scored stories
245
+ scored_stories = []
246
+ for i, story in enumerate(stories):
247
+ scores = CombinedScore(
248
+ coherence=(
249
+ coherence_scores[i] / max_scores[0] if max_scores[0] != 0 else 0
250
+ ),
251
+ fluency=fluency_scores[i] / max_scores[1] if max_scores[1] != 0 else 0,
252
+ genre_alignment=(
253
+ genre_scores[i] / max_scores[2] if max_scores[2] != 0 else 0
254
+ ),
255
  )
256
  scores.total = np.dot(
257
  [scores.coherence, scores.fluency, scores.genre_alignment], self.weights
258
  )
259
+ scored_stories.append((story, scores))
260
 
261
  return sorted(scored_stories, key=lambda x: x[1].total, reverse=True)
story_beam_search/stories_generator.py CHANGED
@@ -160,10 +160,10 @@ class StoryGenerationSystem:
160
  prompt_segments = re.split(r"[^a-zA-Z0-9 ]+", prompt)
161
  prompt_segments = list(set(prompt_segments))
162
 
163
- storyness_score = self.storyness.score(prompt)
164
  for segment in prompt_segments:
165
  if segment.strip():
166
- injection_score = self.injection_guard.score(segment)
167
  if storyness_score < 0.2 or injection_score > 0.2:
168
  print("Potential prompt injection detected.")
169
  print(f"storyness_score: {storyness_score}")
 
160
  prompt_segments = re.split(r"[^a-zA-Z0-9 ]+", prompt)
161
  prompt_segments = list(set(prompt_segments))
162
 
163
+ storyness_score = self.storyness.score([prompt])[0]
164
  for segment in prompt_segments:
165
  if segment.strip():
166
+ injection_score = self.injection_guard.score([segment])[0]
167
  if storyness_score < 0.2 or injection_score > 0.2:
168
  print("Potential prompt injection detected.")
169
  print(f"storyness_score: {storyness_score}")