Spaces:
Running
Running
jhj0517
commited on
Commit
·
57b2878
1
Parent(s):
ea9212d
Fix device error
Browse files
modules/whisper/whisper_base.py
CHANGED
@@ -453,15 +453,34 @@ class WhisperBase(ABC):
|
|
453 |
|
454 |
return time_str.strip()
|
455 |
|
456 |
-
|
457 |
-
def get_device():
|
458 |
if torch.cuda.is_available():
|
459 |
return "cuda"
|
460 |
elif torch.backends.mps.is_available():
|
|
|
|
|
|
|
461 |
return "mps"
|
462 |
else:
|
463 |
return "cpu"
|
464 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
465 |
@staticmethod
|
466 |
def release_cuda_memory():
|
467 |
"""Release memory"""
|
|
|
453 |
|
454 |
return time_str.strip()
|
455 |
|
456 |
+
def get_device(self):
|
|
|
457 |
if torch.cuda.is_available():
|
458 |
return "cuda"
|
459 |
elif torch.backends.mps.is_available():
|
460 |
+
if self.is_sparse_mps_supported():
|
461 |
+
# MPS is not supported for sparse tensor for now. See : https://github.com/pytorch/pytorch/issues/87886
|
462 |
+
return "cpu"
|
463 |
return "mps"
|
464 |
else:
|
465 |
return "cpu"
|
466 |
|
467 |
+
@staticmethod
|
468 |
+
def is_sparse_mps_supported():
|
469 |
+
if torch.backends.mps.is_available():
|
470 |
+
return False
|
471 |
+
|
472 |
+
try:
|
473 |
+
device = torch.device("mps")
|
474 |
+
sparse_tensor = torch.sparse_coo_tensor(
|
475 |
+
indices=torch.tensor([[0, 1], [2, 3]]),
|
476 |
+
values=torch.tensor([1, 2]),
|
477 |
+
size=(4, 4),
|
478 |
+
device=device
|
479 |
+
)
|
480 |
+
return True
|
481 |
+
except RuntimeError:
|
482 |
+
return False
|
483 |
+
|
484 |
@staticmethod
|
485 |
def release_cuda_memory():
|
486 |
"""Release memory"""
|