feat:add bind numa

pull/320/head
li126com 2023-09-19 14:34:55 +08:00
parent ca5858eb85
commit cb34b9325b
2 changed files with 70 additions and 7 deletions

View File

@ -2,7 +2,10 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import argparse import argparse
import copy
import os import os
import re
import socket
from pathlib import Path from pathlib import Path
from typing import Dict, Union from typing import Dict, Union
@ -44,6 +47,7 @@ def get_default_parser():
parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication") parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication")
parser.add_argument("--seed", type=int, default=1024) parser.add_argument("--seed", type=int, default=1024)
parser.add_argument("--profiling", default=False, action="store_true", help="enable/disable profiling.") parser.add_argument("--profiling", default=False, action="store_true", help="enable/disable profiling.")
parser.add_argument("--bind_numa", default=False, action="store_true", help="enable/disable bind_numa.")
return parser return parser
@ -478,8 +482,58 @@ def get_config_value(config, key, defalut):
return value return value
def try_bind_numa(launcher): def get_node_hostnames(nodes: str):
hostnames = []
tmp_nodelist = copy.deepcopy(nodes)
tmp_nodelist = tmp_nodelist.replace(" ", "")
tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
while tmp_res:
tmp_res2 = re.search(",", tmp_nodelist)
if tmp_res2 and tmp_res2.start() < tmp_res.start():
tmp_nodelist2 = tmp_nodelist.split(",")
count = 0
while count < len(tmp_nodelist2):
if re.search(r"\[", tmp_nodelist2[count]):
break
hostnames.append(tmp_nodelist2[count])
count += 1
tmp_nodelist = ",".join(tmp_nodelist2[count:])
tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
node_range = tmp_nodelist[tmp_res.start() : tmp_res.end()]
prefix = tmp_nodelist[: tmp_res.start()].replace(",", "")
node_range = re.sub(r"[\[\]]", "", node_range)
tmplist = node_range.split(",")
pattern1 = r"^\d+$"
pattern2 = r"^\d+-\d+$"
for tmpiter in tmplist:
if re.match(pattern1, tmpiter):
hostnames.append(prefix + tmpiter)
elif re.match(pattern2, tmpiter):
begin, end = int(tmpiter.split("-")[0]), int(tmpiter.split("-")[1])
hostnames.extend([prefix + str(i) for i in range(begin, end + 1)])
else:
prefix = "-".join(tmpiter.split("-")[:-1]) + "-"
hostnames.append(tmpiter)
tmp_nodelist = tmp_nodelist[tmp_res.end() :]
tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
tmplist = tmp_nodelist.split(",")
hostnames.extend(tmplist)
while "" in hostnames:
hostnames.remove("")
return hostnames
def try_bind_numa(launcher):
numa_node_num = numa.info.get_max_node() + 1 numa_node_num = numa.info.get_max_node() + 1
nvsmi = nvidia_smi.getInstance() nvsmi = nvidia_smi.getInstance()
total_GPU_per_node = len(nvsmi.DeviceQuery("memory.total")["gpu"]) total_GPU_per_node = len(nvsmi.DeviceQuery("memory.total")["gpu"])
@ -491,16 +545,24 @@ def try_bind_numa(launcher):
if launcher == "torch": if launcher == "torch":
global_rank = int(os.environ["RANK"]) global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"]) local_rank = int(os.environ["LOCAL_RANK"])
nnodes = int(os.environ["NUM_NODES"]) local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
elif launcher == "slurm": elif launcher == "slurm":
global_rank = int(os.environ["SLURM_PROCID"]) global_rank = int(os.environ["SLURM_PROCID"])
world_size = int(os.environ["SLURM_NPROCS"])
local_rank = int(os.environ["SLURM_LOCALID"]) local_rank = int(os.environ["SLURM_LOCALID"])
nnodes = int(os.environ["SLURM_NNODES"]) local_world_size_group = os.environ["SLURM_TASKS_PER_NODE"]
if "(" in local_world_size_group:
local_world_size = int(os.environ["SLURM_TASKS_PER_NODE"].split("(")[0])
elif "," in local_world_size_group:
local_world_size_list = os.environ["SLURM_TASKS_PER_NODE"].split(",")
node_list = get_node_hostnames(os.environ["SLURM_NODELIST"])
hostname = socket.gethostname()
index = node_list.index(hostname)
local_world_size = int(local_world_size_list[index])
else:
local_world_size = int(local_world_size_group)
if world_size % nnodes != 0 and world_size // nnodes != total_GPU_per_node: if total_GPU_per_node != local_world_size:
return return
per_numa = total_GPU_per_node // numa_node_num per_numa = total_GPU_per_node // numa_node_num

View File

@ -291,7 +291,8 @@ if __name__ == "__main__":
args = parse_args() args = parse_args()
hostname = socket.gethostname() hostname = socket.gethostname()
try_bind_numa(args.launcher) if args.bind_numa:
try_bind_numa(args.launcher)
# initialize distributed environment # initialize distributed environment
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed) initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)