NeMo / nemo /utils /cloud.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pathlib import Path
from time import sleep
import wget
from pytorch_lightning.plugins.environments import LightningEnvironment
from pytorch_lightning.strategies import DDPStrategy, StrategyRegistry
from nemo.utils import logging
def maybe_download_from_cloud(url, filename, subfolder=None, cache_dir=None, refresh_cache=False) -> str:
"""
Helper function to download pre-trained weights from the cloud
Args:
url: (str) URL of storage
filename: (str) what to download. The request will be issued to url/filename
subfolder: (str) subfolder within cache_dir. The file will be stored in cache_dir/subfolder. Subfolder can
be empty
cache_dir: (str) a cache directory where to download. If not present, this function will attempt to create it.
If None (default), then it will be $HOME/.cache/torch/NeMo
refresh_cache: (bool) if True and cached file is present, it will delete it and re-fetch
Returns:
If successful - absolute local path to the downloaded file
else - empty string
"""
# try:
if cache_dir is None:
cache_location = Path.joinpath(Path.home(), ".cache/torch/NeMo")
else:
cache_location = cache_dir
if subfolder is not None:
destination = Path.joinpath(cache_location, subfolder)
else:
destination = cache_location
if not os.path.exists(destination):
os.makedirs(destination, exist_ok=True)
destination_file = Path.joinpath(destination, filename)
if os.path.exists(destination_file):
logging.info(f"Found existing object {destination_file}.")
if refresh_cache:
logging.info("Asked to refresh the cache.")
logging.info(f"Deleting file: {destination_file}")
os.remove(destination_file)
else:
logging.info(f"Re-using file from: {destination_file}")
return str(destination_file)
# download file
wget_uri = url + filename
logging.info(f"Downloading from: {wget_uri} to {str(destination_file)}")
# NGC links do not work everytime so we try and wait
i = 0
max_attempts = 3
while i < max_attempts:
i += 1
try:
wget.download(wget_uri, str(destination_file))
if os.path.exists(destination_file):
return destination_file
else:
return ""
except:
logging.info(f"Download from cloud failed. Attempt {i} of {max_attempts}")
sleep(0.05)
continue
raise ValueError("Not able to download url right now, please try again.")
class SageMakerDDPStrategy(DDPStrategy):
@property
def cluster_environment(self):
env = LightningEnvironment()
env.world_size = lambda: int(os.environ["WORLD_SIZE"])
env.global_rank = lambda: int(os.environ["RANK"])
return env
@cluster_environment.setter
def cluster_environment(self, env):
# prevents Lightning from overriding the Environment required for SageMaker
pass
def initialize_sagemaker() -> None:
"""
Helper function to initiate sagemaker with NeMo.
This function installs libraries that NeMo requires for the ASR toolkit + initializes sagemaker ddp.
"""
StrategyRegistry.register(
name='smddp', strategy=SageMakerDDPStrategy, process_group_backend="smddp", find_unused_parameters=False,
)
def _install_system_libraries() -> None:
os.system('chmod 777 /tmp && apt-get update && apt-get install -y libsndfile1 ffmpeg')
def _patch_torch_metrics() -> None:
"""
Patches torchmetrics to not rely on internal state.
This is because sagemaker DDP overrides the `__init__` function of the modules to do automatic-partitioning.
"""
from torchmetrics import Metric
def __new_hash__(self):
hash_vals = [self.__class__.__name__, id(self)]
return hash(tuple(hash_vals))
Metric.__hash__ = __new_hash__
_patch_torch_metrics()
if os.environ.get("RANK") and os.environ.get("WORLD_SIZE"):
import smdistributed.dataparallel.torch.distributed as dist
# has to be imported, as it overrides torch modules and such when DDP is enabled.
import smdistributed.dataparallel.torch.torch_smddp
dist.init_process_group()
if dist.get_local_rank():
_install_system_libraries()
return dist.barrier() # wait for main process
_install_system_libraries()
return