modify grad clipping due to moe

pull/375/head
Wenwen Qu 2023-08-09 15:04:19 +08:00
parent bc699ad46f
commit e21ddc364b
3 changed files with 46 additions and 13 deletions

View File

@ -116,7 +116,7 @@ class NonPipelineScheduler(BaseScheduler):
self._call_hooks("after_criterion", loss)
moe_loss = sum(moe_losses) * moe_loss_coeff
loss += moe_loss
loss /= scale_loss # TODO: check whether mos_loss should be scaled
loss /= scale_loss
# backward
if not forward_only:

View File

@ -10,6 +10,7 @@ from torch.optim import Optimizer
from internlm.core.context import Config, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.moe import is_moe_param
from internlm.monitor import send_alert_message
from internlm.solver.optimizer.store import (
BucketStore,
GradientStore,
@ -29,7 +30,6 @@ from internlm.solver.optimizer.utils import (
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.monitor import send_alert_message
from .utils import compute_norm
@ -89,6 +89,7 @@ class HybridZeroOptimizer(BaseOptimizer):
overlap_broadcast=False,
grad_scal_cfg: Config = None,
zero_cfg: Config = None,
has_moe: bool = False,
):
# DynamicGradScaler related args
if gpc.config.model.dtype is torch.float32:
@ -109,6 +110,8 @@ class HybridZeroOptimizer(BaseOptimizer):
super().__init__(optim=optimizer)
self.has_moe = has_moe
self._dtype = self.optim.param_groups[0]["params"][0].dtype
self._cpu_offload = cpu_offload
self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
@ -277,7 +280,9 @@ class HybridZeroOptimizer(BaseOptimizer):
no_params_ranks.append(rank)
if gpc.is_rank_for_log():
logger.info(f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}")
logger.info( # pylint: disable=W1203
f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}"
)
return params_per_rank, set(no_params_ranks)
@ -503,6 +508,20 @@ class HybridZeroOptimizer(BaseOptimizer):
return self._step(closure=closure)
def _get_norm_with_moe_layers(self, norm_groups):
# all_groups_norm_old = all_groups_norm
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce
pg = gpc.get_group(ParallelMode.DATA)
print(type(norm_groups))
scaled_norm = norm_groups * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
scaled_norm_tensor = torch.tensor(
scaled_norm, device=self._fp32_flat_param_groups_of_current_rank[0].device, dtype=torch.float
)
dist.all_reduce(scaled_norm_tensor, group=pg)
all_groups_norm = scaled_norm_tensor.item()
# print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}")
return all_groups_norm
def _step(self, closure=None):
assert closure is None, "closure is not supported by step()"
@ -582,6 +601,9 @@ class HybridZeroOptimizer(BaseOptimizer):
if self._clip_grad_norm > 0:
global_norm = sum(norm_groups) ** 0.5
if self.has_moe:
global_norm = self._get_norm_with_moe_layers(global_norm)
# the following operations are performed only on the rank to which parameters are assigned.
if gpc.config.model.dtype is not torch.float32:
if len(single_grad_partition_groups) != 0:

View File

@ -30,6 +30,7 @@ from internlm.data.packed_dataset import (
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
from internlm.model.loss import FlashGPTLMLoss
from internlm.model.metrics import AccPerplex
from internlm.model.moe import has_moe_layers
from internlm.monitor import initialize_monitor_manager, send_alert_message, set_env_var
from internlm.monitor.monitor import monitor_manager as mm
from internlm.solver.beta2_scheduler import Beta2Scheduler
@ -247,7 +248,7 @@ def get_validation_data_loader(num_worker: int = 0):
batch_size = batch_size // data_cfg.micro_bsz * data_cfg.micro_bsz
if batch_size == 0 and gpc.is_rank_for_log():
logger.info(f"skip validate {val_name}.")
logger.info(f"skip validate {val_name}.") # pylint: disable=W1203
continue
val_dls[val_name] = get_dpsampler_dataloader(
@ -255,7 +256,7 @@ def get_validation_data_loader(num_worker: int = 0):
) # drop_last=True, otherwise it may cause problems in the last batch
if gpc.is_rank_for_log():
logger.info(
logger.info( # pylint: disable=W1203
f"load validation dataset {val_name} with valid batch size {str(batch_size)} and "
f"samples {str(len(val_dls[val_name]))}."
)
@ -307,8 +308,12 @@ def initialize_optimizer(model: nn.Module):
eps=adam_cfg.adam_eps,
)
has_moe = has_moe_layers(model)
optimizer = HybridZeroOptimizer(
naive_optimizer, grad_scal_cfg=gpc.config.grad_scaler, zero_cfg=gpc.config.hybrid_zero_optimizer
naive_optimizer,
grad_scal_cfg=gpc.config.grad_scaler,
zero_cfg=gpc.config.hybrid_zero_optimizer,
has_moe=has_moe,
)
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
@ -472,19 +477,19 @@ def main(args):
model_load_path = None
if load_resume_ckpt_folder is not None:
logger.info(
logger.info( # pylint: disable=W1203
f"===========Resume training from `{load_resume_ckpt_folder}` {current_time} on host:"
f"{socket.gethostname()}==========="
)
model_load_path = load_resume_ckpt_folder
elif load_model_only_folder is not None:
logger.info(
logger.info( # pylint: disable=W1203
f"===========SFT training from `{load_model_only_folder}` {current_time} on host:"
f"{socket.gethostname()}==========="
)
model_load_path = load_model_only_folder
else:
logger.info(
logger.info( # pylint: disable=W1203
f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
@ -580,7 +585,7 @@ def main(args):
train_state.num_consumed_samples_in_epoch += len(batch[1])
if batch_skipper(batch_count): # skip this batch
if gpc.is_rank_for_log():
logger.info(f"Skip batch count:`{batch_count}`...")
logger.info(f"Skip batch count:`{batch_count}`...") # pylint: disable=W1203
timer("one-batch").stop()
continue
@ -596,7 +601,13 @@ def main(args):
# do forward and backward
timer("fwd-bwd").start()
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False, moe_loss_coeff = gpc.config.loss.moe_loss_coeff)
_, _, loss = trainer.execute_schedule(
batch,
forward_only=False,
return_loss=True,
return_output_label=False,
moe_loss_coeff=gpc.config.loss.moe_loss_coeff,
)
timer("fwd-bwd").stop()
# update parameters, and returns (success_update, grad_norm)
@ -609,7 +620,7 @@ def main(args):
else:
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
if grad_norm == -99.0 and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
logger.warning(f"Warning: skip parameter update at step {batch_count}.") # pylint: disable=W1203
send_alert_message(
address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}."
)
@ -667,7 +678,7 @@ if __name__ == "__main__":
try:
main(args)
except Exception:
logger.error(
logger.error( # pylint: disable=W1203
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}",
exc_info=traceback.format_exc(),
)