mirror of https://github.com/InternLM/InternLM
feat: bind numa
parent
ad73355215
commit
ea3d333144
|
@ -2,10 +2,7 @@
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import copy
|
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import socket
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
|
@ -482,62 +479,14 @@ def get_config_value(config, key, defalut):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
def get_node_hostnames(nodes: str):
|
|
||||||
hostnames = []
|
|
||||||
|
|
||||||
tmp_nodelist = copy.deepcopy(nodes)
|
|
||||||
tmp_nodelist = tmp_nodelist.replace(" ", "")
|
|
||||||
|
|
||||||
tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
|
|
||||||
while tmp_res:
|
|
||||||
tmp_res2 = re.search(",", tmp_nodelist)
|
|
||||||
if tmp_res2 and tmp_res2.start() < tmp_res.start():
|
|
||||||
tmp_nodelist2 = tmp_nodelist.split(",")
|
|
||||||
count = 0
|
|
||||||
while count < len(tmp_nodelist2):
|
|
||||||
if re.search(r"\[", tmp_nodelist2[count]):
|
|
||||||
break
|
|
||||||
hostnames.append(tmp_nodelist2[count])
|
|
||||||
count += 1
|
|
||||||
tmp_nodelist = ",".join(tmp_nodelist2[count:])
|
|
||||||
tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
|
|
||||||
node_range = tmp_nodelist[tmp_res.start() : tmp_res.end()]
|
|
||||||
prefix = tmp_nodelist[: tmp_res.start()].replace(",", "")
|
|
||||||
|
|
||||||
node_range = re.sub(r"[\[\]]", "", node_range)
|
|
||||||
|
|
||||||
tmplist = node_range.split(",")
|
|
||||||
|
|
||||||
pattern1 = r"^\d+$"
|
|
||||||
pattern2 = r"^\d+-\d+$"
|
|
||||||
|
|
||||||
for tmpiter in tmplist:
|
|
||||||
if re.match(pattern1, tmpiter):
|
|
||||||
hostnames.append(prefix + tmpiter)
|
|
||||||
elif re.match(pattern2, tmpiter):
|
|
||||||
begin, end = int(tmpiter.split("-")[0]), int(tmpiter.split("-")[1])
|
|
||||||
hostnames.extend([prefix + str(i) for i in range(begin, end + 1)])
|
|
||||||
else:
|
|
||||||
prefix = "-".join(tmpiter.split("-")[:-1]) + "-"
|
|
||||||
hostnames.append(tmpiter)
|
|
||||||
|
|
||||||
tmp_nodelist = tmp_nodelist[tmp_res.end() :]
|
|
||||||
tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
|
|
||||||
|
|
||||||
tmplist = tmp_nodelist.split(",")
|
|
||||||
hostnames.extend(tmplist)
|
|
||||||
|
|
||||||
while "" in hostnames:
|
|
||||||
hostnames.remove("")
|
|
||||||
|
|
||||||
return hostnames
|
|
||||||
|
|
||||||
|
|
||||||
def try_bind_numa(launcher):
|
def try_bind_numa(launcher):
|
||||||
|
# get numa node number
|
||||||
numa_node_num = numa.info.get_max_node() + 1
|
numa_node_num = numa.info.get_max_node() + 1
|
||||||
|
# get total gpu number of current node
|
||||||
nvsmi = nvidia_smi.getInstance()
|
nvsmi = nvidia_smi.getInstance()
|
||||||
total_GPU_per_node = len(nvsmi.DeviceQuery("memory.total")["gpu"])
|
total_GPU_per_node = len(nvsmi.DeviceQuery("memory.total")["gpu"])
|
||||||
|
|
||||||
|
# return while total_GPU_per_node is larger than numa_node_num or is not divisible by numa_node_num
|
||||||
if total_GPU_per_node <= numa_node_num:
|
if total_GPU_per_node <= numa_node_num:
|
||||||
return
|
return
|
||||||
if total_GPU_per_node % numa_node_num != 0:
|
if total_GPU_per_node % numa_node_num != 0:
|
||||||
|
@ -546,25 +495,17 @@ def try_bind_numa(launcher):
|
||||||
if launcher == "torch":
|
if launcher == "torch":
|
||||||
global_rank = int(os.environ["RANK"])
|
global_rank = int(os.environ["RANK"])
|
||||||
local_rank = int(os.environ["LOCAL_RANK"])
|
local_rank = int(os.environ["LOCAL_RANK"])
|
||||||
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
|
world_size = int(os.environ["WORLD_SIZE"])
|
||||||
elif launcher == "slurm":
|
elif launcher == "slurm":
|
||||||
global_rank = int(os.environ["SLURM_PROCID"])
|
global_rank = int(os.environ["SLURM_PROCID"])
|
||||||
local_rank = int(os.environ["SLURM_LOCALID"])
|
local_rank = int(os.environ["SLURM_LOCALID"])
|
||||||
local_world_size_group = os.environ["SLURM_TASKS_PER_NODE"]
|
world_size = int(os.environ["SLURM_NPROCS"])
|
||||||
if "(" in local_world_size_group:
|
|
||||||
local_world_size = int(os.environ["SLURM_TASKS_PER_NODE"].split("(")[0])
|
|
||||||
elif "," in local_world_size_group:
|
|
||||||
local_world_size_list = os.environ["SLURM_TASKS_PER_NODE"].split(",")
|
|
||||||
node_list = get_node_hostnames(os.environ["SLURM_NODELIST"])
|
|
||||||
hostname = socket.gethostname()
|
|
||||||
index = node_list.index(hostname)
|
|
||||||
local_world_size = int(local_world_size_list[index])
|
|
||||||
else:
|
|
||||||
local_world_size = int(local_world_size_group)
|
|
||||||
|
|
||||||
if total_GPU_per_node != local_world_size:
|
# return while the number of processes is smaller than one node GPUs num
|
||||||
|
if world_size < total_GPU_per_node:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# compute numa id for each locak rank
|
||||||
per_numa = total_GPU_per_node // numa_node_num
|
per_numa = total_GPU_per_node // numa_node_num
|
||||||
numa_id = local_rank // per_numa
|
numa_id = local_rank // per_numa
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue