diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 2011934..51d2e9c 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -5,7 +5,7 @@ SEQ_LEN = 2048 HIDDEN_SIZE = 4096 NUM_ATTENTION_HEAD = 32 MLP_RATIO = 8 / 3 -NUM_LAYER = 32 +NUM_LAYER = 4 VOCAB_SIZE = 103168 MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" @@ -55,7 +55,7 @@ data = dict( # defaults to the value of micro_num valid_micro_num=4, # defaults to 0, means disable evaluate - valid_every=50, + valid_every=1000, pack_sample_into_one=False, total_steps=50000, skip_batches="", diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 56661d8..9768790 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -202,10 +202,10 @@ class NonPipelineScheduler(BaseScheduler): if return_output_label: outputs.append(_output) labels.append(_label) - + if not return_output_label: outputs, labels = None, None - + # Compatible for non-moe if hasattr(gpc.config.model, "num_experts"): return outputs, labels, loss, moe_loss diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 5ee1af9..5ea0e80 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -28,9 +28,20 @@ from torch.cuda.amp import custom_bwd, custom_fwd import fused_dense_lib as fused_dense_cuda from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_fwd, sqrelu_bwd -from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw +from flash_attn.utils.distributed import all_gather_raw, all_reduce_raw +# reduce_scatter_raw from flash_attn.utils.distributed import reduce_scatter, all_reduce +def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False, op=torch.distributed.ReduceOp.SUM): + world_size = torch.distributed.get_world_size(process_group) + assert input_.shape[0] % world_size == 0 + output = torch.empty( + input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device + ) + handle = torch.distributed.reduce_scatter_tensor( + output, input_.contiguous(), op=op, group=process_group, async_op=async_op + ) + return output, handle class ScaleColumnParallelLinear(nn.Linear): """ @@ -279,15 +290,15 @@ class FusedDenseFunc_fsdp(torch.autograd.Function): grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) # do all-gather for weight before backward - weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) + total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) handle_weight.wait() if ctx.needs_input_grad[0]: if not ctx.return_residual: - grad_input = F.linear(grad_output, weight.t()) + grad_input = F.linear(grad_output, total_weight.t()) else: grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), - grad_output, weight) + grad_output, total_weight) grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) # if process_group is not None: # import pdb; pdb.set_trace() diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 8ac8c58..0db99ad 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -372,6 +372,7 @@ class PackedFlashInternLm1D(nn.Module): def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): # attention_mask: compute attention on the places where the value is 1 + import pdb; pdb.set_trace() if hasattr(self, "embedding"): hidden_states = self.embedding(input_ids) if self.embed_grad_scale != 1: diff --git a/train.py b/train.py index 9bc4bd7..1adcc22 100644 --- a/train.py +++ b/train.py @@ -254,6 +254,7 @@ def main(args): trainer=trainer, start_time=start_time, loss=loss, + moe_loss=None, grad_norm=grad_norm_groups, metric=metric, update_panel=uniscale_logger is not None,