mirror of https://github.com/InternLM/InternLM
overlap grad_input computation and grad_weight reduce_scatter
parent
db637542a6
commit
0fac845c36
|
@ -1,7 +1,7 @@
|
|||
JOB_NAME = "7b_train"
|
||||
DO_ALERT = False
|
||||
|
||||
SEQ_LEN = 2048
|
||||
SEQ_LEN = 4096
|
||||
HIDDEN_SIZE = 4096
|
||||
NUM_ATTENTION_HEAD = 32
|
||||
MLP_RATIO = 8 / 3
|
||||
|
@ -154,10 +154,10 @@ pipeline parallel (dict):
|
|||
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
||||
"""
|
||||
parallel = dict(
|
||||
zero1=dict(size=8, fsdp=False),
|
||||
tensor=dict(size=1, mode='origin_tp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True
|
||||
zero1=dict(size=1, fsdp=False),
|
||||
tensor=dict(size=8, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True
|
||||
pipeline=dict(size=1, interleaved_overlap=True),
|
||||
sequence_parallel=False,
|
||||
sequence_parallel=True,
|
||||
)
|
||||
|
||||
cudnn_deterministic = False
|
||||
|
|
|
@ -407,10 +407,11 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
if hasattr(self, "norm"):
|
||||
hidden_states = self.norm(hidden_states.float())
|
||||
if hasattr(self, "head"):
|
||||
# Evaluation
|
||||
if hidden_states.ndim == 3:
|
||||
hidden_states = self.head(hidden_states, gather_dim=1)
|
||||
else:
|
||||
hidden_states = self.head(hidden_states)
|
||||
else: # Training
|
||||
hidden_states = self.head(hidden_states, gather_dim=0)
|
||||
|
||||
if not self.parallel_output:
|
||||
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
|
||||
|
|
|
@ -349,16 +349,8 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
handle_weight.wait()
|
||||
else:
|
||||
total_weight = weight
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
if not ctx.return_residual:
|
||||
grad_input = F.linear(grad_output, total_weight.t())
|
||||
else:
|
||||
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, total_weight)
|
||||
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
||||
else:
|
||||
grad_input = None
|
||||
|
||||
|
||||
# compute weight grad
|
||||
if ctx.needs_input_grad[1]:
|
||||
assert ctx.compute_weight_gradient
|
||||
|
||||
|
@ -369,11 +361,24 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
|
||||
if grad_bias is not None:
|
||||
grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
|
||||
handle_grad_bias.wait()
|
||||
handle_grad_weight.wait()
|
||||
else:
|
||||
grad_weight = None
|
||||
grad_bias = grad_output if ctx.needs_input_grad[2] else None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
if not ctx.return_residual:
|
||||
grad_input = F.linear(grad_output, total_weight.t())
|
||||
else:
|
||||
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, total_weight)
|
||||
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
||||
else:
|
||||
grad_input = None
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
if world_size > 1:
|
||||
handle_grad_weight.wait()
|
||||
if grad_bias is not None:
|
||||
handle_grad_bias.wait()
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue