mirror of https://github.com/InternLM/InternLM
feat: bind numa
parent
ad73355215
commit
ea3d333144
|
@ -2,10 +2,7 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
|
@ -482,62 +479,14 @@ def get_config_value(config, key, defalut):
|
|||
return value
|
||||
|
||||
|
||||
def get_node_hostnames(nodes: str):
|
||||
hostnames = []
|
||||
|
||||
tmp_nodelist = copy.deepcopy(nodes)
|
||||
tmp_nodelist = tmp_nodelist.replace(" ", "")
|
||||
|
||||
tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
|
||||
while tmp_res:
|
||||
tmp_res2 = re.search(",", tmp_nodelist)
|
||||
if tmp_res2 and tmp_res2.start() < tmp_res.start():
|
||||
tmp_nodelist2 = tmp_nodelist.split(",")
|
||||
count = 0
|
||||
while count < len(tmp_nodelist2):
|
||||
if re.search(r"\[", tmp_nodelist2[count]):
|
||||
break
|
||||
hostnames.append(tmp_nodelist2[count])
|
||||
count += 1
|
||||
tmp_nodelist = ",".join(tmp_nodelist2[count:])
|
||||
tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
|
||||
node_range = tmp_nodelist[tmp_res.start() : tmp_res.end()]
|
||||
prefix = tmp_nodelist[: tmp_res.start()].replace(",", "")
|
||||
|
||||
node_range = re.sub(r"[\[\]]", "", node_range)
|
||||
|
||||
tmplist = node_range.split(",")
|
||||
|
||||
pattern1 = r"^\d+$"
|
||||
pattern2 = r"^\d+-\d+$"
|
||||
|
||||
for tmpiter in tmplist:
|
||||
if re.match(pattern1, tmpiter):
|
||||
hostnames.append(prefix + tmpiter)
|
||||
elif re.match(pattern2, tmpiter):
|
||||
begin, end = int(tmpiter.split("-")[0]), int(tmpiter.split("-")[1])
|
||||
hostnames.extend([prefix + str(i) for i in range(begin, end + 1)])
|
||||
else:
|
||||
prefix = "-".join(tmpiter.split("-")[:-1]) + "-"
|
||||
hostnames.append(tmpiter)
|
||||
|
||||
tmp_nodelist = tmp_nodelist[tmp_res.end() :]
|
||||
tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
|
||||
|
||||
tmplist = tmp_nodelist.split(",")
|
||||
hostnames.extend(tmplist)
|
||||
|
||||
while "" in hostnames:
|
||||
hostnames.remove("")
|
||||
|
||||
return hostnames
|
||||
|
||||
|
||||
def try_bind_numa(launcher):
|
||||
# get numa node number
|
||||
numa_node_num = numa.info.get_max_node() + 1
|
||||
# get total gpu number of current node
|
||||
nvsmi = nvidia_smi.getInstance()
|
||||
total_GPU_per_node = len(nvsmi.DeviceQuery("memory.total")["gpu"])
|
||||
|
||||
# return while total_GPU_per_node is larger than numa_node_num or is not divisible by numa_node_num
|
||||
if total_GPU_per_node <= numa_node_num:
|
||||
return
|
||||
if total_GPU_per_node % numa_node_num != 0:
|
||||
|
@ -546,25 +495,17 @@ def try_bind_numa(launcher):
|
|||
if launcher == "torch":
|
||||
global_rank = int(os.environ["RANK"])
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
elif launcher == "slurm":
|
||||
global_rank = int(os.environ["SLURM_PROCID"])
|
||||
local_rank = int(os.environ["SLURM_LOCALID"])
|
||||
local_world_size_group = os.environ["SLURM_TASKS_PER_NODE"]
|
||||
if "(" in local_world_size_group:
|
||||
local_world_size = int(os.environ["SLURM_TASKS_PER_NODE"].split("(")[0])
|
||||
elif "," in local_world_size_group:
|
||||
local_world_size_list = os.environ["SLURM_TASKS_PER_NODE"].split(",")
|
||||
node_list = get_node_hostnames(os.environ["SLURM_NODELIST"])
|
||||
hostname = socket.gethostname()
|
||||
index = node_list.index(hostname)
|
||||
local_world_size = int(local_world_size_list[index])
|
||||
else:
|
||||
local_world_size = int(local_world_size_group)
|
||||
world_size = int(os.environ["SLURM_NPROCS"])
|
||||
|
||||
if total_GPU_per_node != local_world_size:
|
||||
# return while the number of processes is smaller than one node GPUs num
|
||||
if world_size < total_GPU_per_node:
|
||||
return
|
||||
|
||||
# compute numa id for each locak rank
|
||||
per_numa = total_GPU_per_node // numa_node_num
|
||||
numa_id = local_rank // per_numa
|
||||
|
||||
|
|
Loading…
Reference in New Issue