[pipeline] update shardformer policy

pull/4445/head
ver217 2023-07-05 14:16:55 +08:00 committed by Hongxin Liu
parent 90a65ea682
commit 59f6f573f1
5 changed files with 84 additions and 8 deletions

View File

@ -2,9 +2,13 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Type, Union
from typing import Any, Callable, Dict, List, Optional, Union
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
from colossalai.pipeline.stage_manager import PipelineStageManager
from ..shard.shard_config import ShardConfig
@ -71,9 +75,8 @@ class Policy(ABC):
"""
def __init__(self) -> None:
self.shard_config = None
self.model = None
self.shard_config = None
self.shard_config: Optional[ShardConfig] = None
self.model: Optional[Module] = None
def set_model(self, model: nn.Module) -> None:
r"""
@ -94,6 +97,12 @@ class Policy(ABC):
self.shard_config = shard_config
self.config_sanity_check()
@property
def pipeline_stage_manager(self) -> Optional[PipelineStageManager]:
if self.shard_config is not None:
return self.shard_config.pipeline_stage_manager
return None
@abstractmethod
def config_sanity_check(self):
"""
@ -151,3 +160,19 @@ class Policy(ABC):
policy[target_key] = ModulePolicyDescription(sub_module_replacement=description)
return policy
def get_held_layers(self) -> List[Module]:
"""Get layers that should be held in current stage. This method should be implemented by subclass.
Returns:
List[Module]: List of layers that should be hold in current stage
"""
raise NotImplementedError
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""Get parameters that should be shared across stages. This method should be implemented by subclass.
Returns:
List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}]
"""
return []

View File

@ -1,8 +1,11 @@
from dataclasses import dataclass
from typing import Optional
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.pipeline.stage_manager import PipelineStageManager
__all__ = ['ShardConfig']
@ -13,11 +16,13 @@ class ShardConfig:
Args:
tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group.
pipeline_stage_manager (PipelineStageManager): The pipeline stage manager, defaults to None, which means no pipeline.
enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True.
enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
enable_all_optimization (bool): Whether to turn on all optimization, default is False.
"""
tensor_parallel_process_group: ProcessGroup = None
tensor_parallel_process_group: Optional[ProcessGroup] = None
pipeline_stage_manager: Optional[PipelineStageManager] = None
enable_tensor_parallelism: bool = True
enable_fused_normalization: bool = False
enable_all_optimization: bool = False

View File

@ -1,11 +1,15 @@
from typing import Any, Callable, Dict, List, Union
import torch.nn as nn
from torch import Tensor
from colossalai.lazy import LazyTensor
from .._utils import getattr_, setattr_
from ..policies.autopolicy import get_autopolicy
from ..policies.basepolicy import Policy, SubModuleReplacementDescription
from .shard_config import ShardConfig
from .utils import set_tensors_to_none
__all__ = ['ModelSharder', 'shard_model']
@ -25,15 +29,18 @@ class ModelSharder(object):
self.policy = get_autopolicy(self.model) if policy is None else policy
self.shard_config = shard_config
def shard(self) -> None:
def shard(self) -> List[Dict[int, Tensor]]:
r"""
Shard the model according to the policy
"""
self.policy.set_model(self.model)
self.policy.set_shard_config(self.shard_config)
self._preprocess()
self._release_unheld_layers()
self._replace_module()
self._materialize()
self._postprocess()
return self.policy.get_shared_params()
def _preprocess(self) -> None:
self.model = self.policy.preprocess()
@ -172,3 +179,23 @@ class ModelSharder(object):
)
setattr_(org_layer, suffix, replace_layer)
def _release_unheld_layers(self) -> None:
r"""
Release the unheld layers in the model
"""
if self.shard_config and self.shard_config.pipeline_stage_manager:
held_layers = self.policy.get_held_layers()
set_tensors_to_none(self.model, exclude=set(held_layers))
def _materialize(self) -> None:
r"""
Materialize the model if lazy initialization is used
"""
for p in self.model.parameters():
if isinstance(p, LazyTensor):
p.materialize()
for b in self.model.buffers():
if isinstance(b, LazyTensor):
b.materialize()

View File

@ -42,5 +42,5 @@ class ShardFormer:
policy (`Policy`): the custom policy for sharding
"""
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
sharder.shard()
return model
shared_params = sharder.shard()
return model, shared_params

View File

@ -0,0 +1,19 @@
from typing import Set
import torch.nn as nn
def set_tensors_to_none(model: nn.Module, exclude: Set[nn.Module] = set()) -> None:
"""Set all parameters and buffers of model to None
Args:
model (nn.Module): The model to set
"""
if model in exclude:
return
for child in model.children():
set_tensors_to_none(child, exclude=exclude)
for n, p in model.named_parameters(recurse=False):
setattr(model, n, None)
for n, buf in model.named_buffers(recurse=False):
setattr(model, n, None)