diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index 8462def..7805e11 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -328,13 +328,12 @@ class FSTPOverlapSchedulerHook(SchedulerHook): SchedulerHook for fstp overlap handler """ - def __init__(self, overlap_handler: FSTPOverlapHandler) -> None: - super().__init__() - + def __init__(self, overlap_handler: FSTPOverlapHandler, zero_optim) -> None: self._overlap_handler = overlap_handler + self._zero_optim = zero_optim def before_forward(self, scheduler, inputs) -> None: - if self._overlap_handler is not None: + if self._overlap_handler.model_checkpoint: self._overlap_handler.set_forward_mode(True) def after_forward(self, scheduler, outputs) -> None: @@ -347,11 +346,11 @@ class FSTPOverlapSchedulerHook(SchedulerHook): pass def before_backward(self, scheduler, outputs, outputs_grad) -> None: - if self._overlap_handler is not None: + if self._overlap_handler.model_checkpoint: self._overlap_handler.set_forward_mode(False) def after_backward(self, scheduler, inputs_grad) -> None: - pass + self._zero_optim.accumulate_left_grads_after_backward() def post_helper_func(self, scheduler, outputs, label) -> None: pass diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 19a79bf..2d04bc6 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -66,10 +66,6 @@ class HybridZeroOptimizer(BaseOptimizer): hysteresis = grad_scal_cfg.hysteresis max_scale = grad_scal_cfg.max_scale - self._fstp_handler = None - if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True: - self._fstp_handler = gpc.fstp_handler - # Zero related args reduce_bucket_size = zero_cfg.reduce_bucket_size clip_grad_norm = zero_cfg.clip_grad_norm @@ -133,6 +129,12 @@ class HybridZeroOptimizer(BaseOptimizer): if self._overlap_sync_param: assert self._param_bcast_sync_handler is not None + if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True: + self._fstp_handler = gpc.fstp_handler + else: + self._fstp_handler = None + self._accum_grad_buckets: List[BucketStore] = [] + # iterate over the param group in the optimizer # partition these param groups for data parallel training # and add buffers to parameter store for future access @@ -221,8 +223,7 @@ class HybridZeroOptimizer(BaseOptimizer): # reduction hook is only used if overlapping communication # if it is stage 1 without overlapping, no hook will be attached - if self._overlap_sync_grad: - self._attach_reduction_hook() + self._attach_reduction_hook() @property def zero_local_rank(self): @@ -289,60 +290,79 @@ class HybridZeroOptimizer(BaseOptimizer): param_group = self._fp16_param_groups[group_id] for param in param_group: # we should not reduce the param in moe - if param.requires_grad: - reduce_rank = None + if not param.requires_grad: + continue - def _define_and_attach(param, reduce_rank=None): - # get the AccumulateGrad object of the param itself - # If these objects are not kept, reduction hooks may not be attached successfully. - accum_grad_obj = get_grad_accumulate_object(param) - self._grad_store.add_accumulate_grad_object(accum_grad_obj) + reduce_rank = None - reduction_func = partial( - self._store_and_try_reduce_grads_by_bucket, - param=param, - reduce_rank=reduce_rank, + def _define_and_attach(param, reduce_rank=None): + reduction_func = partial( + self._store_and_try_reduce_grads_by_bucket, + param=param, + reduce_rank=reduce_rank, + ) + + reduce_scatter_checker = partial( + self._wait_reduce_scatter_and_accumulate_grads, + param=param, + reduce_rank=reduce_rank, + ) + + def reduction_sp_func(): + handle = reduce_tensor( + param.grad, + dtype=None, + dst_rank=reduce_rank, + parallel_mode=ParallelMode.TENSOR, ) + handle.wait() - reduce_scatter_checker = partial( - self._wait_reduce_scatter_and_accumulate_grads, - param=param, - reduce_rank=reduce_rank, - ) - def reduction_sp_func(): - handle = reduce_tensor( - param.grad, - dtype=None, - dst_rank=reduce_rank, - parallel_mode=ParallelMode.TENSOR, - ) - handle.wait() + # define hook + # NOT IMPORTANT BUT GOOD TO KNOW: + # args here is not grad, but allow_unreacable and accumulate_grad + def reduce_grad_hook(*args): # pylint: disable=W0613 + if self.skip_grad_reduce is False: + reduction_func() - # define hook - # NOT IMPORTANT BUT GOOD TO KNOW: - # args here is not grad, but allow_unreacable and accumulate_grad - def reduce_grad_hook(*args): # pylint: disable=W0613 - if self._fstp_handler is not None: - reduce_scatter_checker() + # define hook for real gradient accumulation. + def accum_grad_hook(*args): # pylint: disable=W0613 + reduce_scatter_checker() - if self.skip_grad_reduce is False: - reduction_func() + # define hook for sequence_parallel + def reduce_grad_hook_sp(*args): # pylint: disable=W0613 + if self.skip_grad_reduce is False: + reduction_sp_func() - # define hook for sequence_parallel - def reduce_grad_hook_sp(*args): # pylint: disable=W0613 - if self.skip_grad_reduce is False: - reduction_sp_func() + # get the AccumulateGrad object of the param itself + # If these objects are not kept, reduction hooks may not be attached successfully. + accum_grad_obj = get_grad_accumulate_object(param) + self._grad_store.add_accumulate_grad_object(accum_grad_obj) - # if sequence_parallel is True, - # the grad of norm should be all-reduce across the tp process group - if gpc.config.parallel.sequence_parallel is True: - if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True: - accum_grad_obj_sp = get_grad_accumulate_object(param) - accum_grad_obj_sp.register_hook(reduce_grad_hook_sp) + # if sequence_parallel is True, + # the grad of norm should be all-reduce across the tp process group + if ( + gpc.config.parallel.sequence_parallel is True + and hasattr(param, IS_SEQUENCE_PARALLEL) + and getattr(param, IS_SEQUENCE_PARALLEL) is True + ): + accum_grad_obj.register_hook(reduce_grad_hook_sp) + # we should not only register for parameters which have _fstp_reduce_scatter_str attr. + # we must keep up with reduce_grad_hook. + if self._fstp_handler is not None: + accum_grad_obj.register_hook(accum_grad_hook) + + if self._overlap_sync_grad: accum_grad_obj.register_hook(reduce_grad_hook) - _define_and_attach(param, reduce_rank) + _define_and_attach(param, reduce_rank) + + def accumulate_left_grads_after_backward(self): + if self._fstp_handler is None: + return + + for group_id in range(self.num_param_groups): + self._accum_grads_store_in_bucket(self._accum_grad_buckets[group_id]) def belongs_to_current_rank(self, param) -> bool: """ @@ -633,10 +653,6 @@ class HybridZeroOptimizer(BaseOptimizer): if param.grad is not None: self._store_and_try_reduce_grads_by_bucket(param) - # we need to accumulate gradients left in the accumulate gardient bucket - for group_id in range(self.num_param_groups): - self._accum_grads_store_in_bucket(self._accum_grad_buckets[group_id], reduce_rank=None) - # we need to reduce the gradients left in the communication bucket for group_id in range(self.num_param_groups): self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True) diff --git a/train.py b/train.py index 4511762..644bbeb 100644 --- a/train.py +++ b/train.py @@ -5,7 +5,7 @@ import socket import time import traceback from functools import partial -from typing import List, Optional +from typing import List import torch import torch.distributed as dist @@ -70,9 +70,7 @@ def initialize_llm_logger(start_time: str): return uniscale_logger -def get_scheduler_hooks( - metric: Optional[AccPerplex] = None, activation_checkpoint: bool = False -) -> List[SchedulerHook]: +def get_scheduler_hooks(metric, zero_optim) -> List[SchedulerHook]: scheduler_hooks: List[SchedulerHook] = [] if metric is not None: @@ -87,9 +85,8 @@ def get_scheduler_hooks( ), ), ) - - if activation_checkpoint: - scheduler_hooks.append(FSTPOverlapSchedulerHook(gpc.fstp_handler)) + if gpc.fstp_handler is not None: + scheduler_hooks.append(FSTPOverlapSchedulerHook(gpc.fstp_handler, zero_optim)) return scheduler_hooks @@ -112,7 +109,7 @@ def main(args): global_world_size=gpc.get_world_size(ParallelMode.GLOBAL), mlp_ratio=gpc.config.MLP_RATIO, ) - + get_tflops_func_2 = partial( get_megatron_flops_2, checkpoint=gpc.config.model.checkpoint, @@ -196,7 +193,7 @@ def main(args): train_dataloader=train_dl, lr_scheduler=lr_scheduler, beta2_scheduler=beta2_scheduler, - scheduler_hooks=get_scheduler_hooks(metric, gpc.config.model.checkpoint), + scheduler_hooks=get_scheduler_hooks(metric, optimizer), ) # initialize simple memory profiler @@ -323,7 +320,7 @@ def main(args): if memory_profiler is not None: memory_profiler.step() - + if batch_count % 2 == 0: prof.step()