[hotfix] fix typo change MoECheckpintIO to MoECheckpointIO (#5335)

Co-authored-by: binmakeswell <binmakeswell@gmail.com>
pull/5429/head
digger yu 2024-03-05 21:52:30 +08:00 committed by GitHub
parent a7ae2b5b4c
commit 5e1c93d732
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 13 additions and 12 deletions

View File

@ -40,7 +40,7 @@ def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None:
def auto_set_accelerator() -> None: def auto_set_accelerator() -> None:
""" """
Automatically check if any accelerator is available. Automatically check if any accelerator is available.
If an accelerator is availabe, set it as the global accelerator. If an accelerator is available, set it as the global accelerator.
""" """
global _ACCELERATOR global _ACCELERATOR

View File

@ -437,7 +437,7 @@ class GeminiPlugin(DPPluginBase):
) )
def __del__(self): def __del__(self):
"""Destroy the prcess groups in ProcessGroupMesh""" """Destroy the process groups in ProcessGroupMesh"""
self.pg_mesh.destroy_mesh_process_groups() self.pg_mesh.destroy_mesh_process_groups()
def support_no_sync(self) -> bool: def support_no_sync(self) -> bool:

View File

@ -1067,7 +1067,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.max_norm = max_norm self.max_norm = max_norm
def __del__(self): def __del__(self):
"""Destroy the prcess groups in ProcessGroupMesh""" """Destroy the process groups in ProcessGroupMesh"""
self.pg_mesh.destroy_mesh_process_groups() self.pg_mesh.destroy_mesh_process_groups()
@property @property

View File

@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
) )
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.moe import MOE_MANAGER, MoECheckpintIO from colossalai.moe import MOE_MANAGER, MoECheckpointIO
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig from colossalai.shardformer import ShardConfig
@ -341,9 +341,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
**_kwargs, **_kwargs,
) )
def get_checkpoint_io(self) -> MoECheckpintIO:
def get_checkpoint_io(self) -> MoECheckpointIO:
if self.checkpoint_io is None: if self.checkpoint_io is None:
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) self.checkpoint_io = MoECheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
else: else:
self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io return self.checkpoint_io

View File

@ -51,7 +51,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
pp_group (ProcessGroup): Process group along pipeline parallel dimension. pp_group (ProcessGroup): Process group along pipeline parallel dimension.
tp_group (ProcessGroup): Process group along tensor parallel dimension. tp_group (ProcessGroup): Process group along tensor parallel dimension.
zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2]. zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2].
verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True. verbose (bool, optional): Whether to print logging massage when saving/loading has been successfully executed. Defaults to True.
""" """
def __init__( def __init__(
@ -574,7 +574,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
# obtain updated param group # obtain updated param group
new_pg = copy.deepcopy(saved_pg) new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change. new_pg["params"] = old_pg["params"] # The parameters in the same group shouldn't change.
updated_groups.append(new_pg) updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups}) optimizer.optim.__dict__.update({"param_groups": updated_groups})

View File

@ -1,4 +1,4 @@
from .checkpoint import MoECheckpintIO from .checkpoint import MoECheckpointIO
from .experts import MLPExperts from .experts import MLPExperts
from .layers import SparseMLP, apply_load_balance from .layers import SparseMLP, apply_load_balance
from .manager import MOE_MANAGER from .manager import MOE_MANAGER
@ -14,7 +14,7 @@ __all__ = [
"NormalNoiseGenerator", "NormalNoiseGenerator",
"UniformNoiseGenerator", "UniformNoiseGenerator",
"SparseMLP", "SparseMLP",
"MoECheckpintIO", "MoECheckpointIO",
"MOE_MANAGER", "MOE_MANAGER",
"apply_load_balance", "apply_load_balance",
] ]

View File

@ -40,7 +40,7 @@ from colossalai.tensor.moe_tensor.api import (
) )
class MoECheckpintIO(HybridParallelCheckpointIO): class MoECheckpointIO(HybridParallelCheckpointIO):
def __init__( def __init__(
self, self,
dp_group: ProcessGroup, dp_group: ProcessGroup,
@ -373,7 +373,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
# obtain updated param group # obtain updated param group
new_pg = copy.deepcopy(saved_pg) new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change. new_pg["params"] = old_pg["params"] # The parameters in the same group shouldn't change.
updated_groups.append(new_pg) updated_groups.append(new_pg)
# ep param group # ep param group
if len(optimizer.optim.param_groups) > len(saved_groups): if len(optimizer.optim.param_groups) > len(saved_groups):