fix async reduce scatter

pull/436/head
mwiacx 2023-10-31 19:39:24 +08:00
parent bc5a85c624
commit 4c1cd5d49b
3 changed files with 81 additions and 69 deletions

View File

@ -328,13 +328,12 @@ class FSTPOverlapSchedulerHook(SchedulerHook):
SchedulerHook for fstp overlap handler
"""
def __init__(self, overlap_handler: FSTPOverlapHandler) -> None:
super().__init__()
def __init__(self, overlap_handler: FSTPOverlapHandler, zero_optim) -> None:
self._overlap_handler = overlap_handler
self._zero_optim = zero_optim
def before_forward(self, scheduler, inputs) -> None:
if self._overlap_handler is not None:
if self._overlap_handler.model_checkpoint:
self._overlap_handler.set_forward_mode(True)
def after_forward(self, scheduler, outputs) -> None:
@ -347,11 +346,11 @@ class FSTPOverlapSchedulerHook(SchedulerHook):
pass
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
if self._overlap_handler is not None:
if self._overlap_handler.model_checkpoint:
self._overlap_handler.set_forward_mode(False)
def after_backward(self, scheduler, inputs_grad) -> None:
pass
self._zero_optim.accumulate_left_grads_after_backward()
def post_helper_func(self, scheduler, outputs, label) -> None:
pass

View File

@ -66,10 +66,6 @@ class HybridZeroOptimizer(BaseOptimizer):
hysteresis = grad_scal_cfg.hysteresis
max_scale = grad_scal_cfg.max_scale
self._fstp_handler = None
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
self._fstp_handler = gpc.fstp_handler
# Zero related args
reduce_bucket_size = zero_cfg.reduce_bucket_size
clip_grad_norm = zero_cfg.clip_grad_norm
@ -133,6 +129,12 @@ class HybridZeroOptimizer(BaseOptimizer):
if self._overlap_sync_param:
assert self._param_bcast_sync_handler is not None
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
self._fstp_handler = gpc.fstp_handler
else:
self._fstp_handler = None
self._accum_grad_buckets: List[BucketStore] = []
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
# and add buffers to parameter store for future access
@ -221,8 +223,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# reduction hook is only used if overlapping communication
# if it is stage 1 without overlapping, no hook will be attached
if self._overlap_sync_grad:
self._attach_reduction_hook()
self._attach_reduction_hook()
@property
def zero_local_rank(self):
@ -289,60 +290,79 @@ class HybridZeroOptimizer(BaseOptimizer):
param_group = self._fp16_param_groups[group_id]
for param in param_group:
# we should not reduce the param in moe
if param.requires_grad:
reduce_rank = None
if not param.requires_grad:
continue
def _define_and_attach(param, reduce_rank=None):
# get the AccumulateGrad object of the param itself
# If these objects are not kept, reduction hooks may not be attached successfully.
accum_grad_obj = get_grad_accumulate_object(param)
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
reduce_rank = None
reduction_func = partial(
self._store_and_try_reduce_grads_by_bucket,
param=param,
reduce_rank=reduce_rank,
def _define_and_attach(param, reduce_rank=None):
reduction_func = partial(
self._store_and_try_reduce_grads_by_bucket,
param=param,
reduce_rank=reduce_rank,
)
reduce_scatter_checker = partial(
self._wait_reduce_scatter_and_accumulate_grads,
param=param,
reduce_rank=reduce_rank,
)
def reduction_sp_func():
handle = reduce_tensor(
param.grad,
dtype=None,
dst_rank=reduce_rank,
parallel_mode=ParallelMode.TENSOR,
)
handle.wait()
reduce_scatter_checker = partial(
self._wait_reduce_scatter_and_accumulate_grads,
param=param,
reduce_rank=reduce_rank,
)
def reduction_sp_func():
handle = reduce_tensor(
param.grad,
dtype=None,
dst_rank=reduce_rank,
parallel_mode=ParallelMode.TENSOR,
)
handle.wait()
# define hook
# NOT IMPORTANT BUT GOOD TO KNOW:
# args here is not grad, but allow_unreacable and accumulate_grad
def reduce_grad_hook(*args): # pylint: disable=W0613
if self.skip_grad_reduce is False:
reduction_func()
# define hook
# NOT IMPORTANT BUT GOOD TO KNOW:
# args here is not grad, but allow_unreacable and accumulate_grad
def reduce_grad_hook(*args): # pylint: disable=W0613
if self._fstp_handler is not None:
reduce_scatter_checker()
# define hook for real gradient accumulation.
def accum_grad_hook(*args): # pylint: disable=W0613
reduce_scatter_checker()
if self.skip_grad_reduce is False:
reduction_func()
# define hook for sequence_parallel
def reduce_grad_hook_sp(*args): # pylint: disable=W0613
if self.skip_grad_reduce is False:
reduction_sp_func()
# define hook for sequence_parallel
def reduce_grad_hook_sp(*args): # pylint: disable=W0613
if self.skip_grad_reduce is False:
reduction_sp_func()
# get the AccumulateGrad object of the param itself
# If these objects are not kept, reduction hooks may not be attached successfully.
accum_grad_obj = get_grad_accumulate_object(param)
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
# if sequence_parallel is True,
# the grad of norm should be all-reduce across the tp process group
if gpc.config.parallel.sequence_parallel is True:
if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True:
accum_grad_obj_sp = get_grad_accumulate_object(param)
accum_grad_obj_sp.register_hook(reduce_grad_hook_sp)
# if sequence_parallel is True,
# the grad of norm should be all-reduce across the tp process group
if (
gpc.config.parallel.sequence_parallel is True
and hasattr(param, IS_SEQUENCE_PARALLEL)
and getattr(param, IS_SEQUENCE_PARALLEL) is True
):
accum_grad_obj.register_hook(reduce_grad_hook_sp)
# we should not only register for parameters which have _fstp_reduce_scatter_str attr.
# we must keep up with reduce_grad_hook.
if self._fstp_handler is not None:
accum_grad_obj.register_hook(accum_grad_hook)
if self._overlap_sync_grad:
accum_grad_obj.register_hook(reduce_grad_hook)
_define_and_attach(param, reduce_rank)
_define_and_attach(param, reduce_rank)
def accumulate_left_grads_after_backward(self):
if self._fstp_handler is None:
return
for group_id in range(self.num_param_groups):
self._accum_grads_store_in_bucket(self._accum_grad_buckets[group_id])
def belongs_to_current_rank(self, param) -> bool:
"""
@ -633,10 +653,6 @@ class HybridZeroOptimizer(BaseOptimizer):
if param.grad is not None:
self._store_and_try_reduce_grads_by_bucket(param)
# we need to accumulate gradients left in the accumulate gardient bucket
for group_id in range(self.num_param_groups):
self._accum_grads_store_in_bucket(self._accum_grad_buckets[group_id], reduce_rank=None)
# we need to reduce the gradients left in the communication bucket
for group_id in range(self.num_param_groups):
self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True)

View File

@ -5,7 +5,7 @@ import socket
import time
import traceback
from functools import partial
from typing import List, Optional
from typing import List
import torch
import torch.distributed as dist
@ -70,9 +70,7 @@ def initialize_llm_logger(start_time: str):
return uniscale_logger
def get_scheduler_hooks(
metric: Optional[AccPerplex] = None, activation_checkpoint: bool = False
) -> List[SchedulerHook]:
def get_scheduler_hooks(metric, zero_optim) -> List[SchedulerHook]:
scheduler_hooks: List[SchedulerHook] = []
if metric is not None:
@ -87,9 +85,8 @@ def get_scheduler_hooks(
),
),
)
if activation_checkpoint:
scheduler_hooks.append(FSTPOverlapSchedulerHook(gpc.fstp_handler))
if gpc.fstp_handler is not None:
scheduler_hooks.append(FSTPOverlapSchedulerHook(gpc.fstp_handler, zero_optim))
return scheduler_hooks
@ -112,7 +109,7 @@ def main(args):
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
mlp_ratio=gpc.config.MLP_RATIO,
)
get_tflops_func_2 = partial(
get_megatron_flops_2,
checkpoint=gpc.config.model.checkpoint,
@ -196,7 +193,7 @@ def main(args):
train_dataloader=train_dl,
lr_scheduler=lr_scheduler,
beta2_scheduler=beta2_scheduler,
scheduler_hooks=get_scheduler_hooks(metric, gpc.config.model.checkpoint),
scheduler_hooks=get_scheduler_hooks(metric, optimizer),
)
# initialize simple memory profiler
@ -323,7 +320,7 @@ def main(args):
if memory_profiler is not None:
memory_profiler.step()
if batch_count % 2 == 0:
prof.step()