Merge pull request #5899 from BurkeHulk/SP_fp8

[Feature] FP8 communication in ShardFormer
pull/5932/head
Guangyao Zhang 2024-07-18 10:46:59 +08:00 committed by GitHub
commit d0bdb51f48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 180 additions and 72 deletions

View File

@ -945,7 +945,8 @@ class HybridParallelPlugin(PipelinePluginBase):
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism.
fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism
"""
def __init__(
@ -1119,6 +1120,7 @@ class HybridParallelPlugin(PipelinePluginBase):
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication,
)
self.amp_config = dict(
initial_scale=initial_scale,

View File

@ -12,7 +12,6 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Te
scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling
is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied.
fp8_format: e4m3 or e5m2
Returns:
Tuples: A tuple (fp8_tensor, scale)
"""
@ -39,12 +38,10 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Te
def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor:
r"""
Args:
inp: should be a fp8 torch tensor in one of the types: [torch.float8_e4m3fn, torch.float8_e5m2].
scale: scaling factor returned by cast_to_fp8 function.
ret_type: the datatype of the returned tensor.
Returns:
torch.Tensor
"""
@ -58,20 +55,18 @@ def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dt
return ret.to(ret_type)
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None:
r"""
This is an in-place operation for compressed all_reduce using fp8.
It works like dist.all_reduce but during communication the data is cast to fp8 format.
Args:
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
fp8_format: e4m3 or e5m2
Returns:
None
"""
world_size = dist.get_world_size()
world_size = dist.get_world_size(group=group)
input_type = tensor.dtype
input_shape = tensor.shape
input_device = tensor.device
@ -88,19 +83,19 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)]
else:
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]
dist.all_to_all(output_chunks, input_chunks)
dist.all_to_all(output_chunks, input_chunks, group=group)
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
dist.all_gather(scale_list, scale)
dist.all_gather(scale_list, scale, group=group)
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
for scale, out in zip(scale_list, output_chunks):
out = out.view(fp8_type)
summed_out += cast_from_fp8(out, scale, input_type)
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
dist.all_gather(scale_list, scale)
dist.all_gather(scale_list, scale, group=group)
tensor_list = list(torch.chunk(torch.empty(input_size, device=input_device, dtype=torch.uint8), world_size, dim=0))
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8))
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group)
for i in range(world_size):
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
tensor_out = torch.cat(tensor_list, dim=0)
@ -170,3 +165,40 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
if del_metadata:
del inp["fp8_scale"]
def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2") -> None:
r"""
This is an in-place operation for compressed reduce_scatter using fp8.
It works like dist.reduce_scatter but during communication the data is cast to fp8 format.
Args:
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
fp8_format: e4m3 or e5m2
Returns:
None
"""
input_type = output.dtype
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
scale_list = []
cast_input_list = []
output_chunks = []
output_scale_list = []
for input in input_list:
ret, scale = cast_to_fp8(input, fp8_format=fp8_format)
scale_list.append(scale)
ret = ret.view(torch.uint8)
cast_input_list.append(ret)
output_chunks.append(torch.empty_like(ret))
output_scale_list.append(torch.empty_like(scale))
dist.all_to_all(output_chunks, cast_input_list, group=group)
dist.all_to_all(output_scale_list, scale_list, group=group)
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
for scale, out in zip(output_scale_list, output_chunks):
out = out.view(fp8_type)
summed_out += cast_from_fp8(out, scale, input_type)
output.data = summed_out

View File

@ -14,6 +14,8 @@ try:
except ImportError:
_grad_accum_fusion_available = False
from colossalai.quantization.fp8 import all_reduce_fp8, cast_from_fp8, cast_to_fp8, reduce_scatter_fp8
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
r"""Layernorm
@ -59,11 +61,12 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
ctx.fp8_communication = fp8_communication
output = torch.matmul(input_, weight)
@ -76,6 +79,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
fp8_communication = ctx.fp8_communication
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
weight = weight.view(weight.shape)
@ -90,7 +94,9 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_allreduce:
if ctx.async_grad_allreduce and fp8_communication:
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication)
elif ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
@ -99,10 +105,10 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce:
if ctx.async_grad_allreduce and not fp8_communication:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None
return grad_input, grad_weight, grad_bias, None, None, None, None
class LinearWithAsyncCommunication(torch.autograd.Function):
@ -242,7 +248,6 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function):
def backward(ctx, grad_output):
dim = ctx.dim
process_group = ctx.process_group
# do reduce-scatter
new_shape = list(grad_output.shape)
assert (
@ -253,6 +258,7 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function):
item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device)
dist.reduce_scatter(output, grad_list, group=process_group)
return output, None, None
@ -546,9 +552,10 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, process_group, dim):
def forward(ctx, input_, process_group, dim, fp8_communication=False):
ctx.dim = dim
ctx.process_group = process_group
ctx.fp8_communication = fp8_communication
# do reduce-scatter
new_shape = list(input_.shape)
@ -558,7 +565,10 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
dist.reduce_scatter(output, input_list, group=process_group)
if fp8_communication:
reduce_scatter_fp8(output, input_list, group=process_group)
else:
dist.reduce_scatter(output, input_list, group=process_group)
return output
@ -566,8 +576,8 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
def backward(ctx, grad_output):
dim = ctx.dim
process_group = ctx.process_group
return _gather(grad_output, dim, process_group), None, None
fp8_communication = ctx.fp8_communication
return _gather(grad_output, dim, process_group, fp8_communication=fp8_communication), None, None, None
class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
@ -582,13 +592,16 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring):
def forward(
ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap
ctx.fp8_communication = fp8_communication
if ring is True:
input_to_gather = {}
@ -605,7 +618,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
)
else:
input_parallel = _gather(input_, dim, process_group)
input_parallel = _gather(input_, dim, process_group, fp8_communication)
output = torch.matmul(input_parallel, weight)
@ -620,6 +633,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap
fp8_communication = ctx.fp8_communication
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
weight = weight.view(weight.shape)
@ -627,7 +641,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
bias = bias.view(bias.shape)
if not overlap:
input_parallel = _gather(input_, dim, process_group)
input_parallel = _gather(input_, dim, process_group, fp8_communication)
total_input = input_parallel
grad_input = grad_output.matmul(weight.T)
@ -687,7 +701,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
# wait until reduce-scatter finished
reducescatter_handle.wait()
return output, grad_weight, grad_bias, None, None, None, None, None
return output, grad_weight, grad_bias, None, None, None, None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function):
@ -702,17 +716,20 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, dim, process_group, grad_scale=None):
def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=False):
ctx.process_group = process_group
ctx.dim = dim
ctx.grad_scale = grad_scale
ctx.fp8_communication = fp8_communication
return _split(input_, dim, process_group)
@staticmethod
def backward(ctx, grad_output):
if ctx.grad_scale is not None:
grad_output = grad_output * ctx.grad_scale
return _gather(grad_output, ctx.dim, ctx.process_group), None, None, None
# to_cast.append(grad_output.cpu().detach().numpy())
return _gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication), None, None, None, None
class _ReduceForward(torch.autograd.Function):
@ -725,12 +742,12 @@ class _ReduceForward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, process_group):
return _reduce(input_, process_group)
def forward(ctx, input_, process_group, fp8_communication=False):
return _reduce(input_, process_group, fp8_communication)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
return grad_output, None, None
class _ReduceBackward(torch.autograd.Function):
@ -743,13 +760,15 @@ class _ReduceBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, process_group):
def forward(ctx, input_, process_group, fp8_communication=False):
ctx.process_group = process_group
ctx.fp8_communication = fp8_communication
return input_
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output, ctx.process_group), None
fp8_communication = ctx.fp8_communication
return _reduce(grad_output, ctx.process_group, fp8_communication), None, None
class _GatherForwardSplitBackward(torch.autograd.Function):
@ -762,17 +781,18 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, dim, process_group, grad_scale=None):
def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=False):
ctx.process_group = process_group
ctx.dim = dim
ctx.grad_scale = grad_scale
return _gather(input_, dim, process_group)
return _gather(input_, dim, process_group, fp8_communication=fp8_communication)
@staticmethod
def backward(ctx, grad_output):
if ctx.grad_scale is not None:
grad_output = grad_output * ctx.grad_scale
return _split(grad_output, ctx.dim, ctx.process_group), None, None, None
return _split(grad_output, ctx.dim, ctx.process_group), None, None, None, None
class _AllToAll(torch.autograd.Function):
@ -831,12 +851,15 @@ def hook_parameter_in_backward(input, weight=None, bias=None):
return HookParameter.apply(input, weight, bias)
def _reduce(input_, process_group):
def _reduce(input_, process_group, fp8_communication=False):
# skip if only one rank involved
if dist.get_world_size(process_group) == 1:
return input_
else:
dist.all_reduce(input_, group=process_group)
if fp8_communication:
all_reduce_fp8(input_, group=process_group)
else:
dist.all_reduce(input_, group=process_group)
return input_
@ -860,19 +883,39 @@ def _split(input_, dim=-1, process_group=None):
return output
def _gather(input_, dim=-1, process_group=None):
def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_format="e5m2"):
# skip if only one rank involved
world_size = dist.get_world_size(process_group)
if world_size == 1:
return input_
# all gather
input_ = input_.contiguous()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
torch.distributed.all_gather(tensor_list, input_, group=process_group)
if fp8_communication:
input_type = input_.dtype
ret, scale = cast_to_fp8(input_, fp8_format=fp8_format)
fp8_type = ret.dtype
input_ = ret.view(torch.uint8)
input_ = input_.contiguous()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
scale = torch.tensor(scale, dtype=torch.float32).to(input_.device)
scale_list = [torch.ones(1, dtype=torch.float32, device=input_.device) for _ in range(world_size)]
# concat
output = torch.cat(tensor_list, dim=dim).contiguous()
scale = torch.tensor(scale).to(input_.device)
torch.distributed.all_gather(tensor_list, input_, group=process_group)
torch.distributed.all_gather(scale_list, scale, group=process_group)
cast_tensor_list = []
for output, scale in zip(tensor_list, scale_list):
output = output.view(fp8_type)
output = cast_from_fp8(output, scale, input_type)
cast_tensor_list.append(output)
output = torch.cat(cast_tensor_list, dim=dim).contiguous()
else:
input_ = input_.contiguous()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
torch.distributed.all_gather(tensor_list, input_, group=process_group)
output = torch.cat(tensor_list, dim=dim).contiguous()
return output
@ -935,8 +978,10 @@ def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
).contiguous()
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
return MatmulWithAsyncCommunication.apply(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication
)
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
@ -955,8 +1000,8 @@ def gather_forward_reducescatter_backward(input_, process_group, dim):
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim)
def reducescatter_forward_gather_backward(input_, process_group, dim):
return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
def reducescatter_forward_gather_backward(input_, process_group, dim, fp8_communication=False):
return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim, fp8_communication)
def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False):
@ -964,27 +1009,27 @@ def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, proc
def matmul_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False, fp8_communication=False
):
return _MatmulWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
)
def gather_forward_split_backward(input_, dim, process_group, grad_scale=None):
return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale)
def gather_forward_split_backward(input_, dim, process_group, grad_scale=None, fp8_communication=False):
return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale, fp8_communication)
def split_forward_gather_backward(input_, dim, process_group, grad_scale=None):
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale)
def split_forward_gather_backward(input_, dim, process_group, grad_scale=None, fp8_communication=False):
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale, fp8_communication)
def reduce_forward(input_, process_group):
return _ReduceForward.apply(input_, process_group)
def reduce_forward(input_, process_group, fp8_communication=False):
return _ReduceForward.apply(input_, process_group, fp8_communication)
def reduce_backward(input_, process_group):
return _ReduceBackward.apply(input_, process_group)
def reduce_backward(input_, process_group, fp8_communication=False):
return _ReduceBackward.apply(input_, process_group, fp8_communication)
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):

View File

@ -183,6 +183,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False,
):
super().__init__()
@ -197,6 +198,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.n_fused = n_fused
self.process_group = process_group
self.async_communication = async_communication
self.fp8_communication = fp8_communication
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
@ -314,14 +316,26 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
if self.seq_parallel_mode is None:
# Set up backprop all-reduce.
input_parallel = reduce_backward(input_, self.process_group)
input_parallel = reduce_backward(input_, self.process_group, fp8_communication=self.fp8_communication)
output_parallel = matmul_with_async_comm(
input_parallel, self.weight, bias, self.process_group, self.async_communication
input_parallel,
self.weight,
bias,
self.process_group,
self.async_communication,
fp8_communication=self.fp8_communication,
)
elif self.seq_parallel_mode == "split_gather":
input_parallel = input_
output_parallel = matmul_gather_forward_reducescatter_backward(
input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap
input_parallel,
self.weight,
bias,
self.process_group,
True,
1,
self.overlap,
fp8_communication=self.fp8_communication,
)
elif self.seq_parallel_mode == "ring":
input_parallel = input_
@ -331,7 +345,9 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
output = gather_forward_split_backward(
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
else:
output = output_parallel
@ -379,6 +395,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
stream_chunk_num: int = 1,
fp8_communication: bool = False,
):
super().__init__()
@ -392,6 +409,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
self.process_group = process_group
self.seq_parallel_mode = seq_parallel_mode
self.num_partitions = dist.get_world_size(self.process_group)
self.fp8_communication = fp8_communication
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
@ -514,7 +532,9 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[0] * self.num_partitions
)
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
input_ = split_forward_gather_backward(
input_, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
if self.stream_chunk_num > 1:
if self.training:
@ -535,13 +555,20 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
else:
if self.seq_parallel_mode is None:
output_parallel = torch.matmul(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group)
output = reduce_forward(output_parallel, self.process_group, self.fp8_communication)
elif self.seq_parallel_mode == "split_gather":
output_parallel = torch.matmul(input_, self.weight)
output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
output = reducescatter_forward_gather_backward(
output_parallel,
self.process_group,
1,
self.fp8_communication,
)
elif self.seq_parallel_mode == "ring":
output_parallel = torch.matmul(input_, self.weight)
output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, 1, self.fp8_communication
)
if not self.skip_bias_add:
if self.bias is not None:

View File

@ -1137,6 +1137,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
@ -1204,6 +1205,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
hidden_states = self.ln_f(hidden_states)

View File

@ -110,14 +110,13 @@ class GPT2Policy(Policy):
"n_fused": 3,
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={
"seq_parallel_mode": sp_mode,
},
kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription(
suffix="mlp.c_fc",
@ -127,14 +126,13 @@ class GPT2Policy(Policy):
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={
"seq_parallel_mode": sp_mode,
},
kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",

View File

@ -29,6 +29,7 @@ class ShardConfig:
enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism. Defaults to False.
"""
tensor_parallel_process_group: Optional[ProcessGroup] = None
@ -47,6 +48,7 @@ class ShardConfig:
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
ep_group: Optional[ProcessGroup] = None
fp8_communication: bool = False
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']