From c21906534835fd0ce465672e6716ed587342cd1b Mon Sep 17 00:00:00 2001 From: ytxiong <45058324+yingtongxiong@users.noreply.github.com> Date: Mon, 7 Aug 2023 16:42:52 +0800 Subject: [PATCH] feat(*): support sequence_parallel (#180) * support sequence_parallel for no pipeline * sequence_parallel does not support no-flash-attn * support sequence parallel for pipeline * add memory profiler * Update 13B.py * add memory profiler * fix evaluation bug * remove some unnecessary code * remove some unnecessary code * Update parallel_context.py * modify the config * remove memory profiler * modify the config * support selective dropout --- configs/7B_sft.py | 1 + internlm/core/context/parallel_context.py | 1 - internlm/core/scheduler/pipeline_scheduler.py | 19 ++++++++---- internlm/initialize/launch.py | 5 ++++ internlm/model/embedding.py | 7 +++-- internlm/model/linear.py | 15 ++++------ internlm/model/modeling_internlm.py | 7 ++--- internlm/model/multi_head_attention.py | 5 ++-- internlm/model/utils.py | 29 +++++++++++++++++++ internlm/utils/evaluation.py | 24 ++++++++++++--- train.py | 19 ++++++------ 11 files changed, 95 insertions(+), 37 deletions(-) diff --git a/configs/7B_sft.py b/configs/7B_sft.py index b5b1352..30655de 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -118,6 +118,7 @@ model = dict( layer_norm_epsilon=1e-5, use_flash_attn=True, num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. + sequence_parallel=False, ) """ zero1 parallel: diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index bc0346c..87d3114 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -464,7 +464,6 @@ class ParallelContext(metaclass=SingletonMeta): initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args)) if self.pipeline_parallel_size > 1: initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args)) - for initializer in initializers: parallel_setting = initializer.init_dist_group() if isinstance(parallel_setting, list): diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 9a2c1bb..ac13073 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -30,10 +30,17 @@ def get_tensor_shape(): if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"): if gpc.config.model.use_flash_attn: - tensor_shape = ( - gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"], - gpc.config.HIDDEN_SIZE, - ) + if gpc.config.model.sequence_parallel: + sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR) + tensor_shape = ( + gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"] // sequence_world_size, + gpc.config.HIDDEN_SIZE, + ) + else: + tensor_shape = ( + gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"], + gpc.config.HIDDEN_SIZE, + ) else: tensor_shape = ( gpc.config.data["micro_bsz"], @@ -132,6 +139,9 @@ class PipelineScheduler(BaseScheduler): and gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(ParallelMode.TENSOR) > 1 ) + + if gpc.config.model.sequence_parallel: + self.scatter_gather_tensors = False # cache for the batch data self.batch_data = None @@ -254,7 +264,6 @@ class PipelineScheduler(BaseScheduler): if gpc.is_last_rank(ParallelMode.PIPELINE): self._call_hooks("post_helper_func", output_obj, label) - if return_output_label: return_tensors.append((output_obj, label)) if accum_loss is not None: diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index cde8bc0..1f60adc 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -38,6 +38,7 @@ def get_default_parser(): parser.add_argument("--local_rank", type=int, help="local rank on the node") parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication") parser.add_argument("--seed", type=int, default=1024) + parser.add_argument("--profiling", default=True, action="store_true", help="enable/diable profiling.") return parser @@ -198,6 +199,10 @@ def args_sanity_check(): # process the model config if "use_flash_attn" not in gpc.config.model: gpc.config.model._add_item("use_flash_attn", True) + if "sequence_parallel" not in gpc.config.model: + gpc.config.model._add_item("sequence_parallel", False) + else: + assert not (gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False), "sequence parallel does not support use_flash_attn=False" def launch( diff --git a/internlm/model/embedding.py b/internlm/model/embedding.py index ced9e33..0951ccd 100644 --- a/internlm/model/embedding.py +++ b/internlm/model/embedding.py @@ -13,7 +13,7 @@ from torch import Tensor, nn from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from .utils import gather_forward_split_backward +from .utils import gather_forward_split_backward, split_forward_gather_backward class Embedding1D(nn.Module): @@ -55,7 +55,10 @@ class Embedding1D(nn.Module): output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1) - + + if gpc.config.model.sequence_parallel: + output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1) + return output diff --git a/internlm/model/linear.py b/internlm/model/linear.py index c0dfcf9..2fa249c 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -38,7 +38,6 @@ class ScaleColumnParallelLinear(nn.Linear): out_features: int, process_group: Optional[torch.distributed.ProcessGroup], bias: bool = True, - sequence_parallel: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, weight_scale: int = 1, @@ -48,7 +47,6 @@ class ScaleColumnParallelLinear(nn.Linear): raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})") super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype) self.process_group = process_group - self.sequence_parallel = sequence_parallel self.weight_scale = weight_scale def forward(self, input): # pylint: disable=W0622 @@ -60,7 +58,7 @@ class ScaleColumnParallelLinear(nn.Linear): else: weight = self.weight return fused_dense_func_torch( - input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel + input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel ) @@ -87,12 +85,11 @@ class RewardModelLinear(ScaleColumnParallelLinear): out_features: int, process_group: Optional[torch.distributed.ProcessGroup], bias: bool = True, - sequence_parallel: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, weight_scale: int = 1, ) -> None: - super().__init__(in_features, out_features, process_group, bias, sequence_parallel, device, dtype, weight_scale) + super().__init__(in_features, out_features, process_group, bias, device, dtype, weight_scale) torch.distributed.broadcast(self.weight, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group) if bias: torch.distributed.broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group) @@ -106,7 +103,7 @@ class RewardModelLinear(ScaleColumnParallelLinear): else: weight = self.weight return fused_dense_func_torch( - input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel + input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel ) @@ -168,19 +165,19 @@ class FeedForward(nn.Module): hidden_features, process_group, bias, - sequence_parallel=False, + sequence_parallel=gpc.config.model.sequence_parallel, device=device, dtype=dtype, ) self.w2 = ColumnParallelLinearTorch( - in_features, hidden_features, process_group, bias, sequence_parallel=False, device=device, dtype=dtype + in_features, hidden_features, process_group, bias, sequence_parallel=gpc.config.model.sequence_parallel, device=device, dtype=dtype ) self.w3 = RowParallelLinearTorch( hidden_features, out_features, process_group, bias=bias, - sequence_parallel=False, + sequence_parallel=gpc.config.model.sequence_parallel, device=device, dtype=dtype, ) diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index ce5e993..31138fa 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -89,7 +89,6 @@ class PackedFlashBaseLayer1D(nn.Module): rotary_emb_dim=head_dim, rotary_emb_scale_base=0, use_flash_attn=use_flash_attn, - sequence_parallel=False, device=device, dtype=dtype, ) @@ -121,7 +120,7 @@ class PackedFlashBaseLayer1D(nn.Module): process_group=gpc.get_group(ParallelMode.TENSOR), bias1=False, bias2=False, - sequence_parallel=False, + sequence_parallel=gpc.config.model.sequence_parallel, checkpoint_lvl=0, heuristic="auto", device=device, @@ -300,7 +299,7 @@ class PackedFlashInternLm1D(nn.Module): max_position_embeddings=-1, process_group=gpc.get_group(ParallelMode.TENSOR), padding_idx=None, - sequence_parallel=False, + sequence_parallel=gpc.config.model.sequence_parallel, device=device, dtype=dtype, ) @@ -342,7 +341,6 @@ class PackedFlashInternLm1D(nn.Module): out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, process_group=gpc.get_group(ParallelMode.TENSOR), bias=False, - sequence_parallel=False, device=device, dtype=dtype, weight_scale=embed_grad_scale, @@ -463,6 +461,7 @@ def build_model_with_cfg( use_scaled_init: bool = True, use_swiglu: bool = True, use_flash_attn: bool = True, + sequence_parallel: bool = False, ): """ Builde model with config diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 096c4e6..2b213ec 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -59,7 +59,6 @@ class MHA(nn.Module): rotary_emb_dim: int = 0, rotary_emb_scale_base: int = 0, use_flash_attn: bool = True, - sequence_parallel: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: @@ -83,7 +82,7 @@ class MHA(nn.Module): 3 * embed_dim, process_group, bias=True, - sequence_parallel=sequence_parallel, + sequence_parallel=gpc.config.model.sequence_parallel, **factory_kwargs, ) # according to https://spaces.ac.cn/archives/9577 @@ -96,7 +95,7 @@ class MHA(nn.Module): # output projection always have the bias (for now) self.out_proj = RowParallelLinearTorch( - embed_dim, embed_dim, process_group, sequence_parallel=sequence_parallel, **factory_kwargs + embed_dim, embed_dim, process_group, sequence_parallel=gpc.config.model.sequence_parallel, **factory_kwargs ) # need to assign tp attribute so that internlm know it is tensor parallel module if gpc.get_world_size(ParallelMode.TENSOR) > 1: diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 0c7ed2e..a84f058 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -157,6 +157,35 @@ def fused_dense_func_torch( else: return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel) +class _SplitForwardGatherBackward(torch.autograd.Function): + """ + Split the input and keep only the corresponding chuck to the rank. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + dim: dimension + """ + + @staticmethod + def symbolic(graph, input_): + return _split(input_) + + @staticmethod + def forward(ctx, input_, parallel_mode, dim): + ctx.mode = parallel_mode + ctx.dim = dim + return _split(input_, parallel_mode, dim) + + @staticmethod + def backward(ctx, grad_output): + return _gather(grad_output, ctx.mode, ctx.dim), None, None + + +def split_forward_gather_backward(input_, parallel_mode, dim): + return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim) + + def try_import_RMSNorm(): """ Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm diff --git a/internlm/utils/evaluation.py b/internlm/utils/evaluation.py index e6cd792..8424e16 100644 --- a/internlm/utils/evaluation.py +++ b/internlm/utils/evaluation.py @@ -90,9 +90,15 @@ def evaluate_on_val_dls( total_val_bsz = len(batch[1]) assert total_val_bsz % data_cfg.micro_bsz == 0 num_microbatches = total_val_bsz // data_cfg.micro_bsz - tensor_shape = torch.Size( - [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE] - ) + if gpc.config.model.sequence_parallel: + sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR) + tensor_shape = torch.Size( + [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1] // sequence_world_size, gpc.config.HIDDEN_SIZE] + ) + else: + tensor_shape = torch.Size( + [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE] + ) with switch_evaluation_pipeline_scheduler( trainer=trainer, @@ -108,7 +114,7 @@ def evaluate_on_val_dls( assert total_val_bsz % data_cfg.micro_bsz == 0 grad_accum_size = total_val_bsz // data_cfg.micro_bsz grad_accum_batch_size = data_cfg.micro_bsz - + # import pdb; pdb.set_trace() with switch_evaluation_no_pipeline_scheduler( trainer=trainer, grad_accum_size=grad_accum_size, @@ -155,3 +161,13 @@ def evaluate_on_val_dls( trainer.train() torch.cuda.empty_cache() dist.barrier() + + +@contextmanager +def switch_sequence_parallel_mode(): + prev_mode = gpc.config.model.sequence_parallel + try: + gpc.config.model.sequence_parallel = False + yield + finally: + gpc.config.model.sequence_parallel = prev_mode \ No newline at end of file diff --git a/train.py b/train.py index b457498..59729e7 100644 --- a/train.py +++ b/train.py @@ -41,7 +41,7 @@ from internlm.utils.common import ( launch_time, parse_args, ) -from internlm.utils.evaluation import evaluate_on_val_dls +from internlm.utils.evaluation import evaluate_on_val_dls, switch_sequence_parallel_mode from internlm.utils.logger import get_logger, initialize_uniscale_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.model_checkpoint import ( @@ -618,14 +618,15 @@ def main(args): # evaluate on validation data loaders if valid_every > 0 and train_state.step_count % valid_every == 0: - evaluate_on_val_dls( - trainer=trainer, - val_dls=val_dls, - writer=writer, - logger=logger, - step_count=train_state.step_count, - update_panel=uniscale_logger is not None, - ) + with switch_sequence_parallel_mode(): + evaluate_on_val_dls( + trainer=trainer, + val_dls=val_dls, + writer=writer, + logger=logger, + step_count=train_state.step_count, + update_panel=uniscale_logger is not None, + ) # checkpoint the training states in specific steps, which is determined by the args "checkpoint_every" # save batch sampler that tracks the true consumed samples