mirror of https://github.com/InternLM/InternLM
fix async reduce scatter
parent
bc5a85c624
commit
4c1cd5d49b
|
@ -328,13 +328,12 @@ class FSTPOverlapSchedulerHook(SchedulerHook):
|
|||
SchedulerHook for fstp overlap handler
|
||||
"""
|
||||
|
||||
def __init__(self, overlap_handler: FSTPOverlapHandler) -> None:
|
||||
super().__init__()
|
||||
|
||||
def __init__(self, overlap_handler: FSTPOverlapHandler, zero_optim) -> None:
|
||||
self._overlap_handler = overlap_handler
|
||||
self._zero_optim = zero_optim
|
||||
|
||||
def before_forward(self, scheduler, inputs) -> None:
|
||||
if self._overlap_handler is not None:
|
||||
if self._overlap_handler.model_checkpoint:
|
||||
self._overlap_handler.set_forward_mode(True)
|
||||
|
||||
def after_forward(self, scheduler, outputs) -> None:
|
||||
|
@ -347,11 +346,11 @@ class FSTPOverlapSchedulerHook(SchedulerHook):
|
|||
pass
|
||||
|
||||
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
|
||||
if self._overlap_handler is not None:
|
||||
if self._overlap_handler.model_checkpoint:
|
||||
self._overlap_handler.set_forward_mode(False)
|
||||
|
||||
def after_backward(self, scheduler, inputs_grad) -> None:
|
||||
pass
|
||||
self._zero_optim.accumulate_left_grads_after_backward()
|
||||
|
||||
def post_helper_func(self, scheduler, outputs, label) -> None:
|
||||
pass
|
||||
|
|
|
@ -66,10 +66,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
hysteresis = grad_scal_cfg.hysteresis
|
||||
max_scale = grad_scal_cfg.max_scale
|
||||
|
||||
self._fstp_handler = None
|
||||
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
|
||||
self._fstp_handler = gpc.fstp_handler
|
||||
|
||||
# Zero related args
|
||||
reduce_bucket_size = zero_cfg.reduce_bucket_size
|
||||
clip_grad_norm = zero_cfg.clip_grad_norm
|
||||
|
@ -133,6 +129,12 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
if self._overlap_sync_param:
|
||||
assert self._param_bcast_sync_handler is not None
|
||||
|
||||
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
|
||||
self._fstp_handler = gpc.fstp_handler
|
||||
else:
|
||||
self._fstp_handler = None
|
||||
self._accum_grad_buckets: List[BucketStore] = []
|
||||
|
||||
# iterate over the param group in the optimizer
|
||||
# partition these param groups for data parallel training
|
||||
# and add buffers to parameter store for future access
|
||||
|
@ -221,8 +223,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
# reduction hook is only used if overlapping communication
|
||||
# if it is stage 1 without overlapping, no hook will be attached
|
||||
if self._overlap_sync_grad:
|
||||
self._attach_reduction_hook()
|
||||
self._attach_reduction_hook()
|
||||
|
||||
@property
|
||||
def zero_local_rank(self):
|
||||
|
@ -289,60 +290,79 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
param_group = self._fp16_param_groups[group_id]
|
||||
for param in param_group:
|
||||
# we should not reduce the param in moe
|
||||
if param.requires_grad:
|
||||
reduce_rank = None
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
|
||||
def _define_and_attach(param, reduce_rank=None):
|
||||
# get the AccumulateGrad object of the param itself
|
||||
# If these objects are not kept, reduction hooks may not be attached successfully.
|
||||
accum_grad_obj = get_grad_accumulate_object(param)
|
||||
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
|
||||
reduce_rank = None
|
||||
|
||||
reduction_func = partial(
|
||||
self._store_and_try_reduce_grads_by_bucket,
|
||||
param=param,
|
||||
reduce_rank=reduce_rank,
|
||||
def _define_and_attach(param, reduce_rank=None):
|
||||
reduction_func = partial(
|
||||
self._store_and_try_reduce_grads_by_bucket,
|
||||
param=param,
|
||||
reduce_rank=reduce_rank,
|
||||
)
|
||||
|
||||
reduce_scatter_checker = partial(
|
||||
self._wait_reduce_scatter_and_accumulate_grads,
|
||||
param=param,
|
||||
reduce_rank=reduce_rank,
|
||||
)
|
||||
|
||||
def reduction_sp_func():
|
||||
handle = reduce_tensor(
|
||||
param.grad,
|
||||
dtype=None,
|
||||
dst_rank=reduce_rank,
|
||||
parallel_mode=ParallelMode.TENSOR,
|
||||
)
|
||||
handle.wait()
|
||||
|
||||
reduce_scatter_checker = partial(
|
||||
self._wait_reduce_scatter_and_accumulate_grads,
|
||||
param=param,
|
||||
reduce_rank=reduce_rank,
|
||||
)
|
||||
def reduction_sp_func():
|
||||
handle = reduce_tensor(
|
||||
param.grad,
|
||||
dtype=None,
|
||||
dst_rank=reduce_rank,
|
||||
parallel_mode=ParallelMode.TENSOR,
|
||||
)
|
||||
handle.wait()
|
||||
# define hook
|
||||
# NOT IMPORTANT BUT GOOD TO KNOW:
|
||||
# args here is not grad, but allow_unreacable and accumulate_grad
|
||||
def reduce_grad_hook(*args): # pylint: disable=W0613
|
||||
if self.skip_grad_reduce is False:
|
||||
reduction_func()
|
||||
|
||||
# define hook
|
||||
# NOT IMPORTANT BUT GOOD TO KNOW:
|
||||
# args here is not grad, but allow_unreacable and accumulate_grad
|
||||
def reduce_grad_hook(*args): # pylint: disable=W0613
|
||||
if self._fstp_handler is not None:
|
||||
reduce_scatter_checker()
|
||||
# define hook for real gradient accumulation.
|
||||
def accum_grad_hook(*args): # pylint: disable=W0613
|
||||
reduce_scatter_checker()
|
||||
|
||||
if self.skip_grad_reduce is False:
|
||||
reduction_func()
|
||||
# define hook for sequence_parallel
|
||||
def reduce_grad_hook_sp(*args): # pylint: disable=W0613
|
||||
if self.skip_grad_reduce is False:
|
||||
reduction_sp_func()
|
||||
|
||||
# define hook for sequence_parallel
|
||||
def reduce_grad_hook_sp(*args): # pylint: disable=W0613
|
||||
if self.skip_grad_reduce is False:
|
||||
reduction_sp_func()
|
||||
# get the AccumulateGrad object of the param itself
|
||||
# If these objects are not kept, reduction hooks may not be attached successfully.
|
||||
accum_grad_obj = get_grad_accumulate_object(param)
|
||||
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
|
||||
|
||||
# if sequence_parallel is True,
|
||||
# the grad of norm should be all-reduce across the tp process group
|
||||
if gpc.config.parallel.sequence_parallel is True:
|
||||
if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True:
|
||||
accum_grad_obj_sp = get_grad_accumulate_object(param)
|
||||
accum_grad_obj_sp.register_hook(reduce_grad_hook_sp)
|
||||
# if sequence_parallel is True,
|
||||
# the grad of norm should be all-reduce across the tp process group
|
||||
if (
|
||||
gpc.config.parallel.sequence_parallel is True
|
||||
and hasattr(param, IS_SEQUENCE_PARALLEL)
|
||||
and getattr(param, IS_SEQUENCE_PARALLEL) is True
|
||||
):
|
||||
accum_grad_obj.register_hook(reduce_grad_hook_sp)
|
||||
|
||||
# we should not only register for parameters which have _fstp_reduce_scatter_str attr.
|
||||
# we must keep up with reduce_grad_hook.
|
||||
if self._fstp_handler is not None:
|
||||
accum_grad_obj.register_hook(accum_grad_hook)
|
||||
|
||||
if self._overlap_sync_grad:
|
||||
accum_grad_obj.register_hook(reduce_grad_hook)
|
||||
|
||||
_define_and_attach(param, reduce_rank)
|
||||
_define_and_attach(param, reduce_rank)
|
||||
|
||||
def accumulate_left_grads_after_backward(self):
|
||||
if self._fstp_handler is None:
|
||||
return
|
||||
|
||||
for group_id in range(self.num_param_groups):
|
||||
self._accum_grads_store_in_bucket(self._accum_grad_buckets[group_id])
|
||||
|
||||
def belongs_to_current_rank(self, param) -> bool:
|
||||
"""
|
||||
|
@ -633,10 +653,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
if param.grad is not None:
|
||||
self._store_and_try_reduce_grads_by_bucket(param)
|
||||
|
||||
# we need to accumulate gradients left in the accumulate gardient bucket
|
||||
for group_id in range(self.num_param_groups):
|
||||
self._accum_grads_store_in_bucket(self._accum_grad_buckets[group_id], reduce_rank=None)
|
||||
|
||||
# we need to reduce the gradients left in the communication bucket
|
||||
for group_id in range(self.num_param_groups):
|
||||
self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True)
|
||||
|
|
17
train.py
17
train.py
|
@ -5,7 +5,7 @@ import socket
|
|||
import time
|
||||
import traceback
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -70,9 +70,7 @@ def initialize_llm_logger(start_time: str):
|
|||
return uniscale_logger
|
||||
|
||||
|
||||
def get_scheduler_hooks(
|
||||
metric: Optional[AccPerplex] = None, activation_checkpoint: bool = False
|
||||
) -> List[SchedulerHook]:
|
||||
def get_scheduler_hooks(metric, zero_optim) -> List[SchedulerHook]:
|
||||
scheduler_hooks: List[SchedulerHook] = []
|
||||
|
||||
if metric is not None:
|
||||
|
@ -87,9 +85,8 @@ def get_scheduler_hooks(
|
|||
),
|
||||
),
|
||||
)
|
||||
|
||||
if activation_checkpoint:
|
||||
scheduler_hooks.append(FSTPOverlapSchedulerHook(gpc.fstp_handler))
|
||||
if gpc.fstp_handler is not None:
|
||||
scheduler_hooks.append(FSTPOverlapSchedulerHook(gpc.fstp_handler, zero_optim))
|
||||
|
||||
return scheduler_hooks
|
||||
|
||||
|
@ -112,7 +109,7 @@ def main(args):
|
|||
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
|
||||
mlp_ratio=gpc.config.MLP_RATIO,
|
||||
)
|
||||
|
||||
|
||||
get_tflops_func_2 = partial(
|
||||
get_megatron_flops_2,
|
||||
checkpoint=gpc.config.model.checkpoint,
|
||||
|
@ -196,7 +193,7 @@ def main(args):
|
|||
train_dataloader=train_dl,
|
||||
lr_scheduler=lr_scheduler,
|
||||
beta2_scheduler=beta2_scheduler,
|
||||
scheduler_hooks=get_scheduler_hooks(metric, gpc.config.model.checkpoint),
|
||||
scheduler_hooks=get_scheduler_hooks(metric, optimizer),
|
||||
)
|
||||
|
||||
# initialize simple memory profiler
|
||||
|
@ -323,7 +320,7 @@ def main(args):
|
|||
|
||||
if memory_profiler is not None:
|
||||
memory_profiler.step()
|
||||
|
||||
|
||||
if batch_count % 2 == 0:
|
||||
prof.step()
|
||||
|
||||
|
|
Loading…
Reference in New Issue