jhj0517 commited on
Commit
f282b7c
·
unverified ·
2 Parent(s): 9f11092 6074f61

Merge pull request #14 from damho1104/add-to-remove-input-file-when-finish

Browse files
Files changed (1) hide show
  1. modules/whisper_Inference.py +109 -88
modules/whisper_Inference.py CHANGED
@@ -22,20 +22,80 @@ class WhisperInference:
22
  def progress_callback(progress_value):
23
  progress(progress_value, desc="Transcribing..")
24
 
25
- if model_size != self.current_model_size or self.model is None:
26
- progress(0, desc="Initializing Model..")
27
- self.current_model_size = model_size
28
- self.model = whisper.load_model(name=model_size, download_root="models/Whisper")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- if lang == "Automatic Detection":
31
- lang = None
 
32
 
33
- progress(0, desc="Loading Audio..")
 
34
 
35
- files_info = {}
36
- for fileobj in fileobjs:
 
 
 
37
 
38
- audio = whisper.load_audio(fileobj.name)
 
 
 
 
 
39
 
40
  translatable_model = ["large", "large-v1", "large-v2"]
41
  if istranslate and self.current_model_size in translatable_model:
@@ -47,9 +107,7 @@ class WhisperInference:
47
 
48
  progress(1, desc="Completed!")
49
 
50
- file_name, file_ext = os.path.splitext(os.path.basename(fileobj.orig_name))
51
- file_name = file_name[:-9]
52
- file_name = safe_filename(file_name)
53
  timestamp = datetime.now().strftime("%m%d%H%M%S")
54
  output_path = f"outputs/{file_name}-{timestamp}"
55
 
@@ -60,57 +118,14 @@ class WhisperInference:
60
  subtitle = get_vtt(result["segments"])
61
  write_file(subtitle, f"{output_path}.vtt")
62
 
63
- files_info[file_name] = subtitle
64
-
65
- total_result = ''
66
- for file_name, subtitle in files_info.items():
67
- total_result += '------------------------------------\n'
68
- total_result += f'{file_name}\n\n'
69
- total_result += f'{subtitle}'
70
-
71
- return f"Done! Subtitle is in the outputs folder.\n\n{total_result}"
72
-
73
- def transcribe_youtube(self, youtubelink,
74
- model_size, lang, subformat, istranslate,
75
- progress=gr.Progress()):
76
-
77
- def progress_callback(progress_value):
78
- progress(progress_value, desc="Transcribing..")
79
-
80
- if model_size != self.current_model_size or self.model is None:
81
- progress(0, desc="Initializing Model..")
82
- self.current_model_size = model_size
83
- self.model = whisper.load_model(name=model_size, download_root="models/Whisper")
84
-
85
- if lang == "Automatic Detection":
86
- lang = None
87
-
88
- progress(0, desc="Loading Audio from Youtube..")
89
- yt = get_ytdata(youtubelink)
90
- audio = whisper.load_audio(get_ytaudio(yt))
91
-
92
- translatable_model = ["large", "large-v1", "large-v2"]
93
- if istranslate and self.current_model_size in translatable_model:
94
- result = self.model.transcribe(audio=audio, language=lang, verbose=False, task="translate",
95
- progress_callback=progress_callback)
96
- else:
97
- result = self.model.transcribe(audio=audio, language=lang, verbose=False,
98
- progress_callback=progress_callback)
99
-
100
- progress(1, desc="Completed!")
101
-
102
- file_name = safe_filename(yt.title)
103
- timestamp = datetime.now().strftime("%m%d%H%M%S")
104
- output_path = f"outputs/{file_name}-{timestamp}"
105
-
106
- if subformat == "SRT":
107
- subtitle = get_srt(result["segments"])
108
- write_file(subtitle, f"{output_path}.srt")
109
- elif subformat == "WebVTT":
110
- subtitle = get_vtt(result["segments"])
111
- write_file(subtitle, f"{output_path}.vtt")
112
-
113
- return f"Done! Subtitle file is in the outputs folder.\n\n{subtitle}"
114
 
115
  def transcribe_mic(self, micaudio,
116
  model_size, lang, subformat, istranslate,
@@ -119,34 +134,40 @@ class WhisperInference:
119
  def progress_callback(progress_value):
120
  progress(progress_value, desc="Transcribing..")
121
 
122
- if model_size != self.current_model_size or self.model is None:
123
- progress(0, desc="Initializing Model..")
124
- self.current_model_size = model_size
125
- self.model = whisper.load_model(name=model_size, download_root="models/Whisper")
 
126
 
127
- if lang == "Automatic Detection":
128
- lang = None
129
 
130
- progress(0, desc="Loading Audio..")
131
 
132
- translatable_model = ["large", "large-v1", "large-v2"]
133
- if istranslate and self.current_model_size in translatable_model:
134
- result = self.model.transcribe(audio=micaudio, language=lang, verbose=False, task="translate",
135
- progress_callback=progress_callback)
136
- else:
137
- result = self.model.transcribe(audio=micaudio, language=lang, verbose=False,
138
- progress_callback=progress_callback)
139
 
140
- progress(1, desc="Completed!")
141
 
142
- timestamp = datetime.now().strftime("%m%d%H%M%S")
143
- output_path = f"outputs/Mic-{timestamp}"
144
 
145
- if subformat == "SRT":
146
- subtitle = get_srt(result["segments"])
147
- write_file(subtitle, f"{output_path}.srt")
148
- elif subformat == "WebVTT":
149
- subtitle = get_vtt(result["segments"])
150
- write_file(subtitle, f"{output_path}.vtt")
151
 
152
- return f"Done! Subtitle file is in the outputs folder.\n\n{subtitle}"
 
 
 
 
 
 
22
  def progress_callback(progress_value):
23
  progress(progress_value, desc="Transcribing..")
24
 
25
+ try:
26
+ if model_size != self.current_model_size or self.model is None:
27
+ progress(0, desc="Initializing Model..")
28
+ self.current_model_size = model_size
29
+ self.model = whisper.load_model(name=model_size, download_root="models/Whisper")
30
+
31
+ if lang == "Automatic Detection":
32
+ lang = None
33
+
34
+ progress(0, desc="Loading Audio..")
35
+
36
+ files_info = {}
37
+ for fileobj in fileobjs:
38
+
39
+ audio = whisper.load_audio(fileobj.name)
40
+
41
+ translatable_model = ["large", "large-v1", "large-v2"]
42
+ if istranslate and self.current_model_size in translatable_model:
43
+ result = self.model.transcribe(audio=audio, language=lang, verbose=False, task="translate",
44
+ progress_callback=progress_callback)
45
+ else:
46
+ result = self.model.transcribe(audio=audio, language=lang, verbose=False,
47
+ progress_callback=progress_callback)
48
+
49
+ progress(1, desc="Completed!")
50
+
51
+ file_name, file_ext = os.path.splitext(os.path.basename(fileobj.orig_name))
52
+ file_name = file_name[:-9]
53
+ file_name = safe_filename(file_name)
54
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
55
+ output_path = f"outputs/{file_name}-{timestamp}"
56
+
57
+ if subformat == "SRT":
58
+ subtitle = get_srt(result["segments"])
59
+ write_file(subtitle, f"{output_path}.srt")
60
+ elif subformat == "WebVTT":
61
+ subtitle = get_vtt(result["segments"])
62
+ write_file(subtitle, f"{output_path}.vtt")
63
+
64
+ files_info[file_name] = subtitle
65
+
66
+ total_result = ''
67
+ for file_name, subtitle in files_info.items():
68
+ total_result += '------------------------------------\n'
69
+ total_result += f'{file_name}\n\n'
70
+ total_result += f'{subtitle}'
71
+
72
+ return f"Done! Subtitle is in the outputs folder.\n\n{total_result}"
73
+ except Exception as e:
74
+ return str(e)
75
+ finally:
76
+ for fileobj in fileobjs:
77
+ if os.path.exists(fileobj.name):
78
+ os.remove(fileobj.name)
79
 
80
+ def transcribe_youtube(self, youtubelink,
81
+ model_size, lang, subformat, istranslate,
82
+ progress=gr.Progress()):
83
 
84
+ def progress_callback(progress_value):
85
+ progress(progress_value, desc="Transcribing..")
86
 
87
+ try:
88
+ if model_size != self.current_model_size or self.model is None:
89
+ progress(0, desc="Initializing Model..")
90
+ self.current_model_size = model_size
91
+ self.model = whisper.load_model(name=model_size, download_root="models/Whisper")
92
 
93
+ if lang == "Automatic Detection":
94
+ lang = None
95
+
96
+ progress(0, desc="Loading Audio from Youtube..")
97
+ yt = get_ytdata(youtubelink)
98
+ audio = whisper.load_audio(get_ytaudio(yt))
99
 
100
  translatable_model = ["large", "large-v1", "large-v2"]
101
  if istranslate and self.current_model_size in translatable_model:
 
107
 
108
  progress(1, desc="Completed!")
109
 
110
+ file_name = safe_filename(yt.title)
 
 
111
  timestamp = datetime.now().strftime("%m%d%H%M%S")
112
  output_path = f"outputs/{file_name}-{timestamp}"
113
 
 
118
  subtitle = get_vtt(result["segments"])
119
  write_file(subtitle, f"{output_path}.vtt")
120
 
121
+ return f"Done! Subtitle file is in the outputs folder.\n\n{subtitle}"
122
+ except Exception as e:
123
+ return str(e)
124
+ finally:
125
+ yt = get_ytdata(youtubelink)
126
+ file_path = get_ytaudio(yt)
127
+ if os.path.exists(file_path):
128
+ os.remove(file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  def transcribe_mic(self, micaudio,
131
  model_size, lang, subformat, istranslate,
 
134
  def progress_callback(progress_value):
135
  progress(progress_value, desc="Transcribing..")
136
 
137
+ try:
138
+ if model_size != self.current_model_size or self.model is None:
139
+ progress(0, desc="Initializing Model..")
140
+ self.current_model_size = model_size
141
+ self.model = whisper.load_model(name=model_size, download_root="models/Whisper")
142
 
143
+ if lang == "Automatic Detection":
144
+ lang = None
145
 
146
+ progress(0, desc="Loading Audio..")
147
 
148
+ translatable_model = ["large", "large-v1", "large-v2"]
149
+ if istranslate and self.current_model_size in translatable_model:
150
+ result = self.model.transcribe(audio=micaudio, language=lang, verbose=False, task="translate",
151
+ progress_callback=progress_callback)
152
+ else:
153
+ result = self.model.transcribe(audio=micaudio, language=lang, verbose=False,
154
+ progress_callback=progress_callback)
155
 
156
+ progress(1, desc="Completed!")
157
 
158
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
159
+ output_path = f"outputs/Mic-{timestamp}"
160
 
161
+ if subformat == "SRT":
162
+ subtitle = get_srt(result["segments"])
163
+ write_file(subtitle, f"{output_path}.srt")
164
+ elif subformat == "WebVTT":
165
+ subtitle = get_vtt(result["segments"])
166
+ write_file(subtitle, f"{output_path}.vtt")
167
 
168
+ return f"Done! Subtitle file is in the outputs folder.\n\n{subtitle}"
169
+ except Exception as e:
170
+ print(str(e))
171
+ finally:
172
+ if os.path.exists(micaudio):
173
+ os.remove(micaudio)