overlap grad_input computation and grad_weight reduce_scatter

pull/407/head
yingtongxiong 2023-10-10 17:06:13 +08:00
parent db637542a6
commit 0fac845c36
3 changed files with 24 additions and 18 deletions

View File

@ -1,7 +1,7 @@
JOB_NAME = "7b_train"
DO_ALERT = False
SEQ_LEN = 2048
SEQ_LEN = 4096
HIDDEN_SIZE = 4096
NUM_ATTENTION_HEAD = 32
MLP_RATIO = 8 / 3
@ -154,10 +154,10 @@ pipeline parallel (dict):
tensor parallel: tensor parallel size, usually the number of GPUs per node.
"""
parallel = dict(
zero1=dict(size=8, fsdp=False),
tensor=dict(size=1, mode='origin_tp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True
zero1=dict(size=1, fsdp=False),
tensor=dict(size=8, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True
pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=False,
sequence_parallel=True,
)
cudnn_deterministic = False

View File

@ -407,10 +407,11 @@ class PackedFlashInternLm1D(nn.Module):
if hasattr(self, "norm"):
hidden_states = self.norm(hidden_states.float())
if hasattr(self, "head"):
# Evaluation
if hidden_states.ndim == 3:
hidden_states = self.head(hidden_states, gather_dim=1)
else:
hidden_states = self.head(hidden_states)
else: # Training
hidden_states = self.head(hidden_states, gather_dim=0)
if not self.parallel_output:
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)

View File

@ -349,16 +349,8 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
handle_weight.wait()
else:
total_weight = weight
if ctx.needs_input_grad[0]:
if not ctx.return_residual:
grad_input = F.linear(grad_output, total_weight.t())
else:
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, total_weight)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
else:
grad_input = None
# compute weight grad
if ctx.needs_input_grad[1]:
assert ctx.compute_weight_gradient
@ -369,11 +361,24 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
if grad_bias is not None:
grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
handle_grad_bias.wait()
handle_grad_weight.wait()
else:
grad_weight = None
grad_bias = grad_output if ctx.needs_input_grad[2] else None
if ctx.needs_input_grad[0]:
if not ctx.return_residual:
grad_input = F.linear(grad_output, total_weight.t())
else:
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, total_weight)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
else:
grad_input = None
if ctx.needs_input_grad[1]:
if world_size > 1:
handle_grad_weight.wait()
if grad_bias is not None:
handle_grad_bias.wait()
return grad_input, grad_weight, grad_bias, None, None, None