Segmenting Text for Token Limitations - Code Provided for v0.19
#97
by
gmacgmac
- opened
@hexgrad
incredible work on this model, thanks so much.
Been working on handling the segmentation of sentences and though it might help to share with others.
Also some other options here too
There's 2 versions of the segmentation, active one is more comprehensive but you get some ideas.
Place at the root of Kokoro v0.19
I've just not implemented for the latest yet but when I do I will share...
import re
import os
import argparse
import torch
import numpy as np
import scipy.io.wavfile as wav
from typing import List, Tuple, Optional
from dataclasses import dataclass
from models import build_model
from kokoro import phonemize, tokenize, forward # Import core functions from kokoro.py
@dataclass
class TextChunk:
text: str
index: int
class TTSProcessor:
def __init__(self, model_path: str, voice_folder: str, voice_name: str):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = build_model(model_path, self.device)
self.voicepack = torch.load(
f'{voice_folder}{voice_name}',
weights_only=True
).to(self.device)
self.lang = voice_name[0] # Extract language code from voice name
self.sample_rate = 24000 # Standard sample rate for this model
@staticmethod
def word_count(text: str) -> int:
"""Returns the number of words in the text."""
return len(text.split())
@staticmethod
def split_sentences(text: str) -> List[str]:
"""
Splits text into sentences using `. `, `? `, and `! `, while avoiding abbreviations.
"""
# List of common abbreviations to avoid false splits
abbreviations = {"Dr.", "Mr.", "Mrs.", "Ms.", "Prof.", "e.g.", "i.e.", "etc.", "U.S.A.", "U.K."}
# Replace abbreviations with a placeholder to avoid splitting
for abbr in abbreviations:
text = text.replace(abbr, abbr.replace(".", "<abbr>"))
# Split on sentence boundaries
sentences = re.split(r'(?<=[.!?])\s+', text)
# Restore abbreviations
for i in range(len(sentences)):
sentences[i] = sentences[i].replace("<abbr>", ".")
return [s.strip() for s in sentences if s.strip()]
@staticmethod
def split_commas(text: str) -> List[str]:
"""
Splits text at commas (`, `), avoiding splits in lists, dates, or numbers.
"""
# Avoid splitting in numbers (e.g., 1,000) or dates (e.g., Dec. 25, 2023)
if re.search(r'\d,\d', text): # Numbers with commas
return [text]
# Split on commas
segments = re.split(r',\s+', text)
return [s.strip() for s in segments if s.strip()]
@staticmethod
def split_conjunctions(text: str) -> List[str]:
"""
Splits text at conjunctions (` and `, ` but `).
"""
segments = []
for marker in [' and ', ' but ']:
if marker in text:
parts = text.split(marker)
segments.extend([parts[0]] + [marker.strip() + ' ' + p for p in parts[1:]])
break
return segments if segments else [text]
@staticmethod
def force_split(text: str, max_words: int) -> List[str]:
"""
Force-splits text into chunks of `max_words` words each.
"""
words = text.split()
chunks = []
for i in range(0, len(words), max_words):
chunk = ' '.join(words[i:i + max_words])
chunks.append(chunk)
return chunks
def chunk_text(self, text: str, max_words: int = 50) -> List[TextChunk]:
"""
Recursively chunk text with strict validation and prioritized splits.
"""
chunks = []
def _chunk_recursive(text: str) -> None:
# Step 1: Split on newlines
lines = [line.strip() for line in text.split('\n') if line.strip()]
for line in lines:
if self.word_count(line) <= max_words:
chunks.append(TextChunk(text=line, index=len(chunks)))
else:
# Step 2: Split on sentence boundaries
sentences = self.split_sentences(line)
for sentence in sentences:
if self.word_count(sentence) <= max_words:
chunks.append(TextChunk(text=sentence, index=len(chunks)))
else:
# Step 3: Split on commas
segments = self.split_commas(sentence)
for segment in segments:
if self.word_count(segment) <= max_words:
chunks.append(TextChunk(text=segment, index=len(chunks)))
else:
# Step 4: Split on conjunctions
subsegments = self.split_conjunctions(segment)
for subsegment in subsegments:
if self.word_count(subsegment) <= max_words:
chunks.append(TextChunk(text=subsegment, index=len(chunks)))
else:
# Step 5: Force-split at max_words
for forced_chunk in self.force_split(subsegment, max_words):
chunks.append(TextChunk(text=forced_chunk, index=len(chunks)))
_chunk_recursive(text)
return chunks
# @staticmethod
# def chunk_text(self, text: str, split_level: int = 0) -> List[TextChunk]:
# """
# Recursively chunk text with strict validation and prioritized splits.
# Args:
# self: The TTSProcessor instance
# text: Text to chunk
# split_level: Current split level (0=newlines, 1=sentences, 2=conjunctions, 3=force)
# """
# def _chunk_recursive(text: str, level: int) -> List[TextChunk]:
# chunks = []
# def validate_segment(segment: str) -> bool:
# return len(segment.split()) <= 50 if segment else False
# def split_and_validate(segment: str, level: int) -> List[str]:
# if level == 0:
# return [s.strip() for s in segment.split('\n') if s.strip()]
# elif level == 1:
# parts = re.split(r'([.!?])\s+', segment)
# results = []
# for i in range(0, len(parts)-1, 2):
# if i+1 < len(parts):
# results.append(parts[i] + parts[i+1])
# if len(parts) % 2:
# results.append(parts[-1])
# return results
# elif level == 2:
# for marker in [' and ', ' but ']:
# if marker in segment:
# parts = segment.split(marker)
# return [parts[0]] + [marker.strip() + ' ' + p for p in parts[1:]]
# return [segment]
# else: # level == 3
# words = segment.split()
# return [' '.join(words[:50]), ' '.join(words[50:])]
# segments = split_and_validate(text, level)
# for segment in segments:
# if not segment:
# continue
# if validate_segment(segment):
# if not segment[-1] in '.!?':
# segment += '.'
# chunks.append(TextChunk(text=segment, index=len(chunks)))
# else:
# chunks.extend(_chunk_recursive(segment, level + 1))
# return chunks
# # Start the recursive process
# return _chunk_recursive(text, split_level)
def trim_audio(self, audio: np.ndarray, trim_start_ms: int = 0, trim_end_ms: int = 0) -> np.ndarray:
"""
Trim audio array at start and end by specified milliseconds
"""
if trim_start_ms == 0 and trim_end_ms == 0:
return audio
# Convert milliseconds to samples
start_samples = int(trim_start_ms * self.sample_rate / 1000)
end_samples = int(trim_end_ms * self.sample_rate / 1000)
# Apply trimming
if end_samples > 0:
return audio[start_samples:-end_samples]
return audio[start_samples:]
def process_chunk(self, text: str, speed: float = 1.0) -> np.ndarray:
"""
Process a single chunk of text using kokoro's core functionality
"""
# Use kokoro's phonemize function
ps = phonemize(text, self.lang)
tokens = tokenize(ps)
if not tokens:
return None
if len(tokens) > 510:
print(f"Warning: Text exceeds 510 tokens, truncating")
tokens = tokens[:510]
# Get reference style
ref_s = self.voicepack[len(tokens)]
# Use kokoro's forward function
audio = forward(self.model, tokens, ref_s, speed)
return audio
def process_text(
self,
text: str,
output_folder: str,
output_prefix: str = "",
speed: float = 1.0,
trim_start_ms: int = 0,
trim_end_ms: int = 0,
save_individual: bool = False,
final_output_name: str = "combined_output.wav"
) -> None:
# print('')
# print('process_text')
# print('text')
# print('')
"""
Process full text by chunking and combining
"""
os.makedirs(output_folder, exist_ok=True)
chunks = self.chunk_text(text)
all_audio = []
for chunk in chunks:
print('')
print('processing chunk')
print(chunk)
print('')
# Process the chunk using kokoro's functionality
audio = self.process_chunk(chunk.text, speed)
if audio is None:
print(f"Warning: No audio generated for chunk {chunk.index}")
continue
# Trim the audio if specified
audio = self.trim_audio(audio, trim_start_ms, trim_end_ms)
# Save individual chunk if requested
if save_individual:
output_file = f"{output_prefix}{chunk.index}.wav"
output_path = os.path.join(output_folder, output_file)
wav.write(output_path, self.sample_rate, audio)
# Add to our collection for final concatenation
all_audio.append(audio)
# print(f"Processed chunk {chunk.index}: {chunk.text[:50]}...")
print(f"Processed chunk {chunk.index}: {chunk.text}")
# Concatenate all audio chunks
if all_audio:
combined_audio = np.concatenate(all_audio)
final_path = os.path.join(output_folder, final_output_name)
wav.write(final_path, self.sample_rate, combined_audio)
print(f"Created combined audio file: {final_path}")
def main():
parser = argparse.ArgumentParser(description="Batch TTS Processing")
parser.add_argument("--text", type=str, required=True, help="Input text to process")
parser.add_argument("--voice-folder", type=str, default="voices_0.19/")
parser.add_argument("--voice", type=str, required=True)
parser.add_argument("--model", type=str, default="kokoro-v0_19.pth")
parser.add_argument("--output-folder", type=str, default="output")
parser.add_argument("--output-prefix", type=str, default="")
parser.add_argument("--speed", type=float, default=1.0)
parser.add_argument("--trim-start", type=int, default=0, help="Milliseconds to trim from start of each chunk")
parser.add_argument("--trim-end", type=int, default=0, help="Milliseconds to trim from end of each chunk")
parser.add_argument("--save-individual", action="store_true", help="Save individual chunk files")
parser.add_argument("--output-name", type=str, default="combined_output.wav", help="Name for final combined file")
args = parser.parse_args()
print('')
# Initialize the TTS processor once
processor = TTSProcessor(
model_path=args.model,
voice_folder=args.voice_folder,
voice_name=args.voice
)
# Process the text
processor.process_text(
text=args.text,
output_folder=args.output_folder,
output_prefix=args.output_prefix,
speed=args.speed,
trim_start_ms=args.trim_start,
trim_end_ms=args.trim_end,
save_individual=args.save_individual,
final_output_name=args.output_name
)
if __name__ == "__main__":
main()
To use...
cd ~/Kokoro-82M &&
conda run -n tts python inference_pro_v2.py \
--voice-folder "voices/" \
--voice "bf_isabella.pt" \
--model "kokoro-v0_19.pth" \
--output-folder "/Users/me/n8n/output/tts/" \
--output-name "FileName.wav" \
--trim-start 200 \
--trim-end 200 \
--text " [TEXT GOES HERE] "