support fstp and refactor code

pull/407/head
yingtongxiong 2023-10-09 17:26:20 +08:00
parent bd4af3a31f
commit 189a313da6
7 changed files with 104 additions and 144 deletions

View File

@ -5,7 +5,7 @@ SEQ_LEN = 2048
HIDDEN_SIZE = 4096
NUM_ATTENTION_HEAD = 32
MLP_RATIO = 8 / 3
NUM_LAYER = 32
NUM_LAYER = 4
VOCAB_SIZE = 103168
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
@ -55,7 +55,7 @@ data = dict(
# defaults to the value of micro_num
valid_micro_num=4,
# defaults to 0, means disable evaluate
valid_every=1000,
valid_every=10,
pack_sample_into_one=False,
total_steps=50000,
skip_batches="",
@ -64,7 +64,7 @@ data = dict(
min_length=50,
# train_folder=TRAIN_FOLDER,
# valid_folder=VALID_FOLDER,
empty_cache_and_diag_interval=10,
empty_cache_and_diag_interval=100,
diag_outlier_ratio=1.1,
)
@ -135,7 +135,7 @@ model = dict(
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
apply_post_layer_norm=False,
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,
@ -155,7 +155,7 @@ tensor parallel: tensor parallel size, usually the number of GPUs per node.
"""
parallel = dict(
zero1=-1,
tensor=2,
tensor=dict(size=2, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'
pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=True,
)

View File

@ -568,7 +568,8 @@ class ParallelContext(metaclass=SingletonMeta):
# during model construction), this is because the random state will be different in different tensor parallel
# device of the same data parallel group. The underlying reason is that the device of tp_rank = 0 will perform
# additional random operations during the RowParallelLinear module building process.
set_mode(ParallelMode.DUMMY)
# set_mode(ParallelMode.DUMMY)
set_mode(ParallelMode.TENSOR)
seeds = get_seeds()
seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()])

View File

@ -279,6 +279,12 @@ def args_sanity_check():
assert not (
gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False
), "sequence parallel does not support use_flash_attn=False"
if gpc.config.parallel["tensor"].get("mode", None) is None:
gpc.config.parallel["tensor"]["mode"] = "origin_tp"
if gpc.config.parallel["tensor"].get("mode", None) is 'fstp':
assert gpc.config.parallel.sequence_parallel is True, "when the tp_mode is fstp, the sequence_parallel should be True."
# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:

View File

@ -4,44 +4,20 @@
from typing import Optional
import torch
import torch.nn.functional as F
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
from flash_attn.utils.distributed import all_reduce, reduce_scatter, all_gather_raw, reduce_scatter_raw
from torch import Tensor
from torch import nn
from torch.cuda.amp import custom_bwd, custom_fwd
# import fused_dense_cuda # from apex
import fused_dense_lib as fused_dense_cuda
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.utils import Silu, fused_dense_func_torch
from typing import Optional
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.cuda.amp import custom_bwd, custom_fwd
# import fused_dense_cuda # from apex
import fused_dense_lib as fused_dense_cuda
from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_fwd, sqrelu_bwd
from flash_attn.utils.distributed import all_gather_raw, all_reduce_raw
# reduce_scatter_raw
from flash_attn.utils.distributed import reduce_scatter, all_reduce
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False, op=torch.distributed.ReduceOp.SUM):
world_size = torch.distributed.get_world_size(process_group)
assert input_.shape[0] % world_size == 0
output = torch.empty(
input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
)
handle = torch.distributed.reduce_scatter_tensor(
output, input_.contiguous(), op=op, group=process_group, async_op=async_op
)
return output, handle
class ScaleColumnParallelLinear(nn.Linear):
"""
@ -231,7 +207,7 @@ class FeedForward(nn.Module):
out = self.w3(Silu(w1_o, w2_o))
return out
class FusedDenseFunc_fsdp(torch.autograd.Function):
class FSDPFusedDenseFunc(torch.autograd.Function):
@staticmethod
@custom_fwd
@ -243,21 +219,26 @@ class FusedDenseFunc_fsdp(torch.autograd.Function):
if torch.is_autocast_enabled():
x = x.to(dtype=torch.get_autocast_gpu_dtype())
x = x.contiguous()
total_x = x
total_x = x.contiguous()
# do all_gather for weight and bias before actual computation
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
if bias is not None:
total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
handle_bias.wait()
world_size = gpc.get_world_size(ParallelMode.TENSOR)
if world_size > 1:
# do all_gather for weight and bias before actual computation
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
if bias is not None:
total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
handle_bias.wait()
else:
total_bias = bias
handle_weight.wait()
else:
total_weight = weight
total_bias = bias
if torch.is_autocast_enabled():
total_weight = total_weight.to(dtype=torch.get_autocast_gpu_dtype())
total_bias = total_bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
handle_weight.wait()
total_weight = total_weight.contiguous()
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
batch_dim = batch_shape.numel()
@ -289,9 +270,13 @@ class FusedDenseFunc_fsdp(torch.autograd.Function):
batch_dim = batch_shape.numel()
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
# do all-gather for weight before backward
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
handle_weight.wait()
world_size = gpc.get_world_size(ParallelMode.TENSOR)
if world_size > 1:
# do all-gather for weight before backward
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
handle_weight.wait()
else:
total_weight = weight
if ctx.needs_input_grad[0]:
if not ctx.return_residual:
@ -300,32 +285,24 @@ class FusedDenseFunc_fsdp(torch.autograd.Function):
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
grad_output, total_weight)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
# if process_group is not None:
# import pdb; pdb.set_trace()
# grad_input, handle_grad_input = reduce_scatter_raw(grad_input, process_group, async_op=True)
# grad_input, handle_grad_input = all_reduce_raw(grad_input, process_group, async_op=True)
else:
grad_input = None
# import pdb; pdb.set_trace()
if ctx.needs_input_grad[1]:
assert ctx.compute_weight_gradient
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
)
grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
if grad_bias is not None:
grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
handle_grad_bias.wait()
handle_grad_weight.wait()
if world_size > 1:
grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
if grad_bias is not None:
grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
handle_grad_bias.wait()
handle_grad_weight.wait()
else:
grad_weight = None
grad_bias = grad_output if ctx.needs_input_grad[2] else None
# if process_group is not None and ctx.needs_input_grad[0]:
# handle_grad_input.wait()
# import pdb; pdb.set_trace()
return grad_input, grad_weight, grad_bias, None, None, None
@ -334,7 +311,7 @@ def fsdp_fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = No
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
return FusedDenseFunc_fsdp.apply(x, weight, bias, return_residual, process_group)
return FSDPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group)
else:
assert process_group is None
out = F.linear(x, weight, bias)
@ -426,5 +403,5 @@ class FSDPFeedForward(nn.Module):
def forward(self, x):
w1_o = self.w1(x)
w2_o = self.w2(x)
out = self.w3(Silu(w1_o, w2_o))
out = self.w3(F.silu(w1_o) * w2_o)
return out

View File

@ -74,6 +74,7 @@ class PackedFlashBaseLayer1D(nn.Module):
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
tp_mode: str = 'origin_tp',
):
super().__init__()
self.checkpoint = checkpoint
@ -98,6 +99,7 @@ class PackedFlashBaseLayer1D(nn.Module):
use_flash_attn=use_flash_attn,
device=device,
dtype=dtype,
tp_mode=tp_mode,
)
self.dropout1 = nn.Dropout(drop_rate)
@ -109,16 +111,8 @@ class PackedFlashBaseLayer1D(nn.Module):
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
if use_swiglu:
# self.mlp = FeedForward(
# hidden_size,
# int(hidden_size * mlp_ratio),
# out_features=hidden_size,
# process_group=gpc.get_group(ParallelMode.TENSOR),
# bias=False,
# device=device,
# dtype=dtype,
# )
self.mlp = FSDPFeedForward(
mlp_cls = FeedForward if tp_mode == 'origin_tp' else FSDPFeedForward
self.mlp = mlp_cls(
hidden_size,
int(hidden_size * mlp_ratio),
out_features=hidden_size,
@ -179,6 +173,7 @@ class PackedFlashBaseLayer1D(nn.Module):
else:
normal_(std=0.006 if "fc1" in name else 0.0015)(param.data)
def forward(self, hidden_states, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None):
if self.checkpoint and self.training:
return activation_checkpoint(
@ -300,12 +295,12 @@ class PackedFlashInternLm1D(nn.Module):
super().__init__()
checkpoint_layer_num = int(num_layers * checkpoint)
self.tp_mode = gpc.config.parallel["tensor"]["mode"]
if is_reward:
head_cls = RewardModelLinear
else:
# head_cls = ScaleColumnParallelLinear
head_cls = FSDPScaleLinear
head_cls = ScaleColumnParallelLinear
if first:
if embed_split_hidden:
self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
@ -346,6 +341,7 @@ class PackedFlashInternLm1D(nn.Module):
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
tp_mode = self.tp_mode,
)
for lid in range(num_layers)
]
@ -391,7 +387,8 @@ class PackedFlashInternLm1D(nn.Module):
assert len(indexes) == 1
# The indexes are used to indicate the actual position IDs of each token in the packed input.
indexes = indexes[0]
if gpc.config.parallel.sequence_parallel:
# if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension.
if gpc.config.parallel.sequence_parallel and self.tp_mode == 'fstp':
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
@ -408,8 +405,12 @@ class PackedFlashInternLm1D(nn.Module):
if hasattr(self, "norm"):
hidden_states = self.norm(hidden_states.float())
if hasattr(self, "head"):
# if hidden_states.ndim == 3:
# import pdb; pdb.set_trace()
# hidden_states = self.head(hidden_states, dim=1)
# else:
# hidden_states = self.head(hidden_states)
hidden_states = self.head(hidden_states)
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=0)
if not self.parallel_output:
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)

View File

@ -57,49 +57,29 @@ class DistributedAttention(torch.nn.Module):
Arguments:
local_attention (Module): local attention with q,k,v
sequence_process_group (ProcessGroup): sequence parallel process group
scatter_idx (int): scatter_idx for all2all comm
gather_idx (int): gather_idx for all2all comm
first_scatter_idx (int): scatter_idx for the first all2all comm
first_gather_idx (int): gather_idx for the first all2all comm
second_scatter_idx (int): scatter_idx for the second all2all comm
second_gather_idx (int): gather_idx for the second all2all comm
"""
def __init__(
self,
local_attention: Module,
sequence_process_group: dist.ProcessGroup,
scatter_idx: int = 2,
gather_idx: int = 0,
first_scatter_idx: int = 2,
first_gather_idx: int = 0,
second_scatter_idx: int = 0,
second_gather_idx: int = 1,
) -> None:
super(DistributedAttention, self).__init__()
self.local_attn = local_attention
self.spg = sequence_process_group
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx
# def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tensor:
# """ forward
# Arguments:
# query (Tensor): query input to the layer
# key (Tensor): key input to the layer
# value (Tensor): value input to the layer
# args: other args
# Returns:
# * output (Tensor): context output
# """
# # TODO Merge three alltoall calls into one
# #in shape : e.g., [s/p:h:]
# query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx)
# key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx)
# value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)
# #out shape : e.g., [s:h/p:]
# context_layer = self.local_attn(query_layer, key_layer, value_layer, *args)
# output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
# #out e.g., [s/p::h]
# return output
self.first_scatter_idx = first_scatter_idx
self.first_gather_idx = first_gather_idx
self.second_scatter_idx = second_scatter_idx
self.second_gather_idx = second_gather_idx
def forward(self, qkv: Tensor, **kwargs: Any) -> Tensor:
""" forward
@ -114,15 +94,21 @@ class DistributedAttention(torch.nn.Module):
* output (Tensor): context output
"""
# TODO Merge three alltoall calls into one
#in shape : e.g., [s/p:h:]
qkv = _SeqAllToAll.apply(self.spg, qkv, 2, 0)
# key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx)
# value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)
#out shape : e.g., [s:h/p:]
context_layer = self.local_attn(qkv, **kwargs)
output = _SeqAllToAll.apply(self.spg, context_layer, 0, 1)
if qkv.ndim == 5:
# in shape: [seq/tp_size, 3, head, head_dim]
qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx + 1, self.first_gather_idx + 1)
#out shape : [seq, head/tp_size, head_dim]
context_layer = self.local_attn(qkv, **kwargs)
# in shape: [seq, head/tp_size, head_dim]
output = _SeqAllToAll.apply(self.spg, context_layer, self.second_scatter_idx + 1, self.second_gather_idx + 1)
else:
# in shape: [seq/tp_size, 3, head, head_dim]
qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx, self.first_gather_idx)
#out shape : [seq, head/tp_size, head_dim]
context_layer = self.local_attn(qkv, **kwargs)
# in shape: [seq, head/tp_size, head_dim]
output = _SeqAllToAll.apply(self.spg, context_layer, self.second_scatter_idx, self.second_gather_idx)
#out e.g., [s/p::h]
return output
@ -171,6 +157,7 @@ class MHA(nn.Module):
use_flash_attn: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
tp_mode: str = 'origin_tp',
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
@ -198,16 +185,8 @@ class MHA(nn.Module):
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
# notice here should change bias=True
# self.Wqkv = ColumnParallelLinearTorch(
# embed_dim,
# 3 * embed_dim,
# process_group,
# bias=True,
# sequence_parallel=gpc.config.parallel.sequence_parallel,
# **factory_kwargs,
# ) # according to https://spaces.ac.cn/archives/9577
self.Wqkv = FSDPLinear(
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == 'origin_tp' else FSDPLinear
self.Wqkv = Wqkv_cls(
embed_dim,
3 * embed_dim,
process_group,
@ -222,25 +201,20 @@ class MHA(nn.Module):
self.inner_cross_attn = inner_cross_attn_cls(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
)
self.inner_attn_sp = DistributedAttention(self.inner_attn, sequence_process_group=process_group, scatter_idx=3, gather_idx=0)
self.inner_cross_attn_sp = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group, scatter_idx=3, gather_idx=0)
if tp_mode == 'fstp':
self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group)
self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group)
# output projection always have the bias (for now)
# self.out_proj = RowParallelLinearTorch(
# embed_dim,
# embed_dim,
# process_group,
# sequence_parallel=gpc.config.parallel.sequence_parallel,
# **factory_kwargs,
# )
self.out_proj = FSDPLinear(
out_proj_cls = RowParallelLinearTorch if tp_mode == 'origin_tp' else FSDPLinear
self.out_proj = out_proj_cls(
embed_dim,
embed_dim,
process_group,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
)
# need to assign tp attribute so that internlm know it is tensor parallel module
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
for name in ["out_proj", "Wqkv"]:
@ -343,11 +317,9 @@ class MHA(nn.Module):
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
if qkv.dtype not in [torch.float16, torch.bfloat16]:
qkv = qkv.to(torch.bfloat16)
# context = self.inner_attn(qkv, **kwargs).to(x.dtype)
context = self.inner_attn_sp(qkv, **kwargs).to(x.dtype)
context = self.inner_attn(qkv, **kwargs).to(x.dtype)
else:
# context = self.inner_attn(qkv, **kwargs)
context = self.inner_attn_sp(qkv, **kwargs)
context = self.inner_attn(qkv, **kwargs)
else:
raise RuntimeError("Not support this right now")

View File

@ -54,7 +54,10 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape
def switch_sequence_parallel_mode():
prev_mode = gpc.config.parallel.sequence_parallel
try:
gpc.config.parallel.sequence_parallel = False
if gpc.config.parallel["tensor"]["mode"] == 'fstp':
gpc.config.parallel.sequence_parallel = True
else:
gpc.config.parallel.sequence_parallel = False
yield
finally:
gpc.config.parallel.sequence_parallel = prev_mode