[hotfix] fix typo in hybrid parallel io (#4697)

pull/4127/head^2
Baizhou Zhang 2023-09-12 17:32:19 +08:00 committed by GitHub
parent 8844691f4b
commit d8ceeac14e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 7 additions and 7 deletions

View File

@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
from colossalai.checkpoint_io import CheckpointIO, HypridParallelCheckpointIO from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
@ -513,7 +513,7 @@ class HybridParallelPlugin(PipelinePluginBase):
**_kwargs) **_kwargs)
def get_checkpoint_io(self) -> CheckpointIO: def get_checkpoint_io(self) -> CheckpointIO:
self.checkpoint_io = HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io return self.checkpoint_io
def no_sync(self, model: Module) -> Iterator[None]: def no_sync(self, model: Module) -> Iterator[None]:

View File

@ -1,6 +1,6 @@
from .checkpoint_io_base import CheckpointIO from .checkpoint_io_base import CheckpointIO
from .general_checkpoint_io import GeneralCheckpointIO from .general_checkpoint_io import GeneralCheckpointIO
from .hybrid_parallel_checkpoint_io import HypridParallelCheckpointIO from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
from .index_file import CheckpointIndexFile from .index_file import CheckpointIndexFile
__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO'] __all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO']

View File

@ -39,7 +39,7 @@ except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state' _EXTRA_STATE_KEY_SUFFIX = '_extra_state'
class HypridParallelCheckpointIO(GeneralCheckpointIO): class HybridParallelCheckpointIO(GeneralCheckpointIO):
""" """
CheckpointIO for Hybrid Parallel Training. CheckpointIO for Hybrid Parallel Training.
@ -136,7 +136,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
param_id = param_info['param2id'][id(working_param)] param_id = param_info['param2id'][id(working_param)]
original_shape = param_info['param2shape'][id(working_param)] original_shape = param_info['param2shape'][id(working_param)]
state_ = HypridParallelCheckpointIO.gather_from_sharded_optimizer_state(state, state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(state,
working_param, working_param,
original_shape=original_shape, original_shape=original_shape,
dp_group=dp_group, dp_group=dp_group,
@ -189,7 +189,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
# Then collect the sharded parameters & buffers along tp_group. # Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving. # Only devices with tp_rank == 0 are responsible for model saving.
state_dict_shard = HypridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint) index_file = CheckpointIndexFile(checkpoint)
control_saving = (self.tp_rank == 0) control_saving = (self.tp_rank == 0)
@ -385,7 +385,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
# Then collect the sharded states along dp_group(if using zero)/tp_group. # Then collect the sharded states along dp_group(if using zero)/tp_group.
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
state_dict_shard = HypridParallelCheckpointIO._optimizer_sharder( state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder(
optimizer, optimizer,
use_zero=self.use_zero, use_zero=self.use_zero,
dp_group=self.dp_group, dp_group=self.dp_group,