Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

117 lines
4.8 KiB

from copy import copy
from typing import Dict, Optional
import torch
import torch.nn as nn
from .gemini import GeminiDDP
def zero_model_wrapper(
model: nn.Module, zero_stage: int = 1, gemini_config: Optional[Dict] = None, verbose: bool = False
):
"""This wrapper function is used to wrap your training model for ZeRO DDP.
Example:
>>> with ColoInitContext():
>>> my_model = Bert()
>>> my_optim = SGD(my_model.parameters(), lr = 1e-3)
>>> zero_model = zero_model_wrapper(my_model, zero_stage=1)
>>> zero_optim = zero_optim_wrapper(zero_model, my_optim)
Args:
model (nn.Module): The model used in ZeRO DDP.
zero_stage (int, optional): The stage of ZeRO DDP. You can find more information in ZeRO's paper.
https://arxiv.org/abs/1910.02054
gemini_config (dict, optional): The configuration dictionary of `GeminiDDP`. `GeminiDDP` is enabled
when the stage is set to 3. You can set the arguments of `GeminiDDP` in the gemini_config.
Here is an example where we set the device of the model, the placement policy of Gemini, and the
size of hidden dimension to help Gemini find out a unified chunk size.
Example:
>>> config_dict = dict(device=torch.cuda.current_device(), hidden_dim=1024, placement_policy='auto')
>>> model = zero_model_wrapper(model, zero_stage=3, gemini_config=config_dict)
"""
assert zero_stage in [1, 2, 3], "The stage of ZeRO should be 1, 2 or 3"
if gemini_config is None:
gemini_config = dict()
if zero_stage in [1, 2]:
wrapped_model = model
else:
wrapped_model = GeminiDDP(model, **gemini_config, verbose=verbose)
setattr(wrapped_model, "_colo_zero_stage", zero_stage)
return wrapped_model
def zero_optim_wrapper(
model: nn.Module,
optimizer: torch.optim.Optimizer,
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
optim_config: Optional[Dict] = None,
verbose: bool = False,
):
"""This wrapper function is used to wrap your training optimizer for ZeRO DDP.
Args:
model (nn.Module): Your model wrapped by `zero_model_wrapper`
optimizer (torch.optim.Optimizer): Your initialized optimizer
initial_scale (float, optional): initial_scale used by DynamicGradScaler.
min_scale (float, optional): min_scale used by DynamicGradScaler.
growth_factor (float, optional): growth_factor used by DynamicGradScaler.
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler.
growth_interval (float, optional): growth_interval used by DynamicGradScaler.
hysteresis (float, optional): hysteresis used by DynamicGradScaler.
max_scale (int, optional): max_scale used by DynamicGradScaler.
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
norm_type (float, optional): norm_type used for `clip_grad_norm`.
optim_config (dict, optional): The configuration used for the ZeRO optimizer.
Example:
>>> zero2_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True)
>>> optim = zero_optim_wrapper(model, optim, optim_config=zero2_config)
verbose (bool, optional): Whether to print the verbose info.
"""
assert hasattr(model, "_colo_zero_stage"), "You should use `zero_ddp_wrapper` first"
zero_stage = getattr(model, "_colo_zero_stage")
assert norm_type == 2.0, "Current ZeRO optimizers only support 'norm_type=2'"
if optim_config is None:
config_dict = dict()
else:
config_dict = copy(optim_config)
config_dict["initial_scale"] = initial_scale
config_dict["growth_factor"] = growth_factor
config_dict["backoff_factor"] = backoff_factor
config_dict["growth_interval"] = growth_interval
config_dict["hysteresis"] = hysteresis
config_dict["min_scale"] = min_scale
config_dict["max_scale"] = max_scale
if zero_stage in [1, 2]:
from colossalai.zero.low_level import LowLevelZeroOptimizer
config_dict["partition_grad"] = zero_stage == 2
config_dict["clip_grad_norm"] = max_norm
return LowLevelZeroOptimizer(optimizer, **config_dict, verbose=verbose)
else:
from colossalai.zero.gemini.gemini_optimizer import GeminiOptimizer
config_dict["clipping_norm"] = max_norm
return GeminiOptimizer(optimizer, model, **config_dict, verbose=verbose)