mirror of https://github.com/InternLM/InternLM
modify grad clipping due to moe
parent
bc699ad46f
commit
e21ddc364b
|
@ -116,7 +116,7 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
self._call_hooks("after_criterion", loss)
|
self._call_hooks("after_criterion", loss)
|
||||||
moe_loss = sum(moe_losses) * moe_loss_coeff
|
moe_loss = sum(moe_losses) * moe_loss_coeff
|
||||||
loss += moe_loss
|
loss += moe_loss
|
||||||
loss /= scale_loss # TODO: check whether mos_loss should be scaled
|
loss /= scale_loss
|
||||||
|
|
||||||
# backward
|
# backward
|
||||||
if not forward_only:
|
if not forward_only:
|
||||||
|
|
|
@ -10,6 +10,7 @@ from torch.optim import Optimizer
|
||||||
from internlm.core.context import Config, ParallelMode
|
from internlm.core.context import Config, ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.model.moe import is_moe_param
|
from internlm.model.moe import is_moe_param
|
||||||
|
from internlm.monitor import send_alert_message
|
||||||
from internlm.solver.optimizer.store import (
|
from internlm.solver.optimizer.store import (
|
||||||
BucketStore,
|
BucketStore,
|
||||||
GradientStore,
|
GradientStore,
|
||||||
|
@ -29,7 +30,6 @@ from internlm.solver.optimizer.utils import (
|
||||||
from internlm.utils.common import get_current_device
|
from internlm.utils.common import get_current_device
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
from internlm.monitor import send_alert_message
|
|
||||||
|
|
||||||
from .utils import compute_norm
|
from .utils import compute_norm
|
||||||
|
|
||||||
|
@ -89,6 +89,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
overlap_broadcast=False,
|
overlap_broadcast=False,
|
||||||
grad_scal_cfg: Config = None,
|
grad_scal_cfg: Config = None,
|
||||||
zero_cfg: Config = None,
|
zero_cfg: Config = None,
|
||||||
|
has_moe: bool = False,
|
||||||
):
|
):
|
||||||
# DynamicGradScaler related args
|
# DynamicGradScaler related args
|
||||||
if gpc.config.model.dtype is torch.float32:
|
if gpc.config.model.dtype is torch.float32:
|
||||||
|
@ -109,6 +110,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
super().__init__(optim=optimizer)
|
super().__init__(optim=optimizer)
|
||||||
|
|
||||||
|
self.has_moe = has_moe
|
||||||
|
|
||||||
self._dtype = self.optim.param_groups[0]["params"][0].dtype
|
self._dtype = self.optim.param_groups[0]["params"][0].dtype
|
||||||
self._cpu_offload = cpu_offload
|
self._cpu_offload = cpu_offload
|
||||||
self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
||||||
|
@ -277,7 +280,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
no_params_ranks.append(rank)
|
no_params_ranks.append(rank)
|
||||||
|
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
logger.info(f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}")
|
logger.info( # pylint: disable=W1203
|
||||||
|
f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}"
|
||||||
|
)
|
||||||
|
|
||||||
return params_per_rank, set(no_params_ranks)
|
return params_per_rank, set(no_params_ranks)
|
||||||
|
|
||||||
|
@ -503,6 +508,20 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
return self._step(closure=closure)
|
return self._step(closure=closure)
|
||||||
|
|
||||||
|
def _get_norm_with_moe_layers(self, norm_groups):
|
||||||
|
# all_groups_norm_old = all_groups_norm
|
||||||
|
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce
|
||||||
|
pg = gpc.get_group(ParallelMode.DATA)
|
||||||
|
print(type(norm_groups))
|
||||||
|
scaled_norm = norm_groups * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
|
||||||
|
scaled_norm_tensor = torch.tensor(
|
||||||
|
scaled_norm, device=self._fp32_flat_param_groups_of_current_rank[0].device, dtype=torch.float
|
||||||
|
)
|
||||||
|
dist.all_reduce(scaled_norm_tensor, group=pg)
|
||||||
|
all_groups_norm = scaled_norm_tensor.item()
|
||||||
|
# print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}")
|
||||||
|
return all_groups_norm
|
||||||
|
|
||||||
def _step(self, closure=None):
|
def _step(self, closure=None):
|
||||||
assert closure is None, "closure is not supported by step()"
|
assert closure is None, "closure is not supported by step()"
|
||||||
|
|
||||||
|
@ -582,6 +601,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
if self._clip_grad_norm > 0:
|
if self._clip_grad_norm > 0:
|
||||||
global_norm = sum(norm_groups) ** 0.5
|
global_norm = sum(norm_groups) ** 0.5
|
||||||
|
|
||||||
|
if self.has_moe:
|
||||||
|
global_norm = self._get_norm_with_moe_layers(global_norm)
|
||||||
|
|
||||||
# the following operations are performed only on the rank to which parameters are assigned.
|
# the following operations are performed only on the rank to which parameters are assigned.
|
||||||
if gpc.config.model.dtype is not torch.float32:
|
if gpc.config.model.dtype is not torch.float32:
|
||||||
if len(single_grad_partition_groups) != 0:
|
if len(single_grad_partition_groups) != 0:
|
||||||
|
|
31
train.py
31
train.py
|
@ -30,6 +30,7 @@ from internlm.data.packed_dataset import (
|
||||||
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
||||||
from internlm.model.loss import FlashGPTLMLoss
|
from internlm.model.loss import FlashGPTLMLoss
|
||||||
from internlm.model.metrics import AccPerplex
|
from internlm.model.metrics import AccPerplex
|
||||||
|
from internlm.model.moe import has_moe_layers
|
||||||
from internlm.monitor import initialize_monitor_manager, send_alert_message, set_env_var
|
from internlm.monitor import initialize_monitor_manager, send_alert_message, set_env_var
|
||||||
from internlm.monitor.monitor import monitor_manager as mm
|
from internlm.monitor.monitor import monitor_manager as mm
|
||||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||||
|
@ -247,7 +248,7 @@ def get_validation_data_loader(num_worker: int = 0):
|
||||||
batch_size = batch_size // data_cfg.micro_bsz * data_cfg.micro_bsz
|
batch_size = batch_size // data_cfg.micro_bsz * data_cfg.micro_bsz
|
||||||
|
|
||||||
if batch_size == 0 and gpc.is_rank_for_log():
|
if batch_size == 0 and gpc.is_rank_for_log():
|
||||||
logger.info(f"skip validate {val_name}.")
|
logger.info(f"skip validate {val_name}.") # pylint: disable=W1203
|
||||||
continue
|
continue
|
||||||
|
|
||||||
val_dls[val_name] = get_dpsampler_dataloader(
|
val_dls[val_name] = get_dpsampler_dataloader(
|
||||||
|
@ -255,7 +256,7 @@ def get_validation_data_loader(num_worker: int = 0):
|
||||||
) # drop_last=True, otherwise it may cause problems in the last batch
|
) # drop_last=True, otherwise it may cause problems in the last batch
|
||||||
|
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
logger.info(
|
logger.info( # pylint: disable=W1203
|
||||||
f"load validation dataset {val_name} with valid batch size {str(batch_size)} and "
|
f"load validation dataset {val_name} with valid batch size {str(batch_size)} and "
|
||||||
f"samples {str(len(val_dls[val_name]))}."
|
f"samples {str(len(val_dls[val_name]))}."
|
||||||
)
|
)
|
||||||
|
@ -307,8 +308,12 @@ def initialize_optimizer(model: nn.Module):
|
||||||
eps=adam_cfg.adam_eps,
|
eps=adam_cfg.adam_eps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
has_moe = has_moe_layers(model)
|
||||||
optimizer = HybridZeroOptimizer(
|
optimizer = HybridZeroOptimizer(
|
||||||
naive_optimizer, grad_scal_cfg=gpc.config.grad_scaler, zero_cfg=gpc.config.hybrid_zero_optimizer
|
naive_optimizer,
|
||||||
|
grad_scal_cfg=gpc.config.grad_scaler,
|
||||||
|
zero_cfg=gpc.config.hybrid_zero_optimizer,
|
||||||
|
has_moe=has_moe,
|
||||||
)
|
)
|
||||||
|
|
||||||
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
|
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
|
||||||
|
@ -472,19 +477,19 @@ def main(args):
|
||||||
|
|
||||||
model_load_path = None
|
model_load_path = None
|
||||||
if load_resume_ckpt_folder is not None:
|
if load_resume_ckpt_folder is not None:
|
||||||
logger.info(
|
logger.info( # pylint: disable=W1203
|
||||||
f"===========Resume training from `{load_resume_ckpt_folder}` {current_time} on host:"
|
f"===========Resume training from `{load_resume_ckpt_folder}` {current_time} on host:"
|
||||||
f"{socket.gethostname()}==========="
|
f"{socket.gethostname()}==========="
|
||||||
)
|
)
|
||||||
model_load_path = load_resume_ckpt_folder
|
model_load_path = load_resume_ckpt_folder
|
||||||
elif load_model_only_folder is not None:
|
elif load_model_only_folder is not None:
|
||||||
logger.info(
|
logger.info( # pylint: disable=W1203
|
||||||
f"===========SFT training from `{load_model_only_folder}` {current_time} on host:"
|
f"===========SFT training from `{load_model_only_folder}` {current_time} on host:"
|
||||||
f"{socket.gethostname()}==========="
|
f"{socket.gethostname()}==========="
|
||||||
)
|
)
|
||||||
model_load_path = load_model_only_folder
|
model_load_path = load_model_only_folder
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info( # pylint: disable=W1203
|
||||||
f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
|
f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
|
||||||
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
|
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
|
||||||
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
|
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
|
||||||
|
@ -580,7 +585,7 @@ def main(args):
|
||||||
train_state.num_consumed_samples_in_epoch += len(batch[1])
|
train_state.num_consumed_samples_in_epoch += len(batch[1])
|
||||||
if batch_skipper(batch_count): # skip this batch
|
if batch_skipper(batch_count): # skip this batch
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
logger.info(f"Skip batch count:`{batch_count}`...")
|
logger.info(f"Skip batch count:`{batch_count}`...") # pylint: disable=W1203
|
||||||
timer("one-batch").stop()
|
timer("one-batch").stop()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -596,7 +601,13 @@ def main(args):
|
||||||
|
|
||||||
# do forward and backward
|
# do forward and backward
|
||||||
timer("fwd-bwd").start()
|
timer("fwd-bwd").start()
|
||||||
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False, moe_loss_coeff = gpc.config.loss.moe_loss_coeff)
|
_, _, loss = trainer.execute_schedule(
|
||||||
|
batch,
|
||||||
|
forward_only=False,
|
||||||
|
return_loss=True,
|
||||||
|
return_output_label=False,
|
||||||
|
moe_loss_coeff=gpc.config.loss.moe_loss_coeff,
|
||||||
|
)
|
||||||
timer("fwd-bwd").stop()
|
timer("fwd-bwd").stop()
|
||||||
|
|
||||||
# update parameters, and returns (success_update, grad_norm)
|
# update parameters, and returns (success_update, grad_norm)
|
||||||
|
@ -609,7 +620,7 @@ def main(args):
|
||||||
else:
|
else:
|
||||||
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
||||||
if grad_norm == -99.0 and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
|
if grad_norm == -99.0 and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
|
||||||
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
logger.warning(f"Warning: skip parameter update at step {batch_count}.") # pylint: disable=W1203
|
||||||
send_alert_message(
|
send_alert_message(
|
||||||
address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}."
|
address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}."
|
||||||
)
|
)
|
||||||
|
@ -667,7 +678,7 @@ if __name__ == "__main__":
|
||||||
try:
|
try:
|
||||||
main(args)
|
main(args)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error(
|
logger.error( # pylint: disable=W1203
|
||||||
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}",
|
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}",
|
||||||
exc_info=traceback.format_exc(),
|
exc_info=traceback.format_exc(),
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue