mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix typo change MoECheckpintIO to MoECheckpointIO (#5335)
Co-authored-by: binmakeswell <binmakeswell@gmail.com>pull/5429/head
parent
a7ae2b5b4c
commit
5e1c93d732
|
@ -40,7 +40,7 @@ def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None:
|
||||||
def auto_set_accelerator() -> None:
|
def auto_set_accelerator() -> None:
|
||||||
"""
|
"""
|
||||||
Automatically check if any accelerator is available.
|
Automatically check if any accelerator is available.
|
||||||
If an accelerator is availabe, set it as the global accelerator.
|
If an accelerator is available, set it as the global accelerator.
|
||||||
"""
|
"""
|
||||||
global _ACCELERATOR
|
global _ACCELERATOR
|
||||||
|
|
||||||
|
|
|
@ -437,7 +437,7 @@ class GeminiPlugin(DPPluginBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
"""Destroy the prcess groups in ProcessGroupMesh"""
|
"""Destroy the process groups in ProcessGroupMesh"""
|
||||||
self.pg_mesh.destroy_mesh_process_groups()
|
self.pg_mesh.destroy_mesh_process_groups()
|
||||||
|
|
||||||
def support_no_sync(self) -> bool:
|
def support_no_sync(self) -> bool:
|
||||||
|
|
|
@ -1067,7 +1067,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
self.max_norm = max_norm
|
self.max_norm = max_norm
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
"""Destroy the prcess groups in ProcessGroupMesh"""
|
"""Destroy the process groups in ProcessGroupMesh"""
|
||||||
self.pg_mesh.destroy_mesh_process_groups()
|
self.pg_mesh.destroy_mesh_process_groups()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
||||||
)
|
)
|
||||||
from colossalai.cluster import ProcessGroupMesh
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.moe import MOE_MANAGER, MoECheckpintIO
|
from colossalai.moe import MOE_MANAGER, MoECheckpointIO
|
||||||
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer import ShardConfig
|
from colossalai.shardformer import ShardConfig
|
||||||
|
@ -341,9 +341,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
**_kwargs,
|
**_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_checkpoint_io(self) -> MoECheckpintIO:
|
|
||||||
|
def get_checkpoint_io(self) -> MoECheckpointIO:
|
||||||
if self.checkpoint_io is None:
|
if self.checkpoint_io is None:
|
||||||
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
self.checkpoint_io = MoECheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||||
else:
|
else:
|
||||||
self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||||
return self.checkpoint_io
|
return self.checkpoint_io
|
||||||
|
|
|
@ -51,7 +51,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||||
pp_group (ProcessGroup): Process group along pipeline parallel dimension.
|
pp_group (ProcessGroup): Process group along pipeline parallel dimension.
|
||||||
tp_group (ProcessGroup): Process group along tensor parallel dimension.
|
tp_group (ProcessGroup): Process group along tensor parallel dimension.
|
||||||
zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2].
|
zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2].
|
||||||
verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True.
|
verbose (bool, optional): Whether to print logging massage when saving/loading has been successfully executed. Defaults to True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -574,7 +574,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||||
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
|
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
|
||||||
# obtain updated param group
|
# obtain updated param group
|
||||||
new_pg = copy.deepcopy(saved_pg)
|
new_pg = copy.deepcopy(saved_pg)
|
||||||
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
|
new_pg["params"] = old_pg["params"] # The parameters in the same group shouldn't change.
|
||||||
updated_groups.append(new_pg)
|
updated_groups.append(new_pg)
|
||||||
optimizer.optim.__dict__.update({"param_groups": updated_groups})
|
optimizer.optim.__dict__.update({"param_groups": updated_groups})
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from .checkpoint import MoECheckpintIO
|
from .checkpoint import MoECheckpointIO
|
||||||
from .experts import MLPExperts
|
from .experts import MLPExperts
|
||||||
from .layers import SparseMLP, apply_load_balance
|
from .layers import SparseMLP, apply_load_balance
|
||||||
from .manager import MOE_MANAGER
|
from .manager import MOE_MANAGER
|
||||||
|
@ -14,7 +14,7 @@ __all__ = [
|
||||||
"NormalNoiseGenerator",
|
"NormalNoiseGenerator",
|
||||||
"UniformNoiseGenerator",
|
"UniformNoiseGenerator",
|
||||||
"SparseMLP",
|
"SparseMLP",
|
||||||
"MoECheckpintIO",
|
"MoECheckpointIO",
|
||||||
"MOE_MANAGER",
|
"MOE_MANAGER",
|
||||||
"apply_load_balance",
|
"apply_load_balance",
|
||||||
]
|
]
|
||||||
|
|
|
@ -40,7 +40,7 @@ from colossalai.tensor.moe_tensor.api import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MoECheckpintIO(HybridParallelCheckpointIO):
|
class MoECheckpointIO(HybridParallelCheckpointIO):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dp_group: ProcessGroup,
|
dp_group: ProcessGroup,
|
||||||
|
@ -373,7 +373,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
||||||
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
|
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
|
||||||
# obtain updated param group
|
# obtain updated param group
|
||||||
new_pg = copy.deepcopy(saved_pg)
|
new_pg = copy.deepcopy(saved_pg)
|
||||||
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
|
new_pg["params"] = old_pg["params"] # The parameters in the same group shouldn't change.
|
||||||
updated_groups.append(new_pg)
|
updated_groups.append(new_pg)
|
||||||
# ep param group
|
# ep param group
|
||||||
if len(optimizer.optim.param_groups) > len(saved_groups):
|
if len(optimizer.optim.param_groups) > len(saved_groups):
|
||||||
|
|
Loading…
Reference in New Issue