jhj0517 commited on
Commit
ca772aa
·
1 Parent(s): 57b2878

Rename function and use static method

Browse files
Files changed (1) hide show
  1. modules/whisper/whisper_base.py +5 -4
modules/whisper/whisper_base.py CHANGED
@@ -453,19 +453,20 @@ class WhisperBase(ABC):
453
 
454
  return time_str.strip()
455
 
456
- def get_device(self):
 
457
  if torch.cuda.is_available():
458
  return "cuda"
459
  elif torch.backends.mps.is_available():
460
- if self.is_sparse_mps_supported():
461
- # MPS is not supported for sparse tensor for now. See : https://github.com/pytorch/pytorch/issues/87886
462
  return "cpu"
463
  return "mps"
464
  else:
465
  return "cpu"
466
 
467
  @staticmethod
468
- def is_sparse_mps_supported():
469
  if torch.backends.mps.is_available():
470
  return False
471
 
 
453
 
454
  return time_str.strip()
455
 
456
+ @staticmethod
457
+ def get_device():
458
  if torch.cuda.is_available():
459
  return "cuda"
460
  elif torch.backends.mps.is_available():
461
+ if not WhisperBase.is_sparse_api_supported():
462
+ # Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
463
  return "cpu"
464
  return "mps"
465
  else:
466
  return "cpu"
467
 
468
  @staticmethod
469
+ def is_sparse_api_supported():
470
  if torch.backends.mps.is_available():
471
  return False
472