diff --git a/colossalai/inference/modeling/policy/nopadding_baichuan.py b/colossalai/inference/modeling/policy/nopadding_baichuan.py index 37b5062e8..8528de75c 100644 --- a/colossalai/inference/modeling/policy/nopadding_baichuan.py +++ b/colossalai/inference/modeling/policy/nopadding_baichuan.py @@ -57,7 +57,9 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM): target_module=NopadBaichuanMLP, ), SubModuleReplacementDescription( - suffix="self_attn.W_pack", target_module=FusedLinear1D_Col, kwargs={"n_fused": 3} + suffix="self_attn.W_pack", + target_module=FusedLinear1D_Col, + kwargs={"split_sizes": [self.model.config.hidden_size] * 3}, ), SubModuleReplacementDescription( suffix="self_attn.o_proj", diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 8882a33c1..684993de6 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -6,7 +6,7 @@ from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHe from .loss import cross_entropy_1d, dist_cross_entropy from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule -from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row +from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row __all__ = [ "Embedding1D", @@ -34,4 +34,5 @@ __all__ = [ "RingAttention", "get_pad_info", "all_to_all_comm", + "FusedLinear1D_Row", ] diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index aec823567..1d7a1f104 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -840,7 +840,7 @@ class _AllToAll(torch.autograd.Function): ctx.gather_dim = gather_dim ctx.fp8_communication = fp8_communication world_size = dist.get_world_size(process_group) - bsz, _, _ = input_.shape + bsz = input_.shape[0] # using all_to_all_single when batch size is 1 if bsz == 1: @@ -871,7 +871,7 @@ class _AllToAll(torch.autograd.Function): gather_dim = ctx.scatter_dim fp8_communication = ctx.fp8_communication world_size = dist.get_world_size(process_group) - bsz, _, _ = grad_output.shape + bsz = grad_output.shape[0] if bsz == 1: return_grad = _all_to_all_single( diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index d77dd4965..52b0e79c6 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -428,11 +428,8 @@ class Linear1D_Row(ParallelModule): handle.wait() output = torch.cat(output_parallel_list, dim=-1) else: - if self.seq_parallel_mode is None: - output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) - output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication) - elif self.seq_parallel_mode == "split_gather": - output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) + if self.seq_parallel_mode == "split_gather": + output_parallel = F.linear(input_, self.weight) output = reducescatter_forward_gather_backward( output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication ) @@ -445,8 +442,8 @@ class Linear1D_Row(ParallelModule): ring=True, ) else: - output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) - output = reduce_forward(output_parallel, self.process_group) + output_parallel = F.linear(input_, self.weight) + output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication) if not self.skip_bias_add: if self.bias is not None: diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 6fd689908..a1e25ff3a 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -7,6 +7,7 @@ from typing import Callable, List, Optional, Tuple, Union import torch import torch.distributed as dist import torch.nn as nn +import torch.nn.functional as F from torch import Tensor from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter @@ -24,7 +25,9 @@ from colossalai.tensor.d_tensor.api import ( ) from ._operation import ( - gather_forward_split_backward, + gather_forward_reducescatter_backward, + linear_gather_forward_reducescatter_backward, + linear_reducescatter_forward_gather_backward, linear_with_async_comm, matmul_gather_forward_reducescatter_backward, matmul_with_async_comm, @@ -44,21 +47,25 @@ __all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col" def split_fused_qkv_in_gpt2_style( - qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False + qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup, is_transposed: bool = False ): """ The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2]. Args: qkv (torch.Tensor): The fused qkv tensor. - n_fused (int): The number items fused together, defaults to 3 (query, key and value). + split_sizes (List[int]): The sizes of the split tensor. process_group (ProcessGroup): The process group for distributed communication. is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features). """ # get the number of slice for the fused qkv rank = dist.get_rank(group=process_group) world_size = dist.get_world_size(group=process_group) - order = torch.arange(world_size * n_fused) + order = torch.arange(world_size * len(split_sizes)) + new_split_sizes = [] + for sz in split_sizes: + assert sz % world_size == 0, f"size {sz} is not divisible by world_size {world_size}" + new_split_sizes.extend([sz // world_size] * world_size) # split the fused qkv # from @@ -66,9 +73,9 @@ def split_fused_qkv_in_gpt2_style( # to # [Q1, Q2, K1, K2, V1, V2] if is_transposed: - weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1) + weight_chunks = torch.split(qkv, new_split_sizes, dim=-1) else: - weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=0) + weight_chunks = torch.split(qkv, new_split_sizes, dim=0) # rearrange the slice into the final order # from @@ -85,18 +92,23 @@ def split_fused_qkv_in_gpt2_style( def gather_fused_qkv_in_gpt2_style( - qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False + qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup, is_transposed: bool = False ): """ The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2]. Args: qkv (torch.Tensor): The fused qkv tensor. - n_fused (int): The number items fused together, defaults to 3 (query, key and value). + split_sizes (List[int]): The sizes of the split tensor. process_group (ProcessGroup): The process group for distributed communication. is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features). """ world_size = dist.get_world_size(group=process_group) + new_split_sizes = [] + for sz in split_sizes: + assert sz % world_size == 0, f"size {sz} is not divisible by world_size {world_size}" + new_split_sizes.append(sz // world_size) + new_split_sizes = new_split_sizes * world_size # gather the tensors # from @@ -121,13 +133,13 @@ def gather_fused_qkv_in_gpt2_style( # to # [Q1, Q2, K1, K2, V1, V2] if is_transposed: - weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1) + weight_chunks = torch.split(gather_weight, new_split_sizes, dim=-1) else: - weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=0) + weight_chunks = torch.split(gather_weight, new_split_sizes, dim=0) reordered_chunk_list = [] - for i in range(n_fused): - reordered_chunk_list.extend(weight_chunks[i::n_fused]) + for i in range(len(split_sizes)): + reordered_chunk_list.extend(weight_chunks[i :: len(split_sizes)]) if is_transposed: reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1) @@ -136,6 +148,42 @@ def gather_fused_qkv_in_gpt2_style( return reordered_gather_weight +class _SplitForwardGatherBackwardFusedQKV(torch.autograd.Function): + @staticmethod + def forward(ctx, qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup): + ctx.split_sizes = split_sizes + ctx.process_group = process_group + return split_fused_qkv_in_gpt2_style(qkv, split_sizes, process_group, is_transposed=True) + + @staticmethod + def backward(ctx, grad_output): + grad_output = gather_fused_qkv_in_gpt2_style( + grad_output, ctx.split_sizes, ctx.process_group, is_transposed=True + ) + return grad_output, None, None + + +def split_forward_gather_backward_fused_qkv(qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup): + return _SplitForwardGatherBackwardFusedQKV.apply(qkv, split_sizes, process_group) + + +class _GatherForwardSplitBackwardFusedQKV(torch.autograd.Function): + @staticmethod + def forward(ctx, qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup): + ctx.split_sizes = split_sizes + ctx.process_group = process_group + return gather_fused_qkv_in_gpt2_style(qkv, split_sizes, process_group, is_transposed=True) + + @staticmethod + def backward(ctx, grad_output): + grad_output = split_fused_qkv_in_gpt2_style(grad_output, ctx.split_sizes, ctx.process_group, is_transposed=True) + return grad_output, None, None + + +def gather_forward_split_backward_fused_qkv(qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup): + return _GatherForwardSplitBackwardFusedQKV.apply(qkv, split_sizes, process_group) + + class GPT2FusedLinearConv1D_Col(ParallelModule): r"""Linear layer with column parallelism. @@ -145,10 +193,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): Args: in_features (int): size of each input sample. out_features (int): size of each output sample. + split_sizes (List[int]): The sizes of the split tensor. bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. dtype (`torch.dtype`): The dtype of parameters, defaults to None. device (`torch.device`): The device of parameters, defaults to None. - n_fused (int): The number items fused, defaults to 3 (QKV). process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None. gather_output (bool, optional): If true, call all-gather on output and make Y available @@ -169,6 +217,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): self, in_features: int, out_features: int, + split_sizes: List[int], bias: bool = True, dtype: torch.dtype = None, device: torch.device = None, @@ -178,7 +227,6 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): seq_parallel_mode: str = None, overlap: bool = False, skip_bias_add: bool = False, - n_fused: int = 3, weight: Optional[Parameter] = None, bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), @@ -195,11 +243,15 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): self.overlap = overlap self.skip_bias_add = skip_bias_add self.device = device - self.n_fused = n_fused + self.split_sizes = split_sizes self.process_group = process_group self.async_communication = async_communication self.fp8_communication = fp8_communication + assert ( + sum(split_sizes) == out_features + ), f"The sum of split_sizes({sum(split_sizes)}) should be equal to out_features({out_features})." + if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") @@ -223,10 +275,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): self.weight = weight def shard_fn(tensor): - return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True) + return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True) def gather_fn(tensor): - return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True) + return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True) if not is_customized_distributed_tensor(self.weight): with torch.no_grad(): @@ -252,7 +304,11 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): @staticmethod def from_native_module( - module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + module: nn.Module, + process_group: Union[ProcessGroup, List[ProcessGroup]], + split_sizes: List[int], + *args, + **kwargs, ) -> ParallelModule: r""" Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer. @@ -260,7 +316,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): Args: module (`nn.Linear`): The module to be converted. process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication. - n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight. + split_sizes (List[int]): The sizes of the split tensor. In GPT2, Q,K,V are fused in one weight. """ LazyInitContext.materialize(module) # get the attributes @@ -291,6 +347,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): process_group=process_group, weight=module.weight, bias_=module.bias, + split_sizes=split_sizes, *args, **kwargs, ) @@ -354,9 +411,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): if self.gather_output: # All-gather across the partitions. - output = gather_forward_split_backward( - output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication - ) + output = gather_forward_split_backward_fused_qkv(output_parallel, self.split_sizes, self.process_group) else: output = output_parallel @@ -605,10 +660,10 @@ class FusedLinear1D_Col(ParallelModule): Args: in_features (int): size of each input sample. out_features (int): size of each output sample. + split_sizes (List[int]): The sizes of the split tensor. bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. dtype (`torch.dtype`): The dtype of parameters, defaults to None. device (`torch.device`): The device of parameters, defaults to None. - n_fused (int): The number items fused, defaults to 3 (QKV). process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. gather_output (bool, optional): If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output @@ -628,14 +683,16 @@ class FusedLinear1D_Col(ParallelModule): self, in_features: int, out_features: int, + split_sizes: List[int], bias: bool = True, dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, - async_communication: bool = False, gather_output: bool = False, + seq_parallel_mode: str = None, + seq_parallel_dim: int = 1, + overlap: torch.cuda.Stream = None, skip_bias_add: bool = False, - n_fused: int = 3, weight: Optional[Parameter] = None, bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), @@ -647,13 +704,19 @@ class FusedLinear1D_Col(ParallelModule): self.in_features = in_features self.out_features = out_features self.gather_output = gather_output + self.seq_parallel_mode = seq_parallel_mode + self.seq_parallel_dim = seq_parallel_dim + self.overlap = overlap self.skip_bias_add = skip_bias_add self.device = device - self.n_fused = n_fused + self.split_sizes = split_sizes self.process_group = process_group - self.async_communication = async_communication self.fp8_communication = fp8_communication + assert ( + sum(split_sizes) == out_features + ), f"The sum of split_sizes({sum(split_sizes)}) should be equal to out_features({out_features})." + if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") @@ -677,10 +740,10 @@ class FusedLinear1D_Col(ParallelModule): self.weight = weight def shard_fn(tensor): - return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False) + return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, False) def gather_fn(tensor): - return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False) + return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, False) if not is_customized_distributed_tensor(self.weight): with torch.no_grad(): @@ -706,7 +769,11 @@ class FusedLinear1D_Col(ParallelModule): @staticmethod def from_native_module( - module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, *args, **kwargs + module: nn.Module, + process_group: Union[ProcessGroup, List[ProcessGroup]], + split_sizes: List[int], + *args, + **kwargs, ) -> ParallelModule: r""" Convert a fused `torch.nn.linear` layer to a parallelized linear layer. @@ -714,7 +781,7 @@ class FusedLinear1D_Col(ParallelModule): Args: module (`nn.Linear`): The module to be converted. process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication. - n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight. + split_sizes (List[int]): The sizes of the split tensor. In common, Q,K,V are fused in one weight. """ LazyInitContext.materialize(module) @@ -737,25 +804,11 @@ class FusedLinear1D_Col(ParallelModule): process_group=process_group, weight=module.weight, bias_=module.bias, - n_fused=n_fused, + split_sizes=split_sizes, *args, **kwargs, ) - # # TODO: copy the sharded weights - # with torch.no_grad(): - # sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data, - # n_fused=n_fused, - # process_group=process_group, - # is_transposed=False) - # linear_1d.weight.data.copy_(sharded_weight.data) - - # if bias: - # sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data, - # n_fused=n_fused, - # process_group=process_group, - # is_transposed=False) - # linear_1d.bias.data.copy_(sharded_bias.data) return linear_1d def reset_parameters(self, weight_initializer, bias_initializer) -> None: @@ -772,19 +825,30 @@ class FusedLinear1D_Col(ParallelModule): input_.shape, self.weight.shape, self.weight.shape[-1] ) # Set up backprop all-reduce. - # input_parallel = reduce_backward(input_, self.process_group) input_parallel = input_ # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + if self.seq_parallel_mode == "split_gather": + input_parallel = gather_forward_reducescatter_backward( + input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication + ) + output_parallel = linear_with_async_comm( + input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication + ) + elif self.seq_parallel_mode == "ring": + output_parallel = linear_gather_forward_reducescatter_backward( + input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True + ) + else: + output_parallel = linear_with_async_comm( + input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication + ) if self.gather_output: # All-gather across the partitions. - output = gather_forward_split_backward( - output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication - ) + output = gather_forward_split_backward_fused_qkv(output_parallel, self.split_sizes, self.process_group) else: output = output_parallel @@ -792,3 +856,201 @@ class FusedLinear1D_Col(ParallelModule): return output, self.bias else: return output + + +class FusedLinear1D_Row(ParallelModule): + r"""Linear layer with row parallelism + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None. + seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + out_features: int, + split_sizes: List[int], + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + seq_parallel_mode: str = None, + seq_parallel_dim: int = 1, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + fp8_communication: bool = False, + ): + super().__init__() + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.split_sizes = split_sizes + self.parallel_input = parallel_input + self.skip_bias_add = skip_bias_add + self.process_group = process_group + self.seq_parallel_mode = seq_parallel_mode + self.seq_parallel_dim = seq_parallel_dim + self.num_partitions = dist.get_world_size(self.process_group) + self.fp8_communication = fp8_communication + + assert ( + sum(split_sizes) == in_features + ), f"The sum of split_sizes({sum(split_sizes)}) should be equal to in_features({in_features})." + + if skip_bias_add and not bias: + raise ValueError("cannot skip bias addition if bias is None") + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + # sanity check + if weight is not None: + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" + else: + assert bias_ is None, "bias_ must be None if weight is None" + + # Parameters. + if weight is None: + # Initialize weight. + factory_kwargs = {"device": device, "dtype": dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + + def shard_fn(tensor): + return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True) + + def gather_fn(tensor): + return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True) + + if not is_customized_distributed_tensor(self.weight): + with torch.no_grad(): + sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn) + customized_distributed_tensor_to_existing_param(sharded_weight, self.weight) + + if bias: + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ + else: + self.bias = None + + if weight is None: + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], split_sizes: List[int], **kwargs + ) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." + process_group = process_group[0] + + linear_1d = FusedLinear1D_Row( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, + split_sizes=split_sizes, + **kwargs, + ) + + return linear_1d + + @torch.no_grad() + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + if self.process_group is None: + src_rank = 0 + else: + src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) + + origin_device = self.bias.device + bias = self.bias.cuda() + dist.broadcast(bias, src=src_rank, group=self.process_group) + bias = bias.to(origin_device) + self.bias.copy_(bias) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) + input_ = input_ + else: + assert ( + divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions + ) + input_ = split_forward_gather_backward_fused_qkv(input_, self.split_sizes, self.process_group) + + if self.seq_parallel_mode == "split_gather": + output_parallel = F.linear(input_, self.weight) + output = reducescatter_forward_gather_backward( + output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication + ) + elif self.seq_parallel_mode == "ring": + output = linear_reducescatter_forward_gather_backward( + input_, + self.weight, + process_group=self.process_group, + dim=self.seq_parallel_dim, + ring=True, + ) + else: + output_parallel = F.linear(input_, self.weight) + output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication) + + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index da798f6a0..2e73d5c2a 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -71,7 +71,7 @@ class BlipPolicy(Policy): suffix="self_attn.qkv", target_module=col_nn.FusedLinear1D_Col, kwargs={ - "n_fused": 3, + "split_sizes": [self.model.config.vision_config.hidden_size] * 3, "fp8_communication": self.shard_config.fp8_communication, }, ), diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index d9233be9a..faacf91b2 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -92,7 +92,7 @@ class GPT2Policy(Policy): suffix="attn.c_attn", target_module=col_nn.GPT2FusedLinearConv1D_Col, kwargs={ - "n_fused": 3, + "split_sizes": [self.model.config.hidden_size] * 3, "seq_parallel_mode": sp_mode, "overlap": overlap, "fp8_communication": self.shard_config.fp8_communication, @@ -107,7 +107,7 @@ class GPT2Policy(Policy): suffix="mlp.c_fc", target_module=col_nn.GPT2FusedLinearConv1D_Col, kwargs={ - "n_fused": 1, + "split_sizes": [self.model.config.n_inner or 4 * self.model.config.hidden_size], "seq_parallel_mode": sp_mode, "overlap": overlap, "skip_bias_add": self.enable_bias_gelu_fused, diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index 674fe5e58..a94cc9119 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -42,7 +42,7 @@ class SamPolicy(Policy): suffix="attn.qkv", target_module=col_nn.FusedLinear1D_Col, kwargs={ - "n_fused": 3, + "split_sizes": [self.model.config.vision_config.hidden_size] * 3, "fp8_communication": self.shard_config.fp8_communication, }, ), diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index 5aa8584a0..923075e0e 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -41,21 +41,6 @@ class Conv1D(nn.Module): return x -def rearrange(tensor: torch.Tensor, dim: int): - tensor = tensor.clone() - world_size = 2 - order = torch.arange(world_size * 3) - new_order = [] - for i in range(world_size): - new_order.append(order[i::world_size]) - new_order = torch.cat(new_order) - - tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim) - rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order] - rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim) - return rearanged_tensor - - def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() @@ -66,7 +51,7 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b process_group=None, gather_output=True, seq_parallel_mode=seq_parallel_mode, - n_fused=3, + split_sizes=[64] * 3, overlap=overlap, ) @@ -88,13 +73,13 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] ) gather_out = linear_conv_col(x_for_shard) - assert_close(rearrange(out, -1), gather_out) + assert_close(out, gather_out) # check backward correctness out.sum().backward() gather_out.sum().backward() - target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True) + target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [64] * 3, None, True) assert_close(target_grad, linear_conv_col.weight.grad) diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py index dc14fd591..fccba564f 100644 --- a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py @@ -2,13 +2,12 @@ import os from contextlib import nullcontext import torch -import torch.distributed as dist import torch.nn as nn from torch.testing import assert_close import colossalai from colossalai.lazy import LazyInitContext -from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row +from colossalai.shardformer.layer import FusedLinear1D_Col, FusedLinear1D_Row from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -16,93 +15,55 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" -class Conv1D(nn.Module): - """ - 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). - - Basically works like a linear layer but the weights are transposed. - - Args: - nf (`int`): The number of output features. - nx (`int`): The number of input features. - """ - - def __init__(self, nf, nx): - super().__init__() - self.nf = nf - self.weight = nn.Parameter(torch.empty(nx, nf)) - self.bias = nn.Parameter(torch.zeros(nf)) - nn.init.normal_(self.weight, std=0.02) - - def forward(self, x): - size_out = x.size()[:-1] + (self.nf,) - x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) - x = x.view(size_out) - return x - - -def rearrange(tensor: torch.Tensor, dim: int): - tensor = tensor.clone() - world_size = 2 - order = torch.arange(world_size * 3) - new_order = [] - for i in range(world_size): - new_order.append(order[i::world_size]) - new_order = torch.cat(new_order) - - tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim) - rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order] - rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim) - return rearanged_tensor - - @parameterize("lazy_init", [False, True]) -def check_linear_conv_1d_col(lazy_init: bool): +def check_linear_1d_col(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() - linear = Conv1D(192, 48).cuda() + linear = nn.Linear(8, 80).cuda() with ctx: - linear_copy = Conv1D(192, 48).cuda() - linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module( - linear_copy, process_group=None, gather_output=True, n_fused=3 + linear_copy = nn.Linear(8, 80).cuda() + linear_col = FusedLinear1D_Col.from_native_module( + linear_copy, process_group=None, gather_output=True, split_sizes=[32, 32, 16] ) - assert linear.weight.shape == torch.Size([48, 192]) - assert linear.bias.shape == torch.Size([192]) - assert linear_conv_col.weight.shape == torch.Size([48, 96]) - assert linear_conv_col.bias.shape == torch.Size([96]) - assert linear_copy.weight is linear_conv_col.weight - assert linear_copy.bias is linear_conv_col.bias + assert linear.weight.shape == torch.Size([80, 8]) + assert linear.bias.shape == torch.Size([80]) + assert linear_col.weight.shape == torch.Size([40, 8]) + assert linear_col.bias.shape == torch.Size([40]) + assert linear_copy.weight is linear_col.weight + assert linear_copy.bias is linear_col.bias # ensure weights are reversibly loadable - linear_conv_col.load_state_dict(linear.state_dict()) - linear.load_state_dict(linear_conv_col.state_dict()) + linear_col.load_state_dict(linear.state_dict()) + linear.load_state_dict(linear_col.state_dict()) # check computation correctness - x = torch.rand(4, 48).cuda() + x = torch.rand(4, 8).cuda() out = linear(x) - gather_out = linear_conv_col(x) - assert_close(rearrange(out, 1), gather_out) + gather_out = linear_col(x) + assert_close(out, gather_out) # check backward correctness out.sum().backward() gather_out.sum().backward() - target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True) - assert_close(target_grad, linear_conv_col.weight.grad) + target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [32, 32, 16], None, False) + assert_close(target_grad, linear_col.weight.grad) @parameterize("lazy_init", [False, True]) -def check_linear_conv_1d_row(lazy_init: bool): +def check_linear_1d_row(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() - linear = Conv1D(192, 48).cuda() + linear = nn.Linear(80, 8).cuda() with ctx: - linear_copy = Conv1D(192, 48).cuda() - linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) + linear_copy = nn.Linear(80, 8).cuda() + linear_row = FusedLinear1D_Row.from_native_module( + linear_copy, process_group=None, split_sizes=[32, 32, 16], parallel_input=False + ) - assert linear.weight.shape == torch.Size([48, 192]) - assert linear_row.weight.shape == torch.Size([24, 192]) - assert linear_row.bias.shape == torch.Size([192]) + assert linear.weight.shape == torch.Size([8, 80]) + assert linear_row.weight.shape == torch.Size([8, 40]) + assert linear_row.bias.shape == torch.Size([8]) assert linear_copy.weight is linear_row.weight assert linear_copy.bias is linear_row.bias @@ -111,7 +72,7 @@ def check_linear_conv_1d_row(lazy_init: bool): linear.load_state_dict(linear_row.state_dict()) # check computation correctness - x = torch.rand(4, 48).cuda() + x = torch.rand(4, 80).cuda() out = linear(x) gather_out = linear_row(x) assert_close(out, gather_out) @@ -120,17 +81,51 @@ def check_linear_conv_1d_row(lazy_init: bool): out.sum().backward() gather_out.sum().backward() - rank = dist.get_rank() - target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank] + target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [32, 32, 16], None, True) assert_close(target_grad, linear_row.weight.grad) +@parameterize("lazy_init", [False, True]) +def check_linear_1d_col_row(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + linear1 = nn.Linear(8, 80).cuda() + linear2 = nn.Linear(80, 8).cuda() + with ctx: + linear1_copy = nn.Linear(8, 80).cuda() + linear2_copy = nn.Linear(80, 8).cuda() + linear_col = FusedLinear1D_Col.from_native_module(linear1_copy, process_group=None, split_sizes=[32, 32, 16]) + linear_row = FusedLinear1D_Row.from_native_module( + linear2_copy, + process_group=None, + split_sizes=[32, 32, 16], + ) + # ensure weights are reversibly loadable + linear_col.load_state_dict(linear1.state_dict()) + linear_row.load_state_dict(linear2.state_dict()) + + # check computation correctness + x = torch.rand(4, 8).cuda() + target_out = linear2(linear1(x)) + out = linear_row(linear_col(x)) + assert_close(out, target_out) + + # check backward correctness + target_out.sum().backward() + out.sum().backward() + + target_grad1 = split_fused_qkv_in_gpt2_style(linear1.weight.grad, [32, 32, 16], None, False) + assert_close(target_grad1, linear_col.weight.grad) + target_grad2 = split_fused_qkv_in_gpt2_style(linear2.weight.grad, [32, 32, 16], None, True) + assert_close(target_grad2, linear_row.weight.grad) + + def run_dist(rank, world_size, port): colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - # test for linear conv - check_linear_conv_1d_col() - check_linear_conv_1d_row() + check_linear_1d_col() + check_linear_1d_row() + check_linear_1d_col_row() @rerun_if_address_is_in_use()