mirror of https://github.com/InternLM/InternLM
reset the sp allreduce in optimizer
parent
9b1265c591
commit
7c6d2936b3
|
@ -308,6 +308,15 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
reduce_rank=reduce_rank,
|
reduce_rank=reduce_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def reduction_sp_func():
|
||||||
|
handle = reduce_tensor(
|
||||||
|
param.grad,
|
||||||
|
dtype=None,
|
||||||
|
dst_rank=reduce_rank,
|
||||||
|
parallel_mode=ParallelMode.TENSOR,
|
||||||
|
)
|
||||||
|
handle.wait()
|
||||||
|
|
||||||
# define hook
|
# define hook
|
||||||
# NOT IMPORTANT BUT GOOD TO KNOW:
|
# NOT IMPORTANT BUT GOOD TO KNOW:
|
||||||
# args here is not grad, but allow_unreacable and accumulate_grad
|
# args here is not grad, but allow_unreacable and accumulate_grad
|
||||||
|
@ -319,11 +328,25 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
def accum_grad_hook(*args): # pylint: disable=W0613
|
def accum_grad_hook(*args): # pylint: disable=W0613
|
||||||
reduce_scatter_checker()
|
reduce_scatter_checker()
|
||||||
|
|
||||||
|
# define hook for sequence_parallel
|
||||||
|
def reduce_grad_hook_sp(*args): # pylint: disable=W0613
|
||||||
|
if self.skip_grad_reduce is False:
|
||||||
|
reduction_sp_func()
|
||||||
|
|
||||||
# get the AccumulateGrad object of the param itself
|
# get the AccumulateGrad object of the param itself
|
||||||
# If these objects are not kept, reduction hooks may not be attached successfully.
|
# If these objects are not kept, reduction hooks may not be attached successfully.
|
||||||
accum_grad_obj = get_grad_accumulate_object(param)
|
accum_grad_obj = get_grad_accumulate_object(param)
|
||||||
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
|
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
|
||||||
|
|
||||||
|
# if sequence_parallel is True,
|
||||||
|
# the grad of norm should be all-reduce across the tp process group
|
||||||
|
if (
|
||||||
|
gpc.config.parallel.sequence_parallel is True
|
||||||
|
and hasattr(param, IS_SEQUENCE_PARALLEL)
|
||||||
|
and getattr(param, IS_SEQUENCE_PARALLEL) is True
|
||||||
|
):
|
||||||
|
accum_grad_obj.register_hook(reduce_grad_hook_sp)
|
||||||
|
|
||||||
# we should not only register for parameters which have _fstp_reduce_scatter_str attr.
|
# we should not only register for parameters which have _fstp_reduce_scatter_str attr.
|
||||||
# we must keep up with reduce_grad_hook.
|
# we must keep up with reduce_grad_hook.
|
||||||
if self._fstp_handler is not None and self._reduce_scatter_overlap is True:
|
if self._fstp_handler is not None and self._reduce_scatter_overlap is True:
|
||||||
|
@ -621,26 +644,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
"""
|
"""
|
||||||
assert closure is None, "closure is not supported by step()"
|
assert closure is None, "closure is not supported by step()"
|
||||||
|
|
||||||
# do all-reduce for layernorm when sequence_parallel is True
|
|
||||||
if gpc.config.parallel.sequence_parallel is True:
|
|
||||||
for group_id in range(len(self._fp16_param_groups)):
|
|
||||||
norm_bucket = TensorBucket(size=0)
|
|
||||||
for param in self._fp16_param_groups[group_id]:
|
|
||||||
if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True:
|
|
||||||
norm_bucket.add_to_bucket(param.grad, allow_oversize=True)
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
if not norm_bucket.is_empty():
|
|
||||||
norm_bucket.flatten()
|
|
||||||
norm_bucket.commu_handle = reduce_tensor(
|
|
||||||
tensor=norm_bucket.get_flat_tensor(),
|
|
||||||
dtype=None,
|
|
||||||
dst_rank=None,
|
|
||||||
parallel_mode=ParallelMode.TENSOR,
|
|
||||||
)
|
|
||||||
norm_bucket.commu_handle.wait()
|
|
||||||
norm_bucket.unflatten_and_copy()
|
|
||||||
# norm_bucket.empty()
|
|
||||||
|
|
||||||
# if not overlapping communication (no reduction hook is attached)
|
# if not overlapping communication (no reduction hook is attached)
|
||||||
# we need to manually reduce these gradients
|
# we need to manually reduce these gradients
|
||||||
if not self._overlap_sync_grad:
|
if not self._overlap_sync_grad:
|
||||||
|
|
Loading…
Reference in New Issue