akshit-g's picture
add : files
d3cd5c1
raw
history blame
402 Bytes
import torch
LATEST_REVISION = "2024-08-26"
def detect_device():
"""
Detects the appropriate device to run on, and return the device and dtype.
"""
if torch.cuda.is_available():
return torch.device("cuda"), torch.float16
elif torch.backends.mps.is_available():
return torch.device("mps"), torch.float16
else:
return torch.device("cpu"), torch.float32