jhj0517 commited on
Commit
ed88f88
·
1 Parent(s): 5cbd5e7

Fix device bug

Browse files
modules/whisper/faster_whisper_inference.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import time
3
  import numpy as np
 
4
  from typing import BinaryIO, Union, Tuple, List
5
  import faster_whisper
6
  from faster_whisper.vad import VadOptions
@@ -25,6 +26,7 @@ class FasterWhisperInference(WhisperBase):
25
  args=args
26
  )
27
  self.model_paths = self.get_model_paths()
 
28
  self.available_models = self.model_paths.keys()
29
  self.available_compute_types = ctranslate2.get_supported_compute_types(
30
  "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
@@ -155,3 +157,12 @@ class FasterWhisperInference(WhisperBase):
155
  if model_name not in whisper.available_models():
156
  model_paths[model_name] = os.path.join(webui_dir, self.model_dir, model_name)
157
  return model_paths
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import time
3
  import numpy as np
4
+ import torch
5
  from typing import BinaryIO, Union, Tuple, List
6
  import faster_whisper
7
  from faster_whisper.vad import VadOptions
 
26
  args=args
27
  )
28
  self.model_paths = self.get_model_paths()
29
+ self.device = self.get_device()
30
  self.available_models = self.model_paths.keys()
31
  self.available_compute_types = ctranslate2.get_supported_compute_types(
32
  "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
 
157
  if model_name not in whisper.available_models():
158
  model_paths[model_name] = os.path.join(webui_dir, self.model_dir, model_name)
159
  return model_paths
160
+
161
+ @staticmethod
162
+ def get_device():
163
+ if torch.cuda.is_available():
164
+ return "cuda"
165
+ elif torch.backends.mps.is_available():
166
+ return "auto"
167
+ else:
168
+ return "cpu"