mirror of https://github.com/InternLM/InternLM
feat:add bind numa
parent
ca5858eb85
commit
cb34b9325b
|
@ -2,7 +2,10 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
|
@ -44,6 +47,7 @@ def get_default_parser():
|
|||
parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication")
|
||||
parser.add_argument("--seed", type=int, default=1024)
|
||||
parser.add_argument("--profiling", default=False, action="store_true", help="enable/disable profiling.")
|
||||
parser.add_argument("--bind_numa", default=False, action="store_true", help="enable/disable bind_numa.")
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -478,8 +482,58 @@ def get_config_value(config, key, defalut):
|
|||
return value
|
||||
|
||||
|
||||
def try_bind_numa(launcher):
|
||||
def get_node_hostnames(nodes: str):
|
||||
hostnames = []
|
||||
|
||||
tmp_nodelist = copy.deepcopy(nodes)
|
||||
tmp_nodelist = tmp_nodelist.replace(" ", "")
|
||||
|
||||
tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
|
||||
while tmp_res:
|
||||
tmp_res2 = re.search(",", tmp_nodelist)
|
||||
if tmp_res2 and tmp_res2.start() < tmp_res.start():
|
||||
tmp_nodelist2 = tmp_nodelist.split(",")
|
||||
count = 0
|
||||
while count < len(tmp_nodelist2):
|
||||
if re.search(r"\[", tmp_nodelist2[count]):
|
||||
break
|
||||
hostnames.append(tmp_nodelist2[count])
|
||||
count += 1
|
||||
tmp_nodelist = ",".join(tmp_nodelist2[count:])
|
||||
tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
|
||||
node_range = tmp_nodelist[tmp_res.start() : tmp_res.end()]
|
||||
prefix = tmp_nodelist[: tmp_res.start()].replace(",", "")
|
||||
|
||||
node_range = re.sub(r"[\[\]]", "", node_range)
|
||||
|
||||
tmplist = node_range.split(",")
|
||||
|
||||
pattern1 = r"^\d+$"
|
||||
pattern2 = r"^\d+-\d+$"
|
||||
|
||||
for tmpiter in tmplist:
|
||||
if re.match(pattern1, tmpiter):
|
||||
hostnames.append(prefix + tmpiter)
|
||||
elif re.match(pattern2, tmpiter):
|
||||
begin, end = int(tmpiter.split("-")[0]), int(tmpiter.split("-")[1])
|
||||
hostnames.extend([prefix + str(i) for i in range(begin, end + 1)])
|
||||
else:
|
||||
prefix = "-".join(tmpiter.split("-")[:-1]) + "-"
|
||||
hostnames.append(tmpiter)
|
||||
|
||||
tmp_nodelist = tmp_nodelist[tmp_res.end() :]
|
||||
tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
|
||||
|
||||
tmplist = tmp_nodelist.split(",")
|
||||
hostnames.extend(tmplist)
|
||||
|
||||
while "" in hostnames:
|
||||
hostnames.remove("")
|
||||
|
||||
return hostnames
|
||||
|
||||
|
||||
def try_bind_numa(launcher):
|
||||
numa_node_num = numa.info.get_max_node() + 1
|
||||
nvsmi = nvidia_smi.getInstance()
|
||||
total_GPU_per_node = len(nvsmi.DeviceQuery("memory.total")["gpu"])
|
||||
|
@ -491,16 +545,24 @@ def try_bind_numa(launcher):
|
|||
|
||||
if launcher == "torch":
|
||||
global_rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
nnodes = int(os.environ["NUM_NODES"])
|
||||
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
|
||||
elif launcher == "slurm":
|
||||
global_rank = int(os.environ["SLURM_PROCID"])
|
||||
world_size = int(os.environ["SLURM_NPROCS"])
|
||||
local_rank = int(os.environ["SLURM_LOCALID"])
|
||||
nnodes = int(os.environ["SLURM_NNODES"])
|
||||
local_world_size_group = os.environ["SLURM_TASKS_PER_NODE"]
|
||||
if "(" in local_world_size_group:
|
||||
local_world_size = int(os.environ["SLURM_TASKS_PER_NODE"].split("(")[0])
|
||||
elif "," in local_world_size_group:
|
||||
local_world_size_list = os.environ["SLURM_TASKS_PER_NODE"].split(",")
|
||||
node_list = get_node_hostnames(os.environ["SLURM_NODELIST"])
|
||||
hostname = socket.gethostname()
|
||||
index = node_list.index(hostname)
|
||||
local_world_size = int(local_world_size_list[index])
|
||||
else:
|
||||
local_world_size = int(local_world_size_group)
|
||||
|
||||
if world_size % nnodes != 0 and world_size // nnodes != total_GPU_per_node:
|
||||
if total_GPU_per_node != local_world_size:
|
||||
return
|
||||
|
||||
per_numa = total_GPU_per_node // numa_node_num
|
||||
|
|
3
train.py
3
train.py
|
@ -291,7 +291,8 @@ if __name__ == "__main__":
|
|||
args = parse_args()
|
||||
hostname = socket.gethostname()
|
||||
|
||||
try_bind_numa(args.launcher)
|
||||
if args.bind_numa:
|
||||
try_bind_numa(args.launcher)
|
||||
|
||||
# initialize distributed environment
|
||||
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
|
||||
|
|
Loading…
Reference in New Issue