feat: bind numa

pull/320/head
li126com 2023-09-22 12:19:06 +08:00
parent ad73355215
commit ea3d333144
1 changed files with 8 additions and 67 deletions

View File

@ -2,10 +2,7 @@
# -*- encoding: utf-8 -*-
import argparse
import copy
import os
import re
import socket
from pathlib import Path
from typing import Dict, Union
@ -482,62 +479,14 @@ def get_config_value(config, key, defalut):
return value
def get_node_hostnames(nodes: str):
hostnames = []
tmp_nodelist = copy.deepcopy(nodes)
tmp_nodelist = tmp_nodelist.replace(" ", "")
tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
while tmp_res:
tmp_res2 = re.search(",", tmp_nodelist)
if tmp_res2 and tmp_res2.start() < tmp_res.start():
tmp_nodelist2 = tmp_nodelist.split(",")
count = 0
while count < len(tmp_nodelist2):
if re.search(r"\[", tmp_nodelist2[count]):
break
hostnames.append(tmp_nodelist2[count])
count += 1
tmp_nodelist = ",".join(tmp_nodelist2[count:])
tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
node_range = tmp_nodelist[tmp_res.start() : tmp_res.end()]
prefix = tmp_nodelist[: tmp_res.start()].replace(",", "")
node_range = re.sub(r"[\[\]]", "", node_range)
tmplist = node_range.split(",")
pattern1 = r"^\d+$"
pattern2 = r"^\d+-\d+$"
for tmpiter in tmplist:
if re.match(pattern1, tmpiter):
hostnames.append(prefix + tmpiter)
elif re.match(pattern2, tmpiter):
begin, end = int(tmpiter.split("-")[0]), int(tmpiter.split("-")[1])
hostnames.extend([prefix + str(i) for i in range(begin, end + 1)])
else:
prefix = "-".join(tmpiter.split("-")[:-1]) + "-"
hostnames.append(tmpiter)
tmp_nodelist = tmp_nodelist[tmp_res.end() :]
tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
tmplist = tmp_nodelist.split(",")
hostnames.extend(tmplist)
while "" in hostnames:
hostnames.remove("")
return hostnames
def try_bind_numa(launcher):
# get numa node number
numa_node_num = numa.info.get_max_node() + 1
# get total gpu number of current node
nvsmi = nvidia_smi.getInstance()
total_GPU_per_node = len(nvsmi.DeviceQuery("memory.total")["gpu"])
# return while total_GPU_per_node is larger than numa_node_num or is not divisible by numa_node_num
if total_GPU_per_node <= numa_node_num:
return
if total_GPU_per_node % numa_node_num != 0:
@ -546,25 +495,17 @@ def try_bind_numa(launcher):
if launcher == "torch":
global_rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
world_size = int(os.environ["WORLD_SIZE"])
elif launcher == "slurm":
global_rank = int(os.environ["SLURM_PROCID"])
local_rank = int(os.environ["SLURM_LOCALID"])
local_world_size_group = os.environ["SLURM_TASKS_PER_NODE"]
if "(" in local_world_size_group:
local_world_size = int(os.environ["SLURM_TASKS_PER_NODE"].split("(")[0])
elif "," in local_world_size_group:
local_world_size_list = os.environ["SLURM_TASKS_PER_NODE"].split(",")
node_list = get_node_hostnames(os.environ["SLURM_NODELIST"])
hostname = socket.gethostname()
index = node_list.index(hostname)
local_world_size = int(local_world_size_list[index])
else:
local_world_size = int(local_world_size_group)
world_size = int(os.environ["SLURM_NPROCS"])
if total_GPU_per_node != local_world_size:
# return while the number of processes is smaller than one node GPUs num
if world_size < total_GPU_per_node:
return
# compute numa id for each locak rank
per_numa = total_GPU_per_node // numa_node_num
numa_id = local_rank // per_numa