jhj0517 commited on
Commit
9234bb0
·
1 Parent(s): 21c511e

add mps device

Browse files
modules/faster_whisper_inference.py CHANGED
@@ -31,7 +31,12 @@ class FasterWhisperInference(BaseInterface):
31
  self.available_models = self.model_paths.keys()
32
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
33
  self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
34
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
35
  self.available_compute_types = ctranslate2.get_supported_compute_types(
36
  "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
37
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
 
31
  self.available_models = self.model_paths.keys()
32
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
33
  self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
34
+ if torch.cuda.is_available():
35
+ self.device = "cuda"
36
+ elif torch.backends.mps.is_available():
37
+ self.device = "mps"
38
+ else:
39
+ self.device = "cpu"
40
  self.available_compute_types = ctranslate2.get_supported_compute_types(
41
  "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
42
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
modules/whisper_Inference.py CHANGED
@@ -23,6 +23,12 @@ class WhisperInference(BaseInterface):
23
  self.available_models = whisper.available_models()
24
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
25
  self.translatable_model = ["large", "large-v1", "large-v2", "large-v3"]
 
 
 
 
 
 
26
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
27
  self.available_compute_types = ["float16", "float32"]
28
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
 
23
  self.available_models = whisper.available_models()
24
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
25
  self.translatable_model = ["large", "large-v1", "large-v2", "large-v3"]
26
+ if torch.cuda.is_available():
27
+ self.device = "cuda"
28
+ elif torch.backends.mps.is_available():
29
+ self.device = "mps"
30
+ else:
31
+ self.device = "cpu"
32
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
33
  self.available_compute_types = ["float16", "float32"]
34
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"