From cb34b9325bafdf058d56a35cabf8b9a38be81a8d Mon Sep 17 00:00:00 2001 From: li126com Date: Tue, 19 Sep 2023 14:34:55 +0800 Subject: [PATCH] feat:add bind numa --- internlm/initialize/launch.py | 74 ++++++++++++++++++++++++++++++++--- train.py | 3 +- 2 files changed, 70 insertions(+), 7 deletions(-) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index a771df6..2683750 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -2,7 +2,10 @@ # -*- encoding: utf-8 -*- import argparse +import copy import os +import re +import socket from pathlib import Path from typing import Dict, Union @@ -44,6 +47,7 @@ def get_default_parser(): parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication") parser.add_argument("--seed", type=int, default=1024) parser.add_argument("--profiling", default=False, action="store_true", help="enable/disable profiling.") + parser.add_argument("--bind_numa", default=False, action="store_true", help="enable/disable bind_numa.") return parser @@ -478,8 +482,58 @@ def get_config_value(config, key, defalut): return value -def try_bind_numa(launcher): +def get_node_hostnames(nodes: str): + hostnames = [] + tmp_nodelist = copy.deepcopy(nodes) + tmp_nodelist = tmp_nodelist.replace(" ", "") + + tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist) + while tmp_res: + tmp_res2 = re.search(",", tmp_nodelist) + if tmp_res2 and tmp_res2.start() < tmp_res.start(): + tmp_nodelist2 = tmp_nodelist.split(",") + count = 0 + while count < len(tmp_nodelist2): + if re.search(r"\[", tmp_nodelist2[count]): + break + hostnames.append(tmp_nodelist2[count]) + count += 1 + tmp_nodelist = ",".join(tmp_nodelist2[count:]) + tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist) + node_range = tmp_nodelist[tmp_res.start() : tmp_res.end()] + prefix = tmp_nodelist[: tmp_res.start()].replace(",", "") + + node_range = re.sub(r"[\[\]]", "", node_range) + + tmplist = node_range.split(",") + + pattern1 = r"^\d+$" + pattern2 = r"^\d+-\d+$" + + for tmpiter in tmplist: + if re.match(pattern1, tmpiter): + hostnames.append(prefix + tmpiter) + elif re.match(pattern2, tmpiter): + begin, end = int(tmpiter.split("-")[0]), int(tmpiter.split("-")[1]) + hostnames.extend([prefix + str(i) for i in range(begin, end + 1)]) + else: + prefix = "-".join(tmpiter.split("-")[:-1]) + "-" + hostnames.append(tmpiter) + + tmp_nodelist = tmp_nodelist[tmp_res.end() :] + tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist) + + tmplist = tmp_nodelist.split(",") + hostnames.extend(tmplist) + + while "" in hostnames: + hostnames.remove("") + + return hostnames + + +def try_bind_numa(launcher): numa_node_num = numa.info.get_max_node() + 1 nvsmi = nvidia_smi.getInstance() total_GPU_per_node = len(nvsmi.DeviceQuery("memory.total")["gpu"]) @@ -491,16 +545,24 @@ def try_bind_numa(launcher): if launcher == "torch": global_rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) local_rank = int(os.environ["LOCAL_RANK"]) - nnodes = int(os.environ["NUM_NODES"]) + local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) elif launcher == "slurm": global_rank = int(os.environ["SLURM_PROCID"]) - world_size = int(os.environ["SLURM_NPROCS"]) local_rank = int(os.environ["SLURM_LOCALID"]) - nnodes = int(os.environ["SLURM_NNODES"]) + local_world_size_group = os.environ["SLURM_TASKS_PER_NODE"] + if "(" in local_world_size_group: + local_world_size = int(os.environ["SLURM_TASKS_PER_NODE"].split("(")[0]) + elif "," in local_world_size_group: + local_world_size_list = os.environ["SLURM_TASKS_PER_NODE"].split(",") + node_list = get_node_hostnames(os.environ["SLURM_NODELIST"]) + hostname = socket.gethostname() + index = node_list.index(hostname) + local_world_size = int(local_world_size_list[index]) + else: + local_world_size = int(local_world_size_group) - if world_size % nnodes != 0 and world_size // nnodes != total_GPU_per_node: + if total_GPU_per_node != local_world_size: return per_numa = total_GPU_per_node // numa_node_num diff --git a/train.py b/train.py index 7441147..ede175a 100644 --- a/train.py +++ b/train.py @@ -291,7 +291,8 @@ if __name__ == "__main__": args = parse_args() hostname = socket.gethostname() - try_bind_numa(args.launcher) + if args.bind_numa: + try_bind_numa(args.launcher) # initialize distributed environment initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)