diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 0c15daf..cde8bc0 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -162,8 +162,22 @@ def args_sanity_check(): gpc.config.model.dtype = torch.bfloat16 elif gpc.config.model.dtype in ("torch.float16", "torch.half"): gpc.config.model.dtype = torch.float16 + elif gpc.config.model.dtype == "torch.float32": + assert gpc.config.model.use_flash_attn is False, "when using float32, the use_flash_attn must be False" + gpc.config.model.dtype = torch.float32 + elif gpc.config.model.dtype == "torch.tf32": + assert gpc.config.model.use_flash_attn is False, "when using tf32, the use_flash_attn must be False" + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + gpc.config.model.dtype = torch.float32 else: - assert gpc.config.model.dtype in ["torch.float16", "torch.half", "torch.bfloat16"] + assert gpc.config.model.dtype in [ + "torch.float16", + "torch.half", + "torch.bfloat16", + "torch.float32", + "torch.tf32", + ] if gpc.is_rank_for_log(): logger.info("+" * 15 + " Model Info " + "+" * 15) # pylint: disable=W1201 diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 88129af..c0dfcf9 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -5,15 +5,13 @@ from typing import Optional import torch import torch.nn.functional as F -from flash_attn.ops.fused_dense import ( - ColumnParallelLinear, - RowParallelLinear, - fused_dense_func, -) +from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear +from flash_attn.utils.distributed import all_reduce, reduce_scatter from torch import nn from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import global_context as gpc +from internlm.model.utils import fused_dense_func_torch class ScaleColumnParallelLinear(nn.Linear): @@ -61,7 +59,7 @@ class ScaleColumnParallelLinear(nn.Linear): weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach() else: weight = self.weight - return fused_dense_func( + return fused_dense_func_torch( input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel ) @@ -107,11 +105,33 @@ class RewardModelLinear(ScaleColumnParallelLinear): weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach() else: weight = self.weight - return fused_dense_func( + return fused_dense_func_torch( input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel ) +class ColumnParallelLinearTorch(ColumnParallelLinear): + def forward(self, x): + # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: + # we do an all_gather of x before doing the matmul. + # If not, then the input is already gathered. + + return fused_dense_func_torch( + x, self.weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel + ) + + +class RowParallelLinearTorch(RowParallelLinear): + def forward(self, x): + """ + We're doing Tensor Parallel with sequence parallelism: we do the matmul and then + a reduce_scatter of the result. + """ + out = fused_dense_func_torch(x, self.weight, self.bias) + reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + return reduce_fn(out, self.process_group) + + class FeedForward(nn.Module): """ FeedForward. @@ -143,7 +163,7 @@ class FeedForward(nn.Module): hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) - self.w1 = ColumnParallelLinear( + self.w1 = ColumnParallelLinearTorch( in_features, hidden_features, process_group, @@ -152,10 +172,10 @@ class FeedForward(nn.Module): device=device, dtype=dtype, ) - self.w2 = ColumnParallelLinear( + self.w2 = ColumnParallelLinearTorch( in_features, hidden_features, process_group, bias, sequence_parallel=False, device=device, dtype=dtype ) - self.w3 = RowParallelLinear( + self.w3 = RowParallelLinearTorch( hidden_features, out_features, process_group, diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 7513563..096c4e6 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -12,12 +12,12 @@ from flash_attn.modules.mha import ( SelfAttention, _update_kv_cache, ) -from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear from torch import nn from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import global_context as gpc from internlm.model.embedding import RotaryEmbedding +from internlm.model.linear import ColumnParallelLinearTorch, RowParallelLinearTorch class MHA(nn.Module): @@ -78,7 +78,7 @@ class MHA(nn.Module): self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device) # notice here should change bias=True - self.Wqkv = ColumnParallelLinear( + self.Wqkv = ColumnParallelLinearTorch( embed_dim, 3 * embed_dim, process_group, @@ -95,7 +95,7 @@ class MHA(nn.Module): ) # output projection always have the bias (for now) - self.out_proj = RowParallelLinear( + self.out_proj = RowParallelLinearTorch( embed_dim, embed_dim, process_group, sequence_parallel=sequence_parallel, **factory_kwargs ) # need to assign tp attribute so that internlm know it is tensor parallel module diff --git a/internlm/model/utils.py b/internlm/model/utils.py index ba00a27..0c7ed2e 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -1,7 +1,19 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from typing import Optional + import torch +import torch.nn.functional as F +from flash_attn.ops.fused_dense import FusedDenseFunc +from flash_attn.utils.distributed import ( + all_gather_raw, + all_reduce_raw, + reduce_scatter_raw, +) +from torch import Tensor +from torch.cuda.amp import custom_bwd +from torch.distributed import ProcessGroup from internlm.core.context import global_context as gpc @@ -72,6 +84,78 @@ class _GatherForwardSplitBackward(torch.autograd.Function): def gather_forward_split_backward(input_, parallel_mode, dim): return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim) +def linear_bias_wgrad_torch(input, grad_output, has_d_bias): + assert input.dtype == grad_output.dtype + grad_weight = torch.matmul(grad_output.t(), input) + grad_bias = grad_output.sum(dim=0) if has_d_bias else None + return grad_weight, grad_bias + + +# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py +class FusedDenseFuncTorch(FusedDenseFunc): + @staticmethod + @custom_bwd + def backward(ctx, grad_output, *args): + grad_output = grad_output.contiguous() + if ctx.return_residual: + (grad_input,) = args + grad_input = grad_input.contiguous() + process_group = ctx.process_group + sequence_parallel = ctx.sequence_parallel + if ctx.compute_weight_gradient: + x, weight = ctx.saved_tensors + if process_group is not None and sequence_parallel: + total_x, handle_x = all_gather_raw(x, process_group, async_op=True) + else: + total_x = x + else: + (weight,) = ctx.saved_tensors + total_x = None + batch_shape = grad_output.shape[:-1] + batch_dim = batch_shape.numel() + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + if ctx.needs_input_grad[0]: + if not ctx.return_residual: + grad_input = F.linear(grad_output, weight.t()) + else: + grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight) + grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) + if process_group is not None: + reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw + grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) + else: + grad_input = None + if ctx.needs_input_grad[1]: + assert ctx.compute_weight_gradient + if process_group is not None and sequence_parallel: + handle_x.wait() + # we remove the cuda independence, which is different from flash_attn. + grad_weight, grad_bias = linear_bias_wgrad_torch( + total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] + ) + else: + grad_weight = None + grad_bias = grad_output if ctx.needs_input_grad[2] else None + if process_group is not None and ctx.needs_input_grad[0]: + handle_grad_input.wait() + return grad_input, grad_weight, grad_bias, None, None, None + + +def fused_dense_func_torch( + x: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + return_residual: bool = False, + process_group: Optional[ProcessGroup] = None, + sequence_parallel: bool = True, +): + dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( + x.dtype == torch.float32 and torch.is_autocast_enabled() + ) + if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: + return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group, sequence_parallel) + else: + return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel) def try_import_RMSNorm(): """ @@ -86,4 +170,4 @@ def try_import_RMSNorm(): logger = get_logger(__file__) logger.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!") from internlm.model.norm import RMSNormTorch as RMSNorm - return RMSNorm \ No newline at end of file + return RMSNorm diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 94c3e05..116ffc2 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -89,7 +89,10 @@ class HybridZeroOptimizer(BaseOptimizer): zero_cfg: Config = None, ): # DynamicGradScaler related args - initial_scale = grad_scal_cfg.fp16.initial_scale + if gpc.config.model.dtype is torch.float32: + initial_scale = 1 + else: + initial_scale = grad_scal_cfg.fp16.initial_scale min_scale = grad_scal_cfg.fp16.min_scale growth_interval = grad_scal_cfg.fp16.growth_interval growth_factor = grad_scal_cfg.growth_factor @@ -533,7 +536,8 @@ class HybridZeroOptimizer(BaseOptimizer): norm_groups.append(norm_group) loss_scale = float(self.loss_scale.item()) # backup - self.grad_scaler.update(found_inf) + if not gpc.config.model.dtype is torch.float32: + self.grad_scaler.update(found_inf) # update loss scale if overflow occurs if found_inf: if gpc.is_rank_for_log(): @@ -576,8 +580,9 @@ class HybridZeroOptimizer(BaseOptimizer): global_norm = sum(norm_groups) ** 0.5 # the following operations are performed only on the rank to which parameters are assigned. - if len(single_grad_partition_groups) != 0: - self._unscale_and_clip_grads(single_grad_partition_groups, global_norm, loss_scale) + if not gpc.config.model.dtype is torch.float32: + if len(single_grad_partition_groups) != 0: + self._unscale_and_clip_grads(single_grad_partition_groups, global_norm, loss_scale) timer("cal_norm").stop() # update the parameters