diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index 1f5345015..e071563c0 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -6,6 +6,7 @@ import torch import torch.distributed as dist import torch.nn as nn from torch import Tensor +from torch.nn import Parameter from torch.utils._pytree import tree_map from colossalai._analyzer._subclasses import MetaTensor @@ -99,8 +100,11 @@ def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: the converted tensor """ - cls_to_become = nn.Parameter if isinstance(tensor, nn.Parameter) else torch.Tensor + cls_to_become = Parameter if isinstance(tensor, Parameter) else torch.Tensor tensor.__class__ = cls_to_become + if cls_to_become is Parameter: + # to fit UninitializedParameter + delattr(tensor, '_is_param') tensor.data = target tensor.requires_grad = target.requires_grad # subclass of torch.Tensor does not have tolist() method @@ -198,10 +202,10 @@ class LazyTensor(torch.Tensor): def clean(self) -> None: """Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized. """ - self._factory_method = None - self._op_buffer = None - self._materialized_data = None - self._meta_data = None + delattr(self, '_factory_method') + delattr(self, '_op_buffer') + delattr(self, '_materialized_data') + delattr(self, '_meta_data') @staticmethod def _replace_with_materialized(x): @@ -350,20 +354,19 @@ class LazyTensor(torch.Tensor): def factory_fn(): # if self is materialized, return self new_tensor = self.materialize() if type(self) is LazyTensor else self - copied = new_tensor.detach().clone() - if new_tensor.requires_grad: - copied.requires_grad_() - return copied + return _copy_tensor(new_tensor, new_tensor.requires_grad) if self._materialized_data is not None: # self is early materialized - copied = self._materialized_data.detach().clone() - if self.requires_grad: - copied.requires_grad_() + copied = _copy_tensor(self._materialized_data, self.requires_grad) target = LazyTensor(lambda: None, concrete_data=copied) else: target = LazyTensor(factory_fn, meta_data=self._meta_data) + if isinstance(self, Parameter): + # hack isinstance check of parameter + target._is_param = True + memo[id(self)] = target return target @@ -408,6 +411,10 @@ class LazyTensor(torch.Tensor): def __hash__(self): return id(self) + def __rpow__(self, other): + dtype = torch.result_type(self, other) + return torch.tensor(other, dtype=dtype, device=self.device)**self + class LazyInitContext: """Context manager for lazy initialization. Enables initializing the model without allocating real memory. @@ -536,7 +543,7 @@ class LazyInitContext: @staticmethod def materialize(module: nn.Module, verbose: bool = False) -> nn.Module: - """Initialize all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. + """Initialize all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place. Args: module (nn.Module): Target ``nn.Module`` @@ -553,7 +560,7 @@ class LazyInitContext: device_mesh: DeviceMesh, sharding_spec_dict: Dict[str, ShardingSpec], verbose: bool = False) -> nn.Module: - """Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. + """Distribute all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place. Args: module (nn.Module): Target ``nn.Module`` @@ -625,3 +632,9 @@ def _is_int_tuple(args) -> bool: if not isinstance(x, int): return False return True + + +def _copy_tensor(tensor: Tensor, requires_grad: bool) -> Tensor: + copied = tensor.data.clone() + copied.requires_grad = requires_grad + return copied diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index db39a457b..07341ef73 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -9,8 +9,8 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torch.distributed import ProcessGroup -from torch.nn.parameter import Parameter +from colossalai.lazy import LazyInitContext from colossalai.nn import init as init from colossalai.nn.layer.utils import divide from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param @@ -95,6 +95,7 @@ class Embedding1D(ParallelModule): r""" Build a 1D parallelized Embedding from a native nn.Embedding module. """ + LazyInitContext.materialize(module) # get the attributes num_embedding = module.num_embeddings embedding_dim = module.embedding_dim @@ -223,6 +224,7 @@ class VocabParallelEmbedding1D(ParallelModule): r""" Convert a native pytorch embedding module to a parallel module. """ + LazyInitContext.materialize(module) # get the origin attributes num_embeddings = module.num_embeddings embedding_dim = module.embedding_dim @@ -243,6 +245,7 @@ class VocabParallelEmbedding1D(ParallelModule): process_group=process_group, *args, **kwargs) + with torch.no_grad(): # shard and slice the weight along the vocabulary(num_embeddings) dimension # the shape of the weight is (num_embeddings, embedding_dim) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 26ba5883c..a8439f303 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -12,6 +12,7 @@ from torch import Tensor from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter +from colossalai.lazy import LazyInitContext from colossalai.nn import init as init from colossalai.nn.layer.utils import divide from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param @@ -106,6 +107,7 @@ class Linear1D_Col(ParallelModule): r""" Convert a native PyTorch linear layer to a parallelized linear layer. """ + LazyInitContext.materialize(module) # get the attributes in_features = module.in_features out_features = module.out_features @@ -242,6 +244,7 @@ class Linear1D_Row(ParallelModule): r""" Convert a native PyTorch linear layer to a parallelized linear layer. """ + LazyInitContext.materialize(module) # get the attributes in_features = module.in_features out_features = module.out_features diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index b27307154..9bb7738c0 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -4,6 +4,8 @@ import torch import torch.nn as nn +from colossalai.lazy import LazyInitContext + __all__ = ['FusedLayerNorm', 'FusedRMSNorm'] FAST_LAYERNORM_SUPPORTED_SIZE = [ @@ -35,6 +37,7 @@ class FusedLayerNorm(): raise ImportError( 'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel') + LazyInitContext.materialize(module) # get the attributes of the module normalized_shape = module.normalized_shape eps = module.eps @@ -84,6 +87,7 @@ class FusedRMSNorm(): 'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel' ) + LazyInitContext.materialize(module) # to check if it is huggingface LlamaRMSNorm if module.__class__.__name__ == "LlamaRMSNorm": normalized_shape = module.weight.shape[0] diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 9d51670c6..c94d93069 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -12,6 +12,7 @@ from torch import Tensor from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter +from colossalai.lazy import LazyInitContext from colossalai.nn import init as init from colossalai.nn.layer.utils import divide from colossalai.tensor.d_tensor.api import ( @@ -231,6 +232,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication. n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight. """ + LazyInitContext.materialize(module) # get the attributes in_features = module.weight.shape[0] out_features = module.weight.shape[1] @@ -380,6 +382,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): r""" Convert a native PyTorch linear layer to a parallelized linear layer. """ + LazyInitContext.materialize(module) # get the attributes in_features = module.weight.shape[0] out_features = module.weight.shape[1] @@ -428,9 +431,9 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) origin_device = self.bias.device - self.bias = self.bias.cuda() + self.bias.data = self.bias.cuda() dist.broadcast(self.bias, src=src_rank, group=self.process_group) - self.bias = self.bias.to(origin_device) + self.bias.data = self.bias.to(origin_device) def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index e18cb6ece..b80475e05 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -46,11 +46,12 @@ class BertPolicy(Policy): Reshape the Embedding layer to make the embedding dimension divisible by world_size """ # TODO: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) return self.model def module_policy(self): @@ -229,10 +230,11 @@ class BertForPreTrainingPolicy(BertPolicy): return [] def postprocess(self): - binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) + if self.shard_config.enable_tensor_parallelism: + binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + setattr_(self.model, v, param) return self.model @@ -269,10 +271,11 @@ class BertLMHeadModelPolicy(BertPolicy): return [] def postprocess(self): - binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) + if self.shard_config.enable_tensor_parallelism: + binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + setattr_(self.model, v, param) return self.model @@ -288,10 +291,11 @@ class BertForMaskedLMPolicy(BertPolicy): return module_policy def postprocess(self): - binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) + if self.shard_config.enable_tensor_parallelism: + binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + setattr_(self.model, v, param) return self.model diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 8d6f07d4a..662ff5b49 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -17,11 +17,12 @@ class BloomPolicy(Policy): r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) return self.model def module_policy(self): @@ -128,16 +129,13 @@ class BloomForCausalLMPolicy(BloomPolicy): return policy def postprocess(self): - binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"} + if self.shard_config.enable_tensor_parallelism: + binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - - if not isinstance(param, nn.Parameter): - param = nn.Parameter(param) - - # tie weights - setattr_(self.model, v, param) + for k, v in binding_map.items(): + param = getattr_(self.model, k) + # tie weights + setattr_(self.model, v, param) return self.model diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 598f393c0..8f9d90e67 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -21,11 +21,12 @@ class GPT2Policy(Policy): r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) return self.model def module_policy(self): @@ -142,10 +143,11 @@ class GPT2LMHeadModelPolicy(GPT2Policy): return module_policy def postprocess(self): - binding_map = {"transformer.wte.weight": "lm_head.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) + if self.shard_config.enable_tensor_parallelism: + binding_map = {"transformer.wte.weight": "lm_head.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + setattr_(self.model, v, param) return self.model @@ -172,10 +174,11 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy): return module_policy def postprocess(self): - binding_map = {"transformer.wte.weight": "lm_head.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) + if self.shard_config.enable_tensor_parallelism: + binding_map = {"transformer.wte.weight": "lm_head.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + setattr_(self.model, v, param) return self.model diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 391938b27..b10e07560 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -15,13 +15,14 @@ class LlamaPolicy(Policy): pass def preprocess(self): - # Resize embedding - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size + if self.shard_config.enable_tensor_parallelism: + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) return self.model diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index c4c6cde01..1435805d2 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -19,11 +19,12 @@ class OPTPolicy(Policy): r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) return self.model def module_policy(self): @@ -116,14 +117,15 @@ class OPTForCausalLMPolicy(OPTPolicy): return policy def postprocess(self): - binding_map = { - 'model.decoder.embed_tokens': 'lm_head', - } + if self.shard_config.enable_tensor_parallelism: + binding_map = { + 'model.decoder.embed_tokens': 'lm_head', + } - for k, v in binding_map.items(): - src_mod = getattr_(self.model, k) - dst_mod = getattr_(self.model, v) - dst_mod.weight = src_mod.weight + for k, v in binding_map.items(): + src_mod = getattr_(self.model, k) + dst_mod = getattr_(self.model, v) + dst_mod.weight = src_mod.weight return self.model diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 6167e8161..37864885b 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -24,11 +24,12 @@ class T5BasePolicy(Policy): r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) return self.model def module_policy(self): @@ -164,11 +165,12 @@ class T5BasePolicy(Policy): return policy def postprocess(self): - binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]] + if self.shard_config.enable_tensor_parallelism: + binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]] - for k, v in binding_map: - mod = getattr_(self.model, k) - setattr_(self.model, v, mod) + for k, v in binding_map: + mod = getattr_(self.model, k) + setattr_(self.model, v, mod) return self.model @@ -211,13 +213,13 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy): def postprocess(self): super().postprocess() + if self.shard_config.enable_tensor_parallelism: + binding_map = {"shared": "lm_head"} - binding_map = {"shared": "lm_head"} - - for k, v in binding_map.items(): - src_mod = getattr_(self.model, k) - dst_mod = getattr_(self.model, v) - dst_mod.weight = src_mod.weight + for k, v in binding_map.items(): + src_mod = getattr_(self.model, k) + dst_mod = getattr_(self.model, v) + dst_mod.weight = src_mod.weight return self.model @@ -239,11 +241,12 @@ class T5EncoderPolicy(T5BasePolicy): return base_policy def postprocess(self): - binding_map = [ - ["shared", "encoder.embed_tokens"], - ] + if self.shard_config.enable_tensor_parallelism: + binding_map = [ + ["shared", "encoder.embed_tokens"], + ] - for k, v in binding_map: - mod = getattr_(self.model, k) - setattr_(self.model, v, mod) + for k, v in binding_map: + mod = getattr_(self.model, k) + setattr_(self.model, v, mod) return self.model diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index ca2f46a18..56eb76973 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, List, Union import torch.nn as nn from torch import Tensor -from colossalai.lazy import LazyTensor +from colossalai.lazy import LazyInitContext from .._utils import getattr_, setattr_ from ..policies.auto_policy import get_autopolicy @@ -192,10 +192,4 @@ class ModelSharder(object): r""" Materialize the model if lazy initialization is used """ - for p in self.model.parameters(): - if isinstance(p, LazyTensor): - p.materialize() - - for b in self.model.buffers(): - if isinstance(b, LazyTensor): - b.materialize() + LazyInitContext.materialize(self.model) diff --git a/tests/test_shardformer/test_layer/test_embedding.py b/tests/test_shardformer/test_layer/test_embedding.py index 8a6aa42a4..99e494359 100644 --- a/tests/test_shardformer/test_layer/test_embedding.py +++ b/tests/test_shardformer/test_layer/test_embedding.py @@ -1,15 +1,22 @@ +from contextlib import nullcontext + import torch import torch.distributed as dist import torch.nn as nn from torch.testing import assert_close import colossalai +from colossalai.lazy import LazyInitContext from colossalai.shardformer.layer import Embedding1D -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -def check_embedding_1d(): - embedding = nn.Embedding(32, 128).cuda() +@parameterize('lazy_init', [False, True]) +def check_embedding_1d(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + with ctx: + embedding = nn.Embedding(32, 128).cuda() embedding_1d = Embedding1D.from_native_module(embedding, process_group=None) assert embedding_1d.weight.shape == torch.Size([32, 64]) diff --git a/tests/test_shardformer/test_layer/test_layernorm.py b/tests/test_shardformer/test_layer/test_layernorm.py index fc6d894c4..2cb6928ed 100644 --- a/tests/test_shardformer/test_layer/test_layernorm.py +++ b/tests/test_shardformer/test_layer/test_layernorm.py @@ -1,14 +1,21 @@ +from contextlib import nullcontext + import torch import torch.nn as nn from torch.testing import assert_close import colossalai +from colossalai.lazy import LazyInitContext from colossalai.shardformer.layer import FusedLayerNorm -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -def check_layernorm(): - norm = nn.LayerNorm(128, 0.00001).cuda() +@parameterize('lazy_init', [False, True]) +def check_layernorm(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + with ctx: + norm = nn.LayerNorm(128, 0.00001).cuda() norm1d = FusedLayerNorm.from_native_module(norm, process_group=None) assert norm1d.weight.shape == torch.Size([128]) diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index da3bdc1d7..da3cd85ec 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -1,16 +1,23 @@ +from contextlib import nullcontext + import torch import torch.distributed as dist import torch.nn as nn from torch.testing import assert_close import colossalai +from colossalai.lazy import LazyInitContext from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row from colossalai.tensor.d_tensor import is_distributed_tensor -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -def check_linear_1d_col(): - linear = nn.Linear(32, 128).cuda() +@parameterize('lazy_init', [False, True]) +def check_linear_1d_col(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + with ctx: + linear = nn.Linear(32, 128).cuda() linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True) # ensure that the parameters are distributed @@ -50,8 +57,12 @@ def check_linear_1d_col(): assert_close(x_for_unshard.grad, x_for_shard.grad) -def check_linear_1d_row(): - linear = nn.Linear(32, 128).cuda() +@parameterize('lazy_init', [False, True]) +def check_linear_1d_row(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + with ctx: + linear = nn.Linear(32, 128).cuda() linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False) assert linear_row.weight.shape == torch.Size([128, 16]) @@ -83,9 +94,13 @@ def check_linear_1d_row(): assert_close(x_for_unshard.grad, x_for_shard.grad) -def check_linear_col_plus_row(): - linear_1 = nn.Linear(32, 128).cuda() - linear_2 = nn.Linear(128, 32).cuda() +@parameterize('lazy_init', [False, True]) +def check_linear_col_plus_row(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + with ctx: + linear_1 = nn.Linear(32, 128).cuda() + linear_2 = nn.Linear(128, 32).cuda() linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False) linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True) diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py index 681c4f6dd..186b1e821 100644 --- a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py @@ -1,12 +1,15 @@ +from contextlib import nullcontext + import torch import torch.distributed as dist import torch.nn as nn from torch.testing import assert_close import colossalai +from colossalai.lazy import LazyInitContext from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn # This code is copied from https://github.com/huggingface/transformers @@ -50,8 +53,12 @@ def rearrange(tensor: torch.Tensor, dim: int): return rearanged_tensor -def check_linear_conv_1d_col(): - linear = Conv1D(192, 48).cuda() +@parameterize('lazy_init', [False, True]) +def check_linear_conv_1d_col(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + with ctx: + linear = Conv1D(192, 48).cuda() linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear, process_group=None, gather_output=True, @@ -80,8 +87,12 @@ def check_linear_conv_1d_col(): assert_close(target_grad, linear_conv_col.weight.grad) -def check_linear_conv_1d_row(): - linear = Conv1D(192, 48).cuda() +@parameterize('lazy_init', [False, True]) +def check_linear_conv_1d_row(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + with ctx: + linear = Conv1D(192, 48).cuda() linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False) assert linear.weight.shape == torch.Size([48, 192]) diff --git a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py index 8991d9b30..bf5803496 100644 --- a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py +++ b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py @@ -1,15 +1,23 @@ +from contextlib import nullcontext + import torch import torch.distributed as dist import torch.nn as nn from torch.testing import assert_close import colossalai -from colossalai.shardformer.layer import VocabParallelEmbedding1D +from colossalai.lazy import LazyInitContext +from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -def check_vocab_embedding_1d(): - embedding = nn.Embedding(128, 32).to('cuda') +@parameterize('lazy_init', [False, True]) +def check_vocab_embedding_1d(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + with ctx: + embedding = nn.Embedding(128, 32).to('cuda') dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding, process_group=None) assert dist_embedding_1d.weight.shape == torch.Size([64, 32]) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index e03014f3f..f83cfcd49 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -1,19 +1,24 @@ import copy +from contextlib import nullcontext +from colossalai.lazy import LazyInitContext from colossalai.shardformer import ShardConfig, ShardFormer -def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True): - # create new model - org_model = model_fn().cuda() - +def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False): + ctx = LazyInitContext() if use_lazy_init else nullcontext() + with ctx: + # create new model + org_model = model_fn() + model_copy = copy.deepcopy(org_model) + if use_lazy_init: + ctx.materialize(org_model) # shard model shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, enable_tensor_parallelism=enable_tensor_parallelism) - model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model, shared_params = shard_former.optimize(model_copy) - return org_model, sharded_model.cuda() + return org_model.cuda(), sharded_model.cuda() def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 1afedb707..7f179acd7 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -67,12 +67,14 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -def run_bert_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('enable_fused_normalization', [False, True]) +@parameterize('enable_tensor_parallelism', [False, True]) +@parameterize('use_lazy_init', [False, True]) +def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + use_lazy_init) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index a33896522..e18168292 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -69,10 +69,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) -def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('use_lazy_init', [False, True]) +def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + use_lazy_init) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index ee7737687..96c4b90a8 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -69,10 +69,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) -def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('use_lazy_init', [False, True]) +def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + use_lazy_init) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 74b5fdd18..4d63a4348 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -72,10 +72,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) -def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('use_lazy_init', [False, True]) +def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + use_lazy_init) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 25bccb13b..c008596fe 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -71,10 +71,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) -def run_t5_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('use_lazy_init', [False, True]) +def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + use_lazy_init) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 0762dc09e..ccd7d3787 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -82,10 +82,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) -def run_t5_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('use_lazy_init', [False, True]) +def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + use_lazy_init) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_with_torch_ddp.py b/tests/test_shardformer/test_with_torch_ddp.py index f29c8d6f6..2b6933246 100644 --- a/tests/test_shardformer/test_with_torch_ddp.py +++ b/tests/test_shardformer/test_with_torch_ddp.py @@ -1,3 +1,5 @@ +from contextlib import nullcontext + import pytest import torch import torch.distributed as dist @@ -5,15 +7,15 @@ from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -def check_shardformer_with_ddp(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') +@parameterize('lazy_init', [True, False]) +def check_shardformer_with_ddp(lazy_init: bool): sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') @@ -41,9 +43,12 @@ def check_shardformer_with_ddp(rank, world_size, port): shard_config = ShardConfig(tensor_parallel_process_group=tp_process_group, enable_fused_normalization=True) shardformer = ShardFormer(shard_config=shard_config) + ctx = LazyInitContext() if lazy_init else nullcontext() + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): # create and shard model - model = model_fn().cuda() + with ctx: + model = model_fn().cuda() sharded_model, _ = shardformer.optimize(model) # add ddp @@ -65,13 +70,18 @@ def check_shardformer_with_ddp(rank, world_size, port): torch.cuda.empty_cache() +def run_dist(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_shardformer_with_ddp() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_gpt2(): - spawn(check_shardformer_with_ddp, 4) + spawn(run_dist, 4) if __name__ == "__main__": test_gpt2() - test_gpt2()