mirror of https://github.com/hpcaitech/ColossalAI
[checkpointio] Sharded Optimizer Checkpoint for Gemini Plugin (#4302)
* sharded optimizer checkpoint for gemini plugin * modify test to reduce testing time * update doc * fix bug when keep_gatherd is true under GeminiPluginpull/4305/head
parent
fc5cef2c79
commit
c6f6005990
|
@ -1,3 +1,4 @@
|
|||
import gc
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
|
@ -12,11 +13,19 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.checkpoint_io.utils import get_model_base_filenames, get_shard_filename, save_state_dict
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
get_shard_filename,
|
||||
load_shard_state_dict,
|
||||
save_state_dict,
|
||||
save_state_dict_shards,
|
||||
)
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.zero.gemini import ZeroOptimizer
|
||||
from colossalai.zero.gemini.memory_tracer import MemStats
|
||||
|
||||
from .dp_plugin_base import DPPluginBase
|
||||
|
@ -37,7 +46,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
"""
|
||||
Save sharded model to checkpoint but only on master process.
|
||||
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
||||
As there is communication when getting state dict, this must be called on all processes.
|
||||
As there is communication when getting state dict, model.state_dict() must be called on all processes.
|
||||
"""
|
||||
state_dict = model.state_dict(only_rank_0=True)
|
||||
if self.coordinator.is_master():
|
||||
|
@ -54,7 +63,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
"""
|
||||
Save unsharded optimizer state dict to checkpoint.
|
||||
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
|
||||
As there is communication when getting state dict, this must be called on all processes.
|
||||
As there is communication when getting state dict, optimizer.state_dict() must be called on all processes.
|
||||
The saving process will only be executed by master rank.
|
||||
"""
|
||||
state_dict = optimizer.state_dict()
|
||||
|
@ -76,7 +85,8 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
"""
|
||||
Save sharded model
|
||||
Save sharded model.
|
||||
As there is communication when getting state dict, model.state_dict() must be called on all processes.
|
||||
"""
|
||||
if os.path.isfile(checkpoint_path):
|
||||
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
|
||||
|
@ -86,28 +96,24 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
|
||||
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
total_size = 0
|
||||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
for idx, shard_pair in enumerate(state_dict_shard):
|
||||
if not self.coordinator.is_master():
|
||||
continue
|
||||
shard = shard_pair[0]
|
||||
shard_file = get_shard_filename(weights_name, idx)
|
||||
total_size = total_size + shard_pair[1]
|
||||
for key in shard.keys():
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
|
||||
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors)
|
||||
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
# Save shards of optimizer states.
|
||||
is_master = self.coordinator.is_master()
|
||||
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint_path,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=is_master,
|
||||
use_safetensors=use_safetensors)
|
||||
|
||||
# only save the index file on the master rank
|
||||
if self.coordinator.is_master():
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
|
||||
def load_sharded_model(self,
|
||||
model: GeminiDDP,
|
||||
|
@ -115,7 +121,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
strict: bool = False,
|
||||
use_safetensors: bool = False):
|
||||
"""
|
||||
load shard model, load model from multiple files
|
||||
Load shard model, load model from multiple files.
|
||||
"""
|
||||
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
|
||||
|
||||
|
@ -125,16 +131,93 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
Save sharded optimizer state dict to checkpoint folder.
|
||||
As there is communication when getting state dict, this must be called on all processes.
|
||||
"""
|
||||
|
||||
# If optimizer is wrapped, unwrap it.
|
||||
if isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = optimizer.unwrap()
|
||||
|
||||
assert isinstance(optimizer, ZeroOptimizer)
|
||||
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
|
||||
|
||||
# Preparing file paths and index file.
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
|
||||
# Store the information of param groups to param_group_file.
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
param_groups = optimizer.get_param_groups_for_saving()
|
||||
torch.save(param_groups, group_file_path)
|
||||
|
||||
# States are broken into shards within max_shard_size.
|
||||
state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True)
|
||||
|
||||
# Save shards of optimizer states.
|
||||
is_master = self.coordinator.is_master()
|
||||
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=is_master,
|
||||
use_safetensors=False)
|
||||
|
||||
# Wrap up index file. Only save it on master rank.
|
||||
if self.coordinator.is_master():
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str):
|
||||
"""
|
||||
Loading sharded optimizer from checkpoint folder, with index file given.
|
||||
For each process, only loading optimizer states of parameters it controls.
|
||||
"""
|
||||
# TODO(Baizhou): To be implemented.
|
||||
pass
|
||||
|
||||
if not os.path.isfile(checkpoint_index_file):
|
||||
logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
|
||||
|
||||
# If optimizer is wrapped, unwrap it.
|
||||
if isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = optimizer.unwrap()
|
||||
|
||||
assert isinstance(optimizer, ZeroOptimizer)
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
|
||||
# Load param_groups.
|
||||
param_group_path = ckpt_index_file.get_param_group_filename()
|
||||
if param_group_path is None:
|
||||
raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \
|
||||
Lacking param group file under current directory.')
|
||||
saved_param_groups = torch.load(param_group_path)
|
||||
optimizer.load_param_groups(saved_param_groups)
|
||||
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||
|
||||
# Load optimizer states from shard files under checkpoint path.
|
||||
# For each file, only load the states managed by current process.
|
||||
for shard_file in checkpoint_files:
|
||||
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
optimizer.load_param_states(state_dict_shard)
|
||||
del state_dict_shard
|
||||
gc.collect()
|
||||
|
||||
optimizer.optimizer_loading_epilogue()
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
if self.coordinator.is_master():
|
||||
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
|
||||
class GeminiModel(ModelWrapper):
|
||||
|
|
|
@ -5,6 +5,7 @@ from functools import reduce
|
|||
from pathlib import Path
|
||||
from typing import Iterator, Optional, OrderedDict, Tuple
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
@ -16,7 +17,6 @@ from .utils import (
|
|||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
get_shard_filename,
|
||||
has_index_file,
|
||||
is_safetensors_available,
|
||||
load_param_groups_into_optimizer,
|
||||
load_shard_state_dict,
|
||||
|
@ -25,6 +25,7 @@ from .utils import (
|
|||
load_states_into_optimizer,
|
||||
save_param_groups,
|
||||
save_state_dict,
|
||||
save_state_dict_shards,
|
||||
shard_model_checkpoint,
|
||||
shard_optimizer_checkpoint,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
|
@ -122,15 +123,13 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
save_param_groups(state_dict, group_file_path)
|
||||
|
||||
# Save shards of optimizer states.
|
||||
total_size = 0
|
||||
for idx, shard_pair in enumerate(sharded_state):
|
||||
shard, current_size = shard_pair
|
||||
shard_file = get_shard_filename(states_name, idx)
|
||||
total_size = total_size + current_size
|
||||
for key in shard.keys():
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
|
||||
# In general cases, is_master is set to True to get the right behavior.
|
||||
total_size = save_state_dict_shards(sharded_state_dict=sharded_state,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=True,
|
||||
use_safetensors=False)
|
||||
|
||||
# Wrap up index file.
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
|
@ -172,18 +171,17 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
# shard checkpoint
|
||||
state_dict = model.state_dict()
|
||||
state_dict_shard = shard_model_checkpoint(state_dict, max_shard_size=max_shard_size)
|
||||
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
total_size = 0
|
||||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
for idx, shard_pair in enumerate(state_dict_shard):
|
||||
shard = shard_pair[0]
|
||||
shard_file = get_shard_filename(weights_name, idx)
|
||||
total_size = total_size + shard_pair[1]
|
||||
for key in shard.keys():
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors)
|
||||
|
||||
# Save shards of optimizer states.
|
||||
# In general cases, is_master is set to True to get the right behavior.
|
||||
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint_path,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=True,
|
||||
use_safetensors=use_safetensors)
|
||||
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# coding=utf-8
|
||||
import os
|
||||
import re
|
||||
from collections import abc as container_abcs
|
||||
from collections import defaultdict
|
||||
|
@ -103,6 +104,43 @@ def unwrap_optimizer(optimizer: OptimizerWrapper):
|
|||
return unwrapped_optim
|
||||
|
||||
|
||||
def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
|
||||
checkpoint: str,
|
||||
index_file: "CheckpointIndexFile",
|
||||
base_filename: str,
|
||||
is_master: bool,
|
||||
use_safetensors: bool = False) -> int:
|
||||
'''
|
||||
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
|
||||
Args:
|
||||
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
|
||||
checkpoint (str): The path of checkpoint directory as string.
|
||||
index_file (CheckpointIndexFile): The index file object to be updated.
|
||||
base_filename (str): Decides the prefix of filenames of shards.
|
||||
is_master (bool): Whether current rank is master.
|
||||
use_safetensors (bool): Whether to use safetensors to save checkpoint.
|
||||
|
||||
Returns:
|
||||
int: the total size of shards
|
||||
'''
|
||||
|
||||
total_size = 0
|
||||
for idx, shard_pair in enumerate(sharded_state_dict):
|
||||
if not is_master:
|
||||
continue
|
||||
shard, current_size = shard_pair
|
||||
shard_file = get_shard_filename(base_filename, idx)
|
||||
total_size = total_size + current_size
|
||||
for key in shard.keys():
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
|
||||
# Only save on master rank.
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors)
|
||||
|
||||
return total_size
|
||||
|
||||
|
||||
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""
|
||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
|
|
|
@ -3,7 +3,7 @@ import copy
|
|||
import gc
|
||||
import math
|
||||
import warnings
|
||||
from typing import Any, Dict, Set, Tuple
|
||||
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -11,8 +11,10 @@ from torch.nn import Parameter
|
|||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
|
||||
from colossalai.checkpoint_io.utils import calculate_tensor_size
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
|
||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
|
@ -360,10 +362,12 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
|
||||
begin_in_chunk, end_in_chunk = self.param_to_range[fake_param]
|
||||
chunk_offset = begin_in_chunk
|
||||
shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset
|
||||
if chunk.keep_gathered:
|
||||
shard_offset = 0
|
||||
else:
|
||||
shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset
|
||||
shard_size = end_in_chunk - begin_in_chunk
|
||||
assert chunk_offset >= 0 and shard_offset >= 0
|
||||
|
||||
return chunk_offset, shard_offset, shard_size
|
||||
|
||||
def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
|
||||
|
@ -427,7 +431,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
dtype=torch.float32,
|
||||
requires_grad=False).cpu()
|
||||
else:
|
||||
collected_states[state_name] = states[state_name].detach().clone().to(torch.float32).cpu()
|
||||
state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
|
||||
collected_states[state_name] = torch.reshape(state_tensor, param.shape)
|
||||
return collected_states
|
||||
|
||||
# Check whether the param with given id is managed by current process.
|
||||
|
@ -536,6 +541,31 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
target_segment.copy_(compacted_states[next_state_offset:next_state_offset + shard_size])
|
||||
next_state_offset += shard_size
|
||||
|
||||
def get_param_groups_for_saving(self) -> list:
|
||||
'''
|
||||
Return the param_groups in Pytorch format when saving to checkpoint.
|
||||
'''
|
||||
|
||||
param_groups = copy.deepcopy(self.param_groups_backup)
|
||||
|
||||
# To be compatible with pytorch checkpointing,
|
||||
# store extra hyperparameters used by pytorch Adam optimizer.
|
||||
torch_special_hyperparameters = {
|
||||
'amsgrad': False,
|
||||
'maximize': False,
|
||||
'foreach': None,
|
||||
'capturable': False,
|
||||
'differentiable': False,
|
||||
'fused': False
|
||||
}
|
||||
|
||||
for group in param_groups:
|
||||
for k, v in torch_special_hyperparameters.items():
|
||||
if k not in group:
|
||||
group[k] = v
|
||||
|
||||
return param_groups
|
||||
|
||||
def state_dict(self, only_rank_0: bool = True) -> dict:
|
||||
"""
|
||||
Args:
|
||||
|
@ -555,21 +585,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
so it should be called only when memory resources are abundant.
|
||||
"""
|
||||
state_dict = {}
|
||||
state_dict['param_groups'] = copy.deepcopy(self.param_groups_backup)
|
||||
|
||||
torch_special_hyperparameters = {
|
||||
'amsgrad': False,
|
||||
'maximize': False,
|
||||
'foreach': None,
|
||||
'capturable': False,
|
||||
'differentiable': False,
|
||||
'fused': False
|
||||
}
|
||||
|
||||
for group in state_dict['param_groups']:
|
||||
for k, v in torch_special_hyperparameters.items():
|
||||
if k not in group:
|
||||
group[k] = v
|
||||
state_dict['param_groups'] = self.get_param_groups_for_saving()
|
||||
|
||||
# Collect optimizer states.
|
||||
state_dict['state'] = dict()
|
||||
|
@ -634,8 +650,24 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
del v # clean loaded states
|
||||
self.optim.state[fake_param].update(updated_states)
|
||||
|
||||
def load_param_states(self, param_states: dict):
|
||||
"""Loads param states from a state_dict. The param_states can be complete or sharded.
|
||||
During loading, filter out the part of states not considered by current process.
|
||||
|
||||
Args:
|
||||
param_states (dict): A mapping from param_id to its states.
|
||||
"""
|
||||
for param_id, states in param_states.items():
|
||||
if param_id in self.id_to_fake_params:
|
||||
self.load_single_param_states(param_id, states)
|
||||
|
||||
def optimizer_loading_epilogue(self):
|
||||
# Epilogue when loading state_dict to pytorch optimizer.
|
||||
self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
||||
self.optim.defaults.setdefault('differentiable', False)
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
"""Loads optimizer state from whole optimizer state_dict.
|
||||
"""Loads optimizer state from complete optimizer state_dict.
|
||||
During loading, filter out the part of states not considered by current process.
|
||||
|
||||
Args:
|
||||
|
@ -643,17 +675,71 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
assert 'param_groups' in state_dict
|
||||
assert 'state' in state_dict
|
||||
self.load_param_groups(state_dict['param_groups'])
|
||||
self.load_param_states(state_dict['state'])
|
||||
self.optimizer_loading_epilogue()
|
||||
|
||||
state = state_dict['state']
|
||||
def state_shard(self,
|
||||
prefix: str = '',
|
||||
max_shard_size: int = 1024,
|
||||
only_rank_0: bool = True) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""Returns dictionaries containing shards of optimizer states one by one.
|
||||
The max size of each dictionary shard is specified by ``max_shard_size``.
|
||||
|
||||
for param_id, param_states in state.items():
|
||||
if param_id in self.id_to_fake_params:
|
||||
self.load_single_param_states(param_id, param_states)
|
||||
Args:
|
||||
prefix (str, optional): the prefix for states. Default to ''.
|
||||
max_shard_size (int, optional): max size of state dict shard (in MB). Defaults to 1024.
|
||||
only_rank_0 (bool, optional): a boolean value indicating whether the state_dict is collected
|
||||
only on rank 0, dafault to True.
|
||||
|
||||
# Epilogue for pytorch optimizer.
|
||||
self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
||||
self.optim.defaults.setdefault('differentiable', False)
|
||||
Yields:
|
||||
Iterator[OrderedDict]: A generator of state dict shard of optimizer states.
|
||||
"""
|
||||
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
|
||||
for param_id in self.id_to_real_params.keys():
|
||||
|
||||
dist.barrier()
|
||||
state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
|
||||
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
|
||||
# A state might contain more than one tensors.
|
||||
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
|
||||
state_size = 0
|
||||
isDTensor = False
|
||||
for state_tensor in state.values():
|
||||
|
||||
# When state_tensor is not of Tensor class,
|
||||
# e.g., a SGD optimizer with momentum set to 0 can have None as state
|
||||
# The calculation of tensor size should be skipped to avoid error.
|
||||
if not isinstance(state_tensor, torch.Tensor):
|
||||
continue
|
||||
|
||||
# If the states are stored as DTensors, mark isDTensor as true.
|
||||
if is_distributed_tensor(state_tensor):
|
||||
isDTensor = True
|
||||
state_size += calculate_tensor_size(state_tensor)
|
||||
|
||||
if not isDTensor:
|
||||
|
||||
if current_block_size + state_size > max_shard_size and current_block_size > 0:
|
||||
ret_block = current_block
|
||||
ret_block_size = current_block_size
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
|
||||
current_block[param_id] = state
|
||||
current_block_size += state_size
|
||||
|
||||
if ret_block != None:
|
||||
yield ret_block, ret_block_size
|
||||
|
||||
yield current_block, current_block_size
|
||||
|
||||
|
||||
class GeminiAdamOptimizer(ZeroOptimizer):
|
||||
|
|
|
@ -21,10 +21,13 @@ Plugin is an important component that manages parallel configuration (eg: The ge
|
|||
|
||||
**_GeminiPlugin:_** This plugin wraps the Gemini acceleration solution, that ZeRO with chunk-based memory management.
|
||||
|
||||
**_TorchDDPPlugin:_** This plugin wraps the DDP acceleration solution, it implements data parallelism at the module level which can run across multiple machines.
|
||||
**_TorchDDPPlugin:_** This plugin wraps the DDP acceleration solution of Pytorch. It implements data parallelism at the module level which can run across multiple machines.
|
||||
|
||||
**_LowLevelZeroPlugin:_** This plugin wraps the 1/2 stage of Zero Redundancy Optimizer. Stage 1 : Shards optimizer states across data parallel workers/GPUs. Stage 2 : Shards optimizer states + gradients across data parallel workers/GPUs.
|
||||
|
||||
|
||||
**_TorchFSDPPlugin:_** This plugin wraps the FSDP acceleration solution of Pytorch and can be used to train models with zero-dp.
|
||||
|
||||
### API of booster
|
||||
|
||||
{{ autodoc:colossalai.booster.Booster }}
|
||||
|
|
|
@ -21,8 +21,6 @@ Model must be boosted by `colossalai.booster.Booster` before loading. It will de
|
|||
|
||||
## Optimizer Checkpoint
|
||||
|
||||
> ⚠ Saving optimizer checkpoint in a sharded way is not supported yet.
|
||||
|
||||
{{ autodoc:colossalai.booster.Booster.save_optimizer }}
|
||||
|
||||
Optimizer must be boosted by `colossalai.booster.Booster` before saving.
|
||||
|
|
|
@ -51,8 +51,6 @@ This plugin implements Zero-3 with chunk-based and heterogeneous memory manageme
|
|||
|
||||
{{ autodoc:colossalai.booster.plugin.GeminiPlugin }}
|
||||
|
||||
> ⚠ This plugin can only load optimizer checkpoint saved by itself with the same number of processes now. This will be fixed in the future.
|
||||
|
||||
### Torch DDP Plugin
|
||||
|
||||
More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel).
|
||||
|
|
|
@ -24,10 +24,13 @@ Booster 插件是管理并行配置的重要组件(eg:gemini 插件封装了
|
|||
|
||||
**_GeminiPlugin:_** GeminiPlugin 插件封装了 gemini 加速解决方案,即基于块内存管理的 ZeRO 优化方案。
|
||||
|
||||
**_TorchDDPPlugin:_** TorchDDPPlugin 插件封装了 DDP 加速方案,实现了模型级别的数据并行,可以跨多机运行。
|
||||
**_TorchDDPPlugin:_** TorchDDPPlugin 插件封装了Pytorch的DDP加速方案,实现了模型级别的数据并行,可以跨多机运行。
|
||||
|
||||
**_LowLevelZeroPlugin:_** LowLevelZeroPlugin 插件封装了零冗余优化器的 1/2 阶段。阶段 1:切分优化器参数,分发到各并发进程或并发 GPU 上。阶段 2:切分优化器参数及梯度,分发到各并发进程或并发 GPU 上。
|
||||
|
||||
**_TorchFSDPPlugin:_** TorchFSDPPlugin封装了 Pytorch的FSDP加速方案,可以用于零冗余优化器数据并行(ZeroDP)的训练。
|
||||
|
||||
|
||||
### Booster 接口
|
||||
|
||||
<!--TODO: update autodoc -->
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
|
||||
## 优化器 Checkpoint
|
||||
|
||||
> ⚠ 尚不支持以分片方式保存优化器 Checkpoint。
|
||||
|
||||
{{ autodoc:colossalai.booster.Booster.save_optimizer }}
|
||||
|
||||
|
|
|
@ -51,7 +51,6 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累
|
|||
|
||||
{{ autodoc:colossalai.booster.plugin.GeminiPlugin }}
|
||||
|
||||
> ⚠ 该插件现在只能加载自己保存的且具有相同进程数的优化器 Checkpoint。这将在未来得到解决。
|
||||
|
||||
### Torch DDP 插件
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
|
|||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
@parameterize('shard', [False])
|
||||
@parameterize('shard', [False, True])
|
||||
@parameterize('model_name', ['transformers_gpt'])
|
||||
@parameterize('size_per_shard', [32])
|
||||
def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_shard: int):
|
||||
|
@ -117,7 +117,7 @@ def run_dist(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_ckpIO(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
|
|
@ -19,7 +19,7 @@ from tests.kit.model_zoo import model_zoo
|
|||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('shard', [False])
|
||||
@parameterize('shard', [False, True])
|
||||
@parameterize('model_name', ['transformers_gpt'])
|
||||
def exam_torch_load_from_gemini(shard: bool, model_name: str):
|
||||
|
||||
|
@ -83,7 +83,7 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str):
|
|||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('shard', [False])
|
||||
@parameterize('shard', [False, True])
|
||||
@parameterize('model_name', ['transformers_gpt'])
|
||||
def exam_gemini_load_from_torch(shard: bool, model_name: str):
|
||||
|
||||
|
@ -165,7 +165,7 @@ def run_dist(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_ckpIO(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
|
Loading…
Reference in New Issue