Spaces:
Running
Running
Merge branch 'master' into feature/add-bgm-tab
Browse files
modules/whisper/whisper_base.py
CHANGED
@@ -458,10 +458,30 @@ class WhisperBase(ABC):
|
|
458 |
if torch.cuda.is_available():
|
459 |
return "cuda"
|
460 |
elif torch.backends.mps.is_available():
|
|
|
|
|
|
|
461 |
return "mps"
|
462 |
else:
|
463 |
return "cpu"
|
464 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
465 |
@staticmethod
|
466 |
def release_cuda_memory():
|
467 |
"""Release memory"""
|
|
|
458 |
if torch.cuda.is_available():
|
459 |
return "cuda"
|
460 |
elif torch.backends.mps.is_available():
|
461 |
+
if not WhisperBase.is_sparse_api_supported():
|
462 |
+
# Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
|
463 |
+
return "cpu"
|
464 |
return "mps"
|
465 |
else:
|
466 |
return "cpu"
|
467 |
|
468 |
+
@staticmethod
|
469 |
+
def is_sparse_api_supported():
|
470 |
+
if not torch.backends.mps.is_available():
|
471 |
+
return False
|
472 |
+
|
473 |
+
try:
|
474 |
+
device = torch.device("mps")
|
475 |
+
sparse_tensor = torch.sparse_coo_tensor(
|
476 |
+
indices=torch.tensor([[0, 1], [2, 3]]),
|
477 |
+
values=torch.tensor([1, 2]),
|
478 |
+
size=(4, 4),
|
479 |
+
device=device
|
480 |
+
)
|
481 |
+
return True
|
482 |
+
except RuntimeError:
|
483 |
+
return False
|
484 |
+
|
485 |
@staticmethod
|
486 |
def release_cuda_memory():
|
487 |
"""Release memory"""
|