diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 2c8cb6ba1..92bab29ec 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -254,7 +254,7 @@ def get_param_info(optim: Optimizer): return param_info -def init_pipeline_optimizer(optim: Optimizer, model: Module): +def reinitialize_optimizer(optim: Optimizer, model: Module): model_params = set(model.parameters()) new_param_groups = [] for group in optim.param_groups: @@ -276,7 +276,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): ): self.param_info = param_info if use_pipeline: - init_pipeline_optimizer(optim, model) + reinitialize_optimizer(optim, model) self.model = model self.stage_manager = model.stage_manager self.shared_params = model.shared_params @@ -497,7 +497,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 if use_pipeline: - init_pipeline_optimizer(optim, model) + reinitialize_optimizer(optim, model) super().__init__( optim, precision=precision, @@ -678,7 +678,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): self.tp_pg = tp_process_group self.pp_pg = pp_process_group if use_pipeline: - init_pipeline_optimizer(optimizer, model) + reinitialize_optimizer(optimizer, model) super().__init__( optimizer=optimizer, initial_scale=initial_scale, diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 2cfdd000a..3d4250ac8 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -19,7 +19,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import ( HybridParallelNaiveOptimizer, HybridParallelPlugin, get_param_info, - init_pipeline_optimizer, + reinitialize_optimizer, ) from colossalai.checkpoint_io import MoECheckpointIO from colossalai.cluster import ProcessGroupMesh @@ -67,7 +67,7 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer): self.tp_pg = tp_process_group self.pp_pg = pp_process_group if use_pipeline: - init_pipeline_optimizer(optimizer, model) + reinitialize_optimizer(optimizer, model) pg_param_list = { dp_process_group: [], @@ -400,12 +400,19 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): dp_group=self.global_dp_group, tp_group=self.tp_group, sp_group=self.sp_group, - use_ddp=use_ddp, + use_ddp=use_ddp, # TODO fix why this failed ddp_config=self.ddp_config, custom_policy=self.custom_policy, ) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + if self.ep_size > 1: + # if ep is enabled, the num of (moe) paramaters changed since they are sharded among ep groups + # but the optimizer is not aware of ep, so we need to update the optimizer + reinitialize_optimizer(optimizer, model) + if self.zero_stage == 0: + assert self.ep_size > 1 + if self.precision in ["fp16", "bf16"]: optimizer = HybridParallelAMPOptimizer( optimizer,