support optimized sp

pull/407/head
yingtongxiong 2023-10-07 14:03:47 +08:00
parent c8242572f2
commit 10aa63f0e1
5 changed files with 378 additions and 30 deletions

View File

@ -146,10 +146,10 @@ pipeline parallel (dict):
tensor parallel: tensor parallel size, usually the number of GPUs per node.
"""
parallel = dict(
zero1=8,
tensor=1,
zero1=-1,
tensor=2,
pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=False,
sequence_parallel=True,
)
cudnn_deterministic = False

View File

@ -5,13 +5,32 @@ from typing import Optional
import torch
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
from flash_attn.utils.distributed import all_reduce, reduce_scatter
from flash_attn.utils.distributed import all_reduce, reduce_scatter, all_gather_raw, reduce_scatter_raw
from torch import Tensor
from torch import nn
from torch.cuda.amp import custom_bwd, custom_fwd
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.utils import Silu, fused_dense_func_torch
from typing import Optional
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.cuda.amp import custom_bwd, custom_fwd
# import fused_dense_cuda # from apex
import fused_dense_lib as fused_dense_cuda
from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_fwd, sqrelu_bwd
from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw
from flash_attn.utils.distributed import reduce_scatter, all_reduce
class ScaleColumnParallelLinear(nn.Linear):
"""
@ -200,3 +219,201 @@ class FeedForward(nn.Module):
w2_o = self.w2(x)
out = self.w3(Silu(w1_o, w2_o))
return out
class FusedDenseFunc_fsdp(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x, weight, bias, return_residual=False, process_group=None):
ctx.compute_weight_gradient = weight.requires_grad
ctx.return_residual = return_residual
ctx.process_group = process_group
if torch.is_autocast_enabled():
x = x.to(dtype=torch.get_autocast_gpu_dtype())
x = x.contiguous()
total_x = x
# do all_gather for weight and bias before actual computation
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
if bias is not None:
total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
handle_bias.wait()
else:
total_bias = bias
if torch.is_autocast_enabled():
total_weight = total_weight.to(dtype=torch.get_autocast_gpu_dtype())
total_bias = total_bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
handle_weight.wait()
total_weight = total_weight.contiguous()
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
batch_dim = batch_shape.numel()
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
if min(batch_dim, n, *total_weight.shape) > 65535 * 32:
raise RuntimeError('fused_dense only supports matrix dims <= 2M')
output = F.linear(total_x, total_weight, total_bias)
if ctx.compute_weight_gradient:
ctx.save_for_backward(x, weight)
else:
ctx.save_for_backward(weight)
return output if not return_residual else (output, x)
@staticmethod
@custom_bwd
def backward(ctx, grad_output, *args):
grad_output = grad_output.contiguous()
if ctx.return_residual:
grad_input, = args
grad_input = grad_input.contiguous()
process_group = ctx.process_group
if ctx.compute_weight_gradient:
x, weight = ctx.saved_tensors
total_x = x
else:
weight, = ctx.saved_tensors
total_x = None
batch_shape = grad_output.shape[:-1]
batch_dim = batch_shape.numel()
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
# do all-gather for weight before backward
weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
handle_weight.wait()
if ctx.needs_input_grad[0]:
if not ctx.return_residual:
grad_input = F.linear(grad_output, weight.t())
else:
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
grad_output, weight)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
# if process_group is not None:
# import pdb; pdb.set_trace()
# grad_input, handle_grad_input = reduce_scatter_raw(grad_input, process_group, async_op=True)
# grad_input, handle_grad_input = all_reduce_raw(grad_input, process_group, async_op=True)
else:
grad_input = None
# import pdb; pdb.set_trace()
if ctx.needs_input_grad[1]:
assert ctx.compute_weight_gradient
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
)
grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
if grad_bias is not None:
grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
handle_grad_bias.wait()
handle_grad_weight.wait()
else:
grad_weight = None
grad_bias = grad_output if ctx.needs_input_grad[2] else None
# if process_group is not None and ctx.needs_input_grad[0]:
# handle_grad_input.wait()
# import pdb; pdb.set_trace()
return grad_input, grad_weight, grad_bias, None, None, None
def fsdp_fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
return_residual: bool = False, process_group = None):
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
return FusedDenseFunc_fsdp.apply(x, weight, bias, return_residual, process_group)
else:
assert process_group is None
out = F.linear(x, weight, bias)
return out if not return_residual else (out, x)
class FSDPLinear(ColumnParallelLinear):
def forward(self, x):
return fsdp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group)
class FSDPScaleLinear(ScaleColumnParallelLinear):
def forward(self, input): # pylint: disable=W0622
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
if self.weight_scale != 1:
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
else:
weight = self.weight
return fsdp_fused_dense_func(
input,
weight,
self.bias,
process_group=self.process_group,
)
class FSDPFeedForward(nn.Module):
"""
FeedForward.
Args:
in_features (int): size of each input sample
hidden_features (int): size of hidden state of FFN
out_features (int): size of each output sample
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
in the config.
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default.
"""
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: int = None,
process_group: Optional[torch.distributed.ProcessGroup] = None,
bias: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
multiple_of: int = 256,
):
super().__init__()
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
self.w1 = FSDPLinear(
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w2 = FSDPLinear(
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w3 = FSDPLinear(
hidden_features,
out_features,
process_group,
bias=bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
def forward(self, x):
w1_o = self.w1(x)
w2_o = self.w2(x)
out = self.w3(Silu(w1_o, w2_o))
return out

View File

@ -17,9 +17,11 @@ from internlm.model.linear import (
FeedForward,
RewardModelLinear,
ScaleColumnParallelLinear,
FSDPScaleLinear,
FSDPFeedForward,
)
from internlm.model.multi_head_attention import MHA
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm, split_forward_gather_backward
from internlm.solver.pipeline_utils import partition_uniform
from internlm.utils.checkpoint import activation_checkpoint
from internlm.utils.common import filter_kwargs
@ -107,7 +109,16 @@ class PackedFlashBaseLayer1D(nn.Module):
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
if use_swiglu:
self.mlp = FeedForward(
# self.mlp = FeedForward(
# hidden_size,
# int(hidden_size * mlp_ratio),
# out_features=hidden_size,
# process_group=gpc.get_group(ParallelMode.TENSOR),
# bias=False,
# device=device,
# dtype=dtype,
# )
self.mlp = FSDPFeedForward(
hidden_size,
int(hidden_size * mlp_ratio),
out_features=hidden_size,
@ -293,7 +304,8 @@ class PackedFlashInternLm1D(nn.Module):
if is_reward:
head_cls = RewardModelLinear
else:
head_cls = ScaleColumnParallelLinear
# head_cls = ScaleColumnParallelLinear
head_cls = FSDPScaleLinear
if first:
if embed_split_hidden:
self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
@ -379,6 +391,9 @@ class PackedFlashInternLm1D(nn.Module):
assert len(indexes) == 1
# The indexes are used to indicate the actual position IDs of each token in the packed input.
indexes = indexes[0]
if gpc.config.parallel.sequence_parallel:
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
for _, block in enumerate(self.blocks):
@ -394,6 +409,7 @@ class PackedFlashInternLm1D(nn.Module):
hidden_states = self.norm(hidden_states.float())
if hasattr(self, "head"):
hidden_states = self.head(hidden_states)
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=0)
if not self.parallel_output:
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)

View File

@ -18,7 +18,114 @@ from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.embedding import DynamicNTKScalingRotaryEmbedding, RotaryEmbedding
from internlm.model.linear import ColumnParallelLinearTorch, RowParallelLinearTorch
from internlm.model.linear import ColumnParallelLinearTorch, RowParallelLinearTorch, FSDPLinear
import torch
from typing import Any, Tuple
from torch import Tensor
from torch.nn import Module
import torch.distributed as dist
class _SeqAllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
ctx.group = group
ctx.scatter_idx = scatter_idx
ctx.gather_idx = gather_idx
seq_world_size = dist.get_world_size(group)
input_list = [t.contiguous() for t in torch.tensor_split(input, seq_world_size, scatter_idx)]
output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
# TODO Use all_to_all_single instead
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_idx).contiguous()
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None)
class DistributedAttention(torch.nn.Module):
"""Initialization.
Arguments:
local_attention (Module): local attention with q,k,v
sequence_process_group (ProcessGroup): sequence parallel process group
scatter_idx (int): scatter_idx for all2all comm
gather_idx (int): gather_idx for all2all comm
"""
def __init__(
self,
local_attention: Module,
sequence_process_group: dist.ProcessGroup,
scatter_idx: int = 2,
gather_idx: int = 0,
) -> None:
super(DistributedAttention, self).__init__()
self.local_attn = local_attention
self.spg = sequence_process_group
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx
# def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tensor:
# """ forward
# Arguments:
# query (Tensor): query input to the layer
# key (Tensor): key input to the layer
# value (Tensor): value input to the layer
# args: other args
# Returns:
# * output (Tensor): context output
# """
# # TODO Merge three alltoall calls into one
# #in shape : e.g., [s/p:h:]
# query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx)
# key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx)
# value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)
# #out shape : e.g., [s:h/p:]
# context_layer = self.local_attn(query_layer, key_layer, value_layer, *args)
# output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
# #out e.g., [s/p::h]
# return output
def forward(self, qkv: Tensor, **kwargs: Any) -> Tensor:
""" forward
Arguments:
query (Tensor): query input to the layer
key (Tensor): key input to the layer
value (Tensor): value input to the layer
args: other args
Returns:
* output (Tensor): context output
"""
# TODO Merge three alltoall calls into one
#in shape : e.g., [s/p:h:]
qkv = _SeqAllToAll.apply(self.spg, qkv, self.scatter_idx, self.gather_idx)
# key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx)
# value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)
#out shape : e.g., [s:h/p:]
context_layer = self.local_attn(qkv, **kwargs)
output = _SeqAllToAll.apply(self.spg, context_layer, 0, 2)
#out e.g., [s/p::h]
return output
class MHA(nn.Module):
@ -91,7 +198,16 @@ class MHA(nn.Module):
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
# notice here should change bias=True
self.Wqkv = ColumnParallelLinearTorch(
# self.Wqkv = ColumnParallelLinearTorch(
# embed_dim,
# 3 * embed_dim,
# process_group,
# bias=True,
# sequence_parallel=gpc.config.parallel.sequence_parallel,
# **factory_kwargs,
# ) # according to https://spaces.ac.cn/archives/9577
self.Wqkv = FSDPLinear(
embed_dim,
3 * embed_dim,
process_group,
@ -107,8 +223,18 @@ class MHA(nn.Module):
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
)
self.inner_attn_sp = DistributedAttention(self.inner_attn, sequence_process_group=process_group, scatter_idx=3, gather_idx=0)
self.inner_cross_attn_sp = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group, scatter_idx=3, gather_idx=0)
# output projection always have the bias (for now)
self.out_proj = RowParallelLinearTorch(
# self.out_proj = RowParallelLinearTorch(
# embed_dim,
# embed_dim,
# process_group,
# sequence_parallel=gpc.config.parallel.sequence_parallel,
# **factory_kwargs,
# )
self.out_proj = FSDPLinear(
embed_dim,
embed_dim,
process_group,
@ -217,9 +343,11 @@ class MHA(nn.Module):
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
if qkv.dtype not in [torch.float16, torch.bfloat16]:
qkv = qkv.to(torch.bfloat16)
context = self.inner_attn(qkv, **kwargs).to(x.dtype)
# context = self.inner_attn(qkv, **kwargs).to(x.dtype)
context = self.inner_attn_sp(qkv, **kwargs).to(x.dtype)
else:
context = self.inner_attn(qkv, **kwargs)
# context = self.inner_attn(qkv, **kwargs)
context = self.inner_attn_sp(qkv, **kwargs)
else:
raise RuntimeError("Not support this right now")

View File

@ -110,7 +110,6 @@ def main(args):
# initialize and resume train state
train_state = TrainState(gpc.config, train_dl.batch_sampler)
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
ckpt_manager = CheckpointManager(
@ -171,6 +170,7 @@ def main(args):
scheduler_hooks=scheduler_hooks,
)
# initialize simple memory profiler
if args.profiling:
memory_profiler = SimpleMemoryProfiler(
@ -219,21 +219,9 @@ def main(args):
# do forward and backward
timer("fwd-bwd").start()
moe_loss = None
if hasattr(gpc.config.model, "num_experts"):
_, _, loss, moe_loss = trainer.execute_schedule(
batch,
forward_only=False,
return_loss=True,
return_output_label=False,
)
else:
_, _, loss = trainer.execute_schedule(
batch,
forward_only=False,
return_loss=True,
return_output_label=False,
)
_, _, loss = trainer.execute_schedule(
batch, forward_only=False, return_loss=True, return_output_label=False
)
timer("fwd-bwd").stop()
# update parameters, and returns (success_update, grad_norm)
@ -266,7 +254,6 @@ def main(args):
trainer=trainer,
start_time=start_time,
loss=loss,
moe_loss=moe_loss,
grad_norm=grad_norm_groups,
metric=metric,
update_panel=uniscale_logger is not None,