feat:add numa

pull/320/head
li126com 2023-09-18 21:10:16 +08:00
parent ab513e1ddd
commit ca5858eb85
2 changed files with 45 additions and 2 deletions

View File

@ -6,7 +6,9 @@ import os
from pathlib import Path
from typing import Dict, Union
import numa
import torch
from pynvml.smi import nvidia_smi
from internlm.core.context import Config
from internlm.core.context import global_context as gpc
@ -35,7 +37,7 @@ def get_default_parser():
help="launcher for launching distributed environment",
)
parser.add_argument("--host", type=str, help="the master address for distributed training")
parser.add_argument("--port", type=int, default=8888, help="the master port for distributed training")
parser.add_argument("--port", type=int, default=9999, help="the master port for distributed training")
parser.add_argument("--world_size", type=int, help="world size for distributed training")
parser.add_argument("--rank", type=int, help="rank for the default process group")
parser.add_argument("--local_rank", type=int, help="local rank on the node")
@ -474,3 +476,42 @@ def get_config_value(config, key, defalut):
except KeyError:
value = defalut
return value
def try_bind_numa(launcher):
numa_node_num = numa.info.get_max_node() + 1
nvsmi = nvidia_smi.getInstance()
total_GPU_per_node = len(nvsmi.DeviceQuery("memory.total")["gpu"])
if total_GPU_per_node <= numa_node_num:
return
if total_GPU_per_node % numa_node_num != 0:
return
if launcher == "torch":
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"])
nnodes = int(os.environ["NUM_NODES"])
elif launcher == "slurm":
global_rank = int(os.environ["SLURM_PROCID"])
world_size = int(os.environ["SLURM_NPROCS"])
local_rank = int(os.environ["SLURM_LOCALID"])
nnodes = int(os.environ["SLURM_NNODES"])
if world_size % nnodes != 0 and world_size // nnodes != total_GPU_per_node:
return
per_numa = total_GPU_per_node // numa_node_num
numa_id = local_rank // per_numa
try:
from numa import memory, schedule
schedule.run_on_nodes(numa_id)
memory.set_membind_nodes(numa_id)
except (AttributeError, ImportError):
return
else:
print(f"Rank: {global_rank} success bind process to numa node: {numa_id}", flush=True)

View File

@ -14,7 +14,7 @@ from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.scheduler import SchedulerMetricHook
from internlm.core.trainer import TrainState
from internlm.initialize import initialize_distributed_env
from internlm.initialize import initialize_distributed_env, try_bind_numa
from internlm.model.loss import FlashGPTLMLoss
from internlm.model.metrics import AccPerplex
from internlm.monitor import initialize_monitor_manager, send_alert_message
@ -291,6 +291,8 @@ if __name__ == "__main__":
args = parse_args()
hostname = socket.gethostname()
try_bind_numa(args.launcher)
# initialize distributed environment
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
assert hasattr(gpc, "config") and gpc.config is not None