From c6f6005990b182d7ee34c1fb84762d31ce7d3616 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 21 Jul 2023 14:39:01 +0800 Subject: [PATCH] [checkpointio] Sharded Optimizer Checkpoint for Gemini Plugin (#4302) * sharded optimizer checkpoint for gemini plugin * modify test to reduce testing time * update doc * fix bug when keep_gatherd is true under GeminiPlugin --- colossalai/booster/plugin/gemini_plugin.py | 131 +++++++++++++--- .../checkpoint_io/general_checkpoint_io.py | 38 +++-- colossalai/checkpoint_io/utils.py | 38 +++++ colossalai/zero/gemini/gemini_optimizer.py | 140 ++++++++++++++---- docs/source/en/basics/booster_api.md | 5 +- docs/source/en/basics/booster_checkpoint.md | 2 - docs/source/en/basics/booster_plugins.md | 2 - docs/source/zh-Hans/basics/booster_api.md | 5 +- .../zh-Hans/basics/booster_checkpoint.md | 1 - docs/source/zh-Hans/basics/booster_plugins.md | 1 - .../test_gemini_checkpoint_io.py | 4 +- .../test_gemini_torch_compability.py | 6 +- 12 files changed, 289 insertions(+), 84 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 6191f271c..7b6e17337 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -1,3 +1,4 @@ +import gc import logging import os import warnings @@ -12,11 +13,19 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO -from colossalai.checkpoint_io.utils import get_model_base_filenames, get_shard_filename, save_state_dict +from colossalai.checkpoint_io.utils import ( + get_model_base_filenames, + get_optimizer_base_filenames, + get_shard_filename, + load_shard_state_dict, + save_state_dict, + save_state_dict_shards, +) from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.utils import get_current_device from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper +from colossalai.zero.gemini import ZeroOptimizer from colossalai.zero.gemini.memory_tracer import MemStats from .dp_plugin_base import DPPluginBase @@ -37,7 +46,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): """ Save sharded model to checkpoint but only on master process. The model should be unwrapped in self.load_model via ModelWrapper.unwrap. - As there is communication when getting state dict, this must be called on all processes. + As there is communication when getting state dict, model.state_dict() must be called on all processes. """ state_dict = model.state_dict(only_rank_0=True) if self.coordinator.is_master(): @@ -54,7 +63,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): """ Save unsharded optimizer state dict to checkpoint. After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank. - As there is communication when getting state dict, this must be called on all processes. + As there is communication when getting state dict, optimizer.state_dict() must be called on all processes. The saving process will only be executed by master rank. """ state_dict = optimizer.state_dict() @@ -76,7 +85,8 @@ class GeminiCheckpointIO(GeneralCheckpointIO): max_shard_size: int = 1024, use_safetensors: bool = False): """ - Save sharded model + Save sharded model. + As there is communication when getting state dict, model.state_dict() must be called on all processes. """ if os.path.isfile(checkpoint_path): logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") @@ -86,28 +96,24 @@ class GeminiCheckpointIO(GeneralCheckpointIO): state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) - total_size = 0 index_file = CheckpointIndexFile(checkpoint_path) - for idx, shard_pair in enumerate(state_dict_shard): - if not self.coordinator.is_master(): - continue - shard = shard_pair[0] - shard_file = get_shard_filename(weights_name, idx) - total_size = total_size + shard_pair[1] - for key in shard.keys(): - index_file.append_weight_map(key, shard_file) - checkpoint_file_path = os.path.join(checkpoint_path, shard_file) - save_state_dict(shard, checkpoint_file_path, use_safetensors) - - index_file.append_meta_data("total_size", total_size) + # Save shards of optimizer states. + is_master = self.coordinator.is_master() + total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=is_master, + use_safetensors=use_safetensors) # only save the index file on the master rank if self.coordinator.is_master(): + index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - logging.info(f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.") + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") def load_sharded_model(self, model: GeminiDDP, @@ -115,7 +121,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): strict: bool = False, use_safetensors: bool = False): """ - load shard model, load model from multiple files + Load shard model, load model from multiple files. """ return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False) @@ -125,16 +131,93 @@ class GeminiCheckpointIO(GeneralCheckpointIO): Save sharded optimizer state dict to checkpoint folder. As there is communication when getting state dict, this must be called on all processes. """ + + # If optimizer is wrapped, unwrap it. + if isinstance(optimizer, OptimizerWrapper): + optimizer = optimizer.unwrap() + + assert isinstance(optimizer, ZeroOptimizer) + + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + Path(checkpoint).mkdir(parents=True, exist_ok=True) - super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard) + + # Preparing file paths and index file. + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + + # Store the information of param groups to param_group_file. + index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + param_groups = optimizer.get_param_groups_for_saving() + torch.save(param_groups, group_file_path) + + # States are broken into shards within max_shard_size. + state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True) + + # Save shards of optimizer states. + is_master = self.coordinator.is_master() + total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=is_master, + use_safetensors=False) + + # Wrap up index file. Only save it on master rank. + if self.coordinator.is_master(): + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + logging.info(f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str): """ Loading sharded optimizer from checkpoint folder, with index file given. For each process, only loading optimizer states of parameters it controls. """ - # TODO(Baizhou): To be implemented. - pass + + if not os.path.isfile(checkpoint_index_file): + logging.error(f"Provided path ({checkpoint_index_file}) should be a file") + + # If optimizer is wrapped, unwrap it. + if isinstance(optimizer, OptimizerWrapper): + optimizer = optimizer.unwrap() + + assert isinstance(optimizer, ZeroOptimizer) + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + + # Load param_groups. + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \ + Lacking param group file under current directory.') + saved_param_groups = torch.load(param_group_path) + optimizer.load_param_groups(saved_param_groups) + + checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() + + # Load optimizer states from shard files under checkpoint path. + # For each file, only load the states managed by current process. + for shard_file in checkpoint_files: + state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False) + optimizer.load_param_states(state_dict_shard) + del state_dict_shard + gc.collect() + + optimizer.optimizer_loading_epilogue() + + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """ + Save model to checkpoint but only on master process. + """ + if self.coordinator.is_master(): + super().save_lr_scheduler(lr_scheduler, checkpoint) class GeminiModel(ModelWrapper): diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index e1d906694..83e4bdcc8 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -5,6 +5,7 @@ from functools import reduce from pathlib import Path from typing import Iterator, Optional, OrderedDict, Tuple +import torch.distributed as dist import torch.nn as nn from torch.optim import Optimizer @@ -16,7 +17,6 @@ from .utils import ( get_model_base_filenames, get_optimizer_base_filenames, get_shard_filename, - has_index_file, is_safetensors_available, load_param_groups_into_optimizer, load_shard_state_dict, @@ -25,6 +25,7 @@ from .utils import ( load_states_into_optimizer, save_param_groups, save_state_dict, + save_state_dict_shards, shard_model_checkpoint, shard_optimizer_checkpoint, sharded_optimizer_loading_epilogue, @@ -122,15 +123,13 @@ class GeneralCheckpointIO(CheckpointIO): save_param_groups(state_dict, group_file_path) # Save shards of optimizer states. - total_size = 0 - for idx, shard_pair in enumerate(sharded_state): - shard, current_size = shard_pair - shard_file = get_shard_filename(states_name, idx) - total_size = total_size + current_size - for key in shard.keys(): - index_file.append_weight_map(key, shard_file) - checkpoint_file_path = os.path.join(checkpoint, shard_file) - save_state_dict(shard, checkpoint_file_path, use_safetensors=False) + # In general cases, is_master is set to True to get the right behavior. + total_size = save_state_dict_shards(sharded_state_dict=sharded_state, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=True, + use_safetensors=False) # Wrap up index file. index_file.append_meta_data("total_size", total_size) @@ -172,18 +171,17 @@ class GeneralCheckpointIO(CheckpointIO): # shard checkpoint state_dict = model.state_dict() state_dict_shard = shard_model_checkpoint(state_dict, max_shard_size=max_shard_size) - weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) - total_size = 0 index_file = CheckpointIndexFile(checkpoint_path) - for idx, shard_pair in enumerate(state_dict_shard): - shard = shard_pair[0] - shard_file = get_shard_filename(weights_name, idx) - total_size = total_size + shard_pair[1] - for key in shard.keys(): - index_file.append_weight_map(key, shard_file) - checkpoint_file_path = os.path.join(checkpoint_path, shard_file) - save_state_dict(shard, checkpoint_file_path, use_safetensors) + + # Save shards of optimizer states. + # In general cases, is_master is set to True to get the right behavior. + total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=True, + use_safetensors=use_safetensors) index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 19e28c3f7..8837776ae 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -1,4 +1,5 @@ # coding=utf-8 +import os import re from collections import abc as container_abcs from collections import defaultdict @@ -103,6 +104,43 @@ def unwrap_optimizer(optimizer: OptimizerWrapper): return unwrapped_optim +def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]], + checkpoint: str, + index_file: "CheckpointIndexFile", + base_filename: str, + is_master: bool, + use_safetensors: bool = False) -> int: + ''' + Save sharded state dict only on master rank, this method can be used by both model and optimizer states. + Args: + sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size. + checkpoint (str): The path of checkpoint directory as string. + index_file (CheckpointIndexFile): The index file object to be updated. + base_filename (str): Decides the prefix of filenames of shards. + is_master (bool): Whether current rank is master. + use_safetensors (bool): Whether to use safetensors to save checkpoint. + + Returns: + int: the total size of shards + ''' + + total_size = 0 + for idx, shard_pair in enumerate(sharded_state_dict): + if not is_master: + continue + shard, current_size = shard_pair + shard_file = get_shard_filename(base_filename, idx) + total_size = total_size + current_size + for key in shard.keys(): + index_file.append_weight_map(key, shard_file) + checkpoint_file_path = os.path.join(checkpoint, shard_file) + + # Only save on master rank. + save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors) + + return total_size + + def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: """ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 99aff6f1c..7d0db6b1f 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -3,7 +3,7 @@ import copy import gc import math import warnings -from typing import Any, Dict, Set, Tuple +from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple import torch import torch.distributed as dist @@ -11,8 +11,10 @@ from torch.nn import Parameter from torch.optim import Optimizer from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin +from colossalai.checkpoint_io.utils import calculate_tensor_size from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam +from colossalai.tensor.d_tensor import is_distributed_tensor from colossalai.utils import disposable, get_current_device, is_ddp_ignored from .chunk import Chunk, ChunkManager @@ -360,10 +362,12 @@ class ZeroOptimizer(ColossalaiOptimizer): begin_in_chunk, end_in_chunk = self.param_to_range[fake_param] chunk_offset = begin_in_chunk - shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset + if chunk.keep_gathered: + shard_offset = 0 + else: + shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset shard_size = end_in_chunk - begin_in_chunk assert chunk_offset >= 0 and shard_offset >= 0 - return chunk_offset, shard_offset, shard_size def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: @@ -427,7 +431,8 @@ class ZeroOptimizer(ColossalaiOptimizer): dtype=torch.float32, requires_grad=False).cpu() else: - collected_states[state_name] = states[state_name].detach().clone().to(torch.float32).cpu() + state_tensor = states[state_name].detach().clone().to(torch.float32).cpu() + collected_states[state_name] = torch.reshape(state_tensor, param.shape) return collected_states # Check whether the param with given id is managed by current process. @@ -536,6 +541,31 @@ class ZeroOptimizer(ColossalaiOptimizer): target_segment.copy_(compacted_states[next_state_offset:next_state_offset + shard_size]) next_state_offset += shard_size + def get_param_groups_for_saving(self) -> list: + ''' + Return the param_groups in Pytorch format when saving to checkpoint. + ''' + + param_groups = copy.deepcopy(self.param_groups_backup) + + # To be compatible with pytorch checkpointing, + # store extra hyperparameters used by pytorch Adam optimizer. + torch_special_hyperparameters = { + 'amsgrad': False, + 'maximize': False, + 'foreach': None, + 'capturable': False, + 'differentiable': False, + 'fused': False + } + + for group in param_groups: + for k, v in torch_special_hyperparameters.items(): + if k not in group: + group[k] = v + + return param_groups + def state_dict(self, only_rank_0: bool = True) -> dict: """ Args: @@ -555,21 +585,7 @@ class ZeroOptimizer(ColossalaiOptimizer): so it should be called only when memory resources are abundant. """ state_dict = {} - state_dict['param_groups'] = copy.deepcopy(self.param_groups_backup) - - torch_special_hyperparameters = { - 'amsgrad': False, - 'maximize': False, - 'foreach': None, - 'capturable': False, - 'differentiable': False, - 'fused': False - } - - for group in state_dict['param_groups']: - for k, v in torch_special_hyperparameters.items(): - if k not in group: - group[k] = v + state_dict['param_groups'] = self.get_param_groups_for_saving() # Collect optimizer states. state_dict['state'] = dict() @@ -634,8 +650,24 @@ class ZeroOptimizer(ColossalaiOptimizer): del v # clean loaded states self.optim.state[fake_param].update(updated_states) + def load_param_states(self, param_states: dict): + """Loads param states from a state_dict. The param_states can be complete or sharded. + During loading, filter out the part of states not considered by current process. + + Args: + param_states (dict): A mapping from param_id to its states. + """ + for param_id, states in param_states.items(): + if param_id in self.id_to_fake_params: + self.load_single_param_states(param_id, states) + + def optimizer_loading_epilogue(self): + # Epilogue when loading state_dict to pytorch optimizer. + self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle. + self.optim.defaults.setdefault('differentiable', False) + def load_state_dict(self, state_dict: dict): - """Loads optimizer state from whole optimizer state_dict. + """Loads optimizer state from complete optimizer state_dict. During loading, filter out the part of states not considered by current process. Args: @@ -643,17 +675,71 @@ class ZeroOptimizer(ColossalaiOptimizer): from a call to :meth:`state_dict`. """ assert 'param_groups' in state_dict + assert 'state' in state_dict self.load_param_groups(state_dict['param_groups']) + self.load_param_states(state_dict['state']) + self.optimizer_loading_epilogue() - state = state_dict['state'] + def state_shard(self, + prefix: str = '', + max_shard_size: int = 1024, + only_rank_0: bool = True) -> Iterator[Tuple[OrderedDict, int]]: + """Returns dictionaries containing shards of optimizer states one by one. + The max size of each dictionary shard is specified by ``max_shard_size``. - for param_id, param_states in state.items(): - if param_id in self.id_to_fake_params: - self.load_single_param_states(param_id, param_states) + Args: + prefix (str, optional): the prefix for states. Default to ''. + max_shard_size (int, optional): max size of state dict shard (in MB). Defaults to 1024. + only_rank_0 (bool, optional): a boolean value indicating whether the state_dict is collected + only on rank 0, dafault to True. - # Epilogue for pytorch optimizer. - self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle. - self.optim.defaults.setdefault('differentiable', False) + Yields: + Iterator[OrderedDict]: A generator of state dict shard of optimizer states. + """ + + current_block = {} + current_block_size = 0 + + for param_id in self.id_to_real_params.keys(): + + dist.barrier() + state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0) + + ret_block = None + ret_block_size = 0 + + # A state might contain more than one tensors. + # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq' + state_size = 0 + isDTensor = False + for state_tensor in state.values(): + + # When state_tensor is not of Tensor class, + # e.g., a SGD optimizer with momentum set to 0 can have None as state + # The calculation of tensor size should be skipped to avoid error. + if not isinstance(state_tensor, torch.Tensor): + continue + + # If the states are stored as DTensors, mark isDTensor as true. + if is_distributed_tensor(state_tensor): + isDTensor = True + state_size += calculate_tensor_size(state_tensor) + + if not isDTensor: + + if current_block_size + state_size > max_shard_size and current_block_size > 0: + ret_block = current_block + ret_block_size = current_block_size + current_block = {} + current_block_size = 0 + + current_block[param_id] = state + current_block_size += state_size + + if ret_block != None: + yield ret_block, ret_block_size + + yield current_block, current_block_size class GeminiAdamOptimizer(ZeroOptimizer): diff --git a/docs/source/en/basics/booster_api.md b/docs/source/en/basics/booster_api.md index 22d5ee818..1e75c343c 100644 --- a/docs/source/en/basics/booster_api.md +++ b/docs/source/en/basics/booster_api.md @@ -21,10 +21,13 @@ Plugin is an important component that manages parallel configuration (eg: The ge **_GeminiPlugin:_** This plugin wraps the Gemini acceleration solution, that ZeRO with chunk-based memory management. -**_TorchDDPPlugin:_** This plugin wraps the DDP acceleration solution, it implements data parallelism at the module level which can run across multiple machines. +**_TorchDDPPlugin:_** This plugin wraps the DDP acceleration solution of Pytorch. It implements data parallelism at the module level which can run across multiple machines. **_LowLevelZeroPlugin:_** This plugin wraps the 1/2 stage of Zero Redundancy Optimizer. Stage 1 : Shards optimizer states across data parallel workers/GPUs. Stage 2 : Shards optimizer states + gradients across data parallel workers/GPUs. + +**_TorchFSDPPlugin:_** This plugin wraps the FSDP acceleration solution of Pytorch and can be used to train models with zero-dp. + ### API of booster {{ autodoc:colossalai.booster.Booster }} diff --git a/docs/source/en/basics/booster_checkpoint.md b/docs/source/en/basics/booster_checkpoint.md index adc0af60b..b2840fe87 100644 --- a/docs/source/en/basics/booster_checkpoint.md +++ b/docs/source/en/basics/booster_checkpoint.md @@ -21,8 +21,6 @@ Model must be boosted by `colossalai.booster.Booster` before loading. It will de ## Optimizer Checkpoint -> ⚠ Saving optimizer checkpoint in a sharded way is not supported yet. - {{ autodoc:colossalai.booster.Booster.save_optimizer }} Optimizer must be boosted by `colossalai.booster.Booster` before saving. diff --git a/docs/source/en/basics/booster_plugins.md b/docs/source/en/basics/booster_plugins.md index 5e2586b83..c5c45abce 100644 --- a/docs/source/en/basics/booster_plugins.md +++ b/docs/source/en/basics/booster_plugins.md @@ -51,8 +51,6 @@ This plugin implements Zero-3 with chunk-based and heterogeneous memory manageme {{ autodoc:colossalai.booster.plugin.GeminiPlugin }} -> ⚠ This plugin can only load optimizer checkpoint saved by itself with the same number of processes now. This will be fixed in the future. - ### Torch DDP Plugin More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel). diff --git a/docs/source/zh-Hans/basics/booster_api.md b/docs/source/zh-Hans/basics/booster_api.md index 1df821ce7..b2235b73b 100644 --- a/docs/source/zh-Hans/basics/booster_api.md +++ b/docs/source/zh-Hans/basics/booster_api.md @@ -24,10 +24,13 @@ Booster 插件是管理并行配置的重要组件(eg:gemini 插件封装了 **_GeminiPlugin:_** GeminiPlugin 插件封装了 gemini 加速解决方案,即基于块内存管理的 ZeRO 优化方案。 -**_TorchDDPPlugin:_** TorchDDPPlugin 插件封装了 DDP 加速方案,实现了模型级别的数据并行,可以跨多机运行。 +**_TorchDDPPlugin:_** TorchDDPPlugin 插件封装了Pytorch的DDP加速方案,实现了模型级别的数据并行,可以跨多机运行。 **_LowLevelZeroPlugin:_** LowLevelZeroPlugin 插件封装了零冗余优化器的 1/2 阶段。阶段 1:切分优化器参数,分发到各并发进程或并发 GPU 上。阶段 2:切分优化器参数及梯度,分发到各并发进程或并发 GPU 上。 +**_TorchFSDPPlugin:_** TorchFSDPPlugin封装了 Pytorch的FSDP加速方案,可以用于零冗余优化器数据并行(ZeroDP)的训练。 + + ### Booster 接口 diff --git a/docs/source/zh-Hans/basics/booster_checkpoint.md b/docs/source/zh-Hans/basics/booster_checkpoint.md index d75f18c90..4ed049dcf 100644 --- a/docs/source/zh-Hans/basics/booster_checkpoint.md +++ b/docs/source/zh-Hans/basics/booster_checkpoint.md @@ -21,7 +21,6 @@ ## 优化器 Checkpoint -> ⚠ 尚不支持以分片方式保存优化器 Checkpoint。 {{ autodoc:colossalai.booster.Booster.save_optimizer }} diff --git a/docs/source/zh-Hans/basics/booster_plugins.md b/docs/source/zh-Hans/basics/booster_plugins.md index 5bd88b679..0f355c439 100644 --- a/docs/source/zh-Hans/basics/booster_plugins.md +++ b/docs/source/zh-Hans/basics/booster_plugins.md @@ -51,7 +51,6 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累 {{ autodoc:colossalai.booster.plugin.GeminiPlugin }} -> ⚠ 该插件现在只能加载自己保存的且具有相同进程数的优化器 Checkpoint。这将在未来得到解决。 ### Torch DDP 插件 diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 0235ff2e2..7b664419b 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -52,7 +52,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b @clear_cache_before_run() @parameterize('placement_policy', ['cuda', 'cpu']) -@parameterize('shard', [False]) +@parameterize('shard', [False, True]) @parameterize('model_name', ['transformers_gpt']) @parameterize('size_per_shard', [32]) def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_shard: int): @@ -117,7 +117,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2]) +@pytest.mark.parametrize('world_size', [2]) @rerun_if_address_is_in_use() def test_gemini_ckpIO(world_size): spawn(run_dist, world_size) diff --git a/tests/test_checkpoint_io/test_gemini_torch_compability.py b/tests/test_checkpoint_io/test_gemini_torch_compability.py index b34e3e3a1..464fccb39 100644 --- a/tests/test_checkpoint_io/test_gemini_torch_compability.py +++ b/tests/test_checkpoint_io/test_gemini_torch_compability.py @@ -19,7 +19,7 @@ from tests.kit.model_zoo import model_zoo @clear_cache_before_run() -@parameterize('shard', [False]) +@parameterize('shard', [False, True]) @parameterize('model_name', ['transformers_gpt']) def exam_torch_load_from_gemini(shard: bool, model_name: str): @@ -83,7 +83,7 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str): @clear_cache_before_run() -@parameterize('shard', [False]) +@parameterize('shard', [False, True]) @parameterize('model_name', ['transformers_gpt']) def exam_gemini_load_from_torch(shard: bool, model_name: str): @@ -165,7 +165,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2]) +@pytest.mark.parametrize('world_size', [2]) @rerun_if_address_is_in_use() def test_gemini_ckpIO(world_size): spawn(run_dist, world_size)