mirror of https://github.com/InternLM/InternLM
				
				
				
			feat:add bind numa
							parent
							
								
									ca5858eb85
								
							
						
					
					
						commit
						cb34b9325b
					
				| 
						 | 
				
			
			@ -2,7 +2,10 @@
 | 
			
		|||
# -*- encoding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import copy
 | 
			
		||||
import os
 | 
			
		||||
import re
 | 
			
		||||
import socket
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Dict, Union
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -44,6 +47,7 @@ def get_default_parser():
 | 
			
		|||
    parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication")
 | 
			
		||||
    parser.add_argument("--seed", type=int, default=1024)
 | 
			
		||||
    parser.add_argument("--profiling", default=False, action="store_true", help="enable/disable profiling.")
 | 
			
		||||
    parser.add_argument("--bind_numa", default=False, action="store_true", help="enable/disable bind_numa.")
 | 
			
		||||
    return parser
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -478,8 +482,58 @@ def get_config_value(config, key, defalut):
 | 
			
		|||
    return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def try_bind_numa(launcher):
 | 
			
		||||
def get_node_hostnames(nodes: str):
 | 
			
		||||
    hostnames = []
 | 
			
		||||
 | 
			
		||||
    tmp_nodelist = copy.deepcopy(nodes)
 | 
			
		||||
    tmp_nodelist = tmp_nodelist.replace(" ", "")
 | 
			
		||||
 | 
			
		||||
    tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
 | 
			
		||||
    while tmp_res:
 | 
			
		||||
        tmp_res2 = re.search(",", tmp_nodelist)
 | 
			
		||||
        if tmp_res2 and tmp_res2.start() < tmp_res.start():
 | 
			
		||||
            tmp_nodelist2 = tmp_nodelist.split(",")
 | 
			
		||||
            count = 0
 | 
			
		||||
            while count < len(tmp_nodelist2):
 | 
			
		||||
                if re.search(r"\[", tmp_nodelist2[count]):
 | 
			
		||||
                    break
 | 
			
		||||
                hostnames.append(tmp_nodelist2[count])
 | 
			
		||||
                count += 1
 | 
			
		||||
            tmp_nodelist = ",".join(tmp_nodelist2[count:])
 | 
			
		||||
            tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
 | 
			
		||||
        node_range = tmp_nodelist[tmp_res.start() : tmp_res.end()]
 | 
			
		||||
        prefix = tmp_nodelist[: tmp_res.start()].replace(",", "")
 | 
			
		||||
 | 
			
		||||
        node_range = re.sub(r"[\[\]]", "", node_range)
 | 
			
		||||
 | 
			
		||||
        tmplist = node_range.split(",")
 | 
			
		||||
 | 
			
		||||
        pattern1 = r"^\d+$"
 | 
			
		||||
        pattern2 = r"^\d+-\d+$"
 | 
			
		||||
 | 
			
		||||
        for tmpiter in tmplist:
 | 
			
		||||
            if re.match(pattern1, tmpiter):
 | 
			
		||||
                hostnames.append(prefix + tmpiter)
 | 
			
		||||
            elif re.match(pattern2, tmpiter):
 | 
			
		||||
                begin, end = int(tmpiter.split("-")[0]), int(tmpiter.split("-")[1])
 | 
			
		||||
                hostnames.extend([prefix + str(i) for i in range(begin, end + 1)])
 | 
			
		||||
            else:
 | 
			
		||||
                prefix = "-".join(tmpiter.split("-")[:-1]) + "-"
 | 
			
		||||
                hostnames.append(tmpiter)
 | 
			
		||||
 | 
			
		||||
        tmp_nodelist = tmp_nodelist[tmp_res.end() :]
 | 
			
		||||
        tmp_res = re.search(r"\[[\d\-,]+\]", tmp_nodelist)
 | 
			
		||||
 | 
			
		||||
    tmplist = tmp_nodelist.split(",")
 | 
			
		||||
    hostnames.extend(tmplist)
 | 
			
		||||
 | 
			
		||||
    while "" in hostnames:
 | 
			
		||||
        hostnames.remove("")
 | 
			
		||||
 | 
			
		||||
    return hostnames
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def try_bind_numa(launcher):
 | 
			
		||||
    numa_node_num = numa.info.get_max_node() + 1
 | 
			
		||||
    nvsmi = nvidia_smi.getInstance()
 | 
			
		||||
    total_GPU_per_node = len(nvsmi.DeviceQuery("memory.total")["gpu"])
 | 
			
		||||
| 
						 | 
				
			
			@ -491,16 +545,24 @@ def try_bind_numa(launcher):
 | 
			
		|||
 | 
			
		||||
    if launcher == "torch":
 | 
			
		||||
        global_rank = int(os.environ["RANK"])
 | 
			
		||||
        world_size = int(os.environ["WORLD_SIZE"])
 | 
			
		||||
        local_rank = int(os.environ["LOCAL_RANK"])
 | 
			
		||||
        nnodes = int(os.environ["NUM_NODES"])
 | 
			
		||||
        local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
 | 
			
		||||
    elif launcher == "slurm":
 | 
			
		||||
        global_rank = int(os.environ["SLURM_PROCID"])
 | 
			
		||||
        world_size = int(os.environ["SLURM_NPROCS"])
 | 
			
		||||
        local_rank = int(os.environ["SLURM_LOCALID"])
 | 
			
		||||
        nnodes = int(os.environ["SLURM_NNODES"])
 | 
			
		||||
        local_world_size_group = os.environ["SLURM_TASKS_PER_NODE"]
 | 
			
		||||
        if "(" in local_world_size_group:
 | 
			
		||||
            local_world_size = int(os.environ["SLURM_TASKS_PER_NODE"].split("(")[0])
 | 
			
		||||
        elif "," in local_world_size_group:
 | 
			
		||||
            local_world_size_list = os.environ["SLURM_TASKS_PER_NODE"].split(",")
 | 
			
		||||
            node_list = get_node_hostnames(os.environ["SLURM_NODELIST"])
 | 
			
		||||
            hostname = socket.gethostname()
 | 
			
		||||
            index = node_list.index(hostname)
 | 
			
		||||
            local_world_size = int(local_world_size_list[index])
 | 
			
		||||
        else:
 | 
			
		||||
            local_world_size = int(local_world_size_group)
 | 
			
		||||
 | 
			
		||||
    if world_size % nnodes != 0 and world_size // nnodes != total_GPU_per_node:
 | 
			
		||||
    if total_GPU_per_node != local_world_size:
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    per_numa = total_GPU_per_node // numa_node_num
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										3
									
								
								train.py
								
								
								
								
							
							
						
						
									
										3
									
								
								train.py
								
								
								
								
							| 
						 | 
				
			
			@ -291,7 +291,8 @@ if __name__ == "__main__":
 | 
			
		|||
    args = parse_args()
 | 
			
		||||
    hostname = socket.gethostname()
 | 
			
		||||
 | 
			
		||||
    try_bind_numa(args.launcher)
 | 
			
		||||
    if args.bind_numa:
 | 
			
		||||
        try_bind_numa(args.launcher)
 | 
			
		||||
 | 
			
		||||
    # initialize distributed environment
 | 
			
		||||
    initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue