From 7a64fae33ad09bfe7f87798a468da4f45ca1a11c Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Tue, 26 Apr 2022 10:00:03 +0800 Subject: [PATCH] [doc] improved error messages in initialize (#872) --- colossalai/initialize.py | 38 ++++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/colossalai/initialize.py b/colossalai/initialize.py index bdbc96681..fb0de4e20 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -138,8 +138,14 @@ def launch_from_slurm(config: Union[str, Path, Config, Dict], seed (int, optional): Specified random seed for every process. Defaults to 1024. verbose (bool, optional): Whether to print logs. Defaults to True. """ - rank = int(os.environ['SLURM_PROCID']) - world_size = int(os.environ['SLURM_NPROCS']) + try: + rank = int(os.environ['SLURM_PROCID']) + world_size = int(os.environ['SLURM_NPROCS']) + except KeyError as e: + raise RuntimeError( + f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM" + ) + launch(config=config, rank=rank, world_size=world_size, @@ -167,9 +173,15 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict], seed (int, optional): Specified random seed for every process. Defaults to 1024. verbose (bool, optional): Whether to print logs. Defaults to True. """ - rank = int(os.environ['OMPI_COMM_WORLD_RANK']) - local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) - world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + try: + rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + except KeyError as e: + raise RuntimeError( + f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI" + ) + launch(config=config, local_rank=local_rank, rank=rank, @@ -194,11 +206,17 @@ def launch_from_torch(config: Union[str, Path, Config, Dict], seed (int, optional): Specified random seed for every process. Defaults to 1024. verbose (bool, optional): Whether to print logs. Defaults to True. """ - rank = int(os.environ['RANK']) - local_rank = int(os.environ['LOCAL_RANK']) - world_size = int(os.environ['WORLD_SIZE']) - host = os.environ['MASTER_ADDR'] - port = int(os.environ['MASTER_PORT']) + try: + rank = int(os.environ['RANK']) + local_rank = int(os.environ['LOCAL_RANK']) + world_size = int(os.environ['WORLD_SIZE']) + host = os.environ['MASTER_ADDR'] + port = int(os.environ['MASTER_PORT']) + except KeyError as e: + raise RuntimeError( + f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch" + ) + launch(config=config, local_rank=local_rank, rank=rank,