mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] update shardformer policy
parent
90a65ea682
commit
59f6f573f1
|
@ -2,9 +2,13 @@
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Type, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
from ..shard.shard_config import ShardConfig
|
||||
|
||||
|
@ -71,9 +75,8 @@ class Policy(ABC):
|
|||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.shard_config = None
|
||||
self.model = None
|
||||
self.shard_config = None
|
||||
self.shard_config: Optional[ShardConfig] = None
|
||||
self.model: Optional[Module] = None
|
||||
|
||||
def set_model(self, model: nn.Module) -> None:
|
||||
r"""
|
||||
|
@ -94,6 +97,12 @@ class Policy(ABC):
|
|||
self.shard_config = shard_config
|
||||
self.config_sanity_check()
|
||||
|
||||
@property
|
||||
def pipeline_stage_manager(self) -> Optional[PipelineStageManager]:
|
||||
if self.shard_config is not None:
|
||||
return self.shard_config.pipeline_stage_manager
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def config_sanity_check(self):
|
||||
"""
|
||||
|
@ -151,3 +160,19 @@ class Policy(ABC):
|
|||
policy[target_key] = ModulePolicyDescription(sub_module_replacement=description)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get layers that should be held in current stage. This method should be implemented by subclass.
|
||||
|
||||
Returns:
|
||||
List[Module]: List of layers that should be hold in current stage
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""Get parameters that should be shared across stages. This method should be implemented by subclass.
|
||||
|
||||
Returns:
|
||||
List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}]
|
||||
"""
|
||||
return []
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
__all__ = ['ShardConfig']
|
||||
|
||||
|
||||
|
@ -13,11 +16,13 @@ class ShardConfig:
|
|||
|
||||
Args:
|
||||
tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group.
|
||||
pipeline_stage_manager (PipelineStageManager): The pipeline stage manager, defaults to None, which means no pipeline.
|
||||
enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True.
|
||||
enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
|
||||
enable_all_optimization (bool): Whether to turn on all optimization, default is False.
|
||||
"""
|
||||
tensor_parallel_process_group: ProcessGroup = None
|
||||
tensor_parallel_process_group: Optional[ProcessGroup] = None
|
||||
pipeline_stage_manager: Optional[PipelineStageManager] = None
|
||||
enable_tensor_parallelism: bool = True
|
||||
enable_fused_normalization: bool = False
|
||||
enable_all_optimization: bool = False
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
from typing import Any, Callable, Dict, List, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.lazy import LazyTensor
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from ..policies.autopolicy import get_autopolicy
|
||||
from ..policies.basepolicy import Policy, SubModuleReplacementDescription
|
||||
from .shard_config import ShardConfig
|
||||
from .utils import set_tensors_to_none
|
||||
|
||||
__all__ = ['ModelSharder', 'shard_model']
|
||||
|
||||
|
@ -25,15 +29,18 @@ class ModelSharder(object):
|
|||
self.policy = get_autopolicy(self.model) if policy is None else policy
|
||||
self.shard_config = shard_config
|
||||
|
||||
def shard(self) -> None:
|
||||
def shard(self) -> List[Dict[int, Tensor]]:
|
||||
r"""
|
||||
Shard the model according to the policy
|
||||
"""
|
||||
self.policy.set_model(self.model)
|
||||
self.policy.set_shard_config(self.shard_config)
|
||||
self._preprocess()
|
||||
self._release_unheld_layers()
|
||||
self._replace_module()
|
||||
self._materialize()
|
||||
self._postprocess()
|
||||
return self.policy.get_shared_params()
|
||||
|
||||
def _preprocess(self) -> None:
|
||||
self.model = self.policy.preprocess()
|
||||
|
@ -172,3 +179,23 @@ class ModelSharder(object):
|
|||
)
|
||||
|
||||
setattr_(org_layer, suffix, replace_layer)
|
||||
|
||||
def _release_unheld_layers(self) -> None:
|
||||
r"""
|
||||
Release the unheld layers in the model
|
||||
"""
|
||||
if self.shard_config and self.shard_config.pipeline_stage_manager:
|
||||
held_layers = self.policy.get_held_layers()
|
||||
set_tensors_to_none(self.model, exclude=set(held_layers))
|
||||
|
||||
def _materialize(self) -> None:
|
||||
r"""
|
||||
Materialize the model if lazy initialization is used
|
||||
"""
|
||||
for p in self.model.parameters():
|
||||
if isinstance(p, LazyTensor):
|
||||
p.materialize()
|
||||
|
||||
for b in self.model.buffers():
|
||||
if isinstance(b, LazyTensor):
|
||||
b.materialize()
|
||||
|
|
|
@ -42,5 +42,5 @@ class ShardFormer:
|
|||
policy (`Policy`): the custom policy for sharding
|
||||
"""
|
||||
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
|
||||
sharder.shard()
|
||||
return model
|
||||
shared_params = sharder.shard()
|
||||
return model, shared_params
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
from typing import Set
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def set_tensors_to_none(model: nn.Module, exclude: Set[nn.Module] = set()) -> None:
|
||||
"""Set all parameters and buffers of model to None
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to set
|
||||
"""
|
||||
if model in exclude:
|
||||
return
|
||||
for child in model.children():
|
||||
set_tensors_to_none(child, exclude=exclude)
|
||||
for n, p in model.named_parameters(recurse=False):
|
||||
setattr(model, n, None)
|
||||
for n, buf in model.named_buffers(recurse=False):
|
||||
setattr(model, n, None)
|
Loading…
Reference in New Issue