diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 0622a6f..ddd01ef 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -16,6 +16,16 @@ from internlm.utils.common import get_master_node from internlm.utils.logger import get_logger from internlm.utils.timeout import llm_timeout +# check pacakge +try: + import numa + from numa import memory, schedule + from pynvml.smi import nvidia_smi +except (AttributeError, ImportError): + get_numa = False +else: + get_numa = True + logger = get_logger(__file__) @@ -385,6 +395,8 @@ def launch_from_slurm( except KeyError as e: raise RuntimeError(f"Could not find {e} in the SLURM environment") + try_bind_numa(global_rank=rank, world_size=world_size) + launch( config=config, rank=rank, @@ -418,6 +430,8 @@ def launch_from_torch( except KeyError as e: raise RuntimeError(f"Could not find {e} in the torch environment") + try_bind_numa(global_rank=rank, world_size=world_size, local_rank=local_rank) + launch( config=config, local_rank=local_rank, @@ -448,8 +462,6 @@ def initialize_distributed_env( seed (int, optional): Specified random seed for every process. 1024 by default. """ - try_bind_numa(launcher) - # close automatic garbage collection gc.disable() @@ -489,58 +501,43 @@ def get_config_value(config, key, defalut): return value -# check pacakge -try: - import numa - from numa import memory, schedule - from pynvml.smi import nvidia_smi -except (AttributeError, ImportError): - get_numa = False - global_rank = int(os.environ["SLURM_PROCID"]) if "SLURM_PROCID" in os.environ else int(os.environ["RANK"]) - if global_rank == 0: - logger.info( - "Try bind numa failed! Package import error, if numa is not installed, " - "please implement: pip install --upgrade py-libnuma" - ) -else: - get_numa = True - - -def try_bind_numa(launcher): +def try_bind_numa(global_rank, world_size, local_rank=None): # Early return if numa module not available if not get_numa: - return + if global_rank == 0: + logger.info( + "Try bind numa failed! Package import error, if numa is not installed, " + "please implement: pip install --upgrade py-libnuma, Ref: https://pypi.org/project/py-libnuma/" + ) # get numa node number - numa_node_num = numa.info.get_max_node() + 1 - # get total gpu number of current node - nvsmi = nvidia_smi.getInstance() - total_GPU_per_node = len(nvsmi.DeviceQuery("memory.total")["gpu"]) + try: + numa_node_num = numa.info.get_max_node() + 1 + # get total gpu number of current node + nvsmi = nvidia_smi.getInstance() + total_GPU_per_node = len(nvsmi.DeviceQuery("memory.total")["gpu"]) - # return while total_GPU_per_node is larger than numa_node_num or is not divisible by numa_node_num - if total_GPU_per_node <= numa_node_num: - return - if total_GPU_per_node % numa_node_num != 0: - return + # return while total_GPU_per_node is larger than numa_node_num or is not divisible by numa_node_num + if total_GPU_per_node <= numa_node_num: + return + if total_GPU_per_node % numa_node_num != 0: + return + # return while the number of processes is smaller than one node GPUs num + if world_size < total_GPU_per_node: + return - if launcher == "torch": - global_rank = int(os.environ["RANK"]) - local_rank = int(os.environ["LOCAL_RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - elif launcher == "slurm": - global_rank = int(os.environ["SLURM_PROCID"]) - local_rank = int(os.environ["SLURM_LOCALID"]) - world_size = int(os.environ["SLURM_NPROCS"]) + if local_rank is None: + devices_per_node = torch.cuda.device_count() + local_rank = global_rank % devices_per_node - # return while the number of processes is smaller than one node GPUs num - if world_size < total_GPU_per_node: - return + # compute numa id for each locak rank + per_numa = total_GPU_per_node // numa_node_num + numa_id = local_rank // per_numa - # compute numa id for each locak rank - per_numa = total_GPU_per_node // numa_node_num - numa_id = local_rank // per_numa - - # bind numa node - schedule.run_on_nodes(numa_id) - memory.set_membind_nodes(numa_id) - logger.info(f"Rank: {global_rank} success bind process to numa node: {numa_id}") + # bind numa node + schedule.run_on_nodes(numa_id) + memory.set_membind_nodes(numa_id) + except Exception: + return # try_bind_numa should not raise exception + else: + logger.info(f"Rank: {global_rank} success bind process to numa node: {numa_id}")