jhj0517 commited on
Commit
f6adc1d
·
1 Parent(s): 767d188

add util functions for files

Browse files
modules/utils/files_manager.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import fnmatch
3
+
4
+ from gradio.utils import NamedString
5
+
6
+
7
+ def get_media_files(folder_path, include_sub_directory=False):
8
+ video_extensions = ['*.mp4', '*.mkv', '*.flv', '*.avi', '*.mov', '*.wmv']
9
+ audio_extensions = ['*.mp3', '*.wav', '*.aac', '*.flac', '*.ogg', '*.m4a']
10
+ media_extensions = video_extensions + audio_extensions
11
+
12
+ media_files = []
13
+
14
+ if include_sub_directory:
15
+ for root, _, files in os.walk(folder_path):
16
+ for extension in media_extensions:
17
+ media_files.extend(
18
+ os.path.join(root, file) for file in fnmatch.filter(files, extension)
19
+ if os.path.exists(os.path.join(root, file))
20
+ )
21
+ else:
22
+ for extension in media_extensions:
23
+ media_files.extend(
24
+ os.path.join(folder_path, file) for file in fnmatch.filter(os.listdir(folder_path), extension)
25
+ if os.path.isfile(os.path.join(folder_path, file)) and os.path.exists(os.path.join(folder_path, file))
26
+ )
27
+
28
+ return media_files
29
+
30
+
31
+ def format_gradio_files(files: list):
32
+ if not files:
33
+ return files
34
+
35
+ gradio_files = []
36
+ for file in files:
37
+ gradio_files.append(NamedString(file))
38
+ return gradio_files
39
+
modules/whisper/whisper_base.py CHANGED
@@ -12,6 +12,7 @@ from dataclasses import astuple
12
 
13
  from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
14
  from modules.utils.youtube_manager import get_ytdata, get_ytaudio
 
15
  from modules.whisper.whisper_parameter import *
16
  from modules.diarize.diarizer import Diarizer
17
  from modules.vad.silero_vad import SileroVAD
@@ -123,6 +124,7 @@ class WhisperBase(ABC):
123
 
124
  def transcribe_file(self,
125
  files: list,
 
126
  file_format: str,
127
  add_timestamp: bool,
128
  progress=gr.Progress(),
@@ -135,6 +137,9 @@ class WhisperBase(ABC):
135
  ----------
136
  files: list
137
  List of files to transcribe from gr.Files()
 
 
 
138
  file_format: str
139
  Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
140
  add_timestamp: bool
@@ -152,6 +157,10 @@ class WhisperBase(ABC):
152
  Output file path to return to gr.Files()
153
  """
154
  try:
 
 
 
 
155
  files_info = {}
156
  for file in files:
157
  transcribed_segments, time_for_task = self.run(
 
12
 
13
  from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
14
  from modules.utils.youtube_manager import get_ytdata, get_ytaudio
15
+ from modules.utils.files_manager import get_media_files, format_gradio_files
16
  from modules.whisper.whisper_parameter import *
17
  from modules.diarize.diarizer import Diarizer
18
  from modules.vad.silero_vad import SileroVAD
 
124
 
125
  def transcribe_file(self,
126
  files: list,
127
+ input_folder_path: str,
128
  file_format: str,
129
  add_timestamp: bool,
130
  progress=gr.Progress(),
 
137
  ----------
138
  files: list
139
  List of files to transcribe from gr.Files()
140
+ input_folder_path: str
141
+ Input folder path to transcribe from gr.Textbox(). If this is provided, `files` will be ignored and
142
+ this will be used instead.
143
  file_format: str
144
  Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
145
  add_timestamp: bool
 
157
  Output file path to return to gr.Files()
158
  """
159
  try:
160
+ if input_folder_path:
161
+ files = get_media_files(input_folder_path)
162
+ files = format_gradio_files(files)
163
+
164
  files_info = {}
165
  for file in files:
166
  transcribed_segments, time_for_task = self.run(