diff --git a/configs/7B_sft.py b/configs/7B_sft.py index d855700..ac49121 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -1,7 +1,7 @@ JOB_NAME = "7b_train" DO_ALERT = False -SEQ_LEN = 2048 +SEQ_LEN = 4096 HIDDEN_SIZE = 4096 NUM_ATTENTION_HEAD = 32 MLP_RATIO = 8 / 3 @@ -154,10 +154,10 @@ pipeline parallel (dict): tensor parallel: tensor parallel size, usually the number of GPUs per node. """ parallel = dict( - zero1=dict(size=8, fsdp=False), - tensor=dict(size=1, mode='origin_tp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True + zero1=dict(size=1, fsdp=False), + tensor=dict(size=8, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True pipeline=dict(size=1, interleaved_overlap=True), - sequence_parallel=False, + sequence_parallel=True, ) cudnn_deterministic = False diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index b8d7e60..228dbd3 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -407,10 +407,11 @@ class PackedFlashInternLm1D(nn.Module): if hasattr(self, "norm"): hidden_states = self.norm(hidden_states.float()) if hasattr(self, "head"): + # Evaluation if hidden_states.ndim == 3: hidden_states = self.head(hidden_states, gather_dim=1) - else: - hidden_states = self.head(hidden_states) + else: # Training + hidden_states = self.head(hidden_states, gather_dim=0) if not self.parallel_output: hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 67e89ad..3885488 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -349,16 +349,8 @@ class FSTPFusedDenseFunc(torch.autograd.Function): handle_weight.wait() else: total_weight = weight - - if ctx.needs_input_grad[0]: - if not ctx.return_residual: - grad_input = F.linear(grad_output, total_weight.t()) - else: - grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, total_weight) - grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) - else: - grad_input = None - + + # compute weight grad if ctx.needs_input_grad[1]: assert ctx.compute_weight_gradient @@ -369,11 +361,24 @@ class FSTPFusedDenseFunc(torch.autograd.Function): grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True) if grad_bias is not None: grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True) - handle_grad_bias.wait() - handle_grad_weight.wait() else: grad_weight = None grad_bias = grad_output if ctx.needs_input_grad[2] else None + + if ctx.needs_input_grad[0]: + if not ctx.return_residual: + grad_input = F.linear(grad_output, total_weight.t()) + else: + grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, total_weight) + grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) + else: + grad_input = None + + if ctx.needs_input_grad[1]: + if world_size > 1: + handle_grad_weight.wait() + if grad_bias is not None: + handle_grad_bias.wait() return grad_input, grad_weight, grad_bias, None, None, None