mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] add custom policy in hybrid parallel plugin (#4718)
* add custom policy * update assertpull/4741/head
parent
451c3465fb
commit
ac2797996b
|
@ -22,6 +22,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
|
|||
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
|
||||
from .pp_plugin_base import PipelinePluginBase
|
||||
|
@ -38,13 +39,15 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
|||
class HybridParallelModule(ModelWrapper):
|
||||
|
||||
def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool,
|
||||
ddp_config: dict) -> None:
|
||||
ddp_config: dict, custom_policy: Policy) -> None:
|
||||
|
||||
self.stage_manager = shard_config.pipeline_stage_manager
|
||||
self.dp_group = dp_group
|
||||
|
||||
shardformer = ShardFormer(shard_config)
|
||||
module, self.shared_params = shardformer.optimize(module)
|
||||
if custom_policy is not None:
|
||||
assert isinstance(custom_policy, object)
|
||||
module, self.shared_params = shardformer.optimize(module, policy=custom_policy)
|
||||
|
||||
# setting process groups for shared parameters
|
||||
self.shared_param_process_groups = []
|
||||
|
@ -270,6 +273,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
|
||||
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
|
||||
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
|
||||
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -302,7 +306,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
zero_bucket_size_in_m: int = 12,
|
||||
cpu_offload: bool = False,
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = True) -> None:
|
||||
overlap_communication: bool = True,
|
||||
custom_policy: Policy = None) -> None:
|
||||
|
||||
super().__init__()
|
||||
assert dist.get_world_size() % (
|
||||
|
@ -326,6 +331,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
|
||||
self.stage_manager = None
|
||||
self.schedule = None
|
||||
self.custom_policy = custom_policy
|
||||
assert zero_stage in (0, 1, 2)
|
||||
if self.pp_size > 1:
|
||||
assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism'
|
||||
|
@ -405,7 +411,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
if not isinstance(model, ModelWrapper):
|
||||
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
||||
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
|
||||
self.ddp_config)
|
||||
self.ddp_config, self.custom_policy)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if self.zero_stage == 0:
|
||||
if self.precision in ['fp16', 'bf16']:
|
||||
|
|
Loading…
Reference in New Issue