pull/407/head
yingtongxiong 2023-10-09 20:39:57 +08:00
parent 29df765f65
commit f191853bf4
6 changed files with 99 additions and 107 deletions

View File

@ -283,8 +283,10 @@ def args_sanity_check():
if gpc.config.parallel["tensor"].get("mode", None) is None: if gpc.config.parallel["tensor"].get("mode", None) is None:
gpc.config.parallel["tensor"]["mode"] = "origin_tp" gpc.config.parallel["tensor"]["mode"] = "origin_tp"
if gpc.config.parallel["tensor"].get("mode", None) == 'fstp': if gpc.config.parallel["tensor"].get("mode", None) == "fstp":
assert gpc.config.parallel.sequence_parallel is True, "when the tp_mode is fstp, the sequence_parallel should be True." assert (
gpc.config.parallel.sequence_parallel is True
), "when the tp_mode is fstp, the sequence_parallel should be True."
# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy # currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1: if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:

View File

@ -9,10 +9,9 @@ from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
from flash_attn.utils.distributed import all_reduce, reduce_scatter from flash_attn.utils.distributed import all_reduce, reduce_scatter
from torch import nn from torch import nn
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.model.utils import Silu, fused_dense_func_torch, fsdp_fused_dense_func from internlm.model.utils import Silu, fstp_fused_dense_func, fused_dense_func_torch
class ScaleColumnParallelLinear(nn.Linear): class ScaleColumnParallelLinear(nn.Linear):
@ -124,7 +123,12 @@ class ColumnParallelLinearTorch(ColumnParallelLinear):
# If not, then the input is already gathered. # If not, then the input is already gathered.
return fused_dense_func_torch( return fused_dense_func_torch(
x, self.weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel, gather_dim=gather_dim, x,
self.weight,
self.bias,
process_group=self.process_group,
sequence_parallel=self.sequence_parallel,
gather_dim=gather_dim,
) )
@ -204,31 +208,13 @@ class FeedForward(nn.Module):
out = self.w3(Silu(w1_o, w2_o)) out = self.w3(Silu(w1_o, w2_o))
return out return out
class FSDPLinear(ColumnParallelLinear):
class FSTPLinear(ColumnParallelLinear):
def forward(self, x): def forward(self, x):
return fsdp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group) return fstp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group)
class FSDPScaleLinear(ScaleColumnParallelLinear): class FSTPFeedForward(nn.Module):
def forward(self, input): # pylint: disable=W0622
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
if self.weight_scale != 1:
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
else:
weight = self.weight
return fsdp_fused_dense_func(
input,
weight,
self.bias,
process_group=self.process_group,
)
class FSDPFeedForward(nn.Module):
""" """
FeedForward. FeedForward.
@ -259,7 +245,7 @@ class FSDPFeedForward(nn.Module):
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
self.w1 = FSDPLinear( self.w1 = FSTPLinear(
in_features, in_features,
hidden_features, hidden_features,
process_group, process_group,
@ -268,7 +254,7 @@ class FSDPFeedForward(nn.Module):
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
self.w2 = FSDPLinear( self.w2 = FSTPLinear(
in_features, in_features,
hidden_features, hidden_features,
process_group, process_group,
@ -277,7 +263,7 @@ class FSDPFeedForward(nn.Module):
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
self.w3 = FSDPLinear( self.w3 = FSTPLinear(
hidden_features, hidden_features,
out_features, out_features,
process_group, process_group,

View File

@ -15,13 +15,16 @@ from internlm.initialize.initialize_tensor import normal_, scaled_init_method_no
from internlm.model.embedding import Embedding1D from internlm.model.embedding import Embedding1D
from internlm.model.linear import ( from internlm.model.linear import (
FeedForward, FeedForward,
FSTPFeedForward,
RewardModelLinear, RewardModelLinear,
ScaleColumnParallelLinear, ScaleColumnParallelLinear,
FSDPScaleLinear,
FSDPFeedForward,
) )
from internlm.model.multi_head_attention import MHA from internlm.model.multi_head_attention import MHA
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm, split_forward_gather_backward from internlm.model.utils import (
gather_forward_split_backward,
split_forward_gather_backward,
try_import_RMSNorm,
)
from internlm.solver.pipeline_utils import partition_uniform from internlm.solver.pipeline_utils import partition_uniform
from internlm.utils.checkpoint import activation_checkpoint from internlm.utils.checkpoint import activation_checkpoint
from internlm.utils.common import filter_kwargs from internlm.utils.common import filter_kwargs
@ -74,7 +77,7 @@ class PackedFlashBaseLayer1D(nn.Module):
use_scaled_init: bool = True, use_scaled_init: bool = True,
use_swiglu: bool = True, use_swiglu: bool = True,
use_flash_attn: bool = True, use_flash_attn: bool = True,
tp_mode: str = 'origin_tp', tp_mode: str = "origin_tp",
): ):
super().__init__() super().__init__()
self.checkpoint = checkpoint self.checkpoint = checkpoint
@ -111,7 +114,7 @@ class PackedFlashBaseLayer1D(nn.Module):
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
if use_swiglu: if use_swiglu:
mlp_cls = FeedForward if tp_mode == 'origin_tp' else FSDPFeedForward mlp_cls = FeedForward if tp_mode == "origin_tp" else FSTPFeedForward
self.mlp = mlp_cls( self.mlp = mlp_cls(
hidden_size, hidden_size,
int(hidden_size * mlp_ratio), int(hidden_size * mlp_ratio),
@ -173,7 +176,6 @@ class PackedFlashBaseLayer1D(nn.Module):
else: else:
normal_(std=0.006 if "fc1" in name else 0.0015)(param.data) normal_(std=0.006 if "fc1" in name else 0.0015)(param.data)
def forward(self, hidden_states, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None): def forward(self, hidden_states, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None):
if self.checkpoint and self.training: if self.checkpoint and self.training:
return activation_checkpoint( return activation_checkpoint(
@ -341,7 +343,7 @@ class PackedFlashInternLm1D(nn.Module):
use_scaled_init=use_scaled_init, use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu, use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn, use_flash_attn=use_flash_attn,
tp_mode = self.tp_mode, tp_mode=self.tp_mode,
) )
for lid in range(num_layers) for lid in range(num_layers)
] ]
@ -388,7 +390,7 @@ class PackedFlashInternLm1D(nn.Module):
# The indexes are used to indicate the actual position IDs of each token in the packed input. # The indexes are used to indicate the actual position IDs of each token in the packed input.
indexes = indexes[0] indexes = indexes[0]
# if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension. # if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension.
if gpc.config.parallel.sequence_parallel and self.tp_mode == 'fstp': if gpc.config.parallel.sequence_parallel and self.tp_mode == "fstp":
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0) indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None

View File

@ -2,9 +2,10 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import warnings import warnings
from typing import Optional from typing import Any, Optional, Tuple
import torch import torch
import torch.distributed as dist
from einops import rearrange from einops import rearrange
from flash_attn.modules.mha import ( from flash_attn.modules.mha import (
CrossAttention, CrossAttention,
@ -13,26 +14,25 @@ from flash_attn.modules.mha import (
SelfAttention, SelfAttention,
_update_kv_cache, _update_kv_cache,
) )
from torch import nn from torch import Tensor, nn
from torch.nn import Module
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.model.embedding import DynamicNTKScalingRotaryEmbedding, RotaryEmbedding from internlm.model.embedding import DynamicNTKScalingRotaryEmbedding, RotaryEmbedding
from internlm.model.linear import ColumnParallelLinearTorch, RowParallelLinearTorch, FSDPLinear from internlm.model.linear import (
ColumnParallelLinearTorch,
import torch FSTPLinear,
RowParallelLinearTorch,
from typing import Any, Tuple )
from torch import Tensor
from torch.nn import Module
import torch.distributed as dist
# adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py
class _SeqAllToAll(torch.autograd.Function): class _SeqAllToAll(torch.autograd.Function):
"sequence alltoall"
@staticmethod @staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor: def forward(ctx: Any, group: dist.ProcessGroup, input_: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
ctx.group = group ctx.group = group
ctx.scatter_idx = scatter_idx ctx.scatter_idx = scatter_idx
@ -40,7 +40,7 @@ class _SeqAllToAll(torch.autograd.Function):
seq_world_size = dist.get_world_size(group) seq_world_size = dist.get_world_size(group)
input_list = [t.contiguous() for t in torch.tensor_split(input, seq_world_size, scatter_idx)] input_list = [t.contiguous() for t in torch.tensor_split(input_, seq_world_size, scatter_idx)]
output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
# TODO Use all_to_all_single instead # TODO Use all_to_all_single instead
dist.all_to_all(output_list, input_list, group=group) dist.all_to_all(output_list, input_list, group=group)
@ -51,6 +51,7 @@ class _SeqAllToAll(torch.autograd.Function):
return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None) return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None)
# adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py
class DistributedAttention(torch.nn.Module): class DistributedAttention(torch.nn.Module):
"""Initialization. """Initialization.
@ -73,7 +74,7 @@ class DistributedAttention(torch.nn.Module):
second_gather_idx: int = 1, second_gather_idx: int = 1,
) -> None: ) -> None:
super(DistributedAttention, self).__init__() super().__init__()
self.local_attn = local_attention self.local_attn = local_attention
self.spg = sequence_process_group self.spg = sequence_process_group
self.first_scatter_idx = first_scatter_idx self.first_scatter_idx = first_scatter_idx
@ -82,7 +83,7 @@ class DistributedAttention(torch.nn.Module):
self.second_gather_idx = second_gather_idx self.second_gather_idx = second_gather_idx
def forward(self, qkv: Tensor, **kwargs: Any) -> Tensor: def forward(self, qkv: Tensor, **kwargs: Any) -> Tensor:
""" forward """forward
Arguments: Arguments:
query (Tensor): query input to the layer query (Tensor): query input to the layer
@ -93,24 +94,25 @@ class DistributedAttention(torch.nn.Module):
Returns: Returns:
* output (Tensor): context output * output (Tensor): context output
""" """
# TODO Merge three alltoall calls into one # Evaluation
if qkv.ndim == 5: if qkv.ndim == 5:
# in shape: [seq/tp_size, 3, head, head_dim] # in shape: [batch, seq/tp_size, 3, head, head_dim]
qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx + 1, self.first_gather_idx + 1) qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx + 1, self.first_gather_idx + 1)
#out shape : [seq, head/tp_size, head_dim] # out shape : [batch, seq, head/tp_size, head_dim]
context_layer = self.local_attn(qkv, **kwargs) context_layer = self.local_attn(qkv, **kwargs)
# in shape: [seq, head/tp_size, head_dim] # in shape: [batch, seq, head/tp_size, head_dim]
output = _SeqAllToAll.apply(self.spg, context_layer, self.second_scatter_idx + 1, self.second_gather_idx + 1) output = _SeqAllToAll.apply(
else: self.spg, context_layer, self.second_scatter_idx + 1, self.second_gather_idx + 1
)
else: # training
# in shape: [seq/tp_size, 3, head, head_dim] # in shape: [seq/tp_size, 3, head, head_dim]
qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx, self.first_gather_idx) qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx, self.first_gather_idx)
#out shape : [seq, head/tp_size, head_dim] # out shape : [seq, head/tp_size, head_dim]
context_layer = self.local_attn(qkv, **kwargs) context_layer = self.local_attn(qkv, **kwargs)
# in shape: [seq, head/tp_size, head_dim] # in shape: [seq, head/tp_size, head_dim]
output = _SeqAllToAll.apply(self.spg, context_layer, self.second_scatter_idx, self.second_gather_idx) output = _SeqAllToAll.apply(self.spg, context_layer, self.second_scatter_idx, self.second_gather_idx)
#out e.g., [s/p::h] # out e.g., [s/p::h]
return output return output
@ -157,7 +159,7 @@ class MHA(nn.Module):
use_flash_attn: bool = True, use_flash_attn: bool = True,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
tp_mode: str = 'origin_tp', tp_mode: str = "origin_tp",
) -> None: ) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
@ -185,7 +187,7 @@ class MHA(nn.Module):
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device) self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
# notice here should change bias=True # notice here should change bias=True
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == 'origin_tp' else FSDPLinear Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
self.Wqkv = Wqkv_cls( self.Wqkv = Wqkv_cls(
embed_dim, embed_dim,
3 * embed_dim, 3 * embed_dim,
@ -201,12 +203,12 @@ class MHA(nn.Module):
self.inner_cross_attn = inner_cross_attn_cls( self.inner_cross_attn = inner_cross_attn_cls(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
) )
if tp_mode == 'fstp': if tp_mode == "fstp":
self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group) self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group)
self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group) self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group)
# output projection always have the bias (for now) # output projection always have the bias (for now)
out_proj_cls = RowParallelLinearTorch if tp_mode == 'origin_tp' else FSDPLinear out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
self.out_proj = out_proj_cls( self.out_proj = out_proj_cls(
embed_dim, embed_dim,
embed_dim, embed_dim,
@ -214,7 +216,6 @@ class MHA(nn.Module):
sequence_parallel=gpc.config.parallel.sequence_parallel, sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs, **factory_kwargs,
) )
# need to assign tp attribute so that internlm know it is tensor parallel module # need to assign tp attribute so that internlm know it is tensor parallel module
if gpc.get_world_size(ParallelMode.TENSOR) > 1: if gpc.get_world_size(ParallelMode.TENSOR) > 1:
for name in ["out_proj", "Wqkv"]: for name in ["out_proj", "Wqkv"]:
@ -311,7 +312,6 @@ class MHA(nn.Module):
qkv = rearrange(qkv, "t (three h d) -> t three h d", three=3, d=self.head_dim) # total x 3 x n_head x d qkv = rearrange(qkv, "t (three h d) -> t three h d", three=3, d=self.head_dim) # total x 3 x n_head x d
qkv = self.rotary_emb(qkv, **kwargs) qkv = self.rotary_emb(qkv, **kwargs)
kwargs.pop("indexes") kwargs.pop("indexes")
if inference_params is None: if inference_params is None:
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn: if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
with torch.cuda.amp.autocast(dtype=torch.bfloat16): with torch.cuda.amp.autocast(dtype=torch.bfloat16):

View File

@ -3,18 +3,14 @@
from typing import Optional from typing import Optional
import fused_dense_lib as fused_dense_cuda
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from flash_attn.utils.distributed import ( from flash_attn.utils.distributed import all_reduce_raw, reduce_scatter_raw
all_reduce_raw,
reduce_scatter_raw,
)
from torch import Tensor from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
import fused_dense_lib as fused_dense_cuda
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
@ -123,8 +119,9 @@ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool =
shape = list(input_.shape) shape = list(input_.shape)
shape[gather_dim] = shape[gather_dim] * world_size shape[gather_dim] = shape[gather_dim] * world_size
output = torch.empty(shape, dtype=input_.dtype, device=input_.device) output = torch.empty(shape, dtype=input_.dtype, device=input_.device)
handle = torch.distributed.all_gather_into_tensor(output, input_.contiguous(), handle = torch.distributed.all_gather_into_tensor(
group=process_group, async_op=async_op) output, input_.contiguous(), group=process_group, async_op=async_op
)
return output, handle return output, handle
@ -137,11 +134,11 @@ def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias):
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py # adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
class FusedDenseFunc(torch.autograd.Function): class FusedDenseFunc(torch.autograd.Function):
"tp fused dense function"
@staticmethod @staticmethod
@custom_fwd @custom_fwd
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, def forward(ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True, gather_dim=0):
sequence_parallel=True, gather_dim=0):
""" """
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
with sequence parallelism: we do an all_gather_raw of x before doing the matmul. with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
@ -171,7 +168,7 @@ class FusedDenseFunc(torch.autograd.Function):
batch_dim = batch_shape.numel() batch_dim = batch_shape.numel()
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
if min(batch_dim, n, *weight.shape) > 65535 * 32: if min(batch_dim, n, *weight.shape) > 65535 * 32:
raise RuntimeError('fused_dense only supports matrix dims <= 2M') raise RuntimeError("fused_dense only supports matrix dims <= 2M")
output = F.linear(total_x, weight, bias) output = F.linear(total_x, weight, bias)
if ctx.compute_weight_gradient: if ctx.compute_weight_gradient:
ctx.save_for_backward(x, weight) ctx.save_for_backward(x, weight)
@ -184,7 +181,7 @@ class FusedDenseFunc(torch.autograd.Function):
def backward(ctx, grad_output, *args): def backward(ctx, grad_output, *args):
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
if ctx.return_residual: if ctx.return_residual:
grad_input, = args (grad_input,) = args
grad_input = grad_input.contiguous() grad_input = grad_input.contiguous()
process_group = ctx.process_group process_group = ctx.process_group
sequence_parallel = ctx.sequence_parallel sequence_parallel = ctx.sequence_parallel
@ -197,7 +194,7 @@ class FusedDenseFunc(torch.autograd.Function):
else: else:
total_x = x total_x = x
else: else:
weight, = ctx.saved_tensors (weight,) = ctx.saved_tensors
total_x = None total_x = None
batch_shape = grad_output.shape[:-1] batch_shape = grad_output.shape[:-1]
batch_dim = batch_shape.numel() batch_dim = batch_shape.numel()
@ -206,8 +203,7 @@ class FusedDenseFunc(torch.autograd.Function):
if not ctx.return_residual: if not ctx.return_residual:
grad_input = F.linear(grad_output, weight.t()) grad_input = F.linear(grad_output, weight.t())
else: else:
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight)
grad_output, weight)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
if process_group is not None: if process_group is not None:
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
@ -282,7 +278,8 @@ class FusedDenseFuncTorch(FusedDenseFunc):
return grad_input, grad_weight, grad_bias, None, None, None, None return grad_input, grad_weight, grad_bias, None, None, None, None
class FSDPFusedDenseFunc(torch.autograd.Function): class FSTPFusedDenseFunc(torch.autograd.Function):
"FSTP fused dense function"
@staticmethod @staticmethod
@custom_fwd @custom_fwd
@ -319,7 +316,7 @@ class FSDPFusedDenseFunc(torch.autograd.Function):
batch_dim = batch_shape.numel() batch_dim = batch_shape.numel()
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
if min(batch_dim, n, *total_weight.shape) > 65535 * 32: if min(batch_dim, n, *total_weight.shape) > 65535 * 32:
raise RuntimeError('fused_dense only supports matrix dims <= 2M') raise RuntimeError("fused_dense only supports matrix dims <= 2M")
output = F.linear(total_x, total_weight, total_bias) output = F.linear(total_x, total_weight, total_bias)
if ctx.compute_weight_gradient: if ctx.compute_weight_gradient:
ctx.save_for_backward(x, weight) ctx.save_for_backward(x, weight)
@ -332,14 +329,14 @@ class FSDPFusedDenseFunc(torch.autograd.Function):
def backward(ctx, grad_output, *args): def backward(ctx, grad_output, *args):
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
if ctx.return_residual: if ctx.return_residual:
grad_input, = args (grad_input,) = args
grad_input = grad_input.contiguous() grad_input = grad_input.contiguous()
process_group = ctx.process_group process_group = ctx.process_group
if ctx.compute_weight_gradient: if ctx.compute_weight_gradient:
x, weight = ctx.saved_tensors x, weight = ctx.saved_tensors
total_x = x total_x = x
else: else:
weight, = ctx.saved_tensors (weight,) = ctx.saved_tensors
total_x = None total_x = None
batch_shape = grad_output.shape[:-1] batch_shape = grad_output.shape[:-1]
batch_dim = batch_shape.numel() batch_dim = batch_shape.numel()
@ -357,8 +354,7 @@ class FSDPFusedDenseFunc(torch.autograd.Function):
if not ctx.return_residual: if not ctx.return_residual:
grad_input = F.linear(grad_output, total_weight.t()) grad_input = F.linear(grad_output, total_weight.t())
else: else:
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, total_weight)
grad_output, total_weight)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
else: else:
grad_input = None grad_input = None
@ -399,12 +395,14 @@ def fused_dense_func_torch(
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim) return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim)
def fsdp_fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, def fstp_fused_dense_func(
return_residual: bool = False, process_group = None): x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, return_residual: bool = False, process_group=None
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] ):
or (x.dtype == torch.float32 and torch.is_autocast_enabled())) dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
x.dtype == torch.float32 and torch.is_autocast_enabled()
)
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
return FSDPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group) return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group)
else: else:
assert process_group is None assert process_group is None
out = F.linear(x, weight, bias) out = F.linear(x, weight, bias)

View File

@ -54,7 +54,7 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape
def switch_sequence_parallel_mode(): def switch_sequence_parallel_mode():
prev_mode = gpc.config.parallel.sequence_parallel prev_mode = gpc.config.parallel.sequence_parallel
try: try:
if gpc.config.parallel["tensor"]["mode"] == 'fstp': if gpc.config.parallel["tensor"]["mode"] == "fstp":
gpc.config.parallel.sequence_parallel = True gpc.config.parallel.sequence_parallel = True
else: else:
gpc.config.parallel.sequence_parallel = False gpc.config.parallel.sequence_parallel = False
@ -106,10 +106,14 @@ def evaluate_on_val_dls(
total_val_bsz = len(batch[1]) total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0 assert total_val_bsz % data_cfg.micro_bsz == 0
num_microbatches = total_val_bsz // data_cfg.micro_bsz num_microbatches = total_val_bsz // data_cfg.micro_bsz
if gpc.config.parallel['tensor']['mode'] == 'fstp': if gpc.config.parallel["tensor"]["mode"] == "fstp":
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR) sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
tensor_shape = torch.Size( tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1] // sequence_world_size, gpc.config.HIDDEN_SIZE] [
data_cfg.micro_bsz,
batch[0]["input_ids"].shape[1] // sequence_world_size,
gpc.config.HIDDEN_SIZE,
]
) )
else: else:
tensor_shape = torch.Size( tensor_shape = torch.Size(