[shardformer] optimize seq parallelism (#6086)

* [shardformer] optimize seq parallelism

* [shardformer] fix gpt2 fused linear col

* [plugin] update gemini plugin

* [plugin] update moe hybrid plugin

* [test] update gpt2 fused linear test

* [shardformer] fix gpt2 fused linear reduce
pull/6090/head
Hongxin Liu 1 month ago committed by GitHub
parent 6b2c506fc5
commit dc2cdaf3e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -322,7 +322,6 @@ class GeminiPlugin(DPPluginBase):
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
@ -366,7 +365,6 @@ class GeminiPlugin(DPPluginBase):
enable_flash_attention: bool = False,
enable_sequence_parallelism: bool = False,
enable_jit_fused: bool = False,
enable_sequence_overlap: bool = False,
enable_async_reduce: bool = True,
use_fp8: bool = False,
verbose: bool = False,
@ -428,7 +426,6 @@ class GeminiPlugin(DPPluginBase):
self.enable_flash_attention = enable_flash_attention
self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_overlap = enable_sequence_overlap
self.verbose = verbose
self.tp_size = tp_size
@ -455,7 +452,6 @@ class GeminiPlugin(DPPluginBase):
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=self.enable_sequence_parallelism,
enable_sequence_overlap=self.enable_sequence_overlap,
)
def __del__(self):

@ -951,7 +951,6 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
@ -1002,7 +1001,6 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
sequence_parallelism_mode: str = None,
enable_sequence_overlap: bool = False,
parallel_output: bool = True,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
@ -1174,7 +1172,6 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
sequence_parallelism_mode=sequence_parallelism_mode,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,

@ -140,7 +140,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
@ -189,7 +188,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
sequence_parallelism_mode: str = None,
enable_sequence_overlap: bool = False,
parallel_output: bool = True,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
@ -351,7 +349,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
sequence_parallelism_mode=sequence_parallelism_mode,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,

@ -102,7 +102,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_allreduce and fp8_communication:
if fp8_communication or not ctx.async_grad_allreduce:
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2")
elif ctx.async_grad_allreduce:
# Asynchronous all-reduce
@ -216,10 +216,12 @@ def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=
for k in recv_tensors:
send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k]
input_tensors = []
output_tensors = []
handles = communicate_step()
# first round: special case, retrive from local tensor
input_tensors.append(input_to_gather)
output_tensors.append(func(**input_to_gather, **input_local))
for i in range(group_size - 2):
for handle in handles:
@ -230,14 +232,25 @@ def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=
handles = communicate_step()
# actual computation
input_tensors.append(send_tensors)
output_tensors.append(func(**send_tensors, **input_local))
# final round: special case, no need to send/recv again
for handle in handles:
handle.wait()
input_tensors.append(send_tensors)
output_tensors.append(func(**recv_tensors, **input_local))
return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim)
gathered_input = {}
for k in input_to_gather:
input_shards = [d[k] for d in input_tensors[group_size - cur_rank :] + input_tensors[: group_size - cur_rank]]
gathered_input[k] = torch.cat(input_shards, dim=gather_dim)
gathered_output = torch.cat(
output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim
)
return gathered_output, gathered_input
class _GatherForwardReduceScatterBackward(torch.autograd.Function):
@ -293,29 +306,30 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False):
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap
if ring is True:
input_to_gather = {"input": input_}
input_local = {"weight": weight}
output = _ring_as_gather(
output, input_dict = _ring_as_gather(
F.linear,
input_to_gather=input_to_gather,
input_local=input_local,
process_group=process_group,
)
ctx.gathered_input = input_dict["input"]
if bias is not None:
output += bias
else:
input_parallel = _gather(input_, dim, process_group)
ctx.gathered_input = input_parallel
if bias is not None:
output = F.linear(input_parallel, weight, bias)
else:
@ -329,14 +343,12 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
if use_bias:
bias = bias.view(bias.shape)
if not overlap:
input_parallel = _gather(input_, dim, process_group)
input_parallel = ctx.gathered_input
total_input = input_parallel
grad_input = grad_output.matmul(weight)
@ -351,9 +363,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous()
output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
@ -376,53 +386,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
if ctx.async_grad_reduce_scatter:
handle.wait()
else:
input_ = input_.contiguous()
world_size = dist.get_world_size(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
# do all gather in is async way
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
# calculate gradient and prepare data asynchronously with all-gather
# calculate
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None
# prepare data
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
# wait until all-gather finished
gather_handle.wait()
# do reduce-scatter in async way
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
# calculate gradient
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(input_parallel, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(input_parallel, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(input_parallel)
else:
grad_weight = grad_output.t().matmul(input_parallel)
# grad_weight = grad_output.t().matmul(input_parallel)
# wait until reduce-scatter finished
reducescatter_handle.wait()
return output, grad_weight, grad_bias, None, None, None, None, None
return output, grad_weight, grad_bias, None, None, None, None
def _ring_as_reducescatter(
@ -553,7 +517,7 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
total_input = total_input.reshape(-1, total_input.shape[-1])
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
@ -611,34 +575,30 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
@staticmethod
def forward(
ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
):
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap
ctx.fp8_communication = fp8_communication
if ring is True:
input_to_gather = {}
input_local = {}
input_to_gather["input"] = input_
input_local["other"] = weight
input_to_gather = {"input": input_}
input_local = {"other": weight}
output = _ring_as_gather(
output, input_dict = _ring_as_gather(
torch.matmul,
input_to_gather=input_to_gather,
input_local=input_local,
process_group=process_group,
gather_dim=dim,
)
ctx.gathered_input = input_dict["input"]
else:
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")
ctx.gathered_input = input_parallel
output = torch.matmul(input_parallel, weight)
if bias is not None:
@ -651,16 +611,13 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap
fp8_communication = ctx.fp8_communication
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
weight = weight.view(weight.shape)
if use_bias:
bias = bias.view(bias.shape)
if not overlap:
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e5m2")
input_parallel = ctx.gathered_input
total_input = input_parallel
grad_input = grad_output.matmul(weight.T)
@ -675,9 +632,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous()
output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated
@ -688,39 +643,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
if ctx.async_grad_reduce_scatter:
handle.wait()
else:
world_size = dist.get_world_size(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
# do all gather in is async way
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
# calculate gradient and prepare data asynchronously with all-gather
# calculate
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None
# prepare data
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
# wait until all-gather finished
gather_handle.wait()
# do reduce-scatter in async way
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
# calculate gradient
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
grad_weight = input_parallel.t().matmul(grad_output)
# wait until reduce-scatter finished
reducescatter_handle.wait()
return output, grad_weight, grad_bias, None, None, None, None, None, None
return output, grad_weight, grad_bias, None, None, None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function):
@ -1050,10 +973,10 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
def linear_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False
):
return _LinearWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring
)
@ -1070,10 +993,10 @@ def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, proc
def matmul_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False, fp8_communication=False
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, fp8_communication=False
):
return _MatmulWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication
)

@ -23,17 +23,15 @@ from colossalai.tensor.d_tensor.api import (
)
from ._operation import (
gather_forward_reducescatter_backward,
gather_forward_split_backward,
linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
reduce_forward,
reducescatter_forward_gather_backward,
split_forward_gather_backward,
)
from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset
from .utils import create_randomizer_with_offset, is_share_sp_tp
__all__ = ["Linear1D_Col", "Linear1D_Row"]
@ -55,7 +53,6 @@ class Linear1D_Col(ParallelModule):
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
@ -78,7 +75,6 @@ class Linear1D_Col(ParallelModule):
gather_output: bool = False,
seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
overlap: torch.cuda.Stream = None,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
@ -95,7 +91,6 @@ class Linear1D_Col(ParallelModule):
self.gather_output = gather_output
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.process_group = process_group
@ -202,16 +197,15 @@ class Linear1D_Col(ParallelModule):
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel_mode == "split_gather":
input_parallel = gather_forward_reducescatter_backward(
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication
)
elif self.seq_parallel_mode == "ring":
if is_share_sp_tp(self.seq_parallel_mode):
output_parallel = linear_gather_forward_reducescatter_backward(
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
input_parallel,
self.weight,
bias,
self.process_group,
True,
self.seq_parallel_dim,
ring=self.seq_parallel_mode == "ring",
)
else:
output_parallel = linear_with_async_comm(
@ -428,18 +422,13 @@ class Linear1D_Row(ParallelModule):
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
if self.seq_parallel_mode == "split_gather":
output_parallel = F.linear(input_, self.weight)
output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
elif self.seq_parallel_mode == "ring":
if is_share_sp_tp(self.seq_parallel_mode):
output = linear_reducescatter_forward_gather_backward(
input_,
self.weight,
process_group=self.process_group,
dim=self.seq_parallel_dim,
ring=True,
ring=self.seq_parallel_mode == "ring",
)
else:
output_parallel = F.linear(input_, self.weight)
@ -551,7 +540,6 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):

@ -25,19 +25,17 @@ from colossalai.tensor.d_tensor.api import (
)
from ._operation import (
gather_forward_reducescatter_backward,
linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
matmul_gather_forward_reducescatter_backward,
matmul_with_async_comm,
reduce_backward,
reduce_forward,
reducescatter_forward_gather_backward,
split_forward_gather_backward,
)
from .parallel_module import ParallelModule
from .utils import create_randomizer_with_offset
from .utils import create_randomizer_with_offset, is_share_sp_tp
__all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row"]
@ -222,10 +220,8 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
async_communication: bool = False,
gather_output: bool = False,
seq_parallel_mode: str = None,
overlap: bool = False,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
@ -240,12 +236,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.out_features = out_features
self.gather_output = gather_output
self.seq_parallel_mode = seq_parallel_mode
self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.split_sizes = split_sizes
self.process_group = process_group
self.async_communication = async_communication
self.fp8_communication = fp8_communication
assert (
@ -370,7 +364,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel_mode == "split_gather":
if is_share_sp_tp(self.seq_parallel_mode):
input_parallel = input_
output_parallel = matmul_gather_forward_reducescatter_backward(
input_parallel,
@ -379,31 +373,18 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.process_group,
True,
1,
self.overlap,
fp8_communication=self.fp8_communication,
)
elif self.seq_parallel_mode == "ring":
input_parallel = input_
output_parallel = matmul_gather_forward_reducescatter_backward(
input_parallel,
self.weight,
bias,
self.process_group,
True,
1,
self.overlap,
True,
ring=self.seq_parallel_mode == "ring",
fp8_communication=self.fp8_communication,
)
elif self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
# Set up backprop all-reduce.
input_parallel = reduce_backward(input_, self.process_group)
input_parallel = input_
output_parallel = matmul_with_async_comm(
input_parallel,
self.weight,
bias,
self.process_group,
self.async_communication,
True,
fp8_communication=self.fp8_communication,
)
else:
@ -620,7 +601,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
if self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
output_parallel = torch.matmul(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
elif self.seq_parallel_mode == "split_gather":
elif is_share_sp_tp(self.seq_parallel_mode):
output_parallel = torch.matmul(input_, self.weight)
output = reducescatter_forward_gather_backward(
output_parallel,
@ -628,13 +609,6 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
1,
self.fp8_communication,
)
elif self.seq_parallel_mode == "ring":
output_parallel = torch.matmul(input_, self.weight)
output = reducescatter_forward_gather_backward(
output_parallel,
self.process_group,
1,
)
else:
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")
@ -691,7 +665,6 @@ class FusedLinear1D_Col(ParallelModule):
gather_output: bool = False,
seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
overlap: torch.cuda.Stream = None,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
@ -706,7 +679,6 @@ class FusedLinear1D_Col(ParallelModule):
self.gather_output = gather_output
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.split_sizes = split_sizes
@ -830,16 +802,15 @@ class FusedLinear1D_Col(ParallelModule):
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel_mode == "split_gather":
input_parallel = gather_forward_reducescatter_backward(
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication
)
elif self.seq_parallel_mode == "ring":
if is_share_sp_tp(self.seq_parallel_mode):
output_parallel = linear_gather_forward_reducescatter_backward(
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
input_parallel,
self.weight,
bias,
self.process_group,
True,
self.seq_parallel_dim,
ring=self.seq_parallel_mode == "ring",
)
else:
output_parallel = linear_with_async_comm(
@ -1031,18 +1002,13 @@ class FusedLinear1D_Row(ParallelModule):
)
input_ = split_forward_gather_backward_fused_qkv(input_, self.split_sizes, self.process_group)
if self.seq_parallel_mode == "split_gather":
output_parallel = F.linear(input_, self.weight)
output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
elif self.seq_parallel_mode == "ring":
if is_share_sp_tp(self.seq_parallel_mode):
output = linear_reducescatter_forward_gather_backward(
input_,
self.weight,
process_group=self.process_group,
dim=self.seq_parallel_dim,
ring=True,
ring=self.seq_parallel_mode == "ring",
)
else:
output_parallel = F.linear(input_, self.weight)

@ -73,7 +73,6 @@ class BertPolicy(Policy):
)
sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism:
@ -97,7 +96,6 @@ class BertPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
@ -106,7 +104,6 @@ class BertPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
@ -115,7 +112,6 @@ class BertPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
@ -140,7 +136,6 @@ class BertPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
},

@ -57,7 +57,6 @@ class BloomPolicy(Policy):
)
sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism:
@ -78,7 +77,6 @@ class BloomPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
@ -99,7 +97,6 @@ class BloomPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),

@ -67,7 +67,6 @@ class ChatGLMPolicy(Policy):
f"For ChatGLM2, sequence parallelism doesn't support mode {sp_mode} yet, will set to be split_gather"
)
sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode in ["split_gather"]
if sp_mode == "all_to_all":
@ -127,7 +126,6 @@ class ChatGLMPolicy(Policy):
kwargs={
"seq_parallel_mode": sp_mode,
"seq_parallel_dim": 0,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),

@ -65,7 +65,6 @@ class GPT2Policy(Policy):
f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
)
self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode in ["split_gather", "ring"]
use_flash_attention = self.shard_config.enable_flash_attention
if self.shard_config.enable_tensor_parallelism:
@ -94,7 +93,6 @@ class GPT2Policy(Policy):
kwargs={
"split_sizes": [self.model.config.hidden_size] * 3,
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
@ -109,7 +107,6 @@ class GPT2Policy(Policy):
kwargs={
"split_sizes": [self.model.config.n_inner or 4 * self.model.config.hidden_size],
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
},

@ -51,7 +51,6 @@ class GPTJPolicy(Policy):
self.shard_config.enable_sequence_parallelism = False
warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
@ -76,7 +75,6 @@ class GPTJPolicy(Policy):
suffix="attn.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
@ -84,7 +82,6 @@ class GPTJPolicy(Policy):
suffix="attn.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
@ -92,7 +89,6 @@ class GPTJPolicy(Policy):
suffix="attn.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),

@ -26,7 +26,6 @@ class ShardConfig:
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism. Defaults to False.
@ -44,7 +43,6 @@ class ShardConfig:
enable_jit_fused: bool = False
enable_sequence_parallelism: bool = False
sequence_parallelism_mode: str = None
enable_sequence_overlap: bool = False
parallel_output: bool = True
make_vocab_size_divisible_by: int = 64
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
@ -84,24 +82,12 @@ class ShardConfig:
assert (
self.enable_tensor_parallelism
), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True"
elif self.sequence_parallelism_mode in ["all_to_all"]:
# assert (
# not self.enable_tensor_parallelism
# ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False"
if self.enable_sequence_overlap:
self.enable_sequence_overlap = False
warnings.warn(
f"The enable_sequence_overlap flag will be ignored in sequence parallelism mode {self.sequence_parallelism_mode}"
)
else:
if self.sequence_parallelism_mode:
self.sequence_parallelism_mode = None
warnings.warn(
f"The sequence_parallelism_mode will be ignored when enable_sequence_parallelism is False"
)
assert (
not self.enable_sequence_overlap
), f"enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True"
# get the tensor parallel size
if not self.enable_tensor_parallelism:
@ -134,4 +120,3 @@ class ShardConfig:
# This can cause non-in-place param sharding when used without ZeRO.
# It may also slow down training when seq len is small. Plz enable manually.
# self.enable_sequence_parallelism = True
# self.enable_sequence_overlap = True

@ -41,7 +41,7 @@ class Conv1D(nn.Module):
return x
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool):
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
with ctx:
@ -52,7 +52,6 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
gather_output=True,
seq_parallel_mode=seq_parallel_mode,
split_sizes=[64] * 3,
overlap=overlap,
)
assert linear.weight.shape == torch.Size([48, 192])
@ -121,9 +120,8 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel_mode: bool):
@parameterize("lazy_init", [False, True])
@parameterize("seq_parallel_mode", ["split_gather", None])
@parameterize("overlap", [True])
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
check_linear_conv_1d_col(lazy_init, seq_parallel_mode, overlap)
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool):
check_linear_conv_1d_col(lazy_init, seq_parallel_mode)
check_linear_conv_1d_row(lazy_init, seq_parallel_mode)

Loading…
Cancel
Save