mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
336 lines
14 KiB
336 lines
14 KiB
import logging
|
|
import os
|
|
from functools import partial
|
|
from pathlib import Path
|
|
from types import MethodType
|
|
from typing import Callable, Iterator, List, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|
from torch.utils._pytree import tree_map
|
|
from torch.utils.data import DataLoader
|
|
|
|
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
|
|
from colossalai.checkpoint_io.utils import (
|
|
get_optimizer_base_filenames,
|
|
get_shard_filename,
|
|
load_param_groups_into_optimizer,
|
|
load_shard_state_dict,
|
|
load_states_into_optimizer,
|
|
save_param_groups,
|
|
save_state_dict,
|
|
sharded_optimizer_loading_epilogue,
|
|
)
|
|
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
|
from colossalai.utils import get_current_device
|
|
from colossalai.zero import LowLevelZeroOptimizer
|
|
|
|
from .dp_plugin_base import DPPluginBase
|
|
from .torch_ddp_plugin import TorchDDPCheckpointIO
|
|
|
|
__all__ = ["LowLevelZeroPlugin"]
|
|
|
|
|
|
def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
|
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
|
|
return x.to(dtype)
|
|
return x
|
|
|
|
|
|
SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]
|
|
|
|
|
|
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
|
def __init__(self, module: nn.Module, precision: str) -> None:
|
|
super().__init__(module)
|
|
self.dtype = None
|
|
if precision == "fp16":
|
|
self.dtype = torch.float16
|
|
elif precision == "bf16":
|
|
self.dtype = torch.bfloat16
|
|
if self.dtype is not None:
|
|
module = module.to(self.dtype)
|
|
module = module.to(get_current_device())
|
|
self.module = module
|
|
self.convert_fn = None
|
|
if self.dtype is not None:
|
|
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
if self.convert_fn is not None:
|
|
args = tree_map(self.convert_fn, args)
|
|
kwargs = tree_map(self.convert_fn, kwargs)
|
|
return super().forward(*args, **kwargs)
|
|
|
|
|
|
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
|
|
"""Save optimizer to checkpoint but only on master process.
|
|
|
|
Args:
|
|
optimizer (OptimizerWrapper): Optimizer to save state_dict
|
|
checkpoint (str): Path to save checkpoint
|
|
gather_dtensor (bool): Whether to gather_dtensor, not used
|
|
"""
|
|
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
|
|
# the `state_dict` in LowLevelZeroOptimizer has communication
|
|
# if only the master rank collect state_dict and save,
|
|
# the communication on each rank would not match
|
|
state_dict = optimizer.state_dict()
|
|
if self.coordinator.is_master():
|
|
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
|
|
|
def save_sharded_optimizer(
|
|
self,
|
|
optimizer: OptimizerWrapper,
|
|
checkpoint: str,
|
|
gather_dtensor: bool = False,
|
|
prefix: str = None,
|
|
size_per_shard: int = 1024,
|
|
):
|
|
"""
|
|
Save sharded Zero-optimizer checkpoint under the given checkpointing path.
|
|
The following files will be created under the path:
|
|
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
|
|
- A group file (pytorch_optim_group.bin) recording information of param_groups
|
|
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
|
|
|
|
Args:
|
|
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
|
|
checkpoint (str): Path to save optimizer state_dict
|
|
gather_dtensor (bool): Whether to gather_dtensor, not used
|
|
prefix (str): Perfix of file to save
|
|
size_per_shard (int): Max file size of each file that store state tensors
|
|
"""
|
|
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
|
|
if os.path.isfile(checkpoint):
|
|
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
|
return
|
|
|
|
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
|
|
|
# state_dict only provide only 'param_groups'
|
|
state_dict = optimizer.optim.state_dict()
|
|
# state shard would be handled by the low-level zero optimizer
|
|
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard)
|
|
|
|
# Preparing file paths and index file.
|
|
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
|
index_file = CheckpointIndexFile(checkpoint)
|
|
|
|
# Store the information of param groups to param_group_file.
|
|
index_file.append_meta_data("param_groups", param_group_file)
|
|
group_file_path = os.path.join(checkpoint, param_group_file)
|
|
save_param_groups(state_dict, group_file_path)
|
|
|
|
# Save shards of optimizer states.
|
|
total_size = 0
|
|
for idx, shard_pair in enumerate(sharded_state):
|
|
shard, current_size = shard_pair
|
|
shard_file = get_shard_filename(states_name, idx)
|
|
total_size = total_size + current_size
|
|
for param_id in shard.keys():
|
|
index_file.append_weight_map(str(param_id), shard_file)
|
|
|
|
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
|
if self.coordinator.is_master():
|
|
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
|
|
|
|
# Wrap up index file.
|
|
index_file.append_meta_data("total_size", total_size)
|
|
if self.coordinator.is_master():
|
|
index_file.write_index_file(save_index_file)
|
|
logging.info(
|
|
f"The optimizer is going to be split to checkpoint shards. "
|
|
f"You can find where each parameters has been saved in the "
|
|
f"index located at {save_index_file}."
|
|
)
|
|
|
|
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
|
|
"""Load sharded optimizer with the given path to index file.
|
|
|
|
Args:
|
|
optimizer (OptimizerWrapper): Optimizer to load state_dict
|
|
index_file_path (str): Path to the index file
|
|
prefix (str): Not used.
|
|
"""
|
|
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before Loading!"
|
|
optimizer = optimizer.unwrap()
|
|
|
|
# Read checkpoint index file.
|
|
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
|
|
|
# Load param_groups
|
|
param_group_path = ckpt_index_file.get_param_group_filename()
|
|
if param_group_path is None:
|
|
raise RuntimeError(
|
|
f"Invalid index file path {index_file_path} for an optimizer. \
|
|
Lacking param group file under current directory."
|
|
)
|
|
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
|
|
|
|
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
|
|
|
for shard_file in checkpoint_files:
|
|
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
|
# shard state dict
|
|
for param_idx, state in state_dict.items():
|
|
for k, v in state.items():
|
|
if isinstance(v, torch.Tensor) and k != "step":
|
|
padding_size = (
|
|
self.coordinator.world_size - v.numel() % self.coordinator.world_size
|
|
) % self.coordinator.world_size
|
|
with torch.no_grad():
|
|
v = v.flatten()
|
|
if padding_size > 0:
|
|
v = torch.nn.functional.pad(v, [0, padding_size])
|
|
v_list = v.split(v.numel() // self.coordinator.world_size)
|
|
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
|
|
load_states_into_optimizer(optimizer, state_dict, id_map)
|
|
sharded_optimizer_loading_epilogue(optimizer)
|
|
|
|
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
|
|
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
|
super().load_unsharded_model(model, checkpoint, strict)
|
|
model.update_master_params()
|
|
|
|
def load_sharded_model(
|
|
self,
|
|
model: ModelWrapper,
|
|
checkpoint_index_file: Path,
|
|
strict: bool = False,
|
|
use_safetensors: bool = False,
|
|
load_sub_module: bool = True,
|
|
):
|
|
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
|
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
|
model.update_master_params()
|
|
|
|
|
|
class LowLevelZeroPlugin(DPPluginBase):
|
|
"""
|
|
Plugin for low level zero.
|
|
|
|
Example:
|
|
>>> from colossalai.booster import Booster
|
|
>>> from colossalai.booster.plugin import LowLevelZeroPlugin
|
|
>>>
|
|
>>> model, train_dataset, optimizer, criterion = ...
|
|
>>> plugin = LowLevelZeroPlugin()
|
|
|
|
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
|
|
>>> booster = Booster(plugin=plugin)
|
|
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
|
|
|
|
Args:
|
|
strage (int, optional): ZeRO stage. Defaults to 1.
|
|
precision (str, optional): precision. Support 'fp16', 'bf16' and 'fp32'. Defaults to 'fp16'.
|
|
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
|
|
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
|
|
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
|
|
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
|
|
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
|
|
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
|
|
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
|
|
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
|
|
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
|
|
norm_type (float, optional): norm_type used for `clip_grad_norm`.
|
|
reduce_bucket_size_in_m (int, optional): grad reduce bucket size in M. Defaults to 12.
|
|
communication_dtype (torch.dtype, optional): communication dtype. If not specified, the dtype of param will be used. Defaults to None.
|
|
overlap_communication (bool, optional): whether to overlap communication and computation. Defaults to True.
|
|
cpu_offload (bool, optional): whether to offload grad, master weight and optimizer state to cpu. Defaults to False.
|
|
verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
stage: int = 1,
|
|
precision: str = "fp16",
|
|
initial_scale: float = 2**32,
|
|
min_scale: float = 1,
|
|
growth_factor: float = 2,
|
|
backoff_factor: float = 0.5,
|
|
growth_interval: int = 1000,
|
|
hysteresis: int = 2,
|
|
max_scale: float = 2**32,
|
|
max_norm: float = 0.0,
|
|
norm_type: float = 2.0,
|
|
reduce_bucket_size_in_m: int = 12,
|
|
communication_dtype: Optional[torch.dtype] = None,
|
|
overlap_communication: bool = True,
|
|
cpu_offload: bool = False,
|
|
verbose: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
|
|
assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training"
|
|
assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now"
|
|
self.stage = stage
|
|
self.precision = precision
|
|
self.zero_optim_kwargs = dict(
|
|
initial_scale=initial_scale,
|
|
growth_factor=growth_factor,
|
|
backoff_factor=backoff_factor,
|
|
growth_interval=growth_interval,
|
|
hysteresis=hysteresis,
|
|
min_scale=min_scale,
|
|
max_scale=max_scale,
|
|
clip_grad_norm=max_norm,
|
|
reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
|
|
communication_dtype=communication_dtype,
|
|
overlap_communication=overlap_communication,
|
|
cpu_offload=cpu_offload,
|
|
partition_grad=(stage == 2),
|
|
)
|
|
self.verbose = verbose
|
|
|
|
# set class name with stage, for better error message
|
|
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
|
|
|
|
def support_no_sync(self) -> bool:
|
|
return self.stage == 1
|
|
|
|
def control_precision(self) -> bool:
|
|
return True
|
|
|
|
def supported_precisions(self) -> List[str]:
|
|
return SUPPORTED_PRECISION
|
|
|
|
def control_device(self) -> bool:
|
|
return True
|
|
|
|
def supported_devices(self) -> List[str]:
|
|
return ["cuda"]
|
|
|
|
def configure(
|
|
self,
|
|
model: nn.Module,
|
|
optimizer: Optional[Optimizer] = None,
|
|
criterion: Optional[Callable] = None,
|
|
dataloader: Optional[DataLoader] = None,
|
|
lr_scheduler: Optional[LRScheduler] = None,
|
|
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
|
if not isinstance(model, ModelWrapper):
|
|
model = LowLevelZeroModel(model, self.precision)
|
|
|
|
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
|
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
|
|
optimizer, **self.zero_optim_kwargs, verbose=self.verbose
|
|
)
|
|
# inject update_master_params
|
|
model.update_master_params = MethodType(optimizer.update_master_params, model)
|
|
|
|
return model, optimizer, criterion, dataloader, lr_scheduler
|
|
|
|
def control_checkpoint_io(self) -> bool:
|
|
return True
|
|
|
|
def get_checkpoint_io(self) -> CheckpointIO:
|
|
return LowLevelZeroCheckpointIO()
|
|
|
|
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
|
assert isinstance(optimizer, LowLevelZeroOptimizer)
|
|
return optimizer.optim.no_sync()
|