diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 011859881..38a66142b 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -5,7 +5,7 @@ import argparse import os import pprint from pathlib import Path -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union, Type import torch import torch.nn as nn @@ -26,8 +26,9 @@ from colossalai.logging import get_dist_logger from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param) -from colossalai.zero import convert_to_zero, ShardedOptimizer +from colossalai.zero import convert_to_zero_v2 from colossalai.engine.ophooks import BaseOpHook +from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2 def get_default_parser(): @@ -216,8 +217,8 @@ def launch_from_torch(config: Union[str, Path, Config, Dict], verbose=verbose) -def initialize(model: nn.Module, - optimizer: Optimizer, +def initialize(model: Union[Callable, nn.Module], + optimizer: Union[Type[Optimizer], Optimizer], criterion: Optional[_Loss] = None, train_dataloader: Optional[Iterable] = None, test_dataloader: Optional[Iterable] = None, @@ -227,10 +228,10 @@ def initialize(model: nn.Module, """Core function to wrap the essential training components with our functionality based on the config which is loaded into gpc.config. - :param model: Your model instance - :type model: :class:`torch.nn.Module` + :param model: Your model instance or a function to build the model + :type model: :class:`torch.nn.Module` or Callbale :param optimizer: Your optimizer instance - :type optimizer: :class:`torch.optim.optimizer.Optimizer` + :type optimizer: :class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]` :param criterion: Your criterion instance :type criterion: :class:`torch.nn.modules.loss._Loss`, optional :param train_dataloader: Dataloader for training @@ -267,10 +268,28 @@ def initialize(model: nn.Module, if verbose: logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0]) - # first sync model across dp ranks - model.to(get_current_device()) - use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3 - if not moe_env.is_initialized() and not use_zero3: + use_zero = hasattr(gpc.config, 'zero') + if use_zero: + zero_cfg = gpc.config.get('zero', None) + if zero_cfg is not None: + cfg_ = zero_cfg.copy() + else: + cfg_ = {} + optimizer_config = zero_cfg.get('optimzer', None) + model, optimizer = convert_to_zero_v2(model_builder=model, optimizer_config=optimizer_config) + + logger.info("Initializing ZeRO model and optimzer finished!", ranks=[0]) + #FIXME() throw a warning if using zero with MP + if gpc.get_world_size(ParallelMode.MODEL) > 1: + logger.warning("ZeRO currently has not been tested with model parallelism.", ranks=[0]) + else: + if isinstance(model, nn.Module): + # first sync model across dp ranks + model.to(get_current_device()) + elif isinstance(model, Callable): + model = model().to(get_current_device()) + + if not moe_env.is_initialized() and not use_zero: if is_using_sequence(): sync_model_param(model, ParallelMode.SEQUENCE_DP) elif is_using_ddp(): @@ -283,16 +302,15 @@ def initialize(model: nn.Module, # check amp and zero fp16_cfg = gpc.config.get('fp16', None) - zero_cfg = gpc.config.get('zero', None) - if fp16_cfg is not None and fp16_cfg.mode is not None and zero_cfg is not None: + if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero: raise ConfigException( "It is not allowed to set fp16 and zero configuration in your config file at the same time") # clip grad norm clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0) if clip_grad_norm > 0: - if zero_cfg is not None: + if use_zero and zero_cfg is not None: raise ConfigException( "clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration") @@ -311,11 +329,6 @@ def initialize(model: nn.Module, mode=amp_mode, amp_config=cfg_) - if zero_cfg is not None: - cfg_ = zero_cfg.copy() - level = cfg_.pop('level') - model, optimizer = convert_to_zero(model=model, optimizer=optimizer, level=level, zero_config=cfg_) - # gradient handler gradient_handler_cfg = gpc.config.get('gradient_handler', None) if gradient_handler_cfg is None: @@ -324,7 +337,7 @@ def initialize(model: nn.Module, # 1. if optimizer is ZERO, then use zero grad handler # 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp # 3. if using pipeline and dp size larger than 1, use data parallel grad handler - if isinstance(optimizer, ShardedOptimizer): + if isinstance(optimizer, ShardedOptimizerV2): gradient_handler_cfg = [dict(type='ZeROGradientHandler')] if verbose: logger.info( @@ -392,7 +405,7 @@ def initialize(model: nn.Module, gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg] # check if optimizer is ColossalaiOptimizer - if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizer)): + if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizerV2)): optimizer = ColossalaiOptimizer(optim=optimizer) # gradient accumulation diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py index 95186233a..79bf4c11d 100644 --- a/colossalai/zero/__init__.py +++ b/colossalai/zero/__init__.py @@ -1,4 +1,8 @@ +from asyncio.log import logger from distutils.command.config import config +from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 +from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2 +from colossalai.zero.shard_utils import TensorShardStrategy import torch import torch.nn as nn from colossalai.amp.naive_amp import NaiveAMPModel @@ -7,6 +11,53 @@ from colossalai.core import global_context as gpc from torch.optim import Optimizer from .sharded_model import ShardedModel from .sharded_optim import ShardedOptimizer +from colossalai.zero.init_ctx import ZeroInitContext +from typing import Callable, Type +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger + + +def convert_to_zero_v2(model_builder: Callable, optimizer_config) -> (ShardedModelV2, ShardedOptimizerV2): + """ + A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading + + :param model: Your model object + :type model: :class:`torch.nn.Module` + :param optimizer_config: Your optimizer object + :type optimizer_config: :class:`dict` + + :return: (model, optimizer) + :rtype: Tuple + """ + + logger = get_dist_logger('convert_to_zero_v2') + + # FIXME() pass shard strategy from config + shard_strategy = TensorShardStrategy() + + if isinstance(model_builder, nn.Module): + model = model_builder + elif isinstance(model_builder, Callable): + with ZeroInitContext(convert_fp16='fp16' in gpc.config, + target_device=torch.cuda.current_device(), + shard_strategy=shard_strategy, + shard_param=True): + model = model_builder() + else: + raise TypeError(f"convert_to_zero_v2 dose not support model_builder of type {type(convert_to_zero_v2)}") + + zero_model = ShardedModelV2(model, shard_strategy=shard_strategy) + + optimizer_class = optimizer_config.get('optimizer_type', None) + if optimizer_class is None: + raise RuntimeError("Set optimizer_class in zero_config") + logger.info(f'optimizer class is {optimizer_class}') + + cfg = optimizer_config.get('optimizer_config', None) + logger.info(f'optimizer_config is {cfg}') + + zero_optimizer = ShardedOptimizerV2(zero_model, optimizer_class, **optimizer_config.get('optimizer_config', None)) + return zero_model, zero_optimizer def convert_to_zero(model: nn.Module, optimizer: Optimizer, level: int, zero_config: dict): diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 19f9c343b..a7a24ef64 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -223,3 +223,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # Because we will judge whether local grad accumulation # is enabled by wheter grad is None self.optim.zero_grad(set_to_none=True) + + def sync_grad(self): + pass diff --git a/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py b/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py index 41125e3a9..da1f4edf2 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py @@ -19,9 +19,10 @@ def run_dist(rank, world_size, port): # as this model has sync batch normalization # need to configure cudnn deterministic so that # randomness of convolution layers will be disabled - colossalai.launch(config=dict(zero=dict(level=2, partition_grad=True), - cudnn_determinstic=True, - cudnn_benchmark=False), + colossalai.launch(config=dict( + zero=dict(optimzer=dict(optimizer_type=torch.optim.Adam, optimizer_config=dict(lr=1e-3))), + cudnn_determinstic=True, + cudnn_benchmark=False), rank=rank, world_size=world_size, host='localhost', diff --git a/tests/test_zero_data_parallel/test_zero_init_v2.py b/tests/test_zero_data_parallel/test_zero_init_v2.py new file mode 100644 index 000000000..cc6d3d4d3 --- /dev/null +++ b/tests/test_zero_data_parallel/test_zero_init_v2.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import copy +from functools import partial +import pytest + +import colossalai +from colossalai.utils import free_port + +import torch +import torch.multiprocessing as mp + +from tests.components_to_test.registry import non_distributed_component_funcs + +from common import check_sharded_params_padding + + +def run_dist(rank, world_size, port): + _config = dict(fp16=dict(mode=None,), + zero=dict(optimzer=dict(optimizer_type=torch.optim.Adam, optimizer_config=dict(lr=1e-3)), + offload_optimizer_config=dict(device='cpu', + pin_memory=True, + buffer_count=5, + fast_init=False), + offload_param_config=dict(device='cpu', + pin_memory=True, + buffer_count=5, + buffer_size=1e8, + max_in_cpu=1e9)), + parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None))) + + colossalai.launch(config=_config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + # FIXME revert back + # test_models = ['repeated_computed_layers', 'resnet18', 'bert'] + test_models = ['bert'] + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() + + # adapt to a Callbale with empty parameters + # def module_builder_new(): + # return model_builder(checkpoint=True) + + zero_model = model_builder(checkpoint=True) + torch_model = copy.deepcopy(zero_model).cuda() + engine, train_dataloader, _, _ = colossalai.initialize(zero_model, + optimizer=optimizer_class, + criterion=criterion, + train_dataloader=train_dataloader) + engine.train() + torch_optimizer = optimizer_class(torch_model.parameters()) + + i = 0 + for data, label in train_dataloader: + if i > 3: + break + + data, label = data.cuda(), label.cuda() + + engine.zero_grad() + torch_optimizer.zero_grad() + + if criterion: + output = engine(data) + loss = engine.criterion(output, label) + + torch_model(data, label) + torch_loss = engine.criterion(output, label) + else: + loss = engine(data, label) + torch_loss = torch_model(data, label) + + engine.backward(loss) + engine.step() + + torch_loss.backward() + torch_optimizer.step() + i += 1 + + check_sharded_params_padding(torch_model, zero_model, loose=True) + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 2]) +def test_zero_init(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_zero_init(world_size=2)