mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] update shardformer policy
parent
90a65ea682
commit
59f6f573f1
|
@ -2,9 +2,13 @@
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, List, Type, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import Module
|
||||||
|
|
||||||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
|
||||||
from ..shard.shard_config import ShardConfig
|
from ..shard.shard_config import ShardConfig
|
||||||
|
|
||||||
|
@ -71,9 +75,8 @@ class Policy(ABC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.shard_config = None
|
self.shard_config: Optional[ShardConfig] = None
|
||||||
self.model = None
|
self.model: Optional[Module] = None
|
||||||
self.shard_config = None
|
|
||||||
|
|
||||||
def set_model(self, model: nn.Module) -> None:
|
def set_model(self, model: nn.Module) -> None:
|
||||||
r"""
|
r"""
|
||||||
|
@ -94,6 +97,12 @@ class Policy(ABC):
|
||||||
self.shard_config = shard_config
|
self.shard_config = shard_config
|
||||||
self.config_sanity_check()
|
self.config_sanity_check()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pipeline_stage_manager(self) -> Optional[PipelineStageManager]:
|
||||||
|
if self.shard_config is not None:
|
||||||
|
return self.shard_config.pipeline_stage_manager
|
||||||
|
return None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def config_sanity_check(self):
|
def config_sanity_check(self):
|
||||||
"""
|
"""
|
||||||
|
@ -151,3 +160,19 @@ class Policy(ABC):
|
||||||
policy[target_key] = ModulePolicyDescription(sub_module_replacement=description)
|
policy[target_key] = ModulePolicyDescription(sub_module_replacement=description)
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[Module]:
|
||||||
|
"""Get layers that should be held in current stage. This method should be implemented by subclass.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Module]: List of layers that should be hold in current stage
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
"""Get parameters that should be shared across stages. This method should be implemented by subclass.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}]
|
||||||
|
"""
|
||||||
|
return []
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
|
||||||
__all__ = ['ShardConfig']
|
__all__ = ['ShardConfig']
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,11 +16,13 @@ class ShardConfig:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group.
|
tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group.
|
||||||
|
pipeline_stage_manager (PipelineStageManager): The pipeline stage manager, defaults to None, which means no pipeline.
|
||||||
enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True.
|
enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True.
|
||||||
enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
|
enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
|
||||||
enable_all_optimization (bool): Whether to turn on all optimization, default is False.
|
enable_all_optimization (bool): Whether to turn on all optimization, default is False.
|
||||||
"""
|
"""
|
||||||
tensor_parallel_process_group: ProcessGroup = None
|
tensor_parallel_process_group: Optional[ProcessGroup] = None
|
||||||
|
pipeline_stage_manager: Optional[PipelineStageManager] = None
|
||||||
enable_tensor_parallelism: bool = True
|
enable_tensor_parallelism: bool = True
|
||||||
enable_fused_normalization: bool = False
|
enable_fused_normalization: bool = False
|
||||||
enable_all_optimization: bool = False
|
enable_all_optimization: bool = False
|
||||||
|
|
|
@ -1,11 +1,15 @@
|
||||||
from typing import Any, Callable, Dict, List, Union
|
from typing import Any, Callable, Dict, List, Union
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from colossalai.lazy import LazyTensor
|
||||||
|
|
||||||
from .._utils import getattr_, setattr_
|
from .._utils import getattr_, setattr_
|
||||||
from ..policies.autopolicy import get_autopolicy
|
from ..policies.autopolicy import get_autopolicy
|
||||||
from ..policies.basepolicy import Policy, SubModuleReplacementDescription
|
from ..policies.basepolicy import Policy, SubModuleReplacementDescription
|
||||||
from .shard_config import ShardConfig
|
from .shard_config import ShardConfig
|
||||||
|
from .utils import set_tensors_to_none
|
||||||
|
|
||||||
__all__ = ['ModelSharder', 'shard_model']
|
__all__ = ['ModelSharder', 'shard_model']
|
||||||
|
|
||||||
|
@ -25,15 +29,18 @@ class ModelSharder(object):
|
||||||
self.policy = get_autopolicy(self.model) if policy is None else policy
|
self.policy = get_autopolicy(self.model) if policy is None else policy
|
||||||
self.shard_config = shard_config
|
self.shard_config = shard_config
|
||||||
|
|
||||||
def shard(self) -> None:
|
def shard(self) -> List[Dict[int, Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Shard the model according to the policy
|
Shard the model according to the policy
|
||||||
"""
|
"""
|
||||||
self.policy.set_model(self.model)
|
self.policy.set_model(self.model)
|
||||||
self.policy.set_shard_config(self.shard_config)
|
self.policy.set_shard_config(self.shard_config)
|
||||||
self._preprocess()
|
self._preprocess()
|
||||||
|
self._release_unheld_layers()
|
||||||
self._replace_module()
|
self._replace_module()
|
||||||
|
self._materialize()
|
||||||
self._postprocess()
|
self._postprocess()
|
||||||
|
return self.policy.get_shared_params()
|
||||||
|
|
||||||
def _preprocess(self) -> None:
|
def _preprocess(self) -> None:
|
||||||
self.model = self.policy.preprocess()
|
self.model = self.policy.preprocess()
|
||||||
|
@ -172,3 +179,23 @@ class ModelSharder(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
setattr_(org_layer, suffix, replace_layer)
|
setattr_(org_layer, suffix, replace_layer)
|
||||||
|
|
||||||
|
def _release_unheld_layers(self) -> None:
|
||||||
|
r"""
|
||||||
|
Release the unheld layers in the model
|
||||||
|
"""
|
||||||
|
if self.shard_config and self.shard_config.pipeline_stage_manager:
|
||||||
|
held_layers = self.policy.get_held_layers()
|
||||||
|
set_tensors_to_none(self.model, exclude=set(held_layers))
|
||||||
|
|
||||||
|
def _materialize(self) -> None:
|
||||||
|
r"""
|
||||||
|
Materialize the model if lazy initialization is used
|
||||||
|
"""
|
||||||
|
for p in self.model.parameters():
|
||||||
|
if isinstance(p, LazyTensor):
|
||||||
|
p.materialize()
|
||||||
|
|
||||||
|
for b in self.model.buffers():
|
||||||
|
if isinstance(b, LazyTensor):
|
||||||
|
b.materialize()
|
||||||
|
|
|
@ -42,5 +42,5 @@ class ShardFormer:
|
||||||
policy (`Policy`): the custom policy for sharding
|
policy (`Policy`): the custom policy for sharding
|
||||||
"""
|
"""
|
||||||
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
|
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
|
||||||
sharder.shard()
|
shared_params = sharder.shard()
|
||||||
return model
|
return model, shared_params
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
from typing import Set
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def set_tensors_to_none(model: nn.Module, exclude: Set[nn.Module] = set()) -> None:
|
||||||
|
"""Set all parameters and buffers of model to None
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The model to set
|
||||||
|
"""
|
||||||
|
if model in exclude:
|
||||||
|
return
|
||||||
|
for child in model.children():
|
||||||
|
set_tensors_to_none(child, exclude=exclude)
|
||||||
|
for n, p in model.named_parameters(recurse=False):
|
||||||
|
setattr(model, n, None)
|
||||||
|
for n, buf in model.named_buffers(recurse=False):
|
||||||
|
setattr(model, n, None)
|
Loading…
Reference in New Issue