From b528eea0f05162bfedcd06381b953193c2a91b82 Mon Sep 17 00:00:00 2001 From: HELSON Date: Sun, 29 Jan 2023 17:52:58 +0800 Subject: [PATCH] [zero] add zero wrappers (#2523) * [zero] add zero wrappers * change names * add wrapper functions to init --- colossalai/nn/optimizer/zero_optimizer.py | 3 +- colossalai/nn/parallel/__init__.py | 3 +- colossalai/nn/parallel/zero_wrapper.py | 106 ++++++++++++++++++ .../zero/sharded_optim/low_level_optim.py | 9 +- .../test_zero/low_level_zero/test_grad_acc.py | 13 +-- .../test_zero/low_level_zero/test_zero1_2.py | 12 +- .../test_zero/low_level_zero/test_zero_tp.py | 1 - 7 files changed, 128 insertions(+), 19 deletions(-) create mode 100644 colossalai/nn/parallel/zero_wrapper.py diff --git a/colossalai/nn/optimizer/zero_optimizer.py b/colossalai/nn/optimizer/zero_optimizer.py index 9f761efdb..402e28ce8 100644 --- a/colossalai/nn/optimizer/zero_optimizer.py +++ b/colossalai/nn/optimizer/zero_optimizer.py @@ -65,7 +65,8 @@ class ZeroOptimizer(ColossalaiOptimizer): **defaults: Any): super().__init__(optim) assert isinstance(module, ZeroDDP) - assert type(optim) in _AVAIL_OPTIM_LIST, "you should use the optimizer in the available list" + assert type(optim) in _AVAIL_OPTIM_LIST, "You should use an optimizer in the available list:\n" \ + f"{_AVAIL_OPTIM_LIST}" self.module = module self.gemini_manager = module.gemini_manager self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager diff --git a/colossalai/nn/parallel/__init__.py b/colossalai/nn/parallel/__init__.py index 0c369bfce..2afc8f18c 100644 --- a/colossalai/nn/parallel/__init__.py +++ b/colossalai/nn/parallel/__init__.py @@ -1,4 +1,5 @@ from .data_parallel import ColoDDP, ZeroDDP from .gemini_parallel import GeminiDDP +from .zero_wrapper import zero_model_wrapper, zero_optim_wrapper -__all__ = ['ColoDDP', 'ZeroDDP', 'GeminiDDP'] +__all__ = ['ColoDDP', 'ZeroDDP', 'GeminiDDP', 'zero_model_wrapper', 'zero_optim_wrapper'] diff --git a/colossalai/nn/parallel/zero_wrapper.py b/colossalai/nn/parallel/zero_wrapper.py new file mode 100644 index 000000000..504625e62 --- /dev/null +++ b/colossalai/nn/parallel/zero_wrapper.py @@ -0,0 +1,106 @@ +from copy import copy +from typing import Dict, Optional + +import torch +import torch.nn as nn + +from .gemini_parallel import GeminiDDP + + +def zero_model_wrapper(model: nn.Module, zero_stage: int = 1, gemini_config: Optional[Dict] = None): + """This wrapper function is used to wrap your training model for ZeRO DDP. + + Example: + + >>> with ColoInitContext(): + >>> my_model = Bert() + >>> my_optim = SGD(my_model.parameters(), lr = 1e-3) + >>> zero_model = zero_model_wrapper(my_model, zero_stage=1) + >>> zero_optim = zero_optim_wrapper(zero_model, my_optim) + + Args: + model (nn.Module): The model used in ZeRO DDP. + zero_stage (int, optional): The stage of ZeRO DDP. You can find more information in ZeRO's paper. + https://arxiv.org/abs/1910.02054 + gemini_config (dict, optional): The configuration dictionary of `GeminiDDP`. `GeminiDDP` is enabled + when the stage is set to 3. You can set the arguemnts of `GeminiDDP` in the gemini_config. + Here is an example where we set the device of the model, the placement policy of Gemini, and the + size of hidden dimension to help Gemini find out a unified chunk size. + + Example: + + >>> config_dict = dict(device=torch.cuda.current_device(), hidden_dim=1024, placement_policy='auto') + >>> model = zero_model_wrapper(model, zero_stage=3, gemini_config=config_dict) + """ + setattr(model, "_colo_zero_stage", zero_stage) + assert zero_stage in [1, 2, 3], "The stage of ZeRO should be 1, 2 or 3" + + if gemini_config is None: + gemini_config = dict() + + if zero_stage in [1, 2]: + return model + else: + return GeminiDDP(model, **gemini_config) + + +def zero_optim_wrapper(model: nn.Module, + optimizer: torch.optim.Optimizer, + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + min_scale: float = 1, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + optim_config: Optional[Dict] = None): + """This wrapper function is used to wrap your training optimizer for ZeRO DDP. + + Args: + model (nn.Module): Your model wrapped by `zero_model_wrapper` + optimizer (torch.optim.Optimizer): Your initialized optimizer + initial_scale (float, optional): initial_scale used by DynamicGradScaler. + min_scale (float, optional): min_scale used by DynamicGradScaler. + growth_factor (float, optional): growth_factor used by DynamicGradScaler. + backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. + growth_interval (float, optional): growth_interval used by DynamicGradScaler. + hysteresis (float, optional): hysteresis used by DynamicGradScaler. + max_scale (int, optional): max_scale used by DynamicGradScaler. + max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do + clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm. + norm_type (float, optional): norm_type used for `clip_grad_norm`. + optim_config (dict, optinoal): The configuration used for the ZeRO optimizer. + Example: + + >>> zero2_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True) + >>> optim = zero_optim_wrapper(model, optim, optim_config=zero2_config) + """ + assert hasattr(model, "_colo_zero_stage"), "You should use `zero_ddp_wrapper` first" + zero_stage = getattr(model, "_colo_zero_stage") + + assert norm_type == 2.0, "Current ZeRO optimizers only support 'norm_type=2'" + + if optim_config is None: + config_dict = dict() + else: + config_dict = copy(optim_config) + + config_dict['initial_scale'] = initial_scale + config_dict['growth_factor'] = growth_factor + config_dict['backoff_factor'] = backoff_factor + config_dict['growth_interval'] = growth_interval + config_dict['hysteresis'] = hysteresis + config_dict['min_scale'] = min_scale + config_dict['max_scale'] = max_scale + + if zero_stage in [1, 2]: + from colossalai.zero.sharded_optim.low_level_optim import LowLevelZeroOptimizer + config_dict['partition_grad'] = zero_stage == 2 + config_dict['clip_grad_norm'] = max_norm + return LowLevelZeroOptimizer(optimizer, **config_dict) + else: + from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer + config_dict['clipping_norm'] = max_norm + return ZeroOptimizer(optimizer, model, **config_dict) diff --git a/colossalai/zero/sharded_optim/low_level_optim.py b/colossalai/zero/sharded_optim/low_level_optim.py index f45b5e200..d174fc6ac 100644 --- a/colossalai/zero/sharded_optim/low_level_optim.py +++ b/colossalai/zero/sharded_optim/low_level_optim.py @@ -17,7 +17,6 @@ from ._utils import ( calculate_global_norm_from_list, compute_norm, flatten, - get_grad_accumulate_object, has_inf_or_nan, reduce_tensor_dp_group, release_param_grad, @@ -386,7 +385,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # torch.optim.Optimizer methods ################################ - def backward(self, loss, retain_graph=False): + def backward(self, loss, retain_graph=False, sync_grad=True): loss = self.loss_scale * loss loss.backward(retain_graph=retain_graph) @@ -402,6 +401,10 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): torch.cuda.synchronize() self._param_store.clear_grads_of_previous_reduced_params() + # gradient synchronization + if sync_grad: + self._sync_grad() + def zero_grad(self, set_to_none=True): """ Set parameter gradients to zero. If set_to_none = True, gradient @@ -537,7 +540,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # Gradient Synchronization # ############################ - def sync_grad(self): + def _sync_grad(self): # update param already reduced flag reduction_states = self._param_store.get_param_reduction_states() for tensor, state in reduction_states.items(): diff --git a/tests/test_zero/low_level_zero/test_grad_acc.py b/tests/test_zero/low_level_zero/test_grad_acc.py index 1e157c70a..504df202e 100644 --- a/tests/test_zero/low_level_zero/test_grad_acc.py +++ b/tests/test_zero/low_level_zero/test_grad_acc.py @@ -9,7 +9,6 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai -from colossalai.tensor import ProcessGroup from colossalai.testing.random import seed_all from colossalai.utils import free_port from colossalai.zero import LowLevelZeroOptimizer @@ -60,16 +59,16 @@ def exam_zero_1_2_grad_acc(): assert torch.equal(zero1_output, zero2_output) # zero-dp backward - zero1_optimizer.backward(zero1_output.sum().float()) - zero2_optimizer.backward(zero2_output.sum().float()) + zero1_optimizer.backward(zero1_output.sum().float(), sync_grad=False) + zero2_optimizer.backward(zero2_output.sum().float(), sync_grad=False) for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): if z2p.grad is not None: # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) assert torch.equal(z1p.grad, z2p.grad) - zero1_optimizer.sync_grad() - zero2_optimizer.sync_grad() + zero1_optimizer._sync_grad() + zero2_optimizer._sync_grad() fwd_bwd_func(0, input_data1) fwd_bwd_func(1, input_data2) @@ -124,7 +123,7 @@ def exam_zero_1_grad_acc(): assert torch.equal(zero_output, torch_output) # zero-dp backward - zero_optimizer.backward(zero_output.sum().float()) + zero_optimizer.backward(zero_output.sum().float(), sync_grad=False) # torch-ddp backward torch_output.sum().backward() @@ -135,7 +134,7 @@ def exam_zero_1_grad_acc(): # print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad))) assert torch.equal(p.grad, unscale_grad) - zero_optimizer.sync_grad() + zero_optimizer._sync_grad() fwd_bwd_func(0, input_data1, True) fwd_bwd_func(1, input_data2, False) diff --git a/tests/test_zero/low_level_zero/test_zero1_2.py b/tests/test_zero/low_level_zero/test_zero1_2.py index 494963072..930b61291 100644 --- a/tests/test_zero/low_level_zero/test_zero1_2.py +++ b/tests/test_zero/low_level_zero/test_zero1_2.py @@ -78,16 +78,16 @@ def exam_zero_1_2(): assert torch.equal(zero1_output, zero2_output) # zero-dp backward - zero1_optimizer.backward(zero1_output.mean().float()) - zero2_optimizer.backward(zero2_output.mean().float()) + zero1_optimizer.backward(zero1_output.mean().float(), sync_grad=False) + zero2_optimizer.backward(zero2_output.mean().float(), sync_grad=False) for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): if z2p.grad is not None: # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) assert torch.equal(z1p.grad, z2p.grad) - zero1_optimizer.sync_grad() - zero2_optimizer.sync_grad() + zero1_optimizer._sync_grad() + zero2_optimizer._sync_grad() # step zero1_optimizer.step() @@ -146,7 +146,7 @@ def exam_zero_1_torch_ddp(): half_close(zero_output, torch_output, loose=True) # zero-dp backward - zero_optimizer.backward(zero_output.mean().float()) + zero_optimizer.backward(zero_output.mean().float(), sync_grad=False) # torch-ddp backward torch_output.mean().backward() @@ -156,7 +156,7 @@ def exam_zero_1_torch_ddp(): half_close(p.grad, z1p.grad, loose=True) # zero-dp step - zero_optimizer.sync_grad() + zero_optimizer._sync_grad() zero_optimizer.step() # torch ddp step diff --git a/tests/test_zero/low_level_zero/test_zero_tp.py b/tests/test_zero/low_level_zero/test_zero_tp.py index ea8e3a0a3..15d3530ff 100644 --- a/tests/test_zero/low_level_zero/test_zero_tp.py +++ b/tests/test_zero/low_level_zero/test_zero_tp.py @@ -74,7 +74,6 @@ def exam_zero_with_tp(overlap_flag, partition_flag): torch_loss.backward() torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0) hybrid_optim.backward(hybrid_loss) - hybrid_optim.sync_grad() torch_optim.step() hybrid_optim.step()