mirror of https://github.com/hpcaitech/ColossalAI
[moe] fix plugin
parent
6a9164a477
commit
5a9490a46b
|
@ -254,7 +254,7 @@ def get_param_info(optim: Optimizer):
|
|||
return param_info
|
||||
|
||||
|
||||
def init_pipeline_optimizer(optim: Optimizer, model: Module):
|
||||
def reinitialize_optimizer(optim: Optimizer, model: Module):
|
||||
model_params = set(model.parameters())
|
||||
new_param_groups = []
|
||||
for group in optim.param_groups:
|
||||
|
@ -276,7 +276,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||
):
|
||||
self.param_info = param_info
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optim, model)
|
||||
reinitialize_optimizer(optim, model)
|
||||
self.model = model
|
||||
self.stage_manager = model.stage_manager
|
||||
self.shared_params = model.shared_params
|
||||
|
@ -497,7 +497,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optim, model)
|
||||
reinitialize_optimizer(optim, model)
|
||||
super().__init__(
|
||||
optim,
|
||||
precision=precision,
|
||||
|
@ -678,7 +678,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
self.tp_pg = tp_process_group
|
||||
self.pp_pg = pp_process_group
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optimizer, model)
|
||||
reinitialize_optimizer(optimizer, model)
|
||||
super().__init__(
|
||||
optimizer=optimizer,
|
||||
initial_scale=initial_scale,
|
||||
|
|
|
@ -19,7 +19,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
|||
HybridParallelNaiveOptimizer,
|
||||
HybridParallelPlugin,
|
||||
get_param_info,
|
||||
init_pipeline_optimizer,
|
||||
reinitialize_optimizer,
|
||||
)
|
||||
from colossalai.checkpoint_io import MoECheckpointIO
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
|
@ -67,7 +67,7 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
self.tp_pg = tp_process_group
|
||||
self.pp_pg = pp_process_group
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optimizer, model)
|
||||
reinitialize_optimizer(optimizer, model)
|
||||
|
||||
pg_param_list = {
|
||||
dp_process_group: [],
|
||||
|
@ -400,12 +400,19 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
dp_group=self.global_dp_group,
|
||||
tp_group=self.tp_group,
|
||||
sp_group=self.sp_group,
|
||||
use_ddp=use_ddp,
|
||||
use_ddp=use_ddp, # TODO fix why this failed
|
||||
ddp_config=self.ddp_config,
|
||||
custom_policy=self.custom_policy,
|
||||
)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if self.ep_size > 1:
|
||||
# if ep is enabled, the num of (moe) paramaters changed since they are sharded among ep groups
|
||||
# but the optimizer is not aware of ep, so we need to update the optimizer
|
||||
reinitialize_optimizer(optimizer, model)
|
||||
|
||||
if self.zero_stage == 0:
|
||||
assert self.ep_size > 1
|
||||
|
||||
if self.precision in ["fp16", "bf16"]:
|
||||
optimizer = HybridParallelAMPOptimizer(
|
||||
optimizer,
|
||||
|
|
Loading…
Reference in New Issue