fix async reduce scatter

pull/436/head
mwiacx 2023-10-31 19:39:24 +08:00
parent bc5a85c624
commit 4c1cd5d49b
3 changed files with 81 additions and 69 deletions

View File

@ -328,13 +328,12 @@ class FSTPOverlapSchedulerHook(SchedulerHook):
SchedulerHook for fstp overlap handler SchedulerHook for fstp overlap handler
""" """
def __init__(self, overlap_handler: FSTPOverlapHandler) -> None: def __init__(self, overlap_handler: FSTPOverlapHandler, zero_optim) -> None:
super().__init__()
self._overlap_handler = overlap_handler self._overlap_handler = overlap_handler
self._zero_optim = zero_optim
def before_forward(self, scheduler, inputs) -> None: def before_forward(self, scheduler, inputs) -> None:
if self._overlap_handler is not None: if self._overlap_handler.model_checkpoint:
self._overlap_handler.set_forward_mode(True) self._overlap_handler.set_forward_mode(True)
def after_forward(self, scheduler, outputs) -> None: def after_forward(self, scheduler, outputs) -> None:
@ -347,11 +346,11 @@ class FSTPOverlapSchedulerHook(SchedulerHook):
pass pass
def before_backward(self, scheduler, outputs, outputs_grad) -> None: def before_backward(self, scheduler, outputs, outputs_grad) -> None:
if self._overlap_handler is not None: if self._overlap_handler.model_checkpoint:
self._overlap_handler.set_forward_mode(False) self._overlap_handler.set_forward_mode(False)
def after_backward(self, scheduler, inputs_grad) -> None: def after_backward(self, scheduler, inputs_grad) -> None:
pass self._zero_optim.accumulate_left_grads_after_backward()
def post_helper_func(self, scheduler, outputs, label) -> None: def post_helper_func(self, scheduler, outputs, label) -> None:
pass pass

View File

@ -66,10 +66,6 @@ class HybridZeroOptimizer(BaseOptimizer):
hysteresis = grad_scal_cfg.hysteresis hysteresis = grad_scal_cfg.hysteresis
max_scale = grad_scal_cfg.max_scale max_scale = grad_scal_cfg.max_scale
self._fstp_handler = None
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
self._fstp_handler = gpc.fstp_handler
# Zero related args # Zero related args
reduce_bucket_size = zero_cfg.reduce_bucket_size reduce_bucket_size = zero_cfg.reduce_bucket_size
clip_grad_norm = zero_cfg.clip_grad_norm clip_grad_norm = zero_cfg.clip_grad_norm
@ -133,6 +129,12 @@ class HybridZeroOptimizer(BaseOptimizer):
if self._overlap_sync_param: if self._overlap_sync_param:
assert self._param_bcast_sync_handler is not None assert self._param_bcast_sync_handler is not None
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
self._fstp_handler = gpc.fstp_handler
else:
self._fstp_handler = None
self._accum_grad_buckets: List[BucketStore] = []
# iterate over the param group in the optimizer # iterate over the param group in the optimizer
# partition these param groups for data parallel training # partition these param groups for data parallel training
# and add buffers to parameter store for future access # and add buffers to parameter store for future access
@ -221,8 +223,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# reduction hook is only used if overlapping communication # reduction hook is only used if overlapping communication
# if it is stage 1 without overlapping, no hook will be attached # if it is stage 1 without overlapping, no hook will be attached
if self._overlap_sync_grad: self._attach_reduction_hook()
self._attach_reduction_hook()
@property @property
def zero_local_rank(self): def zero_local_rank(self):
@ -289,60 +290,79 @@ class HybridZeroOptimizer(BaseOptimizer):
param_group = self._fp16_param_groups[group_id] param_group = self._fp16_param_groups[group_id]
for param in param_group: for param in param_group:
# we should not reduce the param in moe # we should not reduce the param in moe
if param.requires_grad: if not param.requires_grad:
reduce_rank = None continue
def _define_and_attach(param, reduce_rank=None): reduce_rank = None
# get the AccumulateGrad object of the param itself
# If these objects are not kept, reduction hooks may not be attached successfully.
accum_grad_obj = get_grad_accumulate_object(param)
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
reduction_func = partial( def _define_and_attach(param, reduce_rank=None):
self._store_and_try_reduce_grads_by_bucket, reduction_func = partial(
param=param, self._store_and_try_reduce_grads_by_bucket,
reduce_rank=reduce_rank, param=param,
reduce_rank=reduce_rank,
)
reduce_scatter_checker = partial(
self._wait_reduce_scatter_and_accumulate_grads,
param=param,
reduce_rank=reduce_rank,
)
def reduction_sp_func():
handle = reduce_tensor(
param.grad,
dtype=None,
dst_rank=reduce_rank,
parallel_mode=ParallelMode.TENSOR,
) )
handle.wait()
reduce_scatter_checker = partial( # define hook
self._wait_reduce_scatter_and_accumulate_grads, # NOT IMPORTANT BUT GOOD TO KNOW:
param=param, # args here is not grad, but allow_unreacable and accumulate_grad
reduce_rank=reduce_rank, def reduce_grad_hook(*args): # pylint: disable=W0613
) if self.skip_grad_reduce is False:
def reduction_sp_func(): reduction_func()
handle = reduce_tensor(
param.grad,
dtype=None,
dst_rank=reduce_rank,
parallel_mode=ParallelMode.TENSOR,
)
handle.wait()
# define hook # define hook for real gradient accumulation.
# NOT IMPORTANT BUT GOOD TO KNOW: def accum_grad_hook(*args): # pylint: disable=W0613
# args here is not grad, but allow_unreacable and accumulate_grad reduce_scatter_checker()
def reduce_grad_hook(*args): # pylint: disable=W0613
if self._fstp_handler is not None:
reduce_scatter_checker()
if self.skip_grad_reduce is False: # define hook for sequence_parallel
reduction_func() def reduce_grad_hook_sp(*args): # pylint: disable=W0613
if self.skip_grad_reduce is False:
reduction_sp_func()
# define hook for sequence_parallel # get the AccumulateGrad object of the param itself
def reduce_grad_hook_sp(*args): # pylint: disable=W0613 # If these objects are not kept, reduction hooks may not be attached successfully.
if self.skip_grad_reduce is False: accum_grad_obj = get_grad_accumulate_object(param)
reduction_sp_func() self._grad_store.add_accumulate_grad_object(accum_grad_obj)
# if sequence_parallel is True, # if sequence_parallel is True,
# the grad of norm should be all-reduce across the tp process group # the grad of norm should be all-reduce across the tp process group
if gpc.config.parallel.sequence_parallel is True: if (
if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True: gpc.config.parallel.sequence_parallel is True
accum_grad_obj_sp = get_grad_accumulate_object(param) and hasattr(param, IS_SEQUENCE_PARALLEL)
accum_grad_obj_sp.register_hook(reduce_grad_hook_sp) and getattr(param, IS_SEQUENCE_PARALLEL) is True
):
accum_grad_obj.register_hook(reduce_grad_hook_sp)
# we should not only register for parameters which have _fstp_reduce_scatter_str attr.
# we must keep up with reduce_grad_hook.
if self._fstp_handler is not None:
accum_grad_obj.register_hook(accum_grad_hook)
if self._overlap_sync_grad:
accum_grad_obj.register_hook(reduce_grad_hook) accum_grad_obj.register_hook(reduce_grad_hook)
_define_and_attach(param, reduce_rank) _define_and_attach(param, reduce_rank)
def accumulate_left_grads_after_backward(self):
if self._fstp_handler is None:
return
for group_id in range(self.num_param_groups):
self._accum_grads_store_in_bucket(self._accum_grad_buckets[group_id])
def belongs_to_current_rank(self, param) -> bool: def belongs_to_current_rank(self, param) -> bool:
""" """
@ -633,10 +653,6 @@ class HybridZeroOptimizer(BaseOptimizer):
if param.grad is not None: if param.grad is not None:
self._store_and_try_reduce_grads_by_bucket(param) self._store_and_try_reduce_grads_by_bucket(param)
# we need to accumulate gradients left in the accumulate gardient bucket
for group_id in range(self.num_param_groups):
self._accum_grads_store_in_bucket(self._accum_grad_buckets[group_id], reduce_rank=None)
# we need to reduce the gradients left in the communication bucket # we need to reduce the gradients left in the communication bucket
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True) self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True)

View File

@ -5,7 +5,7 @@ import socket
import time import time
import traceback import traceback
from functools import partial from functools import partial
from typing import List, Optional from typing import List
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -70,9 +70,7 @@ def initialize_llm_logger(start_time: str):
return uniscale_logger return uniscale_logger
def get_scheduler_hooks( def get_scheduler_hooks(metric, zero_optim) -> List[SchedulerHook]:
metric: Optional[AccPerplex] = None, activation_checkpoint: bool = False
) -> List[SchedulerHook]:
scheduler_hooks: List[SchedulerHook] = [] scheduler_hooks: List[SchedulerHook] = []
if metric is not None: if metric is not None:
@ -87,9 +85,8 @@ def get_scheduler_hooks(
), ),
), ),
) )
if gpc.fstp_handler is not None:
if activation_checkpoint: scheduler_hooks.append(FSTPOverlapSchedulerHook(gpc.fstp_handler, zero_optim))
scheduler_hooks.append(FSTPOverlapSchedulerHook(gpc.fstp_handler))
return scheduler_hooks return scheduler_hooks
@ -196,7 +193,7 @@ def main(args):
train_dataloader=train_dl, train_dataloader=train_dl,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
beta2_scheduler=beta2_scheduler, beta2_scheduler=beta2_scheduler,
scheduler_hooks=get_scheduler_hooks(metric, gpc.config.model.checkpoint), scheduler_hooks=get_scheduler_hooks(metric, optimizer),
) )
# initialize simple memory profiler # initialize simple memory profiler