# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from pathlib import Path from time import sleep import wget from pytorch_lightning.plugins.environments import LightningEnvironment from pytorch_lightning.strategies import DDPStrategy, StrategyRegistry from nemo.utils import logging def maybe_download_from_cloud(url, filename, subfolder=None, cache_dir=None, refresh_cache=False) -> str: """ Helper function to download pre-trained weights from the cloud Args: url: (str) URL of storage filename: (str) what to download. The request will be issued to url/filename subfolder: (str) subfolder within cache_dir. The file will be stored in cache_dir/subfolder. Subfolder can be empty cache_dir: (str) a cache directory where to download. If not present, this function will attempt to create it. If None (default), then it will be $HOME/.cache/torch/NeMo refresh_cache: (bool) if True and cached file is present, it will delete it and re-fetch Returns: If successful - absolute local path to the downloaded file else - empty string """ # try: if cache_dir is None: cache_location = Path.joinpath(Path.home(), ".cache/torch/NeMo") else: cache_location = cache_dir if subfolder is not None: destination = Path.joinpath(cache_location, subfolder) else: destination = cache_location if not os.path.exists(destination): os.makedirs(destination, exist_ok=True) destination_file = Path.joinpath(destination, filename) if os.path.exists(destination_file): logging.info(f"Found existing object {destination_file}.") if refresh_cache: logging.info("Asked to refresh the cache.") logging.info(f"Deleting file: {destination_file}") os.remove(destination_file) else: logging.info(f"Re-using file from: {destination_file}") return str(destination_file) # download file wget_uri = url + filename logging.info(f"Downloading from: {wget_uri} to {str(destination_file)}") # NGC links do not work everytime so we try and wait i = 0 max_attempts = 3 while i < max_attempts: i += 1 try: wget.download(wget_uri, str(destination_file)) if os.path.exists(destination_file): return destination_file else: return "" except: logging.info(f"Download from cloud failed. Attempt {i} of {max_attempts}") sleep(0.05) continue raise ValueError("Not able to download url right now, please try again.") class SageMakerDDPStrategy(DDPStrategy): @property def cluster_environment(self): env = LightningEnvironment() env.world_size = lambda: int(os.environ["WORLD_SIZE"]) env.global_rank = lambda: int(os.environ["RANK"]) return env @cluster_environment.setter def cluster_environment(self, env): # prevents Lightning from overriding the Environment required for SageMaker pass def initialize_sagemaker() -> None: """ Helper function to initiate sagemaker with NeMo. This function installs libraries that NeMo requires for the ASR toolkit + initializes sagemaker ddp. """ StrategyRegistry.register( name='smddp', strategy=SageMakerDDPStrategy, process_group_backend="smddp", find_unused_parameters=False, ) def _install_system_libraries() -> None: os.system('chmod 777 /tmp && apt-get update && apt-get install -y libsndfile1 ffmpeg') def _patch_torch_metrics() -> None: """ Patches torchmetrics to not rely on internal state. This is because sagemaker DDP overrides the `__init__` function of the modules to do automatic-partitioning. """ from torchmetrics import Metric def __new_hash__(self): hash_vals = [self.__class__.__name__, id(self)] return hash(tuple(hash_vals)) Metric.__hash__ = __new_hash__ _patch_torch_metrics() if os.environ.get("RANK") and os.environ.get("WORLD_SIZE"): import smdistributed.dataparallel.torch.distributed as dist # has to be imported, as it overrides torch modules and such when DDP is enabled. import smdistributed.dataparallel.torch.torch_smddp dist.init_process_group() if dist.get_local_rank(): _install_system_libraries() return dist.barrier() # wait for main process _install_system_libraries() return