mirror of https://github.com/InternLM/InternLM
modify grad clipping due to moe
parent
bc699ad46f
commit
e21ddc364b
|
@ -116,7 +116,7 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
self._call_hooks("after_criterion", loss)
|
||||
moe_loss = sum(moe_losses) * moe_loss_coeff
|
||||
loss += moe_loss
|
||||
loss /= scale_loss # TODO: check whether mos_loss should be scaled
|
||||
loss /= scale_loss
|
||||
|
||||
# backward
|
||||
if not forward_only:
|
||||
|
|
|
@ -10,6 +10,7 @@ from torch.optim import Optimizer
|
|||
from internlm.core.context import Config, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.moe import is_moe_param
|
||||
from internlm.monitor import send_alert_message
|
||||
from internlm.solver.optimizer.store import (
|
||||
BucketStore,
|
||||
GradientStore,
|
||||
|
@ -29,7 +30,6 @@ from internlm.solver.optimizer.utils import (
|
|||
from internlm.utils.common import get_current_device
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.monitor import send_alert_message
|
||||
|
||||
from .utils import compute_norm
|
||||
|
||||
|
@ -89,6 +89,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
overlap_broadcast=False,
|
||||
grad_scal_cfg: Config = None,
|
||||
zero_cfg: Config = None,
|
||||
has_moe: bool = False,
|
||||
):
|
||||
# DynamicGradScaler related args
|
||||
if gpc.config.model.dtype is torch.float32:
|
||||
|
@ -109,6 +110,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
super().__init__(optim=optimizer)
|
||||
|
||||
self.has_moe = has_moe
|
||||
|
||||
self._dtype = self.optim.param_groups[0]["params"][0].dtype
|
||||
self._cpu_offload = cpu_offload
|
||||
self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
||||
|
@ -277,7 +280,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
no_params_ranks.append(rank)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}")
|
||||
logger.info( # pylint: disable=W1203
|
||||
f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}"
|
||||
)
|
||||
|
||||
return params_per_rank, set(no_params_ranks)
|
||||
|
||||
|
@ -503,6 +508,20 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
return self._step(closure=closure)
|
||||
|
||||
def _get_norm_with_moe_layers(self, norm_groups):
|
||||
# all_groups_norm_old = all_groups_norm
|
||||
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce
|
||||
pg = gpc.get_group(ParallelMode.DATA)
|
||||
print(type(norm_groups))
|
||||
scaled_norm = norm_groups * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
|
||||
scaled_norm_tensor = torch.tensor(
|
||||
scaled_norm, device=self._fp32_flat_param_groups_of_current_rank[0].device, dtype=torch.float
|
||||
)
|
||||
dist.all_reduce(scaled_norm_tensor, group=pg)
|
||||
all_groups_norm = scaled_norm_tensor.item()
|
||||
# print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}")
|
||||
return all_groups_norm
|
||||
|
||||
def _step(self, closure=None):
|
||||
assert closure is None, "closure is not supported by step()"
|
||||
|
||||
|
@ -582,6 +601,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
if self._clip_grad_norm > 0:
|
||||
global_norm = sum(norm_groups) ** 0.5
|
||||
|
||||
if self.has_moe:
|
||||
global_norm = self._get_norm_with_moe_layers(global_norm)
|
||||
|
||||
# the following operations are performed only on the rank to which parameters are assigned.
|
||||
if gpc.config.model.dtype is not torch.float32:
|
||||
if len(single_grad_partition_groups) != 0:
|
||||
|
|
31
train.py
31
train.py
|
@ -30,6 +30,7 @@ from internlm.data.packed_dataset import (
|
|||
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
||||
from internlm.model.loss import FlashGPTLMLoss
|
||||
from internlm.model.metrics import AccPerplex
|
||||
from internlm.model.moe import has_moe_layers
|
||||
from internlm.monitor import initialize_monitor_manager, send_alert_message, set_env_var
|
||||
from internlm.monitor.monitor import monitor_manager as mm
|
||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||
|
@ -247,7 +248,7 @@ def get_validation_data_loader(num_worker: int = 0):
|
|||
batch_size = batch_size // data_cfg.micro_bsz * data_cfg.micro_bsz
|
||||
|
||||
if batch_size == 0 and gpc.is_rank_for_log():
|
||||
logger.info(f"skip validate {val_name}.")
|
||||
logger.info(f"skip validate {val_name}.") # pylint: disable=W1203
|
||||
continue
|
||||
|
||||
val_dls[val_name] = get_dpsampler_dataloader(
|
||||
|
@ -255,7 +256,7 @@ def get_validation_data_loader(num_worker: int = 0):
|
|||
) # drop_last=True, otherwise it may cause problems in the last batch
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(
|
||||
logger.info( # pylint: disable=W1203
|
||||
f"load validation dataset {val_name} with valid batch size {str(batch_size)} and "
|
||||
f"samples {str(len(val_dls[val_name]))}."
|
||||
)
|
||||
|
@ -307,8 +308,12 @@ def initialize_optimizer(model: nn.Module):
|
|||
eps=adam_cfg.adam_eps,
|
||||
)
|
||||
|
||||
has_moe = has_moe_layers(model)
|
||||
optimizer = HybridZeroOptimizer(
|
||||
naive_optimizer, grad_scal_cfg=gpc.config.grad_scaler, zero_cfg=gpc.config.hybrid_zero_optimizer
|
||||
naive_optimizer,
|
||||
grad_scal_cfg=gpc.config.grad_scaler,
|
||||
zero_cfg=gpc.config.hybrid_zero_optimizer,
|
||||
has_moe=has_moe,
|
||||
)
|
||||
|
||||
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
|
||||
|
@ -472,19 +477,19 @@ def main(args):
|
|||
|
||||
model_load_path = None
|
||||
if load_resume_ckpt_folder is not None:
|
||||
logger.info(
|
||||
logger.info( # pylint: disable=W1203
|
||||
f"===========Resume training from `{load_resume_ckpt_folder}` {current_time} on host:"
|
||||
f"{socket.gethostname()}==========="
|
||||
)
|
||||
model_load_path = load_resume_ckpt_folder
|
||||
elif load_model_only_folder is not None:
|
||||
logger.info(
|
||||
logger.info( # pylint: disable=W1203
|
||||
f"===========SFT training from `{load_model_only_folder}` {current_time} on host:"
|
||||
f"{socket.gethostname()}==========="
|
||||
)
|
||||
model_load_path = load_model_only_folder
|
||||
else:
|
||||
logger.info(
|
||||
logger.info( # pylint: disable=W1203
|
||||
f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
|
||||
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
|
||||
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
|
||||
|
@ -580,7 +585,7 @@ def main(args):
|
|||
train_state.num_consumed_samples_in_epoch += len(batch[1])
|
||||
if batch_skipper(batch_count): # skip this batch
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Skip batch count:`{batch_count}`...")
|
||||
logger.info(f"Skip batch count:`{batch_count}`...") # pylint: disable=W1203
|
||||
timer("one-batch").stop()
|
||||
continue
|
||||
|
||||
|
@ -596,7 +601,13 @@ def main(args):
|
|||
|
||||
# do forward and backward
|
||||
timer("fwd-bwd").start()
|
||||
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False, moe_loss_coeff = gpc.config.loss.moe_loss_coeff)
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch,
|
||||
forward_only=False,
|
||||
return_loss=True,
|
||||
return_output_label=False,
|
||||
moe_loss_coeff=gpc.config.loss.moe_loss_coeff,
|
||||
)
|
||||
timer("fwd-bwd").stop()
|
||||
|
||||
# update parameters, and returns (success_update, grad_norm)
|
||||
|
@ -609,7 +620,7 @@ def main(args):
|
|||
else:
|
||||
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
||||
if grad_norm == -99.0 and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
|
||||
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
||||
logger.warning(f"Warning: skip parameter update at step {batch_count}.") # pylint: disable=W1203
|
||||
send_alert_message(
|
||||
address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}."
|
||||
)
|
||||
|
@ -667,7 +678,7 @@ if __name__ == "__main__":
|
|||
try:
|
||||
main(args)
|
||||
except Exception:
|
||||
logger.error(
|
||||
logger.error( # pylint: disable=W1203
|
||||
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}",
|
||||
exc_info=traceback.format_exc(),
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue