mirror of https://github.com/InternLM/InternLM
feat:add numa
parent
ab513e1ddd
commit
ca5858eb85
|
@ -6,7 +6,9 @@ import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
|
import numa
|
||||||
import torch
|
import torch
|
||||||
|
from pynvml.smi import nvidia_smi
|
||||||
|
|
||||||
from internlm.core.context import Config
|
from internlm.core.context import Config
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
|
@ -35,7 +37,7 @@ def get_default_parser():
|
||||||
help="launcher for launching distributed environment",
|
help="launcher for launching distributed environment",
|
||||||
)
|
)
|
||||||
parser.add_argument("--host", type=str, help="the master address for distributed training")
|
parser.add_argument("--host", type=str, help="the master address for distributed training")
|
||||||
parser.add_argument("--port", type=int, default=8888, help="the master port for distributed training")
|
parser.add_argument("--port", type=int, default=9999, help="the master port for distributed training")
|
||||||
parser.add_argument("--world_size", type=int, help="world size for distributed training")
|
parser.add_argument("--world_size", type=int, help="world size for distributed training")
|
||||||
parser.add_argument("--rank", type=int, help="rank for the default process group")
|
parser.add_argument("--rank", type=int, help="rank for the default process group")
|
||||||
parser.add_argument("--local_rank", type=int, help="local rank on the node")
|
parser.add_argument("--local_rank", type=int, help="local rank on the node")
|
||||||
|
@ -474,3 +476,42 @@ def get_config_value(config, key, defalut):
|
||||||
except KeyError:
|
except KeyError:
|
||||||
value = defalut
|
value = defalut
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def try_bind_numa(launcher):
|
||||||
|
|
||||||
|
numa_node_num = numa.info.get_max_node() + 1
|
||||||
|
nvsmi = nvidia_smi.getInstance()
|
||||||
|
total_GPU_per_node = len(nvsmi.DeviceQuery("memory.total")["gpu"])
|
||||||
|
|
||||||
|
if total_GPU_per_node <= numa_node_num:
|
||||||
|
return
|
||||||
|
if total_GPU_per_node % numa_node_num != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
if launcher == "torch":
|
||||||
|
global_rank = int(os.environ["RANK"])
|
||||||
|
world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
local_rank = int(os.environ["LOCAL_RANK"])
|
||||||
|
nnodes = int(os.environ["NUM_NODES"])
|
||||||
|
elif launcher == "slurm":
|
||||||
|
global_rank = int(os.environ["SLURM_PROCID"])
|
||||||
|
world_size = int(os.environ["SLURM_NPROCS"])
|
||||||
|
local_rank = int(os.environ["SLURM_LOCALID"])
|
||||||
|
nnodes = int(os.environ["SLURM_NNODES"])
|
||||||
|
|
||||||
|
if world_size % nnodes != 0 and world_size // nnodes != total_GPU_per_node:
|
||||||
|
return
|
||||||
|
|
||||||
|
per_numa = total_GPU_per_node // numa_node_num
|
||||||
|
numa_id = local_rank // per_numa
|
||||||
|
|
||||||
|
try:
|
||||||
|
from numa import memory, schedule
|
||||||
|
|
||||||
|
schedule.run_on_nodes(numa_id)
|
||||||
|
memory.set_membind_nodes(numa_id)
|
||||||
|
except (AttributeError, ImportError):
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
print(f"Rank: {global_rank} success bind process to numa node: {numa_id}", flush=True)
|
||||||
|
|
4
train.py
4
train.py
|
@ -14,7 +14,7 @@ from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.scheduler import SchedulerMetricHook
|
from internlm.core.scheduler import SchedulerMetricHook
|
||||||
from internlm.core.trainer import TrainState
|
from internlm.core.trainer import TrainState
|
||||||
from internlm.initialize import initialize_distributed_env
|
from internlm.initialize import initialize_distributed_env, try_bind_numa
|
||||||
from internlm.model.loss import FlashGPTLMLoss
|
from internlm.model.loss import FlashGPTLMLoss
|
||||||
from internlm.model.metrics import AccPerplex
|
from internlm.model.metrics import AccPerplex
|
||||||
from internlm.monitor import initialize_monitor_manager, send_alert_message
|
from internlm.monitor import initialize_monitor_manager, send_alert_message
|
||||||
|
@ -291,6 +291,8 @@ if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
hostname = socket.gethostname()
|
hostname = socket.gethostname()
|
||||||
|
|
||||||
|
try_bind_numa(args.launcher)
|
||||||
|
|
||||||
# initialize distributed environment
|
# initialize distributed environment
|
||||||
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
|
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
|
||||||
assert hasattr(gpc, "config") and gpc.config is not None
|
assert hasattr(gpc, "config") and gpc.config is not None
|
||||||
|
|
Loading…
Reference in New Issue