diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index bea5806..e2be1fe 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -2,10 +2,7 @@ # -*- encoding: utf-8 -*- import argparse -import copy import os -import re -import socket from pathlib import Path from typing import Dict, Union @@ -482,62 +479,14 @@ def get_config_value(config, key, defalut): return value -def get_node_hostnames(nodes: str): - hostnames = [] - - tmp_nodelist = copy.deepcopy(nodes) - tmp_nodelist = tmp_nodelist.replace(" ", "") - - tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist) - while tmp_res: - tmp_res2 = re.search(",", tmp_nodelist) - if tmp_res2 and tmp_res2.start() < tmp_res.start(): - tmp_nodelist2 = tmp_nodelist.split(",") - count = 0 - while count < len(tmp_nodelist2): - if re.search(r"\[", tmp_nodelist2[count]): - break - hostnames.append(tmp_nodelist2[count]) - count += 1 - tmp_nodelist = ",".join(tmp_nodelist2[count:]) - tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist) - node_range = tmp_nodelist[tmp_res.start() : tmp_res.end()] - prefix = tmp_nodelist[: tmp_res.start()].replace(",", "") - - node_range = re.sub(r"[\[\]]", "", node_range) - - tmplist = node_range.split(",") - - pattern1 = r"^\d+$" - pattern2 = r"^\d+-\d+$" - - for tmpiter in tmplist: - if re.match(pattern1, tmpiter): - hostnames.append(prefix + tmpiter) - elif re.match(pattern2, tmpiter): - begin, end = int(tmpiter.split("-")[0]), int(tmpiter.split("-")[1]) - hostnames.extend([prefix + str(i) for i in range(begin, end + 1)]) - else: - prefix = "-".join(tmpiter.split("-")[:-1]) + "-" - hostnames.append(tmpiter) - - tmp_nodelist = tmp_nodelist[tmp_res.end() :] - tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist) - - tmplist = tmp_nodelist.split(",") - hostnames.extend(tmplist) - - while "" in hostnames: - hostnames.remove("") - - return hostnames - - def try_bind_numa(launcher): + # get numa node number numa_node_num = numa.info.get_max_node() + 1 + # get total gpu number of current node nvsmi = nvidia_smi.getInstance() total_GPU_per_node = len(nvsmi.DeviceQuery("memory.total")["gpu"]) + # return while total_GPU_per_node is larger than numa_node_num or is not divisible by numa_node_num if total_GPU_per_node <= numa_node_num: return if total_GPU_per_node % numa_node_num != 0: @@ -546,25 +495,17 @@ def try_bind_numa(launcher): if launcher == "torch": global_rank = int(os.environ["RANK"]) local_rank = int(os.environ["LOCAL_RANK"]) - local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + world_size = int(os.environ["WORLD_SIZE"]) elif launcher == "slurm": global_rank = int(os.environ["SLURM_PROCID"]) local_rank = int(os.environ["SLURM_LOCALID"]) - local_world_size_group = os.environ["SLURM_TASKS_PER_NODE"] - if "(" in local_world_size_group: - local_world_size = int(os.environ["SLURM_TASKS_PER_NODE"].split("(")[0]) - elif "," in local_world_size_group: - local_world_size_list = os.environ["SLURM_TASKS_PER_NODE"].split(",") - node_list = get_node_hostnames(os.environ["SLURM_NODELIST"]) - hostname = socket.gethostname() - index = node_list.index(hostname) - local_world_size = int(local_world_size_list[index]) - else: - local_world_size = int(local_world_size_group) + world_size = int(os.environ["SLURM_NPROCS"]) - if total_GPU_per_node != local_world_size: + # return while the number of processes is smaller than one node GPUs num + if world_size < total_GPU_per_node: return + # compute numa id for each locak rank per_numa = total_GPU_per_node // numa_node_num numa_id = local_rank // per_numa