From 6682f5d92a02111777f5c1fbc8c0765c9770ffa2 Mon Sep 17 00:00:00 2001 From: "chenxun.p" Date: Tue, 17 Oct 2023 15:10:07 +0800 Subject: [PATCH] fix reduce scatter async bug --- internlm/model/utils.py | 4 ++-- internlm/solver/optimizer/hybrid_zero_optim.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 78ad456..0194e84 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -371,12 +371,12 @@ class FSTPFusedDenseFunc(torch.autograd.Function): grad_weight_async, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True) assert hasattr(weight, "_fstp_reduce_scatter_str") all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async) - grad_weight = torch.empty(grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:], dtype=grad_weight.dtype, device=grad_weight.device) + grad_weight = torch.zeros(grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:], dtype=grad_weight.dtype, device=grad_weight.device) if grad_bias is not None: grad_bias_async, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True) assert hasattr(bias, "_fstp_reduce_scatter_str") all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async) - grad_bias = torch.empty(grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:], dtype=grad_bias.dtype, device=grad_bias.device) + grad_bias = torch.zeros(grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:], dtype=grad_bias.dtype, device=grad_bias.device) else: grad_weight = None grad_bias = grad_output if ctx.needs_input_grad[2] else None diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index c6e9aab..950d35e 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -333,7 +333,7 @@ class HybridZeroOptimizer(BaseOptimizer): key = getattr(_param, "_fstp_reduce_scatter_str") comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key] comm_handle.wait() - _param.grad = _grad + _param.grad += _grad bucket.reset_by_rank(rank) @@ -356,7 +356,7 @@ class HybridZeroOptimizer(BaseOptimizer): key = getattr(_param, "_fstp_reduce_scatter_str") comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key] comm_handle.wait() - _param.grad = _grad + _param.grad += _grad # reduce grad if self.skip_grad_reduce is False: