ColossalAI/colossalai/zero/wrapper.py

118 lines
4.8 KiB
Python
Raw Normal View History

from copy import copy
from typing import Dict, Optional
import torch
import torch.nn as nn
from .gemini import GeminiDDP
def zero_model_wrapper(
model: nn.Module, zero_stage: int = 1, gemini_config: Optional[Dict] = None, verbose: bool = False
):
"""This wrapper function is used to wrap your training model for ZeRO DDP.
Example:
>>> with ColoInitContext():
>>> my_model = Bert()
>>> my_optim = SGD(my_model.parameters(), lr = 1e-3)
>>> zero_model = zero_model_wrapper(my_model, zero_stage=1)
>>> zero_optim = zero_optim_wrapper(zero_model, my_optim)
Args:
model (nn.Module): The model used in ZeRO DDP.
zero_stage (int, optional): The stage of ZeRO DDP. You can find more information in ZeRO's paper.
https://arxiv.org/abs/1910.02054
gemini_config (dict, optional): The configuration dictionary of `GeminiDDP`. `GeminiDDP` is enabled
when the stage is set to 3. You can set the arguments of `GeminiDDP` in the gemini_config.
Here is an example where we set the device of the model, the placement policy of Gemini, and the
size of hidden dimension to help Gemini find out a unified chunk size.
Example:
>>> config_dict = dict(device=torch.cuda.current_device(), hidden_dim=1024, placement_policy='auto')
>>> model = zero_model_wrapper(model, zero_stage=3, gemini_config=config_dict)
"""
assert zero_stage in [1, 2, 3], "The stage of ZeRO should be 1, 2 or 3"
if gemini_config is None:
gemini_config = dict()
if zero_stage in [1, 2]:
wrapped_model = model
else:
wrapped_model = GeminiDDP(model, **gemini_config, verbose=verbose)
setattr(wrapped_model, "_colo_zero_stage", zero_stage)
return wrapped_model
def zero_optim_wrapper(
model: nn.Module,
optimizer: torch.optim.Optimizer,
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
optim_config: Optional[Dict] = None,
verbose: bool = False,
):
"""This wrapper function is used to wrap your training optimizer for ZeRO DDP.
Args:
model (nn.Module): Your model wrapped by `zero_model_wrapper`
optimizer (torch.optim.Optimizer): Your initialized optimizer
initial_scale (float, optional): initial_scale used by DynamicGradScaler.
min_scale (float, optional): min_scale used by DynamicGradScaler.
growth_factor (float, optional): growth_factor used by DynamicGradScaler.
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler.
growth_interval (float, optional): growth_interval used by DynamicGradScaler.
hysteresis (float, optional): hysteresis used by DynamicGradScaler.
max_scale (int, optional): max_scale used by DynamicGradScaler.
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
norm_type (float, optional): norm_type used for `clip_grad_norm`.
optim_config (dict, optional): The configuration used for the ZeRO optimizer.
Example:
>>> zero2_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True)
>>> optim = zero_optim_wrapper(model, optim, optim_config=zero2_config)
verbose (bool, optional): Whether to print the verbose info.
"""
assert hasattr(model, "_colo_zero_stage"), "You should use `zero_ddp_wrapper` first"
zero_stage = getattr(model, "_colo_zero_stage")
assert norm_type == 2.0, "Current ZeRO optimizers only support 'norm_type=2'"
if optim_config is None:
config_dict = dict()
else:
config_dict = copy(optim_config)
config_dict["initial_scale"] = initial_scale
config_dict["growth_factor"] = growth_factor
config_dict["backoff_factor"] = backoff_factor
config_dict["growth_interval"] = growth_interval
config_dict["hysteresis"] = hysteresis
config_dict["min_scale"] = min_scale
config_dict["max_scale"] = max_scale
if zero_stage in [1, 2]:
from colossalai.zero.low_level import LowLevelZeroOptimizer
config_dict["partition_grad"] = zero_stage == 2
config_dict["clip_grad_norm"] = max_norm
return LowLevelZeroOptimizer(optimizer, **config_dict, verbose=verbose)
else:
[gemini] improve compatibility and add static placement policy (#4479) * [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
2023-08-24 01:29:25 +00:00
from colossalai.zero.gemini.gemini_optimizer import GeminiOptimizer
config_dict["clipping_norm"] = max_norm
[gemini] improve compatibility and add static placement policy (#4479) * [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
2023-08-24 01:29:25 +00:00
return GeminiOptimizer(optimizer, model, **config_dict, verbose=verbose)