[hotfix] fix initialize bug with zero (#442)

pull/445/head
Jiarui Fang 2022-03-17 13:16:22 +08:00 committed by GitHub
parent 725a39f4bd
commit 496cbb0760
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 87 additions and 58 deletions

View File

@ -11,17 +11,13 @@ from .apex_amp import convert_to_apex_amp
from .naive_amp import convert_to_naive_amp from .naive_amp import convert_to_naive_amp
def convert_to_amp(model: nn.Module, def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None):
optimizer: Optimizer,
criterion: _Loss,
mode: AMP_TYPE,
amp_config: Config = None):
"""A helper function to wrap training components with Torch AMP modules """A helper function to wrap training components with Torch AMP modules
:param model: your model object :param model: your model object
:type model: :class:`torch.nn.Module` :type model: :class:`torch.nn.Module`
:param optimizer: your optimizer object :param optimizer: your optimizer object
:type optimizer: :class:`torch.optim.Optimzer` :type optimizer: :class:`torch.optim.Optimizer`
:param criterion: your loss function object :param criterion: your loss function object
:type criterion: :class:`torch.nn.modules.loss._Loss` :type criterion: :class:`torch.nn.modules.loss._Loss`
:param mode: amp mode :param mode: amp mode

View File

@ -3,15 +3,13 @@ import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
def convert_to_apex_amp(model: nn.Module, def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config):
optimizer: Optimizer,
amp_config):
"""A helper function to wrap training components with Apex AMP modules """A helper function to wrap training components with Apex AMP modules
:param model: your model object :param model: your model object
:type model: :class:`torch.nn.Module` :type model: :class:`torch.nn.Module`
:param optimizer: your optimizer object :param optimizer: your optimizer object
:type optimizer: :class:`torch.optim.Optimzer` :type optimizer: :class:`torch.optim.Optimizer`
:param amp_config: configuration for nvidia apex :param amp_config: configuration for nvidia apex
:type amp_config: :class:`colossalai.context.Config` or dict :type amp_config: :class:`colossalai.context.Config` or dict

View File

@ -12,7 +12,7 @@ def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
:param model: your model object :param model: your model object
:type model: :class:`torch.nn.Module` :type model: :class:`torch.nn.Module`
:param optimizer: your optimizer object :param optimizer: your optimizer object
:type optimizer: :class:`torch.optim.Optimzer` :type optimizer: :class:`torch.optim.Optimizer`
:param amp_config: configuration for naive mode amp :param amp_config: configuration for naive mode amp
:type amp_config: :class:`colossalai.context.Config` or dict :type amp_config: :class:`colossalai.context.Config` or dict

View File

@ -15,7 +15,7 @@ def convert_to_torch_amp(model: nn.Module,
:param model: your model object :param model: your model object
:type model: :class:`torch.nn.Module` :type model: :class:`torch.nn.Module`
:param optimizer: your optimizer object :param optimizer: your optimizer object
:type optimizer: :class:`torch.optim.Optimzer` :type optimizer: :class:`torch.optim.Optimizer`
:param criterion: your loss function object :param criterion: your loss function object
:type criterion: :class:`torch.nn.modules.loss._Loss`, optional :type criterion: :class:`torch.nn.modules.loss._Loss`, optional
:param amp_config: configuration for different amp modes :param amp_config: configuration for different amp modes

View File

@ -268,6 +268,7 @@ def initialize(model: Union[Callable, nn.Module],
if verbose: if verbose:
logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0]) logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
# zero
use_zero = hasattr(gpc.config, 'zero') use_zero = hasattr(gpc.config, 'zero')
if use_zero: if use_zero:
zero_cfg = gpc.config.get('zero', None) zero_cfg = gpc.config.get('zero', None)
@ -275,10 +276,13 @@ def initialize(model: Union[Callable, nn.Module],
cfg_ = zero_cfg.copy() cfg_ = zero_cfg.copy()
else: else:
cfg_ = {} cfg_ = {}
optimizer_config = zero_cfg.get('optimzer', None) optimizer_config = zero_cfg.get('optimizer_config', None)
model, optimizer = convert_to_zero_v2(model_builder=model, optimizer_config=optimizer_config) model_config = zero_cfg.get('model_config', None)
model, optimizer = convert_to_zero_v2(model_builder=model,
model_config=model_config,
optimizer_config=optimizer_config)
logger.info("Initializing ZeRO model and optimzer finished!", ranks=[0]) logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0])
#FIXME() throw a warning if using zero with MP #FIXME() throw a warning if using zero with MP
if gpc.get_world_size(ParallelMode.MODEL) > 1: if gpc.get_world_size(ParallelMode.MODEL) > 1:
logger.warning("ZeRO currently has not been tested with model parallelism.", ranks=[0]) logger.warning("ZeRO currently has not been tested with model parallelism.", ranks=[0])
@ -289,6 +293,11 @@ def initialize(model: Union[Callable, nn.Module],
elif isinstance(model, Callable): elif isinstance(model, Callable):
model = model().to(get_current_device()) model = model().to(get_current_device())
# optimizer maybe a optimizer_cls
logger.warning("Initializing an non ZeRO model with optimizer class")
if isinstance(optimizer, Callable):
optimizer = optimizer(model.parameters())
if not moe_env.is_initialized() and not use_zero: if not moe_env.is_initialized() and not use_zero:
if is_using_sequence(): if is_using_sequence():
sync_model_param(model, ParallelMode.SEQUENCE_DP) sync_model_param(model, ParallelMode.SEQUENCE_DP)

View File

@ -1,4 +1,3 @@
import imp
import torch import torch
from colossalai.utils import get_current_device from colossalai.utils import get_current_device

View File

@ -17,7 +17,7 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
def convert_to_zero_v2(model_builder: Callable, optimizer_config) -> (ShardedModelV2, ShardedOptimizerV2): def convert_to_zero_v2(model_builder: Callable, model_config, optimizer_config) -> (ShardedModelV2, ShardedOptimizerV2):
""" """
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
@ -35,28 +35,26 @@ def convert_to_zero_v2(model_builder: Callable, optimizer_config) -> (ShardedMod
# FIXME() pass shard strategy from config # FIXME() pass shard strategy from config
shard_strategy = TensorShardStrategy() shard_strategy = TensorShardStrategy()
logger.info(f'optimizer_config is {optimizer_config}')
if optimizer_config is None:
optimizer_config = dict()
logger.info(f'model_config is {model_config}')
if model_config is None:
model_config = dict()
if isinstance(model_builder, nn.Module): if isinstance(model_builder, nn.Module):
model = model_builder model = model_builder
elif isinstance(model_builder, Callable): elif isinstance(model_builder, Callable):
with ZeroInitContext(convert_fp16='fp16' in gpc.config, with ZeroInitContext(convert_fp16='fp16' in gpc.config,
target_device=torch.cuda.current_device(), target_device=torch.cuda.current_device(),
shard_strategy=shard_strategy, shard_strategy=shard_strategy,
shard_param=True): shard_param=model_config.get('shard_param', True)):
model = model_builder() model = model_builder()
else: else:
raise TypeError(f"convert_to_zero_v2 dose not support model_builder of type {type(convert_to_zero_v2)}") raise TypeError(f"convert_to_zero_v2 dose not support model_builder of type {type(convert_to_zero_v2)}")
zero_model = ShardedModelV2(model, shard_strategy=shard_strategy) zero_model = ShardedModelV2(model, shard_strategy=shard_strategy, **model_config)
zero_optimizer = ShardedOptimizerV2(zero_model, **optimizer_config)
optimizer_class = optimizer_config.get('optimizer_type', None)
if optimizer_class is None:
raise RuntimeError("Set optimizer_class in zero_config")
logger.info(f'optimizer class is {optimizer_class}')
cfg = optimizer_config.get('optimizer_config', None)
logger.info(f'optimizer_config is {cfg}')
zero_optimizer = ShardedOptimizerV2(zero_model, optimizer_class, **optimizer_config.get('optimizer_config', None))
return zero_model, zero_optimizer return zero_model, zero_optimizer

View File

@ -1,5 +1,4 @@
import functools import functools
from asyncio.log import logger
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Optional from typing import Any, Optional

View File

@ -10,7 +10,7 @@ from colossalai.amp.amp_type import AMP_TYPE
from colossalai.builder import build_pipeline_model from colossalai.builder import build_pipeline_model
from colossalai.engine.schedule import PipelineSchedule from colossalai.engine.schedule import PipelineSchedule
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn import Accuracy, LinearWarmupLR from colossalai.nn import LinearWarmupLR
from colossalai.nn.loss import CrossEntropyLoss from colossalai.nn.loss import CrossEntropyLoss
from colossalai.trainer import Trainer, hooks from colossalai.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, free_port, get_dataloader from colossalai.utils import MultiTimer, free_port, get_dataloader
@ -19,7 +19,7 @@ from model_zoo.vit import vit_tiny_patch4_32
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
BATCH_SIZE = 16 BATCH_SIZE = 4
NUM_EPOCHS = 60 NUM_EPOCHS = 60
WARMUP_EPOCHS = 5 WARMUP_EPOCHS = 5
CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')), CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),

View File

@ -2,23 +2,38 @@ from functools import partial
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import checkpoint from colossalai.utils import checkpoint
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
LOGGER = get_dist_logger() LOGGER = get_dist_logger('zero_test')
_ZERO_OPTIMIZER_CONFIG = dict(optimizer_type=torch.optim.Adam, optimizer_config=dict(lr=1e-3)) MP_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), parallel=dict(pipeline=dict(size=1), tensor=dict(size=2, mode=None)))
_ZERO_OFFLOAD_OPTIMIZER_CONFIG = dict(device='cpu', pin_memory=True, buffer_count=5, fast_init=False)
_ZERO_OFFLOAD_PARAM_CONFIG = dict(device='cpu', pin_memory=True, buffer_count=5, buffer_size=1e8, max_in_cpu=1e9) _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
fp32_reduce_scatter=False,
offload_config=None,
gradient_predivide_factor=1.0,
shard_param=True,
use_memory_tracer=False)
_ZERO_OPTIMIZER_CONFIG = dict(
optimizer_class=torch.optim.Adam,
cpu_offload=False,
initial_scale=2**32,
min_scale=1,
growth_factor=2,
backoff_factor=0.5,
growth_interval=1000,
hysteresis=2,
max_scale=2**32,
)
ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,),
zero=dict( zero=dict(
optimzer=_ZERO_OPTIMIZER_CONFIG, model_config=_ZERO_MODEL_CONFIG,
offload_optimizer_config=_ZERO_OFFLOAD_OPTIMIZER_CONFIG, optimizer_config=_ZERO_OPTIMIZER_CONFIG,
offload_param_config=_ZERO_OFFLOAD_PARAM_CONFIG,
), ),
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None))) parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
@ -72,8 +87,8 @@ def check_grads(model, zero_model, loose=False):
def check_params(model, zero_model, loose=False): def check_params(model, zero_model, loose=False):
for p, zero_p in zip(model.parameters(), zero_model.parameters()): for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_p = zero_p.clone().to(p.device) zero_p = zero_p.clone().to(p.device)
assert p.dtype == zero_p.dtype # assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose) assert allclose(p.float(), zero_p.float(), loose=loose), f"diff {p.float() - zero_p.float()}"
def check_grads_padding(model, zero_model, loose=False): def check_grads_padding(model, zero_model, loose=False):

View File

@ -19,7 +19,7 @@ def run_dist(rank, world_size, port):
# as this model has sync batch normalization # as this model has sync batch normalization
# need to configure cudnn deterministic so that # need to configure cudnn deterministic so that
# randomness of convolution layers will be disabled # randomness of convolution layers will be disabled
zero_config = dict(optimzer=dict(optimizer_type=torch.optim.Adam, optimizer_config=dict(lr=1e-3))) zero_config = dict(optimizer_config=dict(optimizer_class=torch.optim.Adam, lr=1e-3))
colossalai.launch(config=dict(zero=zero_config, cudnn_determinstic=True, cudnn_benchmark=False), colossalai.launch(config=dict(zero=zero_config, cudnn_determinstic=True, cudnn_benchmark=False),
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,

View File

@ -3,19 +3,22 @@
import copy import copy
from functools import partial from functools import partial
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
import pytest import pytest
import colossalai import colossalai
from colossalai.utils import free_port from colossalai.utils import free_port
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from common import check_sharded_params_padding, ZERO_PARALLEL_CONFIG from common import check_sharded_params_padding, ZERO_PARALLEL_CONFIG, MP_PARALLEL_CONFIG, check_params
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port, parallel_config):
colossalai.launch(config=ZERO_PARALLEL_CONFIG, colossalai.launch(config=parallel_config,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
host='localhost', host='localhost',
@ -27,22 +30,21 @@ def run_dist(rank, world_size, port):
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
# adapt to a Callbale with empty parameters colo_model = model_builder(checkpoint=True)
# def module_builder_new(): torch_model = copy.deepcopy(colo_model).cuda()
# return model_builder(checkpoint=True) engine, train_dataloader, _, _ = colossalai.initialize(colo_model,
zero_model = model_builder(checkpoint=True)
torch_model = copy.deepcopy(zero_model).cuda()
engine, train_dataloader, _, _ = colossalai.initialize(zero_model,
optimizer=optimizer_class, optimizer=optimizer_class,
criterion=criterion, criterion=criterion,
train_dataloader=train_dataloader) train_dataloader=train_dataloader)
engine.train() engine.train()
torch_optimizer = optimizer_class(torch_model.parameters()) torch_optimizer = optimizer_class(torch_model.parameters())
if dist.get_world_size() > 1:
torch_model = DDP(torch_model)
i = 0 i = 0
for data, label in train_dataloader: for data, label in train_dataloader:
if i > 3: if i > 4:
break break
data, label = data.cuda(), label.cuda() data, label = data.cuda(), label.cuda()
@ -67,15 +69,28 @@ def run_dist(rank, world_size, port):
torch_optimizer.step() torch_optimizer.step()
i += 1 i += 1
check_sharded_params_padding(torch_model, zero_model, loose=True) # for torch_param, zero_param in zip(torch_model.parameters(), colo_model.parameters()):
# assert torch.allclose(torch_param, zero_param), f"diff {torch_param - zero_param}"
if parallel_config == MP_PARALLEL_CONFIG:
check_params(torch_model, colo_model, loose=True)
elif isinstance(colo_model, ShardedModelV2):
check_sharded_params_padding(torch_model, colo_model, loose=True)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4])
def test_mp_engine(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=MP_PARALLEL_CONFIG)
mp.spawn(run_func, nprocs=world_size)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
def test_zero_init(world_size): def test_zero_engine(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=ZERO_PARALLEL_CONFIG)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_zero_init(world_size=2) test_zero_engine(world_size=4)