jhj0517 commited on
Commit
64a4fe1
·
2 Parent(s): e30abef 855124f

Merge branch 'master' into feature/add-bgm-tab

Browse files
Files changed (1) hide show
  1. modules/whisper/whisper_base.py +20 -0
modules/whisper/whisper_base.py CHANGED
@@ -458,10 +458,30 @@ class WhisperBase(ABC):
458
  if torch.cuda.is_available():
459
  return "cuda"
460
  elif torch.backends.mps.is_available():
 
 
 
461
  return "mps"
462
  else:
463
  return "cpu"
464
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
  @staticmethod
466
  def release_cuda_memory():
467
  """Release memory"""
 
458
  if torch.cuda.is_available():
459
  return "cuda"
460
  elif torch.backends.mps.is_available():
461
+ if not WhisperBase.is_sparse_api_supported():
462
+ # Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
463
+ return "cpu"
464
  return "mps"
465
  else:
466
  return "cpu"
467
 
468
+ @staticmethod
469
+ def is_sparse_api_supported():
470
+ if not torch.backends.mps.is_available():
471
+ return False
472
+
473
+ try:
474
+ device = torch.device("mps")
475
+ sparse_tensor = torch.sparse_coo_tensor(
476
+ indices=torch.tensor([[0, 1], [2, 3]]),
477
+ values=torch.tensor([1, 2]),
478
+ size=(4, 4),
479
+ device=device
480
+ )
481
+ return True
482
+ except RuntimeError:
483
+ return False
484
+
485
  @staticmethod
486
  def release_cuda_memory():
487
  """Release memory"""