mirror of https://github.com/InternLM/InternLM
merge origin
commit
b80e6cdcf3
|
@ -163,7 +163,7 @@ pipeline parallel (dict):
|
||||||
"""
|
"""
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
zero1=dict(size=-1, fsdp=False),
|
zero1=dict(size=-1, fsdp=False),
|
||||||
tensor=dict(size=4, sp="intern", intern_overlap=True, reduce_scatter_overlap=True),
|
tensor=dict(size=4, sp="intern", intern_overlap=True),
|
||||||
pipeline=dict(size=1, interleaved_overlap=True),
|
pipeline=dict(size=1, interleaved_overlap=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -73,7 +73,6 @@ class FSTPOverlapHandler:
|
||||||
|
|
||||||
setattr(child, "_fstp_name", name)
|
setattr(child, "_fstp_name", name)
|
||||||
|
|
||||||
if gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False):
|
|
||||||
_full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
|
_full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
|
||||||
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
|
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
|
||||||
if child.bias is not None:
|
if child.bias is not None:
|
||||||
|
|
|
@ -568,7 +568,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
||||||
)
|
)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
if overlap_handler is not None and gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False):
|
if overlap_handler is not None:
|
||||||
grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(
|
grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(
|
||||||
grad_weight, process_group, async_op=True
|
grad_weight, process_group, async_op=True
|
||||||
)
|
)
|
||||||
|
@ -621,14 +621,16 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
del total_weight
|
del total_weight
|
||||||
|
|
||||||
if ctx.needs_input_grad[1]:
|
if ctx.needs_input_grad[1]:
|
||||||
if world_size > 1 and not (overlap_handler is not None and gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False)):
|
if world_size > 1 and overlap_handler is None:
|
||||||
handle_grad_weight.wait()
|
handle_grad_weight.wait()
|
||||||
if grad_bias is not None:
|
if grad_bias is not None:
|
||||||
handle_grad_bias.wait()
|
handle_grad_bias.wait()
|
||||||
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
|
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class FSTPFusedDenseFuncTorch(FSTPFusedDenseFunc):
|
class FSTPFusedDenseFuncTorch(FSTPFusedDenseFunc):
|
||||||
"FusedDenseFunc for FSTP, which is optimized based on flash implementation."
|
"FusedDenseFunc for FSTP, which is optimized based on flash implementation."
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_bwd
|
@custom_bwd
|
||||||
def backward(ctx, grad_output, *args):
|
def backward(ctx, grad_output, *args):
|
||||||
|
@ -667,7 +669,7 @@ class FSTPFusedDenseFuncTorch(FSTPFusedDenseFunc):
|
||||||
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
||||||
)
|
)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
if overlap_handler is not None and gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False):
|
if overlap_handler is not None:
|
||||||
grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(
|
grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(
|
||||||
grad_weight, process_group, async_op=True
|
grad_weight, process_group, async_op=True
|
||||||
)
|
)
|
||||||
|
@ -720,12 +722,13 @@ class FSTPFusedDenseFuncTorch(FSTPFusedDenseFunc):
|
||||||
del total_weight
|
del total_weight
|
||||||
|
|
||||||
if ctx.needs_input_grad[1]:
|
if ctx.needs_input_grad[1]:
|
||||||
if world_size > 1 and not (overlap_handler is not None and gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False)):
|
if world_size > 1 and overlap_handler is None:
|
||||||
handle_grad_weight.wait()
|
handle_grad_weight.wait()
|
||||||
if grad_bias is not None:
|
if grad_bias is not None:
|
||||||
handle_grad_bias.wait()
|
handle_grad_bias.wait()
|
||||||
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
|
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def fused_dense_func_torch(
|
def fused_dense_func_torch(
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
weight: Tensor,
|
weight: Tensor,
|
||||||
|
|
|
@ -133,7 +133,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
self._fstp_handler = gpc.fstp_handler
|
self._fstp_handler = gpc.fstp_handler
|
||||||
else:
|
else:
|
||||||
self._fstp_handler = None
|
self._fstp_handler = None
|
||||||
self._reduce_scatter_overlap = gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False)
|
|
||||||
|
|
||||||
# iterate over the param group in the optimizer
|
# iterate over the param group in the optimizer
|
||||||
# partition these param groups for data parallel training
|
# partition these param groups for data parallel training
|
||||||
|
@ -349,7 +348,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
# we should not only register for parameters which have _fstp_reduce_scatter_str attr.
|
# we should not only register for parameters which have _fstp_reduce_scatter_str attr.
|
||||||
# we must keep up with reduce_grad_hook.
|
# we must keep up with reduce_grad_hook.
|
||||||
if self._fstp_handler is not None and self._reduce_scatter_overlap is True:
|
if self._fstp_handler is not None:
|
||||||
accum_grad_obj.register_hook(accum_grad_hook)
|
accum_grad_obj.register_hook(accum_grad_hook)
|
||||||
|
|
||||||
if self._overlap_sync_grad:
|
if self._overlap_sync_grad:
|
||||||
|
@ -358,7 +357,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
_define_and_attach(param, reduce_rank)
|
_define_and_attach(param, reduce_rank)
|
||||||
|
|
||||||
def accumulate_left_grads_after_backward(self):
|
def accumulate_left_grads_after_backward(self):
|
||||||
if self._fstp_handler is None or self._reduce_scatter_overlap is False:
|
if self._fstp_handler is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
|
@ -644,6 +643,27 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
"""
|
"""
|
||||||
assert closure is None, "closure is not supported by step()"
|
assert closure is None, "closure is not supported by step()"
|
||||||
|
|
||||||
|
<<<<<<< HEAD
|
||||||
|
=======
|
||||||
|
# do all-reduce for layernorm when sequence_parallel is True
|
||||||
|
if gpc.config.parallel.sequence_parallel is True:
|
||||||
|
for group_id in range(len(self._fp16_param_groups)):
|
||||||
|
norm_bucket = TensorBucket(size=0)
|
||||||
|
for param in self._fp16_param_groups[group_id]:
|
||||||
|
if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True:
|
||||||
|
norm_bucket.add_to_bucket(param.grad, allow_oversize=True)
|
||||||
|
if not norm_bucket.is_empty():
|
||||||
|
norm_bucket.flatten()
|
||||||
|
norm_bucket.commu_handle = reduce_tensor(
|
||||||
|
tensor=norm_bucket.get_flat_tensor(),
|
||||||
|
dtype=None,
|
||||||
|
dst_rank=None,
|
||||||
|
parallel_mode=ParallelMode.TENSOR,
|
||||||
|
)
|
||||||
|
norm_bucket.commu_handle.wait()
|
||||||
|
norm_bucket.unflatten_and_copy()
|
||||||
|
|
||||||
|
>>>>>>> c517ec5b8cdf9c675f97dcc615bfd39c2ffda010
|
||||||
# if not overlapping communication (no reduction hook is attached)
|
# if not overlapping communication (no reduction hook is attached)
|
||||||
# we need to manually reduce these gradients
|
# we need to manually reduce these gradients
|
||||||
if not self._overlap_sync_grad:
|
if not self._overlap_sync_grad:
|
||||||
|
|
Loading…
Reference in New Issue