diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 2a1ccbc..7a9ede5 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -477,19 +477,26 @@ def get_config_value(config, key, defalut): return value +# check pacakge +try: + import numa + from numa import memory, schedule + from pynvml.smi import nvidia_smi +except (AttributeError, ImportError): + get_numa = False + global_rank = int(os.environ["SLURM_PROCID"]) if "SLURM_PROCID" in os.environ else int(os.environ["RANK"]) + if global_rank == 0: + logger.info( + "Try bind numa failed! Package import error, if numa is not installed, " + "please implement: pip install --upgrade py-libnuma" + ) +else: + get_numa = True + + def try_bind_numa(launcher): - # check pacakge - try: - import numa - from numa import memory, schedule - from pynvml.smi import nvidia_smi - except (AttributeError, ImportError): - global_rank = int(os.environ["SLURM_PROCID"]) if launcher == "slurm" else int(os.environ["RANK"]) - if global_rank == 0: - logger.info( - "Try bind numa failed! Package import error, if numa is not installed, " - "please implement: pip install --upgrade py-libnuma" - ) + # Early return if numa module not available + if not get_numa: return # get numa node number