import warnings from functools import partial from typing import Callable, Iterator, List, Optional, Tuple, Union import torch import torch.nn as nn from torch import Tensor from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils._pytree import tree_map from torch.utils.data import DataLoader from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.utils import get_current_device from colossalai.zero import zero_model_wrapper, zero_optim_wrapper from .dp_plugin_base import DPPluginBase from .torch_ddp_plugin import TorchDDPCheckpointIO __all__ = ['LowLevelZeroPlugin'] def _convert_floating_point(x, dtype: torch.dtype = torch.float16): if isinstance(x, torch.Tensor) and torch.is_floating_point(x): return x.to(dtype) return x SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32'] class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): """ Save optimizer to checkpoint but only on master process. """ # TODO(ver217): optimizer state dict is sharded, and cannot get full state dict now warnings.warn( 'LowLevelZeroPlugin does not support save full optimizer checkpoint now. Save it on every process.') checkpoint = f'{checkpoint}.rank{self.coordinator.rank}' GeneralCheckpointIO.save_unsharded_optimizer(self, optimizer, checkpoint, gather_dtensor) def load_optimizer(self, optimizer: Optimizer, checkpoint: str): warnings.warn( 'LowLevelZeroPlugin can only load optimizer checkpoint saved by itself with the same number of processes.') checkpoint = f'{checkpoint}.rank{self.coordinator.rank}' super().load_optimizer(optimizer, checkpoint) class LowLevelZeroModel(ModelWrapper): def __init__(self, module: nn.Module, stage: int, precision: str) -> None: super().__init__(module) self.dtype = None if precision == 'fp16': self.dtype = torch.float16 elif precision == 'bf16': self.dtype = torch.bfloat16 module = zero_model_wrapper(module, zero_stage=stage) if self.dtype is not None: module = module.to(self.dtype) module = module.to(get_current_device()) self.module = module self.convert_fn = None if self.dtype is not None: self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) def forward(self, *args, **kwargs): if self.convert_fn is not None: args = tree_map(self.convert_fn, args) kwargs = tree_map(self.convert_fn, kwargs) return super().forward(*args, **kwargs) class LowLevelZeroOptimizer(OptimizerWrapper): def __init__(self, module: nn.Module, optimizer: Optimizer, zero_optim_config: dict, optim_kwargs: dict, verbose: bool = False) -> None: optimizer = zero_optim_wrapper(module, optimizer, optim_config=zero_optim_config, **optim_kwargs, verbose=verbose) super().__init__(optimizer) def backward(self, loss: Tensor, *args, **kwargs): self.optim.backward(loss) def clip_grad_by_norm(self, max_norm: Union[float, int], norm_type: Union[float, int] = 2, error_if_nonfinite: bool = False, *args, **kwargs) -> Tensor: warnings.warn(f'LowLevelZero controls grad clipping by itself, so you should not use clip_grad_by_norm') def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: raise NotImplementedError('LowLevelZero does not support clip_grad_by_value') class LowLevelZeroPlugin(DPPluginBase): """ Plugin for low level zero. Example: >>> from colossalai.booster import Booster >>> from colossalai.booster.plugin import LowLevelZeroPlugin >>> >>> model, train_dataset, optimizer, criterion = ... >>> plugin = LowLevelZeroPlugin() >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) >>> booster = Booster(plugin=plugin) >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) Args: strage (int, optional): ZeRO stage. Defaults to 1. precision (str, optional): precision. Support 'fp16', 'bf16' and 'fp32'. Defaults to 'fp16'. initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5. growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000. hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2. max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32. max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm. norm_type (float, optional): norm_type used for `clip_grad_norm`. reduce_bucket_size_in_m (int, optional): grad reduce bucket size in M. Defaults to 12. communication_dtype (torch.dtype, optional): communication dtype. If not specified, the dtype of param will be used. Defaults to None. overlap_communication (bool, optional): whether to overlap communication and computation. Defaults to True. cpu_offload (bool, optional): whether to offload grad, master weight and optimizer state to cpu. Defaults to False. verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False. """ def __init__( self, stage: int = 1, precision: str = 'fp16', initial_scale: float = 2**32, min_scale: float = 1, growth_factor: float = 2, backoff_factor: float = 0.5, growth_interval: int = 1000, hysteresis: int = 2, max_scale: float = 2**32, max_norm: float = 0.0, norm_type: float = 2.0, reduce_bucket_size_in_m: int = 12, communication_dtype: Optional[torch.dtype] = None, overlap_communication: bool = True, cpu_offload: bool = False, verbose: bool = False, ) -> None: super().__init__() assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training' assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training' self.stage = stage self.precision = precision self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024, communication_dtype=communication_dtype, overlap_communication=overlap_communication, cpu_offload=cpu_offload) self.optim_kwargs = dict(initial_scale=initial_scale, growth_factor=growth_factor, backoff_factor=backoff_factor, growth_interval=growth_interval, hysteresis=hysteresis, min_scale=min_scale, max_scale=max_scale, max_norm=max_norm, norm_type=norm_type) self.verbose = verbose def support_no_sync(self) -> bool: return False def control_precision(self) -> bool: return True def supported_precisions(self) -> List[str]: return SUPPORTED_PRECISION def control_device(self) -> bool: return True def supported_devices(self) -> List[str]: return ['cuda'] def configure( self, model: nn.Module, optimizer: Optimizer, criterion: Callable = None, dataloader: DataLoader = None, lr_scheduler: LRScheduler = None, ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: if not isinstance(model, ModelWrapper): model = LowLevelZeroModel(model, self.stage, self.precision) if not isinstance(optimizer, OptimizerWrapper): optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, self.verbose) return model, optimizer, criterion, dataloader, lr_scheduler def control_checkpoint_io(self) -> bool: return True def get_checkpoint_io(self) -> CheckpointIO: return LowLevelZeroCheckpointIO() def no_sync(self, model: nn.Module) -> Iterator[None]: raise NotImplementedError