feat(init): add skip args check flag and add zero overlap flag (#222)

* feat(init): add skip args check flag

* fix(optim): add param overlap enable flag
pull/228/head
Guoteng 2023-08-24 16:44:18 +08:00 committed by GitHub
parent 9cd1e0314e
commit 7c820cfa40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 88 additions and 51 deletions

View File

@ -75,7 +75,8 @@ grad_scaler = dict(
hybrid_zero_optimizer = dict( hybrid_zero_optimizer = dict(
# Enable low_level_optimzer overlap_communication # Enable low_level_optimzer overlap_communication
zero_overlap_communication=True, overlap_sync_grad=True,
overlap_sync_param=True,
# bucket size for nccl communication params # bucket size for nccl communication params
reduce_bucket_size=512 * 1024 * 1024, reduce_bucket_size=512 * 1024 * 1024,
# grad clipping # grad clipping
@ -120,7 +121,7 @@ model = dict(
num_layers=NUM_LAYER, num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO, mlp_ratio=MLP_RATIO,
apply_post_layer_norm=False, apply_post_layer_norm=False,
dtype="torch.tf32", # dtype could be in "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32", dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm", norm_type="rmsnorm",
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
use_flash_attn=True, use_flash_attn=True,

View File

@ -1,9 +1,15 @@
from .initialize_trainer import initialize_trainer from .initialize_trainer import initialize_trainer
from .launch import get_default_parser, launch_from_slurm, launch_from_torch from .launch import (
get_default_parser,
initialize_distributed_env,
launch_from_slurm,
launch_from_torch,
)
__all__ = [ __all__ = [
"get_default_parser", "get_default_parser",
"initialize_trainer", "initialize_trainer",
"launch_from_slurm", "launch_from_slurm",
"launch_from_torch", "launch_from_torch",
"initialize_distributed_env",
] ]

View File

@ -10,6 +10,7 @@ import torch
from internlm.core.context import Config from internlm.core.context import Config
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.utils.common import get_master_node
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
from internlm.utils.storage_manager import init_storage_manager from internlm.utils.storage_manager import init_storage_manager
@ -276,6 +277,19 @@ and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
if "alert_address" not in gpc.config: if "alert_address" not in gpc.config:
gpc.config._add_item("alert_address", None) gpc.config._add_item("alert_address", None)
optim_ckpt = gpc.config.hybrid_zero_optimizer
if "zero_overlap_communication" in optim_ckpt:
# Compatible with the old interfaces.
optim_ckpt._add_item("overlap_sync_grad", optim_ckpt.zero_overlap_communication)
if "overlap_sync_grad" not in optim_ckpt:
optim_ckpt._add_item("overlap_sync_grad", False)
if "overlap_sync_param" not in optim_ckpt:
optim_ckpt._add_item("overlap_sync_param", False)
if gpc.is_rank_for_log():
logger.info(
f"overlap_sync_grad:{optim_ckpt.overlap_sync_grad}, overlap_sync_param:{optim_ckpt.overlap_sync_param}"
)
def launch( def launch(
config: Union[str, Path, Config, Dict], config: Union[str, Path, Config, Dict],
@ -322,8 +336,6 @@ def launch(
# init process groups for different parallel modes from config # init process groups for different parallel modes from config
gpc.init_parallel_groups() gpc.init_parallel_groups()
args_sanity_check()
# set cuda device # set cuda device
if torch.cuda.is_available(): if torch.cuda.is_available():
# if local rank is not given, calculate automatically # if local rank is not given, calculate automatically
@ -376,7 +388,11 @@ def launch_from_slurm(
) )
def launch_from_torch(config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024): def launch_from_torch(
config: Union[str, Path, Config, Dict],
backend: str = "nccl",
seed: int = 1024,
):
"""A wrapper for internlm.launch for torchrun or torch.distributed.launch by reading rank and world size """A wrapper for internlm.launch for torchrun or torch.distributed.launch by reading rank and world size
from the environment variables set by PyTorch from the environment variables set by PyTorch
@ -404,3 +420,38 @@ def launch_from_torch(config: Union[str, Path, Config, Dict], backend: str = "nc
backend=backend, backend=backend,
seed=seed, seed=seed,
) )
def initialize_distributed_env(
config: str,
launcher: str = "slurm",
master_port: int = 8888,
seed: int = 1024,
args_check=True,
):
"""
Initialize distributed environment for distributed training.
Args:
config (str): Config file path.
launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default.
master_port (str): The master port for distributed training. 8888 by default.
seed (int, optional): Specified random seed for every process. 1024 by default.
"""
torch.cuda.empty_cache()
if launcher == "torch":
launch_from_torch(config=config, seed=seed)
elif launcher == "slurm":
launch_from_slurm(
config=config,
host=get_master_node(),
port=master_port,
seed=seed,
)
else:
assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"
if args_check:
args_sanity_check()

View File

@ -133,7 +133,7 @@ class MHA(nn.Module):
if inference_params is None: if inference_params is None:
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn: if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
with torch.cuda.amp.autocast(dtype=torch.float16): with torch.cuda.amp.autocast(dtype=torch.bfloat16):
if qkv.dtype not in [torch.float16, torch.bfloat16]: if qkv.dtype not in [torch.float16, torch.bfloat16]:
qkv = qkv.to(torch.bfloat16) qkv = qkv.to(torch.bfloat16)
context = self.inner_attn(qkv).to(x.dtype) context = self.inner_attn(qkv).to(x.dtype)
@ -171,7 +171,7 @@ class MHA(nn.Module):
if inference_params is None: if inference_params is None:
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn: if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
with torch.cuda.amp.autocast(dtype=torch.float16): with torch.cuda.amp.autocast(dtype=torch.bfloat16):
if qkv.dtype not in [torch.float16, torch.bfloat16]: if qkv.dtype not in [torch.float16, torch.bfloat16]:
qkv = qkv.to(torch.bfloat16) qkv = qkv.to(torch.bfloat16)
context = self.inner_attn(qkv, **kwargs).to(x.dtype) context = self.inner_attn(qkv, **kwargs).to(x.dtype)

View File

@ -106,9 +106,10 @@ class HybridZeroOptimizer(BaseOptimizer):
max_scale = grad_scal_cfg.max_scale max_scale = grad_scal_cfg.max_scale
# Zero related args # Zero related args
overlap_communication = zero_cfg.zero_overlap_communication
reduce_bucket_size = zero_cfg.reduce_bucket_size reduce_bucket_size = zero_cfg.reduce_bucket_size
clip_grad_norm = zero_cfg.clip_grad_norm clip_grad_norm = zero_cfg.clip_grad_norm
self._overlap_sync_grad = zero_cfg.overlap_sync_grad
self._overlap_sync_param = zero_cfg.overlap_sync_param
super().__init__(optim=optimizer) super().__init__(optim=optimizer)
@ -129,7 +130,7 @@ class HybridZeroOptimizer(BaseOptimizer):
self._fp32_flat_param_groups_of_current_rank = dict() self._fp32_flat_param_groups_of_current_rank = dict()
# communication params # communication params
self._overlap_communication = overlap_communication # self._overlap_communication = overlap_communication
self._reduce_bucket_size = reduce_bucket_size self._reduce_bucket_size = reduce_bucket_size
# gradient scaler # gradient scaler
@ -161,8 +162,11 @@ class HybridZeroOptimizer(BaseOptimizer):
) )
self.params_per_rank_id_dict = [] self.params_per_rank_id_dict = []
self._param_bcast_sync_handler = param_bcast_sync_handler self._param_bcast_sync_handler = param_bcast_sync_handler
if self._overlap_communication: if self._overlap_sync_param:
assert self._param_bcast_sync_handler is not None assert self._param_bcast_sync_handler is not None
self._broadcast_comm_stream = torch.cuda.Stream()
else:
self._broadcast_comm_stream = torch.cuda.current_stream()
# iterate over the param group in the optimizer # iterate over the param group in the optimizer
# partition these param groups for data parallel training # partition these param groups for data parallel training
@ -232,14 +236,14 @@ class HybridZeroOptimizer(BaseOptimizer):
# initialize communication stream for # initialize communication stream for
# communication-computation overlapping # communication-computation overlapping
if self._overlap_communication: if self._overlap_sync_grad:
self._comm_stream = torch.cuda.Stream() self._comm_stream = torch.cuda.Stream()
else: else:
self._comm_stream = torch.cuda.current_stream() self._comm_stream = torch.cuda.current_stream()
# reduction hook is only used if overlapping communication # reduction hook is only used if overlapping communication
# if it is stage 1 without overlapping, no hook will be attached # if it is stage 1 without overlapping, no hook will be attached
if self._overlap_communication: if self._overlap_sync_grad:
self._attach_reduction_hook() self._attach_reduction_hook()
@property @property
@ -273,7 +277,7 @@ class HybridZeroOptimizer(BaseOptimizer):
global_id = str(i) global_id = str(i)
for j in range(len(param.size())): for j in range(len(param.size())):
global_id = "_".join([global_id, str(param.size()[j])]) global_id = "_".join([global_id, str(param.size()[j])])
if self._overlap_communication: if self._overlap_sync_param:
rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param) rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param)
else: else:
rank_to_go = numel_per_rank.index(min(numel_per_rank)) rank_to_go = numel_per_rank.index(min(numel_per_rank))
@ -394,7 +398,7 @@ class HybridZeroOptimizer(BaseOptimizer):
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank) self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank): def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
if self._overlap_communication: if self._overlap_sync_grad:
self._comm_stream.synchronize() self._comm_stream.synchronize()
self._param_store.clear_grads_of_previous_reduced_params() self._param_store.clear_grads_of_previous_reduced_params()
@ -517,7 +521,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# if not overlapping communication (no reduction hook is attached) # if not overlapping communication (no reduction hook is attached)
# we need to manually reduce these gradients # we need to manually reduce these gradients
if not self._overlap_communication: if not self._overlap_sync_grad:
for group_id in range(len(self._fp16_param_groups)): for group_id in range(len(self._fp16_param_groups)):
for param in self._fp16_param_groups[group_id]: for param in self._fp16_param_groups[group_id]:
if param.grad is not None: if param.grad is not None:
@ -532,7 +536,7 @@ class HybridZeroOptimizer(BaseOptimizer):
groups_norms.append(self._compute_norm_with_stage(group_id=group_id)) groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
# clear reduced grads # clear reduced grads
if self._overlap_communication: if self._overlap_sync_grad:
# grads in the last bucket is reduced # grads in the last bucket is reduced
self._comm_stream.synchronize() self._comm_stream.synchronize()
self._param_store.clear_grads_of_previous_reduced_params() self._param_store.clear_grads_of_previous_reduced_params()
@ -641,7 +645,7 @@ class HybridZeroOptimizer(BaseOptimizer):
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp16_param.data.copy_(fp32_param) fp16_param.data.copy_(fp32_param)
with torch.cuda.stream(self._comm_stream): with torch.cuda.stream(self._broadcast_comm_stream):
self.broadcast_params() self.broadcast_params()
timer("step").stop() timer("step").stop()
@ -668,7 +672,7 @@ class HybridZeroOptimizer(BaseOptimizer):
async_op=True, async_op=True,
) )
if self._overlap_communication: if self._overlap_sync_param:
self._param_bcast_sync_handler.add_bcast_handle(rank, handle) self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
else: else:
handles.append(handle) handles.append(handle)

View File

@ -1,7 +1,6 @@
from .training_internlm import ( from .training_internlm import (
get_train_data_loader, get_train_data_loader,
get_validation_data_loader, get_validation_data_loader,
initialize_distributed_env,
initialize_llm_profile, initialize_llm_profile,
initialize_model, initialize_model,
initialize_optimizer, initialize_optimizer,
@ -12,7 +11,6 @@ from .training_internlm import (
__all__ = [ __all__ = [
"get_train_data_loader", "get_train_data_loader",
"get_validation_data_loader", "get_validation_data_loader",
"initialize_distributed_env",
"initialize_llm_profile", "initialize_llm_profile",
"initialize_model", "initialize_model",
"initialize_optimizer", "initialize_optimizer",

View File

@ -10,7 +10,6 @@ import torch.distributed as dist
from torch import nn from torch import nn
from torch.utils.data import ConcatDataset, DataLoader from torch.utils.data import ConcatDataset, DataLoader
import internlm
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel from internlm.core.naive_amp import NaiveAMPModel
@ -31,7 +30,7 @@ from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer import HybridZeroOptimizer from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.solver.optimizer.utils import ParamBcastSyncHandler from internlm.solver.optimizer.utils import ParamBcastSyncHandler
from internlm.utils.common import DummyProfile, get_master_node from internlm.utils.common import DummyProfile
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import ( from internlm.utils.parallel import (
@ -44,32 +43,6 @@ from internlm.utils.registry import MODEL_INITIALIZER
logger = get_logger(__file__) logger = get_logger(__file__)
def initialize_distributed_env(config: str, launcher: str = "slurm", master_port: int = 8888, seed: int = 1024):
"""
Initialize distributed environment for distributed training.
Args:
config (str): Config file path.
launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default.
master_port (str): The master port for distributed training. 8888 by default.
seed (int, optional): Specified random seed for every process. 1024 by default.
"""
torch.cuda.empty_cache()
if launcher == "torch":
internlm.launch_from_torch(config=config, seed=seed)
elif launcher == "slurm":
internlm.launch_from_slurm(
config=config,
host=get_master_node(),
port=master_port,
seed=seed,
)
else:
assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"
def initialize_model(): def initialize_model():
""" """
Initialize model. Initialize model.
@ -119,7 +92,11 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler). Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler).
""" """
param_bcast_sync_handler = ParamBcastSyncHandler(model) if gpc.config.hybrid_zero_optimizer.overlap_sync_param:
param_bcast_sync_handler = ParamBcastSyncHandler(model)
else:
param_bcast_sync_handler = None
adam_cfg = gpc.config.adam adam_cfg = gpc.config.adam
naive_optimizer = torch.optim.AdamW( naive_optimizer = torch.optim.AdamW(
params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}], params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}],

View File

@ -15,6 +15,7 @@ from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.core.scheduler import SchedulerMetricHook from internlm.core.scheduler import SchedulerMetricHook
from internlm.core.trainer import TrainState from internlm.core.trainer import TrainState
from internlm.initialize import initialize_distributed_env
from internlm.model.loss import FlashGPTLMLoss from internlm.model.loss import FlashGPTLMLoss
from internlm.model.metrics import AccPerplex from internlm.model.metrics import AccPerplex
from internlm.monitor import initialize_monitor_manager, send_alert_message from internlm.monitor import initialize_monitor_manager, send_alert_message
@ -22,7 +23,6 @@ from internlm.monitor.monitor import monitor_manager as mm
from internlm.train import ( from internlm.train import (
get_train_data_loader, get_train_data_loader,
get_validation_data_loader, get_validation_data_loader,
initialize_distributed_env,
initialize_llm_profile, initialize_llm_profile,
initialize_model, initialize_model,
initialize_optimizer, initialize_optimizer,