mirror of https://github.com/InternLM/InternLM
fix async reduce scatter
parent
bc5a85c624
commit
4c1cd5d49b
|
@ -328,13 +328,12 @@ class FSTPOverlapSchedulerHook(SchedulerHook):
|
||||||
SchedulerHook for fstp overlap handler
|
SchedulerHook for fstp overlap handler
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, overlap_handler: FSTPOverlapHandler) -> None:
|
def __init__(self, overlap_handler: FSTPOverlapHandler, zero_optim) -> None:
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self._overlap_handler = overlap_handler
|
self._overlap_handler = overlap_handler
|
||||||
|
self._zero_optim = zero_optim
|
||||||
|
|
||||||
def before_forward(self, scheduler, inputs) -> None:
|
def before_forward(self, scheduler, inputs) -> None:
|
||||||
if self._overlap_handler is not None:
|
if self._overlap_handler.model_checkpoint:
|
||||||
self._overlap_handler.set_forward_mode(True)
|
self._overlap_handler.set_forward_mode(True)
|
||||||
|
|
||||||
def after_forward(self, scheduler, outputs) -> None:
|
def after_forward(self, scheduler, outputs) -> None:
|
||||||
|
@ -347,11 +346,11 @@ class FSTPOverlapSchedulerHook(SchedulerHook):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
|
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
|
||||||
if self._overlap_handler is not None:
|
if self._overlap_handler.model_checkpoint:
|
||||||
self._overlap_handler.set_forward_mode(False)
|
self._overlap_handler.set_forward_mode(False)
|
||||||
|
|
||||||
def after_backward(self, scheduler, inputs_grad) -> None:
|
def after_backward(self, scheduler, inputs_grad) -> None:
|
||||||
pass
|
self._zero_optim.accumulate_left_grads_after_backward()
|
||||||
|
|
||||||
def post_helper_func(self, scheduler, outputs, label) -> None:
|
def post_helper_func(self, scheduler, outputs, label) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -66,10 +66,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
hysteresis = grad_scal_cfg.hysteresis
|
hysteresis = grad_scal_cfg.hysteresis
|
||||||
max_scale = grad_scal_cfg.max_scale
|
max_scale = grad_scal_cfg.max_scale
|
||||||
|
|
||||||
self._fstp_handler = None
|
|
||||||
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
|
|
||||||
self._fstp_handler = gpc.fstp_handler
|
|
||||||
|
|
||||||
# Zero related args
|
# Zero related args
|
||||||
reduce_bucket_size = zero_cfg.reduce_bucket_size
|
reduce_bucket_size = zero_cfg.reduce_bucket_size
|
||||||
clip_grad_norm = zero_cfg.clip_grad_norm
|
clip_grad_norm = zero_cfg.clip_grad_norm
|
||||||
|
@ -133,6 +129,12 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
if self._overlap_sync_param:
|
if self._overlap_sync_param:
|
||||||
assert self._param_bcast_sync_handler is not None
|
assert self._param_bcast_sync_handler is not None
|
||||||
|
|
||||||
|
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
|
||||||
|
self._fstp_handler = gpc.fstp_handler
|
||||||
|
else:
|
||||||
|
self._fstp_handler = None
|
||||||
|
self._accum_grad_buckets: List[BucketStore] = []
|
||||||
|
|
||||||
# iterate over the param group in the optimizer
|
# iterate over the param group in the optimizer
|
||||||
# partition these param groups for data parallel training
|
# partition these param groups for data parallel training
|
||||||
# and add buffers to parameter store for future access
|
# and add buffers to parameter store for future access
|
||||||
|
@ -221,8 +223,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
# reduction hook is only used if overlapping communication
|
# reduction hook is only used if overlapping communication
|
||||||
# if it is stage 1 without overlapping, no hook will be attached
|
# if it is stage 1 without overlapping, no hook will be attached
|
||||||
if self._overlap_sync_grad:
|
self._attach_reduction_hook()
|
||||||
self._attach_reduction_hook()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def zero_local_rank(self):
|
def zero_local_rank(self):
|
||||||
|
@ -289,60 +290,79 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
param_group = self._fp16_param_groups[group_id]
|
param_group = self._fp16_param_groups[group_id]
|
||||||
for param in param_group:
|
for param in param_group:
|
||||||
# we should not reduce the param in moe
|
# we should not reduce the param in moe
|
||||||
if param.requires_grad:
|
if not param.requires_grad:
|
||||||
reduce_rank = None
|
continue
|
||||||
|
|
||||||
def _define_and_attach(param, reduce_rank=None):
|
reduce_rank = None
|
||||||
# get the AccumulateGrad object of the param itself
|
|
||||||
# If these objects are not kept, reduction hooks may not be attached successfully.
|
|
||||||
accum_grad_obj = get_grad_accumulate_object(param)
|
|
||||||
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
|
|
||||||
|
|
||||||
reduction_func = partial(
|
def _define_and_attach(param, reduce_rank=None):
|
||||||
self._store_and_try_reduce_grads_by_bucket,
|
reduction_func = partial(
|
||||||
param=param,
|
self._store_and_try_reduce_grads_by_bucket,
|
||||||
reduce_rank=reduce_rank,
|
param=param,
|
||||||
|
reduce_rank=reduce_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
reduce_scatter_checker = partial(
|
||||||
|
self._wait_reduce_scatter_and_accumulate_grads,
|
||||||
|
param=param,
|
||||||
|
reduce_rank=reduce_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reduction_sp_func():
|
||||||
|
handle = reduce_tensor(
|
||||||
|
param.grad,
|
||||||
|
dtype=None,
|
||||||
|
dst_rank=reduce_rank,
|
||||||
|
parallel_mode=ParallelMode.TENSOR,
|
||||||
)
|
)
|
||||||
|
handle.wait()
|
||||||
|
|
||||||
reduce_scatter_checker = partial(
|
# define hook
|
||||||
self._wait_reduce_scatter_and_accumulate_grads,
|
# NOT IMPORTANT BUT GOOD TO KNOW:
|
||||||
param=param,
|
# args here is not grad, but allow_unreacable and accumulate_grad
|
||||||
reduce_rank=reduce_rank,
|
def reduce_grad_hook(*args): # pylint: disable=W0613
|
||||||
)
|
if self.skip_grad_reduce is False:
|
||||||
def reduction_sp_func():
|
reduction_func()
|
||||||
handle = reduce_tensor(
|
|
||||||
param.grad,
|
|
||||||
dtype=None,
|
|
||||||
dst_rank=reduce_rank,
|
|
||||||
parallel_mode=ParallelMode.TENSOR,
|
|
||||||
)
|
|
||||||
handle.wait()
|
|
||||||
|
|
||||||
# define hook
|
# define hook for real gradient accumulation.
|
||||||
# NOT IMPORTANT BUT GOOD TO KNOW:
|
def accum_grad_hook(*args): # pylint: disable=W0613
|
||||||
# args here is not grad, but allow_unreacable and accumulate_grad
|
reduce_scatter_checker()
|
||||||
def reduce_grad_hook(*args): # pylint: disable=W0613
|
|
||||||
if self._fstp_handler is not None:
|
|
||||||
reduce_scatter_checker()
|
|
||||||
|
|
||||||
if self.skip_grad_reduce is False:
|
# define hook for sequence_parallel
|
||||||
reduction_func()
|
def reduce_grad_hook_sp(*args): # pylint: disable=W0613
|
||||||
|
if self.skip_grad_reduce is False:
|
||||||
|
reduction_sp_func()
|
||||||
|
|
||||||
# define hook for sequence_parallel
|
# get the AccumulateGrad object of the param itself
|
||||||
def reduce_grad_hook_sp(*args): # pylint: disable=W0613
|
# If these objects are not kept, reduction hooks may not be attached successfully.
|
||||||
if self.skip_grad_reduce is False:
|
accum_grad_obj = get_grad_accumulate_object(param)
|
||||||
reduction_sp_func()
|
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
|
||||||
|
|
||||||
# if sequence_parallel is True,
|
# if sequence_parallel is True,
|
||||||
# the grad of norm should be all-reduce across the tp process group
|
# the grad of norm should be all-reduce across the tp process group
|
||||||
if gpc.config.parallel.sequence_parallel is True:
|
if (
|
||||||
if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True:
|
gpc.config.parallel.sequence_parallel is True
|
||||||
accum_grad_obj_sp = get_grad_accumulate_object(param)
|
and hasattr(param, IS_SEQUENCE_PARALLEL)
|
||||||
accum_grad_obj_sp.register_hook(reduce_grad_hook_sp)
|
and getattr(param, IS_SEQUENCE_PARALLEL) is True
|
||||||
|
):
|
||||||
|
accum_grad_obj.register_hook(reduce_grad_hook_sp)
|
||||||
|
|
||||||
|
# we should not only register for parameters which have _fstp_reduce_scatter_str attr.
|
||||||
|
# we must keep up with reduce_grad_hook.
|
||||||
|
if self._fstp_handler is not None:
|
||||||
|
accum_grad_obj.register_hook(accum_grad_hook)
|
||||||
|
|
||||||
|
if self._overlap_sync_grad:
|
||||||
accum_grad_obj.register_hook(reduce_grad_hook)
|
accum_grad_obj.register_hook(reduce_grad_hook)
|
||||||
|
|
||||||
_define_and_attach(param, reduce_rank)
|
_define_and_attach(param, reduce_rank)
|
||||||
|
|
||||||
|
def accumulate_left_grads_after_backward(self):
|
||||||
|
if self._fstp_handler is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
for group_id in range(self.num_param_groups):
|
||||||
|
self._accum_grads_store_in_bucket(self._accum_grad_buckets[group_id])
|
||||||
|
|
||||||
def belongs_to_current_rank(self, param) -> bool:
|
def belongs_to_current_rank(self, param) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -633,10 +653,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
if param.grad is not None:
|
if param.grad is not None:
|
||||||
self._store_and_try_reduce_grads_by_bucket(param)
|
self._store_and_try_reduce_grads_by_bucket(param)
|
||||||
|
|
||||||
# we need to accumulate gradients left in the accumulate gardient bucket
|
|
||||||
for group_id in range(self.num_param_groups):
|
|
||||||
self._accum_grads_store_in_bucket(self._accum_grad_buckets[group_id], reduce_rank=None)
|
|
||||||
|
|
||||||
# we need to reduce the gradients left in the communication bucket
|
# we need to reduce the gradients left in the communication bucket
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True)
|
self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True)
|
||||||
|
|
13
train.py
13
train.py
|
@ -5,7 +5,7 @@ import socket
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List, Optional
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -70,9 +70,7 @@ def initialize_llm_logger(start_time: str):
|
||||||
return uniscale_logger
|
return uniscale_logger
|
||||||
|
|
||||||
|
|
||||||
def get_scheduler_hooks(
|
def get_scheduler_hooks(metric, zero_optim) -> List[SchedulerHook]:
|
||||||
metric: Optional[AccPerplex] = None, activation_checkpoint: bool = False
|
|
||||||
) -> List[SchedulerHook]:
|
|
||||||
scheduler_hooks: List[SchedulerHook] = []
|
scheduler_hooks: List[SchedulerHook] = []
|
||||||
|
|
||||||
if metric is not None:
|
if metric is not None:
|
||||||
|
@ -87,9 +85,8 @@ def get_scheduler_hooks(
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
if gpc.fstp_handler is not None:
|
||||||
if activation_checkpoint:
|
scheduler_hooks.append(FSTPOverlapSchedulerHook(gpc.fstp_handler, zero_optim))
|
||||||
scheduler_hooks.append(FSTPOverlapSchedulerHook(gpc.fstp_handler))
|
|
||||||
|
|
||||||
return scheduler_hooks
|
return scheduler_hooks
|
||||||
|
|
||||||
|
@ -196,7 +193,7 @@ def main(args):
|
||||||
train_dataloader=train_dl,
|
train_dataloader=train_dl,
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
beta2_scheduler=beta2_scheduler,
|
beta2_scheduler=beta2_scheduler,
|
||||||
scheduler_hooks=get_scheduler_hooks(metric, gpc.config.model.checkpoint),
|
scheduler_hooks=get_scheduler_hooks(metric, optimizer),
|
||||||
)
|
)
|
||||||
|
|
||||||
# initialize simple memory profiler
|
# initialize simple memory profiler
|
||||||
|
|
Loading…
Reference in New Issue