[shardformer] optimize seq parallelism (#6086)

* [shardformer] optimize seq parallelism

* [shardformer] fix gpt2 fused linear col

* [plugin] update gemini plugin

* [plugin] update moe hybrid plugin

* [test] update gpt2 fused linear test

* [shardformer] fix gpt2 fused linear reduce
pull/6090/head
Hongxin Liu 1 month ago committed by GitHub
parent 6b2c506fc5
commit dc2cdaf3e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -322,7 +322,6 @@ class GeminiPlugin(DPPluginBase):
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False. use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False. verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False. fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
@ -366,7 +365,6 @@ class GeminiPlugin(DPPluginBase):
enable_flash_attention: bool = False, enable_flash_attention: bool = False,
enable_sequence_parallelism: bool = False, enable_sequence_parallelism: bool = False,
enable_jit_fused: bool = False, enable_jit_fused: bool = False,
enable_sequence_overlap: bool = False,
enable_async_reduce: bool = True, enable_async_reduce: bool = True,
use_fp8: bool = False, use_fp8: bool = False,
verbose: bool = False, verbose: bool = False,
@ -428,7 +426,6 @@ class GeminiPlugin(DPPluginBase):
self.enable_flash_attention = enable_flash_attention self.enable_flash_attention = enable_flash_attention
self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False
self.enable_jit_fused = enable_jit_fused self.enable_jit_fused = enable_jit_fused
self.enable_sequence_overlap = enable_sequence_overlap
self.verbose = verbose self.verbose = verbose
self.tp_size = tp_size self.tp_size = tp_size
@ -455,7 +452,6 @@ class GeminiPlugin(DPPluginBase):
enable_flash_attention=self.enable_flash_attention, enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused, enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=self.enable_sequence_parallelism, enable_sequence_parallelism=self.enable_sequence_parallelism,
enable_sequence_overlap=self.enable_sequence_overlap,
) )
def __del__(self): def __del__(self):

@ -951,7 +951,6 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather". sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True. parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism. microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
@ -1002,7 +1001,6 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_jit_fused: bool = False, enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False, enable_sequence_parallelism: bool = False,
sequence_parallelism_mode: str = None, sequence_parallelism_mode: str = None,
enable_sequence_overlap: bool = False,
parallel_output: bool = True, parallel_output: bool = True,
num_microbatches: Optional[int] = None, num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None, microbatch_size: Optional[int] = None,
@ -1174,7 +1172,6 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_jit_fused=self.enable_jit_fused, enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism, enable_sequence_parallelism=enable_sequence_parallelism,
sequence_parallelism_mode=sequence_parallelism_mode, sequence_parallelism_mode=sequence_parallelism_mode,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output, parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by, make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config, gradient_checkpoint_config=gradient_checkpoint_config,

@ -140,7 +140,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather". sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True. parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism. microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
@ -189,7 +188,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
enable_jit_fused: bool = False, enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False, enable_sequence_parallelism: bool = False,
sequence_parallelism_mode: str = None, sequence_parallelism_mode: str = None,
enable_sequence_overlap: bool = False,
parallel_output: bool = True, parallel_output: bool = True,
num_microbatches: Optional[int] = None, num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None, microbatch_size: Optional[int] = None,
@ -351,7 +349,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
enable_jit_fused=self.enable_jit_fused, enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism, enable_sequence_parallelism=enable_sequence_parallelism,
sequence_parallelism_mode=sequence_parallelism_mode, sequence_parallelism_mode=sequence_parallelism_mode,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output, parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by, make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config, gradient_checkpoint_config=gradient_checkpoint_config,

@ -102,7 +102,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
grad_output = grad_output.view(-1, grad_output.shape[-1]) grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1]) total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_allreduce and fp8_communication: if fp8_communication or not ctx.async_grad_allreduce:
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2") _reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2")
elif ctx.async_grad_allreduce: elif ctx.async_grad_allreduce:
# Asynchronous all-reduce # Asynchronous all-reduce
@ -216,10 +216,12 @@ def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=
for k in recv_tensors: for k in recv_tensors:
send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k] send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k]
input_tensors = []
output_tensors = [] output_tensors = []
handles = communicate_step() handles = communicate_step()
# first round: special case, retrive from local tensor # first round: special case, retrive from local tensor
input_tensors.append(input_to_gather)
output_tensors.append(func(**input_to_gather, **input_local)) output_tensors.append(func(**input_to_gather, **input_local))
for i in range(group_size - 2): for i in range(group_size - 2):
for handle in handles: for handle in handles:
@ -230,14 +232,25 @@ def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=
handles = communicate_step() handles = communicate_step()
# actual computation # actual computation
input_tensors.append(send_tensors)
output_tensors.append(func(**send_tensors, **input_local)) output_tensors.append(func(**send_tensors, **input_local))
# final round: special case, no need to send/recv again # final round: special case, no need to send/recv again
for handle in handles: for handle in handles:
handle.wait() handle.wait()
input_tensors.append(send_tensors)
output_tensors.append(func(**recv_tensors, **input_local)) output_tensors.append(func(**recv_tensors, **input_local))
return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim) gathered_input = {}
for k in input_to_gather:
input_shards = [d[k] for d in input_tensors[group_size - cur_rank :] + input_tensors[: group_size - cur_rank]]
gathered_input[k] = torch.cat(input_shards, dim=gather_dim)
gathered_output = torch.cat(
output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim
)
return gathered_output, gathered_input
class _GatherForwardReduceScatterBackward(torch.autograd.Function): class _GatherForwardReduceScatterBackward(torch.autograd.Function):
@ -293,29 +306,30 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
""" """
@staticmethod @staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False): def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False):
ctx.save_for_backward(input_, weight, bias) ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.process_group = process_group ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim ctx.dim = dim
ctx.overlap = overlap
if ring is True: if ring is True:
input_to_gather = {"input": input_} input_to_gather = {"input": input_}
input_local = {"weight": weight} input_local = {"weight": weight}
output = _ring_as_gather( output, input_dict = _ring_as_gather(
F.linear, F.linear,
input_to_gather=input_to_gather, input_to_gather=input_to_gather,
input_local=input_local, input_local=input_local,
process_group=process_group, process_group=process_group,
) )
ctx.gathered_input = input_dict["input"]
if bias is not None: if bias is not None:
output += bias output += bias
else: else:
input_parallel = _gather(input_, dim, process_group) input_parallel = _gather(input_, dim, process_group)
ctx.gathered_input = input_parallel
if bias is not None: if bias is not None:
output = F.linear(input_parallel, weight, bias) output = F.linear(input_parallel, weight, bias)
else: else:
@ -329,100 +343,50 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
use_bias = ctx.use_bias use_bias = ctx.use_bias
dim = ctx.dim dim = ctx.dim
process_group = ctx.process_group process_group = ctx.process_group
overlap = ctx.overlap
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
if use_bias: if use_bias:
bias = bias.view(bias.shape) bias = bias.view(bias.shape)
if not overlap: input_parallel = ctx.gathered_input
input_parallel = _gather(input_, dim, process_group)
total_input = input_parallel
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter: total_input = input_parallel
handle.wait() grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
else: if ctx.async_grad_reduce_scatter:
input_ = input_.contiguous() # Asynchronous reduce-scatter
world_size = dist.get_world_size(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
# do all gather in is async way
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
# calculate gradient and prepare data asynchronously with all-gather
# calculate
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None
# prepare data
input_list = [ input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
] ]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
# wait until all-gather finished handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
gather_handle.wait() # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
# do reduce-scatter in async way
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) if _grad_accum_fusion_available and weight.grad is not None:
input_parallel = torch.cat(tensor_list, dim=dim).contiguous() grad = weight.grad
# calculate gradient if grad.dtype == torch.float32:
if len(input_parallel.shape) > 2: fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) grad_weight = None
elif grad.dtype == torch.float16:
if _grad_accum_fusion_available and weight.grad is not None: fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad = weight.grad grad_weight = None
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(input_parallel, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(input_parallel, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(input_parallel)
else: else:
grad_weight = grad_output.t().matmul(input_parallel) grad_weight = grad_output.t().matmul(total_input)
# grad_weight = grad_output.t().matmul(input_parallel) else:
# wait until reduce-scatter finished grad_weight = grad_output.t().matmul(total_input)
reducescatter_handle.wait()
return output, grad_weight, grad_bias, None, None, None, None, None grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
return output, grad_weight, grad_bias, None, None, None, None
def _ring_as_reducescatter( def _ring_as_reducescatter(
@ -553,7 +517,7 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
# Convert the tensor shapes to 2D for execution compatibility # Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2: if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1]) grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1]) total_input = total_input.reshape(-1, total_input.shape[-1])
grad_weight = grad_output.t().matmul(total_input) grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None grad_bias = grad_output.sum(dim=0) if use_bias else None
@ -611,34 +575,30 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
""" """
@staticmethod @staticmethod
def forward( def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication):
ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
):
ctx.save_for_backward(input_, weight, bias) ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.process_group = process_group ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim ctx.dim = dim
ctx.overlap = overlap
ctx.fp8_communication = fp8_communication ctx.fp8_communication = fp8_communication
if ring is True: if ring is True:
input_to_gather = {} input_to_gather = {"input": input_}
input_local = {} input_local = {"other": weight}
input_to_gather["input"] = input_
input_local["other"] = weight
output = _ring_as_gather( output, input_dict = _ring_as_gather(
torch.matmul, torch.matmul,
input_to_gather=input_to_gather, input_to_gather=input_to_gather,
input_local=input_local, input_local=input_local,
process_group=process_group, process_group=process_group,
gather_dim=dim, gather_dim=dim,
) )
ctx.gathered_input = input_dict["input"]
else: else:
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3") input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")
ctx.gathered_input = input_parallel
output = torch.matmul(input_parallel, weight) output = torch.matmul(input_parallel, weight)
if bias is not None: if bias is not None:
@ -651,76 +611,39 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
use_bias = ctx.use_bias use_bias = ctx.use_bias
dim = ctx.dim dim = ctx.dim
process_group = ctx.process_group process_group = ctx.process_group
overlap = ctx.overlap
fp8_communication = ctx.fp8_communication
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
weight = weight.view(weight.shape) weight = weight.view(weight.shape)
if use_bias: if use_bias:
bias = bias.view(bias.shape) bias = bias.view(bias.shape)
if not overlap: input_parallel = ctx.gathered_input
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e5m2")
total_input = input_parallel
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated
grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
else: total_input = input_parallel
world_size = dist.get_world_size(process_group) grad_input = grad_output.matmul(weight.T)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
# do all gather in is async way if len(grad_output.shape) > 2:
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) grad_output = grad_output.view(-1, grad_output.shape[-1])
# calculate gradient and prepare data asynchronously with all-gather total_input = total_input.view(-1, total_input.shape[-1])
# calculate
grad_input = grad_output.matmul(weight.T) if ctx.async_grad_reduce_scatter:
grad_output = grad_output.contiguous() # Asynchronous reduce-scatter
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None
# prepare data
input_list = [ input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
] ]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
# wait until all-gather finished handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
gather_handle.wait() # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated
grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
# do reduce-scatter in async way if ctx.async_grad_reduce_scatter:
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) handle.wait()
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
# calculate gradient
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
grad_weight = input_parallel.t().matmul(grad_output)
# wait until reduce-scatter finished
reducescatter_handle.wait()
return output, grad_weight, grad_bias, None, None, None, None, None, None return output, grad_weight, grad_bias, None, None, None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function): class _SplitForwardGatherBackward(torch.autograd.Function):
@ -1050,10 +973,10 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
def linear_gather_forward_reducescatter_backward( def linear_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False
): ):
return _LinearWithGatherForwardReduceScatterBackward.apply( return _LinearWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring
) )
@ -1070,10 +993,10 @@ def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, proc
def matmul_gather_forward_reducescatter_backward( def matmul_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False, fp8_communication=False input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, fp8_communication=False
): ):
return _MatmulWithGatherForwardReduceScatterBackward.apply( return _MatmulWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication
) )

@ -23,17 +23,15 @@ from colossalai.tensor.d_tensor.api import (
) )
from ._operation import ( from ._operation import (
gather_forward_reducescatter_backward,
gather_forward_split_backward, gather_forward_split_backward,
linear_gather_forward_reducescatter_backward, linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward, linear_reducescatter_forward_gather_backward,
linear_with_async_comm, linear_with_async_comm,
reduce_forward, reduce_forward,
reducescatter_forward_gather_backward,
split_forward_gather_backward, split_forward_gather_backward,
) )
from .parallel_module import PaddingParallelModule, ParallelModule from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset from .utils import create_randomizer_with_offset, is_share_sp_tp
__all__ = ["Linear1D_Col", "Linear1D_Row"] __all__ = ["Linear1D_Col", "Linear1D_Row"]
@ -55,7 +53,6 @@ class Linear1D_Col(ParallelModule):
to all GPUs, otherwise, every GPU will have its output to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`): weight_initializer (`typing.Callable`):
@ -78,7 +75,6 @@ class Linear1D_Col(ParallelModule):
gather_output: bool = False, gather_output: bool = False,
seq_parallel_mode: str = None, seq_parallel_mode: str = None,
seq_parallel_dim: int = 1, seq_parallel_dim: int = 1,
overlap: torch.cuda.Stream = None,
skip_bias_add: bool = False, skip_bias_add: bool = False,
weight: Optional[Parameter] = None, weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None, bias_: Optional[Parameter] = None,
@ -95,7 +91,6 @@ class Linear1D_Col(ParallelModule):
self.gather_output = gather_output self.gather_output = gather_output
self.seq_parallel_mode = seq_parallel_mode self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim self.seq_parallel_dim = seq_parallel_dim
self.overlap = overlap
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.device = device self.device = device
self.process_group = process_group self.process_group = process_group
@ -202,16 +197,15 @@ class Linear1D_Col(ParallelModule):
# Matrix multiply. # Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel_mode == "split_gather": if is_share_sp_tp(self.seq_parallel_mode):
input_parallel = gather_forward_reducescatter_backward(
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication
)
elif self.seq_parallel_mode == "ring":
output_parallel = linear_gather_forward_reducescatter_backward( output_parallel = linear_gather_forward_reducescatter_backward(
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True input_parallel,
self.weight,
bias,
self.process_group,
True,
self.seq_parallel_dim,
ring=self.seq_parallel_mode == "ring",
) )
else: else:
output_parallel = linear_with_async_comm( output_parallel = linear_with_async_comm(
@ -428,18 +422,13 @@ class Linear1D_Row(ParallelModule):
handle.wait() handle.wait()
output = torch.cat(output_parallel_list, dim=-1) output = torch.cat(output_parallel_list, dim=-1)
else: else:
if self.seq_parallel_mode == "split_gather": if is_share_sp_tp(self.seq_parallel_mode):
output_parallel = F.linear(input_, self.weight)
output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
elif self.seq_parallel_mode == "ring":
output = linear_reducescatter_forward_gather_backward( output = linear_reducescatter_forward_gather_backward(
input_, input_,
self.weight, self.weight,
process_group=self.process_group, process_group=self.process_group,
dim=self.seq_parallel_dim, dim=self.seq_parallel_dim,
ring=True, ring=self.seq_parallel_mode == "ring",
) )
else: else:
output_parallel = F.linear(input_, self.weight) output_parallel = F.linear(input_, self.weight)
@ -551,7 +540,6 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
to all GPUs, otherwise, every GPU will have its output to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`): weight_initializer (`typing.Callable`):

@ -25,19 +25,17 @@ from colossalai.tensor.d_tensor.api import (
) )
from ._operation import ( from ._operation import (
gather_forward_reducescatter_backward,
linear_gather_forward_reducescatter_backward, linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward, linear_reducescatter_forward_gather_backward,
linear_with_async_comm, linear_with_async_comm,
matmul_gather_forward_reducescatter_backward, matmul_gather_forward_reducescatter_backward,
matmul_with_async_comm, matmul_with_async_comm,
reduce_backward,
reduce_forward, reduce_forward,
reducescatter_forward_gather_backward, reducescatter_forward_gather_backward,
split_forward_gather_backward, split_forward_gather_backward,
) )
from .parallel_module import ParallelModule from .parallel_module import ParallelModule
from .utils import create_randomizer_with_offset from .utils import create_randomizer_with_offset, is_share_sp_tp
__all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row"] __all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row"]
@ -222,10 +220,8 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
dtype: torch.dtype = None, dtype: torch.dtype = None,
device: torch.device = None, device: torch.device = None,
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
async_communication: bool = False,
gather_output: bool = False, gather_output: bool = False,
seq_parallel_mode: str = None, seq_parallel_mode: str = None,
overlap: bool = False,
skip_bias_add: bool = False, skip_bias_add: bool = False,
weight: Optional[Parameter] = None, weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None, bias_: Optional[Parameter] = None,
@ -240,12 +236,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.out_features = out_features self.out_features = out_features
self.gather_output = gather_output self.gather_output = gather_output
self.seq_parallel_mode = seq_parallel_mode self.seq_parallel_mode = seq_parallel_mode
self.overlap = overlap
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.device = device self.device = device
self.split_sizes = split_sizes self.split_sizes = split_sizes
self.process_group = process_group self.process_group = process_group
self.async_communication = async_communication
self.fp8_communication = fp8_communication self.fp8_communication = fp8_communication
assert ( assert (
@ -370,7 +364,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
# Matrix multiply. # Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel_mode == "split_gather": if is_share_sp_tp(self.seq_parallel_mode):
input_parallel = input_ input_parallel = input_
output_parallel = matmul_gather_forward_reducescatter_backward( output_parallel = matmul_gather_forward_reducescatter_backward(
input_parallel, input_parallel,
@ -379,31 +373,18 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.process_group, self.process_group,
True, True,
1, 1,
self.overlap, ring=self.seq_parallel_mode == "ring",
fp8_communication=self.fp8_communication,
)
elif self.seq_parallel_mode == "ring":
input_parallel = input_
output_parallel = matmul_gather_forward_reducescatter_backward(
input_parallel,
self.weight,
bias,
self.process_group,
True,
1,
self.overlap,
True,
fp8_communication=self.fp8_communication, fp8_communication=self.fp8_communication,
) )
elif self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn": elif self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
# Set up backprop all-reduce. # Set up backprop all-reduce.
input_parallel = reduce_backward(input_, self.process_group) input_parallel = input_
output_parallel = matmul_with_async_comm( output_parallel = matmul_with_async_comm(
input_parallel, input_parallel,
self.weight, self.weight,
bias, bias,
self.process_group, self.process_group,
self.async_communication, True,
fp8_communication=self.fp8_communication, fp8_communication=self.fp8_communication,
) )
else: else:
@ -620,7 +601,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
if self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn": if self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
output_parallel = torch.matmul(input_, self.weight) output_parallel = torch.matmul(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication) output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
elif self.seq_parallel_mode == "split_gather": elif is_share_sp_tp(self.seq_parallel_mode):
output_parallel = torch.matmul(input_, self.weight) output_parallel = torch.matmul(input_, self.weight)
output = reducescatter_forward_gather_backward( output = reducescatter_forward_gather_backward(
output_parallel, output_parallel,
@ -628,13 +609,6 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
1, 1,
self.fp8_communication, self.fp8_communication,
) )
elif self.seq_parallel_mode == "ring":
output_parallel = torch.matmul(input_, self.weight)
output = reducescatter_forward_gather_backward(
output_parallel,
self.process_group,
1,
)
else: else:
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!") raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")
@ -691,7 +665,6 @@ class FusedLinear1D_Col(ParallelModule):
gather_output: bool = False, gather_output: bool = False,
seq_parallel_mode: str = None, seq_parallel_mode: str = None,
seq_parallel_dim: int = 1, seq_parallel_dim: int = 1,
overlap: torch.cuda.Stream = None,
skip_bias_add: bool = False, skip_bias_add: bool = False,
weight: Optional[Parameter] = None, weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None, bias_: Optional[Parameter] = None,
@ -706,7 +679,6 @@ class FusedLinear1D_Col(ParallelModule):
self.gather_output = gather_output self.gather_output = gather_output
self.seq_parallel_mode = seq_parallel_mode self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim self.seq_parallel_dim = seq_parallel_dim
self.overlap = overlap
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.device = device self.device = device
self.split_sizes = split_sizes self.split_sizes = split_sizes
@ -830,16 +802,15 @@ class FusedLinear1D_Col(ParallelModule):
# Matrix multiply. # Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel_mode == "split_gather": if is_share_sp_tp(self.seq_parallel_mode):
input_parallel = gather_forward_reducescatter_backward(
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication
)
elif self.seq_parallel_mode == "ring":
output_parallel = linear_gather_forward_reducescatter_backward( output_parallel = linear_gather_forward_reducescatter_backward(
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True input_parallel,
self.weight,
bias,
self.process_group,
True,
self.seq_parallel_dim,
ring=self.seq_parallel_mode == "ring",
) )
else: else:
output_parallel = linear_with_async_comm( output_parallel = linear_with_async_comm(
@ -1031,18 +1002,13 @@ class FusedLinear1D_Row(ParallelModule):
) )
input_ = split_forward_gather_backward_fused_qkv(input_, self.split_sizes, self.process_group) input_ = split_forward_gather_backward_fused_qkv(input_, self.split_sizes, self.process_group)
if self.seq_parallel_mode == "split_gather": if is_share_sp_tp(self.seq_parallel_mode):
output_parallel = F.linear(input_, self.weight)
output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
elif self.seq_parallel_mode == "ring":
output = linear_reducescatter_forward_gather_backward( output = linear_reducescatter_forward_gather_backward(
input_, input_,
self.weight, self.weight,
process_group=self.process_group, process_group=self.process_group,
dim=self.seq_parallel_dim, dim=self.seq_parallel_dim,
ring=True, ring=self.seq_parallel_mode == "ring",
) )
else: else:
output_parallel = F.linear(input_, self.weight) output_parallel = F.linear(input_, self.weight)

@ -73,7 +73,6 @@ class BertPolicy(Policy):
) )
sp_mode = "split_gather" sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode == "split_gather" sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
@ -97,7 +96,6 @@ class BertPolicy(Policy):
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),
@ -106,7 +104,6 @@ class BertPolicy(Policy):
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),
@ -115,7 +112,6 @@ class BertPolicy(Policy):
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),
@ -140,7 +136,6 @@ class BertPolicy(Policy):
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused, "skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },

@ -57,7 +57,6 @@ class BloomPolicy(Policy):
) )
sp_mode = "split_gather" sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode == "split_gather" sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
@ -78,7 +77,6 @@ class BloomPolicy(Policy):
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),
@ -99,7 +97,6 @@ class BloomPolicy(Policy):
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),

@ -67,7 +67,6 @@ class ChatGLMPolicy(Policy):
f"For ChatGLM2, sequence parallelism doesn't support mode {sp_mode} yet, will set to be split_gather" f"For ChatGLM2, sequence parallelism doesn't support mode {sp_mode} yet, will set to be split_gather"
) )
sp_mode = "split_gather" sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode in ["split_gather"] sp_partial_derived = sp_mode in ["split_gather"]
if sp_mode == "all_to_all": if sp_mode == "all_to_all":
@ -127,7 +126,6 @@ class ChatGLMPolicy(Policy):
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"seq_parallel_dim": 0, "seq_parallel_dim": 0,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),

@ -65,7 +65,6 @@ class GPT2Policy(Policy):
f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
) )
self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather" self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode in ["split_gather", "ring"] sp_partial_derived = sp_mode in ["split_gather", "ring"]
use_flash_attention = self.shard_config.enable_flash_attention use_flash_attention = self.shard_config.enable_flash_attention
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
@ -94,7 +93,6 @@ class GPT2Policy(Policy):
kwargs={ kwargs={
"split_sizes": [self.model.config.hidden_size] * 3, "split_sizes": [self.model.config.hidden_size] * 3,
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),
@ -109,7 +107,6 @@ class GPT2Policy(Policy):
kwargs={ kwargs={
"split_sizes": [self.model.config.n_inner or 4 * self.model.config.hidden_size], "split_sizes": [self.model.config.n_inner or 4 * self.model.config.hidden_size],
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused, "skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },

@ -51,7 +51,6 @@ class GPTJPolicy(Policy):
self.shard_config.enable_sequence_parallelism = False self.shard_config.enable_sequence_parallelism = False
warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
assert ( assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
@ -76,7 +75,6 @@ class GPTJPolicy(Policy):
suffix="attn.k_proj", suffix="attn.k_proj",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={ kwargs={
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),
@ -84,7 +82,6 @@ class GPTJPolicy(Policy):
suffix="attn.q_proj", suffix="attn.q_proj",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={ kwargs={
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),
@ -92,7 +89,6 @@ class GPTJPolicy(Policy):
suffix="attn.v_proj", suffix="attn.v_proj",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={ kwargs={
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),

@ -26,7 +26,6 @@ class ShardConfig:
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False. enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False. enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None. gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism. Defaults to False. fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism. Defaults to False.
@ -44,7 +43,6 @@ class ShardConfig:
enable_jit_fused: bool = False enable_jit_fused: bool = False
enable_sequence_parallelism: bool = False enable_sequence_parallelism: bool = False
sequence_parallelism_mode: str = None sequence_parallelism_mode: str = None
enable_sequence_overlap: bool = False
parallel_output: bool = True parallel_output: bool = True
make_vocab_size_divisible_by: int = 64 make_vocab_size_divisible_by: int = 64
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
@ -84,24 +82,12 @@ class ShardConfig:
assert ( assert (
self.enable_tensor_parallelism self.enable_tensor_parallelism
), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True" ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True"
elif self.sequence_parallelism_mode in ["all_to_all"]:
# assert (
# not self.enable_tensor_parallelism
# ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False"
if self.enable_sequence_overlap:
self.enable_sequence_overlap = False
warnings.warn(
f"The enable_sequence_overlap flag will be ignored in sequence parallelism mode {self.sequence_parallelism_mode}"
)
else: else:
if self.sequence_parallelism_mode: if self.sequence_parallelism_mode:
self.sequence_parallelism_mode = None self.sequence_parallelism_mode = None
warnings.warn( warnings.warn(
f"The sequence_parallelism_mode will be ignored when enable_sequence_parallelism is False" f"The sequence_parallelism_mode will be ignored when enable_sequence_parallelism is False"
) )
assert (
not self.enable_sequence_overlap
), f"enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True"
# get the tensor parallel size # get the tensor parallel size
if not self.enable_tensor_parallelism: if not self.enable_tensor_parallelism:
@ -134,4 +120,3 @@ class ShardConfig:
# This can cause non-in-place param sharding when used without ZeRO. # This can cause non-in-place param sharding when used without ZeRO.
# It may also slow down training when seq len is small. Plz enable manually. # It may also slow down training when seq len is small. Plz enable manually.
# self.enable_sequence_parallelism = True # self.enable_sequence_parallelism = True
# self.enable_sequence_overlap = True

@ -41,7 +41,7 @@ class Conv1D(nn.Module):
return x return x
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool): def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str):
ctx = LazyInitContext() if lazy_init else nullcontext() ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda() linear = Conv1D(192, 48).cuda()
with ctx: with ctx:
@ -52,7 +52,6 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
gather_output=True, gather_output=True,
seq_parallel_mode=seq_parallel_mode, seq_parallel_mode=seq_parallel_mode,
split_sizes=[64] * 3, split_sizes=[64] * 3,
overlap=overlap,
) )
assert linear.weight.shape == torch.Size([48, 192]) assert linear.weight.shape == torch.Size([48, 192])
@ -121,9 +120,8 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel_mode: bool):
@parameterize("lazy_init", [False, True]) @parameterize("lazy_init", [False, True])
@parameterize("seq_parallel_mode", ["split_gather", None]) @parameterize("seq_parallel_mode", ["split_gather", None])
@parameterize("overlap", [True]) def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool):
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool, overlap: bool): check_linear_conv_1d_col(lazy_init, seq_parallel_mode)
check_linear_conv_1d_col(lazy_init, seq_parallel_mode, overlap)
check_linear_conv_1d_row(lazy_init, seq_parallel_mode) check_linear_conv_1d_row(lazy_init, seq_parallel_mode)

Loading…
Cancel
Save