mirror of https://github.com/InternLM/InternLM
feat(init): add skip args check flag and add zero overlap flag (#222)
* feat(init): add skip args check flag * fix(optim): add param overlap enable flagpull/228/head
parent
9cd1e0314e
commit
7c820cfa40
|
@ -75,7 +75,8 @@ grad_scaler = dict(
|
|||
|
||||
hybrid_zero_optimizer = dict(
|
||||
# Enable low_level_optimzer overlap_communication
|
||||
zero_overlap_communication=True,
|
||||
overlap_sync_grad=True,
|
||||
overlap_sync_param=True,
|
||||
# bucket size for nccl communication params
|
||||
reduce_bucket_size=512 * 1024 * 1024,
|
||||
# grad clipping
|
||||
|
@ -120,7 +121,7 @@ model = dict(
|
|||
num_layers=NUM_LAYER,
|
||||
mlp_ratio=MLP_RATIO,
|
||||
apply_post_layer_norm=False,
|
||||
dtype="torch.tf32", # dtype could be in "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32",
|
||||
dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
|
||||
norm_type="rmsnorm",
|
||||
layer_norm_epsilon=1e-5,
|
||||
use_flash_attn=True,
|
||||
|
|
|
@ -1,9 +1,15 @@
|
|||
from .initialize_trainer import initialize_trainer
|
||||
from .launch import get_default_parser, launch_from_slurm, launch_from_torch
|
||||
from .launch import (
|
||||
get_default_parser,
|
||||
initialize_distributed_env,
|
||||
launch_from_slurm,
|
||||
launch_from_torch,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_default_parser",
|
||||
"initialize_trainer",
|
||||
"launch_from_slurm",
|
||||
"launch_from_torch",
|
||||
"initialize_distributed_env",
|
||||
]
|
||||
|
|
|
@ -10,6 +10,7 @@ import torch
|
|||
|
||||
from internlm.core.context import Config
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.common import get_master_node
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.storage_manager import init_storage_manager
|
||||
|
||||
|
@ -276,6 +277,19 @@ and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
|
|||
if "alert_address" not in gpc.config:
|
||||
gpc.config._add_item("alert_address", None)
|
||||
|
||||
optim_ckpt = gpc.config.hybrid_zero_optimizer
|
||||
if "zero_overlap_communication" in optim_ckpt:
|
||||
# Compatible with the old interfaces.
|
||||
optim_ckpt._add_item("overlap_sync_grad", optim_ckpt.zero_overlap_communication)
|
||||
if "overlap_sync_grad" not in optim_ckpt:
|
||||
optim_ckpt._add_item("overlap_sync_grad", False)
|
||||
if "overlap_sync_param" not in optim_ckpt:
|
||||
optim_ckpt._add_item("overlap_sync_param", False)
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(
|
||||
f"overlap_sync_grad:{optim_ckpt.overlap_sync_grad}, overlap_sync_param:{optim_ckpt.overlap_sync_param}"
|
||||
)
|
||||
|
||||
|
||||
def launch(
|
||||
config: Union[str, Path, Config, Dict],
|
||||
|
@ -322,8 +336,6 @@ def launch(
|
|||
# init process groups for different parallel modes from config
|
||||
gpc.init_parallel_groups()
|
||||
|
||||
args_sanity_check()
|
||||
|
||||
# set cuda device
|
||||
if torch.cuda.is_available():
|
||||
# if local rank is not given, calculate automatically
|
||||
|
@ -376,7 +388,11 @@ def launch_from_slurm(
|
|||
)
|
||||
|
||||
|
||||
def launch_from_torch(config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024):
|
||||
def launch_from_torch(
|
||||
config: Union[str, Path, Config, Dict],
|
||||
backend: str = "nccl",
|
||||
seed: int = 1024,
|
||||
):
|
||||
"""A wrapper for internlm.launch for torchrun or torch.distributed.launch by reading rank and world size
|
||||
from the environment variables set by PyTorch
|
||||
|
||||
|
@ -404,3 +420,38 @@ def launch_from_torch(config: Union[str, Path, Config, Dict], backend: str = "nc
|
|||
backend=backend,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
|
||||
def initialize_distributed_env(
|
||||
config: str,
|
||||
launcher: str = "slurm",
|
||||
master_port: int = 8888,
|
||||
seed: int = 1024,
|
||||
args_check=True,
|
||||
):
|
||||
"""
|
||||
Initialize distributed environment for distributed training.
|
||||
|
||||
Args:
|
||||
config (str): Config file path.
|
||||
launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default.
|
||||
master_port (str): The master port for distributed training. 8888 by default.
|
||||
seed (int, optional): Specified random seed for every process. 1024 by default.
|
||||
"""
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if launcher == "torch":
|
||||
launch_from_torch(config=config, seed=seed)
|
||||
elif launcher == "slurm":
|
||||
launch_from_slurm(
|
||||
config=config,
|
||||
host=get_master_node(),
|
||||
port=master_port,
|
||||
seed=seed,
|
||||
)
|
||||
else:
|
||||
assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"
|
||||
|
||||
if args_check:
|
||||
args_sanity_check()
|
||||
|
|
|
@ -133,7 +133,7 @@ class MHA(nn.Module):
|
|||
|
||||
if inference_params is None:
|
||||
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
|
||||
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
||||
if qkv.dtype not in [torch.float16, torch.bfloat16]:
|
||||
qkv = qkv.to(torch.bfloat16)
|
||||
context = self.inner_attn(qkv).to(x.dtype)
|
||||
|
@ -171,7 +171,7 @@ class MHA(nn.Module):
|
|||
|
||||
if inference_params is None:
|
||||
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
|
||||
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
||||
if qkv.dtype not in [torch.float16, torch.bfloat16]:
|
||||
qkv = qkv.to(torch.bfloat16)
|
||||
context = self.inner_attn(qkv, **kwargs).to(x.dtype)
|
||||
|
|
|
@ -106,9 +106,10 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
max_scale = grad_scal_cfg.max_scale
|
||||
|
||||
# Zero related args
|
||||
overlap_communication = zero_cfg.zero_overlap_communication
|
||||
reduce_bucket_size = zero_cfg.reduce_bucket_size
|
||||
clip_grad_norm = zero_cfg.clip_grad_norm
|
||||
self._overlap_sync_grad = zero_cfg.overlap_sync_grad
|
||||
self._overlap_sync_param = zero_cfg.overlap_sync_param
|
||||
|
||||
super().__init__(optim=optimizer)
|
||||
|
||||
|
@ -129,7 +130,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self._fp32_flat_param_groups_of_current_rank = dict()
|
||||
|
||||
# communication params
|
||||
self._overlap_communication = overlap_communication
|
||||
# self._overlap_communication = overlap_communication
|
||||
self._reduce_bucket_size = reduce_bucket_size
|
||||
|
||||
# gradient scaler
|
||||
|
@ -161,8 +162,11 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
)
|
||||
self.params_per_rank_id_dict = []
|
||||
self._param_bcast_sync_handler = param_bcast_sync_handler
|
||||
if self._overlap_communication:
|
||||
if self._overlap_sync_param:
|
||||
assert self._param_bcast_sync_handler is not None
|
||||
self._broadcast_comm_stream = torch.cuda.Stream()
|
||||
else:
|
||||
self._broadcast_comm_stream = torch.cuda.current_stream()
|
||||
|
||||
# iterate over the param group in the optimizer
|
||||
# partition these param groups for data parallel training
|
||||
|
@ -232,14 +236,14 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
# initialize communication stream for
|
||||
# communication-computation overlapping
|
||||
if self._overlap_communication:
|
||||
if self._overlap_sync_grad:
|
||||
self._comm_stream = torch.cuda.Stream()
|
||||
else:
|
||||
self._comm_stream = torch.cuda.current_stream()
|
||||
|
||||
# reduction hook is only used if overlapping communication
|
||||
# if it is stage 1 without overlapping, no hook will be attached
|
||||
if self._overlap_communication:
|
||||
if self._overlap_sync_grad:
|
||||
self._attach_reduction_hook()
|
||||
|
||||
@property
|
||||
|
@ -273,7 +277,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
global_id = str(i)
|
||||
for j in range(len(param.size())):
|
||||
global_id = "_".join([global_id, str(param.size()[j])])
|
||||
if self._overlap_communication:
|
||||
if self._overlap_sync_param:
|
||||
rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param)
|
||||
else:
|
||||
rank_to_go = numel_per_rank.index(min(numel_per_rank))
|
||||
|
@ -394,7 +398,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
|
||||
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
|
||||
if self._overlap_communication:
|
||||
if self._overlap_sync_grad:
|
||||
self._comm_stream.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
|
||||
|
@ -517,7 +521,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
# if not overlapping communication (no reduction hook is attached)
|
||||
# we need to manually reduce these gradients
|
||||
if not self._overlap_communication:
|
||||
if not self._overlap_sync_grad:
|
||||
for group_id in range(len(self._fp16_param_groups)):
|
||||
for param in self._fp16_param_groups[group_id]:
|
||||
if param.grad is not None:
|
||||
|
@ -532,7 +536,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
|
||||
|
||||
# clear reduced grads
|
||||
if self._overlap_communication:
|
||||
if self._overlap_sync_grad:
|
||||
# grads in the last bucket is reduced
|
||||
self._comm_stream.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
|
@ -641,7 +645,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||
fp16_param.data.copy_(fp32_param)
|
||||
|
||||
with torch.cuda.stream(self._comm_stream):
|
||||
with torch.cuda.stream(self._broadcast_comm_stream):
|
||||
self.broadcast_params()
|
||||
|
||||
timer("step").stop()
|
||||
|
@ -668,7 +672,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
async_op=True,
|
||||
)
|
||||
|
||||
if self._overlap_communication:
|
||||
if self._overlap_sync_param:
|
||||
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
|
||||
else:
|
||||
handles.append(handle)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from .training_internlm import (
|
||||
get_train_data_loader,
|
||||
get_validation_data_loader,
|
||||
initialize_distributed_env,
|
||||
initialize_llm_profile,
|
||||
initialize_model,
|
||||
initialize_optimizer,
|
||||
|
@ -12,7 +11,6 @@ from .training_internlm import (
|
|||
__all__ = [
|
||||
"get_train_data_loader",
|
||||
"get_validation_data_loader",
|
||||
"initialize_distributed_env",
|
||||
"initialize_llm_profile",
|
||||
"initialize_model",
|
||||
"initialize_optimizer",
|
||||
|
|
|
@ -10,7 +10,6 @@ import torch.distributed as dist
|
|||
from torch import nn
|
||||
from torch.utils.data import ConcatDataset, DataLoader
|
||||
|
||||
import internlm
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
|
@ -31,7 +30,7 @@ from internlm.solver.beta2_scheduler import Beta2Scheduler
|
|||
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
||||
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
|
||||
from internlm.utils.common import DummyProfile, get_master_node
|
||||
from internlm.utils.common import DummyProfile
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.parallel import (
|
||||
|
@ -44,32 +43,6 @@ from internlm.utils.registry import MODEL_INITIALIZER
|
|||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def initialize_distributed_env(config: str, launcher: str = "slurm", master_port: int = 8888, seed: int = 1024):
|
||||
"""
|
||||
Initialize distributed environment for distributed training.
|
||||
|
||||
Args:
|
||||
config (str): Config file path.
|
||||
launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default.
|
||||
master_port (str): The master port for distributed training. 8888 by default.
|
||||
seed (int, optional): Specified random seed for every process. 1024 by default.
|
||||
"""
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if launcher == "torch":
|
||||
internlm.launch_from_torch(config=config, seed=seed)
|
||||
elif launcher == "slurm":
|
||||
internlm.launch_from_slurm(
|
||||
config=config,
|
||||
host=get_master_node(),
|
||||
port=master_port,
|
||||
seed=seed,
|
||||
)
|
||||
else:
|
||||
assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"
|
||||
|
||||
|
||||
def initialize_model():
|
||||
"""
|
||||
Initialize model.
|
||||
|
@ -119,7 +92,11 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
|||
|
||||
Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler).
|
||||
"""
|
||||
param_bcast_sync_handler = ParamBcastSyncHandler(model)
|
||||
if gpc.config.hybrid_zero_optimizer.overlap_sync_param:
|
||||
param_bcast_sync_handler = ParamBcastSyncHandler(model)
|
||||
else:
|
||||
param_bcast_sync_handler = None
|
||||
|
||||
adam_cfg = gpc.config.adam
|
||||
naive_optimizer = torch.optim.AdamW(
|
||||
params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}],
|
||||
|
|
2
train.py
2
train.py
|
@ -15,6 +15,7 @@ from internlm.core.context import ParallelMode
|
|||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.scheduler import SchedulerMetricHook
|
||||
from internlm.core.trainer import TrainState
|
||||
from internlm.initialize import initialize_distributed_env
|
||||
from internlm.model.loss import FlashGPTLMLoss
|
||||
from internlm.model.metrics import AccPerplex
|
||||
from internlm.monitor import initialize_monitor_manager, send_alert_message
|
||||
|
@ -22,7 +23,6 @@ from internlm.monitor.monitor import monitor_manager as mm
|
|||
from internlm.train import (
|
||||
get_train_data_loader,
|
||||
get_validation_data_loader,
|
||||
initialize_distributed_env,
|
||||
initialize_llm_profile,
|
||||
initialize_model,
|
||||
initialize_optimizer,
|
||||
|
|
Loading…
Reference in New Issue