feat(init): add skip args check flag and add zero overlap flag (#222)

* feat(init): add skip args check flag

* fix(optim): add param overlap enable flag
pull/228/head
Guoteng 2023-08-24 16:44:18 +08:00 committed by GitHub
parent 9cd1e0314e
commit 7c820cfa40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 88 additions and 51 deletions

View File

@ -75,7 +75,8 @@ grad_scaler = dict(
hybrid_zero_optimizer = dict(
# Enable low_level_optimzer overlap_communication
zero_overlap_communication=True,
overlap_sync_grad=True,
overlap_sync_param=True,
# bucket size for nccl communication params
reduce_bucket_size=512 * 1024 * 1024,
# grad clipping
@ -120,7 +121,7 @@ model = dict(
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
apply_post_layer_norm=False,
dtype="torch.tf32", # dtype could be in "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32",
dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,

View File

@ -1,9 +1,15 @@
from .initialize_trainer import initialize_trainer
from .launch import get_default_parser, launch_from_slurm, launch_from_torch
from .launch import (
get_default_parser,
initialize_distributed_env,
launch_from_slurm,
launch_from_torch,
)
__all__ = [
"get_default_parser",
"initialize_trainer",
"launch_from_slurm",
"launch_from_torch",
"initialize_distributed_env",
]

View File

@ -10,6 +10,7 @@ import torch
from internlm.core.context import Config
from internlm.core.context import global_context as gpc
from internlm.utils.common import get_master_node
from internlm.utils.logger import get_logger
from internlm.utils.storage_manager import init_storage_manager
@ -276,6 +277,19 @@ and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
if "alert_address" not in gpc.config:
gpc.config._add_item("alert_address", None)
optim_ckpt = gpc.config.hybrid_zero_optimizer
if "zero_overlap_communication" in optim_ckpt:
# Compatible with the old interfaces.
optim_ckpt._add_item("overlap_sync_grad", optim_ckpt.zero_overlap_communication)
if "overlap_sync_grad" not in optim_ckpt:
optim_ckpt._add_item("overlap_sync_grad", False)
if "overlap_sync_param" not in optim_ckpt:
optim_ckpt._add_item("overlap_sync_param", False)
if gpc.is_rank_for_log():
logger.info(
f"overlap_sync_grad:{optim_ckpt.overlap_sync_grad}, overlap_sync_param:{optim_ckpt.overlap_sync_param}"
)
def launch(
config: Union[str, Path, Config, Dict],
@ -322,8 +336,6 @@ def launch(
# init process groups for different parallel modes from config
gpc.init_parallel_groups()
args_sanity_check()
# set cuda device
if torch.cuda.is_available():
# if local rank is not given, calculate automatically
@ -376,7 +388,11 @@ def launch_from_slurm(
)
def launch_from_torch(config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024):
def launch_from_torch(
config: Union[str, Path, Config, Dict],
backend: str = "nccl",
seed: int = 1024,
):
"""A wrapper for internlm.launch for torchrun or torch.distributed.launch by reading rank and world size
from the environment variables set by PyTorch
@ -404,3 +420,38 @@ def launch_from_torch(config: Union[str, Path, Config, Dict], backend: str = "nc
backend=backend,
seed=seed,
)
def initialize_distributed_env(
config: str,
launcher: str = "slurm",
master_port: int = 8888,
seed: int = 1024,
args_check=True,
):
"""
Initialize distributed environment for distributed training.
Args:
config (str): Config file path.
launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default.
master_port (str): The master port for distributed training. 8888 by default.
seed (int, optional): Specified random seed for every process. 1024 by default.
"""
torch.cuda.empty_cache()
if launcher == "torch":
launch_from_torch(config=config, seed=seed)
elif launcher == "slurm":
launch_from_slurm(
config=config,
host=get_master_node(),
port=master_port,
seed=seed,
)
else:
assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"
if args_check:
args_sanity_check()

View File

@ -133,7 +133,7 @@ class MHA(nn.Module):
if inference_params is None:
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
with torch.cuda.amp.autocast(dtype=torch.float16):
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
if qkv.dtype not in [torch.float16, torch.bfloat16]:
qkv = qkv.to(torch.bfloat16)
context = self.inner_attn(qkv).to(x.dtype)
@ -171,7 +171,7 @@ class MHA(nn.Module):
if inference_params is None:
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
with torch.cuda.amp.autocast(dtype=torch.float16):
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
if qkv.dtype not in [torch.float16, torch.bfloat16]:
qkv = qkv.to(torch.bfloat16)
context = self.inner_attn(qkv, **kwargs).to(x.dtype)

View File

@ -106,9 +106,10 @@ class HybridZeroOptimizer(BaseOptimizer):
max_scale = grad_scal_cfg.max_scale
# Zero related args
overlap_communication = zero_cfg.zero_overlap_communication
reduce_bucket_size = zero_cfg.reduce_bucket_size
clip_grad_norm = zero_cfg.clip_grad_norm
self._overlap_sync_grad = zero_cfg.overlap_sync_grad
self._overlap_sync_param = zero_cfg.overlap_sync_param
super().__init__(optim=optimizer)
@ -129,7 +130,7 @@ class HybridZeroOptimizer(BaseOptimizer):
self._fp32_flat_param_groups_of_current_rank = dict()
# communication params
self._overlap_communication = overlap_communication
# self._overlap_communication = overlap_communication
self._reduce_bucket_size = reduce_bucket_size
# gradient scaler
@ -161,8 +162,11 @@ class HybridZeroOptimizer(BaseOptimizer):
)
self.params_per_rank_id_dict = []
self._param_bcast_sync_handler = param_bcast_sync_handler
if self._overlap_communication:
if self._overlap_sync_param:
assert self._param_bcast_sync_handler is not None
self._broadcast_comm_stream = torch.cuda.Stream()
else:
self._broadcast_comm_stream = torch.cuda.current_stream()
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
@ -232,14 +236,14 @@ class HybridZeroOptimizer(BaseOptimizer):
# initialize communication stream for
# communication-computation overlapping
if self._overlap_communication:
if self._overlap_sync_grad:
self._comm_stream = torch.cuda.Stream()
else:
self._comm_stream = torch.cuda.current_stream()
# reduction hook is only used if overlapping communication
# if it is stage 1 without overlapping, no hook will be attached
if self._overlap_communication:
if self._overlap_sync_grad:
self._attach_reduction_hook()
@property
@ -273,7 +277,7 @@ class HybridZeroOptimizer(BaseOptimizer):
global_id = str(i)
for j in range(len(param.size())):
global_id = "_".join([global_id, str(param.size()[j])])
if self._overlap_communication:
if self._overlap_sync_param:
rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param)
else:
rank_to_go = numel_per_rank.index(min(numel_per_rank))
@ -394,7 +398,7 @@ class HybridZeroOptimizer(BaseOptimizer):
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
if self._overlap_communication:
if self._overlap_sync_grad:
self._comm_stream.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
@ -517,7 +521,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# if not overlapping communication (no reduction hook is attached)
# we need to manually reduce these gradients
if not self._overlap_communication:
if not self._overlap_sync_grad:
for group_id in range(len(self._fp16_param_groups)):
for param in self._fp16_param_groups[group_id]:
if param.grad is not None:
@ -532,7 +536,7 @@ class HybridZeroOptimizer(BaseOptimizer):
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
# clear reduced grads
if self._overlap_communication:
if self._overlap_sync_grad:
# grads in the last bucket is reduced
self._comm_stream.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
@ -641,7 +645,7 @@ class HybridZeroOptimizer(BaseOptimizer):
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp16_param.data.copy_(fp32_param)
with torch.cuda.stream(self._comm_stream):
with torch.cuda.stream(self._broadcast_comm_stream):
self.broadcast_params()
timer("step").stop()
@ -668,7 +672,7 @@ class HybridZeroOptimizer(BaseOptimizer):
async_op=True,
)
if self._overlap_communication:
if self._overlap_sync_param:
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
else:
handles.append(handle)

View File

@ -1,7 +1,6 @@
from .training_internlm import (
get_train_data_loader,
get_validation_data_loader,
initialize_distributed_env,
initialize_llm_profile,
initialize_model,
initialize_optimizer,
@ -12,7 +11,6 @@ from .training_internlm import (
__all__ = [
"get_train_data_loader",
"get_validation_data_loader",
"initialize_distributed_env",
"initialize_llm_profile",
"initialize_model",
"initialize_optimizer",

View File

@ -10,7 +10,6 @@ import torch.distributed as dist
from torch import nn
from torch.utils.data import ConcatDataset, DataLoader
import internlm
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
@ -31,7 +30,7 @@ from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
from internlm.utils.common import DummyProfile, get_master_node
from internlm.utils.common import DummyProfile
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import (
@ -44,32 +43,6 @@ from internlm.utils.registry import MODEL_INITIALIZER
logger = get_logger(__file__)
def initialize_distributed_env(config: str, launcher: str = "slurm", master_port: int = 8888, seed: int = 1024):
"""
Initialize distributed environment for distributed training.
Args:
config (str): Config file path.
launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default.
master_port (str): The master port for distributed training. 8888 by default.
seed (int, optional): Specified random seed for every process. 1024 by default.
"""
torch.cuda.empty_cache()
if launcher == "torch":
internlm.launch_from_torch(config=config, seed=seed)
elif launcher == "slurm":
internlm.launch_from_slurm(
config=config,
host=get_master_node(),
port=master_port,
seed=seed,
)
else:
assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"
def initialize_model():
"""
Initialize model.
@ -119,7 +92,11 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler).
"""
param_bcast_sync_handler = ParamBcastSyncHandler(model)
if gpc.config.hybrid_zero_optimizer.overlap_sync_param:
param_bcast_sync_handler = ParamBcastSyncHandler(model)
else:
param_bcast_sync_handler = None
adam_cfg = gpc.config.adam
naive_optimizer = torch.optim.AdamW(
params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}],

View File

@ -15,6 +15,7 @@ from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.scheduler import SchedulerMetricHook
from internlm.core.trainer import TrainState
from internlm.initialize import initialize_distributed_env
from internlm.model.loss import FlashGPTLMLoss
from internlm.model.metrics import AccPerplex
from internlm.monitor import initialize_monitor_manager, send_alert_message
@ -22,7 +23,6 @@ from internlm.monitor.monitor import monitor_manager as mm
from internlm.train import (
get_train_data_loader,
get_validation_data_loader,
initialize_distributed_env,
initialize_llm_profile,
initialize_model,
initialize_optimizer,