[shardformer] fix linear 1d row and support uneven splits for fused qkv linear (#6084)

* [tp] hotfix linear row

* [tp] support uneven split for fused linear

* [tp] support sp for fused linear

* [tp] fix gpt2 mlp policy

* [tp] fix gather fused and add fused linear row
supercooledith-patch-1
Hongxin Liu 2 months ago committed by GitHub
parent f4daf04270
commit 646b3c5a90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -57,7 +57,9 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
target_module=NopadBaichuanMLP, target_module=NopadBaichuanMLP,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.W_pack", target_module=FusedLinear1D_Col, kwargs={"n_fused": 3} suffix="self_attn.W_pack",
target_module=FusedLinear1D_Col,
kwargs={"split_sizes": [self.model.config.hidden_size] * 3},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.o_proj", suffix="self_attn.o_proj",

@ -6,7 +6,7 @@ from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHe
from .loss import cross_entropy_1d, dist_cross_entropy from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
__all__ = [ __all__ = [
"Embedding1D", "Embedding1D",
@ -34,4 +34,5 @@ __all__ = [
"RingAttention", "RingAttention",
"get_pad_info", "get_pad_info",
"all_to_all_comm", "all_to_all_comm",
"FusedLinear1D_Row",
] ]

@ -840,7 +840,7 @@ class _AllToAll(torch.autograd.Function):
ctx.gather_dim = gather_dim ctx.gather_dim = gather_dim
ctx.fp8_communication = fp8_communication ctx.fp8_communication = fp8_communication
world_size = dist.get_world_size(process_group) world_size = dist.get_world_size(process_group)
bsz, _, _ = input_.shape bsz = input_.shape[0]
# using all_to_all_single when batch size is 1 # using all_to_all_single when batch size is 1
if bsz == 1: if bsz == 1:
@ -871,7 +871,7 @@ class _AllToAll(torch.autograd.Function):
gather_dim = ctx.scatter_dim gather_dim = ctx.scatter_dim
fp8_communication = ctx.fp8_communication fp8_communication = ctx.fp8_communication
world_size = dist.get_world_size(process_group) world_size = dist.get_world_size(process_group)
bsz, _, _ = grad_output.shape bsz = grad_output.shape[0]
if bsz == 1: if bsz == 1:
return_grad = _all_to_all_single( return_grad = _all_to_all_single(

@ -428,11 +428,8 @@ class Linear1D_Row(ParallelModule):
handle.wait() handle.wait()
output = torch.cat(output_parallel_list, dim=-1) output = torch.cat(output_parallel_list, dim=-1)
else: else:
if self.seq_parallel_mode is None: if self.seq_parallel_mode == "split_gather":
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) output_parallel = F.linear(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
elif self.seq_parallel_mode == "split_gather":
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
output = reducescatter_forward_gather_backward( output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
) )
@ -445,8 +442,8 @@ class Linear1D_Row(ParallelModule):
ring=True, ring=True,
) )
else: else:
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) output_parallel = F.linear(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group) output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
if not self.skip_bias_add: if not self.skip_bias_add:
if self.bias is not None: if self.bias is not None:

@ -7,6 +7,7 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
@ -24,7 +25,9 @@ from colossalai.tensor.d_tensor.api import (
) )
from ._operation import ( from ._operation import (
gather_forward_split_backward, gather_forward_reducescatter_backward,
linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm, linear_with_async_comm,
matmul_gather_forward_reducescatter_backward, matmul_gather_forward_reducescatter_backward,
matmul_with_async_comm, matmul_with_async_comm,
@ -44,21 +47,25 @@ __all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col"
def split_fused_qkv_in_gpt2_style( def split_fused_qkv_in_gpt2_style(
qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup, is_transposed: bool = False
): ):
""" """
The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2]. The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2].
Args: Args:
qkv (torch.Tensor): The fused qkv tensor. qkv (torch.Tensor): The fused qkv tensor.
n_fused (int): The number items fused together, defaults to 3 (query, key and value). split_sizes (List[int]): The sizes of the split tensor.
process_group (ProcessGroup): The process group for distributed communication. process_group (ProcessGroup): The process group for distributed communication.
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features). is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
""" """
# get the number of slice for the fused qkv # get the number of slice for the fused qkv
rank = dist.get_rank(group=process_group) rank = dist.get_rank(group=process_group)
world_size = dist.get_world_size(group=process_group) world_size = dist.get_world_size(group=process_group)
order = torch.arange(world_size * n_fused) order = torch.arange(world_size * len(split_sizes))
new_split_sizes = []
for sz in split_sizes:
assert sz % world_size == 0, f"size {sz} is not divisible by world_size {world_size}"
new_split_sizes.extend([sz // world_size] * world_size)
# split the fused qkv # split the fused qkv
# from # from
@ -66,9 +73,9 @@ def split_fused_qkv_in_gpt2_style(
# to # to
# [Q1, Q2, K1, K2, V1, V2] # [Q1, Q2, K1, K2, V1, V2]
if is_transposed: if is_transposed:
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1) weight_chunks = torch.split(qkv, new_split_sizes, dim=-1)
else: else:
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=0) weight_chunks = torch.split(qkv, new_split_sizes, dim=0)
# rearrange the slice into the final order # rearrange the slice into the final order
# from # from
@ -85,18 +92,23 @@ def split_fused_qkv_in_gpt2_style(
def gather_fused_qkv_in_gpt2_style( def gather_fused_qkv_in_gpt2_style(
qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup, is_transposed: bool = False
): ):
""" """
The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2]. The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2].
Args: Args:
qkv (torch.Tensor): The fused qkv tensor. qkv (torch.Tensor): The fused qkv tensor.
n_fused (int): The number items fused together, defaults to 3 (query, key and value). split_sizes (List[int]): The sizes of the split tensor.
process_group (ProcessGroup): The process group for distributed communication. process_group (ProcessGroup): The process group for distributed communication.
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features). is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
""" """
world_size = dist.get_world_size(group=process_group) world_size = dist.get_world_size(group=process_group)
new_split_sizes = []
for sz in split_sizes:
assert sz % world_size == 0, f"size {sz} is not divisible by world_size {world_size}"
new_split_sizes.append(sz // world_size)
new_split_sizes = new_split_sizes * world_size
# gather the tensors # gather the tensors
# from # from
@ -121,13 +133,13 @@ def gather_fused_qkv_in_gpt2_style(
# to # to
# [Q1, Q2, K1, K2, V1, V2] # [Q1, Q2, K1, K2, V1, V2]
if is_transposed: if is_transposed:
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1) weight_chunks = torch.split(gather_weight, new_split_sizes, dim=-1)
else: else:
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=0) weight_chunks = torch.split(gather_weight, new_split_sizes, dim=0)
reordered_chunk_list = [] reordered_chunk_list = []
for i in range(n_fused): for i in range(len(split_sizes)):
reordered_chunk_list.extend(weight_chunks[i::n_fused]) reordered_chunk_list.extend(weight_chunks[i :: len(split_sizes)])
if is_transposed: if is_transposed:
reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1) reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1)
@ -136,6 +148,42 @@ def gather_fused_qkv_in_gpt2_style(
return reordered_gather_weight return reordered_gather_weight
class _SplitForwardGatherBackwardFusedQKV(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
ctx.split_sizes = split_sizes
ctx.process_group = process_group
return split_fused_qkv_in_gpt2_style(qkv, split_sizes, process_group, is_transposed=True)
@staticmethod
def backward(ctx, grad_output):
grad_output = gather_fused_qkv_in_gpt2_style(
grad_output, ctx.split_sizes, ctx.process_group, is_transposed=True
)
return grad_output, None, None
def split_forward_gather_backward_fused_qkv(qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
return _SplitForwardGatherBackwardFusedQKV.apply(qkv, split_sizes, process_group)
class _GatherForwardSplitBackwardFusedQKV(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
ctx.split_sizes = split_sizes
ctx.process_group = process_group
return gather_fused_qkv_in_gpt2_style(qkv, split_sizes, process_group, is_transposed=True)
@staticmethod
def backward(ctx, grad_output):
grad_output = split_fused_qkv_in_gpt2_style(grad_output, ctx.split_sizes, ctx.process_group, is_transposed=True)
return grad_output, None, None
def gather_forward_split_backward_fused_qkv(qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
return _GatherForwardSplitBackwardFusedQKV.apply(qkv, split_sizes, process_group)
class GPT2FusedLinearConv1D_Col(ParallelModule): class GPT2FusedLinearConv1D_Col(ParallelModule):
r"""Linear layer with column parallelism. r"""Linear layer with column parallelism.
@ -145,10 +193,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
Args: Args:
in_features (int): size of each input sample. in_features (int): size of each input sample.
out_features (int): size of each output sample. out_features (int): size of each output sample.
split_sizes (List[int]): The sizes of the split tensor.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None. dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None. device (`torch.device`): The device of parameters, defaults to None.
n_fused (int): The number items fused, defaults to 3 (QKV).
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None. seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available gather_output (bool, optional): If true, call all-gather on output and make Y available
@ -169,6 +217,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self, self,
in_features: int, in_features: int,
out_features: int, out_features: int,
split_sizes: List[int],
bias: bool = True, bias: bool = True,
dtype: torch.dtype = None, dtype: torch.dtype = None,
device: torch.device = None, device: torch.device = None,
@ -178,7 +227,6 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
seq_parallel_mode: str = None, seq_parallel_mode: str = None,
overlap: bool = False, overlap: bool = False,
skip_bias_add: bool = False, skip_bias_add: bool = False,
n_fused: int = 3,
weight: Optional[Parameter] = None, weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None, bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
@ -195,11 +243,15 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.overlap = overlap self.overlap = overlap
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.device = device self.device = device
self.n_fused = n_fused self.split_sizes = split_sizes
self.process_group = process_group self.process_group = process_group
self.async_communication = async_communication self.async_communication = async_communication
self.fp8_communication = fp8_communication self.fp8_communication = fp8_communication
assert (
sum(split_sizes) == out_features
), f"The sum of split_sizes({sum(split_sizes)}) should be equal to out_features({out_features})."
if skip_bias_add and not bias: if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None") raise ValueError("cannot skip bias addition if bias is None")
@ -223,10 +275,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.weight = weight self.weight = weight
def shard_fn(tensor): def shard_fn(tensor):
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True) return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
def gather_fn(tensor): def gather_fn(tensor):
return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True) return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
if not is_customized_distributed_tensor(self.weight): if not is_customized_distributed_tensor(self.weight):
with torch.no_grad(): with torch.no_grad():
@ -252,7 +304,11 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
@staticmethod @staticmethod
def from_native_module( def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs module: nn.Module,
process_group: Union[ProcessGroup, List[ProcessGroup]],
split_sizes: List[int],
*args,
**kwargs,
) -> ParallelModule: ) -> ParallelModule:
r""" r"""
Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer. Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer.
@ -260,7 +316,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
Args: Args:
module (`nn.Linear`): The module to be converted. module (`nn.Linear`): The module to be converted.
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication. process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight. split_sizes (List[int]): The sizes of the split tensor. In GPT2, Q,K,V are fused in one weight.
""" """
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
# get the attributes # get the attributes
@ -291,6 +347,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
process_group=process_group, process_group=process_group,
weight=module.weight, weight=module.weight,
bias_=module.bias, bias_=module.bias,
split_sizes=split_sizes,
*args, *args,
**kwargs, **kwargs,
) )
@ -354,9 +411,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = gather_forward_split_backward( output = gather_forward_split_backward_fused_qkv(output_parallel, self.split_sizes, self.process_group)
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
else: else:
output = output_parallel output = output_parallel
@ -605,10 +660,10 @@ class FusedLinear1D_Col(ParallelModule):
Args: Args:
in_features (int): size of each input sample. in_features (int): size of each input sample.
out_features (int): size of each output sample. out_features (int): size of each output sample.
split_sizes (List[int]): The sizes of the split tensor.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None. dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None. device (`torch.device`): The device of parameters, defaults to None.
n_fused (int): The number items fused, defaults to 3 (QKV).
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output to all GPUs, otherwise, every GPU will have its output
@ -628,14 +683,16 @@ class FusedLinear1D_Col(ParallelModule):
self, self,
in_features: int, in_features: int,
out_features: int, out_features: int,
split_sizes: List[int],
bias: bool = True, bias: bool = True,
dtype: torch.dtype = None, dtype: torch.dtype = None,
device: torch.device = None, device: torch.device = None,
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
async_communication: bool = False,
gather_output: bool = False, gather_output: bool = False,
seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
overlap: torch.cuda.Stream = None,
skip_bias_add: bool = False, skip_bias_add: bool = False,
n_fused: int = 3,
weight: Optional[Parameter] = None, weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None, bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
@ -647,13 +704,19 @@ class FusedLinear1D_Col(ParallelModule):
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.gather_output = gather_output self.gather_output = gather_output
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.overlap = overlap
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.device = device self.device = device
self.n_fused = n_fused self.split_sizes = split_sizes
self.process_group = process_group self.process_group = process_group
self.async_communication = async_communication
self.fp8_communication = fp8_communication self.fp8_communication = fp8_communication
assert (
sum(split_sizes) == out_features
), f"The sum of split_sizes({sum(split_sizes)}) should be equal to out_features({out_features})."
if skip_bias_add and not bias: if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None") raise ValueError("cannot skip bias addition if bias is None")
@ -677,10 +740,10 @@ class FusedLinear1D_Col(ParallelModule):
self.weight = weight self.weight = weight
def shard_fn(tensor): def shard_fn(tensor):
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False) return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, False)
def gather_fn(tensor): def gather_fn(tensor):
return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False) return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, False)
if not is_customized_distributed_tensor(self.weight): if not is_customized_distributed_tensor(self.weight):
with torch.no_grad(): with torch.no_grad():
@ -706,7 +769,11 @@ class FusedLinear1D_Col(ParallelModule):
@staticmethod @staticmethod
def from_native_module( def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, *args, **kwargs module: nn.Module,
process_group: Union[ProcessGroup, List[ProcessGroup]],
split_sizes: List[int],
*args,
**kwargs,
) -> ParallelModule: ) -> ParallelModule:
r""" r"""
Convert a fused `torch.nn.linear` layer to a parallelized linear layer. Convert a fused `torch.nn.linear` layer to a parallelized linear layer.
@ -714,7 +781,7 @@ class FusedLinear1D_Col(ParallelModule):
Args: Args:
module (`nn.Linear`): The module to be converted. module (`nn.Linear`): The module to be converted.
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication. process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight. split_sizes (List[int]): The sizes of the split tensor. In common, Q,K,V are fused in one weight.
""" """
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
@ -737,25 +804,11 @@ class FusedLinear1D_Col(ParallelModule):
process_group=process_group, process_group=process_group,
weight=module.weight, weight=module.weight,
bias_=module.bias, bias_=module.bias,
n_fused=n_fused, split_sizes=split_sizes,
*args, *args,
**kwargs, **kwargs,
) )
# # TODO: copy the sharded weights
# with torch.no_grad():
# sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
# n_fused=n_fused,
# process_group=process_group,
# is_transposed=False)
# linear_1d.weight.data.copy_(sharded_weight.data)
# if bias:
# sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
# n_fused=n_fused,
# process_group=process_group,
# is_transposed=False)
# linear_1d.bias.data.copy_(sharded_bias.data)
return linear_1d return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None: def reset_parameters(self, weight_initializer, bias_initializer) -> None:
@ -772,19 +825,30 @@ class FusedLinear1D_Col(ParallelModule):
input_.shape, self.weight.shape, self.weight.shape[-1] input_.shape, self.weight.shape, self.weight.shape[-1]
) )
# Set up backprop all-reduce. # Set up backprop all-reduce.
# input_parallel = reduce_backward(input_, self.process_group)
input_parallel = input_ input_parallel = input_
# Matrix multiply. # Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) if self.seq_parallel_mode == "split_gather":
input_parallel = gather_forward_reducescatter_backward(
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication
)
elif self.seq_parallel_mode == "ring":
output_parallel = linear_gather_forward_reducescatter_backward(
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
)
else:
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = gather_forward_split_backward( output = gather_forward_split_backward_fused_qkv(output_parallel, self.split_sizes, self.process_group)
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
else: else:
output = output_parallel output = output_parallel
@ -792,3 +856,201 @@ class FusedLinear1D_Col(ParallelModule):
return output, self.bias return output, self.bias
else: else:
return output return output
class FusedLinear1D_Row(ParallelModule):
r"""Linear layer with row parallelism
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None.
seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(
self,
in_features: int,
out_features: int,
split_sizes: List[int],
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
parallel_input: bool = True,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False,
):
super().__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.split_sizes = split_sizes
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.num_partitions = dist.get_world_size(self.process_group)
self.fp8_communication = fp8_communication
assert (
sum(split_sizes) == in_features
), f"The sum of split_sizes({sum(split_sizes)}) should be equal to in_features({in_features})."
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
else:
assert bias_ is None, "bias_ must be None if weight is None"
# Parameters.
if weight is None:
# Initialize weight.
factory_kwargs = {"device": device, "dtype": dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
def shard_fn(tensor):
return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
def gather_fn(tensor):
return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
if not is_customized_distributed_tensor(self.weight):
with torch.no_grad():
sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn)
customized_distributed_tensor_to_existing_param(sharded_weight, self.weight)
if bias:
if bias_ is None:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
else:
self.bias = None
if weight is None:
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], split_sizes: List[int], **kwargs
) -> ParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
linear_1d = FusedLinear1D_Row(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
weight=module.weight,
bias_=module.bias,
split_sizes=split_sizes,
**kwargs,
)
return linear_1d
@torch.no_grad()
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
if self.process_group is None:
src_rank = 0
else:
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
origin_device = self.bias.device
bias = self.bias.cuda()
dist.broadcast(bias, src=src_rank, group=self.process_group)
bias = bias.to(origin_device)
self.bias.copy_(bias)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
assert (
input_.shape[-1] == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1]
)
input_ = input_
else:
assert (
divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
)
input_ = split_forward_gather_backward_fused_qkv(input_, self.split_sizes, self.process_group)
if self.seq_parallel_mode == "split_gather":
output_parallel = F.linear(input_, self.weight)
output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
elif self.seq_parallel_mode == "ring":
output = linear_reducescatter_forward_gather_backward(
input_,
self.weight,
process_group=self.process_group,
dim=self.seq_parallel_dim,
ring=True,
)
else:
output_parallel = F.linear(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
if not self.skip_bias_add:
if self.bias is not None:
output = output + self.bias
return output
else:
return output, self.bias

@ -71,7 +71,7 @@ class BlipPolicy(Policy):
suffix="self_attn.qkv", suffix="self_attn.qkv",
target_module=col_nn.FusedLinear1D_Col, target_module=col_nn.FusedLinear1D_Col,
kwargs={ kwargs={
"n_fused": 3, "split_sizes": [self.model.config.vision_config.hidden_size] * 3,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),

@ -92,7 +92,7 @@ class GPT2Policy(Policy):
suffix="attn.c_attn", suffix="attn.c_attn",
target_module=col_nn.GPT2FusedLinearConv1D_Col, target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={ kwargs={
"n_fused": 3, "split_sizes": [self.model.config.hidden_size] * 3,
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"overlap": overlap, "overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
@ -107,7 +107,7 @@ class GPT2Policy(Policy):
suffix="mlp.c_fc", suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv1D_Col, target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={ kwargs={
"n_fused": 1, "split_sizes": [self.model.config.n_inner or 4 * self.model.config.hidden_size],
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"overlap": overlap, "overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused, "skip_bias_add": self.enable_bias_gelu_fused,

@ -42,7 +42,7 @@ class SamPolicy(Policy):
suffix="attn.qkv", suffix="attn.qkv",
target_module=col_nn.FusedLinear1D_Col, target_module=col_nn.FusedLinear1D_Col,
kwargs={ kwargs={
"n_fused": 3, "split_sizes": [self.model.config.vision_config.hidden_size] * 3,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),

@ -41,21 +41,6 @@ class Conv1D(nn.Module):
return x return x
def rearrange(tensor: torch.Tensor, dim: int):
tensor = tensor.clone()
world_size = 2
order = torch.arange(world_size * 3)
new_order = []
for i in range(world_size):
new_order.append(order[i::world_size])
new_order = torch.cat(new_order)
tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim)
rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order]
rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim)
return rearanged_tensor
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool): def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext() ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda() linear = Conv1D(192, 48).cuda()
@ -66,7 +51,7 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
process_group=None, process_group=None,
gather_output=True, gather_output=True,
seq_parallel_mode=seq_parallel_mode, seq_parallel_mode=seq_parallel_mode,
n_fused=3, split_sizes=[64] * 3,
overlap=overlap, overlap=overlap,
) )
@ -88,13 +73,13 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
) )
gather_out = linear_conv_col(x_for_shard) gather_out = linear_conv_col(x_for_shard)
assert_close(rearrange(out, -1), gather_out) assert_close(out, gather_out)
# check backward correctness # check backward correctness
out.sum().backward() out.sum().backward()
gather_out.sum().backward() gather_out.sum().backward()
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True) target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [64] * 3, None, True)
assert_close(target_grad, linear_conv_col.weight.grad) assert_close(target_grad, linear_conv_col.weight.grad)

@ -2,13 +2,12 @@ import os
from contextlib import nullcontext from contextlib import nullcontext
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row from colossalai.shardformer.layer import FusedLinear1D_Col, FusedLinear1D_Row
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@ -16,93 +15,55 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
class Conv1D(nn.Module):
"""
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
Basically works like a linear layer but the weights are transposed.
Args:
nf (`int`): The number of output features.
nx (`int`): The number of input features.
"""
def __init__(self, nf, nx):
super().__init__()
self.nf = nf
self.weight = nn.Parameter(torch.empty(nx, nf))
self.bias = nn.Parameter(torch.zeros(nf))
nn.init.normal_(self.weight, std=0.02)
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(size_out)
return x
def rearrange(tensor: torch.Tensor, dim: int):
tensor = tensor.clone()
world_size = 2
order = torch.arange(world_size * 3)
new_order = []
for i in range(world_size):
new_order.append(order[i::world_size])
new_order = torch.cat(new_order)
tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim)
rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order]
rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim)
return rearanged_tensor
@parameterize("lazy_init", [False, True]) @parameterize("lazy_init", [False, True])
def check_linear_conv_1d_col(lazy_init: bool): def check_linear_1d_col(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext() ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda() linear = nn.Linear(8, 80).cuda()
with ctx: with ctx:
linear_copy = Conv1D(192, 48).cuda() linear_copy = nn.Linear(8, 80).cuda()
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module( linear_col = FusedLinear1D_Col.from_native_module(
linear_copy, process_group=None, gather_output=True, n_fused=3 linear_copy, process_group=None, gather_output=True, split_sizes=[32, 32, 16]
) )
assert linear.weight.shape == torch.Size([48, 192]) assert linear.weight.shape == torch.Size([80, 8])
assert linear.bias.shape == torch.Size([192]) assert linear.bias.shape == torch.Size([80])
assert linear_conv_col.weight.shape == torch.Size([48, 96]) assert linear_col.weight.shape == torch.Size([40, 8])
assert linear_conv_col.bias.shape == torch.Size([96]) assert linear_col.bias.shape == torch.Size([40])
assert linear_copy.weight is linear_conv_col.weight assert linear_copy.weight is linear_col.weight
assert linear_copy.bias is linear_conv_col.bias assert linear_copy.bias is linear_col.bias
# ensure weights are reversibly loadable # ensure weights are reversibly loadable
linear_conv_col.load_state_dict(linear.state_dict()) linear_col.load_state_dict(linear.state_dict())
linear.load_state_dict(linear_conv_col.state_dict()) linear.load_state_dict(linear_col.state_dict())
# check computation correctness # check computation correctness
x = torch.rand(4, 48).cuda() x = torch.rand(4, 8).cuda()
out = linear(x) out = linear(x)
gather_out = linear_conv_col(x) gather_out = linear_col(x)
assert_close(rearrange(out, 1), gather_out) assert_close(out, gather_out)
# check backward correctness # check backward correctness
out.sum().backward() out.sum().backward()
gather_out.sum().backward() gather_out.sum().backward()
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True) target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [32, 32, 16], None, False)
assert_close(target_grad, linear_conv_col.weight.grad) assert_close(target_grad, linear_col.weight.grad)
@parameterize("lazy_init", [False, True]) @parameterize("lazy_init", [False, True])
def check_linear_conv_1d_row(lazy_init: bool): def check_linear_1d_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext() ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda() linear = nn.Linear(80, 8).cuda()
with ctx: with ctx:
linear_copy = Conv1D(192, 48).cuda() linear_copy = nn.Linear(80, 8).cuda()
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) linear_row = FusedLinear1D_Row.from_native_module(
linear_copy, process_group=None, split_sizes=[32, 32, 16], parallel_input=False
)
assert linear.weight.shape == torch.Size([48, 192]) assert linear.weight.shape == torch.Size([8, 80])
assert linear_row.weight.shape == torch.Size([24, 192]) assert linear_row.weight.shape == torch.Size([8, 40])
assert linear_row.bias.shape == torch.Size([192]) assert linear_row.bias.shape == torch.Size([8])
assert linear_copy.weight is linear_row.weight assert linear_copy.weight is linear_row.weight
assert linear_copy.bias is linear_row.bias assert linear_copy.bias is linear_row.bias
@ -111,7 +72,7 @@ def check_linear_conv_1d_row(lazy_init: bool):
linear.load_state_dict(linear_row.state_dict()) linear.load_state_dict(linear_row.state_dict())
# check computation correctness # check computation correctness
x = torch.rand(4, 48).cuda() x = torch.rand(4, 80).cuda()
out = linear(x) out = linear(x)
gather_out = linear_row(x) gather_out = linear_row(x)
assert_close(out, gather_out) assert_close(out, gather_out)
@ -120,17 +81,51 @@ def check_linear_conv_1d_row(lazy_init: bool):
out.sum().backward() out.sum().backward()
gather_out.sum().backward() gather_out.sum().backward()
rank = dist.get_rank() target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [32, 32, 16], None, True)
target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank]
assert_close(target_grad, linear_row.weight.grad) assert_close(target_grad, linear_row.weight.grad)
@parameterize("lazy_init", [False, True])
def check_linear_1d_col_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear1 = nn.Linear(8, 80).cuda()
linear2 = nn.Linear(80, 8).cuda()
with ctx:
linear1_copy = nn.Linear(8, 80).cuda()
linear2_copy = nn.Linear(80, 8).cuda()
linear_col = FusedLinear1D_Col.from_native_module(linear1_copy, process_group=None, split_sizes=[32, 32, 16])
linear_row = FusedLinear1D_Row.from_native_module(
linear2_copy,
process_group=None,
split_sizes=[32, 32, 16],
)
# ensure weights are reversibly loadable
linear_col.load_state_dict(linear1.state_dict())
linear_row.load_state_dict(linear2.state_dict())
# check computation correctness
x = torch.rand(4, 8).cuda()
target_out = linear2(linear1(x))
out = linear_row(linear_col(x))
assert_close(out, target_out)
# check backward correctness
target_out.sum().backward()
out.sum().backward()
target_grad1 = split_fused_qkv_in_gpt2_style(linear1.weight.grad, [32, 32, 16], None, False)
assert_close(target_grad1, linear_col.weight.grad)
target_grad2 = split_fused_qkv_in_gpt2_style(linear2.weight.grad, [32, 32, 16], None, True)
assert_close(target_grad2, linear_row.weight.grad)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# test for linear conv check_linear_1d_col()
check_linear_conv_1d_col() check_linear_1d_row()
check_linear_conv_1d_row() check_linear_1d_col_row()
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()

Loading…
Cancel
Save