Browse Source

[shardformer] fix linear 1d row and support uneven splits for fused qkv linear (#6084)

* [tp] hotfix linear row

* [tp] support uneven split for fused linear

* [tp] support sp for fused linear

* [tp] fix gpt2 mlp policy

* [tp] fix gather fused and add fused linear row
supercooledith-patch-1
Hongxin Liu 1 month ago committed by GitHub
parent
commit
646b3c5a90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 4
      colossalai/inference/modeling/policy/nopadding_baichuan.py
  2. 3
      colossalai/shardformer/layer/__init__.py
  3. 4
      colossalai/shardformer/layer/_operation.py
  4. 11
      colossalai/shardformer/layer/linear.py
  5. 364
      colossalai/shardformer/layer/qkv_fused_linear.py
  6. 2
      colossalai/shardformer/policies/blip2.py
  7. 4
      colossalai/shardformer/policies/gpt2.py
  8. 2
      colossalai/shardformer/policies/sam.py
  9. 21
      tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py
  10. 141
      tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py

4
colossalai/inference/modeling/policy/nopadding_baichuan.py

@ -57,7 +57,9 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
target_module=NopadBaichuanMLP,
),
SubModuleReplacementDescription(
suffix="self_attn.W_pack", target_module=FusedLinear1D_Col, kwargs={"n_fused": 3}
suffix="self_attn.W_pack",
target_module=FusedLinear1D_Col,
kwargs={"split_sizes": [self.model.config.hidden_size] * 3},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",

3
colossalai/shardformer/layer/__init__.py

@ -6,7 +6,7 @@ from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHe
from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
__all__ = [
"Embedding1D",
@ -34,4 +34,5 @@ __all__ = [
"RingAttention",
"get_pad_info",
"all_to_all_comm",
"FusedLinear1D_Row",
]

4
colossalai/shardformer/layer/_operation.py

@ -840,7 +840,7 @@ class _AllToAll(torch.autograd.Function):
ctx.gather_dim = gather_dim
ctx.fp8_communication = fp8_communication
world_size = dist.get_world_size(process_group)
bsz, _, _ = input_.shape
bsz = input_.shape[0]
# using all_to_all_single when batch size is 1
if bsz == 1:
@ -871,7 +871,7 @@ class _AllToAll(torch.autograd.Function):
gather_dim = ctx.scatter_dim
fp8_communication = ctx.fp8_communication
world_size = dist.get_world_size(process_group)
bsz, _, _ = grad_output.shape
bsz = grad_output.shape[0]
if bsz == 1:
return_grad = _all_to_all_single(

11
colossalai/shardformer/layer/linear.py

@ -428,11 +428,8 @@ class Linear1D_Row(ParallelModule):
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
if self.seq_parallel_mode is None:
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
elif self.seq_parallel_mode == "split_gather":
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
if self.seq_parallel_mode == "split_gather":
output_parallel = F.linear(input_, self.weight)
output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
@ -445,8 +442,8 @@ class Linear1D_Row(ParallelModule):
ring=True,
)
else:
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
output = reduce_forward(output_parallel, self.process_group)
output_parallel = F.linear(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
if not self.skip_bias_add:
if self.bias is not None:

364
colossalai/shardformer/layer/qkv_fused_linear.py

@ -7,6 +7,7 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
@ -24,7 +25,9 @@ from colossalai.tensor.d_tensor.api import (
)
from ._operation import (
gather_forward_split_backward,
gather_forward_reducescatter_backward,
linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
matmul_gather_forward_reducescatter_backward,
matmul_with_async_comm,
@ -44,21 +47,25 @@ __all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col"
def split_fused_qkv_in_gpt2_style(
qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False
qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup, is_transposed: bool = False
):
"""
The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2].
Args:
qkv (torch.Tensor): The fused qkv tensor.
n_fused (int): The number items fused together, defaults to 3 (query, key and value).
split_sizes (List[int]): The sizes of the split tensor.
process_group (ProcessGroup): The process group for distributed communication.
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
"""
# get the number of slice for the fused qkv
rank = dist.get_rank(group=process_group)
world_size = dist.get_world_size(group=process_group)
order = torch.arange(world_size * n_fused)
order = torch.arange(world_size * len(split_sizes))
new_split_sizes = []
for sz in split_sizes:
assert sz % world_size == 0, f"size {sz} is not divisible by world_size {world_size}"
new_split_sizes.extend([sz // world_size] * world_size)
# split the fused qkv
# from
@ -66,9 +73,9 @@ def split_fused_qkv_in_gpt2_style(
# to
# [Q1, Q2, K1, K2, V1, V2]
if is_transposed:
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1)
weight_chunks = torch.split(qkv, new_split_sizes, dim=-1)
else:
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=0)
weight_chunks = torch.split(qkv, new_split_sizes, dim=0)
# rearrange the slice into the final order
# from
@ -85,18 +92,23 @@ def split_fused_qkv_in_gpt2_style(
def gather_fused_qkv_in_gpt2_style(
qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False
qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup, is_transposed: bool = False
):
"""
The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2].
Args:
qkv (torch.Tensor): The fused qkv tensor.
n_fused (int): The number items fused together, defaults to 3 (query, key and value).
split_sizes (List[int]): The sizes of the split tensor.
process_group (ProcessGroup): The process group for distributed communication.
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
"""
world_size = dist.get_world_size(group=process_group)
new_split_sizes = []
for sz in split_sizes:
assert sz % world_size == 0, f"size {sz} is not divisible by world_size {world_size}"
new_split_sizes.append(sz // world_size)
new_split_sizes = new_split_sizes * world_size
# gather the tensors
# from
@ -121,13 +133,13 @@ def gather_fused_qkv_in_gpt2_style(
# to
# [Q1, Q2, K1, K2, V1, V2]
if is_transposed:
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1)
weight_chunks = torch.split(gather_weight, new_split_sizes, dim=-1)
else:
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=0)
weight_chunks = torch.split(gather_weight, new_split_sizes, dim=0)
reordered_chunk_list = []
for i in range(n_fused):
reordered_chunk_list.extend(weight_chunks[i::n_fused])
for i in range(len(split_sizes)):
reordered_chunk_list.extend(weight_chunks[i :: len(split_sizes)])
if is_transposed:
reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1)
@ -136,6 +148,42 @@ def gather_fused_qkv_in_gpt2_style(
return reordered_gather_weight
class _SplitForwardGatherBackwardFusedQKV(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
ctx.split_sizes = split_sizes
ctx.process_group = process_group
return split_fused_qkv_in_gpt2_style(qkv, split_sizes, process_group, is_transposed=True)
@staticmethod
def backward(ctx, grad_output):
grad_output = gather_fused_qkv_in_gpt2_style(
grad_output, ctx.split_sizes, ctx.process_group, is_transposed=True
)
return grad_output, None, None
def split_forward_gather_backward_fused_qkv(qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
return _SplitForwardGatherBackwardFusedQKV.apply(qkv, split_sizes, process_group)
class _GatherForwardSplitBackwardFusedQKV(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
ctx.split_sizes = split_sizes
ctx.process_group = process_group
return gather_fused_qkv_in_gpt2_style(qkv, split_sizes, process_group, is_transposed=True)
@staticmethod
def backward(ctx, grad_output):
grad_output = split_fused_qkv_in_gpt2_style(grad_output, ctx.split_sizes, ctx.process_group, is_transposed=True)
return grad_output, None, None
def gather_forward_split_backward_fused_qkv(qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
return _GatherForwardSplitBackwardFusedQKV.apply(qkv, split_sizes, process_group)
class GPT2FusedLinearConv1D_Col(ParallelModule):
r"""Linear layer with column parallelism.
@ -145,10 +193,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
split_sizes (List[int]): The sizes of the split tensor.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
n_fused (int): The number items fused, defaults to 3 (QKV).
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
@ -169,6 +217,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self,
in_features: int,
out_features: int,
split_sizes: List[int],
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
@ -178,7 +227,6 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
seq_parallel_mode: str = None,
overlap: bool = False,
skip_bias_add: bool = False,
n_fused: int = 3,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
@ -195,11 +243,15 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.n_fused = n_fused
self.split_sizes = split_sizes
self.process_group = process_group
self.async_communication = async_communication
self.fp8_communication = fp8_communication
assert (
sum(split_sizes) == out_features
), f"The sum of split_sizes({sum(split_sizes)}) should be equal to out_features({out_features})."
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
@ -223,10 +275,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.weight = weight
def shard_fn(tensor):
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)
return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
def gather_fn(tensor):
return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)
return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
if not is_customized_distributed_tensor(self.weight):
with torch.no_grad():
@ -252,7 +304,11 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
module: nn.Module,
process_group: Union[ProcessGroup, List[ProcessGroup]],
split_sizes: List[int],
*args,
**kwargs,
) -> ParallelModule:
r"""
Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer.
@ -260,7 +316,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
Args:
module (`nn.Linear`): The module to be converted.
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight.
split_sizes (List[int]): The sizes of the split tensor. In GPT2, Q,K,V are fused in one weight.
"""
LazyInitContext.materialize(module)
# get the attributes
@ -291,6 +347,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
process_group=process_group,
weight=module.weight,
bias_=module.bias,
split_sizes=split_sizes,
*args,
**kwargs,
)
@ -354,9 +411,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
output = gather_forward_split_backward_fused_qkv(output_parallel, self.split_sizes, self.process_group)
else:
output = output_parallel
@ -605,10 +660,10 @@ class FusedLinear1D_Col(ParallelModule):
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
split_sizes (List[int]): The sizes of the split tensor.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
n_fused (int): The number items fused, defaults to 3 (QKV).
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
@ -628,14 +683,16 @@ class FusedLinear1D_Col(ParallelModule):
self,
in_features: int,
out_features: int,
split_sizes: List[int],
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
async_communication: bool = False,
gather_output: bool = False,
seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
overlap: torch.cuda.Stream = None,
skip_bias_add: bool = False,
n_fused: int = 3,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
@ -647,13 +704,19 @@ class FusedLinear1D_Col(ParallelModule):
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.n_fused = n_fused
self.split_sizes = split_sizes
self.process_group = process_group
self.async_communication = async_communication
self.fp8_communication = fp8_communication
assert (
sum(split_sizes) == out_features
), f"The sum of split_sizes({sum(split_sizes)}) should be equal to out_features({out_features})."
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
@ -677,10 +740,10 @@ class FusedLinear1D_Col(ParallelModule):
self.weight = weight
def shard_fn(tensor):
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, False)
def gather_fn(tensor):
return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, False)
if not is_customized_distributed_tensor(self.weight):
with torch.no_grad():
@ -706,7 +769,11 @@ class FusedLinear1D_Col(ParallelModule):
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, *args, **kwargs
module: nn.Module,
process_group: Union[ProcessGroup, List[ProcessGroup]],
split_sizes: List[int],
*args,
**kwargs,
) -> ParallelModule:
r"""
Convert a fused `torch.nn.linear` layer to a parallelized linear layer.
@ -714,7 +781,7 @@ class FusedLinear1D_Col(ParallelModule):
Args:
module (`nn.Linear`): The module to be converted.
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight.
split_sizes (List[int]): The sizes of the split tensor. In common, Q,K,V are fused in one weight.
"""
LazyInitContext.materialize(module)
@ -737,25 +804,11 @@ class FusedLinear1D_Col(ParallelModule):
process_group=process_group,
weight=module.weight,
bias_=module.bias,
n_fused=n_fused,
split_sizes=split_sizes,
*args,
**kwargs,
)
# # TODO: copy the sharded weights
# with torch.no_grad():
# sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
# n_fused=n_fused,
# process_group=process_group,
# is_transposed=False)
# linear_1d.weight.data.copy_(sharded_weight.data)
# if bias:
# sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
# n_fused=n_fused,
# process_group=process_group,
# is_transposed=False)
# linear_1d.bias.data.copy_(sharded_bias.data)
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
@ -772,19 +825,30 @@ class FusedLinear1D_Col(ParallelModule):
input_.shape, self.weight.shape, self.weight.shape[-1]
)
# Set up backprop all-reduce.
# input_parallel = reduce_backward(input_, self.process_group)
input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
if self.seq_parallel_mode == "split_gather":
input_parallel = gather_forward_reducescatter_backward(
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication
)
elif self.seq_parallel_mode == "ring":
output_parallel = linear_gather_forward_reducescatter_backward(
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
)
else:
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
output = gather_forward_split_backward_fused_qkv(output_parallel, self.split_sizes, self.process_group)
else:
output = output_parallel
@ -792,3 +856,201 @@ class FusedLinear1D_Col(ParallelModule):
return output, self.bias
else:
return output
class FusedLinear1D_Row(ParallelModule):
r"""Linear layer with row parallelism
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None.
seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(
self,
in_features: int,
out_features: int,
split_sizes: List[int],
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
parallel_input: bool = True,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False,
):
super().__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.split_sizes = split_sizes
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.num_partitions = dist.get_world_size(self.process_group)
self.fp8_communication = fp8_communication
assert (
sum(split_sizes) == in_features
), f"The sum of split_sizes({sum(split_sizes)}) should be equal to in_features({in_features})."
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
else:
assert bias_ is None, "bias_ must be None if weight is None"
# Parameters.
if weight is None:
# Initialize weight.
factory_kwargs = {"device": device, "dtype": dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
def shard_fn(tensor):
return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
def gather_fn(tensor):
return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
if not is_customized_distributed_tensor(self.weight):
with torch.no_grad():
sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn)
customized_distributed_tensor_to_existing_param(sharded_weight, self.weight)
if bias:
if bias_ is None:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
else:
self.bias = None
if weight is None:
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], split_sizes: List[int], **kwargs
) -> ParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
linear_1d = FusedLinear1D_Row(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
weight=module.weight,
bias_=module.bias,
split_sizes=split_sizes,
**kwargs,
)
return linear_1d
@torch.no_grad()
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
if self.process_group is None:
src_rank = 0
else:
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
origin_device = self.bias.device
bias = self.bias.cuda()
dist.broadcast(bias, src=src_rank, group=self.process_group)
bias = bias.to(origin_device)
self.bias.copy_(bias)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
assert (
input_.shape[-1] == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1]
)
input_ = input_
else:
assert (
divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
)
input_ = split_forward_gather_backward_fused_qkv(input_, self.split_sizes, self.process_group)
if self.seq_parallel_mode == "split_gather":
output_parallel = F.linear(input_, self.weight)
output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
elif self.seq_parallel_mode == "ring":
output = linear_reducescatter_forward_gather_backward(
input_,
self.weight,
process_group=self.process_group,
dim=self.seq_parallel_dim,
ring=True,
)
else:
output_parallel = F.linear(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
if not self.skip_bias_add:
if self.bias is not None:
output = output + self.bias
return output
else:
return output, self.bias

2
colossalai/shardformer/policies/blip2.py

@ -71,7 +71,7 @@ class BlipPolicy(Policy):
suffix="self_attn.qkv",
target_module=col_nn.FusedLinear1D_Col,
kwargs={
"n_fused": 3,
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
"fp8_communication": self.shard_config.fp8_communication,
},
),

4
colossalai/shardformer/policies/gpt2.py

@ -92,7 +92,7 @@ class GPT2Policy(Policy):
suffix="attn.c_attn",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 3,
"split_sizes": [self.model.config.hidden_size] * 3,
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
@ -107,7 +107,7 @@ class GPT2Policy(Policy):
suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 1,
"split_sizes": [self.model.config.n_inner or 4 * self.model.config.hidden_size],
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused,

2
colossalai/shardformer/policies/sam.py

@ -42,7 +42,7 @@ class SamPolicy(Policy):
suffix="attn.qkv",
target_module=col_nn.FusedLinear1D_Col,
kwargs={
"n_fused": 3,
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
"fp8_communication": self.shard_config.fp8_communication,
},
),

21
tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py

@ -41,21 +41,6 @@ class Conv1D(nn.Module):
return x
def rearrange(tensor: torch.Tensor, dim: int):
tensor = tensor.clone()
world_size = 2
order = torch.arange(world_size * 3)
new_order = []
for i in range(world_size):
new_order.append(order[i::world_size])
new_order = torch.cat(new_order)
tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim)
rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order]
rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim)
return rearanged_tensor
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
@ -66,7 +51,7 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
process_group=None,
gather_output=True,
seq_parallel_mode=seq_parallel_mode,
n_fused=3,
split_sizes=[64] * 3,
overlap=overlap,
)
@ -88,13 +73,13 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
)
gather_out = linear_conv_col(x_for_shard)
assert_close(rearrange(out, -1), gather_out)
assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True)
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [64] * 3, None, True)
assert_close(target_grad, linear_conv_col.weight.grad)

141
tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py

@ -2,13 +2,12 @@ import os
from contextlib import nullcontext
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
from colossalai.shardformer.layer import FusedLinear1D_Col, FusedLinear1D_Row
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@ -16,93 +15,55 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
class Conv1D(nn.Module):
"""
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
Basically works like a linear layer but the weights are transposed.
Args:
nf (`int`): The number of output features.
nx (`int`): The number of input features.
"""
def __init__(self, nf, nx):
super().__init__()
self.nf = nf
self.weight = nn.Parameter(torch.empty(nx, nf))
self.bias = nn.Parameter(torch.zeros(nf))
nn.init.normal_(self.weight, std=0.02)
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(size_out)
return x
def rearrange(tensor: torch.Tensor, dim: int):
tensor = tensor.clone()
world_size = 2
order = torch.arange(world_size * 3)
new_order = []
for i in range(world_size):
new_order.append(order[i::world_size])
new_order = torch.cat(new_order)
tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim)
rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order]
rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim)
return rearanged_tensor
@parameterize("lazy_init", [False, True])
def check_linear_conv_1d_col(lazy_init: bool):
def check_linear_1d_col(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
linear = nn.Linear(8, 80).cuda()
with ctx:
linear_copy = Conv1D(192, 48).cuda()
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(
linear_copy, process_group=None, gather_output=True, n_fused=3
linear_copy = nn.Linear(8, 80).cuda()
linear_col = FusedLinear1D_Col.from_native_module(
linear_copy, process_group=None, gather_output=True, split_sizes=[32, 32, 16]
)
assert linear.weight.shape == torch.Size([48, 192])
assert linear.bias.shape == torch.Size([192])
assert linear_conv_col.weight.shape == torch.Size([48, 96])
assert linear_conv_col.bias.shape == torch.Size([96])
assert linear_copy.weight is linear_conv_col.weight
assert linear_copy.bias is linear_conv_col.bias
assert linear.weight.shape == torch.Size([80, 8])
assert linear.bias.shape == torch.Size([80])
assert linear_col.weight.shape == torch.Size([40, 8])
assert linear_col.bias.shape == torch.Size([40])
assert linear_copy.weight is linear_col.weight
assert linear_copy.bias is linear_col.bias
# ensure weights are reversibly loadable
linear_conv_col.load_state_dict(linear.state_dict())
linear.load_state_dict(linear_conv_col.state_dict())
linear_col.load_state_dict(linear.state_dict())
linear.load_state_dict(linear_col.state_dict())
# check computation correctness
x = torch.rand(4, 48).cuda()
x = torch.rand(4, 8).cuda()
out = linear(x)
gather_out = linear_conv_col(x)
assert_close(rearrange(out, 1), gather_out)
gather_out = linear_col(x)
assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True)
assert_close(target_grad, linear_conv_col.weight.grad)
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [32, 32, 16], None, False)
assert_close(target_grad, linear_col.weight.grad)
@parameterize("lazy_init", [False, True])
def check_linear_conv_1d_row(lazy_init: bool):
def check_linear_1d_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
linear = nn.Linear(80, 8).cuda()
with ctx:
linear_copy = Conv1D(192, 48).cuda()
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
linear_copy = nn.Linear(80, 8).cuda()
linear_row = FusedLinear1D_Row.from_native_module(
linear_copy, process_group=None, split_sizes=[32, 32, 16], parallel_input=False
)
assert linear.weight.shape == torch.Size([48, 192])
assert linear_row.weight.shape == torch.Size([24, 192])
assert linear_row.bias.shape == torch.Size([192])
assert linear.weight.shape == torch.Size([8, 80])
assert linear_row.weight.shape == torch.Size([8, 40])
assert linear_row.bias.shape == torch.Size([8])
assert linear_copy.weight is linear_row.weight
assert linear_copy.bias is linear_row.bias
@ -111,7 +72,7 @@ def check_linear_conv_1d_row(lazy_init: bool):
linear.load_state_dict(linear_row.state_dict())
# check computation correctness
x = torch.rand(4, 48).cuda()
x = torch.rand(4, 80).cuda()
out = linear(x)
gather_out = linear_row(x)
assert_close(out, gather_out)
@ -120,17 +81,51 @@ def check_linear_conv_1d_row(lazy_init: bool):
out.sum().backward()
gather_out.sum().backward()
rank = dist.get_rank()
target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank]
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [32, 32, 16], None, True)
assert_close(target_grad, linear_row.weight.grad)
@parameterize("lazy_init", [False, True])
def check_linear_1d_col_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear1 = nn.Linear(8, 80).cuda()
linear2 = nn.Linear(80, 8).cuda()
with ctx:
linear1_copy = nn.Linear(8, 80).cuda()
linear2_copy = nn.Linear(80, 8).cuda()
linear_col = FusedLinear1D_Col.from_native_module(linear1_copy, process_group=None, split_sizes=[32, 32, 16])
linear_row = FusedLinear1D_Row.from_native_module(
linear2_copy,
process_group=None,
split_sizes=[32, 32, 16],
)
# ensure weights are reversibly loadable
linear_col.load_state_dict(linear1.state_dict())
linear_row.load_state_dict(linear2.state_dict())
# check computation correctness
x = torch.rand(4, 8).cuda()
target_out = linear2(linear1(x))
out = linear_row(linear_col(x))
assert_close(out, target_out)
# check backward correctness
target_out.sum().backward()
out.sum().backward()
target_grad1 = split_fused_qkv_in_gpt2_style(linear1.weight.grad, [32, 32, 16], None, False)
assert_close(target_grad1, linear_col.weight.grad)
target_grad2 = split_fused_qkv_in_gpt2_style(linear2.weight.grad, [32, 32, 16], None, True)
assert_close(target_grad2, linear_row.weight.grad)
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# test for linear conv
check_linear_conv_1d_col()
check_linear_conv_1d_row()
check_linear_1d_col()
check_linear_1d_row()
check_linear_1d_col_row()
@rerun_if_address_is_in_use()

Loading…
Cancel
Save