feat: bind numa

pull/320/head
li126com 2023-09-22 12:19:06 +08:00
parent ad73355215
commit ea3d333144
1 changed files with 8 additions and 67 deletions

View File

@ -2,10 +2,7 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import argparse import argparse
import copy
import os import os
import re
import socket
from pathlib import Path from pathlib import Path
from typing import Dict, Union from typing import Dict, Union
@ -482,62 +479,14 @@ def get_config_value(config, key, defalut):
return value return value
def get_node_hostnames(nodes: str):
hostnames = []
tmp_nodelist = copy.deepcopy(nodes)
tmp_nodelist = tmp_nodelist.replace(" ", "")
tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
while tmp_res:
tmp_res2 = re.search(",", tmp_nodelist)
if tmp_res2 and tmp_res2.start() < tmp_res.start():
tmp_nodelist2 = tmp_nodelist.split(",")
count = 0
while count < len(tmp_nodelist2):
if re.search(r"\[", tmp_nodelist2[count]):
break
hostnames.append(tmp_nodelist2[count])
count += 1
tmp_nodelist = ",".join(tmp_nodelist2[count:])
tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
node_range = tmp_nodelist[tmp_res.start() : tmp_res.end()]
prefix = tmp_nodelist[: tmp_res.start()].replace(",", "")
node_range = re.sub(r"[\[\]]", "", node_range)
tmplist = node_range.split(",")
pattern1 = r"^\d+$"
pattern2 = r"^\d+-\d+$"
for tmpiter in tmplist:
if re.match(pattern1, tmpiter):
hostnames.append(prefix + tmpiter)
elif re.match(pattern2, tmpiter):
begin, end = int(tmpiter.split("-")[0]), int(tmpiter.split("-")[1])
hostnames.extend([prefix + str(i) for i in range(begin, end + 1)])
else:
prefix = "-".join(tmpiter.split("-")[:-1]) + "-"
hostnames.append(tmpiter)
tmp_nodelist = tmp_nodelist[tmp_res.end() :]
tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
tmplist = tmp_nodelist.split(",")
hostnames.extend(tmplist)
while "" in hostnames:
hostnames.remove("")
return hostnames
def try_bind_numa(launcher): def try_bind_numa(launcher):
# get numa node number
numa_node_num = numa.info.get_max_node() + 1 numa_node_num = numa.info.get_max_node() + 1
# get total gpu number of current node
nvsmi = nvidia_smi.getInstance() nvsmi = nvidia_smi.getInstance()
total_GPU_per_node = len(nvsmi.DeviceQuery("memory.total")["gpu"]) total_GPU_per_node = len(nvsmi.DeviceQuery("memory.total")["gpu"])
# return while total_GPU_per_node is larger than numa_node_num or is not divisible by numa_node_num
if total_GPU_per_node <= numa_node_num: if total_GPU_per_node <= numa_node_num:
return return
if total_GPU_per_node % numa_node_num != 0: if total_GPU_per_node % numa_node_num != 0:
@ -546,25 +495,17 @@ def try_bind_numa(launcher):
if launcher == "torch": if launcher == "torch":
global_rank = int(os.environ["RANK"]) global_rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"]) local_rank = int(os.environ["LOCAL_RANK"])
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) world_size = int(os.environ["WORLD_SIZE"])
elif launcher == "slurm": elif launcher == "slurm":
global_rank = int(os.environ["SLURM_PROCID"]) global_rank = int(os.environ["SLURM_PROCID"])
local_rank = int(os.environ["SLURM_LOCALID"]) local_rank = int(os.environ["SLURM_LOCALID"])
local_world_size_group = os.environ["SLURM_TASKS_PER_NODE"] world_size = int(os.environ["SLURM_NPROCS"])
if "(" in local_world_size_group:
local_world_size = int(os.environ["SLURM_TASKS_PER_NODE"].split("(")[0])
elif "," in local_world_size_group:
local_world_size_list = os.environ["SLURM_TASKS_PER_NODE"].split(",")
node_list = get_node_hostnames(os.environ["SLURM_NODELIST"])
hostname = socket.gethostname()
index = node_list.index(hostname)
local_world_size = int(local_world_size_list[index])
else:
local_world_size = int(local_world_size_group)
if total_GPU_per_node != local_world_size: # return while the number of processes is smaller than one node GPUs num
if world_size < total_GPU_per_node:
return return
# compute numa id for each locak rank
per_numa = total_GPU_per_node // numa_node_num per_numa = total_GPU_per_node // numa_node_num
numa_id = local_rank // per_numa numa_id = local_rank // per_numa