feat:add numa

pull/320/head
li126com 2023-09-25 13:09:56 +08:00
parent 5738e7cf50
commit 5d99f0be62
2 changed files with 2 additions and 6 deletions

View File

@ -42,7 +42,6 @@ def get_default_parser():
parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication") parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication")
parser.add_argument("--seed", type=int, default=1024) parser.add_argument("--seed", type=int, default=1024)
parser.add_argument("--profiling", default=False, action="store_true", help="enable/disable profiling.") parser.add_argument("--profiling", default=False, action="store_true", help="enable/disable profiling.")
parser.add_argument("--bind_numa", default=False, action="store_true", help="enable/disable bind_numa.")
return parser return parser
@ -441,7 +440,7 @@ def initialize_distributed_env(
master_port (str): The master port for distributed training. 8888 by default. master_port (str): The master port for distributed training. 8888 by default.
seed (int, optional): Specified random seed for every process. 1024 by default. seed (int, optional): Specified random seed for every process. 1024 by default.
""" """
try_bind_numa(launcher)
torch.cuda.empty_cache() torch.cuda.empty_cache()
if launcher == "torch": if launcher == "torch":

View File

@ -14,7 +14,7 @@ from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.core.scheduler import SchedulerMetricHook from internlm.core.scheduler import SchedulerMetricHook
from internlm.core.trainer import TrainState from internlm.core.trainer import TrainState
from internlm.initialize import initialize_distributed_env, try_bind_numa from internlm.initialize import initialize_distributed_env
from internlm.model.loss import FlashGPTLMLoss from internlm.model.loss import FlashGPTLMLoss
from internlm.model.metrics import AccPerplex from internlm.model.metrics import AccPerplex
from internlm.monitor import initialize_monitor_manager, send_alert_message from internlm.monitor import initialize_monitor_manager, send_alert_message
@ -291,9 +291,6 @@ if __name__ == "__main__":
args = parse_args() args = parse_args()
hostname = socket.gethostname() hostname = socket.gethostname()
if args.bind_numa:
try_bind_numa(args.launcher)
# initialize distributed environment # initialize distributed environment
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed) initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
assert hasattr(gpc, "config") and gpc.config is not None assert hasattr(gpc, "config") and gpc.config is not None