From 7a64fae33ad09bfe7f87798a468da4f45ca1a11c Mon Sep 17 00:00:00 2001
From: Frank Lee <somerlee.9@gmail.com>
Date: Tue, 26 Apr 2022 10:00:03 +0800
Subject: [PATCH] [doc] improved error messages in initialize (#872)

---
 colossalai/initialize.py | 38 ++++++++++++++++++++++++++++----------
 1 file changed, 28 insertions(+), 10 deletions(-)

diff --git a/colossalai/initialize.py b/colossalai/initialize.py
index bdbc96681..fb0de4e20 100644
--- a/colossalai/initialize.py
+++ b/colossalai/initialize.py
@@ -138,8 +138,14 @@ def launch_from_slurm(config: Union[str, Path, Config, Dict],
         seed (int, optional): Specified random seed for every process. Defaults to 1024.
         verbose (bool, optional): Whether to print logs. Defaults to True.
     """
-    rank = int(os.environ['SLURM_PROCID'])
-    world_size = int(os.environ['SLURM_NPROCS'])
+    try:
+        rank = int(os.environ['SLURM_PROCID'])
+        world_size = int(os.environ['SLURM_NPROCS'])
+    except KeyError as e:
+        raise RuntimeError(
+            f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM"
+        )
+
     launch(config=config,
            rank=rank,
            world_size=world_size,
@@ -167,9 +173,15 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict],
         seed (int, optional): Specified random seed for every process. Defaults to 1024.
         verbose (bool, optional): Whether to print logs. Defaults to True.
     """
-    rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
-    local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
-    world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+    try:
+        rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+        local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
+        world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+    except KeyError as e:
+        raise RuntimeError(
+            f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI"
+        )
+
     launch(config=config,
            local_rank=local_rank,
            rank=rank,
@@ -194,11 +206,17 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
         seed (int, optional): Specified random seed for every process. Defaults to 1024.
         verbose (bool, optional): Whether to print logs. Defaults to True.
     """
-    rank = int(os.environ['RANK'])
-    local_rank = int(os.environ['LOCAL_RANK'])
-    world_size = int(os.environ['WORLD_SIZE'])
-    host = os.environ['MASTER_ADDR']
-    port = int(os.environ['MASTER_PORT'])
+    try:
+        rank = int(os.environ['RANK'])
+        local_rank = int(os.environ['LOCAL_RANK'])
+        world_size = int(os.environ['WORLD_SIZE'])
+        host = os.environ['MASTER_ADDR']
+        port = int(os.environ['MASTER_PORT'])
+    except KeyError as e:
+        raise RuntimeError(
+            f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
+        )
+
     launch(config=config,
            local_rank=local_rank,
            rank=rank,