diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 7fad4948d..7cdcfc318 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -3,10 +3,11 @@ from .embedding import Embedding1D, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm +from .parallel_module import ParallelModule from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row __all__ = [ "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", - 'FusedLayerNorm', 'FusedRMSNorm' + 'FusedLayerNorm', 'FusedRMSNorm', 'ParallelModule' ] diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index 07341ef73..09b22abb1 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Callable, List, Union +from typing import Callable, List, Optional, Union import torch import torch.distributed as dist @@ -13,7 +13,12 @@ from torch.distributed import ProcessGroup from colossalai.lazy import LazyInitContext from colossalai.nn import init as init from colossalai.nn.layer.utils import divide -from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param +from colossalai.tensor.d_tensor.api import ( + is_distributed_tensor, + shard_colwise, + shard_rowwise, + sharded_tensor_to_existing_param, +) from ._operation import gather_forward_split_backward, reduce_forward from .parallel_module import ParallelModule @@ -60,6 +65,7 @@ class Embedding1D(ParallelModule): device: torch.device = None, process_group: ProcessGroup = None, gather_output: bool = True, + weight: Optional[nn.Parameter] = None, weight_initializer: Callable = init.normal_(), *args, **kwargs): @@ -74,18 +80,24 @@ class Embedding1D(ParallelModule): self.embed_kwargs = kwargs self.gather_output = gather_output - # Parameters. - factory_kwargs = {'device': device, 'dtype': dtype} - weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs) - sharded_weight = shard_colwise(weight, process_group) - self.weight = sharded_tensor_to_param(sharded_weight) - # offset the seed with randomizer index and rank seed = torch.random.initial_seed() self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer) + # Parameters. + if weight is None: + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + if not is_distributed_tensor(self.weight): + sharded_weight = shard_colwise(self.weight.data, process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) + + if weight is None: + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer) @staticmethod def from_native_module(module: nn.Embedding, @@ -121,14 +133,10 @@ class Embedding1D(ParallelModule): norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, sparse=sparse, + weight=module.weight, *args, **kwargs) - # copy the weight - with torch.no_grad(): - sharded_weight = shard_colwise(module.weight.data, process_group) - embedding.weight.copy_(sharded_weight) - return embedding def reset_parameters(self, weight_initializer) -> None: @@ -143,7 +151,6 @@ class Embedding1D(ParallelModule): def forward(self, input_: Tensor) -> Tensor: output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - if self.gather_output: output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) return output @@ -188,6 +195,7 @@ class VocabParallelEmbedding1D(ParallelModule): dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, + weight: Optional[nn.Parameter] = None, weight_initializer: Callable = init.normal_(), *args, **kwargs): @@ -207,16 +215,23 @@ class VocabParallelEmbedding1D(ParallelModule): self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition - # parameter - factory_kwargs = {'device': device, 'dtype': dtype} - weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs) - sharded_weight = shard_rowwise(weight, process_group) - self.weight = sharded_tensor_to_param(sharded_weight) - # offset the seed with randomizer index and rank seed = torch.random.initial_seed() self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - self.reset_parameters(weight_initializer) + + # parameter + if weight is None: + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + if not is_distributed_tensor(self.weight): + sharded_weight = shard_rowwise(self.weight.data, process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) + + if weight is None: + self.reset_parameters(weight_initializer) @staticmethod def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, @@ -243,15 +258,10 @@ class VocabParallelEmbedding1D(ParallelModule): padding_idx=padding_idx, device=device, process_group=process_group, + weight=module.weight, *args, **kwargs) - with torch.no_grad(): - # shard and slice the weight along the vocabulary(num_embeddings) dimension - # the shape of the weight is (num_embeddings, embedding_dim) - shard_weight = shard_rowwise(module.weight.data, process_group) - vocab_embedding_1d.weight.data.copy_(shard_weight) - return vocab_embedding_1d def reset_parameters(self, weight_initializer) -> None: diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 383d9b3f5..bb36854bd 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -2,7 +2,7 @@ # -*- encoding: utf-8 -*- import math -from typing import Callable, List, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -15,7 +15,12 @@ from torch.nn.parameter import Parameter from colossalai.lazy import LazyInitContext from colossalai.nn import init as init from colossalai.nn.layer.utils import divide -from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param +from colossalai.tensor.d_tensor.api import ( + is_distributed_tensor, + shard_colwise, + shard_rowwise, + sharded_tensor_to_existing_param, +) from ._operation import ( gather_forward_split_backward, @@ -65,6 +70,8 @@ class Linear1D_Col(ParallelModule): process_group: ProcessGroup = None, gather_output: bool = False, skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() @@ -80,26 +87,42 @@ class Linear1D_Col(ParallelModule): if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') - # Parameters. - factory_kwargs = {'device': device, 'dtype': dtype} - - weight = torch.empty(self.out_features, self.in_features, **factory_kwargs) - sharded_weight = shard_rowwise(weight, self.process_group) - self.weight = sharded_tensor_to_param(sharded_weight) - - if bias: - bias = torch.empty(self.out_features, **factory_kwargs) - sharded_bias = shard_colwise(bias, self.process_group) - self.bias = sharded_tensor_to_param(sharded_bias) - else: - self.bias = None - # offset the seed with randomizer index and rank seed = torch.random.initial_seed() self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - # init weights - self.reset_parameters(weight_initializer, bias_initializer) + # sanity check + if weight is not None: + assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + else: + assert bias_ is None, 'bias_ must be None if weight is None' + + # Parameters. + if weight is None: + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + if not is_distributed_tensor(self.weight): + sharded_weight = shard_rowwise(self.weight.data, self.process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) + + if bias: + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ + if not is_distributed_tensor(self.bias): + sharded_bias = shard_colwise(self.bias.data, self.process_group) + sharded_tensor_to_existing_param(sharded_bias, self.bias) + else: + self.bias = None + + if weight is None: + # init weights + self.reset_parameters(weight_initializer, bias_initializer) @staticmethod def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, @@ -125,17 +148,11 @@ class Linear1D_Col(ParallelModule): bias=bias, device=device, process_group=process_group, + weight=module.weight, + bias_=module.bias, *args, **kwargs) - with torch.no_grad(): - # the weight to the linear layer is a transpose - # thus shard on row is equal to shard on column - sharded_weight = shard_rowwise(module.weight.data, process_group) - linear_1d.weight.data.copy_(sharded_weight) - if bias: - sharded_bias = shard_colwise(module.bias.data, process_group) - linear_1d.bias.copy_(sharded_bias) return linear_1d def reset_parameters(self, weight_initializer, bias_initializer) -> None: @@ -198,6 +215,8 @@ class Linear1D_Row(ParallelModule): process_group: ProcessGroup = None, parallel_input: bool = True, skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), stream_chunk_num: int = 1): @@ -216,27 +235,44 @@ class Linear1D_Row(ParallelModule): if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') - # Parameters. - # Initialize weight. - factory_kwargs = {'device': device, 'dtype': dtype} - weight = torch.empty(self.out_features, self.in_features, **factory_kwargs) - sharded_weight = shard_colwise(weight, self.process_group) - self.weight = sharded_tensor_to_param(sharded_weight) - - if self.stream_chunk_num > 1: - # TODO() work for inference only - self.chunk_weight() - if bias: - self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) - else: - self.bias = None - # offset the seed with randomizer index and rank seed = torch.random.initial_seed() self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer, bias_initializer) + # sanity check + if weight is not None: + assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + else: + assert bias_ is None, 'bias_ must be None if weight is None' + + # Parameters. + if weight is None: + # Initialize weight. + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + if not is_distributed_tensor(self.weight): + sharded_weight = shard_colwise(self.weight.data, self.process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) + + if self.stream_chunk_num > 1: + # TODO() work for inference only + self.chunk_weight() + + if bias: + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ + else: + self.bias = None + + if weight is None: + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer, bias_initializer) @staticmethod def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, @@ -262,19 +298,11 @@ class Linear1D_Row(ParallelModule): bias=bias, device=device, process_group=process_group, + weight=module.weight, + bias_=module.bias, *args, **kwargs) - # TODO: copy the sharded weights - with torch.no_grad(): - # the weigh to the linear layer is a transpose - # thus shard on col is equal to shard on row - sharded_weight = shard_colwise(module.weight.data, process_group) - linear_1d.weight.data.copy_(sharded_weight) - - if bias: - linear_1d.bias.copy_(module.bias.data) - return linear_1d def chunk_weight(self): diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 9bb7738c0..0aea29566 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -60,10 +60,8 @@ class FusedLayerNorm(): layernorm = ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) - with torch.no_grad(): - # copy weight and bias - layernorm.weight.copy_(module.weight) - layernorm.bias.copy_(module.bias) + layernorm.weight = module.weight + layernorm.bias = module.bias return layernorm @@ -101,8 +99,6 @@ class FusedRMSNorm(): rmsnorm = ApexFusedRMSNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine) - with torch.no_grad(): - # copy weight and bias - rmsnorm.weight.copy_(module.weight) + rmsnorm.weight = module.weight return rmsnorm diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index c94d93069..bcefcf058 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -2,12 +2,11 @@ # -*- encoding: utf-8 -*- import math -from typing import Callable, List, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.distributed as dist import torch.nn as nn -import torch.nn.functional as F from torch import Tensor from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter @@ -16,10 +15,12 @@ from colossalai.lazy import LazyInitContext from colossalai.nn import init as init from colossalai.nn.layer.utils import divide from colossalai.tensor.d_tensor.api import ( - customized_distributed_tensor_to_param, + customized_distributed_tensor_to_existing_param, distribute_tensor_with_customization, + is_customized_distributed_tensor, + is_distributed_tensor, shard_rowwise, - sharded_tensor_to_param, + sharded_tensor_to_existing_param, ) from ._operation import ( @@ -173,6 +174,8 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): gather_output: bool = False, skip_bias_add: bool = False, n_fused: int = 3, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() @@ -190,40 +193,56 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + # sanity check + if weight is not None: + assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + else: + assert bias_ is None, 'bias_ must be None if weight is None' + # Parameters. - # Initialize weight. - factory_kwargs = {'device': device, 'dtype': dtype} - weight = torch.empty(self.in_features, self.out_features, **factory_kwargs) + if weight is None: + # Initialize weight. + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.in_features, self.out_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight def shard_fn(tensor): return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True) def gather_fn(tensor): - return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, True) + return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True) - with torch.no_grad(): - sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn) - self.weight = customized_distributed_tensor_to_param(sharded_weight) + if not is_customized_distributed_tensor(self.weight): + with torch.no_grad(): + sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn) + customized_distributed_tensor_to_existing_param(sharded_weight, self.weight) if bias: - bias = torch.empty(self.out_features, **factory_kwargs) - - with torch.no_grad(): - sharded_bias = distribute_tensor_with_customization(bias, shard_fn, gather_fn) - self.bias = customized_distributed_tensor_to_param(sharded_bias) + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ + if not is_customized_distributed_tensor(self.bias): + with torch.no_grad(): + sharded_bias = distribute_tensor_with_customization(self.bias.data, shard_fn, gather_fn) + customized_distributed_tensor_to_existing_param(sharded_bias, self.bias) else: self.bias = None - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - # init weights - self.reset_parameters(weight_initializer, bias_initializer) + if weight is None: + # init weights + self.reset_parameters(weight_initializer, bias_initializer) @staticmethod - def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, - *args, **kwargs) -> ParallelModule: + def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: r""" Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer. @@ -250,24 +269,11 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): bias=bias, device=device, process_group=process_group, + weight=module.weight, + bias_=module.bias, *args, **kwargs) - # TODO: copy the sharded weights - with torch.no_grad(): - sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data, - n_fused=n_fused, - process_group=process_group, - is_transposed=True) - linear_1d.weight.data.copy_(sharded_weight.data) - - if bias: - sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data, - n_fused=n_fused, - process_group=process_group, - is_transposed=True) - linear_1d.bias.data.copy_(sharded_bias.data) - return linear_1d def reset_parameters(self, weight_initializer, bias_initializer) -> None: @@ -333,6 +339,8 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): process_group: ProcessGroup = None, parallel_input: bool = True, skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), stream_chunk_num: int = 1): @@ -351,30 +359,46 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + # Divide the weight matrix along the last dimension. self.input_size_per_partition = divide(in_features, self.num_partitions) + # sanity check + if weight is not None: + assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + else: + assert bias_ is None, 'bias_ must be None if weight is None' + # Parameters. - # Initialize weight. - factory_kwargs = {'device': device, 'dtype': dtype} - weight = torch.empty(self.in_features, self.out_features, **factory_kwargs) - sharded_weight = shard_rowwise(weight, self.process_group) - self.weight = sharded_tensor_to_param(sharded_weight) + if weight is None: + # Initialize weight. + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.in_features, self.out_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + if not is_distributed_tensor(self.weight): + sharded_weight = shard_rowwise(self.weight.data, self.process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) if self.stream_chunk_num > 1: # TODO() work for inference only self.chunk_weight() if bias: - self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ else: self.bias = None - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - # init weights - self.reset_parameters(weight_initializer, bias_initializer) + if weight is None: + # init weights + self.reset_parameters(weight_initializer, bias_initializer) @staticmethod def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, @@ -400,19 +424,11 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): bias=bias, device=device, process_group=process_group, + weight=module.weight, + bias_=module.bias, *args, **kwargs) - # TODO: copy the sharded weights - with torch.no_grad(): - # the weigh to the linear layer is a transpose - # thus shard on col is equal to shard on row - sharded_weight = shard_rowwise(module.weight.data, process_group) - linear_1d.weight.data.copy_(sharded_weight.data) - - if bias: - linear_1d.bias.copy_(module.bias.data) - return linear_1d def chunk_weight(self): diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 1af26f504..0a1a46621 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,13 +1,11 @@ from functools import partial -from types import MethodType -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional import torch import torch.nn as nn from torch import Tensor from torch.nn import CrossEntropyLoss, Module from transformers.modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, MultipleChoiceModelOutput, @@ -28,12 +26,11 @@ from transformers.models.bert.modeling_bert import ( BertLMHeadModel, BertModel, ) -from transformers.utils import ModelOutput, logging +from transformers.utils import logging import colossalai.shardformer.layer as col_nn from colossalai.pipeline.stage_manager import PipelineStageManager -from .._utils import getattr_, setattr_ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription logger = logging.get_logger(__name__) @@ -177,6 +174,17 @@ class BertPolicy(Policy): target_key=BertLMPredictionHead) return base_policy + def add_lm_prediction_policy(self, base_policy): + from transformers.models.bert.modeling_bert import BertLMPredictionHead + method_replacement = { + '_save_to_state_dict': col_nn.ParallelModule._save_to_state_dict, + '_load_from_state_dict': col_nn.ParallelModule._load_from_state_dict, + } + self.append_or_create_method_replacement(description=method_replacement, + policy=base_policy, + target_key=BertLMPredictionHead) + return base_policy + def postprocess(self): return self.model @@ -240,6 +248,7 @@ class BertForPreTrainingPolicy(BertPolicy): def module_policy(self): policy = super().module_policy() policy = self.add_lm_head_policy(policy) + policy = self.add_lm_prediction_policy(policy) from transformers.models.bert.modeling_bert import BertForPreTraining self.set_pipeline_forward(model_cls=BertForPreTraining, new_forward=bert_for_pretraining_forward, policy=policy) return policy @@ -266,21 +275,13 @@ class BertForPreTrainingPolicy(BertPolicy): model = self.model if self.pipeline_stage_manager: if id(model.bert.embeddings.word_embeddings.weight) == id(model.cls.predictions.decoder.weight): - #tie weights + # tie weights return [{ 0: model.bert.embeddings.word_embeddings.weight, self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight }] return [] - def postprocess(self): - if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: - binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) - return self.model - # BertLMHeadModel class BertLMHeadModelPolicy(BertPolicy): @@ -291,6 +292,7 @@ class BertLMHeadModelPolicy(BertPolicy): def module_policy(self): policy = super().module_policy() policy = self.add_lm_head_policy(policy) + policy = self.add_lm_prediction_policy(policy) from transformers.models.bert.modeling_bert import BertLMHeadModel self.set_pipeline_forward(model_cls=BertLMHeadModel, new_forward=bert_lm_head_model_forward, policy=policy) return policy @@ -316,21 +318,13 @@ class BertLMHeadModelPolicy(BertPolicy): bert_model = self.model.bert if self.pipeline_stage_manager: if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): - #tie weights + # tie weights return [{ 0: bert_model.embeddings.word_embeddings.weight, self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight }] return [] - def postprocess(self): - if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: - binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) - return self.model - # BertForMaskedLM class BertForMaskedLMPolicy(BertPolicy): @@ -341,6 +335,7 @@ class BertForMaskedLMPolicy(BertPolicy): def module_policy(self): policy = super().module_policy() policy = self.add_lm_head_policy(policy) + mpolicy = self.add_lm_prediction_policy(policy) from transformers.models.bert.modeling_bert import BertForMaskedLM self.set_pipeline_forward(model_cls=BertForMaskedLM, new_forward=bert_for_masked_lm_forward, policy=policy) return policy @@ -366,21 +361,13 @@ class BertForMaskedLMPolicy(BertPolicy): bert_model = self.model.bert if self.pipeline_stage_manager: if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): - #tie weights + # tie weights return [{ 0: bert_model.embeddings.word_embeddings.weight, self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight }] return [] - def postprocess(self): - if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: - binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) - return self.model - # BertForSequenceClassification class BertForSequenceClassificationPolicy(BertPolicy): @@ -1032,6 +1019,7 @@ def bert_for_masked_lm_forward( stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, ): + # -> Union[Tuple[torch.Tensor], MaskedLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., @@ -1109,7 +1097,7 @@ def bert_for_next_sentence_prediction_forward( stage_index: Optional[List[int]] = None, **kwargs, ): - #-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: + # -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 8afaadefb..b0e454529 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -1,9 +1,7 @@ import warnings from functools import partial -from types import MethodType from typing import Callable, Dict, List, Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn from torch import Tensor @@ -27,7 +25,6 @@ from transformers.utils import logging import colossalai.shardformer.layer as col_nn from colossalai.pipeline.stage_manager import PipelineStageManager -from .._utils import getattr_, setattr_ from ..modeling.bloom import build_bloom_alibi_tensor_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -229,20 +226,10 @@ class BloomForCausalLMPolicy(BloomPolicy): # tie weights return [{ 0: bloom_model.transformer.word_embeddings.weight, - self.stage_manager.num_stages - 1: bloom_model.lm_head.weight + self.pipeline_stage_manager.num_stages - 1: bloom_model.lm_head.weight }] return [] - def postprocess(self): - if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: - binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"} - - for k, v in binding_map.items(): - param = getattr_(self.model, k) - # tie weights - setattr_(self.model, v, param) - return self.model - class BloomForSequenceClassificationPolicy(BloomPolicy): @@ -692,7 +679,7 @@ def bloom_for_sequence_classification_forward( all_cross_attentions = None if stage_manager.is_last_stage(): batch_size = hidden_states.shape[0] - #update batch size + # update batch size hidden_states = transformer_outputs[0] logits = self.score(hidden_states) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 05178895d..6614a32b5 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -8,7 +8,6 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss import colossalai.shardformer.layer as col_nn from colossalai.pipeline.stage_manager import PipelineStageManager -from .._utils import getattr_, setattr_ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -56,42 +55,42 @@ class GPT2Policy(Policy): "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attn.c_attn", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 3, - }, - ), - SubModuleReplacementDescription( - suffix="attn.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - ), - SubModuleReplacementDescription( - suffix="mlp.c_fc", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 1, - }, - ), - SubModuleReplacementDescription( - suffix="mlp.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - ), - SubModuleReplacementDescription( - suffix="attn.attn_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attn.resid_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - ]) + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.c_attn", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 3, + }, + ), + SubModuleReplacementDescription( + suffix="attn.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.c_fc", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 1, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + ), + SubModuleReplacementDescription( + suffix="attn.attn_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attn.resid_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ]) # optimization configuration if self.shard_config.enable_fused_normalization: @@ -99,8 +98,8 @@ class GPT2Policy(Policy): suffix="ln_f", target_module=col_nn.FusedLayerNorm, ), - policy=policy, - target_key=GPT2Model) + policy=policy, + target_key=GPT2Model) self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription( @@ -115,8 +114,8 @@ class GPT2Policy(Policy): target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True) ], - policy=policy, - target_key=GPT2Block) + policy=policy, + target_key=GPT2Block) return policy def postprocess(self): @@ -227,15 +226,6 @@ class GPT2LMHeadModelPolicy(GPT2Policy): else: return [] - def postprocess(self): - if self.shard_config.enable_tensor_parallelism \ - and self.pipeline_stage_manager is None: - binding_map = {"transformer.wte.weight": "lm_head.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) - return self.model - # GPT2DoubleHeadsModel class GPT2DoubleHeadsModelPolicy(GPT2Policy): @@ -286,15 +276,6 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy): else: return [] - def postprocess(self): - if self.shard_config.enable_tensor_parallelism \ - and self.pipeline_stage_manager is None: - binding_map = {"transformer.wte.weight": "lm_head.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) - return self.model - # GPT2ForQuestionAnswering class GPT2ForQuestionAnsweringPolicy(GPT2Policy): diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index b3757452c..c7cd8182a 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -1,7 +1,5 @@ -import math from functools import partial -from types import MethodType -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union import torch import torch.nn as nn @@ -9,14 +7,11 @@ from torch import Tensor from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss from transformers.modeling_outputs import ( BaseModelOutputWithPast, - BaseModelOutputWithPastAndCrossAttentions, - BaseModelOutputWithPoolingAndCrossAttentions, - CausalLMOutputWithCrossAttentions, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel -from transformers.utils import ModelOutput, logging +from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 1435805d2..bbcc90e00 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,6 +1,5 @@ from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D -from .._utils import getattr_, setattr_ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -116,19 +115,6 @@ class OPTForCausalLMPolicy(OPTPolicy): target_key=OPTForCausalLM) return policy - def postprocess(self): - if self.shard_config.enable_tensor_parallelism: - binding_map = { - 'model.decoder.embed_tokens': 'lm_head', - } - - for k, v in binding_map.items(): - src_mod = getattr_(self.model, k) - dst_mod = getattr_(self.model, v) - dst_mod.weight = src_mod.weight - - return self.model - class OPTForSequenceClassificationPolicy(OPTPolicy): diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 37864885b..6b8f404f1 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -8,7 +8,6 @@ from colossalai.shardformer.layer import ( ) from colossalai.shardformer.policies.base_policy import ModulePolicyDescription -from .._utils import getattr_, setattr_ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"] @@ -53,7 +52,7 @@ class T5BasePolicy(Policy): ), SubModuleReplacementDescription( suffix="embed_tokens", - target_module=Embedding1D, + target_module=VocabParallelEmbedding1D, ) ]) policy[T5LayerSelfAttention] = ModulePolicyDescription(sub_module_replacement=[ @@ -165,12 +164,6 @@ class T5BasePolicy(Policy): return policy def postprocess(self): - if self.shard_config.enable_tensor_parallelism: - binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]] - - for k, v in binding_map: - mod = getattr_(self.model, k) - setattr_(self.model, v, mod) return self.model @@ -211,18 +204,6 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy): target_key=T5ForConditionalGeneration) return policy - def postprocess(self): - super().postprocess() - if self.shard_config.enable_tensor_parallelism: - binding_map = {"shared": "lm_head"} - - for k, v in binding_map.items(): - src_mod = getattr_(self.model, k) - dst_mod = getattr_(self.model, v) - dst_mod.weight = src_mod.weight - - return self.model - class T5EncoderPolicy(T5BasePolicy): @@ -239,14 +220,3 @@ class T5EncoderPolicy(T5BasePolicy): policy=base_policy, target_key=T5EncoderModel) return base_policy - - def postprocess(self): - if self.shard_config.enable_tensor_parallelism: - binding_map = [ - ["shared", "encoder.embed_tokens"], - ] - - for k, v in binding_map: - mod = getattr_(self.model, k) - setattr_(self.model, v, mod) - return self.model diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 5e0b572e2..b32c285bd 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -37,11 +37,13 @@ class ModelSharder(object): self.policy.set_model(self.model) self.policy.set_shard_config(self.shard_config) self._preprocess() + # get shared params before release unheld layers, this avoid misjudgement of shared params (None is None) + shared_params = self.policy.get_shared_params() self._release_unheld_layers() self._replace_module() self._materialize() self._postprocess() - return self.policy.get_shared_params() + return shared_params def _preprocess(self) -> None: self.model = self.policy.preprocess() diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py index 95a44e09e..32182faf6 100644 --- a/colossalai/tensor/d_tensor/api.py +++ b/colossalai/tensor/d_tensor/api.py @@ -235,6 +235,14 @@ def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True): return param +def sharded_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter) -> None: + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + param.data = dtensor + # make it distributed as well + param.dist_layout = dtensor.dist_layout + _hijack_detach_and_clone(param) + + def compute_global_numel(dtensor: torch.Tensor) -> int: """ Compute the global number of elements in the distributed tensor. @@ -432,3 +440,15 @@ def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad: param.gather_fn = dtensor.gather_fn _hijack_detach_and_clone_for_customized_distributed_tensor(param) return param + + +def customized_distributed_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter): + """ + Convert the given customized distributed tensor to an existing parameter. + """ + assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.' + + param.data = dtensor.data + param.shard_fn = dtensor.shard_fn + param.gather_fn = dtensor.gather_fn + _hijack_detach_and_clone_for_customized_distributed_tensor(param) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index e65271621..2dae645f7 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -17,3 +17,4 @@ requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggi SentencePiece ninja flash_attn>=2.0 +datasets diff --git a/tests/test_shardformer/test_layer/test_embedding.py b/tests/test_shardformer/test_layer/test_embedding.py index 99e494359..d62dba7ea 100644 --- a/tests/test_shardformer/test_layer/test_embedding.py +++ b/tests/test_shardformer/test_layer/test_embedding.py @@ -15,11 +15,13 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn def check_embedding_1d(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() + embedding = nn.Embedding(32, 128).cuda() with ctx: - embedding = nn.Embedding(32, 128).cuda() - embedding_1d = Embedding1D.from_native_module(embedding, process_group=None) + embedding_copy = nn.Embedding(32, 128).cuda() + embedding_1d = Embedding1D.from_native_module(embedding_copy, process_group=None) assert embedding_1d.weight.shape == torch.Size([32, 64]) + assert embedding_1d.weight is embedding_copy.weight # ensure state dict is reversibly loadable embedding.load_state_dict(embedding_1d.state_dict()) diff --git a/tests/test_shardformer/test_layer/test_layernorm.py b/tests/test_shardformer/test_layer/test_layernorm.py index 2cb6928ed..f9c21b82a 100644 --- a/tests/test_shardformer/test_layer/test_layernorm.py +++ b/tests/test_shardformer/test_layer/test_layernorm.py @@ -14,11 +14,14 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn def check_layernorm(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() + norm = nn.LayerNorm(128, 0.00001).cuda() with ctx: - norm = nn.LayerNorm(128, 0.00001).cuda() - norm1d = FusedLayerNorm.from_native_module(norm, process_group=None) + norm_copy = nn.LayerNorm(128, 0.00001).cuda() + norm1d = FusedLayerNorm.from_native_module(norm_copy, process_group=None) assert norm1d.weight.shape == torch.Size([128]) + assert norm_copy.weight is norm1d.weight + assert norm_copy.bias is norm1d.bias # ensure state dict is reversibly loadable norm.load_state_dict(norm1d.state_dict()) diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index da3cd85ec..aa75879e0 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -15,14 +15,16 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @parameterize('lazy_init', [False, True]) def check_linear_1d_col(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() - + linear = nn.Linear(32, 128).cuda() with ctx: - linear = nn.Linear(32, 128).cuda() - linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True) + linear_copy = nn.Linear(32, 128).cuda() + linear_col = Linear1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True) # ensure that the parameters are distributed assert is_distributed_tensor(linear_col.weight) assert is_distributed_tensor(linear_col.bias) + assert linear_copy.weight is linear_col.weight + assert linear_copy.bias is linear_col.bias # ensure the shape is correct assert linear_col.weight.shape == torch.Size([64, 32]) @@ -61,12 +63,18 @@ def check_linear_1d_col(lazy_init: bool): def check_linear_1d_row(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() + linear = nn.Linear(32, 128).cuda() with ctx: - linear = nn.Linear(32, 128).cuda() - linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False) + linear_copy = nn.Linear(32, 128).cuda() + linear_row = Linear1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) assert linear_row.weight.shape == torch.Size([128, 16]) assert linear_row.bias.shape == torch.Size([128]) + assert linear_copy.weight is linear_row.weight + assert linear_copy.bias is linear_row.bias + + linear.load_state_dict(linear_row.state_dict()) + linear_row.load_state_dict(linear.state_dict()) # check computation correctness x = torch.rand(4, 32).cuda() @@ -98,11 +106,19 @@ def check_linear_1d_row(lazy_init: bool): def check_linear_col_plus_row(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() + linear_1 = nn.Linear(32, 128).cuda() + linear_2 = nn.Linear(128, 32).cuda() + with ctx: - linear_1 = nn.Linear(32, 128).cuda() - linear_2 = nn.Linear(128, 32).cuda() - linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False) - linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True) + linear_1_copy = nn.Linear(32, 128).cuda() + linear_2_copy = nn.Linear(128, 32).cuda() + linear_col = Linear1D_Col.from_native_module(linear_1_copy, process_group=None, gather_output=False) + linear_row = Linear1D_Row.from_native_module(linear_2_copy, process_group=None, parallel_input=True) + + linear_1.load_state_dict(linear_col.state_dict()) + linear_col.load_state_dict(linear_1.state_dict()) + linear_2.load_state_dict(linear_row.state_dict()) + linear_row.load_state_dict(linear_2.state_dict()) # check computation correctness x = torch.rand(4, 32).cuda() diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py index 186b1e821..b45cd172c 100644 --- a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py @@ -56,10 +56,10 @@ def rearrange(tensor: torch.Tensor, dim: int): @parameterize('lazy_init', [False, True]) def check_linear_conv_1d_col(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() - + linear = Conv1D(192, 48).cuda() with ctx: - linear = Conv1D(192, 48).cuda() - linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear, + linear_copy = Conv1D(192, 48).cuda() + linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True, n_fused=3) @@ -68,6 +68,8 @@ def check_linear_conv_1d_col(lazy_init: bool): assert linear.bias.shape == torch.Size([192]) assert linear_conv_col.weight.shape == torch.Size([48, 96]) assert linear_conv_col.bias.shape == torch.Size([96]) + assert linear_copy.weight is linear_conv_col.weight + assert linear_copy.bias is linear_conv_col.bias # ensure weights are reversibly loadable linear_conv_col.load_state_dict(linear.state_dict()) @@ -91,13 +93,20 @@ def check_linear_conv_1d_col(lazy_init: bool): def check_linear_conv_1d_row(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() + linear = Conv1D(192, 48).cuda() with ctx: - linear = Conv1D(192, 48).cuda() - linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False) + linear_copy = Conv1D(192, 48).cuda() + linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) assert linear.weight.shape == torch.Size([48, 192]) assert linear_row.weight.shape == torch.Size([24, 192]) assert linear_row.bias.shape == torch.Size([192]) + assert linear_copy.weight is linear_row.weight + assert linear_copy.bias is linear_row.bias + + # ensure weights are reversibly loadable + linear_row.load_state_dict(linear.state_dict()) + linear.load_state_dict(linear_row.state_dict()) # check computation correctness x = torch.rand(4, 48).cuda() diff --git a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py index bf5803496..6d2f08730 100644 --- a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py +++ b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py @@ -7,8 +7,7 @@ from torch.testing import assert_close import colossalai from colossalai.lazy import LazyInitContext -from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, VocabParallelEmbedding1D -from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style +from colossalai.shardformer.layer import VocabParallelEmbedding1D from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -16,13 +15,15 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn def check_vocab_embedding_1d(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() + embedding = nn.Embedding(128, 32).to('cuda') with ctx: - embedding = nn.Embedding(128, 32).to('cuda') - dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding, process_group=None) + embedding_copy = nn.Embedding(128, 32).to('cuda') + dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding_copy, process_group=None) assert dist_embedding_1d.weight.shape == torch.Size([64, 32]) assert dist_embedding_1d.num_embeddings == 64 assert dist_embedding_1d.embedding_dim == 32 + assert embedding_copy.weight is dist_embedding_1d.weight # ensure state dict is reversibly loadable embedding.load_state_dict(dist_embedding_1d.state_dict()) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 825d6df6b..2320c725d 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -1,8 +1,10 @@ import copy from contextlib import nullcontext +import torch +from torch.nn import Module + from colossalai.lazy import LazyInitContext -from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer @@ -61,3 +63,14 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, shard_output = output_transform_fn(shard_output) shard_loss = loss_fn(shard_output) return org_output, org_loss, shard_output, shard_loss + + +def check_state_dict(org_model: Module, sharded_model: Module, name: str = ''): + org_sd = org_model.state_dict() + shard_sd = sharded_model.state_dict() + for k, v in org_sd.items(): + assert k in shard_sd, f'{name} {k} not in sharded model' + shard_v = shard_sd[k] + assert v.shape == shard_v.shape, f'{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}' + assert v.dtype == shard_v.dtype, f'{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}' + assert torch.equal(v, shard_v), f'{name} {k} value mismatch' diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 7f179acd7..ea0f12264 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -12,7 +12,7 @@ from colossalai.testing import ( spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): @@ -75,6 +75,7 @@ def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_laz for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) + check_state_dict(org_model, sharded_model, name=name) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index e18168292..fe4686aeb 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -12,7 +12,7 @@ from colossalai.testing import ( spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): @@ -75,6 +75,7 @@ def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_la for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) + check_state_dict(org_model, sharded_model, name=name) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 552c6e2f4..99451b403 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -12,7 +12,7 @@ from colossalai.testing import ( spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): @@ -77,6 +77,7 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) + check_state_dict(org_model, sharded_model, name=name) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 4d63a4348..aaeef13ef 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -14,7 +14,7 @@ from colossalai.testing import ( spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' @@ -78,6 +78,7 @@ def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, use_la for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) + check_state_dict(org_model, sharded_model, name=name) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index c008596fe..297affceb 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -15,7 +15,7 @@ from colossalai.testing import ( spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' @@ -77,6 +77,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_ for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) + check_state_dict(org_model, sharded_model, name=name) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index ccd7d3787..96dfdeb73 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -14,7 +14,7 @@ from colossalai.testing import ( spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): @@ -88,6 +88,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_ for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) + check_state_dict(org_model, sharded_model, name=name) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache()