mirror of https://github.com/InternLM/InternLM
modify the sp allreduce and support tf32 for fstp linear
parent
5a18b3b651
commit
9b1265c591
|
@ -152,6 +152,8 @@ fstp_logs/
|
||||||
configs/7B_train/*
|
configs/7B_train/*
|
||||||
configs/13B_train/*
|
configs/13B_train/*
|
||||||
configs/30B_train/*
|
configs/30B_train/*
|
||||||
|
configs/test_loss/*
|
||||||
|
loss_tensorboard/*
|
||||||
|
|
||||||
atb
|
atb
|
||||||
pip
|
pip
|
||||||
|
|
|
@ -39,6 +39,14 @@ for idx, root_name in enumerate(root_names):
|
||||||
|
|
||||||
log_name = root_name + "_" + output_file_name[:-3]
|
log_name = root_name + "_" + output_file_name[:-3]
|
||||||
|
|
||||||
|
skip = True
|
||||||
|
|
||||||
|
if sp_mode == "intern" and intern_overlap[i] is True:
|
||||||
|
skip = False
|
||||||
|
|
||||||
|
if skip:
|
||||||
|
continue
|
||||||
|
|
||||||
print(log_name)
|
print(log_name)
|
||||||
command = f"srun -p llm_s -N 8 -n 64 --ntasks-per-node=8 --gpus-per-task=1 --time=30 python train.py --config {write_file} --profiling 2>&1 | tee ./fstp_logs/{log_name}.log"
|
command = f"srun -p llm_s -N 8 -n 64 --ntasks-per-node=8 --gpus-per-task=1 --time=30 python train.py --config {write_file} --profiling 2>&1 | tee ./fstp_logs/{log_name}.log"
|
||||||
process = subprocess.Popen(command, shell=True, executable='/bin/bash')
|
process = subprocess.Popen(command, shell=True, executable='/bin/bash')
|
||||||
|
|
|
@ -627,6 +627,104 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
handle_grad_bias.wait()
|
handle_grad_bias.wait()
|
||||||
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
|
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
|
||||||
|
|
||||||
|
class FSTPFusedDenseFuncTorch(FSTPFusedDenseFunc):
|
||||||
|
"FusedDenseFunc for FSTP, which is optimized based on flash implementation."
|
||||||
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
|
def backward(ctx, grad_output, *args):
|
||||||
|
grad_output = grad_output.contiguous()
|
||||||
|
if ctx.return_residual:
|
||||||
|
(grad_input,) = args
|
||||||
|
grad_input = grad_input.contiguous()
|
||||||
|
process_group = ctx.process_group
|
||||||
|
overlap_handler = ctx.overlap_handler
|
||||||
|
module = ctx.module
|
||||||
|
|
||||||
|
if ctx.compute_weight_gradient:
|
||||||
|
x, weight, bias = ctx.saved_tensors
|
||||||
|
total_x = x
|
||||||
|
else:
|
||||||
|
weight, bias = ctx.saved_tensors
|
||||||
|
total_x = None
|
||||||
|
batch_shape = grad_output.shape[:-1]
|
||||||
|
batch_dim = batch_shape.numel()
|
||||||
|
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
||||||
|
|
||||||
|
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
|
if world_size > 1:
|
||||||
|
if overlap_handler is not None:
|
||||||
|
total_weight = gpc.fstp_handler.get_all_gather_memory(module=module)
|
||||||
|
else:
|
||||||
|
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||||
|
handle_weight.wait()
|
||||||
|
else:
|
||||||
|
total_weight = weight
|
||||||
|
|
||||||
|
# compute weight grad
|
||||||
|
if ctx.needs_input_grad[1]:
|
||||||
|
assert ctx.compute_weight_gradient
|
||||||
|
grad_weight, grad_bias = linear_bias_wgrad_torch(
|
||||||
|
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
||||||
|
)
|
||||||
|
if world_size > 1:
|
||||||
|
if overlap_handler is not None and gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False):
|
||||||
|
grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(
|
||||||
|
grad_weight, process_group, async_op=True
|
||||||
|
)
|
||||||
|
assert hasattr(weight, "_fstp_reduce_scatter_str")
|
||||||
|
overlap_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (
|
||||||
|
handle_grad_weight,
|
||||||
|
grad_weight_async,
|
||||||
|
)
|
||||||
|
grad_weight = overlap_handler.get_zero_by_shape(
|
||||||
|
(
|
||||||
|
grad_weight.shape[0] // torch.distributed.get_world_size(process_group),
|
||||||
|
*grad_weight.shape[1:],
|
||||||
|
),
|
||||||
|
dtype=grad_weight.dtype,
|
||||||
|
device=grad_weight.device,
|
||||||
|
)
|
||||||
|
if grad_bias is not None:
|
||||||
|
grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool(
|
||||||
|
grad_bias, process_group, async_op=True
|
||||||
|
)
|
||||||
|
assert hasattr(bias, "_fstp_reduce_scatter_str")
|
||||||
|
overlap_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (
|
||||||
|
handle_grad_bias,
|
||||||
|
grad_bias_async,
|
||||||
|
)
|
||||||
|
grad_bias = overlap_handler.get_zero_by_shape(
|
||||||
|
(
|
||||||
|
grad_bias.shape[0] // torch.distributed.get_world_size(process_group),
|
||||||
|
*grad_bias.shape[1:],
|
||||||
|
),
|
||||||
|
dtype=grad_bias.dtype,
|
||||||
|
device=grad_bias.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
|
||||||
|
if grad_bias is not None:
|
||||||
|
grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
|
||||||
|
else:
|
||||||
|
grad_weight = None
|
||||||
|
grad_bias = grad_output if ctx.needs_input_grad[2] else None
|
||||||
|
|
||||||
|
if ctx.needs_input_grad[0]:
|
||||||
|
if not ctx.return_residual:
|
||||||
|
grad_input = F.linear(grad_output, total_weight.t())
|
||||||
|
else:
|
||||||
|
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, total_weight)
|
||||||
|
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
||||||
|
else:
|
||||||
|
grad_input = None
|
||||||
|
del total_weight
|
||||||
|
|
||||||
|
if ctx.needs_input_grad[1]:
|
||||||
|
if world_size > 1 and not (overlap_handler is not None and gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False)):
|
||||||
|
handle_grad_weight.wait()
|
||||||
|
if grad_bias is not None:
|
||||||
|
handle_grad_bias.wait()
|
||||||
|
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
|
||||||
|
|
||||||
def fused_dense_func_torch(
|
def fused_dense_func_torch(
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
|
@ -683,9 +781,7 @@ def fstp_fused_dense_func(
|
||||||
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
||||||
return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, module, handler)
|
return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, module, handler)
|
||||||
else:
|
else:
|
||||||
assert process_group is None
|
return FSTPFusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, module, handler)
|
||||||
out = F.linear(x, weight, bias)
|
|
||||||
return out if not return_residual else (out, x)
|
|
||||||
|
|
||||||
|
|
||||||
def try_import_RMSNorm():
|
def try_import_RMSNorm():
|
||||||
|
|
|
@ -308,15 +308,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
reduce_rank=reduce_rank,
|
reduce_rank=reduce_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
def reduction_sp_func():
|
|
||||||
handle = reduce_tensor(
|
|
||||||
param.grad,
|
|
||||||
dtype=None,
|
|
||||||
dst_rank=reduce_rank,
|
|
||||||
parallel_mode=ParallelMode.TENSOR,
|
|
||||||
)
|
|
||||||
handle.wait()
|
|
||||||
|
|
||||||
# define hook
|
# define hook
|
||||||
# NOT IMPORTANT BUT GOOD TO KNOW:
|
# NOT IMPORTANT BUT GOOD TO KNOW:
|
||||||
# args here is not grad, but allow_unreacable and accumulate_grad
|
# args here is not grad, but allow_unreacable and accumulate_grad
|
||||||
|
@ -328,25 +319,11 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
def accum_grad_hook(*args): # pylint: disable=W0613
|
def accum_grad_hook(*args): # pylint: disable=W0613
|
||||||
reduce_scatter_checker()
|
reduce_scatter_checker()
|
||||||
|
|
||||||
# define hook for sequence_parallel
|
|
||||||
def reduce_grad_hook_sp(*args): # pylint: disable=W0613
|
|
||||||
if self.skip_grad_reduce is False:
|
|
||||||
reduction_sp_func()
|
|
||||||
|
|
||||||
# get the AccumulateGrad object of the param itself
|
# get the AccumulateGrad object of the param itself
|
||||||
# If these objects are not kept, reduction hooks may not be attached successfully.
|
# If these objects are not kept, reduction hooks may not be attached successfully.
|
||||||
accum_grad_obj = get_grad_accumulate_object(param)
|
accum_grad_obj = get_grad_accumulate_object(param)
|
||||||
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
|
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
|
||||||
|
|
||||||
# if sequence_parallel is True,
|
|
||||||
# the grad of norm should be all-reduce across the tp process group
|
|
||||||
if (
|
|
||||||
gpc.config.parallel.sequence_parallel is True
|
|
||||||
and hasattr(param, IS_SEQUENCE_PARALLEL)
|
|
||||||
and getattr(param, IS_SEQUENCE_PARALLEL) is True
|
|
||||||
):
|
|
||||||
accum_grad_obj.register_hook(reduce_grad_hook_sp)
|
|
||||||
|
|
||||||
# we should not only register for parameters which have _fstp_reduce_scatter_str attr.
|
# we should not only register for parameters which have _fstp_reduce_scatter_str attr.
|
||||||
# we must keep up with reduce_grad_hook.
|
# we must keep up with reduce_grad_hook.
|
||||||
if self._fstp_handler is not None and self._reduce_scatter_overlap is True:
|
if self._fstp_handler is not None and self._reduce_scatter_overlap is True:
|
||||||
|
@ -644,6 +621,26 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
"""
|
"""
|
||||||
assert closure is None, "closure is not supported by step()"
|
assert closure is None, "closure is not supported by step()"
|
||||||
|
|
||||||
|
# do all-reduce for layernorm when sequence_parallel is True
|
||||||
|
if gpc.config.parallel.sequence_parallel is True:
|
||||||
|
for group_id in range(len(self._fp16_param_groups)):
|
||||||
|
norm_bucket = TensorBucket(size=0)
|
||||||
|
for param in self._fp16_param_groups[group_id]:
|
||||||
|
if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True:
|
||||||
|
norm_bucket.add_to_bucket(param.grad, allow_oversize=True)
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
|
if not norm_bucket.is_empty():
|
||||||
|
norm_bucket.flatten()
|
||||||
|
norm_bucket.commu_handle = reduce_tensor(
|
||||||
|
tensor=norm_bucket.get_flat_tensor(),
|
||||||
|
dtype=None,
|
||||||
|
dst_rank=None,
|
||||||
|
parallel_mode=ParallelMode.TENSOR,
|
||||||
|
)
|
||||||
|
norm_bucket.commu_handle.wait()
|
||||||
|
norm_bucket.unflatten_and_copy()
|
||||||
|
# norm_bucket.empty()
|
||||||
|
|
||||||
# if not overlapping communication (no reduction hook is attached)
|
# if not overlapping communication (no reduction hook is attached)
|
||||||
# we need to manually reduce these gradients
|
# we need to manually reduce these gradients
|
||||||
if not self._overlap_sync_grad:
|
if not self._overlap_sync_grad:
|
||||||
|
|
Loading…
Reference in New Issue