jhj0517 commited on
Commit
a06971f
·
1 Parent(s): adab100

add parameters in FasterWhisperInference

Browse files
modules/whisper/faster_whisper_inference.py CHANGED
@@ -5,6 +5,7 @@ import torch
5
  from typing import BinaryIO, Union, Tuple, List
6
  import faster_whisper
7
  from faster_whisper.vad import VadOptions
 
8
  import ctranslate2
9
  import whisper
10
  import gradio as gr
@@ -62,6 +63,8 @@ class FasterWhisperInference(WhisperBase):
62
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
63
  self.update_model(params.model_size, params.compute_type, progress)
64
 
 
 
65
  segments, info = self.model.transcribe(
66
  audio=audio,
67
  language=params.lang,
@@ -73,6 +76,22 @@ class FasterWhisperInference(WhisperBase):
73
  patience=params.patience,
74
  temperature=params.temperature,
75
  compression_ratio_threshold=params.compression_ratio_threshold,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  )
77
  progress(0, desc="Loading audio..")
78
 
@@ -147,3 +166,13 @@ class FasterWhisperInference(WhisperBase):
147
  return "cuda"
148
  else:
149
  return "auto"
 
 
 
 
 
 
 
 
 
 
 
5
  from typing import BinaryIO, Union, Tuple, List
6
  import faster_whisper
7
  from faster_whisper.vad import VadOptions
8
+ import ast
9
  import ctranslate2
10
  import whisper
11
  import gradio as gr
 
63
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
64
  self.update_model(params.model_size, params.compute_type, progress)
65
 
66
+ params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens)
67
+
68
  segments, info = self.model.transcribe(
69
  audio=audio,
70
  language=params.lang,
 
76
  patience=params.patience,
77
  temperature=params.temperature,
78
  compression_ratio_threshold=params.compression_ratio_threshold,
79
+ length_penalty=params.length_penalty,
80
+ repetition_penalty=params.repetition_penalty,
81
+ no_repeat_ngram_size=params.no_repeat_ngram_size,
82
+ prefix=params.prefix,
83
+ suppress_blank=params.suppress_blank,
84
+ suppress_tokens=params.suppress_tokens,
85
+ max_initial_timestamp=params.max_initial_timestamp,
86
+ word_timestamps=params.word_timestamps,
87
+ prepend_punctuations=params.prepend_punctuations,
88
+ append_punctuations=params.append_punctuations,
89
+ max_new_tokens=params.max_new_tokens,
90
+ chunk_length=params.chunk_length,
91
+ hallucination_silence_threshold=params.hallucination_silence_threshold,
92
+ hotwords=params.hotwords,
93
+ language_detection_threshold=params.language_detection_threshold,
94
+ language_detection_segments=params.language_detection_segments
95
  )
96
  progress(0, desc="Loading audio..")
97
 
 
166
  return "cuda"
167
  else:
168
  return "auto"
169
+
170
+ @staticmethod
171
+ def format_suppress_tokens_str(suppress_tokens_str: str) -> List[int]:
172
+ try:
173
+ suppress_tokens = ast.literal_eval(suppress_tokens_str)
174
+ if not isinstance(suppress_tokens, list) or not all(isinstance(item, int) for item in suppress_tokens):
175
+ raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
176
+ return suppress_tokens
177
+ except Exception as e:
178
+ raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")