feat:add bind numa

pull/320/head
li126com 2023-09-19 14:34:55 +08:00
parent ca5858eb85
commit cb34b9325b
2 changed files with 70 additions and 7 deletions

View File

@ -2,7 +2,10 @@
# -*- encoding: utf-8 -*-
import argparse
import copy
import os
import re
import socket
from pathlib import Path
from typing import Dict, Union
@ -44,6 +47,7 @@ def get_default_parser():
parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication")
parser.add_argument("--seed", type=int, default=1024)
parser.add_argument("--profiling", default=False, action="store_true", help="enable/disable profiling.")
parser.add_argument("--bind_numa", default=False, action="store_true", help="enable/disable bind_numa.")
return parser
@ -478,8 +482,58 @@ def get_config_value(config, key, defalut):
return value
def try_bind_numa(launcher):
def get_node_hostnames(nodes: str):
hostnames = []
tmp_nodelist = copy.deepcopy(nodes)
tmp_nodelist = tmp_nodelist.replace(" ", "")
tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
while tmp_res:
tmp_res2 = re.search(",", tmp_nodelist)
if tmp_res2 and tmp_res2.start() < tmp_res.start():
tmp_nodelist2 = tmp_nodelist.split(",")
count = 0
while count < len(tmp_nodelist2):
if re.search(r"\[", tmp_nodelist2[count]):
break
hostnames.append(tmp_nodelist2[count])
count += 1
tmp_nodelist = ",".join(tmp_nodelist2[count:])
tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
node_range = tmp_nodelist[tmp_res.start() : tmp_res.end()]
prefix = tmp_nodelist[: tmp_res.start()].replace(",", "")
node_range = re.sub(r"[\[\]]", "", node_range)
tmplist = node_range.split(",")
pattern1 = r"^\d+$"
pattern2 = r"^\d+-\d+$"
for tmpiter in tmplist:
if re.match(pattern1, tmpiter):
hostnames.append(prefix + tmpiter)
elif re.match(pattern2, tmpiter):
begin, end = int(tmpiter.split("-")[0]), int(tmpiter.split("-")[1])
hostnames.extend([prefix + str(i) for i in range(begin, end + 1)])
else:
prefix = "-".join(tmpiter.split("-")[:-1]) + "-"
hostnames.append(tmpiter)
tmp_nodelist = tmp_nodelist[tmp_res.end() :]
tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
tmplist = tmp_nodelist.split(",")
hostnames.extend(tmplist)
while "" in hostnames:
hostnames.remove("")
return hostnames
def try_bind_numa(launcher):
numa_node_num = numa.info.get_max_node() + 1
nvsmi = nvidia_smi.getInstance()
total_GPU_per_node = len(nvsmi.DeviceQuery("memory.total")["gpu"])
@ -491,16 +545,24 @@ def try_bind_numa(launcher):
if launcher == "torch":
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"])
nnodes = int(os.environ["NUM_NODES"])
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
elif launcher == "slurm":
global_rank = int(os.environ["SLURM_PROCID"])
world_size = int(os.environ["SLURM_NPROCS"])
local_rank = int(os.environ["SLURM_LOCALID"])
nnodes = int(os.environ["SLURM_NNODES"])
local_world_size_group = os.environ["SLURM_TASKS_PER_NODE"]
if "(" in local_world_size_group:
local_world_size = int(os.environ["SLURM_TASKS_PER_NODE"].split("(")[0])
elif "," in local_world_size_group:
local_world_size_list = os.environ["SLURM_TASKS_PER_NODE"].split(",")
node_list = get_node_hostnames(os.environ["SLURM_NODELIST"])
hostname = socket.gethostname()
index = node_list.index(hostname)
local_world_size = int(local_world_size_list[index])
else:
local_world_size = int(local_world_size_group)
if world_size % nnodes != 0 and world_size // nnodes != total_GPU_per_node:
if total_GPU_per_node != local_world_size:
return
per_numa = total_GPU_per_node // numa_node_num

View File

@ -291,7 +291,8 @@ if __name__ == "__main__":
args = parse_args()
hostname = socket.gethostname()
try_bind_numa(args.launcher)
if args.bind_numa:
try_bind_numa(args.launcher)
# initialize distributed environment
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)