mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] fix linear 1d row and support uneven splits for fused qkv linear (#6084)
* [tp] hotfix linear row * [tp] support uneven split for fused linear * [tp] support sp for fused linear * [tp] fix gpt2 mlp policy * [tp] fix gather fused and add fused linear rowsupercooledith-patch-1
parent
f4daf04270
commit
646b3c5a90
|
@ -57,7 +57,9 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
|
|||
target_module=NopadBaichuanMLP,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.W_pack", target_module=FusedLinear1D_Col, kwargs={"n_fused": 3}
|
||||
suffix="self_attn.W_pack",
|
||||
target_module=FusedLinear1D_Col,
|
||||
kwargs={"split_sizes": [self.model.config.hidden_size] * 3},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
|
|
|
@ -6,7 +6,7 @@ from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHe
|
|||
from .loss import cross_entropy_1d, dist_cross_entropy
|
||||
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
||||
from .parallel_module import ParallelModule
|
||||
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||
from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||
|
||||
__all__ = [
|
||||
"Embedding1D",
|
||||
|
@ -34,4 +34,5 @@ __all__ = [
|
|||
"RingAttention",
|
||||
"get_pad_info",
|
||||
"all_to_all_comm",
|
||||
"FusedLinear1D_Row",
|
||||
]
|
||||
|
|
|
@ -840,7 +840,7 @@ class _AllToAll(torch.autograd.Function):
|
|||
ctx.gather_dim = gather_dim
|
||||
ctx.fp8_communication = fp8_communication
|
||||
world_size = dist.get_world_size(process_group)
|
||||
bsz, _, _ = input_.shape
|
||||
bsz = input_.shape[0]
|
||||
|
||||
# using all_to_all_single when batch size is 1
|
||||
if bsz == 1:
|
||||
|
@ -871,7 +871,7 @@ class _AllToAll(torch.autograd.Function):
|
|||
gather_dim = ctx.scatter_dim
|
||||
fp8_communication = ctx.fp8_communication
|
||||
world_size = dist.get_world_size(process_group)
|
||||
bsz, _, _ = grad_output.shape
|
||||
bsz = grad_output.shape[0]
|
||||
|
||||
if bsz == 1:
|
||||
return_grad = _all_to_all_single(
|
||||
|
|
|
@ -428,11 +428,8 @@ class Linear1D_Row(ParallelModule):
|
|||
handle.wait()
|
||||
output = torch.cat(output_parallel_list, dim=-1)
|
||||
else:
|
||||
if self.seq_parallel_mode is None:
|
||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
||||
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
||||
elif self.seq_parallel_mode == "split_gather":
|
||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
||||
if self.seq_parallel_mode == "split_gather":
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
output = reducescatter_forward_gather_backward(
|
||||
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
@ -445,8 +442,8 @@ class Linear1D_Row(ParallelModule):
|
|||
ring=True,
|
||||
)
|
||||
else:
|
||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
||||
output = reduce_forward(output_parallel, self.process_group)
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
||||
|
||||
if not self.skip_bias_add:
|
||||
if self.bias is not None:
|
||||
|
|
|
@ -7,6 +7,7 @@ from typing import Callable, List, Optional, Tuple, Union
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
|
@ -24,7 +25,9 @@ from colossalai.tensor.d_tensor.api import (
|
|||
)
|
||||
|
||||
from ._operation import (
|
||||
gather_forward_split_backward,
|
||||
gather_forward_reducescatter_backward,
|
||||
linear_gather_forward_reducescatter_backward,
|
||||
linear_reducescatter_forward_gather_backward,
|
||||
linear_with_async_comm,
|
||||
matmul_gather_forward_reducescatter_backward,
|
||||
matmul_with_async_comm,
|
||||
|
@ -44,21 +47,25 @@ __all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col"
|
|||
|
||||
|
||||
def split_fused_qkv_in_gpt2_style(
|
||||
qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False
|
||||
qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup, is_transposed: bool = False
|
||||
):
|
||||
"""
|
||||
The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2].
|
||||
|
||||
Args:
|
||||
qkv (torch.Tensor): The fused qkv tensor.
|
||||
n_fused (int): The number items fused together, defaults to 3 (query, key and value).
|
||||
split_sizes (List[int]): The sizes of the split tensor.
|
||||
process_group (ProcessGroup): The process group for distributed communication.
|
||||
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
|
||||
"""
|
||||
# get the number of slice for the fused qkv
|
||||
rank = dist.get_rank(group=process_group)
|
||||
world_size = dist.get_world_size(group=process_group)
|
||||
order = torch.arange(world_size * n_fused)
|
||||
order = torch.arange(world_size * len(split_sizes))
|
||||
new_split_sizes = []
|
||||
for sz in split_sizes:
|
||||
assert sz % world_size == 0, f"size {sz} is not divisible by world_size {world_size}"
|
||||
new_split_sizes.extend([sz // world_size] * world_size)
|
||||
|
||||
# split the fused qkv
|
||||
# from
|
||||
|
@ -66,9 +73,9 @@ def split_fused_qkv_in_gpt2_style(
|
|||
# to
|
||||
# [Q1, Q2, K1, K2, V1, V2]
|
||||
if is_transposed:
|
||||
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1)
|
||||
weight_chunks = torch.split(qkv, new_split_sizes, dim=-1)
|
||||
else:
|
||||
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=0)
|
||||
weight_chunks = torch.split(qkv, new_split_sizes, dim=0)
|
||||
|
||||
# rearrange the slice into the final order
|
||||
# from
|
||||
|
@ -85,18 +92,23 @@ def split_fused_qkv_in_gpt2_style(
|
|||
|
||||
|
||||
def gather_fused_qkv_in_gpt2_style(
|
||||
qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False
|
||||
qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup, is_transposed: bool = False
|
||||
):
|
||||
"""
|
||||
The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2].
|
||||
|
||||
Args:
|
||||
qkv (torch.Tensor): The fused qkv tensor.
|
||||
n_fused (int): The number items fused together, defaults to 3 (query, key and value).
|
||||
split_sizes (List[int]): The sizes of the split tensor.
|
||||
process_group (ProcessGroup): The process group for distributed communication.
|
||||
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
|
||||
"""
|
||||
world_size = dist.get_world_size(group=process_group)
|
||||
new_split_sizes = []
|
||||
for sz in split_sizes:
|
||||
assert sz % world_size == 0, f"size {sz} is not divisible by world_size {world_size}"
|
||||
new_split_sizes.append(sz // world_size)
|
||||
new_split_sizes = new_split_sizes * world_size
|
||||
|
||||
# gather the tensors
|
||||
# from
|
||||
|
@ -121,13 +133,13 @@ def gather_fused_qkv_in_gpt2_style(
|
|||
# to
|
||||
# [Q1, Q2, K1, K2, V1, V2]
|
||||
if is_transposed:
|
||||
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1)
|
||||
weight_chunks = torch.split(gather_weight, new_split_sizes, dim=-1)
|
||||
else:
|
||||
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=0)
|
||||
weight_chunks = torch.split(gather_weight, new_split_sizes, dim=0)
|
||||
|
||||
reordered_chunk_list = []
|
||||
for i in range(n_fused):
|
||||
reordered_chunk_list.extend(weight_chunks[i::n_fused])
|
||||
for i in range(len(split_sizes)):
|
||||
reordered_chunk_list.extend(weight_chunks[i :: len(split_sizes)])
|
||||
|
||||
if is_transposed:
|
||||
reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1)
|
||||
|
@ -136,6 +148,42 @@ def gather_fused_qkv_in_gpt2_style(
|
|||
return reordered_gather_weight
|
||||
|
||||
|
||||
class _SplitForwardGatherBackwardFusedQKV(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
|
||||
ctx.split_sizes = split_sizes
|
||||
ctx.process_group = process_group
|
||||
return split_fused_qkv_in_gpt2_style(qkv, split_sizes, process_group, is_transposed=True)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_output = gather_fused_qkv_in_gpt2_style(
|
||||
grad_output, ctx.split_sizes, ctx.process_group, is_transposed=True
|
||||
)
|
||||
return grad_output, None, None
|
||||
|
||||
|
||||
def split_forward_gather_backward_fused_qkv(qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
|
||||
return _SplitForwardGatherBackwardFusedQKV.apply(qkv, split_sizes, process_group)
|
||||
|
||||
|
||||
class _GatherForwardSplitBackwardFusedQKV(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
|
||||
ctx.split_sizes = split_sizes
|
||||
ctx.process_group = process_group
|
||||
return gather_fused_qkv_in_gpt2_style(qkv, split_sizes, process_group, is_transposed=True)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_output = split_fused_qkv_in_gpt2_style(grad_output, ctx.split_sizes, ctx.process_group, is_transposed=True)
|
||||
return grad_output, None, None
|
||||
|
||||
|
||||
def gather_forward_split_backward_fused_qkv(qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
|
||||
return _GatherForwardSplitBackwardFusedQKV.apply(qkv, split_sizes, process_group)
|
||||
|
||||
|
||||
class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||
r"""Linear layer with column parallelism.
|
||||
|
||||
|
@ -145,10 +193,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
Args:
|
||||
in_features (int): size of each input sample.
|
||||
out_features (int): size of each output sample.
|
||||
split_sizes (List[int]): The sizes of the split tensor.
|
||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||
device (`torch.device`): The device of parameters, defaults to None.
|
||||
n_fused (int): The number items fused, defaults to 3 (QKV).
|
||||
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||
seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None.
|
||||
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
||||
|
@ -169,6 +217,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
split_sizes: List[int],
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
|
@ -178,7 +227,6 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
seq_parallel_mode: str = None,
|
||||
overlap: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
n_fused: int = 3,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
|
@ -195,11 +243,15 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
self.overlap = overlap
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
self.n_fused = n_fused
|
||||
self.split_sizes = split_sizes
|
||||
self.process_group = process_group
|
||||
self.async_communication = async_communication
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
assert (
|
||||
sum(split_sizes) == out_features
|
||||
), f"The sum of split_sizes({sum(split_sizes)}) should be equal to out_features({out_features})."
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
||||
|
@ -223,10 +275,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
self.weight = weight
|
||||
|
||||
def shard_fn(tensor):
|
||||
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)
|
||||
return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
|
||||
|
||||
def gather_fn(tensor):
|
||||
return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)
|
||||
return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
|
||||
|
||||
if not is_customized_distributed_tensor(self.weight):
|
||||
with torch.no_grad():
|
||||
|
@ -252,7 +304,11 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
module: nn.Module,
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]],
|
||||
split_sizes: List[int],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> ParallelModule:
|
||||
r"""
|
||||
Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer.
|
||||
|
@ -260,7 +316,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
Args:
|
||||
module (`nn.Linear`): The module to be converted.
|
||||
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
|
||||
n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight.
|
||||
split_sizes (List[int]): The sizes of the split tensor. In GPT2, Q,K,V are fused in one weight.
|
||||
"""
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
|
@ -291,6 +347,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
split_sizes=split_sizes,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -354,9 +411,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_forward_split_backward(
|
||||
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
output = gather_forward_split_backward_fused_qkv(output_parallel, self.split_sizes, self.process_group)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
|
@ -605,10 +660,10 @@ class FusedLinear1D_Col(ParallelModule):
|
|||
Args:
|
||||
in_features (int): size of each input sample.
|
||||
out_features (int): size of each output sample.
|
||||
split_sizes (List[int]): The sizes of the split tensor.
|
||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||
device (`torch.device`): The device of parameters, defaults to None.
|
||||
n_fused (int): The number items fused, defaults to 3 (QKV).
|
||||
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
|
@ -628,14 +683,16 @@ class FusedLinear1D_Col(ParallelModule):
|
|||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
split_sizes: List[int],
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
async_communication: bool = False,
|
||||
gather_output: bool = False,
|
||||
seq_parallel_mode: str = None,
|
||||
seq_parallel_dim: int = 1,
|
||||
overlap: torch.cuda.Stream = None,
|
||||
skip_bias_add: bool = False,
|
||||
n_fused: int = 3,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
|
@ -647,13 +704,19 @@ class FusedLinear1D_Col(ParallelModule):
|
|||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.gather_output = gather_output
|
||||
self.seq_parallel_mode = seq_parallel_mode
|
||||
self.seq_parallel_dim = seq_parallel_dim
|
||||
self.overlap = overlap
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
self.n_fused = n_fused
|
||||
self.split_sizes = split_sizes
|
||||
self.process_group = process_group
|
||||
self.async_communication = async_communication
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
assert (
|
||||
sum(split_sizes) == out_features
|
||||
), f"The sum of split_sizes({sum(split_sizes)}) should be equal to out_features({out_features})."
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
||||
|
@ -677,10 +740,10 @@ class FusedLinear1D_Col(ParallelModule):
|
|||
self.weight = weight
|
||||
|
||||
def shard_fn(tensor):
|
||||
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
|
||||
return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, False)
|
||||
|
||||
def gather_fn(tensor):
|
||||
return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
|
||||
return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, False)
|
||||
|
||||
if not is_customized_distributed_tensor(self.weight):
|
||||
with torch.no_grad():
|
||||
|
@ -706,7 +769,11 @@ class FusedLinear1D_Col(ParallelModule):
|
|||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, *args, **kwargs
|
||||
module: nn.Module,
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]],
|
||||
split_sizes: List[int],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> ParallelModule:
|
||||
r"""
|
||||
Convert a fused `torch.nn.linear` layer to a parallelized linear layer.
|
||||
|
@ -714,7 +781,7 @@ class FusedLinear1D_Col(ParallelModule):
|
|||
Args:
|
||||
module (`nn.Linear`): The module to be converted.
|
||||
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
|
||||
n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight.
|
||||
split_sizes (List[int]): The sizes of the split tensor. In common, Q,K,V are fused in one weight.
|
||||
"""
|
||||
LazyInitContext.materialize(module)
|
||||
|
||||
|
@ -737,25 +804,11 @@ class FusedLinear1D_Col(ParallelModule):
|
|||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
n_fused=n_fused,
|
||||
split_sizes=split_sizes,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# # TODO: copy the sharded weights
|
||||
# with torch.no_grad():
|
||||
# sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
|
||||
# n_fused=n_fused,
|
||||
# process_group=process_group,
|
||||
# is_transposed=False)
|
||||
# linear_1d.weight.data.copy_(sharded_weight.data)
|
||||
|
||||
# if bias:
|
||||
# sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
|
||||
# n_fused=n_fused,
|
||||
# process_group=process_group,
|
||||
# is_transposed=False)
|
||||
# linear_1d.bias.data.copy_(sharded_bias.data)
|
||||
return linear_1d
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
|
@ -772,19 +825,30 @@ class FusedLinear1D_Col(ParallelModule):
|
|||
input_.shape, self.weight.shape, self.weight.shape[-1]
|
||||
)
|
||||
# Set up backprop all-reduce.
|
||||
# input_parallel = reduce_backward(input_, self.process_group)
|
||||
input_parallel = input_
|
||||
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
||||
if self.seq_parallel_mode == "split_gather":
|
||||
input_parallel = gather_forward_reducescatter_backward(
|
||||
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||
)
|
||||
output_parallel = linear_with_async_comm(
|
||||
input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication
|
||||
)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
output_parallel = linear_gather_forward_reducescatter_backward(
|
||||
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
|
||||
)
|
||||
else:
|
||||
output_parallel = linear_with_async_comm(
|
||||
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_forward_split_backward(
|
||||
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
output = gather_forward_split_backward_fused_qkv(output_parallel, self.split_sizes, self.process_group)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
|
@ -792,3 +856,201 @@ class FusedLinear1D_Col(ParallelModule):
|
|||
return output, self.bias
|
||||
else:
|
||||
return output
|
||||
|
||||
|
||||
class FusedLinear1D_Row(ParallelModule):
|
||||
r"""Linear layer with row parallelism
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample.
|
||||
out_features (int): size of each output sample.
|
||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
|
||||
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||
seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None.
|
||||
seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence.
|
||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||
which is preserved for kernel fusion, defaults to False
|
||||
weight_initializer (:class:`typing.Callable`, optional):
|
||||
The initializer of weight, defaults to kaiming uniform initializer.
|
||||
bias_initializer (:class:`typing.Callable`, optional):
|
||||
The initializer of bias, defaults to xavier uniform initializer.
|
||||
|
||||
More details about ``initializer`` please refer to
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
split_sizes: List[int],
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
seq_parallel_mode: str = None,
|
||||
seq_parallel_dim: int = 1,
|
||||
parallel_input: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.split_sizes = split_sizes
|
||||
self.parallel_input = parallel_input
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.process_group = process_group
|
||||
self.seq_parallel_mode = seq_parallel_mode
|
||||
self.seq_parallel_dim = seq_parallel_dim
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
assert (
|
||||
sum(split_sizes) == in_features
|
||||
), f"The sum of split_sizes({sum(split_sizes)}) should be equal to in_features({in_features})."
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
# sanity check
|
||||
if weight is not None:
|
||||
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
|
||||
else:
|
||||
assert bias_ is None, "bias_ must be None if weight is None"
|
||||
|
||||
# Parameters.
|
||||
if weight is None:
|
||||
# Initialize weight.
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
self.weight = weight
|
||||
|
||||
def shard_fn(tensor):
|
||||
return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
|
||||
|
||||
def gather_fn(tensor):
|
||||
return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
|
||||
|
||||
if not is_customized_distributed_tensor(self.weight):
|
||||
with torch.no_grad():
|
||||
sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn)
|
||||
customized_distributed_tensor_to_existing_param(sharded_weight, self.weight)
|
||||
|
||||
if bias:
|
||||
if bias_ is None:
|
||||
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||
else:
|
||||
bias_.data = bias_.data.to(device=device, dtype=dtype)
|
||||
self.bias = bias_
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
if weight is None:
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], split_sizes: List[int], **kwargs
|
||||
) -> ParallelModule:
|
||||
r"""
|
||||
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||
"""
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
bias = module.bias is not None
|
||||
device = module.weight.device
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
process_group = process_group[0]
|
||||
|
||||
linear_1d = FusedLinear1D_Row(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
split_sizes=split_sizes,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return linear_1d
|
||||
|
||||
@torch.no_grad()
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
if self.process_group is None:
|
||||
src_rank = 0
|
||||
else:
|
||||
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
|
||||
|
||||
origin_device = self.bias.device
|
||||
bias = self.bias.cuda()
|
||||
dist.broadcast(bias, src=src_rank, group=self.process_group)
|
||||
bias = bias.to(origin_device)
|
||||
self.bias.copy_(bias)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# Set up backprop all-reduce.
|
||||
if self.parallel_input:
|
||||
assert (
|
||||
input_.shape[-1] == self.weight.shape[-1]
|
||||
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1]
|
||||
)
|
||||
input_ = input_
|
||||
else:
|
||||
assert (
|
||||
divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1]
|
||||
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
|
||||
)
|
||||
input_ = split_forward_gather_backward_fused_qkv(input_, self.split_sizes, self.process_group)
|
||||
|
||||
if self.seq_parallel_mode == "split_gather":
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
output = reducescatter_forward_gather_backward(
|
||||
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||
)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
output = linear_reducescatter_forward_gather_backward(
|
||||
input_,
|
||||
self.weight,
|
||||
process_group=self.process_group,
|
||||
dim=self.seq_parallel_dim,
|
||||
ring=True,
|
||||
)
|
||||
else:
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
||||
|
||||
if not self.skip_bias_add:
|
||||
if self.bias is not None:
|
||||
output = output + self.bias
|
||||
return output
|
||||
else:
|
||||
return output, self.bias
|
||||
|
|
|
@ -71,7 +71,7 @@ class BlipPolicy(Policy):
|
|||
suffix="self_attn.qkv",
|
||||
target_module=col_nn.FusedLinear1D_Col,
|
||||
kwargs={
|
||||
"n_fused": 3,
|
||||
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
|
|
|
@ -92,7 +92,7 @@ class GPT2Policy(Policy):
|
|||
suffix="attn.c_attn",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||
kwargs={
|
||||
"n_fused": 3,
|
||||
"split_sizes": [self.model.config.hidden_size] * 3,
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
|
@ -107,7 +107,7 @@ class GPT2Policy(Policy):
|
|||
suffix="mlp.c_fc",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||
kwargs={
|
||||
"n_fused": 1,
|
||||
"split_sizes": [self.model.config.n_inner or 4 * self.model.config.hidden_size],
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
|
|
|
@ -42,7 +42,7 @@ class SamPolicy(Policy):
|
|||
suffix="attn.qkv",
|
||||
target_module=col_nn.FusedLinear1D_Col,
|
||||
kwargs={
|
||||
"n_fused": 3,
|
||||
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
|
|
|
@ -41,21 +41,6 @@ class Conv1D(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def rearrange(tensor: torch.Tensor, dim: int):
|
||||
tensor = tensor.clone()
|
||||
world_size = 2
|
||||
order = torch.arange(world_size * 3)
|
||||
new_order = []
|
||||
for i in range(world_size):
|
||||
new_order.append(order[i::world_size])
|
||||
new_order = torch.cat(new_order)
|
||||
|
||||
tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim)
|
||||
rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order]
|
||||
rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim)
|
||||
return rearanged_tensor
|
||||
|
||||
|
||||
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
|
@ -66,7 +51,7 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
|
|||
process_group=None,
|
||||
gather_output=True,
|
||||
seq_parallel_mode=seq_parallel_mode,
|
||||
n_fused=3,
|
||||
split_sizes=[64] * 3,
|
||||
overlap=overlap,
|
||||
)
|
||||
|
||||
|
@ -88,13 +73,13 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
|
|||
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
||||
)
|
||||
gather_out = linear_conv_col(x_for_shard)
|
||||
assert_close(rearrange(out, -1), gather_out)
|
||||
assert_close(out, gather_out)
|
||||
|
||||
# check backward correctness
|
||||
out.sum().backward()
|
||||
gather_out.sum().backward()
|
||||
|
||||
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True)
|
||||
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [64] * 3, None, True)
|
||||
assert_close(target_grad, linear_conv_col.weight.grad)
|
||||
|
||||
|
||||
|
|
|
@ -2,13 +2,12 @@ import os
|
|||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||
from colossalai.shardformer.layer import FusedLinear1D_Col, FusedLinear1D_Row
|
||||
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
@ -16,93 +15,55 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
|||
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
||||
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
"""
|
||||
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
|
||||
|
||||
Basically works like a linear layer but the weights are transposed.
|
||||
|
||||
Args:
|
||||
nf (`int`): The number of output features.
|
||||
nx (`int`): The number of input features.
|
||||
"""
|
||||
|
||||
def __init__(self, nf, nx):
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
self.weight = nn.Parameter(torch.empty(nx, nf))
|
||||
self.bias = nn.Parameter(torch.zeros(nf))
|
||||
nn.init.normal_(self.weight, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
size_out = x.size()[:-1] + (self.nf,)
|
||||
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
||||
x = x.view(size_out)
|
||||
return x
|
||||
|
||||
|
||||
def rearrange(tensor: torch.Tensor, dim: int):
|
||||
tensor = tensor.clone()
|
||||
world_size = 2
|
||||
order = torch.arange(world_size * 3)
|
||||
new_order = []
|
||||
for i in range(world_size):
|
||||
new_order.append(order[i::world_size])
|
||||
new_order = torch.cat(new_order)
|
||||
|
||||
tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim)
|
||||
rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order]
|
||||
rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim)
|
||||
return rearanged_tensor
|
||||
|
||||
|
||||
@parameterize("lazy_init", [False, True])
|
||||
def check_linear_conv_1d_col(lazy_init: bool):
|
||||
def check_linear_1d_col(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
linear = nn.Linear(8, 80).cuda()
|
||||
with ctx:
|
||||
linear_copy = Conv1D(192, 48).cuda()
|
||||
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(
|
||||
linear_copy, process_group=None, gather_output=True, n_fused=3
|
||||
linear_copy = nn.Linear(8, 80).cuda()
|
||||
linear_col = FusedLinear1D_Col.from_native_module(
|
||||
linear_copy, process_group=None, gather_output=True, split_sizes=[32, 32, 16]
|
||||
)
|
||||
|
||||
assert linear.weight.shape == torch.Size([48, 192])
|
||||
assert linear.bias.shape == torch.Size([192])
|
||||
assert linear_conv_col.weight.shape == torch.Size([48, 96])
|
||||
assert linear_conv_col.bias.shape == torch.Size([96])
|
||||
assert linear_copy.weight is linear_conv_col.weight
|
||||
assert linear_copy.bias is linear_conv_col.bias
|
||||
assert linear.weight.shape == torch.Size([80, 8])
|
||||
assert linear.bias.shape == torch.Size([80])
|
||||
assert linear_col.weight.shape == torch.Size([40, 8])
|
||||
assert linear_col.bias.shape == torch.Size([40])
|
||||
assert linear_copy.weight is linear_col.weight
|
||||
assert linear_copy.bias is linear_col.bias
|
||||
|
||||
# ensure weights are reversibly loadable
|
||||
linear_conv_col.load_state_dict(linear.state_dict())
|
||||
linear.load_state_dict(linear_conv_col.state_dict())
|
||||
linear_col.load_state_dict(linear.state_dict())
|
||||
linear.load_state_dict(linear_col.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 48).cuda()
|
||||
x = torch.rand(4, 8).cuda()
|
||||
out = linear(x)
|
||||
gather_out = linear_conv_col(x)
|
||||
assert_close(rearrange(out, 1), gather_out)
|
||||
gather_out = linear_col(x)
|
||||
assert_close(out, gather_out)
|
||||
|
||||
# check backward correctness
|
||||
out.sum().backward()
|
||||
gather_out.sum().backward()
|
||||
|
||||
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True)
|
||||
assert_close(target_grad, linear_conv_col.weight.grad)
|
||||
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [32, 32, 16], None, False)
|
||||
assert_close(target_grad, linear_col.weight.grad)
|
||||
|
||||
|
||||
@parameterize("lazy_init", [False, True])
|
||||
def check_linear_conv_1d_row(lazy_init: bool):
|
||||
def check_linear_1d_row(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
linear = nn.Linear(80, 8).cuda()
|
||||
with ctx:
|
||||
linear_copy = Conv1D(192, 48).cuda()
|
||||
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
|
||||
linear_copy = nn.Linear(80, 8).cuda()
|
||||
linear_row = FusedLinear1D_Row.from_native_module(
|
||||
linear_copy, process_group=None, split_sizes=[32, 32, 16], parallel_input=False
|
||||
)
|
||||
|
||||
assert linear.weight.shape == torch.Size([48, 192])
|
||||
assert linear_row.weight.shape == torch.Size([24, 192])
|
||||
assert linear_row.bias.shape == torch.Size([192])
|
||||
assert linear.weight.shape == torch.Size([8, 80])
|
||||
assert linear_row.weight.shape == torch.Size([8, 40])
|
||||
assert linear_row.bias.shape == torch.Size([8])
|
||||
assert linear_copy.weight is linear_row.weight
|
||||
assert linear_copy.bias is linear_row.bias
|
||||
|
||||
|
@ -111,7 +72,7 @@ def check_linear_conv_1d_row(lazy_init: bool):
|
|||
linear.load_state_dict(linear_row.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 48).cuda()
|
||||
x = torch.rand(4, 80).cuda()
|
||||
out = linear(x)
|
||||
gather_out = linear_row(x)
|
||||
assert_close(out, gather_out)
|
||||
|
@ -120,17 +81,51 @@ def check_linear_conv_1d_row(lazy_init: bool):
|
|||
out.sum().backward()
|
||||
gather_out.sum().backward()
|
||||
|
||||
rank = dist.get_rank()
|
||||
target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank]
|
||||
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [32, 32, 16], None, True)
|
||||
assert_close(target_grad, linear_row.weight.grad)
|
||||
|
||||
|
||||
@parameterize("lazy_init", [False, True])
|
||||
def check_linear_1d_col_row(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear1 = nn.Linear(8, 80).cuda()
|
||||
linear2 = nn.Linear(80, 8).cuda()
|
||||
with ctx:
|
||||
linear1_copy = nn.Linear(8, 80).cuda()
|
||||
linear2_copy = nn.Linear(80, 8).cuda()
|
||||
linear_col = FusedLinear1D_Col.from_native_module(linear1_copy, process_group=None, split_sizes=[32, 32, 16])
|
||||
linear_row = FusedLinear1D_Row.from_native_module(
|
||||
linear2_copy,
|
||||
process_group=None,
|
||||
split_sizes=[32, 32, 16],
|
||||
)
|
||||
# ensure weights are reversibly loadable
|
||||
linear_col.load_state_dict(linear1.state_dict())
|
||||
linear_row.load_state_dict(linear2.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 8).cuda()
|
||||
target_out = linear2(linear1(x))
|
||||
out = linear_row(linear_col(x))
|
||||
assert_close(out, target_out)
|
||||
|
||||
# check backward correctness
|
||||
target_out.sum().backward()
|
||||
out.sum().backward()
|
||||
|
||||
target_grad1 = split_fused_qkv_in_gpt2_style(linear1.weight.grad, [32, 32, 16], None, False)
|
||||
assert_close(target_grad1, linear_col.weight.grad)
|
||||
target_grad2 = split_fused_qkv_in_gpt2_style(linear2.weight.grad, [32, 32, 16], None, True)
|
||||
assert_close(target_grad2, linear_row.weight.grad)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
# test for linear conv
|
||||
check_linear_conv_1d_col()
|
||||
check_linear_conv_1d_row()
|
||||
check_linear_1d_col()
|
||||
check_linear_1d_row()
|
||||
check_linear_1d_col_row()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
|
|
Loading…
Reference in New Issue