[hotfix] fix typo change MoECheckpintIO to MoECheckpointIO (#5335)

Co-authored-by: binmakeswell <binmakeswell@gmail.com>
pull/5429/head
digger yu 2024-03-05 21:52:30 +08:00 committed by GitHub
parent a7ae2b5b4c
commit 5e1c93d732
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 13 additions and 12 deletions

View File

@ -40,7 +40,7 @@ def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None:
def auto_set_accelerator() -> None:
"""
Automatically check if any accelerator is available.
If an accelerator is availabe, set it as the global accelerator.
If an accelerator is available, set it as the global accelerator.
"""
global _ACCELERATOR

View File

@ -437,7 +437,7 @@ class GeminiPlugin(DPPluginBase):
)
def __del__(self):
"""Destroy the prcess groups in ProcessGroupMesh"""
"""Destroy the process groups in ProcessGroupMesh"""
self.pg_mesh.destroy_mesh_process_groups()
def support_no_sync(self) -> bool:

View File

@ -1067,7 +1067,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.max_norm = max_norm
def __del__(self):
"""Destroy the prcess groups in ProcessGroupMesh"""
"""Destroy the process groups in ProcessGroupMesh"""
self.pg_mesh.destroy_mesh_process_groups()
@property

View File

@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
)
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.moe import MOE_MANAGER, MoECheckpintIO
from colossalai.moe import MOE_MANAGER, MoECheckpointIO
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
@ -341,9 +341,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
**_kwargs,
)
def get_checkpoint_io(self) -> MoECheckpintIO:
def get_checkpoint_io(self) -> MoECheckpointIO:
if self.checkpoint_io is None:
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
self.checkpoint_io = MoECheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
else:
self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io

View File

@ -51,7 +51,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
pp_group (ProcessGroup): Process group along pipeline parallel dimension.
tp_group (ProcessGroup): Process group along tensor parallel dimension.
zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2].
verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True.
verbose (bool, optional): Whether to print logging massage when saving/loading has been successfully executed. Defaults to True.
"""
def __init__(
@ -574,7 +574,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
# obtain updated param group
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
new_pg["params"] = old_pg["params"] # The parameters in the same group shouldn't change.
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups})

View File

@ -1,4 +1,4 @@
from .checkpoint import MoECheckpintIO
from .checkpoint import MoECheckpointIO
from .experts import MLPExperts
from .layers import SparseMLP, apply_load_balance
from .manager import MOE_MANAGER
@ -14,7 +14,7 @@ __all__ = [
"NormalNoiseGenerator",
"UniformNoiseGenerator",
"SparseMLP",
"MoECheckpintIO",
"MoECheckpointIO",
"MOE_MANAGER",
"apply_load_balance",
]

View File

@ -40,7 +40,7 @@ from colossalai.tensor.moe_tensor.api import (
)
class MoECheckpintIO(HybridParallelCheckpointIO):
class MoECheckpointIO(HybridParallelCheckpointIO):
def __init__(
self,
dp_group: ProcessGroup,
@ -373,7 +373,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
# obtain updated param group
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
new_pg["params"] = old_pg["params"] # The parameters in the same group shouldn't change.
updated_groups.append(new_pg)
# ep param group
if len(optimizer.optim.param_groups) > len(saved_groups):