[moe] fix plugin

moe_sp
hxwang 2024-07-02 09:09:00 +00:00
parent 6a9164a477
commit 5a9490a46b
No known key found for this signature in database
GPG Key ID: 0EC383D418F0B9F8
2 changed files with 14 additions and 7 deletions

View File

@ -254,7 +254,7 @@ def get_param_info(optim: Optimizer):
return param_info return param_info
def init_pipeline_optimizer(optim: Optimizer, model: Module): def reinitialize_optimizer(optim: Optimizer, model: Module):
model_params = set(model.parameters()) model_params = set(model.parameters())
new_param_groups = [] new_param_groups = []
for group in optim.param_groups: for group in optim.param_groups:
@ -276,7 +276,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
): ):
self.param_info = param_info self.param_info = param_info
if use_pipeline: if use_pipeline:
init_pipeline_optimizer(optim, model) reinitialize_optimizer(optim, model)
self.model = model self.model = model
self.stage_manager = model.stage_manager self.stage_manager = model.stage_manager
self.shared_params = model.shared_params self.shared_params = model.shared_params
@ -497,7 +497,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
if use_pipeline: if use_pipeline:
init_pipeline_optimizer(optim, model) reinitialize_optimizer(optim, model)
super().__init__( super().__init__(
optim, optim,
precision=precision, precision=precision,
@ -678,7 +678,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
self.tp_pg = tp_process_group self.tp_pg = tp_process_group
self.pp_pg = pp_process_group self.pp_pg = pp_process_group
if use_pipeline: if use_pipeline:
init_pipeline_optimizer(optimizer, model) reinitialize_optimizer(optimizer, model)
super().__init__( super().__init__(
optimizer=optimizer, optimizer=optimizer,
initial_scale=initial_scale, initial_scale=initial_scale,

View File

@ -19,7 +19,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
HybridParallelNaiveOptimizer, HybridParallelNaiveOptimizer,
HybridParallelPlugin, HybridParallelPlugin,
get_param_info, get_param_info,
init_pipeline_optimizer, reinitialize_optimizer,
) )
from colossalai.checkpoint_io import MoECheckpointIO from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
@ -67,7 +67,7 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
self.tp_pg = tp_process_group self.tp_pg = tp_process_group
self.pp_pg = pp_process_group self.pp_pg = pp_process_group
if use_pipeline: if use_pipeline:
init_pipeline_optimizer(optimizer, model) reinitialize_optimizer(optimizer, model)
pg_param_list = { pg_param_list = {
dp_process_group: [], dp_process_group: [],
@ -400,12 +400,19 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
dp_group=self.global_dp_group, dp_group=self.global_dp_group,
tp_group=self.tp_group, tp_group=self.tp_group,
sp_group=self.sp_group, sp_group=self.sp_group,
use_ddp=use_ddp, use_ddp=use_ddp, # TODO fix why this failed
ddp_config=self.ddp_config, ddp_config=self.ddp_config,
custom_policy=self.custom_policy, custom_policy=self.custom_policy,
) )
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.ep_size > 1:
# if ep is enabled, the num of (moe) paramaters changed since they are sharded among ep groups
# but the optimizer is not aware of ep, so we need to update the optimizer
reinitialize_optimizer(optimizer, model)
if self.zero_stage == 0: if self.zero_stage == 0:
assert self.ep_size > 1
if self.precision in ["fp16", "bf16"]: if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer( optimizer = HybridParallelAMPOptimizer(
optimizer, optimizer,