try_bind_numa should not raise exception

pull/320/head
877825076@qq.com 2023-09-25 17:50:02 +08:00
parent 83070d3454
commit c139a05a94
1 changed files with 46 additions and 49 deletions

View File

@ -16,6 +16,16 @@ from internlm.utils.common import get_master_node
from internlm.utils.logger import get_logger
from internlm.utils.timeout import llm_timeout
# check pacakge
try:
import numa
from numa import memory, schedule
from pynvml.smi import nvidia_smi
except (AttributeError, ImportError):
get_numa = False
else:
get_numa = True
logger = get_logger(__file__)
@ -385,6 +395,8 @@ def launch_from_slurm(
except KeyError as e:
raise RuntimeError(f"Could not find {e} in the SLURM environment")
try_bind_numa(global_rank=rank, world_size=world_size)
launch(
config=config,
rank=rank,
@ -418,6 +430,8 @@ def launch_from_torch(
except KeyError as e:
raise RuntimeError(f"Could not find {e} in the torch environment")
try_bind_numa(global_rank=rank, world_size=world_size, local_rank=local_rank)
launch(
config=config,
local_rank=local_rank,
@ -448,8 +462,6 @@ def initialize_distributed_env(
seed (int, optional): Specified random seed for every process. 1024 by default.
"""
try_bind_numa(launcher)
# close automatic garbage collection
gc.disable()
@ -489,29 +501,17 @@ def get_config_value(config, key, defalut):
return value
# check pacakge
try:
import numa
from numa import memory, schedule
from pynvml.smi import nvidia_smi
except (AttributeError, ImportError):
get_numa = False
global_rank = int(os.environ["SLURM_PROCID"]) if "SLURM_PROCID" in os.environ else int(os.environ["RANK"])
def try_bind_numa(global_rank, world_size, local_rank=None):
# Early return if numa module not available
if not get_numa:
if global_rank == 0:
logger.info(
"Try bind numa failed! Package import error, if numa is not installed, "
"please implement: pip install --upgrade py-libnuma"
"please implement: pip install --upgrade py-libnuma, Ref: https://pypi.org/project/py-libnuma/"
)
else:
get_numa = True
def try_bind_numa(launcher):
# Early return if numa module not available
if not get_numa:
return
# get numa node number
try:
numa_node_num = numa.info.get_max_node() + 1
# get total gpu number of current node
nvsmi = nvidia_smi.getInstance()
@ -522,20 +522,14 @@ def try_bind_numa(launcher):
return
if total_GPU_per_node % numa_node_num != 0:
return
if launcher == "torch":
global_rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
elif launcher == "slurm":
global_rank = int(os.environ["SLURM_PROCID"])
local_rank = int(os.environ["SLURM_LOCALID"])
world_size = int(os.environ["SLURM_NPROCS"])
# return while the number of processes is smaller than one node GPUs num
if world_size < total_GPU_per_node:
return
if local_rank is None:
devices_per_node = torch.cuda.device_count()
local_rank = global_rank % devices_per_node
# compute numa id for each locak rank
per_numa = total_GPU_per_node // numa_node_num
numa_id = local_rank // per_numa
@ -543,4 +537,7 @@ def try_bind_numa(launcher):
# bind numa node
schedule.run_on_nodes(numa_id)
memory.set_membind_nodes(numa_id)
except Exception:
return # try_bind_numa should not raise exception
else:
logger.info(f"Rank: {global_rank} success bind process to numa node: {numa_id}")