diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 079c2cb..a771df6 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -6,7 +6,9 @@ import os from pathlib import Path from typing import Dict, Union +import numa import torch +from pynvml.smi import nvidia_smi from internlm.core.context import Config from internlm.core.context import global_context as gpc @@ -35,7 +37,7 @@ def get_default_parser(): help="launcher for launching distributed environment", ) parser.add_argument("--host", type=str, help="the master address for distributed training") - parser.add_argument("--port", type=int, default=8888, help="the master port for distributed training") + parser.add_argument("--port", type=int, default=9999, help="the master port for distributed training") parser.add_argument("--world_size", type=int, help="world size for distributed training") parser.add_argument("--rank", type=int, help="rank for the default process group") parser.add_argument("--local_rank", type=int, help="local rank on the node") @@ -474,3 +476,42 @@ def get_config_value(config, key, defalut): except KeyError: value = defalut return value + + +def try_bind_numa(launcher): + + numa_node_num = numa.info.get_max_node() + 1 + nvsmi = nvidia_smi.getInstance() + total_GPU_per_node = len(nvsmi.DeviceQuery("memory.total")["gpu"]) + + if total_GPU_per_node <= numa_node_num: + return + if total_GPU_per_node % numa_node_num != 0: + return + + if launcher == "torch": + global_rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + nnodes = int(os.environ["NUM_NODES"]) + elif launcher == "slurm": + global_rank = int(os.environ["SLURM_PROCID"]) + world_size = int(os.environ["SLURM_NPROCS"]) + local_rank = int(os.environ["SLURM_LOCALID"]) + nnodes = int(os.environ["SLURM_NNODES"]) + + if world_size % nnodes != 0 and world_size // nnodes != total_GPU_per_node: + return + + per_numa = total_GPU_per_node // numa_node_num + numa_id = local_rank // per_numa + + try: + from numa import memory, schedule + + schedule.run_on_nodes(numa_id) + memory.set_membind_nodes(numa_id) + except (AttributeError, ImportError): + return + else: + print(f"Rank: {global_rank} success bind process to numa node: {numa_id}", flush=True) diff --git a/train.py b/train.py index ff15354..7441147 100644 --- a/train.py +++ b/train.py @@ -14,7 +14,7 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.scheduler import SchedulerMetricHook from internlm.core.trainer import TrainState -from internlm.initialize import initialize_distributed_env +from internlm.initialize import initialize_distributed_env, try_bind_numa from internlm.model.loss import FlashGPTLMLoss from internlm.model.metrics import AccPerplex from internlm.monitor import initialize_monitor_manager, send_alert_message @@ -291,6 +291,8 @@ if __name__ == "__main__": args = parse_args() hostname = socket.gethostname() + try_bind_numa(args.launcher) + # initialize distributed environment initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed) assert hasattr(gpc, "config") and gpc.config is not None