From dcd89ed30466b7552f79077af5049e3581d46270 Mon Sep 17 00:00:00 2001 From: yingtongxiong <974106207@qq.com> Date: Fri, 20 Oct 2023 17:50:56 +0800 Subject: [PATCH] refactor linear --- configs/7B_sft.py | 2 +- internlm/model/linear.py | 350 ++++++++---------- internlm/model/modeling_internlm.py | 24 +- internlm/model/multi_head_attention.py | 12 +- internlm/model/utils.py | 206 +++++++++-- .../solver/optimizer/hybrid_zero_optim.py | 54 +-- internlm/train/training_internlm.py | 3 +- train.py | 4 +- 8 files changed, 356 insertions(+), 299 deletions(-) diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 6ea8b96..0058e04 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -162,7 +162,7 @@ sequence parallel (bool): enable/disable sequence parallel, defaults to False. """ parallel = dict( zero1=dict(size=-1, fsdp=False), - tensor=dict(size=8, mode="fstp", overlap=True), + tensor=dict(size=8, sp="intern", intern_overlap=True), pipeline=dict(size=1, interleaved_overlap=True), sequence_parallel=True, ) diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 4f05cd3..8f57a02 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -19,25 +19,26 @@ from internlm.model.utils import ( all_gather_raw_memory_pool, fstp_fused_dense_func, fused_dense_func_torch, + megatron_fused_dense_func_torch, ) -class ScaleColumnParallelLinear(nn.Linear): +class BaseScaleColumnParallelLinear(nn.Linear): """ - ScaleColumnParallelLinear. + Base class for ScaleColumnParallelLinear. - Args: - in_features (int): size of each input sample - out_features (int): size of each output sample - process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. - bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False - in the config. - sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul. - If not, then the input is already gathered. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - weight_scale (int): For training stability. 1 by default. + Args: + in_features (int): size of each input sample + out_features (int): size of each output sample + process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. + bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False + in the config. + sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: + we do an all_gather of x before doing the matmul. + If not, then the input is already gathered. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + weight_scale (int): For training stability. 1 by default. """ def __init__( @@ -57,6 +58,10 @@ class ScaleColumnParallelLinear(nn.Linear): self.process_group = process_group self.weight_scale = weight_scale +class ScaleColumnParallelLinear(BaseScaleColumnParallelLinear): + """ + ScaleColumnParallelLinear in flash implementation. + """ def forward(self, input, gather_dim=0): # pylint: disable=W0622 # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: # we do an all_gather of x before doing the matmul. @@ -74,6 +79,27 @@ class ScaleColumnParallelLinear(nn.Linear): gather_dim=gather_dim, ) +class MegatronScaleColumnParallelLinear(BaseScaleColumnParallelLinear): + """ + ScaleColumnParallelLinear in megatron implementation. + """ + + def forward(self, input, gather_dim=0): # pylint: disable=W0622 + # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: + # we do an all_gather of x before doing the matmul. + # If not, then the input is already gathered. + if self.weight_scale != 1: + weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach() + else: + weight = self.weight + return megatron_fused_dense_func_torch( + input, + weight, + self.bias, + process_group=self.process_group, + sequence_parallel=gpc.config.parallel.sequence_parallel, + gather_dim=gather_dim, + ) class RewardModelLinear(ScaleColumnParallelLinear): """ @@ -129,7 +155,6 @@ class ColumnParallelLinearTorch(ColumnParallelLinear): # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: # we do an all_gather of x before doing the matmul. # If not, then the input is already gathered. - return fused_dense_func_torch( x, self.weight, @@ -139,6 +164,19 @@ class ColumnParallelLinearTorch(ColumnParallelLinear): gather_dim=gather_dim, ) +class MegatronColumnParallelLinearTorch(ColumnParallelLinear): + def forward(self, x, gather_dim=0): + # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: + # we do an all_gather of x before doing the matmul. + # If not, then the input is already gathered. + return megatron_fused_dense_func_torch( + x, + self.weight, + self.bias, + process_group=self.process_group, + sequence_parallel=self.sequence_parallel, + gather_dim=gather_dim, + ) class RowParallelLinearTorch(RowParallelLinear): def forward(self, x): @@ -150,10 +188,20 @@ class RowParallelLinearTorch(RowParallelLinear): reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce return reduce_fn(out, self.process_group) +class MegatronRowParallelLinearTorch(RowParallelLinear): + def forward(self, x): + """ + We're doing Tensor Parallel with sequence parallelism: we do the matmul and then + a reduce_scatter of the result. + """ + out = megatron_fused_dense_func_torch(x, self.weight, self.bias) + reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + return reduce_fn(out, self.process_group) -class FeedForward(nn.Module): + +class BaseFeedForward(nn.Module): """ - FeedForward. + Base FeedForward in flash implementation. Args: in_features (int): size of each input sample @@ -177,13 +225,13 @@ class FeedForward(nn.Module): device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, multiple_of: int = 256, - block_idx: int = 0, + colum_cls = None, + row_cls = None, ): super().__init__() - hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) - self.w1 = ColumnParallelLinearTorch( + self.w1 = colum_cls( in_features, hidden_features, process_group, @@ -192,7 +240,7 @@ class FeedForward(nn.Module): device=device, dtype=dtype, ) - self.w2 = ColumnParallelLinearTorch( + self.w2 = colum_cls( in_features, hidden_features, process_group, @@ -201,7 +249,7 @@ class FeedForward(nn.Module): device=device, dtype=dtype, ) - self.w3 = RowParallelLinearTorch( + self.w3 = row_cls( hidden_features, out_features, process_group, @@ -217,21 +265,9 @@ class FeedForward(nn.Module): out = self.w3(Silu(w1_o, w2_o)) return out - -class FSTPLinear(ColumnParallelLinear): - def forward(self, x): - block_index = gpc.config.fstp_handler.module_to_index[self] - name_index = gpc.config.fstp_handler.module_name_index[self] - name = gpc.config.fstp_handler.module_name[name_index] - return fstp_fused_dense_func( - x, self.weight, self.bias, process_group=self.process_group, - module=self, handler=gpc.config.fstp_handler, block_index=block_index, module_name=name - ) - - -class FSTPFeedForward(nn.Module): +class FeedForward(BaseFeedForward): """ - FeedForward. + FeedForward in flash implementation. Args: in_features (int): size of each input sample @@ -255,169 +291,106 @@ class FSTPFeedForward(nn.Module): device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, multiple_of: int = 256, - block_idx: int = 0, ): - super().__init__() + super().__init__(in_features, hidden_features, out_features, process_group, bias, device, + dtype, multiple_of, ColumnParallelLinearTorch, RowParallelLinearTorch) + - hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) +class MegatronFeedForward(BaseFeedForward): + """ + FeedForward in megatron implementation. - self.w1 = FSTPLinear( - in_features, - hidden_features, - process_group, - bias, - sequence_parallel=gpc.config.parallel.sequence_parallel, - device=device, - dtype=dtype, - ) - self.w2 = FSTPLinear( - in_features, - hidden_features, - process_group, - bias, - sequence_parallel=gpc.config.parallel.sequence_parallel, - device=device, - dtype=dtype, - ) - self.w3 = FSTPLinear( - hidden_features, - out_features, - process_group, - bias=bias, - sequence_parallel=gpc.config.parallel.sequence_parallel, - device=device, - dtype=dtype, - ) + Args: + in_features (int): size of each input sample + hidden_features (int): size of hidden state of FFN + out_features (int): size of each output sample + process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. + bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False + in the config. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default. + """ + def __init__( + self, + in_features: int, + hidden_features: int, + out_features: int = None, + process_group: Optional[torch.distributed.ProcessGroup] = None, + bias: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + multiple_of: int = 256, + ): + super().__init__(in_features, hidden_features, out_features, process_group, bias, device, + dtype, multiple_of, MegatronColumnParallelLinearTorch, MegatronRowParallelLinearTorch) + +class FSTPLinear(ColumnParallelLinear): def forward(self, x): - w1_o = self.w1(x) - w2_o = self.w2(x) - out = self.w3(F.silu(w1_o) * w2_o) - return out + block_index = gpc.config.fstp_handler.module_to_index[self] + name_index = gpc.config.fstp_handler.module_name_index[self] + name = gpc.config.fstp_handler.module_name[name_index] + return fstp_fused_dense_func( + x, self.weight, self.bias, process_group=self.process_group, + module=self, handler=gpc.config.fstp_handler, block_index=block_index, module_name=name + ) - -class FSTPAllGatherSyncHandler: +class FSTPFeedForward(BaseFeedForward): """ - All-gather handler for overlapping the all-gather in adjcent FSTP linear. + FeedForward in FSTP. + + Args: + in_features (int): size of each input sample + hidden_features (int): size of hidden state of FFN + out_features (int): size of each output sample + process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. + bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False + in the config. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default. """ - def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None: - # import pdb; pdb.set_trace() - self.process_group = process_group - self.FSTP_modules = [] - self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] - self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward - self.module_handler = dict() # key: FSTP module; value: all-gather handler - self.module_block = dict() # key: FSTP module; value: transformer block index - self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module} - self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name + def __init__( + self, + in_features: int, + hidden_features: int, + out_features: int = None, + process_group: Optional[torch.distributed.ProcessGroup] = None, + bias: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + multiple_of: int = 256, + ): + super().__init__(in_features, hidden_features, out_features, process_group, bias, device, + dtype, multiple_of, FSTPLinear, FSTPLinear) - self.reduce_scatter_handlers = {} - self.all_reduce_handlers = {} - - # just want to share same for loop for ModuleList and Module - if not isinstance(model, nn.ModuleList): - model = [model] - - for _chunk in model: - if isinstance(_chunk, NaiveAMPModel): - _chunk = _chunk.model - - for _chunk_name, children in _chunk.named_children(): - if isinstance(children, nn.ModuleList): - for idx, block in enumerate(children): - index = 0 - self.block_module[idx] = {} - for _sub_name, sub in block.named_children(): - sub_modules = list(sub.children()) - if len(sub_modules) > 0: - for name, child in sub.named_children(): - if isinstance(child, FSTPLinear): - - _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}" - setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight") - if child.bias is not None: - setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias") - - self.FSTP_modules.append(child) - self.module_block[child] = idx - self.block_module[idx][index] = child - self.module_name_index[child] = index - index = index + 1 - else: - continue - - def _register_sync_parameters_hook(self) -> None: - """ - register pre_forward_hook and pre_backward_hook for FSTPLinear. - """ - - def _pre_forward_hook(module: nn.Module, inputs: Any): - block_index = self.module_block[module] - name_index = self.module_name_index[module] - if name_index == 0: - total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) - weight_handler.wait() - self.FSTP_global_weights[module] = total_weight - - # start the all-gather for next module - next_module = self.block_module[block_index][name_index + 1] - self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( - next_module.weight, self.process_group, async_op=True - ) - self.module_handler[next_module] = weights_handler - else: - handler = self.module_handler[module] - handler.wait() - if name_index != 4: - next_module = self.block_module[block_index][name_index + 1] - self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( - next_module.weight, self.process_group, async_op=True - ) - self.module_handler[next_module] = weights_handler - - def _post_forward_hook(module: nn.Module, input, output): - if module in self.FSTP_global_weights: - del self.FSTP_global_weights[module] - if module in self.module_handler: - del self.module_handler[module] - - def _pre_backward_hook(module: nn.Module, grad_output): - block_index = self.module_block[module] - name_index = self.module_name_index[module] - if name_index == 4: - total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) - weight_handler.wait() - self.FSTP_global_weights[module] = total_weight - - # start the all-gather for next module - next_module = self.block_module[block_index][name_index - 1] - self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( - next_module.weight, self.process_group, async_op=True - ) - self.module_handler[next_module] = weights_handler - else: - handler = self.module_handler[module] - handler.wait() - if name_index != 0: - next_module = self.block_module[block_index][name_index - 1] - self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( - next_module.weight, self.process_group, async_op=True - ) - self.module_handler[next_module] = weights_handler - - def _post_backward_hook(module, grad_input, grad_output): - del self.FSTP_global_weights[module] - - for module in self.FSTP_modules: - # import pdb; pdb.set_trace() - module.register_forward_pre_hook(_pre_forward_hook) - module.register_forward_hook(_post_forward_hook) - # module.register_backward_pre_hook(_pre_backward_hook) - # module.register_backward_hook(_post_backward_hook) - module.register_full_backward_pre_hook(_pre_backward_hook) - module.register_full_backward_hook(_post_backward_hook) +def get_mlp_cls(sp_mode: str): + if sp_mode in ["none", "flash-attn"]: + mlp_cls = FeedForward + elif sp_mode == "megatron": + mlp_cls = MegatronFeedForward + else: + mlp_cls = FSTPFeedForward + return mlp_cls +def get_linear_cls(sp_mode: str, parallel_mode: str): + if parallel_mode == "column": + if sp_mode in ["none", "flash-attn"]: + cls = ColumnParallelLinearTorch + elif sp_mode == "megatron": + cls = MegatronColumnParallelLinearTorch + else: + cls = FSTPLinear + elif parallel_mode == 'row': + if sp_mode in ["none", "flash-attn"]: + cls = RowParallelLinearTorch + elif sp_mode == "megatron": + cls = MegatronRowParallelLinearTorch + else: + cls = FSTPLinear + return cls class CoarseGrainedFSTPAllGatherSyncHandler: """ @@ -468,7 +441,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler: sub_modules = list(sub.children()) if len(sub_modules) > 0: for name, child in sub.named_children(): - # print(f"name: {name}", flush=True) if name == "out_proj": self.FSTP_outs.append(child) self.module_to_index[child] = idx diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index b004dff..99d540f 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -15,9 +15,12 @@ from internlm.initialize.initialize_tensor import normal_, scaled_init_method_no from internlm.model.embedding import Embedding1D from internlm.model.linear import ( FeedForward, + MegatronFeedForward, FSTPFeedForward, RewardModelLinear, ScaleColumnParallelLinear, + MegatronScaleColumnParallelLinear, + get_mlp_cls, ) from internlm.model.multi_head_attention import MHA from internlm.model.utils import ( @@ -77,8 +80,7 @@ class PackedFlashBaseLayer1D(nn.Module): use_scaled_init: bool = True, use_swiglu: bool = True, use_flash_attn: bool = True, - tp_mode: str = "origin_tp", - block_idx: int = 0, + sp_mode: str = "none", ): super().__init__() self.checkpoint = checkpoint @@ -103,8 +105,7 @@ class PackedFlashBaseLayer1D(nn.Module): use_flash_attn=use_flash_attn, device=device, dtype=dtype, - tp_mode=tp_mode, - block_idx=block_idx, + sp_mode=sp_mode, ) self.dropout1 = nn.Dropout(drop_rate) @@ -116,7 +117,7 @@ class PackedFlashBaseLayer1D(nn.Module): self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) if use_swiglu: - mlp_cls = FeedForward if tp_mode == "origin_tp" else FSTPFeedForward + mlp_cls = get_mlp_cls(sp_mode) self.mlp = mlp_cls( hidden_size, int(hidden_size * mlp_ratio), @@ -299,12 +300,16 @@ class PackedFlashInternLm1D(nn.Module): super().__init__() checkpoint_layer_num = int(num_layers * checkpoint) - self.tp_mode = gpc.config.parallel["tensor"]["mode"] + self.sp_mode = gpc.config.parallel["tensor"]["sp"] + if self.sp_mode == "none": + gpc.config.parallel.sequence_parallel = False + else: + gpc.config.parallel.sequence_parallel = True if is_reward: head_cls = RewardModelLinear else: - head_cls = ScaleColumnParallelLinear + head_cls = ScaleColumnParallelLinear if self.sp_mode in ["flash-attn", "none", "intern"] else MegatronScaleColumnParallelLinear if first: if embed_split_hidden: self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) @@ -345,8 +350,7 @@ class PackedFlashInternLm1D(nn.Module): use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, use_flash_attn=use_flash_attn, - tp_mode=self.tp_mode, - block_idx=lid, + sp_mode=self.sp_mode, ) for lid in range(num_layers) ] @@ -393,7 +397,7 @@ class PackedFlashInternLm1D(nn.Module): # The indexes are used to indicate the actual position IDs of each token in the packed input. indexes = indexes[0] # if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension. - if gpc.config.parallel.sequence_parallel and self.tp_mode == "fstp": + if gpc.config.parallel.sequence_parallel and self.sp_mode == "intern": indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 7a0f4ed..8ba49ed 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -42,6 +42,9 @@ from internlm.model.linear import ( ColumnParallelLinearTorch, FSTPLinear, RowParallelLinearTorch, + MegatronColumnParallelLinearTorch, + MegatronRowParallelLinearTorch, + get_linear_cls, ) @@ -175,8 +178,7 @@ class MHA(nn.Module): use_flash_attn: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - tp_mode: str = "origin_tp", - block_idx: int = 0, + sp_mode: str = "none", ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -204,7 +206,7 @@ class MHA(nn.Module): self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device) # notice here should change bias=True - Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear + Wqkv_cls = get_linear_cls(sp_mode, "column") self.Wqkv = Wqkv_cls( embed_dim, 3 * embed_dim, @@ -220,12 +222,12 @@ class MHA(nn.Module): self.inner_cross_attn = inner_cross_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout ) - if tp_mode == "fstp": + if sp_mode == "intern": self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group) self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group) # output projection always have the bias (for now) - out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear + out_proj_cls = get_linear_cls(sp_mode, 'row') self.out_proj = out_proj_cls( embed_dim, embed_dim, diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 2667efe..6757906 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -164,7 +164,7 @@ def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup, # adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py class FusedDenseFunc(torch.autograd.Function): - "tp fused dense function" + "FusedDenseFunc for tensor parallel in flash-attn implementation." @staticmethod @custom_fwd @@ -255,9 +255,96 @@ class FusedDenseFunc(torch.autograd.Function): return grad_input, grad_weight, grad_bias, None, None, None, None +class MegatronFusedDenseFunc(torch.autograd.Function): + ''' + FusedDenseFunc for tensor parallel in megatron implementation. + The diffenrence between the implementation of flash-attn and megatron is that the total_x could be saved for backward in megatron, + so that the all-gather in backward is ommited. + ''' + + @staticmethod + @custom_fwd + def forward(ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True, gather_dim=0): + """ + If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel + with sequence parallelism: we do an all_gather_raw of x before doing the matmul. + """ + ctx.compute_weight_gradient = weight.requires_grad + ctx.return_residual = return_residual + ctx.process_group = process_group + ctx.sequence_parallel = sequence_parallel + + if torch.is_autocast_enabled(): + x = x.to(dtype=torch.get_autocast_gpu_dtype()) + x = x.contiguous() + if process_group is not None and sequence_parallel: + # We want to kick off the all_gather early, before weight dtype conversion + total_x, handle_x = all_gather_raw(x, process_group, async_op=True, gather_dim=gather_dim) + else: + total_x = x + + if torch.is_autocast_enabled(): + weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) + bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None + weight = weight.contiguous() + if process_group is not None and sequence_parallel: + handle_x.wait() + batch_shape, n = total_x.shape[:-1], total_x.shape[-1] + batch_dim = batch_shape.numel() + # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 + if min(batch_dim, n, *weight.shape) > 65535 * 32: + raise RuntimeError("fused_dense only supports matrix dims <= 2M") + output = F.linear(total_x, weight, bias) + if ctx.compute_weight_gradient: + ctx.save_for_backward(total_x, weight) + else: + ctx.save_for_backward(weight) + return output if not return_residual else (output, x) + + @staticmethod + @custom_bwd + def backward(ctx, grad_output, *args): + grad_output = grad_output.contiguous() + if ctx.return_residual: + (grad_input,) = args + grad_input = grad_input.contiguous() + process_group = ctx.process_group + sequence_parallel = ctx.sequence_parallel + + if ctx.compute_weight_gradient: + total_x, weight = ctx.saved_tensors + else: + (weight,) = ctx.saved_tensors + total_x = None + batch_shape = grad_output.shape[:-1] + batch_dim = batch_shape.numel() + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + if ctx.needs_input_grad[0]: + if not ctx.return_residual: + grad_input = F.linear(grad_output, weight.t()) + else: + grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight) + grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) + if process_group is not None: + reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw + grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) + else: + grad_input = None + if ctx.needs_input_grad[1]: + assert ctx.compute_weight_gradient + grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad( + total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] + ) + else: + grad_weight = None + grad_bias = grad_output if ctx.needs_input_grad[2] else None + if process_group is not None and ctx.needs_input_grad[0]: + handle_grad_input.wait() + return grad_input, grad_weight, grad_bias, None, None, None, None + # adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py class FusedDenseFuncTorch(FusedDenseFunc): - """A custom PyTorch module extending FusedDenseFunc.""" + '''FusedDenseFunc in flash implementation for supporting torch.float32''' @staticmethod @custom_bwd @@ -307,17 +394,61 @@ class FusedDenseFuncTorch(FusedDenseFunc): handle_grad_input.wait() return grad_input, grad_weight, grad_bias, None, None, None, None +class MegatronFusedDenseFuncTorch(FusedDenseFunc): + '''FusedDenseFunc in megatron implementation for supporting torch.float32''' + + @staticmethod + @custom_bwd + def backward(ctx, grad_output, *args): + grad_output = grad_output.contiguous() + if ctx.return_residual: + (grad_input,) = args + grad_input = grad_input.contiguous() + process_group = ctx.process_group + sequence_parallel = ctx.sequence_parallel + gather_dim = ctx.gather_dim + if ctx.compute_weight_gradient: + total_x, weight = ctx.saved_tensors + else: + (weight,) = ctx.saved_tensors + total_x = None + batch_shape = grad_output.shape[:-1] + batch_dim = batch_shape.numel() + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + if ctx.needs_input_grad[0]: + if not ctx.return_residual: + grad_input = F.linear(grad_output, weight.t()) + else: + grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight) + grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) + if process_group is not None: + reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw + grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) + else: + grad_input = None + if ctx.needs_input_grad[1]: + assert ctx.compute_weight_gradient + # we remove the cuda independence, which is different from flash_attn. + grad_weight, grad_bias = linear_bias_wgrad_torch( + total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] + ) + else: + grad_weight = None + grad_bias = grad_output if ctx.needs_input_grad[2] else None + if process_group is not None and ctx.needs_input_grad[0]: + handle_grad_input.wait() + return grad_input, grad_weight, grad_bias, None, None, None, None class FSTPFusedDenseFunc(torch.autograd.Function): - "FSTP fused dense function" + "FusedDenseFunc for FSTP, which is optimized based on flash implementation." @staticmethod @custom_fwd - def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, all_gather_handler=None, block_index=None, module_name=None): + def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, overlap_handler=None, block_index=None, module_name=None): ctx.compute_weight_gradient = weight.requires_grad ctx.return_residual = return_residual ctx.process_group = process_group - ctx.all_gather_handler = all_gather_handler + ctx.overlap_handler = overlap_handler ctx.module = module ctx.block_index = block_index ctx.module_name = module_name @@ -329,13 +460,12 @@ class FSTPFusedDenseFunc(torch.autograd.Function): world_size = gpc.get_world_size(ParallelMode.TENSOR) if world_size > 1: # do all_gather for weight and bias before actual computation - if all_gather_handler is not None:# and module in all_gather_handler.FSTP_global_weights: - # total_weight = all_gather_handler.FSTP_global_weights[module] - total_weight = gpc.config.block_memory[block_index % 2][module_name] + if overlap_handler is not None: + total_weight = gpc.config.block_memory[block_index % 2][module_name] else: total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) handle_weight.wait() - + # TODO memory pool for bias if bias is not None: total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True) handle_bias.wait() @@ -356,6 +486,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function): if min(batch_dim, n, *total_weight.shape) > 65535 * 32: raise RuntimeError("fused_dense only supports matrix dims <= 2M") output = F.linear(total_x, total_weight, total_bias) + # release memory del total_weight del total_bias if ctx.compute_weight_gradient: @@ -372,8 +503,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function): (grad_input,) = args grad_input = grad_input.contiguous() process_group = ctx.process_group - all_gather_handler = ctx.all_gather_handler - module = ctx.module + overlap_handler = ctx.overlap_handler block_index = ctx.block_index module_name = ctx.module_name @@ -389,51 +519,35 @@ class FSTPFusedDenseFunc(torch.autograd.Function): world_size = gpc.get_world_size(ParallelMode.TENSOR) if world_size > 1: - total_weight = gpc.config.block_memory[block_index % 2][module_name] - # # do all-gather for weight before backward - # if module in all_gather_handler.FSTP_global_weights: - # total_weight = all_gather_handler.FSTP_global_weights[module] - # else: - # total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) - # handle_weight.wait() + if overlap_handler is not None: + total_weight = gpc.config.block_memory[block_index % 2][module_name] + else: + total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) + handle_weight.wait() else: total_weight = weight # compute weight grad if ctx.needs_input_grad[1]: assert ctx.compute_weight_gradient - grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad( total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] ) if world_size > 1: - if gpc.config.fstp_handler is not None: - # grad_weight_async, handle_grad_weight = all_reduce_raw(grad_weight, process_group, async_op=True) - # assert hasattr(weight, "_fstp_all_reduce_str") - # all_gather_handler.all_reduce_handlers[weight._fstp_all_reduce_str] = (handle_grad_weight, grad_weight_async) - # grad_weight = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device) - # if grad_bias is not None: - # grad_bias_async, handle_grad_bias = all_reduce_raw(grad_bias, process_group, async_op=True) - # assert hasattr(bias, "_fstp_all_reduce_str") - # all_gather_handler.all_reduce_handlers[bias._fstp_all_reduce_str] = (handle_grad_bias, grad_bias_async) - # grad_bias = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device) - + if overlap_handler is not None: grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(grad_weight, process_group, async_op=True) assert hasattr(weight, "_fstp_reduce_scatter_str") - all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async) - grad_weight = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device) + overlap_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async) + grad_weight = overlap_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device) if grad_bias is not None: grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool(grad_bias, process_group, async_op=True) assert hasattr(bias, "_fstp_reduce_scatter_str") - all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async) - grad_bias = all_gather_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device) + overlap_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async) + grad_bias = overlap_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device) else: grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True) if grad_bias is not None: grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True) - # grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True) - # if grad_bias is not None: - # grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True) else: grad_weight = None grad_bias = grad_output if ctx.needs_input_grad[2] else None @@ -449,7 +563,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function): del total_weight if ctx.needs_input_grad[1]: - if world_size > 1 and gpc.config.fstp_handler is None: + if world_size > 1 and overlap_handler is None: handle_grad_weight.wait() if grad_bias is not None: handle_grad_bias.wait() @@ -473,6 +587,22 @@ def fused_dense_func_torch( else: return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim) +def megatron_fused_dense_func_torch( + x: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + return_residual: bool = False, + process_group: Optional[ProcessGroup] = None, + sequence_parallel: bool = True, + gather_dim: int = 0, +): + dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( + x.dtype == torch.float32 and torch.is_autocast_enabled() + ) + if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: + return MegatronFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim) + else: + return MegatronFusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim) def fstp_fused_dense_func( x: Tensor, diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 96a54c0..4de5c7c 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -40,11 +40,6 @@ from .utils import compute_norm inf = math.inf logger = get_logger(__file__) -def print_memory(msg): - print(msg, " rank = ", gpc.get_global_rank(), " memory allocated: ", torch.cuda.memory_allocated() / 1024 / 1024 / 1024, " reverved memory: ", torch.cuda.memory_reserved() / 1024 / 1024 / 1024, " max memory: ", torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, flush=True) - print("===========================================") - - class HybridZeroOptimizer(BaseOptimizer): """ Hybrid Zero Optimizer. @@ -70,7 +65,7 @@ class HybridZeroOptimizer(BaseOptimizer): hysteresis = grad_scal_cfg.hysteresis max_scale = grad_scal_cfg.max_scale - if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True: + if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] == True: self._fstp_handler = gpc.config.fstp_handler # Zero related args @@ -358,20 +353,7 @@ class HybridZeroOptimizer(BaseOptimizer): del self._fstp_handler.reduce_scatter_handlers[key] self._fstp_handler.reduce_scatter_handlers[key] = None assert key in self._fstp_handler.reduce_scatter_handlers - # if not hasattr(_param, "_fstp_all_reduce_str"): - # continue - # key = getattr(_param, "_fstp_all_reduce_str") - # comm_handle, _grad = self._fstp_handler.all_reduce_handlers[key] - # comm_handle.wait() - # with torch.no_grad(): - # _grad = split_forward_gather_backward(_grad, ParallelMode.TENSOR, dim=0) - # _param.grad.add_(_grad) - # # self._fstp_handler.reduce_scatter_handlers[key] = None - # del _grad - # del self._fstp_handler.all_reduce_handlers[key] - # self._fstp_handler.all_reduce_handlers[key] = None - # assert key in self._fstp_handler.all_reduce_handlers bucket.reset_by_rank(rank) @@ -401,21 +383,6 @@ class HybridZeroOptimizer(BaseOptimizer): del self._fstp_handler.reduce_scatter_handlers[key] self._fstp_handler.reduce_scatter_handlers[key] = None assert key in self._fstp_handler.reduce_scatter_handlers - - # if not hasattr(_param, "_fstp_all_reduce_str"): - # continue - - # key = getattr(_param, "_fstp_all_reduce_str") - # comm_handle, _grad = self._fstp_handler.all_reduce_handlers[key] - # comm_handle.wait() - # with torch.no_grad(): - # _grad = split_forward_gather_backward(_grad, ParallelMode.TENSOR, dim=0) - # _param.grad.add_(_grad) - # # self._fstp_handler.reduce_scatter_handlers[key] = None - # del _grad - # del self._fstp_handler.all_reduce_handlers[key] - # self._fstp_handler.all_reduce_handlers[key] = None - # assert key in self._fstp_handler.all_reduce_handlers current_bucket.reset_by_rank(reduce_rank) @@ -634,7 +601,6 @@ class HybridZeroOptimizer(BaseOptimizer): # if not overlapping communication (no reduction hook is attached) # we need to manually reduce these gradients - print_memory("No 1") if not self._overlap_sync_grad: for group_id in range(len(self._fp16_param_groups)): for param in self._fp16_param_groups[group_id]: @@ -659,7 +625,6 @@ class HybridZeroOptimizer(BaseOptimizer): bucket.empty() self._bucket_in_progress = [] self._param_store.clear_grads_of_previous_reduced_params() - print_memory("No 2") # compute norm for gradients in the last bucket total_norms = {} for group_id in range(self.num_param_groups): @@ -681,19 +646,11 @@ class HybridZeroOptimizer(BaseOptimizer): scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float) dist.all_reduce(scaled_norm_tensor, group=pg) total_norms[group_name] = scaled_norm_tensor.item() - print_memory("No 3") timer("sync_grad").start() self._sync_grad() timer("sync_grad").stop() - print_memory("No 4") - - try: - res = self._step(closure=closure, norms=total_norms) - except torch.cuda.OutOfMemoryError as e: - print(e, flush=True) - print(torch.cuda.memory_summary(), flush=True) - torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") + res = self._step(closure=closure, norms=total_norms) return res @@ -740,7 +697,6 @@ class HybridZeroOptimizer(BaseOptimizer): self._grad_store._averaged_gradients = dict() self.zero_grad() return False, norms - print_memory("No 5") # copy the grad of fp16 param to fp32 param single_grad_partition_groups = [] for group_id in range(self.num_param_groups): @@ -781,7 +737,6 @@ class HybridZeroOptimizer(BaseOptimizer): single_grad_partition_groups.append(flat_fp32_avg_grads) device = self._fp32_flat_param_groups_of_current_rank[group_id].device self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device) - print_memory("No 6") # unscale and clip grads # get the global norm global_norm_groups = {} @@ -804,12 +759,9 @@ class HybridZeroOptimizer(BaseOptimizer): # For those ranks that are not assigned parameters, we just wait for other ranks # to send them updated their own parameters. if self.has_params: - print_memory("No 7") self.optim.step() - print_memory("No 8") # release the fp32 grad release_param_grad(self._fp32_flat_param_groups_of_current_rank.values()) - print_memory("No 9") # update fp16 partition updated by the current rank for group_id in range(len(self._fp16_param_groups)): if self.param_group_has_params[group_id]: @@ -818,7 +770,6 @@ class HybridZeroOptimizer(BaseOptimizer): ) fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] fp16_param.data.copy_(fp32_param) - print_memory("No 10") torch.cuda.synchronize() with torch.cuda.stream(self._comm_bcast_stream): self.broadcast_params() @@ -829,7 +780,6 @@ class HybridZeroOptimizer(BaseOptimizer): # so synchronization is maintained for group_name, global_norm in global_norm_groups.items(): global_norm_groups[group_name] = global_norm / loss_scale - print_memory("No 11") return True, global_norm_groups def broadcast_params(self): diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 2816da0..20592c2 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -38,7 +38,6 @@ from internlm.model.embedding import Embedding1D from internlm.model.linear import ( CoarseGrainedFSTPAllGatherSyncHandler, FeedForward, - FSTPAllGatherSyncHandler, RewardModelLinear, ScaleColumnParallelLinear, ) @@ -111,7 +110,7 @@ def initialize_model(): gpc.config.fstp_handler = None - if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True: + if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] == True: handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) # handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) handler._register_sync_parameters_hook() diff --git a/train.py b/train.py index 41ab070..a917d12 100644 --- a/train.py +++ b/train.py @@ -195,7 +195,7 @@ def main(args): # start iterating the train data and begin training for batch_count in range(train_state.batch_count, total_steps): empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) - torch.cuda.memory._record_memory_history() + # torch.cuda.memory._record_memory_history() start_time = time.time() timer("one-batch").start() @@ -300,7 +300,7 @@ def main(args): if gpc.config.fstp_handler is not None: gpc.config.fstp_handler.zero_const_pool = {} gpc.config.fstp_handler.reduce_scatter_memory = {} - torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") + # torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") torch.cuda.reset_peak_memory_stats() ckpt_manager.wait_async_upload_finish()