mirror of https://github.com/InternLM/InternLM
feat:add numa
parent
5738e7cf50
commit
5d99f0be62
|
@ -42,7 +42,6 @@ def get_default_parser():
|
||||||
parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication")
|
parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication")
|
||||||
parser.add_argument("--seed", type=int, default=1024)
|
parser.add_argument("--seed", type=int, default=1024)
|
||||||
parser.add_argument("--profiling", default=False, action="store_true", help="enable/disable profiling.")
|
parser.add_argument("--profiling", default=False, action="store_true", help="enable/disable profiling.")
|
||||||
parser.add_argument("--bind_numa", default=False, action="store_true", help="enable/disable bind_numa.")
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -441,7 +440,7 @@ def initialize_distributed_env(
|
||||||
master_port (str): The master port for distributed training. 8888 by default.
|
master_port (str): The master port for distributed training. 8888 by default.
|
||||||
seed (int, optional): Specified random seed for every process. 1024 by default.
|
seed (int, optional): Specified random seed for every process. 1024 by default.
|
||||||
"""
|
"""
|
||||||
|
try_bind_numa(launcher)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
if launcher == "torch":
|
if launcher == "torch":
|
||||||
|
|
5
train.py
5
train.py
|
@ -14,7 +14,7 @@ from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.scheduler import SchedulerMetricHook
|
from internlm.core.scheduler import SchedulerMetricHook
|
||||||
from internlm.core.trainer import TrainState
|
from internlm.core.trainer import TrainState
|
||||||
from internlm.initialize import initialize_distributed_env, try_bind_numa
|
from internlm.initialize import initialize_distributed_env
|
||||||
from internlm.model.loss import FlashGPTLMLoss
|
from internlm.model.loss import FlashGPTLMLoss
|
||||||
from internlm.model.metrics import AccPerplex
|
from internlm.model.metrics import AccPerplex
|
||||||
from internlm.monitor import initialize_monitor_manager, send_alert_message
|
from internlm.monitor import initialize_monitor_manager, send_alert_message
|
||||||
|
@ -291,9 +291,6 @@ if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
hostname = socket.gethostname()
|
hostname = socket.gethostname()
|
||||||
|
|
||||||
if args.bind_numa:
|
|
||||||
try_bind_numa(args.launcher)
|
|
||||||
|
|
||||||
# initialize distributed environment
|
# initialize distributed environment
|
||||||
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
|
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
|
||||||
assert hasattr(gpc, "config") and gpc.config is not None
|
assert hasattr(gpc, "config") and gpc.config is not None
|
||||||
|
|
Loading…
Reference in New Issue