|
|
|
@ -138,8 +138,14 @@ def launch_from_slurm(config: Union[str, Path, Config, Dict],
|
|
|
|
|
seed (int, optional): Specified random seed for every process. Defaults to 1024. |
|
|
|
|
verbose (bool, optional): Whether to print logs. Defaults to True. |
|
|
|
|
""" |
|
|
|
|
rank = int(os.environ['SLURM_PROCID']) |
|
|
|
|
world_size = int(os.environ['SLURM_NPROCS']) |
|
|
|
|
try: |
|
|
|
|
rank = int(os.environ['SLURM_PROCID']) |
|
|
|
|
world_size = int(os.environ['SLURM_NPROCS']) |
|
|
|
|
except KeyError as e: |
|
|
|
|
raise RuntimeError( |
|
|
|
|
f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
launch(config=config, |
|
|
|
|
rank=rank, |
|
|
|
|
world_size=world_size, |
|
|
|
@ -167,9 +173,15 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict],
|
|
|
|
|
seed (int, optional): Specified random seed for every process. Defaults to 1024. |
|
|
|
|
verbose (bool, optional): Whether to print logs. Defaults to True. |
|
|
|
|
""" |
|
|
|
|
rank = int(os.environ['OMPI_COMM_WORLD_RANK']) |
|
|
|
|
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) |
|
|
|
|
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) |
|
|
|
|
try: |
|
|
|
|
rank = int(os.environ['OMPI_COMM_WORLD_RANK']) |
|
|
|
|
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) |
|
|
|
|
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) |
|
|
|
|
except KeyError as e: |
|
|
|
|
raise RuntimeError( |
|
|
|
|
f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
launch(config=config, |
|
|
|
|
local_rank=local_rank, |
|
|
|
|
rank=rank, |
|
|
|
@ -194,11 +206,17 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
|
|
|
|
|
seed (int, optional): Specified random seed for every process. Defaults to 1024. |
|
|
|
|
verbose (bool, optional): Whether to print logs. Defaults to True. |
|
|
|
|
""" |
|
|
|
|
rank = int(os.environ['RANK']) |
|
|
|
|
local_rank = int(os.environ['LOCAL_RANK']) |
|
|
|
|
world_size = int(os.environ['WORLD_SIZE']) |
|
|
|
|
host = os.environ['MASTER_ADDR'] |
|
|
|
|
port = int(os.environ['MASTER_PORT']) |
|
|
|
|
try: |
|
|
|
|
rank = int(os.environ['RANK']) |
|
|
|
|
local_rank = int(os.environ['LOCAL_RANK']) |
|
|
|
|
world_size = int(os.environ['WORLD_SIZE']) |
|
|
|
|
host = os.environ['MASTER_ADDR'] |
|
|
|
|
port = int(os.environ['MASTER_PORT']) |
|
|
|
|
except KeyError as e: |
|
|
|
|
raise RuntimeError( |
|
|
|
|
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
launch(config=config, |
|
|
|
|
local_rank=local_rank, |
|
|
|
|
rank=rank, |
|
|
|
|