pull/407/head
yingtongxiong 2023-10-09 20:39:57 +08:00
parent 29df765f65
commit f191853bf4
6 changed files with 99 additions and 107 deletions

View File

@ -283,8 +283,10 @@ def args_sanity_check():
if gpc.config.parallel["tensor"].get("mode", None) is None:
gpc.config.parallel["tensor"]["mode"] = "origin_tp"
if gpc.config.parallel["tensor"].get("mode", None) == 'fstp':
assert gpc.config.parallel.sequence_parallel is True, "when the tp_mode is fstp, the sequence_parallel should be True."
if gpc.config.parallel["tensor"].get("mode", None) == "fstp":
assert (
gpc.config.parallel.sequence_parallel is True
), "when the tp_mode is fstp, the sequence_parallel should be True."
# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:

View File

@ -9,10 +9,9 @@ from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
from flash_attn.utils.distributed import all_reduce, reduce_scatter
from torch import nn
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.utils import Silu, fused_dense_func_torch, fsdp_fused_dense_func
from internlm.model.utils import Silu, fstp_fused_dense_func, fused_dense_func_torch
class ScaleColumnParallelLinear(nn.Linear):
@ -124,7 +123,12 @@ class ColumnParallelLinearTorch(ColumnParallelLinear):
# If not, then the input is already gathered.
return fused_dense_func_torch(
x, self.weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel, gather_dim=gather_dim,
x,
self.weight,
self.bias,
process_group=self.process_group,
sequence_parallel=self.sequence_parallel,
gather_dim=gather_dim,
)
@ -204,31 +208,13 @@ class FeedForward(nn.Module):
out = self.w3(Silu(w1_o, w2_o))
return out
class FSDPLinear(ColumnParallelLinear):
class FSTPLinear(ColumnParallelLinear):
def forward(self, x):
return fsdp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group)
return fstp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group)
class FSDPScaleLinear(ScaleColumnParallelLinear):
def forward(self, input): # pylint: disable=W0622
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
if self.weight_scale != 1:
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
else:
weight = self.weight
return fsdp_fused_dense_func(
input,
weight,
self.bias,
process_group=self.process_group,
)
class FSDPFeedForward(nn.Module):
class FSTPFeedForward(nn.Module):
"""
FeedForward.
@ -259,7 +245,7 @@ class FSDPFeedForward(nn.Module):
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
self.w1 = FSDPLinear(
self.w1 = FSTPLinear(
in_features,
hidden_features,
process_group,
@ -268,7 +254,7 @@ class FSDPFeedForward(nn.Module):
device=device,
dtype=dtype,
)
self.w2 = FSDPLinear(
self.w2 = FSTPLinear(
in_features,
hidden_features,
process_group,
@ -277,7 +263,7 @@ class FSDPFeedForward(nn.Module):
device=device,
dtype=dtype,
)
self.w3 = FSDPLinear(
self.w3 = FSTPLinear(
hidden_features,
out_features,
process_group,

View File

@ -15,13 +15,16 @@ from internlm.initialize.initialize_tensor import normal_, scaled_init_method_no
from internlm.model.embedding import Embedding1D
from internlm.model.linear import (
FeedForward,
FSTPFeedForward,
RewardModelLinear,
ScaleColumnParallelLinear,
FSDPScaleLinear,
FSDPFeedForward,
)
from internlm.model.multi_head_attention import MHA
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm, split_forward_gather_backward
from internlm.model.utils import (
gather_forward_split_backward,
split_forward_gather_backward,
try_import_RMSNorm,
)
from internlm.solver.pipeline_utils import partition_uniform
from internlm.utils.checkpoint import activation_checkpoint
from internlm.utils.common import filter_kwargs
@ -74,7 +77,7 @@ class PackedFlashBaseLayer1D(nn.Module):
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
tp_mode: str = 'origin_tp',
tp_mode: str = "origin_tp",
):
super().__init__()
self.checkpoint = checkpoint
@ -111,7 +114,7 @@ class PackedFlashBaseLayer1D(nn.Module):
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
if use_swiglu:
mlp_cls = FeedForward if tp_mode == 'origin_tp' else FSDPFeedForward
mlp_cls = FeedForward if tp_mode == "origin_tp" else FSTPFeedForward
self.mlp = mlp_cls(
hidden_size,
int(hidden_size * mlp_ratio),
@ -173,7 +176,6 @@ class PackedFlashBaseLayer1D(nn.Module):
else:
normal_(std=0.006 if "fc1" in name else 0.0015)(param.data)
def forward(self, hidden_states, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None):
if self.checkpoint and self.training:
return activation_checkpoint(
@ -341,7 +343,7 @@ class PackedFlashInternLm1D(nn.Module):
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
tp_mode = self.tp_mode,
tp_mode=self.tp_mode,
)
for lid in range(num_layers)
]
@ -388,7 +390,7 @@ class PackedFlashInternLm1D(nn.Module):
# The indexes are used to indicate the actual position IDs of each token in the packed input.
indexes = indexes[0]
# if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension.
if gpc.config.parallel.sequence_parallel and self.tp_mode == 'fstp':
if gpc.config.parallel.sequence_parallel and self.tp_mode == "fstp":
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None

View File

@ -2,9 +2,10 @@
# -*- encoding: utf-8 -*-
import warnings
from typing import Optional
from typing import Any, Optional, Tuple
import torch
import torch.distributed as dist
from einops import rearrange
from flash_attn.modules.mha import (
CrossAttention,
@ -13,26 +14,25 @@ from flash_attn.modules.mha import (
SelfAttention,
_update_kv_cache,
)
from torch import nn
from torch import Tensor, nn
from torch.nn import Module
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.embedding import DynamicNTKScalingRotaryEmbedding, RotaryEmbedding
from internlm.model.linear import ColumnParallelLinearTorch, RowParallelLinearTorch, FSDPLinear
import torch
from typing import Any, Tuple
from torch import Tensor
from torch.nn import Module
import torch.distributed as dist
from internlm.model.linear import (
ColumnParallelLinearTorch,
FSTPLinear,
RowParallelLinearTorch,
)
# adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py
class _SeqAllToAll(torch.autograd.Function):
"sequence alltoall"
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
def forward(ctx: Any, group: dist.ProcessGroup, input_: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
ctx.group = group
ctx.scatter_idx = scatter_idx
@ -40,7 +40,7 @@ class _SeqAllToAll(torch.autograd.Function):
seq_world_size = dist.get_world_size(group)
input_list = [t.contiguous() for t in torch.tensor_split(input, seq_world_size, scatter_idx)]
input_list = [t.contiguous() for t in torch.tensor_split(input_, seq_world_size, scatter_idx)]
output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
# TODO Use all_to_all_single instead
dist.all_to_all(output_list, input_list, group=group)
@ -51,6 +51,7 @@ class _SeqAllToAll(torch.autograd.Function):
return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None)
# adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py
class DistributedAttention(torch.nn.Module):
"""Initialization.
@ -73,7 +74,7 @@ class DistributedAttention(torch.nn.Module):
second_gather_idx: int = 1,
) -> None:
super(DistributedAttention, self).__init__()
super().__init__()
self.local_attn = local_attention
self.spg = sequence_process_group
self.first_scatter_idx = first_scatter_idx
@ -82,7 +83,7 @@ class DistributedAttention(torch.nn.Module):
self.second_gather_idx = second_gather_idx
def forward(self, qkv: Tensor, **kwargs: Any) -> Tensor:
""" forward
"""forward
Arguments:
query (Tensor): query input to the layer
@ -93,24 +94,25 @@ class DistributedAttention(torch.nn.Module):
Returns:
* output (Tensor): context output
"""
# TODO Merge three alltoall calls into one
# Evaluation
if qkv.ndim == 5:
# in shape: [seq/tp_size, 3, head, head_dim]
# in shape: [batch, seq/tp_size, 3, head, head_dim]
qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx + 1, self.first_gather_idx + 1)
#out shape : [seq, head/tp_size, head_dim]
# out shape : [batch, seq, head/tp_size, head_dim]
context_layer = self.local_attn(qkv, **kwargs)
# in shape: [seq, head/tp_size, head_dim]
output = _SeqAllToAll.apply(self.spg, context_layer, self.second_scatter_idx + 1, self.second_gather_idx + 1)
else:
# in shape: [batch, seq, head/tp_size, head_dim]
output = _SeqAllToAll.apply(
self.spg, context_layer, self.second_scatter_idx + 1, self.second_gather_idx + 1
)
else: # training
# in shape: [seq/tp_size, 3, head, head_dim]
qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx, self.first_gather_idx)
#out shape : [seq, head/tp_size, head_dim]
# out shape : [seq, head/tp_size, head_dim]
context_layer = self.local_attn(qkv, **kwargs)
# in shape: [seq, head/tp_size, head_dim]
output = _SeqAllToAll.apply(self.spg, context_layer, self.second_scatter_idx, self.second_gather_idx)
#out e.g., [s/p::h]
# out e.g., [s/p::h]
return output
@ -157,7 +159,7 @@ class MHA(nn.Module):
use_flash_attn: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
tp_mode: str = 'origin_tp',
tp_mode: str = "origin_tp",
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
@ -185,7 +187,7 @@ class MHA(nn.Module):
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
# notice here should change bias=True
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == 'origin_tp' else FSDPLinear
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
self.Wqkv = Wqkv_cls(
embed_dim,
3 * embed_dim,
@ -201,12 +203,12 @@ class MHA(nn.Module):
self.inner_cross_attn = inner_cross_attn_cls(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
)
if tp_mode == 'fstp':
if tp_mode == "fstp":
self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group)
self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group)
# output projection always have the bias (for now)
out_proj_cls = RowParallelLinearTorch if tp_mode == 'origin_tp' else FSDPLinear
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
self.out_proj = out_proj_cls(
embed_dim,
embed_dim,
@ -214,7 +216,6 @@ class MHA(nn.Module):
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
)
# need to assign tp attribute so that internlm know it is tensor parallel module
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
for name in ["out_proj", "Wqkv"]:
@ -311,7 +312,6 @@ class MHA(nn.Module):
qkv = rearrange(qkv, "t (three h d) -> t three h d", three=3, d=self.head_dim) # total x 3 x n_head x d
qkv = self.rotary_emb(qkv, **kwargs)
kwargs.pop("indexes")
if inference_params is None:
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
with torch.cuda.amp.autocast(dtype=torch.bfloat16):

View File

@ -3,18 +3,14 @@
from typing import Optional
import fused_dense_lib as fused_dense_cuda
import torch
import torch.nn.functional as F
from flash_attn.utils.distributed import (
all_reduce_raw,
reduce_scatter_raw,
)
from flash_attn.utils.distributed import all_reduce_raw, reduce_scatter_raw
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup
import fused_dense_lib as fused_dense_cuda
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.utils.logger import get_logger
@ -123,8 +119,9 @@ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool =
shape = list(input_.shape)
shape[gather_dim] = shape[gather_dim] * world_size
output = torch.empty(shape, dtype=input_.dtype, device=input_.device)
handle = torch.distributed.all_gather_into_tensor(output, input_.contiguous(),
group=process_group, async_op=async_op)
handle = torch.distributed.all_gather_into_tensor(
output, input_.contiguous(), group=process_group, async_op=async_op
)
return output, handle
@ -137,11 +134,11 @@ def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias):
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
class FusedDenseFunc(torch.autograd.Function):
"tp fused dense function"
@staticmethod
@custom_fwd
def forward(ctx, x, weight, bias, return_residual=False, process_group=None,
sequence_parallel=True, gather_dim=0):
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True, gather_dim=0):
"""
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
@ -171,7 +168,7 @@ class FusedDenseFunc(torch.autograd.Function):
batch_dim = batch_shape.numel()
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
if min(batch_dim, n, *weight.shape) > 65535 * 32:
raise RuntimeError('fused_dense only supports matrix dims <= 2M')
raise RuntimeError("fused_dense only supports matrix dims <= 2M")
output = F.linear(total_x, weight, bias)
if ctx.compute_weight_gradient:
ctx.save_for_backward(x, weight)
@ -184,7 +181,7 @@ class FusedDenseFunc(torch.autograd.Function):
def backward(ctx, grad_output, *args):
grad_output = grad_output.contiguous()
if ctx.return_residual:
grad_input, = args
(grad_input,) = args
grad_input = grad_input.contiguous()
process_group = ctx.process_group
sequence_parallel = ctx.sequence_parallel
@ -197,7 +194,7 @@ class FusedDenseFunc(torch.autograd.Function):
else:
total_x = x
else:
weight, = ctx.saved_tensors
(weight,) = ctx.saved_tensors
total_x = None
batch_shape = grad_output.shape[:-1]
batch_dim = batch_shape.numel()
@ -206,8 +203,7 @@ class FusedDenseFunc(torch.autograd.Function):
if not ctx.return_residual:
grad_input = F.linear(grad_output, weight.t())
else:
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
grad_output, weight)
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
if process_group is not None:
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
@ -282,7 +278,8 @@ class FusedDenseFuncTorch(FusedDenseFunc):
return grad_input, grad_weight, grad_bias, None, None, None, None
class FSDPFusedDenseFunc(torch.autograd.Function):
class FSTPFusedDenseFunc(torch.autograd.Function):
"FSTP fused dense function"
@staticmethod
@custom_fwd
@ -319,7 +316,7 @@ class FSDPFusedDenseFunc(torch.autograd.Function):
batch_dim = batch_shape.numel()
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
if min(batch_dim, n, *total_weight.shape) > 65535 * 32:
raise RuntimeError('fused_dense only supports matrix dims <= 2M')
raise RuntimeError("fused_dense only supports matrix dims <= 2M")
output = F.linear(total_x, total_weight, total_bias)
if ctx.compute_weight_gradient:
ctx.save_for_backward(x, weight)
@ -332,14 +329,14 @@ class FSDPFusedDenseFunc(torch.autograd.Function):
def backward(ctx, grad_output, *args):
grad_output = grad_output.contiguous()
if ctx.return_residual:
grad_input, = args
(grad_input,) = args
grad_input = grad_input.contiguous()
process_group = ctx.process_group
if ctx.compute_weight_gradient:
x, weight = ctx.saved_tensors
total_x = x
else:
weight, = ctx.saved_tensors
(weight,) = ctx.saved_tensors
total_x = None
batch_shape = grad_output.shape[:-1]
batch_dim = batch_shape.numel()
@ -357,8 +354,7 @@ class FSDPFusedDenseFunc(torch.autograd.Function):
if not ctx.return_residual:
grad_input = F.linear(grad_output, total_weight.t())
else:
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
grad_output, total_weight)
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, total_weight)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
else:
grad_input = None
@ -399,12 +395,14 @@ def fused_dense_func_torch(
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim)
def fsdp_fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
return_residual: bool = False, process_group = None):
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
def fstp_fused_dense_func(
x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, return_residual: bool = False, process_group=None
):
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
x.dtype == torch.float32 and torch.is_autocast_enabled()
)
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
return FSDPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group)
return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group)
else:
assert process_group is None
out = F.linear(x, weight, bias)

View File

@ -54,7 +54,7 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape
def switch_sequence_parallel_mode():
prev_mode = gpc.config.parallel.sequence_parallel
try:
if gpc.config.parallel["tensor"]["mode"] == 'fstp':
if gpc.config.parallel["tensor"]["mode"] == "fstp":
gpc.config.parallel.sequence_parallel = True
else:
gpc.config.parallel.sequence_parallel = False
@ -106,10 +106,14 @@ def evaluate_on_val_dls(
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
num_microbatches = total_val_bsz // data_cfg.micro_bsz
if gpc.config.parallel['tensor']['mode'] == 'fstp':
if gpc.config.parallel["tensor"]["mode"] == "fstp":
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1] // sequence_world_size, gpc.config.HIDDEN_SIZE]
[
data_cfg.micro_bsz,
batch[0]["input_ids"].shape[1] // sequence_world_size,
gpc.config.HIDDEN_SIZE,
]
)
else:
tensor_shape = torch.Size(