|
|
|
@ -138,8 +138,14 @@ def launch_from_slurm(config: Union[str, Path, Config, Dict],
|
|
|
|
|
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
|
|
|
|
verbose (bool, optional): Whether to print logs. Defaults to True.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
rank = int(os.environ['SLURM_PROCID'])
|
|
|
|
|
world_size = int(os.environ['SLURM_NPROCS'])
|
|
|
|
|
except KeyError as e:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
launch(config=config,
|
|
|
|
|
rank=rank,
|
|
|
|
|
world_size=world_size,
|
|
|
|
@ -167,9 +173,15 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict],
|
|
|
|
|
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
|
|
|
|
verbose (bool, optional): Whether to print logs. Defaults to True.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
|
|
|
|
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
|
|
|
|
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
|
|
|
|
except KeyError as e:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
launch(config=config,
|
|
|
|
|
local_rank=local_rank,
|
|
|
|
|
rank=rank,
|
|
|
|
@ -194,11 +206,17 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
|
|
|
|
|
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
|
|
|
|
verbose (bool, optional): Whether to print logs. Defaults to True.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
rank = int(os.environ['RANK'])
|
|
|
|
|
local_rank = int(os.environ['LOCAL_RANK'])
|
|
|
|
|
world_size = int(os.environ['WORLD_SIZE'])
|
|
|
|
|
host = os.environ['MASTER_ADDR']
|
|
|
|
|
port = int(os.environ['MASTER_PORT'])
|
|
|
|
|
except KeyError as e:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
launch(config=config,
|
|
|
|
|
local_rank=local_rank,
|
|
|
|
|
rank=rank,
|
|
|
|
|