diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 2d347542f..16f3fa14e 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -2,9 +2,13 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Type, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch.nn as nn +from torch import Tensor +from torch.nn import Module + +from colossalai.pipeline.stage_manager import PipelineStageManager from ..shard.shard_config import ShardConfig @@ -71,9 +75,8 @@ class Policy(ABC): """ def __init__(self) -> None: - self.shard_config = None - self.model = None - self.shard_config = None + self.shard_config: Optional[ShardConfig] = None + self.model: Optional[Module] = None def set_model(self, model: nn.Module) -> None: r""" @@ -94,6 +97,12 @@ class Policy(ABC): self.shard_config = shard_config self.config_sanity_check() + @property + def pipeline_stage_manager(self) -> Optional[PipelineStageManager]: + if self.shard_config is not None: + return self.shard_config.pipeline_stage_manager + return None + @abstractmethod def config_sanity_check(self): """ @@ -151,3 +160,19 @@ class Policy(ABC): policy[target_key] = ModulePolicyDescription(sub_module_replacement=description) return policy + + def get_held_layers(self) -> List[Module]: + """Get layers that should be held in current stage. This method should be implemented by subclass. + + Returns: + List[Module]: List of layers that should be hold in current stage + """ + raise NotImplementedError + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """Get parameters that should be shared across stages. This method should be implemented by subclass. + + Returns: + List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}] + """ + return [] diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 83c08d275..fba2c27a2 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -1,8 +1,11 @@ from dataclasses import dataclass +from typing import Optional import torch.distributed as dist from torch.distributed import ProcessGroup +from colossalai.pipeline.stage_manager import PipelineStageManager + __all__ = ['ShardConfig'] @@ -13,11 +16,13 @@ class ShardConfig: Args: tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group. + pipeline_stage_manager (PipelineStageManager): The pipeline stage manager, defaults to None, which means no pipeline. enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True. enable_fused_normalization (bool): Whether to use fused layernorm, default is False. enable_all_optimization (bool): Whether to turn on all optimization, default is False. """ - tensor_parallel_process_group: ProcessGroup = None + tensor_parallel_process_group: Optional[ProcessGroup] = None + pipeline_stage_manager: Optional[PipelineStageManager] = None enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False enable_all_optimization: bool = False diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 201e0a08c..429ca8ed7 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -1,11 +1,15 @@ from typing import Any, Callable, Dict, List, Union import torch.nn as nn +from torch import Tensor + +from colossalai.lazy import LazyTensor from .._utils import getattr_, setattr_ from ..policies.autopolicy import get_autopolicy from ..policies.basepolicy import Policy, SubModuleReplacementDescription from .shard_config import ShardConfig +from .utils import set_tensors_to_none __all__ = ['ModelSharder', 'shard_model'] @@ -25,15 +29,18 @@ class ModelSharder(object): self.policy = get_autopolicy(self.model) if policy is None else policy self.shard_config = shard_config - def shard(self) -> None: + def shard(self) -> List[Dict[int, Tensor]]: r""" Shard the model according to the policy """ self.policy.set_model(self.model) self.policy.set_shard_config(self.shard_config) self._preprocess() + self._release_unheld_layers() self._replace_module() + self._materialize() self._postprocess() + return self.policy.get_shared_params() def _preprocess(self) -> None: self.model = self.policy.preprocess() @@ -172,3 +179,23 @@ class ModelSharder(object): ) setattr_(org_layer, suffix, replace_layer) + + def _release_unheld_layers(self) -> None: + r""" + Release the unheld layers in the model + """ + if self.shard_config and self.shard_config.pipeline_stage_manager: + held_layers = self.policy.get_held_layers() + set_tensors_to_none(self.model, exclude=set(held_layers)) + + def _materialize(self) -> None: + r""" + Materialize the model if lazy initialization is used + """ + for p in self.model.parameters(): + if isinstance(p, LazyTensor): + p.materialize() + + for b in self.model.buffers(): + if isinstance(b, LazyTensor): + b.materialize() diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py index 3fce12463..069a46ca5 100644 --- a/colossalai/shardformer/shard/shardformer.py +++ b/colossalai/shardformer/shard/shardformer.py @@ -42,5 +42,5 @@ class ShardFormer: policy (`Policy`): the custom policy for sharding """ sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy) - sharder.shard() - return model + shared_params = sharder.shard() + return model, shared_params diff --git a/colossalai/shardformer/shard/utils.py b/colossalai/shardformer/shard/utils.py new file mode 100644 index 000000000..2bac37bfe --- /dev/null +++ b/colossalai/shardformer/shard/utils.py @@ -0,0 +1,19 @@ +from typing import Set + +import torch.nn as nn + + +def set_tensors_to_none(model: nn.Module, exclude: Set[nn.Module] = set()) -> None: + """Set all parameters and buffers of model to None + + Args: + model (nn.Module): The model to set + """ + if model in exclude: + return + for child in model.children(): + set_tensors_to_none(child, exclude=exclude) + for n, p in model.named_parameters(recurse=False): + setattr(model, n, None) + for n, buf in model.named_buffers(recurse=False): + setattr(model, n, None)