from copy import copy from typing import Dict, Optional import torch import torch.nn as nn from .gemini_parallel import GeminiDDP def zero_model_wrapper(model: nn.Module, zero_stage: int = 1, gemini_config: Optional[Dict] = None): """This wrapper function is used to wrap your training model for ZeRO DDP. Example: >>> with ColoInitContext(): >>> my_model = Bert() >>> my_optim = SGD(my_model.parameters(), lr = 1e-3) >>> zero_model = zero_model_wrapper(my_model, zero_stage=1) >>> zero_optim = zero_optim_wrapper(zero_model, my_optim) Args: model (nn.Module): The model used in ZeRO DDP. zero_stage (int, optional): The stage of ZeRO DDP. You can find more information in ZeRO's paper. https://arxiv.org/abs/1910.02054 gemini_config (dict, optional): The configuration dictionary of `GeminiDDP`. `GeminiDDP` is enabled when the stage is set to 3. You can set the arguemnts of `GeminiDDP` in the gemini_config. Here is an example where we set the device of the model, the placement policy of Gemini, and the size of hidden dimension to help Gemini find out a unified chunk size. Example: >>> config_dict = dict(device=torch.cuda.current_device(), hidden_dim=1024, placement_policy='auto') >>> model = zero_model_wrapper(model, zero_stage=3, gemini_config=config_dict) """ assert zero_stage in [1, 2, 3], "The stage of ZeRO should be 1, 2 or 3" if gemini_config is None: gemini_config = dict() if zero_stage in [1, 2]: wrapped_model = model else: wrapped_model = GeminiDDP(model, **gemini_config) setattr(wrapped_model, "_colo_zero_stage", zero_stage) return wrapped_model def zero_optim_wrapper(model: nn.Module, optimizer: torch.optim.Optimizer, initial_scale: float = 2**16, growth_factor: float = 2, backoff_factor: float = 0.5, growth_interval: int = 1000, hysteresis: int = 2, min_scale: float = 1, max_scale: float = 2**32, max_norm: float = 0.0, norm_type: float = 2.0, optim_config: Optional[Dict] = None): """This wrapper function is used to wrap your training optimizer for ZeRO DDP. Args: model (nn.Module): Your model wrapped by `zero_model_wrapper` optimizer (torch.optim.Optimizer): Your initialized optimizer initial_scale (float, optional): initial_scale used by DynamicGradScaler. min_scale (float, optional): min_scale used by DynamicGradScaler. growth_factor (float, optional): growth_factor used by DynamicGradScaler. backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. growth_interval (float, optional): growth_interval used by DynamicGradScaler. hysteresis (float, optional): hysteresis used by DynamicGradScaler. max_scale (int, optional): max_scale used by DynamicGradScaler. max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm. norm_type (float, optional): norm_type used for `clip_grad_norm`. optim_config (dict, optinoal): The configuration used for the ZeRO optimizer. Example: >>> zero2_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True) >>> optim = zero_optim_wrapper(model, optim, optim_config=zero2_config) """ assert hasattr(model, "_colo_zero_stage"), "You should use `zero_ddp_wrapper` first" zero_stage = getattr(model, "_colo_zero_stage") assert norm_type == 2.0, "Current ZeRO optimizers only support 'norm_type=2'" if optim_config is None: config_dict = dict() else: config_dict = copy(optim_config) config_dict['initial_scale'] = initial_scale config_dict['growth_factor'] = growth_factor config_dict['backoff_factor'] = backoff_factor config_dict['growth_interval'] = growth_interval config_dict['hysteresis'] = hysteresis config_dict['min_scale'] = min_scale config_dict['max_scale'] = max_scale if zero_stage in [1, 2]: from colossalai.zero.sharded_optim.low_level_optim import LowLevelZeroOptimizer config_dict['partition_grad'] = zero_stage == 2 config_dict['clip_grad_norm'] = max_norm return LowLevelZeroOptimizer(optimizer, **config_dict) else: from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer config_dict['clipping_norm'] = max_norm return ZeroOptimizer(optimizer, model, **config_dict)