From 496cbb0760379ce32a30c8b4542fec6bb1d5c27a Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Thu, 17 Mar 2022 13:16:22 +0800 Subject: [PATCH] [hotfix] fix initialize bug with zero (#442) --- colossalai/amp/__init__.py | 8 +--- colossalai/amp/apex_amp/__init__.py | 6 +-- colossalai/amp/naive_amp/__init__.py | 2 +- colossalai/amp/torch_amp/__init__.py | 2 +- colossalai/initialize.py | 15 +++++-- colossalai/utils/commons/memory.py | 1 - colossalai/zero/__init__.py | 24 +++++----- .../zero/sharded_model/sharded_model_v2.py | 1 - .../test_cifar_with_data_pipeline_tensor.py | 4 +- tests/test_zero_data_parallel/common.py | 35 ++++++++++----- .../test_sharded_optim_with_sync_bn.py | 2 +- ...st_zero_init_v2.py => test_zero_engine.py} | 45 ++++++++++++------- 12 files changed, 87 insertions(+), 58 deletions(-) rename tests/test_zero_data_parallel/{test_zero_init_v2.py => test_zero_engine.py} (59%) diff --git a/colossalai/amp/__init__.py b/colossalai/amp/__init__.py index 5a30e67fb..e01776613 100644 --- a/colossalai/amp/__init__.py +++ b/colossalai/amp/__init__.py @@ -11,17 +11,13 @@ from .apex_amp import convert_to_apex_amp from .naive_amp import convert_to_naive_amp -def convert_to_amp(model: nn.Module, - optimizer: Optimizer, - criterion: _Loss, - mode: AMP_TYPE, - amp_config: Config = None): +def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None): """A helper function to wrap training components with Torch AMP modules :param model: your model object :type model: :class:`torch.nn.Module` :param optimizer: your optimizer object - :type optimizer: :class:`torch.optim.Optimzer` + :type optimizer: :class:`torch.optim.Optimizer` :param criterion: your loss function object :type criterion: :class:`torch.nn.modules.loss._Loss` :param mode: amp mode diff --git a/colossalai/amp/apex_amp/__init__.py b/colossalai/amp/apex_amp/__init__.py index 23585ede7..678bcd9dd 100644 --- a/colossalai/amp/apex_amp/__init__.py +++ b/colossalai/amp/apex_amp/__init__.py @@ -3,15 +3,13 @@ import torch.nn as nn from torch.optim import Optimizer -def convert_to_apex_amp(model: nn.Module, - optimizer: Optimizer, - amp_config): +def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config): """A helper function to wrap training components with Apex AMP modules :param model: your model object :type model: :class:`torch.nn.Module` :param optimizer: your optimizer object - :type optimizer: :class:`torch.optim.Optimzer` + :type optimizer: :class:`torch.optim.Optimizer` :param amp_config: configuration for nvidia apex :type amp_config: :class:`colossalai.context.Config` or dict diff --git a/colossalai/amp/naive_amp/__init__.py b/colossalai/amp/naive_amp/__init__.py index 2390c199e..c7956651f 100644 --- a/colossalai/amp/naive_amp/__init__.py +++ b/colossalai/amp/naive_amp/__init__.py @@ -12,7 +12,7 @@ def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config): :param model: your model object :type model: :class:`torch.nn.Module` :param optimizer: your optimizer object - :type optimizer: :class:`torch.optim.Optimzer` + :type optimizer: :class:`torch.optim.Optimizer` :param amp_config: configuration for naive mode amp :type amp_config: :class:`colossalai.context.Config` or dict diff --git a/colossalai/amp/torch_amp/__init__.py b/colossalai/amp/torch_amp/__init__.py index 9c9976a5d..b3efcb63a 100644 --- a/colossalai/amp/torch_amp/__init__.py +++ b/colossalai/amp/torch_amp/__init__.py @@ -15,7 +15,7 @@ def convert_to_torch_amp(model: nn.Module, :param model: your model object :type model: :class:`torch.nn.Module` :param optimizer: your optimizer object - :type optimizer: :class:`torch.optim.Optimzer` + :type optimizer: :class:`torch.optim.Optimizer` :param criterion: your loss function object :type criterion: :class:`torch.nn.modules.loss._Loss`, optional :param amp_config: configuration for different amp modes diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 38a66142b..94409619f 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -268,6 +268,7 @@ def initialize(model: Union[Callable, nn.Module], if verbose: logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0]) + # zero use_zero = hasattr(gpc.config, 'zero') if use_zero: zero_cfg = gpc.config.get('zero', None) @@ -275,10 +276,13 @@ def initialize(model: Union[Callable, nn.Module], cfg_ = zero_cfg.copy() else: cfg_ = {} - optimizer_config = zero_cfg.get('optimzer', None) - model, optimizer = convert_to_zero_v2(model_builder=model, optimizer_config=optimizer_config) + optimizer_config = zero_cfg.get('optimizer_config', None) + model_config = zero_cfg.get('model_config', None) + model, optimizer = convert_to_zero_v2(model_builder=model, + model_config=model_config, + optimizer_config=optimizer_config) - logger.info("Initializing ZeRO model and optimzer finished!", ranks=[0]) + logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0]) #FIXME() throw a warning if using zero with MP if gpc.get_world_size(ParallelMode.MODEL) > 1: logger.warning("ZeRO currently has not been tested with model parallelism.", ranks=[0]) @@ -289,6 +293,11 @@ def initialize(model: Union[Callable, nn.Module], elif isinstance(model, Callable): model = model().to(get_current_device()) + # optimizer maybe a optimizer_cls + logger.warning("Initializing an non ZeRO model with optimizer class") + if isinstance(optimizer, Callable): + optimizer = optimizer(model.parameters()) + if not moe_env.is_initialized() and not use_zero: if is_using_sequence(): sync_model_param(model, ParallelMode.SEQUENCE_DP) diff --git a/colossalai/utils/commons/memory.py b/colossalai/utils/commons/memory.py index 9754ae6d2..374871d67 100644 --- a/colossalai/utils/commons/memory.py +++ b/colossalai/utils/commons/memory.py @@ -1,4 +1,3 @@ -import imp import torch from colossalai.utils import get_current_device diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py index 79bf4c11d..1d02c09bb 100644 --- a/colossalai/zero/__init__.py +++ b/colossalai/zero/__init__.py @@ -17,7 +17,7 @@ from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger -def convert_to_zero_v2(model_builder: Callable, optimizer_config) -> (ShardedModelV2, ShardedOptimizerV2): +def convert_to_zero_v2(model_builder: Callable, model_config, optimizer_config) -> (ShardedModelV2, ShardedOptimizerV2): """ A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading @@ -35,28 +35,26 @@ def convert_to_zero_v2(model_builder: Callable, optimizer_config) -> (ShardedMod # FIXME() pass shard strategy from config shard_strategy = TensorShardStrategy() + logger.info(f'optimizer_config is {optimizer_config}') + if optimizer_config is None: + optimizer_config = dict() + logger.info(f'model_config is {model_config}') + if model_config is None: + model_config = dict() + if isinstance(model_builder, nn.Module): model = model_builder elif isinstance(model_builder, Callable): with ZeroInitContext(convert_fp16='fp16' in gpc.config, target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, - shard_param=True): + shard_param=model_config.get('shard_param', True)): model = model_builder() else: raise TypeError(f"convert_to_zero_v2 dose not support model_builder of type {type(convert_to_zero_v2)}") - zero_model = ShardedModelV2(model, shard_strategy=shard_strategy) - - optimizer_class = optimizer_config.get('optimizer_type', None) - if optimizer_class is None: - raise RuntimeError("Set optimizer_class in zero_config") - logger.info(f'optimizer class is {optimizer_class}') - - cfg = optimizer_config.get('optimizer_config', None) - logger.info(f'optimizer_config is {cfg}') - - zero_optimizer = ShardedOptimizerV2(zero_model, optimizer_class, **optimizer_config.get('optimizer_config', None)) + zero_model = ShardedModelV2(model, shard_strategy=shard_strategy, **model_config) + zero_optimizer = ShardedOptimizerV2(zero_model, **optimizer_config) return zero_model, zero_optimizer diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index c33aa9599..060246392 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -1,5 +1,4 @@ import functools -from asyncio.log import logger from collections import OrderedDict from typing import Any, Optional diff --git a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py index 27d1a5e21..2ab907072 100644 --- a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py +++ b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py @@ -10,7 +10,7 @@ from colossalai.amp.amp_type import AMP_TYPE from colossalai.builder import build_pipeline_model from colossalai.engine.schedule import PipelineSchedule from colossalai.logging import get_dist_logger -from colossalai.nn import Accuracy, LinearWarmupLR +from colossalai.nn import LinearWarmupLR from colossalai.nn.loss import CrossEntropyLoss from colossalai.trainer import Trainer, hooks from colossalai.utils import MultiTimer, free_port, get_dataloader @@ -19,7 +19,7 @@ from model_zoo.vit import vit_tiny_patch4_32 from torchvision import transforms from torchvision.datasets import CIFAR10 -BATCH_SIZE = 16 +BATCH_SIZE = 4 NUM_EPOCHS = 60 WARMUP_EPOCHS = 5 CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')), diff --git a/tests/test_zero_data_parallel/common.py b/tests/test_zero_data_parallel/common.py index 4671dc3a5..3236c54e5 100644 --- a/tests/test_zero_data_parallel/common.py +++ b/tests/test_zero_data_parallel/common.py @@ -2,23 +2,38 @@ from functools import partial import torch import torch.distributed as dist -import torch.nn as nn from colossalai.logging import get_dist_logger from colossalai.utils import checkpoint from colossalai.zero.sharded_model import ShardedModelV2 -LOGGER = get_dist_logger() +LOGGER = get_dist_logger('zero_test') -_ZERO_OPTIMIZER_CONFIG = dict(optimizer_type=torch.optim.Adam, optimizer_config=dict(lr=1e-3)) -_ZERO_OFFLOAD_OPTIMIZER_CONFIG = dict(device='cpu', pin_memory=True, buffer_count=5, fast_init=False) -_ZERO_OFFLOAD_PARAM_CONFIG = dict(device='cpu', pin_memory=True, buffer_count=5, buffer_size=1e8, max_in_cpu=1e9) +MP_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), parallel=dict(pipeline=dict(size=1), tensor=dict(size=2, mode=None))) + +_ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25, + fp32_reduce_scatter=False, + offload_config=None, + gradient_predivide_factor=1.0, + shard_param=True, + use_memory_tracer=False) + +_ZERO_OPTIMIZER_CONFIG = dict( + optimizer_class=torch.optim.Adam, + cpu_offload=False, + initial_scale=2**32, + min_scale=1, + growth_factor=2, + backoff_factor=0.5, + growth_interval=1000, + hysteresis=2, + max_scale=2**32, +) ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), zero=dict( - optimzer=_ZERO_OPTIMIZER_CONFIG, - offload_optimizer_config=_ZERO_OFFLOAD_OPTIMIZER_CONFIG, - offload_param_config=_ZERO_OFFLOAD_PARAM_CONFIG, + model_config=_ZERO_MODEL_CONFIG, + optimizer_config=_ZERO_OPTIMIZER_CONFIG, ), parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None))) @@ -72,8 +87,8 @@ def check_grads(model, zero_model, loose=False): def check_params(model, zero_model, loose=False): for p, zero_p in zip(model.parameters(), zero_model.parameters()): zero_p = zero_p.clone().to(p.device) - assert p.dtype == zero_p.dtype - assert allclose(p, zero_p, loose=loose) + # assert p.dtype == zero_p.dtype + assert allclose(p.float(), zero_p.float(), loose=loose), f"diff {p.float() - zero_p.float()}" def check_grads_padding(model, zero_model, loose=False): diff --git a/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py b/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py index 2eecc4802..a50881504 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py @@ -19,7 +19,7 @@ def run_dist(rank, world_size, port): # as this model has sync batch normalization # need to configure cudnn deterministic so that # randomness of convolution layers will be disabled - zero_config = dict(optimzer=dict(optimizer_type=torch.optim.Adam, optimizer_config=dict(lr=1e-3))) + zero_config = dict(optimizer_config=dict(optimizer_class=torch.optim.Adam, lr=1e-3)) colossalai.launch(config=dict(zero=zero_config, cudnn_determinstic=True, cudnn_benchmark=False), rank=rank, world_size=world_size, diff --git a/tests/test_zero_data_parallel/test_zero_init_v2.py b/tests/test_zero_data_parallel/test_zero_engine.py similarity index 59% rename from tests/test_zero_data_parallel/test_zero_init_v2.py rename to tests/test_zero_data_parallel/test_zero_engine.py index f7696eef5..cdd2bbc5e 100644 --- a/tests/test_zero_data_parallel/test_zero_init_v2.py +++ b/tests/test_zero_data_parallel/test_zero_engine.py @@ -3,19 +3,22 @@ import copy from functools import partial +from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 import pytest import colossalai from colossalai.utils import free_port import torch.multiprocessing as mp +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP from tests.components_to_test.registry import non_distributed_component_funcs -from common import check_sharded_params_padding, ZERO_PARALLEL_CONFIG +from common import check_sharded_params_padding, ZERO_PARALLEL_CONFIG, MP_PARALLEL_CONFIG, check_params -def run_dist(rank, world_size, port): - colossalai.launch(config=ZERO_PARALLEL_CONFIG, +def run_dist(rank, world_size, port, parallel_config): + colossalai.launch(config=parallel_config, rank=rank, world_size=world_size, host='localhost', @@ -27,22 +30,21 @@ def run_dist(rank, world_size, port): get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() - # adapt to a Callbale with empty parameters - # def module_builder_new(): - # return model_builder(checkpoint=True) - - zero_model = model_builder(checkpoint=True) - torch_model = copy.deepcopy(zero_model).cuda() - engine, train_dataloader, _, _ = colossalai.initialize(zero_model, + colo_model = model_builder(checkpoint=True) + torch_model = copy.deepcopy(colo_model).cuda() + engine, train_dataloader, _, _ = colossalai.initialize(colo_model, optimizer=optimizer_class, criterion=criterion, train_dataloader=train_dataloader) engine.train() torch_optimizer = optimizer_class(torch_model.parameters()) + if dist.get_world_size() > 1: + torch_model = DDP(torch_model) + i = 0 for data, label in train_dataloader: - if i > 3: + if i > 4: break data, label = data.cuda(), label.cuda() @@ -67,15 +69,28 @@ def run_dist(rank, world_size, port): torch_optimizer.step() i += 1 - check_sharded_params_padding(torch_model, zero_model, loose=True) + # for torch_param, zero_param in zip(torch_model.parameters(), colo_model.parameters()): + # assert torch.allclose(torch_param, zero_param), f"diff {torch_param - zero_param}" + + if parallel_config == MP_PARALLEL_CONFIG: + check_params(torch_model, colo_model, loose=True) + elif isinstance(colo_model, ShardedModelV2): + check_sharded_params_padding(torch_model, colo_model, loose=True) + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2, 4]) +def test_mp_engine(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=MP_PARALLEL_CONFIG) + mp.spawn(run_func, nprocs=world_size) @pytest.mark.dist @pytest.mark.parametrize("world_size", [1, 2]) -def test_zero_init(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) +def test_zero_engine(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=ZERO_PARALLEL_CONFIG) mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': - test_zero_init(world_size=2) + test_zero_engine(world_size=4)