From 726541afe2cde5c6f547968a3d232bbb8b3f5f14 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Tue, 1 Aug 2023 18:02:49 +0800 Subject: [PATCH] update some module with new api version --- .../shardformer/layer/qkv_fused_linear.py | 85 ++++++++++++------- colossalai/shardformer/policies/blip2.py | 2 +- colossalai/shardformer/policies/chatglm.py | 2 +- colossalai/shardformer/policies/sam.py | 2 +- colossalai/shardformer/policies/whisper.py | 2 +- .../test_gpt2_qkv_fused_linear_1d.py | 38 +++++++-- .../test_model/test_shard_chatglm.py | 5 +- 7 files changed, 88 insertions(+), 48 deletions(-) diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 1e4b6ecb6..42417f8bc 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -537,10 +537,11 @@ class FusedLinear1D_Col(ParallelModule): gather_output: bool = False, skip_bias_add: bool = False, n_fused: int = 3, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() - # Keep input parameters self.in_features = in_features self.out_features = out_features @@ -554,36 +555,52 @@ class FusedLinear1D_Col(ParallelModule): if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + # sanity check + if weight is not None: + assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + else: + assert bias_ is None, 'bias_ must be None if weight is None' + # Parameters. - # Initialize weight. - factory_kwargs = {'device': device, 'dtype': dtype} - weight = torch.empty(self.out_features, self.in_features, **factory_kwargs) + if weight is None: + # Initialize weight. + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight def shard_fn(tensor): return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False) def gather_fn(tensor): - return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, False) + return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False) - with torch.no_grad(): - sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn) - self.weight = customized_distributed_tensor_to_param(sharded_weight) + if not is_customized_distributed_tensor(self.weight): + with torch.no_grad(): + sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn) + customized_distributed_tensor_to_existing_param(sharded_weight, self.weight) if bias: - bias = torch.empty(self.out_features, **factory_kwargs) - - with torch.no_grad(): - sharded_bias = distribute_tensor_with_customization(bias, shard_fn, gather_fn) - self.bias = customized_distributed_tensor_to_param(sharded_bias) + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ + if not is_customized_distributed_tensor(self.bias): + with torch.no_grad(): + sharded_bias = distribute_tensor_with_customization(self.bias.data, shard_fn, gather_fn) + customized_distributed_tensor_to_existing_param(sharded_bias, self.bias) else: self.bias = None - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - # init weights - self.reset_parameters(weight_initializer, bias_initializer) + if weight is None: + # init weights + self.reset_parameters(weight_initializer, bias_initializer) @staticmethod def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, @@ -613,24 +630,26 @@ class FusedLinear1D_Col(ParallelModule): bias=bias, device=device, process_group=process_group, + weight=module.weight, + bias_=module.bias, *args, **kwargs) - # TODO: copy the sharded weights - with torch.no_grad(): - sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data, - n_fused=n_fused, - process_group=process_group, - is_transposed=False) - linear_1d.weight.data.copy_(sharded_weight.data) - - if bias: - sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data, - n_fused=n_fused, - process_group=process_group, - is_transposed=False) - linear_1d.bias.data.copy_(sharded_bias.data) + # # TODO: copy the sharded weights + # with torch.no_grad(): + # sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data, + # n_fused=n_fused, + # process_group=process_group, + # is_transposed=False) + # linear_1d.weight.data.copy_(sharded_weight.data) + # if bias: + # sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data, + # n_fused=n_fused, + # process_group=process_group, + # is_transposed=False) + # linear_1d.bias.data.copy_(sharded_bias.data) + print(linear_1d.weight.shape) return linear_1d def reset_parameters(self, weight_initializer, bias_initializer) -> None: diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 43aa1adc1..a244d70b5 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -4,7 +4,7 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ from ..modeling.blip2 import forward_fn -from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ['BlipPolicy', 'BlipModelPolicy'] diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py index 46aa3b52a..732a817b0 100644 --- a/colossalai/shardformer/policies/chatglm.py +++ b/colossalai/shardformer/policies/chatglm.py @@ -4,7 +4,7 @@ import torch.nn as nn import colossalai.shardformer.layer as col_nn -from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ['ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy'] diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index e75d63946..ca20fff71 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -4,7 +4,7 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ from ..modeling.sam import forward_fn -from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ['SamPolicy', 'SamModelPolicy'] diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 7751bbb5d..2f3565bda 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -3,7 +3,7 @@ import torch.nn as nn import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ -from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ 'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy', 'WhisperForAudioClassification' diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index 9eeda93af..b45cd172c 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -1,12 +1,15 @@ +from contextlib import nullcontext + import torch import torch.distributed as dist import torch.nn as nn from torch.testing import assert_close import colossalai +from colossalai.lazy import LazyInitContext from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn # This code is copied from https://github.com/huggingface/transformers @@ -50,9 +53,13 @@ def rearrange(tensor: torch.Tensor, dim: int): return rearanged_tensor -def check_gpt2_linear_conv_1d_col(): +@parameterize('lazy_init', [False, True]) +def check_linear_conv_1d_col(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() - linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear, + with ctx: + linear_copy = Conv1D(192, 48).cuda() + linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True, n_fused=3) @@ -61,6 +68,8 @@ def check_gpt2_linear_conv_1d_col(): assert linear.bias.shape == torch.Size([192]) assert linear_conv_col.weight.shape == torch.Size([48, 96]) assert linear_conv_col.bias.shape == torch.Size([96]) + assert linear_copy.weight is linear_conv_col.weight + assert linear_copy.bias is linear_conv_col.bias # ensure weights are reversibly loadable linear_conv_col.load_state_dict(linear.state_dict()) @@ -80,13 +89,24 @@ def check_gpt2_linear_conv_1d_col(): assert_close(target_grad, linear_conv_col.weight.grad) -def check_gpt2_linear_conv_1d_row(): +@parameterize('lazy_init', [False, True]) +def check_linear_conv_1d_row(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + linear = Conv1D(192, 48).cuda() - linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False) + with ctx: + linear_copy = Conv1D(192, 48).cuda() + linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) assert linear.weight.shape == torch.Size([48, 192]) assert linear_row.weight.shape == torch.Size([24, 192]) assert linear_row.bias.shape == torch.Size([192]) + assert linear_copy.weight is linear_row.weight + assert linear_copy.bias is linear_row.bias + + # ensure weights are reversibly loadable + linear_row.load_state_dict(linear.state_dict()) + linear.load_state_dict(linear_row.state_dict()) # check computation correctness x = torch.rand(4, 48).cuda() @@ -107,14 +127,14 @@ def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # test for linear conv - check_gpt2_linear_conv_1d_col() - check_gpt2_linear_conv_1d_row() + check_linear_conv_1d_col() + check_linear_conv_1d_row() @rerun_if_address_is_in_use() -def test_gpt2_linearconv(): +def test_linearconv(): spawn(run_dist, nprocs=2) if __name__ == '__main__': - test_gpt2_linearconv() + test_linearconv() diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index a0fa4bd82..36f240a0f 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -84,9 +84,10 @@ def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism): model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) if name == "transformers_chatglm": - sharded_model = shard_former.optimize(model_copy, ChatGLMModelPolicy()).cuda() + sharded_model, _ = shard_former.optimize(model_copy, ChatGLMModelPolicy()) else: - sharded_model = shard_former.optimize(model_copy, ChatGLMForConditionalGenerationPolicy()).cuda() + sharded_model, _ = shard_former.optimize(model_copy, ChatGLMForConditionalGenerationPolicy()) + sharded_model = sharded_model.cuda() check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache()