From 95488d8e8f1737947c4f9a00f888d9f57e6ea606 Mon Sep 17 00:00:00 2001 From: "chenxun.p" Date: Fri, 20 Oct 2023 15:58:06 +0800 Subject: [PATCH] update optimizer accumulate grad impl when fstp --- .../core/scheduler/no_pipeline_scheduler.py | 1 - .../solver/optimizer/hybrid_zero_optim.py | 133 +++++++----------- 2 files changed, 51 insertions(+), 83 deletions(-) diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index f0caf05..56661d8 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -194,7 +194,6 @@ class NonPipelineScheduler(BaseScheduler): _output, _loss, _moe_loss = self._train_one_batch( _data, _label, engine, forward_only, return_loss, self._grad_accum_size ) - engine.optimizer.reset_reduce_bucket() if return_loss: loss += _loss diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 96a54c0..2c14c65 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -2,6 +2,7 @@ # -*- encoding: utf-8 -*- import math +from typing import Optional, List from functools import partial import torch @@ -40,8 +41,20 @@ from .utils import compute_norm inf = math.inf logger = get_logger(__file__) + def print_memory(msg): - print(msg, " rank = ", gpc.get_global_rank(), " memory allocated: ", torch.cuda.memory_allocated() / 1024 / 1024 / 1024, " reverved memory: ", torch.cuda.memory_reserved() / 1024 / 1024 / 1024, " max memory: ", torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, flush=True) + print( + msg, + " rank = ", + gpc.get_global_rank(), + " memory allocated: ", + torch.cuda.memory_allocated() / 1024 / 1024 / 1024, + " reverved memory: ", + torch.cuda.memory_reserved() / 1024 / 1024 / 1024, + " max memory: ", + torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, + flush=True, + ) print("===========================================") @@ -69,7 +82,7 @@ class HybridZeroOptimizer(BaseOptimizer): backoff_factor = grad_scal_cfg.backoff_factor hysteresis = grad_scal_cfg.hysteresis max_scale = grad_scal_cfg.max_scale - + if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True: self._fstp_handler = gpc.config.fstp_handler @@ -90,8 +103,8 @@ class HybridZeroOptimizer(BaseOptimizer): # it will not manage the tensors used by mixed precision training self._param_store = ParameterStore(ParallelMode.ZERO1) self._grad_store = GradientStore(ParallelMode.DATA) - self._bucket_store = [] - self._bucket_store_2 = [] + self._bucket_store: List[BucketStore] = [] + self._accum_grad_buckets: List[BucketStore] = [] self._bucket_in_progress = [] # fp16 and fp32 params for mixed precision training @@ -160,7 +173,7 @@ class HybridZeroOptimizer(BaseOptimizer): # TODO _broadcast_parallel_mode is not only used in broadcast, maybe can change its name self._broadcast_parallel_mode.append(zero_mode) self._bucket_store.append(BucketStore(group_id, param_group["dp_mode"])) - self._bucket_store_2.append(BucketStore(group_id, param_group["dp_mode"])) + self._accum_grad_buckets.append(BucketStore(group_id, param_group["dp_mode"])) # assign parameters to ranks the params in the list are sorted params_per_rank, no_params_ranks = self._partition_param_list(group_id, param_group) @@ -306,9 +319,9 @@ class HybridZeroOptimizer(BaseOptimizer): param=param, reduce_rank=reduce_rank, ) - + reduce_scatter_checker = partial( - self._wait_reduce_scatter_and_accumulate_grad, + self._wait_reduce_scatter_and_accumulate_grads, param=param, reduce_rank=reduce_rank, ) @@ -317,7 +330,7 @@ class HybridZeroOptimizer(BaseOptimizer): # NOT IMPORTANT BUT GOOD TO KNOW: # args here is not grad, but allow_unreacable and accumulate_grad def reduce_grad_hook(*args): # pylint: disable=W0613 - if gpc.config.fstp_handler is not None: + if self._fstp_handler is not None: reduce_scatter_checker() if self.skip_grad_reduce is False: @@ -341,84 +354,36 @@ class HybridZeroOptimizer(BaseOptimizer): group_id = getattr(param, "group_id") return tensor_rank == gpc.get_local_rank(self._broadcast_parallel_mode[group_id]) - def reset_reduce_bucket(self) -> None: - for bucket in self._bucket_store_2: - for rank, params in bucket._params.items(): - for _param in params: - if not hasattr(_param, "_fstp_reduce_scatter_str"): - continue + def _accum_grads_store_in_bucket(self, bucket: BucketStore, reduce_rank: Optional[int] = None) -> None: + for _param in bucket.get_param(reduce_rank): + if not hasattr(_param, "_fstp_reduce_scatter_str"): + continue - key = getattr(_param, "_fstp_reduce_scatter_str") - comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key] - comm_handle.wait() - _param.grad.add_(_grad) - # self._fstp_handler.reduce_scatter_handlers[key] = None - # del _grad - release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index) - del self._fstp_handler.reduce_scatter_handlers[key] - self._fstp_handler.reduce_scatter_handlers[key] = None - assert key in self._fstp_handler.reduce_scatter_handlers - # if not hasattr(_param, "_fstp_all_reduce_str"): - # continue + # wait and accumulate gardient. + _key = getattr(_param, "_fstp_reduce_scatter_str") + _comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[_key] + _comm_handle.wait() + _param.grad.add_(_grad) - # key = getattr(_param, "_fstp_all_reduce_str") - # comm_handle, _grad = self._fstp_handler.all_reduce_handlers[key] - # comm_handle.wait() - # with torch.no_grad(): - # _grad = split_forward_gather_backward(_grad, ParallelMode.TENSOR, dim=0) - # _param.grad.add_(_grad) - # # self._fstp_handler.reduce_scatter_handlers[key] = None - # del _grad - # del self._fstp_handler.all_reduce_handlers[key] - # self._fstp_handler.all_reduce_handlers[key] = None - # assert key in self._fstp_handler.all_reduce_handlers + # release cuda memory. + self._fstp_handler.reduce_scatter_handlers[_key] = None + _grad = None - bucket.reset_by_rank(rank) - - def _wait_reduce_scatter_and_accumulate_grad(self, param, reduce_rank=None): + bucket.reset_by_rank(reduce_rank) + + def _wait_reduce_scatter_and_accumulate_grads(self, param, reduce_rank: Optional[int] = None): param_size = param.numel() + group_id = getattr(param, "group_id") + current_bucket = self._accum_grad_buckets[group_id] + # check if the bucket is full # if full, will reduce the grads already in the bucket # after reduction, the bucket will be empty - group_id = getattr(param, "group_id") - current_bucket = self._bucket_store_2[group_id] + if current_bucket.num_elements_in_bucket(reduce_rank) >= self._reduce_bucket_size: + self._accum_grads_store_in_bucket(current_bucket, reduce_rank) - if current_bucket.num_elements_in_bucket(reduce_rank) >= 512 * 1024 * 1024: - # wait reduce scatter communication - params = current_bucket.get_param(reduce_rank) - for _param in params: - if not hasattr(_param, "_fstp_reduce_scatter_str"): - continue - - key = getattr(_param, "_fstp_reduce_scatter_str") - comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key] - comm_handle.wait() - _param.grad.add_(_grad) - # self._fstp_handler.reduce_scatter_handlers[key] = None - # del _grad - release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index) - del self._fstp_handler.reduce_scatter_handlers[key] - self._fstp_handler.reduce_scatter_handlers[key] = None - assert key in self._fstp_handler.reduce_scatter_handlers - - # if not hasattr(_param, "_fstp_all_reduce_str"): - # continue - - # key = getattr(_param, "_fstp_all_reduce_str") - # comm_handle, _grad = self._fstp_handler.all_reduce_handlers[key] - # comm_handle.wait() - # with torch.no_grad(): - # _grad = split_forward_gather_backward(_grad, ParallelMode.TENSOR, dim=0) - # _param.grad.add_(_grad) - # # self._fstp_handler.reduce_scatter_handlers[key] = None - # del _grad - # del self._fstp_handler.all_reduce_handlers[key] - # self._fstp_handler.all_reduce_handlers[key] = None - # assert key in self._fstp_handler.all_reduce_handlers - - current_bucket.reset_by_rank(reduce_rank) - + # otherwise, add the parameter into bucket. current_bucket.add_num_elements_in_bucket(param_size, reduce_rank) current_bucket.add_param(param, reduce_rank) @@ -646,6 +611,10 @@ class HybridZeroOptimizer(BaseOptimizer): for group_id in range(self.num_param_groups): self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True) + # we need to accumulate gradients left in the accumulate gardient bucket + for group_id in range(self.num_param_groups): + self._accum_grads_store_in_bucket(self._accum_grad_buckets[group_id], reduce_rank=None) + # compute norm for gradients in the before bucket groups_norms = [] for group_id in range(self.num_param_groups): @@ -685,16 +654,16 @@ class HybridZeroOptimizer(BaseOptimizer): timer("sync_grad").start() self._sync_grad() timer("sync_grad").stop() - + print_memory("No 4") - + try: - res = self._step(closure=closure, norms=total_norms) + res = self._step(closure=closure, norms=total_norms) except torch.cuda.OutOfMemoryError as e: print(e, flush=True) print(torch.cuda.memory_summary(), flush=True) torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") - + return res def _step(self, closure=None, norms=None): @@ -822,7 +791,7 @@ class HybridZeroOptimizer(BaseOptimizer): torch.cuda.synchronize() with torch.cuda.stream(self._comm_bcast_stream): self.broadcast_params() - + timer("step").stop() # update gradients may not be needed here, because the sync_params function is used in initialization,