mirror of https://github.com/InternLM/InternLM
fix(optimizer/hybrid_zero_optim.py): add reduce_scatter_overlap switch
parent
6b843253eb
commit
b3def4c162
|
@ -57,7 +57,7 @@ data = dict(
|
|||
# defaults to 0, means disable evaluate
|
||||
valid_every=50,
|
||||
pack_sample_into_one=True,
|
||||
total_steps=50,
|
||||
total_steps=10,
|
||||
skip_batches="",
|
||||
rampup_batch_size="",
|
||||
# Datasets with less than 50 rows will be discarded
|
||||
|
@ -163,7 +163,7 @@ pipeline parallel (dict):
|
|||
"""
|
||||
parallel = dict(
|
||||
zero1=dict(size=-1, fsdp=False),
|
||||
tensor=dict(size=8, sp="intern", intern_overlap=True),
|
||||
tensor=dict(size=8, sp="intern", intern_overlap=True, reduce_scatter_overlap=True),
|
||||
pipeline=dict(size=1, interleaved_overlap=True),
|
||||
)
|
||||
|
||||
|
|
|
@ -70,10 +70,11 @@ class FSTPOverlapHandler:
|
|||
|
||||
setattr(child, "_fstp_name", name)
|
||||
|
||||
_full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
|
||||
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
|
||||
if child.bias is not None:
|
||||
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
|
||||
if gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False):
|
||||
_full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
|
||||
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
|
||||
if child.bias is not None:
|
||||
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
|
||||
|
||||
self.num_blocks = len(self.index_to_fstp_modules)
|
||||
|
||||
|
|
|
@ -568,7 +568,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
||||
)
|
||||
if world_size > 1:
|
||||
if overlap_handler is not None:
|
||||
if overlap_handler is not None and gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False):
|
||||
grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(
|
||||
grad_weight, process_group, async_op=True
|
||||
)
|
||||
|
@ -621,7 +621,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
del total_weight
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
if world_size > 1 and overlap_handler is None:
|
||||
if world_size > 1 and not (overlap_handler is not None and gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False)):
|
||||
handle_grad_weight.wait()
|
||||
if grad_bias is not None:
|
||||
handle_grad_bias.wait()
|
||||
|
|
|
@ -133,6 +133,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self._fstp_handler = gpc.fstp_handler
|
||||
else:
|
||||
self._fstp_handler = None
|
||||
self._reduce_scatter_overlap = gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False)
|
||||
|
||||
# iterate over the param group in the optimizer
|
||||
# partition these param groups for data parallel training
|
||||
|
@ -348,7 +349,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
# we should not only register for parameters which have _fstp_reduce_scatter_str attr.
|
||||
# we must keep up with reduce_grad_hook.
|
||||
if self._fstp_handler is not None:
|
||||
if self._fstp_handler is not None and self._reduce_scatter_overlap is True:
|
||||
accum_grad_obj.register_hook(accum_grad_hook)
|
||||
|
||||
if self._overlap_sync_grad:
|
||||
|
@ -357,7 +358,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
_define_and_attach(param, reduce_rank)
|
||||
|
||||
def accumulate_left_grads_after_backward(self):
|
||||
if self._fstp_handler is None:
|
||||
if self._fstp_handler is None or self._reduce_scatter_overlap is False:
|
||||
return
|
||||
|
||||
for group_id in range(self.num_param_groups):
|
||||
|
|
Loading…
Reference in New Issue