mirror of https://github.com/InternLM/InternLM
change the order of dp and sp all-reduce
parent
1655a90f34
commit
1bc3c33b75
|
@ -318,15 +318,14 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
if self.skip_grad_reduce is False:
|
if self.skip_grad_reduce is False:
|
||||||
reduction_sp_func()
|
reduction_sp_func()
|
||||||
|
|
||||||
|
|
||||||
accum_grad_obj.register_hook(reduce_grad_hook)
|
|
||||||
|
|
||||||
# if sequence_parallel is True, the grad of norm should be all-reduce across the tp process group
|
# if sequence_parallel is True, the grad of norm should be all-reduce across the tp process group
|
||||||
if gpc.config.parallel.sequence_parallel is True:
|
if gpc.config.parallel.sequence_parallel is True:
|
||||||
if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True:
|
if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True:
|
||||||
accum_grad_obj_sp = get_grad_accumulate_object(param)
|
accum_grad_obj_sp = get_grad_accumulate_object(param)
|
||||||
accum_grad_obj_sp.register_hook(reduce_grad_hook_sp)
|
accum_grad_obj_sp.register_hook(reduce_grad_hook_sp)
|
||||||
|
|
||||||
|
accum_grad_obj.register_hook(reduce_grad_hook)
|
||||||
|
|
||||||
_define_and_attach(param, reduce_rank)
|
_define_and_attach(param, reduce_rank)
|
||||||
|
|
||||||
def belongs_to_current_rank(self, param) -> bool:
|
def belongs_to_current_rank(self, param) -> bool:
|
||||||
|
|
Loading…
Reference in New Issue