|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
from pathlib import Path |
|
from time import sleep |
|
|
|
import wget |
|
from pytorch_lightning.plugins.environments import LightningEnvironment |
|
from pytorch_lightning.strategies import DDPStrategy, StrategyRegistry |
|
|
|
from nemo.utils import logging |
|
|
|
|
|
def maybe_download_from_cloud(url, filename, subfolder=None, cache_dir=None, refresh_cache=False) -> str: |
|
""" |
|
Helper function to download pre-trained weights from the cloud |
|
Args: |
|
url: (str) URL of storage |
|
filename: (str) what to download. The request will be issued to url/filename |
|
subfolder: (str) subfolder within cache_dir. The file will be stored in cache_dir/subfolder. Subfolder can |
|
be empty |
|
cache_dir: (str) a cache directory where to download. If not present, this function will attempt to create it. |
|
If None (default), then it will be $HOME/.cache/torch/NeMo |
|
refresh_cache: (bool) if True and cached file is present, it will delete it and re-fetch |
|
|
|
Returns: |
|
If successful - absolute local path to the downloaded file |
|
else - empty string |
|
""" |
|
|
|
if cache_dir is None: |
|
cache_location = Path.joinpath(Path.home(), ".cache/torch/NeMo") |
|
else: |
|
cache_location = cache_dir |
|
if subfolder is not None: |
|
destination = Path.joinpath(cache_location, subfolder) |
|
else: |
|
destination = cache_location |
|
|
|
if not os.path.exists(destination): |
|
os.makedirs(destination, exist_ok=True) |
|
|
|
destination_file = Path.joinpath(destination, filename) |
|
|
|
if os.path.exists(destination_file): |
|
logging.info(f"Found existing object {destination_file}.") |
|
if refresh_cache: |
|
logging.info("Asked to refresh the cache.") |
|
logging.info(f"Deleting file: {destination_file}") |
|
os.remove(destination_file) |
|
else: |
|
logging.info(f"Re-using file from: {destination_file}") |
|
return str(destination_file) |
|
|
|
wget_uri = url + filename |
|
logging.info(f"Downloading from: {wget_uri} to {str(destination_file)}") |
|
|
|
i = 0 |
|
max_attempts = 3 |
|
while i < max_attempts: |
|
i += 1 |
|
try: |
|
wget.download(wget_uri, str(destination_file)) |
|
if os.path.exists(destination_file): |
|
return destination_file |
|
else: |
|
return "" |
|
except: |
|
logging.info(f"Download from cloud failed. Attempt {i} of {max_attempts}") |
|
sleep(0.05) |
|
continue |
|
raise ValueError("Not able to download url right now, please try again.") |
|
|
|
|
|
class SageMakerDDPStrategy(DDPStrategy): |
|
@property |
|
def cluster_environment(self): |
|
env = LightningEnvironment() |
|
env.world_size = lambda: int(os.environ["WORLD_SIZE"]) |
|
env.global_rank = lambda: int(os.environ["RANK"]) |
|
return env |
|
|
|
@cluster_environment.setter |
|
def cluster_environment(self, env): |
|
|
|
pass |
|
|
|
|
|
def initialize_sagemaker() -> None: |
|
""" |
|
Helper function to initiate sagemaker with NeMo. |
|
This function installs libraries that NeMo requires for the ASR toolkit + initializes sagemaker ddp. |
|
""" |
|
|
|
StrategyRegistry.register( |
|
name='smddp', strategy=SageMakerDDPStrategy, process_group_backend="smddp", find_unused_parameters=False, |
|
) |
|
|
|
def _install_system_libraries() -> None: |
|
os.system('chmod 777 /tmp && apt-get update && apt-get install -y libsndfile1 ffmpeg') |
|
|
|
def _patch_torch_metrics() -> None: |
|
""" |
|
Patches torchmetrics to not rely on internal state. |
|
This is because sagemaker DDP overrides the `__init__` function of the modules to do automatic-partitioning. |
|
""" |
|
from torchmetrics import Metric |
|
|
|
def __new_hash__(self): |
|
hash_vals = [self.__class__.__name__, id(self)] |
|
return hash(tuple(hash_vals)) |
|
|
|
Metric.__hash__ = __new_hash__ |
|
|
|
_patch_torch_metrics() |
|
|
|
if os.environ.get("RANK") and os.environ.get("WORLD_SIZE"): |
|
import smdistributed.dataparallel.torch.distributed as dist |
|
|
|
|
|
import smdistributed.dataparallel.torch.torch_smddp |
|
|
|
dist.init_process_group() |
|
|
|
if dist.get_local_rank(): |
|
_install_system_libraries() |
|
return dist.barrier() |
|
_install_system_libraries() |
|
return |
|
|