jhj0517 commited on
Commit
57b2878
·
1 Parent(s): ea9212d

Fix device error

Browse files
Files changed (1) hide show
  1. modules/whisper/whisper_base.py +21 -2
modules/whisper/whisper_base.py CHANGED
@@ -453,15 +453,34 @@ class WhisperBase(ABC):
453
 
454
  return time_str.strip()
455
 
456
- @staticmethod
457
- def get_device():
458
  if torch.cuda.is_available():
459
  return "cuda"
460
  elif torch.backends.mps.is_available():
 
 
 
461
  return "mps"
462
  else:
463
  return "cpu"
464
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
  @staticmethod
466
  def release_cuda_memory():
467
  """Release memory"""
 
453
 
454
  return time_str.strip()
455
 
456
+ def get_device(self):
 
457
  if torch.cuda.is_available():
458
  return "cuda"
459
  elif torch.backends.mps.is_available():
460
+ if self.is_sparse_mps_supported():
461
+ # MPS is not supported for sparse tensor for now. See : https://github.com/pytorch/pytorch/issues/87886
462
+ return "cpu"
463
  return "mps"
464
  else:
465
  return "cpu"
466
 
467
+ @staticmethod
468
+ def is_sparse_mps_supported():
469
+ if torch.backends.mps.is_available():
470
+ return False
471
+
472
+ try:
473
+ device = torch.device("mps")
474
+ sparse_tensor = torch.sparse_coo_tensor(
475
+ indices=torch.tensor([[0, 1], [2, 3]]),
476
+ values=torch.tensor([1, 2]),
477
+ size=(4, 4),
478
+ device=device
479
+ )
480
+ return True
481
+ except RuntimeError:
482
+ return False
483
+
484
  @staticmethod
485
  def release_cuda_memory():
486
  """Release memory"""