import warnings from typing import Callable, Iterator, List, Optional, Tuple, Union import torch import torch.nn as nn from torch import Tensor from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils._pytree import tree_map from torch.utils.data import DataLoader from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.utils import get_current_device from colossalai.zero import zero_model_wrapper, zero_optim_wrapper from .dp_plugin_base import DPPluginBase from .torch_ddp_plugin import TorchDDPCheckpointIO __all__ = ['LowLevelZeroPlugin'] def _convert_to_fp16(x): if isinstance(x, torch.Tensor) and torch.is_floating_point(x): return x.half() return x class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): """ Save optimizer to checkpoint but only on master process. """ # TODO(ver217): optimizer state dict is sharded, and cannot get full state dict now warnings.warn( 'LowLevelZeroPlugin does not support save full optimizer checkpoint now. Save it on every process.') checkpoint = f'{checkpoint}.rank{self.coordinator.rank}' GeneralCheckpointIO.save_unsharded_optimizer(self, optimizer, checkpoint, gather_dtensor) def load_optimizer(self, optimizer: Optimizer, checkpoint: str): warnings.warn( 'LowLevelZeroPlugin can only load optimizer checkpoint saved by itself with the same number of processes.') checkpoint = f'{checkpoint}.rank{self.coordinator.rank}' super().load_optimizer(optimizer, checkpoint) class LowLevelZeroModel(ModelWrapper): def __init__(self, module: nn.Module, stage: int, precision: str) -> None: super().__init__(module) self.convert_inputs = (precision == 'fp16') module = zero_model_wrapper(module, zero_stage=stage) if precision == 'fp16': module = module.half() module = module.to(get_current_device()) self.module = module def forward(self, *args, **kwargs): if self.convert_inputs: args = tree_map(_convert_to_fp16, args) kwargs = tree_map(_convert_to_fp16, kwargs) return super().forward(*args, **kwargs) class LowLevelZeroOptimizer(OptimizerWrapper): def __init__(self, module: nn.Module, optimizer: Optimizer, zero_optim_config: dict, optim_kwargs: dict, verbose: bool = False) -> None: optimizer = zero_optim_wrapper(module, optimizer, optim_config=zero_optim_config, **optim_kwargs, verbose=verbose) super().__init__(optimizer) def backward(self, loss: Tensor, *args, **kwargs): self.optim.backward(loss) def clip_grad_by_norm(self, max_norm: Union[float, int], norm_type: Union[float, int] = 2, error_if_nonfinite: bool = False, *args, **kwargs) -> Tensor: warnings.warn(f'LowLevelZero controls grad clipping by itself, so you should not use clip_grad_by_norm') def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: raise NotImplementedError('LowLevelZero does not support clip_grad_by_value') class LowLevelZeroPlugin(DPPluginBase): """ Plugin for low level zero. Example: >>> from colossalai.booster import Booster >>> from colossalai.booster.plugin import LowLevelZeroPlugin >>> >>> model, train_dataset, optimizer, criterion = ... >>> plugin = LowLevelZeroPlugin() >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) >>> booster = Booster(plugin=plugin) >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) Args: strage (int, optional): ZeRO stage. Defaults to 1. precision (str, optional): precision. Support 'fp16' and 'fp32'. Defaults to 'fp16'. initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5. growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000. hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2. max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32. max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm. norm_type (float, optional): norm_type used for `clip_grad_norm`. reduce_bucket_size_in_m (int, optional): grad reduce bucket size in M. Defaults to 12. communication_dtype (torch.dtype, optional): communication dtype. If not specified, the dtype of param will be used. Defaults to None. overlap_communication (bool, optional): whether to overlap communication and computation. Defaults to True. cpu_offload (bool, optional): whether to offload grad, master weight and optimizer state to cpu. Defaults to False. verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False. """ def __init__( self, stage: int = 1, precision: str = 'fp16', initial_scale: float = 2**32, min_scale: float = 1, growth_factor: float = 2, backoff_factor: float = 0.5, growth_interval: int = 1000, hysteresis: int = 2, max_scale: float = 2**32, max_norm: float = 0.0, norm_type: float = 2.0, reduce_bucket_size_in_m: int = 12, communication_dtype: Optional[torch.dtype] = None, overlap_communication: bool = True, cpu_offload: bool = False, verbose: bool = False, ) -> None: super().__init__() assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training' assert precision in ('fp16', 'fp32'), f'LowLevelZeroPlugin only supports fp16/fp32 training' self.stage = stage self.precision = precision self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024, communication_dtype=communication_dtype, overlap_communication=overlap_communication, cpu_offload=cpu_offload) self.optim_kwargs = dict(initial_scale=initial_scale, growth_factor=growth_factor, backoff_factor=backoff_factor, growth_interval=growth_interval, hysteresis=hysteresis, min_scale=min_scale, max_scale=max_scale, max_norm=max_norm, norm_type=norm_type) self.verbose = verbose def support_no_sync(self) -> bool: return False def control_precision(self) -> bool: return True def supported_precisions(self) -> List[str]: return ['fp16', 'fp32'] def control_device(self) -> bool: return True def supported_devices(self) -> List[str]: return ['cuda'] def configure( self, model: nn.Module, optimizer: Optimizer, criterion: Callable = None, dataloader: DataLoader = None, lr_scheduler: LRScheduler = None, ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: if not isinstance(model, ModelWrapper): model = LowLevelZeroModel(model, self.stage, self.precision) if not isinstance(optimizer, OptimizerWrapper): optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, self.verbose) return model, optimizer, criterion, dataloader, lr_scheduler def control_checkpoint_io(self) -> bool: return True def get_checkpoint_io(self) -> CheckpointIO: return LowLevelZeroCheckpointIO() def no_sync(self, model: nn.Module) -> Iterator[None]: raise NotImplementedError