mirror of https://github.com/InternLM/InternLM
try_bind_numa should not raise exception
parent
83070d3454
commit
c139a05a94
|
@ -16,6 +16,16 @@ from internlm.utils.common import get_master_node
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
from internlm.utils.timeout import llm_timeout
|
from internlm.utils.timeout import llm_timeout
|
||||||
|
|
||||||
|
# check pacakge
|
||||||
|
try:
|
||||||
|
import numa
|
||||||
|
from numa import memory, schedule
|
||||||
|
from pynvml.smi import nvidia_smi
|
||||||
|
except (AttributeError, ImportError):
|
||||||
|
get_numa = False
|
||||||
|
else:
|
||||||
|
get_numa = True
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -385,6 +395,8 @@ def launch_from_slurm(
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise RuntimeError(f"Could not find {e} in the SLURM environment")
|
raise RuntimeError(f"Could not find {e} in the SLURM environment")
|
||||||
|
|
||||||
|
try_bind_numa(global_rank=rank, world_size=world_size)
|
||||||
|
|
||||||
launch(
|
launch(
|
||||||
config=config,
|
config=config,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
|
@ -418,6 +430,8 @@ def launch_from_torch(
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise RuntimeError(f"Could not find {e} in the torch environment")
|
raise RuntimeError(f"Could not find {e} in the torch environment")
|
||||||
|
|
||||||
|
try_bind_numa(global_rank=rank, world_size=world_size, local_rank=local_rank)
|
||||||
|
|
||||||
launch(
|
launch(
|
||||||
config=config,
|
config=config,
|
||||||
local_rank=local_rank,
|
local_rank=local_rank,
|
||||||
|
@ -448,8 +462,6 @@ def initialize_distributed_env(
|
||||||
seed (int, optional): Specified random seed for every process. 1024 by default.
|
seed (int, optional): Specified random seed for every process. 1024 by default.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try_bind_numa(launcher)
|
|
||||||
|
|
||||||
# close automatic garbage collection
|
# close automatic garbage collection
|
||||||
gc.disable()
|
gc.disable()
|
||||||
|
|
||||||
|
@ -489,58 +501,43 @@ def get_config_value(config, key, defalut):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
# check pacakge
|
def try_bind_numa(global_rank, world_size, local_rank=None):
|
||||||
try:
|
|
||||||
import numa
|
|
||||||
from numa import memory, schedule
|
|
||||||
from pynvml.smi import nvidia_smi
|
|
||||||
except (AttributeError, ImportError):
|
|
||||||
get_numa = False
|
|
||||||
global_rank = int(os.environ["SLURM_PROCID"]) if "SLURM_PROCID" in os.environ else int(os.environ["RANK"])
|
|
||||||
if global_rank == 0:
|
|
||||||
logger.info(
|
|
||||||
"Try bind numa failed! Package import error, if numa is not installed, "
|
|
||||||
"please implement: pip install --upgrade py-libnuma"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
get_numa = True
|
|
||||||
|
|
||||||
|
|
||||||
def try_bind_numa(launcher):
|
|
||||||
# Early return if numa module not available
|
# Early return if numa module not available
|
||||||
if not get_numa:
|
if not get_numa:
|
||||||
return
|
if global_rank == 0:
|
||||||
|
logger.info(
|
||||||
|
"Try bind numa failed! Package import error, if numa is not installed, "
|
||||||
|
"please implement: pip install --upgrade py-libnuma, Ref: https://pypi.org/project/py-libnuma/"
|
||||||
|
)
|
||||||
|
|
||||||
# get numa node number
|
# get numa node number
|
||||||
numa_node_num = numa.info.get_max_node() + 1
|
try:
|
||||||
# get total gpu number of current node
|
numa_node_num = numa.info.get_max_node() + 1
|
||||||
nvsmi = nvidia_smi.getInstance()
|
# get total gpu number of current node
|
||||||
total_GPU_per_node = len(nvsmi.DeviceQuery("memory.total")["gpu"])
|
nvsmi = nvidia_smi.getInstance()
|
||||||
|
total_GPU_per_node = len(nvsmi.DeviceQuery("memory.total")["gpu"])
|
||||||
|
|
||||||
# return while total_GPU_per_node is larger than numa_node_num or is not divisible by numa_node_num
|
# return while total_GPU_per_node is larger than numa_node_num or is not divisible by numa_node_num
|
||||||
if total_GPU_per_node <= numa_node_num:
|
if total_GPU_per_node <= numa_node_num:
|
||||||
return
|
return
|
||||||
if total_GPU_per_node % numa_node_num != 0:
|
if total_GPU_per_node % numa_node_num != 0:
|
||||||
return
|
return
|
||||||
|
# return while the number of processes is smaller than one node GPUs num
|
||||||
|
if world_size < total_GPU_per_node:
|
||||||
|
return
|
||||||
|
|
||||||
if launcher == "torch":
|
if local_rank is None:
|
||||||
global_rank = int(os.environ["RANK"])
|
devices_per_node = torch.cuda.device_count()
|
||||||
local_rank = int(os.environ["LOCAL_RANK"])
|
local_rank = global_rank % devices_per_node
|
||||||
world_size = int(os.environ["WORLD_SIZE"])
|
|
||||||
elif launcher == "slurm":
|
|
||||||
global_rank = int(os.environ["SLURM_PROCID"])
|
|
||||||
local_rank = int(os.environ["SLURM_LOCALID"])
|
|
||||||
world_size = int(os.environ["SLURM_NPROCS"])
|
|
||||||
|
|
||||||
# return while the number of processes is smaller than one node GPUs num
|
# compute numa id for each locak rank
|
||||||
if world_size < total_GPU_per_node:
|
per_numa = total_GPU_per_node // numa_node_num
|
||||||
return
|
numa_id = local_rank // per_numa
|
||||||
|
|
||||||
# compute numa id for each locak rank
|
# bind numa node
|
||||||
per_numa = total_GPU_per_node // numa_node_num
|
schedule.run_on_nodes(numa_id)
|
||||||
numa_id = local_rank // per_numa
|
memory.set_membind_nodes(numa_id)
|
||||||
|
except Exception:
|
||||||
# bind numa node
|
return # try_bind_numa should not raise exception
|
||||||
schedule.run_on_nodes(numa_id)
|
else:
|
||||||
memory.set_membind_nodes(numa_id)
|
logger.info(f"Rank: {global_rank} success bind process to numa node: {numa_id}")
|
||||||
logger.info(f"Rank: {global_rank} success bind process to numa node: {numa_id}")
|
|
||||||
|
|
Loading…
Reference in New Issue