From d2e05a99b3a776eb0f438d61b74065c9633c7391 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 30 Oct 2024 02:54:32 +0000 Subject: [PATCH] [feat] support no tensor parallel Linear in shardformer; Add test for use weightGradStore and not use WeightGradStore --- colossalai/shardformer/layer/__init__.py | 3 +- colossalai/shardformer/layer/_operation.py | 106 +++++++++++- colossalai/shardformer/layer/linear.py | 155 +++++++++++++++++- .../test_layer/test_linear_1d.py | 92 ++++++++++- 4 files changed, 352 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 8882a33c1..613ce73c3 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -2,7 +2,7 @@ from ._operation import all_to_all_comm from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D -from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D +from .linear import Linear1D, Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D from .loss import cross_entropy_1d, dist_cross_entropy from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule @@ -11,6 +11,7 @@ from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2 __all__ = [ "Embedding1D", "VocabParallelEmbedding1D", + "Linear1D", "Linear1D_Col", "Linear1D_Row", "GPT2FusedLinearConv1D_Col", diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 4a0800468..46f50ef02 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -154,7 +154,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function): wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_) def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): - # _grad_output_.t().matmul(_input_) return wgrad_gemm_func(_grad_output_.t(), _input_) # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias. @@ -236,6 +235,107 @@ class LinearWithAsyncCommunication(torch.autograd.Function): return grad_input, grad_weight, grad_bias, None, None, None, None +class LinearBase(torch.autograd.Function): + """ + Linear layer baseline (no tensor parallel version). + """ + + @staticmethod + def forward(ctx, input_, weight, bias, async_grad_allreduce, fp8_communication=False, use_zbv=False): + ctx.save_for_backward(input_, weight, bias) + ctx.use_bias = bias is not None + ctx.async_grad_allreduce = async_grad_allreduce + ctx.fp8_communication = fp8_communication + ctx.use_zbv = use_zbv + if bias is not None: + output = F.linear(input_, weight, bias) + else: + output = F.linear(input_, weight) + + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + use_bias = ctx.use_bias + ctx.fp8_communication + use_zbv = ctx.use_zbv + + def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): + wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_) + + def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): + return wgrad_gemm_func(_grad_output_.t(), _input_) + + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias. + if use_bias: + bias.view(bias.shape) + + total_input = input.contiguous() + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + if _grad_accum_fusion_available and weight.grad is not None: + grad = weight.grad + if use_zbv: + # TODO: append input, grad_output_, weight, grad func to WeightGradStore + if grad.dtype == torch.float32: + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, + ), + ) + grad_weight = None + elif grad.dtype in (torch.float16, torch.bfloat16): + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, + ), + ) + grad_weight = None + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + else: + if grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) + grad_weight = None + elif grad.dtype == torch.float16: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + else: + if use_zbv: + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass, + wgrad_gemm_func=torch.matmul, + ), + ) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + + grad_bias = grad_output.sum(dim=0) if use_bias else None + + return grad_input, grad_weight, grad_bias, None, None, None, None + + def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False): # currently only support one single tensor as output group_size = dist.get_world_size(process_group) @@ -1101,6 +1201,10 @@ def linear_with_async_comm( ) +def linear_base(input_, weight, bias, async_grad_allreduce, fp8_communication=False, use_zbv=False): + return LinearBase.apply(input_, weight, bias, async_grad_allreduce, fp8_communication, use_zbv) + + def linear_gather_forward_reducescatter_backward( input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False ): diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index a8a3be63a..cb1496a0b 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -25,6 +25,7 @@ from colossalai.tensor.d_tensor.api import ( from ._operation import ( gather_forward_reducescatter_backward, gather_forward_split_backward, + linear_base, linear_gather_forward_reducescatter_backward, linear_reducescatter_forward_gather_backward, linear_with_async_comm, @@ -35,7 +36,159 @@ from ._operation import ( from .parallel_module import PaddingParallelModule, ParallelModule from .utils import create_randomizer_with_offset -__all__ = ["Linear1D_Col", "Linear1D_Row"] +__all__ = ["Linear1D", "Linear1D_Col", "Linear1D_Row"] + + +class Linear1D(ParallelModule): + r"""Linear layer with no parallelism. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (`typing.Callable`): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (`typing.Callable`): + The initializer of bias, defaults to xavier uniform initializer. + + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + gather_output: bool = False, + seq_parallel_mode: str = None, + seq_parallel_dim: int = 1, + overlap: torch.cuda.Stream = None, + skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + fp8_communication: bool = False, + use_zbv: bool = False, + **kwargs, + ): + super().__init__(weight=weight, bias_=bias_, **kwargs) + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.seq_parallel_mode = seq_parallel_mode + self.seq_parallel_dim = seq_parallel_dim + self.overlap = overlap + self.skip_bias_add = skip_bias_add + self.device = device + self.fp8_communication = fp8_communication + self.use_zbv = use_zbv + + if skip_bias_add and not bias: + raise ValueError("cannot skip bias addition if bias is None") + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + + self.randomizer = create_randomizer_with_offset(seed, process_group=None) + + # sanity check + if weight is not None: + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" + else: + assert bias_ is None, "bias_ must be None if weight is None" + + # Parameters. + if weight is None: + factory_kwargs = {"device": device, "dtype": dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + + if bias: + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ + else: + self.bias = None + + if weight is None: + # init weights + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module(module: nn.Linear, **kwargs) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + linear_1d = Linear1D( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + weight=module.weight, + bias_=module.bias, + **kwargs, + ) + + return linear_1d + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) + + # Set up backprop all-reduce. + input_parallel = input_ + + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + output_parallel = linear_base( + input_parallel, + self.weight, + bias, + False, + fp8_communication=self.fp8_communication, + use_zbv=self.use_zbv, + ) + + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output class Linear1D_Col(ParallelModule): diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index 541aa3251..0556bc986 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -8,7 +8,8 @@ from torch.testing import assert_close import colossalai from colossalai.lazy import LazyInitContext -from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row +from colossalai.pipeline.weight_grad_store import WeightGradStore +from colossalai.shardformer.layer import Linear1D, Linear1D_Col, Linear1D_Row from colossalai.tensor.d_tensor import is_distributed_tensor from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -117,6 +118,93 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel_mode: bool): assert_close(x_for_unshard.grad, x_for_shard.grad) +def check_linear_without_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + linear = nn.Linear(32, 128).cuda() + with ctx: + linear_copy = nn.Linear(32, 128).cuda() + linear_base = Linear1D.from_native_module( + linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=False + ) + assert linear_base.weight.shape == torch.Size([128, 32]) + assert linear_base.bias.shape == torch.Size([128]) + assert linear_copy.weight is linear_base.weight + assert linear_copy.bias is linear_base.bias + + linear.load_state_dict(linear_base.state_dict()) + linear_base.load_state_dict(linear.state_dict()) + + # check computation correctness + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() + x_for_unshard = x.expand_as(x.clone()) + x_for_unshard.requires_grad_(True) + x_for_shard = x.expand_as(x.clone()) + x_for_shard.requires_grad_(True) + + # run forward + out = linear(x_for_unshard) + gather_out = linear_base(x_for_shard) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + assert_close(linear.weight.grad, linear_base.weight.grad) + # check the input gradients + assert x_for_shard.grad is not None + assert x_for_unshard.grad is not None + assert_close(x_for_unshard.grad, x_for_shard.grad) + + +def check_linear_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + linear = nn.Linear(32, 128).cuda() + with ctx: + linear_copy = nn.Linear(32, 128).cuda() + linear_base = Linear1D.from_native_module( + linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=True + ) + assert linear_base.weight.shape == torch.Size([128, 32]) + assert linear_base.bias.shape == torch.Size([128]) + assert linear_copy.weight is linear_base.weight + assert linear_copy.bias is linear_base.bias + + linear.load_state_dict(linear_base.state_dict()) + linear_base.load_state_dict(linear.state_dict()) + + # check computation correctness + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() + x_for_unshard = x.expand_as(x.clone()) + x_for_unshard.requires_grad_(True) + x_for_shard = x.expand_as(x.clone()) + x_for_shard.requires_grad_(True) + + # run forward + out = linear(x_for_unshard) + gather_out = linear_base(x_for_shard) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + # Weight grad is None before we do WeightGradStore pop + assert linear_base.weight.grad is None + # after WeightGradStore pop (dw computation complete), we assert weight grad + WeightGradStore.flush(chunk=0) # flush buffer to chunk 0 Queue + WeightGradStore.pop(chunk=0) + assert_close(linear.weight.grad, linear_base.weight.grad) + + # check the input gradients + assert x_for_shard.grad is not None + assert x_for_unshard.grad is not None + assert_close(x_for_unshard.grad, x_for_shard.grad) + + def check_linear_col_plus_row(lazy_init: bool, seq_parallel_mode: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() @@ -182,6 +270,8 @@ def run_dist_linear_test(lazy_init, seq_parallel_mode, overlap): check_linear_1d_col(lazy_init, seq_parallel_mode, overlap) check_linear_1d_row(lazy_init, seq_parallel_mode) check_linear_col_plus_row(lazy_init, seq_parallel_mode, overlap) + check_linear_without_weight_grad_store(lazy_init, seq_parallel_mode) + check_linear_with_weight_grad_store(lazy_init, seq_parallel_mode) def check_dist_linear(rank, world_size, port):