mirror of https://github.com/InternLM/InternLM
fix the bug so that the sequence parallel norm is all-reduced when overlap is False (#534)
parent
d904730be7
commit
432bd5ee9f
|
@ -219,10 +219,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled.
|
# flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled.
|
||||||
self.skip_grad_reduce = False
|
self.skip_grad_reduce = False
|
||||||
|
|
||||||
# reduction hook is only used if overlapping communication
|
self._attach_reduction_hook()
|
||||||
# if it is stage 1 without overlapping, no hook will be attached
|
|
||||||
if self._overlap_sync_grad:
|
|
||||||
self._attach_reduction_hook()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def zero_local_rank(self):
|
def zero_local_rank(self):
|
||||||
|
@ -321,12 +318,15 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
# if sequence_parallel is True,
|
# if sequence_parallel is True,
|
||||||
# the grad of norm should be all-reduce across the tp process group
|
# the grad of norm should be all-reduce across the tp process group
|
||||||
if gpc.config.parallel.sequence_parallel is True:
|
if (
|
||||||
if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True:
|
gpc.config.parallel.sequence_parallel is True
|
||||||
accum_grad_obj_sp = get_grad_accumulate_object(param)
|
and hasattr(param, IS_SEQUENCE_PARALLEL)
|
||||||
accum_grad_obj_sp.register_hook(reduce_grad_hook_sp)
|
and getattr(param, IS_SEQUENCE_PARALLEL) is True
|
||||||
|
):
|
||||||
|
accum_grad_obj.register_hook(reduce_grad_hook_sp)
|
||||||
|
|
||||||
accum_grad_obj.register_hook(reduce_grad_hook)
|
if self._overlap_sync_grad:
|
||||||
|
accum_grad_obj.register_hook(reduce_grad_hook)
|
||||||
|
|
||||||
_define_and_attach(param, reduce_rank)
|
_define_and_attach(param, reduce_rank)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue