[moe] fix plugin

moe_sp
hxwang 2024-07-02 09:09:00 +00:00
parent 6a9164a477
commit 5a9490a46b
No known key found for this signature in database
GPG Key ID: 0EC383D418F0B9F8
2 changed files with 14 additions and 7 deletions

View File

@ -254,7 +254,7 @@ def get_param_info(optim: Optimizer):
return param_info
def init_pipeline_optimizer(optim: Optimizer, model: Module):
def reinitialize_optimizer(optim: Optimizer, model: Module):
model_params = set(model.parameters())
new_param_groups = []
for group in optim.param_groups:
@ -276,7 +276,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optim, model)
reinitialize_optimizer(optim, model)
self.model = model
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
@ -497,7 +497,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
if use_pipeline:
init_pipeline_optimizer(optim, model)
reinitialize_optimizer(optim, model)
super().__init__(
optim,
precision=precision,
@ -678,7 +678,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
reinitialize_optimizer(optimizer, model)
super().__init__(
optimizer=optimizer,
initial_scale=initial_scale,

View File

@ -19,7 +19,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
HybridParallelNaiveOptimizer,
HybridParallelPlugin,
get_param_info,
init_pipeline_optimizer,
reinitialize_optimizer,
)
from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.cluster import ProcessGroupMesh
@ -67,7 +67,7 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
reinitialize_optimizer(optimizer, model)
pg_param_list = {
dp_process_group: [],
@ -400,12 +400,19 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
dp_group=self.global_dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
use_ddp=use_ddp,
use_ddp=use_ddp, # TODO fix why this failed
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.ep_size > 1:
# if ep is enabled, the num of (moe) paramaters changed since they are sharded among ep groups
# but the optimizer is not aware of ep, so we need to update the optimizer
reinitialize_optimizer(optimizer, model)
if self.zero_stage == 0:
assert self.ep_size > 1
if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer(
optimizer,