fix reduce scatter async bug

pull/407/head
chenxun.p 2023-10-17 15:10:07 +08:00
parent 229cc5c68c
commit 6682f5d92a
2 changed files with 4 additions and 4 deletions

View File

@ -371,12 +371,12 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
grad_weight_async, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True) grad_weight_async, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
assert hasattr(weight, "_fstp_reduce_scatter_str") assert hasattr(weight, "_fstp_reduce_scatter_str")
all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async) all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async)
grad_weight = torch.empty(grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:], dtype=grad_weight.dtype, device=grad_weight.device) grad_weight = torch.zeros(grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:], dtype=grad_weight.dtype, device=grad_weight.device)
if grad_bias is not None: if grad_bias is not None:
grad_bias_async, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True) grad_bias_async, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
assert hasattr(bias, "_fstp_reduce_scatter_str") assert hasattr(bias, "_fstp_reduce_scatter_str")
all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async) all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async)
grad_bias = torch.empty(grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:], dtype=grad_bias.dtype, device=grad_bias.device) grad_bias = torch.zeros(grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:], dtype=grad_bias.dtype, device=grad_bias.device)
else: else:
grad_weight = None grad_weight = None
grad_bias = grad_output if ctx.needs_input_grad[2] else None grad_bias = grad_output if ctx.needs_input_grad[2] else None

View File

@ -333,7 +333,7 @@ class HybridZeroOptimizer(BaseOptimizer):
key = getattr(_param, "_fstp_reduce_scatter_str") key = getattr(_param, "_fstp_reduce_scatter_str")
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key] comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
comm_handle.wait() comm_handle.wait()
_param.grad = _grad _param.grad += _grad
bucket.reset_by_rank(rank) bucket.reset_by_rank(rank)
@ -356,7 +356,7 @@ class HybridZeroOptimizer(BaseOptimizer):
key = getattr(_param, "_fstp_reduce_scatter_str") key = getattr(_param, "_fstp_reduce_scatter_str")
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key] comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
comm_handle.wait() comm_handle.wait()
_param.grad = _grad _param.grad += _grad
# reduce grad # reduce grad
if self.skip_grad_reduce is False: if self.skip_grad_reduce is False: