mirror of https://github.com/hpcaitech/ColossalAI
[zero] add zero wrappers (#2523)
* [zero] add zero wrappers * change names * add wrapper functions to initpull/2527/head
parent
c198c7c0b0
commit
b528eea0f0
|
@ -65,7 +65,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||||
**defaults: Any):
|
**defaults: Any):
|
||||||
super().__init__(optim)
|
super().__init__(optim)
|
||||||
assert isinstance(module, ZeroDDP)
|
assert isinstance(module, ZeroDDP)
|
||||||
assert type(optim) in _AVAIL_OPTIM_LIST, "you should use the optimizer in the available list"
|
assert type(optim) in _AVAIL_OPTIM_LIST, "You should use an optimizer in the available list:\n" \
|
||||||
|
f"{_AVAIL_OPTIM_LIST}"
|
||||||
self.module = module
|
self.module = module
|
||||||
self.gemini_manager = module.gemini_manager
|
self.gemini_manager = module.gemini_manager
|
||||||
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
|
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from .data_parallel import ColoDDP, ZeroDDP
|
from .data_parallel import ColoDDP, ZeroDDP
|
||||||
from .gemini_parallel import GeminiDDP
|
from .gemini_parallel import GeminiDDP
|
||||||
|
from .zero_wrapper import zero_model_wrapper, zero_optim_wrapper
|
||||||
|
|
||||||
__all__ = ['ColoDDP', 'ZeroDDP', 'GeminiDDP']
|
__all__ = ['ColoDDP', 'ZeroDDP', 'GeminiDDP', 'zero_model_wrapper', 'zero_optim_wrapper']
|
||||||
|
|
|
@ -0,0 +1,106 @@
|
||||||
|
from copy import copy
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .gemini_parallel import GeminiDDP
|
||||||
|
|
||||||
|
|
||||||
|
def zero_model_wrapper(model: nn.Module, zero_stage: int = 1, gemini_config: Optional[Dict] = None):
|
||||||
|
"""This wrapper function is used to wrap your training model for ZeRO DDP.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> with ColoInitContext():
|
||||||
|
>>> my_model = Bert()
|
||||||
|
>>> my_optim = SGD(my_model.parameters(), lr = 1e-3)
|
||||||
|
>>> zero_model = zero_model_wrapper(my_model, zero_stage=1)
|
||||||
|
>>> zero_optim = zero_optim_wrapper(zero_model, my_optim)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The model used in ZeRO DDP.
|
||||||
|
zero_stage (int, optional): The stage of ZeRO DDP. You can find more information in ZeRO's paper.
|
||||||
|
https://arxiv.org/abs/1910.02054
|
||||||
|
gemini_config (dict, optional): The configuration dictionary of `GeminiDDP`. `GeminiDDP` is enabled
|
||||||
|
when the stage is set to 3. You can set the arguemnts of `GeminiDDP` in the gemini_config.
|
||||||
|
Here is an example where we set the device of the model, the placement policy of Gemini, and the
|
||||||
|
size of hidden dimension to help Gemini find out a unified chunk size.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> config_dict = dict(device=torch.cuda.current_device(), hidden_dim=1024, placement_policy='auto')
|
||||||
|
>>> model = zero_model_wrapper(model, zero_stage=3, gemini_config=config_dict)
|
||||||
|
"""
|
||||||
|
setattr(model, "_colo_zero_stage", zero_stage)
|
||||||
|
assert zero_stage in [1, 2, 3], "The stage of ZeRO should be 1, 2 or 3"
|
||||||
|
|
||||||
|
if gemini_config is None:
|
||||||
|
gemini_config = dict()
|
||||||
|
|
||||||
|
if zero_stage in [1, 2]:
|
||||||
|
return model
|
||||||
|
else:
|
||||||
|
return GeminiDDP(model, **gemini_config)
|
||||||
|
|
||||||
|
|
||||||
|
def zero_optim_wrapper(model: nn.Module,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
initial_scale: float = 2**16,
|
||||||
|
growth_factor: float = 2,
|
||||||
|
backoff_factor: float = 0.5,
|
||||||
|
growth_interval: int = 1000,
|
||||||
|
hysteresis: int = 2,
|
||||||
|
min_scale: float = 1,
|
||||||
|
max_scale: float = 2**32,
|
||||||
|
max_norm: float = 0.0,
|
||||||
|
norm_type: float = 2.0,
|
||||||
|
optim_config: Optional[Dict] = None):
|
||||||
|
"""This wrapper function is used to wrap your training optimizer for ZeRO DDP.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): Your model wrapped by `zero_model_wrapper`
|
||||||
|
optimizer (torch.optim.Optimizer): Your initialized optimizer
|
||||||
|
initial_scale (float, optional): initial_scale used by DynamicGradScaler.
|
||||||
|
min_scale (float, optional): min_scale used by DynamicGradScaler.
|
||||||
|
growth_factor (float, optional): growth_factor used by DynamicGradScaler.
|
||||||
|
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler.
|
||||||
|
growth_interval (float, optional): growth_interval used by DynamicGradScaler.
|
||||||
|
hysteresis (float, optional): hysteresis used by DynamicGradScaler.
|
||||||
|
max_scale (int, optional): max_scale used by DynamicGradScaler.
|
||||||
|
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
|
||||||
|
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
|
||||||
|
norm_type (float, optional): norm_type used for `clip_grad_norm`.
|
||||||
|
optim_config (dict, optinoal): The configuration used for the ZeRO optimizer.
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> zero2_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True)
|
||||||
|
>>> optim = zero_optim_wrapper(model, optim, optim_config=zero2_config)
|
||||||
|
"""
|
||||||
|
assert hasattr(model, "_colo_zero_stage"), "You should use `zero_ddp_wrapper` first"
|
||||||
|
zero_stage = getattr(model, "_colo_zero_stage")
|
||||||
|
|
||||||
|
assert norm_type == 2.0, "Current ZeRO optimizers only support 'norm_type=2'"
|
||||||
|
|
||||||
|
if optim_config is None:
|
||||||
|
config_dict = dict()
|
||||||
|
else:
|
||||||
|
config_dict = copy(optim_config)
|
||||||
|
|
||||||
|
config_dict['initial_scale'] = initial_scale
|
||||||
|
config_dict['growth_factor'] = growth_factor
|
||||||
|
config_dict['backoff_factor'] = backoff_factor
|
||||||
|
config_dict['growth_interval'] = growth_interval
|
||||||
|
config_dict['hysteresis'] = hysteresis
|
||||||
|
config_dict['min_scale'] = min_scale
|
||||||
|
config_dict['max_scale'] = max_scale
|
||||||
|
|
||||||
|
if zero_stage in [1, 2]:
|
||||||
|
from colossalai.zero.sharded_optim.low_level_optim import LowLevelZeroOptimizer
|
||||||
|
config_dict['partition_grad'] = zero_stage == 2
|
||||||
|
config_dict['clip_grad_norm'] = max_norm
|
||||||
|
return LowLevelZeroOptimizer(optimizer, **config_dict)
|
||||||
|
else:
|
||||||
|
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
|
||||||
|
config_dict['clipping_norm'] = max_norm
|
||||||
|
return ZeroOptimizer(optimizer, model, **config_dict)
|
|
@ -17,7 +17,6 @@ from ._utils import (
|
||||||
calculate_global_norm_from_list,
|
calculate_global_norm_from_list,
|
||||||
compute_norm,
|
compute_norm,
|
||||||
flatten,
|
flatten,
|
||||||
get_grad_accumulate_object,
|
|
||||||
has_inf_or_nan,
|
has_inf_or_nan,
|
||||||
reduce_tensor_dp_group,
|
reduce_tensor_dp_group,
|
||||||
release_param_grad,
|
release_param_grad,
|
||||||
|
@ -386,7 +385,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||||
# torch.optim.Optimizer methods
|
# torch.optim.Optimizer methods
|
||||||
################################
|
################################
|
||||||
|
|
||||||
def backward(self, loss, retain_graph=False):
|
def backward(self, loss, retain_graph=False, sync_grad=True):
|
||||||
loss = self.loss_scale * loss
|
loss = self.loss_scale * loss
|
||||||
loss.backward(retain_graph=retain_graph)
|
loss.backward(retain_graph=retain_graph)
|
||||||
|
|
||||||
|
@ -402,6 +401,10 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
self._param_store.clear_grads_of_previous_reduced_params()
|
self._param_store.clear_grads_of_previous_reduced_params()
|
||||||
|
|
||||||
|
# gradient synchronization
|
||||||
|
if sync_grad:
|
||||||
|
self._sync_grad()
|
||||||
|
|
||||||
def zero_grad(self, set_to_none=True):
|
def zero_grad(self, set_to_none=True):
|
||||||
"""
|
"""
|
||||||
Set parameter gradients to zero. If set_to_none = True, gradient
|
Set parameter gradients to zero. If set_to_none = True, gradient
|
||||||
|
@ -537,7 +540,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||||
# Gradient Synchronization #
|
# Gradient Synchronization #
|
||||||
############################
|
############################
|
||||||
|
|
||||||
def sync_grad(self):
|
def _sync_grad(self):
|
||||||
# update param already reduced flag
|
# update param already reduced flag
|
||||||
reduction_states = self._param_store.get_param_reduction_states()
|
reduction_states = self._param_store.get_param_reduction_states()
|
||||||
for tensor, state in reduction_states.items():
|
for tensor, state in reduction_states.items():
|
||||||
|
|
|
@ -9,7 +9,6 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.tensor import ProcessGroup
|
|
||||||
from colossalai.testing.random import seed_all
|
from colossalai.testing.random import seed_all
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.zero import LowLevelZeroOptimizer
|
from colossalai.zero import LowLevelZeroOptimizer
|
||||||
|
@ -60,16 +59,16 @@ def exam_zero_1_2_grad_acc():
|
||||||
assert torch.equal(zero1_output, zero2_output)
|
assert torch.equal(zero1_output, zero2_output)
|
||||||
|
|
||||||
# zero-dp backward
|
# zero-dp backward
|
||||||
zero1_optimizer.backward(zero1_output.sum().float())
|
zero1_optimizer.backward(zero1_output.sum().float(), sync_grad=False)
|
||||||
zero2_optimizer.backward(zero2_output.sum().float())
|
zero2_optimizer.backward(zero2_output.sum().float(), sync_grad=False)
|
||||||
|
|
||||||
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
|
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
|
||||||
if z2p.grad is not None:
|
if z2p.grad is not None:
|
||||||
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
|
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
|
||||||
assert torch.equal(z1p.grad, z2p.grad)
|
assert torch.equal(z1p.grad, z2p.grad)
|
||||||
|
|
||||||
zero1_optimizer.sync_grad()
|
zero1_optimizer._sync_grad()
|
||||||
zero2_optimizer.sync_grad()
|
zero2_optimizer._sync_grad()
|
||||||
|
|
||||||
fwd_bwd_func(0, input_data1)
|
fwd_bwd_func(0, input_data1)
|
||||||
fwd_bwd_func(1, input_data2)
|
fwd_bwd_func(1, input_data2)
|
||||||
|
@ -124,7 +123,7 @@ def exam_zero_1_grad_acc():
|
||||||
assert torch.equal(zero_output, torch_output)
|
assert torch.equal(zero_output, torch_output)
|
||||||
|
|
||||||
# zero-dp backward
|
# zero-dp backward
|
||||||
zero_optimizer.backward(zero_output.sum().float())
|
zero_optimizer.backward(zero_output.sum().float(), sync_grad=False)
|
||||||
# torch-ddp backward
|
# torch-ddp backward
|
||||||
torch_output.sum().backward()
|
torch_output.sum().backward()
|
||||||
|
|
||||||
|
@ -135,7 +134,7 @@ def exam_zero_1_grad_acc():
|
||||||
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
|
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
|
||||||
assert torch.equal(p.grad, unscale_grad)
|
assert torch.equal(p.grad, unscale_grad)
|
||||||
|
|
||||||
zero_optimizer.sync_grad()
|
zero_optimizer._sync_grad()
|
||||||
|
|
||||||
fwd_bwd_func(0, input_data1, True)
|
fwd_bwd_func(0, input_data1, True)
|
||||||
fwd_bwd_func(1, input_data2, False)
|
fwd_bwd_func(1, input_data2, False)
|
||||||
|
|
|
@ -78,16 +78,16 @@ def exam_zero_1_2():
|
||||||
assert torch.equal(zero1_output, zero2_output)
|
assert torch.equal(zero1_output, zero2_output)
|
||||||
|
|
||||||
# zero-dp backward
|
# zero-dp backward
|
||||||
zero1_optimizer.backward(zero1_output.mean().float())
|
zero1_optimizer.backward(zero1_output.mean().float(), sync_grad=False)
|
||||||
zero2_optimizer.backward(zero2_output.mean().float())
|
zero2_optimizer.backward(zero2_output.mean().float(), sync_grad=False)
|
||||||
|
|
||||||
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
|
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
|
||||||
if z2p.grad is not None:
|
if z2p.grad is not None:
|
||||||
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
|
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
|
||||||
assert torch.equal(z1p.grad, z2p.grad)
|
assert torch.equal(z1p.grad, z2p.grad)
|
||||||
|
|
||||||
zero1_optimizer.sync_grad()
|
zero1_optimizer._sync_grad()
|
||||||
zero2_optimizer.sync_grad()
|
zero2_optimizer._sync_grad()
|
||||||
|
|
||||||
# step
|
# step
|
||||||
zero1_optimizer.step()
|
zero1_optimizer.step()
|
||||||
|
@ -146,7 +146,7 @@ def exam_zero_1_torch_ddp():
|
||||||
half_close(zero_output, torch_output, loose=True)
|
half_close(zero_output, torch_output, loose=True)
|
||||||
|
|
||||||
# zero-dp backward
|
# zero-dp backward
|
||||||
zero_optimizer.backward(zero_output.mean().float())
|
zero_optimizer.backward(zero_output.mean().float(), sync_grad=False)
|
||||||
|
|
||||||
# torch-ddp backward
|
# torch-ddp backward
|
||||||
torch_output.mean().backward()
|
torch_output.mean().backward()
|
||||||
|
@ -156,7 +156,7 @@ def exam_zero_1_torch_ddp():
|
||||||
half_close(p.grad, z1p.grad, loose=True)
|
half_close(p.grad, z1p.grad, loose=True)
|
||||||
|
|
||||||
# zero-dp step
|
# zero-dp step
|
||||||
zero_optimizer.sync_grad()
|
zero_optimizer._sync_grad()
|
||||||
zero_optimizer.step()
|
zero_optimizer.step()
|
||||||
|
|
||||||
# torch ddp step
|
# torch ddp step
|
||||||
|
|
|
@ -74,7 +74,6 @@ def exam_zero_with_tp(overlap_flag, partition_flag):
|
||||||
torch_loss.backward()
|
torch_loss.backward()
|
||||||
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
|
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
|
||||||
hybrid_optim.backward(hybrid_loss)
|
hybrid_optim.backward(hybrid_loss)
|
||||||
hybrid_optim.sync_grad()
|
|
||||||
|
|
||||||
torch_optim.step()
|
torch_optim.step()
|
||||||
hybrid_optim.step()
|
hybrid_optim.step()
|
||||||
|
|
Loading…
Reference in New Issue