mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] fix linear 1d row and support uneven splits for fused qkv linear (#6084)
* [tp] hotfix linear row * [tp] support uneven split for fused linear * [tp] support sp for fused linear * [tp] fix gpt2 mlp policy * [tp] fix gather fused and add fused linear rowsupercooledith-patch-1
parent
f4daf04270
commit
646b3c5a90
|
@ -57,7 +57,9 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
|
||||||
target_module=NopadBaichuanMLP,
|
target_module=NopadBaichuanMLP,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.W_pack", target_module=FusedLinear1D_Col, kwargs={"n_fused": 3}
|
suffix="self_attn.W_pack",
|
||||||
|
target_module=FusedLinear1D_Col,
|
||||||
|
kwargs={"split_sizes": [self.model.config.hidden_size] * 3},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.o_proj",
|
suffix="self_attn.o_proj",
|
||||||
|
|
|
@ -6,7 +6,7 @@ from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHe
|
||||||
from .loss import cross_entropy_1d, dist_cross_entropy
|
from .loss import cross_entropy_1d, dist_cross_entropy
|
||||||
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
||||||
from .parallel_module import ParallelModule
|
from .parallel_module import ParallelModule
|
||||||
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Embedding1D",
|
"Embedding1D",
|
||||||
|
@ -34,4 +34,5 @@ __all__ = [
|
||||||
"RingAttention",
|
"RingAttention",
|
||||||
"get_pad_info",
|
"get_pad_info",
|
||||||
"all_to_all_comm",
|
"all_to_all_comm",
|
||||||
|
"FusedLinear1D_Row",
|
||||||
]
|
]
|
||||||
|
|
|
@ -840,7 +840,7 @@ class _AllToAll(torch.autograd.Function):
|
||||||
ctx.gather_dim = gather_dim
|
ctx.gather_dim = gather_dim
|
||||||
ctx.fp8_communication = fp8_communication
|
ctx.fp8_communication = fp8_communication
|
||||||
world_size = dist.get_world_size(process_group)
|
world_size = dist.get_world_size(process_group)
|
||||||
bsz, _, _ = input_.shape
|
bsz = input_.shape[0]
|
||||||
|
|
||||||
# using all_to_all_single when batch size is 1
|
# using all_to_all_single when batch size is 1
|
||||||
if bsz == 1:
|
if bsz == 1:
|
||||||
|
@ -871,7 +871,7 @@ class _AllToAll(torch.autograd.Function):
|
||||||
gather_dim = ctx.scatter_dim
|
gather_dim = ctx.scatter_dim
|
||||||
fp8_communication = ctx.fp8_communication
|
fp8_communication = ctx.fp8_communication
|
||||||
world_size = dist.get_world_size(process_group)
|
world_size = dist.get_world_size(process_group)
|
||||||
bsz, _, _ = grad_output.shape
|
bsz = grad_output.shape[0]
|
||||||
|
|
||||||
if bsz == 1:
|
if bsz == 1:
|
||||||
return_grad = _all_to_all_single(
|
return_grad = _all_to_all_single(
|
||||||
|
|
|
@ -428,11 +428,8 @@ class Linear1D_Row(ParallelModule):
|
||||||
handle.wait()
|
handle.wait()
|
||||||
output = torch.cat(output_parallel_list, dim=-1)
|
output = torch.cat(output_parallel_list, dim=-1)
|
||||||
else:
|
else:
|
||||||
if self.seq_parallel_mode is None:
|
if self.seq_parallel_mode == "split_gather":
|
||||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
output_parallel = F.linear(input_, self.weight)
|
||||||
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
|
||||||
elif self.seq_parallel_mode == "split_gather":
|
|
||||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
|
||||||
output = reducescatter_forward_gather_backward(
|
output = reducescatter_forward_gather_backward(
|
||||||
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||||
)
|
)
|
||||||
|
@ -445,8 +442,8 @@ class Linear1D_Row(ParallelModule):
|
||||||
ring=True,
|
ring=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
output_parallel = F.linear(input_, self.weight)
|
||||||
output = reduce_forward(output_parallel, self.process_group)
|
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
||||||
|
|
||||||
if not self.skip_bias_add:
|
if not self.skip_bias_add:
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
|
|
|
@ -7,6 +7,7 @@ from typing import Callable, List, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
@ -24,7 +25,9 @@ from colossalai.tensor.d_tensor.api import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._operation import (
|
from ._operation import (
|
||||||
gather_forward_split_backward,
|
gather_forward_reducescatter_backward,
|
||||||
|
linear_gather_forward_reducescatter_backward,
|
||||||
|
linear_reducescatter_forward_gather_backward,
|
||||||
linear_with_async_comm,
|
linear_with_async_comm,
|
||||||
matmul_gather_forward_reducescatter_backward,
|
matmul_gather_forward_reducescatter_backward,
|
||||||
matmul_with_async_comm,
|
matmul_with_async_comm,
|
||||||
|
@ -44,21 +47,25 @@ __all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col"
|
||||||
|
|
||||||
|
|
||||||
def split_fused_qkv_in_gpt2_style(
|
def split_fused_qkv_in_gpt2_style(
|
||||||
qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False
|
qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup, is_transposed: bool = False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2].
|
The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2].
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
qkv (torch.Tensor): The fused qkv tensor.
|
qkv (torch.Tensor): The fused qkv tensor.
|
||||||
n_fused (int): The number items fused together, defaults to 3 (query, key and value).
|
split_sizes (List[int]): The sizes of the split tensor.
|
||||||
process_group (ProcessGroup): The process group for distributed communication.
|
process_group (ProcessGroup): The process group for distributed communication.
|
||||||
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
|
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
|
||||||
"""
|
"""
|
||||||
# get the number of slice for the fused qkv
|
# get the number of slice for the fused qkv
|
||||||
rank = dist.get_rank(group=process_group)
|
rank = dist.get_rank(group=process_group)
|
||||||
world_size = dist.get_world_size(group=process_group)
|
world_size = dist.get_world_size(group=process_group)
|
||||||
order = torch.arange(world_size * n_fused)
|
order = torch.arange(world_size * len(split_sizes))
|
||||||
|
new_split_sizes = []
|
||||||
|
for sz in split_sizes:
|
||||||
|
assert sz % world_size == 0, f"size {sz} is not divisible by world_size {world_size}"
|
||||||
|
new_split_sizes.extend([sz // world_size] * world_size)
|
||||||
|
|
||||||
# split the fused qkv
|
# split the fused qkv
|
||||||
# from
|
# from
|
||||||
|
@ -66,9 +73,9 @@ def split_fused_qkv_in_gpt2_style(
|
||||||
# to
|
# to
|
||||||
# [Q1, Q2, K1, K2, V1, V2]
|
# [Q1, Q2, K1, K2, V1, V2]
|
||||||
if is_transposed:
|
if is_transposed:
|
||||||
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1)
|
weight_chunks = torch.split(qkv, new_split_sizes, dim=-1)
|
||||||
else:
|
else:
|
||||||
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=0)
|
weight_chunks = torch.split(qkv, new_split_sizes, dim=0)
|
||||||
|
|
||||||
# rearrange the slice into the final order
|
# rearrange the slice into the final order
|
||||||
# from
|
# from
|
||||||
|
@ -85,18 +92,23 @@ def split_fused_qkv_in_gpt2_style(
|
||||||
|
|
||||||
|
|
||||||
def gather_fused_qkv_in_gpt2_style(
|
def gather_fused_qkv_in_gpt2_style(
|
||||||
qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False
|
qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup, is_transposed: bool = False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2].
|
The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2].
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
qkv (torch.Tensor): The fused qkv tensor.
|
qkv (torch.Tensor): The fused qkv tensor.
|
||||||
n_fused (int): The number items fused together, defaults to 3 (query, key and value).
|
split_sizes (List[int]): The sizes of the split tensor.
|
||||||
process_group (ProcessGroup): The process group for distributed communication.
|
process_group (ProcessGroup): The process group for distributed communication.
|
||||||
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
|
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
|
||||||
"""
|
"""
|
||||||
world_size = dist.get_world_size(group=process_group)
|
world_size = dist.get_world_size(group=process_group)
|
||||||
|
new_split_sizes = []
|
||||||
|
for sz in split_sizes:
|
||||||
|
assert sz % world_size == 0, f"size {sz} is not divisible by world_size {world_size}"
|
||||||
|
new_split_sizes.append(sz // world_size)
|
||||||
|
new_split_sizes = new_split_sizes * world_size
|
||||||
|
|
||||||
# gather the tensors
|
# gather the tensors
|
||||||
# from
|
# from
|
||||||
|
@ -121,13 +133,13 @@ def gather_fused_qkv_in_gpt2_style(
|
||||||
# to
|
# to
|
||||||
# [Q1, Q2, K1, K2, V1, V2]
|
# [Q1, Q2, K1, K2, V1, V2]
|
||||||
if is_transposed:
|
if is_transposed:
|
||||||
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1)
|
weight_chunks = torch.split(gather_weight, new_split_sizes, dim=-1)
|
||||||
else:
|
else:
|
||||||
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=0)
|
weight_chunks = torch.split(gather_weight, new_split_sizes, dim=0)
|
||||||
|
|
||||||
reordered_chunk_list = []
|
reordered_chunk_list = []
|
||||||
for i in range(n_fused):
|
for i in range(len(split_sizes)):
|
||||||
reordered_chunk_list.extend(weight_chunks[i::n_fused])
|
reordered_chunk_list.extend(weight_chunks[i :: len(split_sizes)])
|
||||||
|
|
||||||
if is_transposed:
|
if is_transposed:
|
||||||
reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1)
|
reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1)
|
||||||
|
@ -136,6 +148,42 @@ def gather_fused_qkv_in_gpt2_style(
|
||||||
return reordered_gather_weight
|
return reordered_gather_weight
|
||||||
|
|
||||||
|
|
||||||
|
class _SplitForwardGatherBackwardFusedQKV(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
|
||||||
|
ctx.split_sizes = split_sizes
|
||||||
|
ctx.process_group = process_group
|
||||||
|
return split_fused_qkv_in_gpt2_style(qkv, split_sizes, process_group, is_transposed=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
grad_output = gather_fused_qkv_in_gpt2_style(
|
||||||
|
grad_output, ctx.split_sizes, ctx.process_group, is_transposed=True
|
||||||
|
)
|
||||||
|
return grad_output, None, None
|
||||||
|
|
||||||
|
|
||||||
|
def split_forward_gather_backward_fused_qkv(qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
|
||||||
|
return _SplitForwardGatherBackwardFusedQKV.apply(qkv, split_sizes, process_group)
|
||||||
|
|
||||||
|
|
||||||
|
class _GatherForwardSplitBackwardFusedQKV(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
|
||||||
|
ctx.split_sizes = split_sizes
|
||||||
|
ctx.process_group = process_group
|
||||||
|
return gather_fused_qkv_in_gpt2_style(qkv, split_sizes, process_group, is_transposed=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
grad_output = split_fused_qkv_in_gpt2_style(grad_output, ctx.split_sizes, ctx.process_group, is_transposed=True)
|
||||||
|
return grad_output, None, None
|
||||||
|
|
||||||
|
|
||||||
|
def gather_forward_split_backward_fused_qkv(qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
|
||||||
|
return _GatherForwardSplitBackwardFusedQKV.apply(qkv, split_sizes, process_group)
|
||||||
|
|
||||||
|
|
||||||
class GPT2FusedLinearConv1D_Col(ParallelModule):
|
class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
r"""Linear layer with column parallelism.
|
r"""Linear layer with column parallelism.
|
||||||
|
|
||||||
|
@ -145,10 +193,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
Args:
|
Args:
|
||||||
in_features (int): size of each input sample.
|
in_features (int): size of each input sample.
|
||||||
out_features (int): size of each output sample.
|
out_features (int): size of each output sample.
|
||||||
|
split_sizes (List[int]): The sizes of the split tensor.
|
||||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||||
device (`torch.device`): The device of parameters, defaults to None.
|
device (`torch.device`): The device of parameters, defaults to None.
|
||||||
n_fused (int): The number items fused, defaults to 3 (QKV).
|
|
||||||
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||||
seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None.
|
seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None.
|
||||||
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
||||||
|
@ -169,6 +217,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
self,
|
self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
out_features: int,
|
out_features: int,
|
||||||
|
split_sizes: List[int],
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
dtype: torch.dtype = None,
|
dtype: torch.dtype = None,
|
||||||
device: torch.device = None,
|
device: torch.device = None,
|
||||||
|
@ -178,7 +227,6 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
seq_parallel_mode: str = None,
|
seq_parallel_mode: str = None,
|
||||||
overlap: bool = False,
|
overlap: bool = False,
|
||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
n_fused: int = 3,
|
|
||||||
weight: Optional[Parameter] = None,
|
weight: Optional[Parameter] = None,
|
||||||
bias_: Optional[Parameter] = None,
|
bias_: Optional[Parameter] = None,
|
||||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
|
@ -195,11 +243,15 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
self.overlap = overlap
|
self.overlap = overlap
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
self.device = device
|
self.device = device
|
||||||
self.n_fused = n_fused
|
self.split_sizes = split_sizes
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.async_communication = async_communication
|
self.async_communication = async_communication
|
||||||
self.fp8_communication = fp8_communication
|
self.fp8_communication = fp8_communication
|
||||||
|
|
||||||
|
assert (
|
||||||
|
sum(split_sizes) == out_features
|
||||||
|
), f"The sum of split_sizes({sum(split_sizes)}) should be equal to out_features({out_features})."
|
||||||
|
|
||||||
if skip_bias_add and not bias:
|
if skip_bias_add and not bias:
|
||||||
raise ValueError("cannot skip bias addition if bias is None")
|
raise ValueError("cannot skip bias addition if bias is None")
|
||||||
|
|
||||||
|
@ -223,10 +275,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
|
|
||||||
def shard_fn(tensor):
|
def shard_fn(tensor):
|
||||||
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)
|
return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
|
||||||
|
|
||||||
def gather_fn(tensor):
|
def gather_fn(tensor):
|
||||||
return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)
|
return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
|
||||||
|
|
||||||
if not is_customized_distributed_tensor(self.weight):
|
if not is_customized_distributed_tensor(self.weight):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -252,7 +304,11 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(
|
def from_native_module(
|
||||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
module: nn.Module,
|
||||||
|
process_group: Union[ProcessGroup, List[ProcessGroup]],
|
||||||
|
split_sizes: List[int],
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
) -> ParallelModule:
|
) -> ParallelModule:
|
||||||
r"""
|
r"""
|
||||||
Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer.
|
Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer.
|
||||||
|
@ -260,7 +316,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
Args:
|
Args:
|
||||||
module (`nn.Linear`): The module to be converted.
|
module (`nn.Linear`): The module to be converted.
|
||||||
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
|
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
|
||||||
n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight.
|
split_sizes (List[int]): The sizes of the split tensor. In GPT2, Q,K,V are fused in one weight.
|
||||||
"""
|
"""
|
||||||
LazyInitContext.materialize(module)
|
LazyInitContext.materialize(module)
|
||||||
# get the attributes
|
# get the attributes
|
||||||
|
@ -291,6 +347,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
weight=module.weight,
|
weight=module.weight,
|
||||||
bias_=module.bias,
|
bias_=module.bias,
|
||||||
|
split_sizes=split_sizes,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
@ -354,9 +411,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
|
|
||||||
if self.gather_output:
|
if self.gather_output:
|
||||||
# All-gather across the partitions.
|
# All-gather across the partitions.
|
||||||
output = gather_forward_split_backward(
|
output = gather_forward_split_backward_fused_qkv(output_parallel, self.split_sizes, self.process_group)
|
||||||
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
output = output_parallel
|
output = output_parallel
|
||||||
|
|
||||||
|
@ -605,10 +660,10 @@ class FusedLinear1D_Col(ParallelModule):
|
||||||
Args:
|
Args:
|
||||||
in_features (int): size of each input sample.
|
in_features (int): size of each input sample.
|
||||||
out_features (int): size of each output sample.
|
out_features (int): size of each output sample.
|
||||||
|
split_sizes (List[int]): The sizes of the split tensor.
|
||||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||||
device (`torch.device`): The device of parameters, defaults to None.
|
device (`torch.device`): The device of parameters, defaults to None.
|
||||||
n_fused (int): The number items fused, defaults to 3 (QKV).
|
|
||||||
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||||
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
||||||
to all GPUs, otherwise, every GPU will have its output
|
to all GPUs, otherwise, every GPU will have its output
|
||||||
|
@ -628,14 +683,16 @@ class FusedLinear1D_Col(ParallelModule):
|
||||||
self,
|
self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
out_features: int,
|
out_features: int,
|
||||||
|
split_sizes: List[int],
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
dtype: torch.dtype = None,
|
dtype: torch.dtype = None,
|
||||||
device: torch.device = None,
|
device: torch.device = None,
|
||||||
process_group: ProcessGroup = None,
|
process_group: ProcessGroup = None,
|
||||||
async_communication: bool = False,
|
|
||||||
gather_output: bool = False,
|
gather_output: bool = False,
|
||||||
|
seq_parallel_mode: str = None,
|
||||||
|
seq_parallel_dim: int = 1,
|
||||||
|
overlap: torch.cuda.Stream = None,
|
||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
n_fused: int = 3,
|
|
||||||
weight: Optional[Parameter] = None,
|
weight: Optional[Parameter] = None,
|
||||||
bias_: Optional[Parameter] = None,
|
bias_: Optional[Parameter] = None,
|
||||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
|
@ -647,13 +704,19 @@ class FusedLinear1D_Col(ParallelModule):
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
self.gather_output = gather_output
|
self.gather_output = gather_output
|
||||||
|
self.seq_parallel_mode = seq_parallel_mode
|
||||||
|
self.seq_parallel_dim = seq_parallel_dim
|
||||||
|
self.overlap = overlap
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
self.device = device
|
self.device = device
|
||||||
self.n_fused = n_fused
|
self.split_sizes = split_sizes
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.async_communication = async_communication
|
|
||||||
self.fp8_communication = fp8_communication
|
self.fp8_communication = fp8_communication
|
||||||
|
|
||||||
|
assert (
|
||||||
|
sum(split_sizes) == out_features
|
||||||
|
), f"The sum of split_sizes({sum(split_sizes)}) should be equal to out_features({out_features})."
|
||||||
|
|
||||||
if skip_bias_add and not bias:
|
if skip_bias_add and not bias:
|
||||||
raise ValueError("cannot skip bias addition if bias is None")
|
raise ValueError("cannot skip bias addition if bias is None")
|
||||||
|
|
||||||
|
@ -677,10 +740,10 @@ class FusedLinear1D_Col(ParallelModule):
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
|
|
||||||
def shard_fn(tensor):
|
def shard_fn(tensor):
|
||||||
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
|
return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, False)
|
||||||
|
|
||||||
def gather_fn(tensor):
|
def gather_fn(tensor):
|
||||||
return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
|
return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, False)
|
||||||
|
|
||||||
if not is_customized_distributed_tensor(self.weight):
|
if not is_customized_distributed_tensor(self.weight):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -706,7 +769,11 @@ class FusedLinear1D_Col(ParallelModule):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(
|
def from_native_module(
|
||||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, *args, **kwargs
|
module: nn.Module,
|
||||||
|
process_group: Union[ProcessGroup, List[ProcessGroup]],
|
||||||
|
split_sizes: List[int],
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
) -> ParallelModule:
|
) -> ParallelModule:
|
||||||
r"""
|
r"""
|
||||||
Convert a fused `torch.nn.linear` layer to a parallelized linear layer.
|
Convert a fused `torch.nn.linear` layer to a parallelized linear layer.
|
||||||
|
@ -714,7 +781,7 @@ class FusedLinear1D_Col(ParallelModule):
|
||||||
Args:
|
Args:
|
||||||
module (`nn.Linear`): The module to be converted.
|
module (`nn.Linear`): The module to be converted.
|
||||||
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
|
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
|
||||||
n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight.
|
split_sizes (List[int]): The sizes of the split tensor. In common, Q,K,V are fused in one weight.
|
||||||
"""
|
"""
|
||||||
LazyInitContext.materialize(module)
|
LazyInitContext.materialize(module)
|
||||||
|
|
||||||
|
@ -737,25 +804,11 @@ class FusedLinear1D_Col(ParallelModule):
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
weight=module.weight,
|
weight=module.weight,
|
||||||
bias_=module.bias,
|
bias_=module.bias,
|
||||||
n_fused=n_fused,
|
split_sizes=split_sizes,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# # TODO: copy the sharded weights
|
|
||||||
# with torch.no_grad():
|
|
||||||
# sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
|
|
||||||
# n_fused=n_fused,
|
|
||||||
# process_group=process_group,
|
|
||||||
# is_transposed=False)
|
|
||||||
# linear_1d.weight.data.copy_(sharded_weight.data)
|
|
||||||
|
|
||||||
# if bias:
|
|
||||||
# sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
|
|
||||||
# n_fused=n_fused,
|
|
||||||
# process_group=process_group,
|
|
||||||
# is_transposed=False)
|
|
||||||
# linear_1d.bias.data.copy_(sharded_bias.data)
|
|
||||||
return linear_1d
|
return linear_1d
|
||||||
|
|
||||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||||
|
@ -772,19 +825,30 @@ class FusedLinear1D_Col(ParallelModule):
|
||||||
input_.shape, self.weight.shape, self.weight.shape[-1]
|
input_.shape, self.weight.shape, self.weight.shape[-1]
|
||||||
)
|
)
|
||||||
# Set up backprop all-reduce.
|
# Set up backprop all-reduce.
|
||||||
# input_parallel = reduce_backward(input_, self.process_group)
|
|
||||||
input_parallel = input_
|
input_parallel = input_
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
|
||||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
if self.seq_parallel_mode == "split_gather":
|
||||||
|
input_parallel = gather_forward_reducescatter_backward(
|
||||||
|
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||||
|
)
|
||||||
|
output_parallel = linear_with_async_comm(
|
||||||
|
input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication
|
||||||
|
)
|
||||||
|
elif self.seq_parallel_mode == "ring":
|
||||||
|
output_parallel = linear_gather_forward_reducescatter_backward(
|
||||||
|
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_parallel = linear_with_async_comm(
|
||||||
|
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
|
||||||
|
)
|
||||||
|
|
||||||
if self.gather_output:
|
if self.gather_output:
|
||||||
# All-gather across the partitions.
|
# All-gather across the partitions.
|
||||||
output = gather_forward_split_backward(
|
output = gather_forward_split_backward_fused_qkv(output_parallel, self.split_sizes, self.process_group)
|
||||||
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
output = output_parallel
|
output = output_parallel
|
||||||
|
|
||||||
|
@ -792,3 +856,201 @@ class FusedLinear1D_Col(ParallelModule):
|
||||||
return output, self.bias
|
return output, self.bias
|
||||||
else:
|
else:
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class FusedLinear1D_Row(ParallelModule):
|
||||||
|
r"""Linear layer with row parallelism
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_features (int): size of each input sample.
|
||||||
|
out_features (int): size of each output sample.
|
||||||
|
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||||
|
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||||
|
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
|
||||||
|
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||||
|
seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None.
|
||||||
|
seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence.
|
||||||
|
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||||
|
which is preserved for kernel fusion, defaults to False
|
||||||
|
weight_initializer (:class:`typing.Callable`, optional):
|
||||||
|
The initializer of weight, defaults to kaiming uniform initializer.
|
||||||
|
bias_initializer (:class:`typing.Callable`, optional):
|
||||||
|
The initializer of bias, defaults to xavier uniform initializer.
|
||||||
|
|
||||||
|
More details about ``initializer`` please refer to
|
||||||
|
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
split_sizes: List[int],
|
||||||
|
bias: bool = True,
|
||||||
|
dtype: torch.dtype = None,
|
||||||
|
device: torch.device = None,
|
||||||
|
process_group: ProcessGroup = None,
|
||||||
|
seq_parallel_mode: str = None,
|
||||||
|
seq_parallel_dim: int = 1,
|
||||||
|
parallel_input: bool = True,
|
||||||
|
skip_bias_add: bool = False,
|
||||||
|
weight: Optional[Parameter] = None,
|
||||||
|
bias_: Optional[Parameter] = None,
|
||||||
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||||
|
fp8_communication: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# Keep input parameters
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
self.split_sizes = split_sizes
|
||||||
|
self.parallel_input = parallel_input
|
||||||
|
self.skip_bias_add = skip_bias_add
|
||||||
|
self.process_group = process_group
|
||||||
|
self.seq_parallel_mode = seq_parallel_mode
|
||||||
|
self.seq_parallel_dim = seq_parallel_dim
|
||||||
|
self.num_partitions = dist.get_world_size(self.process_group)
|
||||||
|
self.fp8_communication = fp8_communication
|
||||||
|
|
||||||
|
assert (
|
||||||
|
sum(split_sizes) == in_features
|
||||||
|
), f"The sum of split_sizes({sum(split_sizes)}) should be equal to in_features({in_features})."
|
||||||
|
|
||||||
|
if skip_bias_add and not bias:
|
||||||
|
raise ValueError("cannot skip bias addition if bias is None")
|
||||||
|
|
||||||
|
# offset the seed with randomizer index and rank
|
||||||
|
seed = torch.random.initial_seed()
|
||||||
|
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||||
|
|
||||||
|
# sanity check
|
||||||
|
if weight is not None:
|
||||||
|
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
|
||||||
|
else:
|
||||||
|
assert bias_ is None, "bias_ must be None if weight is None"
|
||||||
|
|
||||||
|
# Parameters.
|
||||||
|
if weight is None:
|
||||||
|
# Initialize weight.
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
|
||||||
|
else:
|
||||||
|
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||||
|
self.weight = weight
|
||||||
|
|
||||||
|
def shard_fn(tensor):
|
||||||
|
return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
|
||||||
|
|
||||||
|
def gather_fn(tensor):
|
||||||
|
return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
|
||||||
|
|
||||||
|
if not is_customized_distributed_tensor(self.weight):
|
||||||
|
with torch.no_grad():
|
||||||
|
sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn)
|
||||||
|
customized_distributed_tensor_to_existing_param(sharded_weight, self.weight)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
if bias_ is None:
|
||||||
|
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||||
|
else:
|
||||||
|
bias_.data = bias_.data.to(device=device, dtype=dtype)
|
||||||
|
self.bias = bias_
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
if weight is None:
|
||||||
|
with self.randomizer.fork_rng(enable_cpu=True):
|
||||||
|
self.reset_parameters(weight_initializer, bias_initializer)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_native_module(
|
||||||
|
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], split_sizes: List[int], **kwargs
|
||||||
|
) -> ParallelModule:
|
||||||
|
r"""
|
||||||
|
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||||
|
"""
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
|
# get the attributes
|
||||||
|
in_features = module.in_features
|
||||||
|
out_features = module.out_features
|
||||||
|
bias = module.bias is not None
|
||||||
|
device = module.weight.device
|
||||||
|
|
||||||
|
# ensure only one process group is passed
|
||||||
|
if isinstance(process_group, (list, tuple)):
|
||||||
|
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||||
|
process_group = process_group[0]
|
||||||
|
|
||||||
|
linear_1d = FusedLinear1D_Row(
|
||||||
|
in_features=in_features,
|
||||||
|
out_features=out_features,
|
||||||
|
bias=bias,
|
||||||
|
device=device,
|
||||||
|
process_group=process_group,
|
||||||
|
weight=module.weight,
|
||||||
|
bias_=module.bias,
|
||||||
|
split_sizes=split_sizes,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return linear_1d
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||||
|
fan_in, fan_out = self.in_features, self.out_features
|
||||||
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||||
|
|
||||||
|
if self.bias is not None:
|
||||||
|
bias_initializer(self.bias, fan_in=fan_in)
|
||||||
|
if self.process_group is None:
|
||||||
|
src_rank = 0
|
||||||
|
else:
|
||||||
|
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
|
||||||
|
|
||||||
|
origin_device = self.bias.device
|
||||||
|
bias = self.bias.cuda()
|
||||||
|
dist.broadcast(bias, src=src_rank, group=self.process_group)
|
||||||
|
bias = bias.to(origin_device)
|
||||||
|
self.bias.copy_(bias)
|
||||||
|
|
||||||
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
|
# Set up backprop all-reduce.
|
||||||
|
if self.parallel_input:
|
||||||
|
assert (
|
||||||
|
input_.shape[-1] == self.weight.shape[-1]
|
||||||
|
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||||
|
input_.shape, self.weight.shape, self.weight.shape[-1]
|
||||||
|
)
|
||||||
|
input_ = input_
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1]
|
||||||
|
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||||
|
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
|
||||||
|
)
|
||||||
|
input_ = split_forward_gather_backward_fused_qkv(input_, self.split_sizes, self.process_group)
|
||||||
|
|
||||||
|
if self.seq_parallel_mode == "split_gather":
|
||||||
|
output_parallel = F.linear(input_, self.weight)
|
||||||
|
output = reducescatter_forward_gather_backward(
|
||||||
|
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||||
|
)
|
||||||
|
elif self.seq_parallel_mode == "ring":
|
||||||
|
output = linear_reducescatter_forward_gather_backward(
|
||||||
|
input_,
|
||||||
|
self.weight,
|
||||||
|
process_group=self.process_group,
|
||||||
|
dim=self.seq_parallel_dim,
|
||||||
|
ring=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_parallel = F.linear(input_, self.weight)
|
||||||
|
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
||||||
|
|
||||||
|
if not self.skip_bias_add:
|
||||||
|
if self.bias is not None:
|
||||||
|
output = output + self.bias
|
||||||
|
return output
|
||||||
|
else:
|
||||||
|
return output, self.bias
|
||||||
|
|
|
@ -71,7 +71,7 @@ class BlipPolicy(Policy):
|
||||||
suffix="self_attn.qkv",
|
suffix="self_attn.qkv",
|
||||||
target_module=col_nn.FusedLinear1D_Col,
|
target_module=col_nn.FusedLinear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_fused": 3,
|
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
|
|
@ -92,7 +92,7 @@ class GPT2Policy(Policy):
|
||||||
suffix="attn.c_attn",
|
suffix="attn.c_attn",
|
||||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_fused": 3,
|
"split_sizes": [self.model.config.hidden_size] * 3,
|
||||||
"seq_parallel_mode": sp_mode,
|
"seq_parallel_mode": sp_mode,
|
||||||
"overlap": overlap,
|
"overlap": overlap,
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
@ -107,7 +107,7 @@ class GPT2Policy(Policy):
|
||||||
suffix="mlp.c_fc",
|
suffix="mlp.c_fc",
|
||||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_fused": 1,
|
"split_sizes": [self.model.config.n_inner or 4 * self.model.config.hidden_size],
|
||||||
"seq_parallel_mode": sp_mode,
|
"seq_parallel_mode": sp_mode,
|
||||||
"overlap": overlap,
|
"overlap": overlap,
|
||||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||||
|
|
|
@ -42,7 +42,7 @@ class SamPolicy(Policy):
|
||||||
suffix="attn.qkv",
|
suffix="attn.qkv",
|
||||||
target_module=col_nn.FusedLinear1D_Col,
|
target_module=col_nn.FusedLinear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_fused": 3,
|
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
|
|
@ -41,21 +41,6 @@ class Conv1D(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def rearrange(tensor: torch.Tensor, dim: int):
|
|
||||||
tensor = tensor.clone()
|
|
||||||
world_size = 2
|
|
||||||
order = torch.arange(world_size * 3)
|
|
||||||
new_order = []
|
|
||||||
for i in range(world_size):
|
|
||||||
new_order.append(order[i::world_size])
|
|
||||||
new_order = torch.cat(new_order)
|
|
||||||
|
|
||||||
tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim)
|
|
||||||
rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order]
|
|
||||||
rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim)
|
|
||||||
return rearanged_tensor
|
|
||||||
|
|
||||||
|
|
||||||
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool):
|
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool):
|
||||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
linear = Conv1D(192, 48).cuda()
|
linear = Conv1D(192, 48).cuda()
|
||||||
|
@ -66,7 +51,7 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
|
||||||
process_group=None,
|
process_group=None,
|
||||||
gather_output=True,
|
gather_output=True,
|
||||||
seq_parallel_mode=seq_parallel_mode,
|
seq_parallel_mode=seq_parallel_mode,
|
||||||
n_fused=3,
|
split_sizes=[64] * 3,
|
||||||
overlap=overlap,
|
overlap=overlap,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -88,13 +73,13 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
|
||||||
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
||||||
)
|
)
|
||||||
gather_out = linear_conv_col(x_for_shard)
|
gather_out = linear_conv_col(x_for_shard)
|
||||||
assert_close(rearrange(out, -1), gather_out)
|
assert_close(out, gather_out)
|
||||||
|
|
||||||
# check backward correctness
|
# check backward correctness
|
||||||
out.sum().backward()
|
out.sum().backward()
|
||||||
gather_out.sum().backward()
|
gather_out.sum().backward()
|
||||||
|
|
||||||
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True)
|
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [64] * 3, None, True)
|
||||||
assert_close(target_grad, linear_conv_col.weight.grad)
|
assert_close(target_grad, linear_conv_col.weight.grad)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,13 +2,12 @@ import os
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
from colossalai.shardformer.layer import FusedLinear1D_Col, FusedLinear1D_Row
|
||||||
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
|
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
@ -16,93 +15,55 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
||||||
|
|
||||||
|
|
||||||
class Conv1D(nn.Module):
|
|
||||||
"""
|
|
||||||
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
|
|
||||||
|
|
||||||
Basically works like a linear layer but the weights are transposed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
nf (`int`): The number of output features.
|
|
||||||
nx (`int`): The number of input features.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, nf, nx):
|
|
||||||
super().__init__()
|
|
||||||
self.nf = nf
|
|
||||||
self.weight = nn.Parameter(torch.empty(nx, nf))
|
|
||||||
self.bias = nn.Parameter(torch.zeros(nf))
|
|
||||||
nn.init.normal_(self.weight, std=0.02)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
size_out = x.size()[:-1] + (self.nf,)
|
|
||||||
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
|
||||||
x = x.view(size_out)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def rearrange(tensor: torch.Tensor, dim: int):
|
|
||||||
tensor = tensor.clone()
|
|
||||||
world_size = 2
|
|
||||||
order = torch.arange(world_size * 3)
|
|
||||||
new_order = []
|
|
||||||
for i in range(world_size):
|
|
||||||
new_order.append(order[i::world_size])
|
|
||||||
new_order = torch.cat(new_order)
|
|
||||||
|
|
||||||
tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim)
|
|
||||||
rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order]
|
|
||||||
rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim)
|
|
||||||
return rearanged_tensor
|
|
||||||
|
|
||||||
|
|
||||||
@parameterize("lazy_init", [False, True])
|
@parameterize("lazy_init", [False, True])
|
||||||
def check_linear_conv_1d_col(lazy_init: bool):
|
def check_linear_1d_col(lazy_init: bool):
|
||||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
linear = Conv1D(192, 48).cuda()
|
linear = nn.Linear(8, 80).cuda()
|
||||||
with ctx:
|
with ctx:
|
||||||
linear_copy = Conv1D(192, 48).cuda()
|
linear_copy = nn.Linear(8, 80).cuda()
|
||||||
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(
|
linear_col = FusedLinear1D_Col.from_native_module(
|
||||||
linear_copy, process_group=None, gather_output=True, n_fused=3
|
linear_copy, process_group=None, gather_output=True, split_sizes=[32, 32, 16]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert linear.weight.shape == torch.Size([48, 192])
|
assert linear.weight.shape == torch.Size([80, 8])
|
||||||
assert linear.bias.shape == torch.Size([192])
|
assert linear.bias.shape == torch.Size([80])
|
||||||
assert linear_conv_col.weight.shape == torch.Size([48, 96])
|
assert linear_col.weight.shape == torch.Size([40, 8])
|
||||||
assert linear_conv_col.bias.shape == torch.Size([96])
|
assert linear_col.bias.shape == torch.Size([40])
|
||||||
assert linear_copy.weight is linear_conv_col.weight
|
assert linear_copy.weight is linear_col.weight
|
||||||
assert linear_copy.bias is linear_conv_col.bias
|
assert linear_copy.bias is linear_col.bias
|
||||||
|
|
||||||
# ensure weights are reversibly loadable
|
# ensure weights are reversibly loadable
|
||||||
linear_conv_col.load_state_dict(linear.state_dict())
|
linear_col.load_state_dict(linear.state_dict())
|
||||||
linear.load_state_dict(linear_conv_col.state_dict())
|
linear.load_state_dict(linear_col.state_dict())
|
||||||
|
|
||||||
# check computation correctness
|
# check computation correctness
|
||||||
x = torch.rand(4, 48).cuda()
|
x = torch.rand(4, 8).cuda()
|
||||||
out = linear(x)
|
out = linear(x)
|
||||||
gather_out = linear_conv_col(x)
|
gather_out = linear_col(x)
|
||||||
assert_close(rearrange(out, 1), gather_out)
|
assert_close(out, gather_out)
|
||||||
|
|
||||||
# check backward correctness
|
# check backward correctness
|
||||||
out.sum().backward()
|
out.sum().backward()
|
||||||
gather_out.sum().backward()
|
gather_out.sum().backward()
|
||||||
|
|
||||||
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True)
|
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [32, 32, 16], None, False)
|
||||||
assert_close(target_grad, linear_conv_col.weight.grad)
|
assert_close(target_grad, linear_col.weight.grad)
|
||||||
|
|
||||||
|
|
||||||
@parameterize("lazy_init", [False, True])
|
@parameterize("lazy_init", [False, True])
|
||||||
def check_linear_conv_1d_row(lazy_init: bool):
|
def check_linear_1d_row(lazy_init: bool):
|
||||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
|
||||||
linear = Conv1D(192, 48).cuda()
|
linear = nn.Linear(80, 8).cuda()
|
||||||
with ctx:
|
with ctx:
|
||||||
linear_copy = Conv1D(192, 48).cuda()
|
linear_copy = nn.Linear(80, 8).cuda()
|
||||||
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
|
linear_row = FusedLinear1D_Row.from_native_module(
|
||||||
|
linear_copy, process_group=None, split_sizes=[32, 32, 16], parallel_input=False
|
||||||
|
)
|
||||||
|
|
||||||
assert linear.weight.shape == torch.Size([48, 192])
|
assert linear.weight.shape == torch.Size([8, 80])
|
||||||
assert linear_row.weight.shape == torch.Size([24, 192])
|
assert linear_row.weight.shape == torch.Size([8, 40])
|
||||||
assert linear_row.bias.shape == torch.Size([192])
|
assert linear_row.bias.shape == torch.Size([8])
|
||||||
assert linear_copy.weight is linear_row.weight
|
assert linear_copy.weight is linear_row.weight
|
||||||
assert linear_copy.bias is linear_row.bias
|
assert linear_copy.bias is linear_row.bias
|
||||||
|
|
||||||
|
@ -111,7 +72,7 @@ def check_linear_conv_1d_row(lazy_init: bool):
|
||||||
linear.load_state_dict(linear_row.state_dict())
|
linear.load_state_dict(linear_row.state_dict())
|
||||||
|
|
||||||
# check computation correctness
|
# check computation correctness
|
||||||
x = torch.rand(4, 48).cuda()
|
x = torch.rand(4, 80).cuda()
|
||||||
out = linear(x)
|
out = linear(x)
|
||||||
gather_out = linear_row(x)
|
gather_out = linear_row(x)
|
||||||
assert_close(out, gather_out)
|
assert_close(out, gather_out)
|
||||||
|
@ -120,17 +81,51 @@ def check_linear_conv_1d_row(lazy_init: bool):
|
||||||
out.sum().backward()
|
out.sum().backward()
|
||||||
gather_out.sum().backward()
|
gather_out.sum().backward()
|
||||||
|
|
||||||
rank = dist.get_rank()
|
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [32, 32, 16], None, True)
|
||||||
target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank]
|
|
||||||
assert_close(target_grad, linear_row.weight.grad)
|
assert_close(target_grad, linear_row.weight.grad)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize("lazy_init", [False, True])
|
||||||
|
def check_linear_1d_col_row(lazy_init: bool):
|
||||||
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
|
||||||
|
linear1 = nn.Linear(8, 80).cuda()
|
||||||
|
linear2 = nn.Linear(80, 8).cuda()
|
||||||
|
with ctx:
|
||||||
|
linear1_copy = nn.Linear(8, 80).cuda()
|
||||||
|
linear2_copy = nn.Linear(80, 8).cuda()
|
||||||
|
linear_col = FusedLinear1D_Col.from_native_module(linear1_copy, process_group=None, split_sizes=[32, 32, 16])
|
||||||
|
linear_row = FusedLinear1D_Row.from_native_module(
|
||||||
|
linear2_copy,
|
||||||
|
process_group=None,
|
||||||
|
split_sizes=[32, 32, 16],
|
||||||
|
)
|
||||||
|
# ensure weights are reversibly loadable
|
||||||
|
linear_col.load_state_dict(linear1.state_dict())
|
||||||
|
linear_row.load_state_dict(linear2.state_dict())
|
||||||
|
|
||||||
|
# check computation correctness
|
||||||
|
x = torch.rand(4, 8).cuda()
|
||||||
|
target_out = linear2(linear1(x))
|
||||||
|
out = linear_row(linear_col(x))
|
||||||
|
assert_close(out, target_out)
|
||||||
|
|
||||||
|
# check backward correctness
|
||||||
|
target_out.sum().backward()
|
||||||
|
out.sum().backward()
|
||||||
|
|
||||||
|
target_grad1 = split_fused_qkv_in_gpt2_style(linear1.weight.grad, [32, 32, 16], None, False)
|
||||||
|
assert_close(target_grad1, linear_col.weight.grad)
|
||||||
|
target_grad2 = split_fused_qkv_in_gpt2_style(linear2.weight.grad, [32, 32, 16], None, True)
|
||||||
|
assert_close(target_grad2, linear_row.weight.grad)
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
|
||||||
# test for linear conv
|
check_linear_1d_col()
|
||||||
check_linear_conv_1d_col()
|
check_linear_1d_row()
|
||||||
check_linear_conv_1d_row()
|
check_linear_1d_col_row()
|
||||||
|
|
||||||
|
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
|
|
Loading…
Reference in New Issue