Text-to-Speech
English

Segmenting Text for Token Limitations - Code Provided for v0.19

#97
by gmacgmac - opened

@hexgrad incredible work on this model, thanks so much.
Been working on handling the segmentation of sentences and though it might help to share with others.
Also some other options here too

There's 2 versions of the segmentation, active one is more comprehensive but you get some ideas.
Place at the root of Kokoro v0.19

I've just not implemented for the latest yet but when I do I will share...

import re
import os
import argparse
import torch
import numpy as np
import scipy.io.wavfile as wav
from typing import List, Tuple, Optional
from dataclasses import dataclass
from models import build_model
from kokoro import phonemize, tokenize, forward  # Import core functions from kokoro.py

@dataclass
class TextChunk:
    text: str
    index: int

class TTSProcessor:
    def __init__(self, model_path: str, voice_folder: str, voice_name: str):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = build_model(model_path, self.device)
        self.voicepack = torch.load(
            f'{voice_folder}{voice_name}', 
            weights_only=True
        ).to(self.device)
        self.lang = voice_name[0]  # Extract language code from voice name
        self.sample_rate = 24000  # Standard sample rate for this model

    @staticmethod
    def word_count(text: str) -> int:
        """Returns the number of words in the text."""
        return len(text.split())

    @staticmethod
    def split_sentences(text: str) -> List[str]:
        """
        Splits text into sentences using `. `, `? `, and `! `, while avoiding abbreviations.
        """
        # List of common abbreviations to avoid false splits
        abbreviations = {"Dr.", "Mr.", "Mrs.", "Ms.", "Prof.", "e.g.", "i.e.", "etc.", "U.S.A.", "U.K."}
        
        # Replace abbreviations with a placeholder to avoid splitting
        for abbr in abbreviations:
            text = text.replace(abbr, abbr.replace(".", "<abbr>"))
        
        # Split on sentence boundaries
        sentences = re.split(r'(?<=[.!?])\s+', text)
        
        # Restore abbreviations
        for i in range(len(sentences)):
            sentences[i] = sentences[i].replace("<abbr>", ".")
        
        return [s.strip() for s in sentences if s.strip()]

    @staticmethod
    def split_commas(text: str) -> List[str]:
        """
        Splits text at commas (`, `), avoiding splits in lists, dates, or numbers.
        """
        # Avoid splitting in numbers (e.g., 1,000) or dates (e.g., Dec. 25, 2023)
        if re.search(r'\d,\d', text):  # Numbers with commas
            return [text]
        
        # Split on commas
        segments = re.split(r',\s+', text)
        return [s.strip() for s in segments if s.strip()]

    @staticmethod
    def split_conjunctions(text: str) -> List[str]:
        """
        Splits text at conjunctions (` and `, ` but `).
        """
        segments = []
        for marker in [' and ', ' but ']:
            if marker in text:
                parts = text.split(marker)
                segments.extend([parts[0]] + [marker.strip() + ' ' + p for p in parts[1:]])
                break
        return segments if segments else [text]

    @staticmethod
    def force_split(text: str, max_words: int) -> List[str]:
        """
        Force-splits text into chunks of `max_words` words each.
        """
        words = text.split()
        chunks = []
        for i in range(0, len(words), max_words):
            chunk = ' '.join(words[i:i + max_words])
            chunks.append(chunk)
        return chunks

    def chunk_text(self, text: str, max_words: int = 50) -> List[TextChunk]:
        """
        Recursively chunk text with strict validation and prioritized splits.
        """
        chunks = []

        def _chunk_recursive(text: str) -> None:
            # Step 1: Split on newlines
            lines = [line.strip() for line in text.split('\n') if line.strip()]
            for line in lines:
                if self.word_count(line) <= max_words:
                    chunks.append(TextChunk(text=line, index=len(chunks)))
                else:
                    # Step 2: Split on sentence boundaries
                    sentences = self.split_sentences(line)
                    for sentence in sentences:
                        if self.word_count(sentence) <= max_words:
                            chunks.append(TextChunk(text=sentence, index=len(chunks)))
                        else:
                            # Step 3: Split on commas
                            segments = self.split_commas(sentence)
                            for segment in segments:
                                if self.word_count(segment) <= max_words:
                                    chunks.append(TextChunk(text=segment, index=len(chunks)))
                                else:
                                    # Step 4: Split on conjunctions
                                    subsegments = self.split_conjunctions(segment)
                                    for subsegment in subsegments:
                                        if self.word_count(subsegment) <= max_words:
                                            chunks.append(TextChunk(text=subsegment, index=len(chunks)))
                                        else:
                                            # Step 5: Force-split at max_words
                                            for forced_chunk in self.force_split(subsegment, max_words):
                                                chunks.append(TextChunk(text=forced_chunk, index=len(chunks)))

        _chunk_recursive(text)
        return chunks
    
    # @staticmethod
    # def chunk_text(self, text: str, split_level: int = 0) -> List[TextChunk]:
    #     """
    #     Recursively chunk text with strict validation and prioritized splits.
    #     Args:
    #         self: The TTSProcessor instance
    #         text: Text to chunk
    #         split_level: Current split level (0=newlines, 1=sentences, 2=conjunctions, 3=force)
    #     """
    #     def _chunk_recursive(text: str, level: int) -> List[TextChunk]:
    #         chunks = []
            
    #         def validate_segment(segment: str) -> bool:
    #             return len(segment.split()) <= 50 if segment else False
            
    #         def split_and_validate(segment: str, level: int) -> List[str]:
    #             if level == 0:
    #                 return [s.strip() for s in segment.split('\n') if s.strip()]
                    
    #             elif level == 1:
    #                 parts = re.split(r'([.!?])\s+', segment)
    #                 results = []
    #                 for i in range(0, len(parts)-1, 2):
    #                     if i+1 < len(parts):
    #                         results.append(parts[i] + parts[i+1])
    #                 if len(parts) % 2:
    #                     results.append(parts[-1])
    #                 return results
                    
    #             elif level == 2:
    #                 for marker in [' and ', ' but ']:
    #                     if marker in segment:
    #                         parts = segment.split(marker)
    #                         return [parts[0]] + [marker.strip() + ' ' + p for p in parts[1:]]
    #                 return [segment]
                    
    #             else:  # level == 3
    #                 words = segment.split()
    #                 return [' '.join(words[:50]), ' '.join(words[50:])]
            
    #         segments = split_and_validate(text, level)
            
    #         for segment in segments:
    #             if not segment:
    #                 continue
                    
    #             if validate_segment(segment):
    #                 if not segment[-1] in '.!?':
    #                     segment += '.'
    #                 chunks.append(TextChunk(text=segment, index=len(chunks)))
    #             else:
    #                 chunks.extend(_chunk_recursive(segment, level + 1))
            
    #         return chunks

    #     # Start the recursive process
    #     return _chunk_recursive(text, split_level)

    def trim_audio(self, audio: np.ndarray, trim_start_ms: int = 0, trim_end_ms: int = 0) -> np.ndarray:
        """
        Trim audio array at start and end by specified milliseconds
        """
        if trim_start_ms == 0 and trim_end_ms == 0:
            return audio
            
        # Convert milliseconds to samples
        start_samples = int(trim_start_ms * self.sample_rate / 1000)
        end_samples = int(trim_end_ms * self.sample_rate / 1000)
        
        # Apply trimming
        if end_samples > 0:
            return audio[start_samples:-end_samples]
        return audio[start_samples:]

    def process_chunk(self, text: str, speed: float = 1.0) -> np.ndarray:
        """
        Process a single chunk of text using kokoro's core functionality
        """
        # Use kokoro's phonemize function
        ps = phonemize(text, self.lang)
        tokens = tokenize(ps)
        
        if not tokens:
            return None
            
        if len(tokens) > 510:
            print(f"Warning: Text exceeds 510 tokens, truncating")
            tokens = tokens[:510]
            
        # Get reference style
        ref_s = self.voicepack[len(tokens)]
        
        # Use kokoro's forward function
        audio = forward(self.model, tokens, ref_s, speed)
        return audio

    def process_text(
        self, 
        text: str,
        output_folder: str,
        output_prefix: str = "",
        speed: float = 1.0,
        trim_start_ms: int = 0,
        trim_end_ms: int = 0,
        save_individual: bool = False,
        final_output_name: str = "combined_output.wav"
    ) -> None:

        # print('')
        # print('process_text')
        # print('text')
        # print('')
        """
        Process full text by chunking and combining
        """
        os.makedirs(output_folder, exist_ok=True)
        chunks = self.chunk_text(text)
        all_audio = []

        for chunk in chunks:
            print('')
            print('processing chunk')
            print(chunk)
            print('')
            # Process the chunk using kokoro's functionality
            audio = self.process_chunk(chunk.text, speed)
            
            if audio is None:
                print(f"Warning: No audio generated for chunk {chunk.index}")
                continue
                
            # Trim the audio if specified
            audio = self.trim_audio(audio, trim_start_ms, trim_end_ms)
            
            # Save individual chunk if requested
            if save_individual:
                output_file = f"{output_prefix}{chunk.index}.wav"
                output_path = os.path.join(output_folder, output_file)
                wav.write(output_path, self.sample_rate, audio)
            
            # Add to our collection for final concatenation
            all_audio.append(audio)
            # print(f"Processed chunk {chunk.index}: {chunk.text[:50]}...")
            print(f"Processed chunk {chunk.index}: {chunk.text}")

        # Concatenate all audio chunks
        if all_audio:
            combined_audio = np.concatenate(all_audio)
            final_path = os.path.join(output_folder, final_output_name)
            wav.write(final_path, self.sample_rate, combined_audio)
            print(f"Created combined audio file: {final_path}")

def main():
    parser = argparse.ArgumentParser(description="Batch TTS Processing")
    parser.add_argument("--text", type=str, required=True, help="Input text to process")
    parser.add_argument("--voice-folder", type=str, default="voices_0.19/")
    parser.add_argument("--voice", type=str, required=True)
    parser.add_argument("--model", type=str, default="kokoro-v0_19.pth")
    parser.add_argument("--output-folder", type=str, default="output")
    parser.add_argument("--output-prefix", type=str, default="")
    parser.add_argument("--speed", type=float, default=1.0)
    parser.add_argument("--trim-start", type=int, default=0, help="Milliseconds to trim from start of each chunk")
    parser.add_argument("--trim-end", type=int, default=0, help="Milliseconds to trim from end of each chunk")
    parser.add_argument("--save-individual", action="store_true", help="Save individual chunk files")
    parser.add_argument("--output-name", type=str, default="combined_output.wav", help="Name for final combined file")
    
    args = parser.parse_args()
    print('')

    # Initialize the TTS processor once
    processor = TTSProcessor(
        model_path=args.model,
        voice_folder=args.voice_folder,
        voice_name=args.voice
    )
    
    # Process the text
    processor.process_text(
        text=args.text,
        output_folder=args.output_folder,
        output_prefix=args.output_prefix,
        speed=args.speed,
        trim_start_ms=args.trim_start,
        trim_end_ms=args.trim_end,
        save_individual=args.save_individual,
        final_output_name=args.output_name
    )

if __name__ == "__main__":
    main()

To use...

cd ~/Kokoro-82M &&
conda run -n tts python inference_pro_v2.py \
  --voice-folder "voices/" \
  --voice "bf_isabella.pt" \
  --model "kokoro-v0_19.pth" \
  --output-folder "/Users/me/n8n/output/tts/" \
  --output-name "FileName.wav" \
  --trim-start 200 \
  --trim-end 200 \
  --text " [TEXT GOES HERE] "

Sign up or log in to comment