Spaces:
Running
Running
jhj0517
commited on
Commit
·
6d9d096
1
Parent(s):
f9b7286
add args for local model path
Browse files- app.py +4 -0
- modules/faster_whisper_inference.py +2 -2
- modules/whisper_Inference.py +2 -2
app.py
CHANGED
@@ -17,8 +17,10 @@ class App:
|
|
17 |
self.app = gr.Blocks(css=CSS, theme=self.args.theme)
|
18 |
self.whisper_inf = WhisperInference() if self.args.disable_faster_whisper else FasterWhisperInference()
|
19 |
if isinstance(self.whisper_inf, FasterWhisperInference):
|
|
|
20 |
print("Use Faster Whisper implementation")
|
21 |
else:
|
|
|
22 |
print("Use Open AI Whisper implementation")
|
23 |
print(f"Device \"{self.whisper_inf.device}\" is detected")
|
24 |
self.nllb_inf = NLLBInference()
|
@@ -296,6 +298,8 @@ parser.add_argument('--password', type=str, default=None, help='Gradio authentic
|
|
296 |
parser.add_argument('--theme', type=str, default=None, help='Gradio Blocks theme')
|
297 |
parser.add_argument('--colab', type=bool, default=False, nargs='?', const=True, help='Is colab user or not')
|
298 |
parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=True, help='enable api or not')
|
|
|
|
|
299 |
_args = parser.parse_args()
|
300 |
|
301 |
if __name__ == "__main__":
|
|
|
17 |
self.app = gr.Blocks(css=CSS, theme=self.args.theme)
|
18 |
self.whisper_inf = WhisperInference() if self.args.disable_faster_whisper else FasterWhisperInference()
|
19 |
if isinstance(self.whisper_inf, FasterWhisperInference):
|
20 |
+
self.whisper_inf.model_dir = args.faster_whisper_model_dir
|
21 |
print("Use Faster Whisper implementation")
|
22 |
else:
|
23 |
+
self.whisper_inf.model_dir = args.whisper_model_dir
|
24 |
print("Use Open AI Whisper implementation")
|
25 |
print(f"Device \"{self.whisper_inf.device}\" is detected")
|
26 |
self.nllb_inf = NLLBInference()
|
|
|
298 |
parser.add_argument('--theme', type=str, default=None, help='Gradio Blocks theme')
|
299 |
parser.add_argument('--colab', type=bool, default=False, nargs='?', const=True, help='Is colab user or not')
|
300 |
parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=True, help='enable api or not')
|
301 |
+
parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"), help='Directory path of the whisper model')
|
302 |
+
parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"), help='Directory path of the faster-whisper model')
|
303 |
_args = parser.parse_args()
|
304 |
|
305 |
if __name__ == "__main__":
|
modules/faster_whisper_inference.py
CHANGED
@@ -32,7 +32,7 @@ class FasterWhisperInference(BaseInterface):
|
|
32 |
self.available_compute_types = ctranslate2.get_supported_compute_types(
|
33 |
"cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
|
34 |
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
|
35 |
-
self.
|
36 |
|
37 |
def transcribe_file(self,
|
38 |
files: list,
|
@@ -311,7 +311,7 @@ class FasterWhisperInference(BaseInterface):
|
|
311 |
self.model = faster_whisper.WhisperModel(
|
312 |
device=self.device,
|
313 |
model_size_or_path=model_size,
|
314 |
-
download_root=
|
315 |
compute_type=self.current_compute_type
|
316 |
)
|
317 |
|
|
|
32 |
self.available_compute_types = ctranslate2.get_supported_compute_types(
|
33 |
"cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
|
34 |
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
|
35 |
+
self.model_dir = os.path.join("models", "Whisper", "faster-whisper")
|
36 |
|
37 |
def transcribe_file(self,
|
38 |
files: list,
|
|
|
311 |
self.model = faster_whisper.WhisperModel(
|
312 |
device=self.device,
|
313 |
model_size_or_path=model_size,
|
314 |
+
download_root=self.model_dir,
|
315 |
compute_type=self.current_compute_type
|
316 |
)
|
317 |
|
modules/whisper_Inference.py
CHANGED
@@ -26,7 +26,7 @@ class WhisperInference(BaseInterface):
|
|
26 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
27 |
self.available_compute_types = ["float16", "float32"]
|
28 |
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
|
29 |
-
self.
|
30 |
|
31 |
def transcribe_file(self,
|
32 |
files: list,
|
@@ -288,7 +288,7 @@ class WhisperInference(BaseInterface):
|
|
288 |
self.model = whisper.load_model(
|
289 |
name=model_size,
|
290 |
device=self.device,
|
291 |
-
download_root=
|
292 |
)
|
293 |
|
294 |
@staticmethod
|
|
|
26 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
27 |
self.available_compute_types = ["float16", "float32"]
|
28 |
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
|
29 |
+
self.model_dir = os.path.join("models", "Whisper")
|
30 |
|
31 |
def transcribe_file(self,
|
32 |
files: list,
|
|
|
288 |
self.model = whisper.load_model(
|
289 |
name=model_size,
|
290 |
device=self.device,
|
291 |
+
download_root=self.model_dir
|
292 |
)
|
293 |
|
294 |
@staticmethod
|