mirror of https://github.com/hpcaitech/ColossalAI
[moe] fix plugin
parent
6a9164a477
commit
5a9490a46b
|
@ -254,7 +254,7 @@ def get_param_info(optim: Optimizer):
|
||||||
return param_info
|
return param_info
|
||||||
|
|
||||||
|
|
||||||
def init_pipeline_optimizer(optim: Optimizer, model: Module):
|
def reinitialize_optimizer(optim: Optimizer, model: Module):
|
||||||
model_params = set(model.parameters())
|
model_params = set(model.parameters())
|
||||||
new_param_groups = []
|
new_param_groups = []
|
||||||
for group in optim.param_groups:
|
for group in optim.param_groups:
|
||||||
|
@ -276,7 +276,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||||
):
|
):
|
||||||
self.param_info = param_info
|
self.param_info = param_info
|
||||||
if use_pipeline:
|
if use_pipeline:
|
||||||
init_pipeline_optimizer(optim, model)
|
reinitialize_optimizer(optim, model)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.stage_manager = model.stage_manager
|
self.stage_manager = model.stage_manager
|
||||||
self.shared_params = model.shared_params
|
self.shared_params = model.shared_params
|
||||||
|
@ -497,7 +497,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||||
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||||
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||||
if use_pipeline:
|
if use_pipeline:
|
||||||
init_pipeline_optimizer(optim, model)
|
reinitialize_optimizer(optim, model)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
optim,
|
optim,
|
||||||
precision=precision,
|
precision=precision,
|
||||||
|
@ -678,7 +678,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||||
self.tp_pg = tp_process_group
|
self.tp_pg = tp_process_group
|
||||||
self.pp_pg = pp_process_group
|
self.pp_pg = pp_process_group
|
||||||
if use_pipeline:
|
if use_pipeline:
|
||||||
init_pipeline_optimizer(optimizer, model)
|
reinitialize_optimizer(optimizer, model)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
initial_scale=initial_scale,
|
initial_scale=initial_scale,
|
||||||
|
|
|
@ -19,7 +19,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
||||||
HybridParallelNaiveOptimizer,
|
HybridParallelNaiveOptimizer,
|
||||||
HybridParallelPlugin,
|
HybridParallelPlugin,
|
||||||
get_param_info,
|
get_param_info,
|
||||||
init_pipeline_optimizer,
|
reinitialize_optimizer,
|
||||||
)
|
)
|
||||||
from colossalai.checkpoint_io import MoECheckpointIO
|
from colossalai.checkpoint_io import MoECheckpointIO
|
||||||
from colossalai.cluster import ProcessGroupMesh
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
|
@ -67,7 +67,7 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||||
self.tp_pg = tp_process_group
|
self.tp_pg = tp_process_group
|
||||||
self.pp_pg = pp_process_group
|
self.pp_pg = pp_process_group
|
||||||
if use_pipeline:
|
if use_pipeline:
|
||||||
init_pipeline_optimizer(optimizer, model)
|
reinitialize_optimizer(optimizer, model)
|
||||||
|
|
||||||
pg_param_list = {
|
pg_param_list = {
|
||||||
dp_process_group: [],
|
dp_process_group: [],
|
||||||
|
@ -400,12 +400,19 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
dp_group=self.global_dp_group,
|
dp_group=self.global_dp_group,
|
||||||
tp_group=self.tp_group,
|
tp_group=self.tp_group,
|
||||||
sp_group=self.sp_group,
|
sp_group=self.sp_group,
|
||||||
use_ddp=use_ddp,
|
use_ddp=use_ddp, # TODO fix why this failed
|
||||||
ddp_config=self.ddp_config,
|
ddp_config=self.ddp_config,
|
||||||
custom_policy=self.custom_policy,
|
custom_policy=self.custom_policy,
|
||||||
)
|
)
|
||||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||||
|
if self.ep_size > 1:
|
||||||
|
# if ep is enabled, the num of (moe) paramaters changed since they are sharded among ep groups
|
||||||
|
# but the optimizer is not aware of ep, so we need to update the optimizer
|
||||||
|
reinitialize_optimizer(optimizer, model)
|
||||||
|
|
||||||
if self.zero_stage == 0:
|
if self.zero_stage == 0:
|
||||||
|
assert self.ep_size > 1
|
||||||
|
|
||||||
if self.precision in ["fp16", "bf16"]:
|
if self.precision in ["fp16", "bf16"]:
|
||||||
optimizer = HybridParallelAMPOptimizer(
|
optimizer = HybridParallelAMPOptimizer(
|
||||||
optimizer,
|
optimizer,
|
||||||
|
|
Loading…
Reference in New Issue