[refactory] refactory the initialize method for new zero design (#431)

pull/420/head^2
Jiarui Fang 2022-03-16 19:29:37 +08:00 committed by GitHub
parent 4f85b687cf
commit 640a6cd304
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 184 additions and 24 deletions

View File

@ -5,7 +5,7 @@ import argparse
import os
import pprint
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union, Type
import torch
import torch.nn as nn
@ -26,8 +26,9 @@ from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
sync_model_param)
from colossalai.zero import convert_to_zero, ShardedOptimizer
from colossalai.zero import convert_to_zero_v2
from colossalai.engine.ophooks import BaseOpHook
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
def get_default_parser():
@ -216,8 +217,8 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
verbose=verbose)
def initialize(model: nn.Module,
optimizer: Optimizer,
def initialize(model: Union[Callable, nn.Module],
optimizer: Union[Type[Optimizer], Optimizer],
criterion: Optional[_Loss] = None,
train_dataloader: Optional[Iterable] = None,
test_dataloader: Optional[Iterable] = None,
@ -227,10 +228,10 @@ def initialize(model: nn.Module,
"""Core function to wrap the essential training components with our functionality based on the config which is
loaded into gpc.config.
:param model: Your model instance
:type model: :class:`torch.nn.Module`
:param model: Your model instance or a function to build the model
:type model: :class:`torch.nn.Module` or Callbale
:param optimizer: Your optimizer instance
:type optimizer: :class:`torch.optim.optimizer.Optimizer`
:type optimizer: :class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`
:param criterion: Your criterion instance
:type criterion: :class:`torch.nn.modules.loss._Loss`, optional
:param train_dataloader: Dataloader for training
@ -267,10 +268,28 @@ def initialize(model: nn.Module,
if verbose:
logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
# first sync model across dp ranks
model.to(get_current_device())
use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3
if not moe_env.is_initialized() and not use_zero3:
use_zero = hasattr(gpc.config, 'zero')
if use_zero:
zero_cfg = gpc.config.get('zero', None)
if zero_cfg is not None:
cfg_ = zero_cfg.copy()
else:
cfg_ = {}
optimizer_config = zero_cfg.get('optimzer', None)
model, optimizer = convert_to_zero_v2(model_builder=model, optimizer_config=optimizer_config)
logger.info("Initializing ZeRO model and optimzer finished!", ranks=[0])
#FIXME() throw a warning if using zero with MP
if gpc.get_world_size(ParallelMode.MODEL) > 1:
logger.warning("ZeRO currently has not been tested with model parallelism.", ranks=[0])
else:
if isinstance(model, nn.Module):
# first sync model across dp ranks
model.to(get_current_device())
elif isinstance(model, Callable):
model = model().to(get_current_device())
if not moe_env.is_initialized() and not use_zero:
if is_using_sequence():
sync_model_param(model, ParallelMode.SEQUENCE_DP)
elif is_using_ddp():
@ -283,16 +302,15 @@ def initialize(model: nn.Module,
# check amp and zero
fp16_cfg = gpc.config.get('fp16', None)
zero_cfg = gpc.config.get('zero', None)
if fp16_cfg is not None and fp16_cfg.mode is not None and zero_cfg is not None:
if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero:
raise ConfigException(
"It is not allowed to set fp16 and zero configuration in your config file at the same time")
# clip grad norm
clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
if clip_grad_norm > 0:
if zero_cfg is not None:
if use_zero and zero_cfg is not None:
raise ConfigException(
"clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration")
@ -311,11 +329,6 @@ def initialize(model: nn.Module,
mode=amp_mode,
amp_config=cfg_)
if zero_cfg is not None:
cfg_ = zero_cfg.copy()
level = cfg_.pop('level')
model, optimizer = convert_to_zero(model=model, optimizer=optimizer, level=level, zero_config=cfg_)
# gradient handler
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
if gradient_handler_cfg is None:
@ -324,7 +337,7 @@ def initialize(model: nn.Module,
# 1. if optimizer is ZERO, then use zero grad handler
# 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
# 3. if using pipeline and dp size larger than 1, use data parallel grad handler
if isinstance(optimizer, ShardedOptimizer):
if isinstance(optimizer, ShardedOptimizerV2):
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
if verbose:
logger.info(
@ -392,7 +405,7 @@ def initialize(model: nn.Module,
gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
# check if optimizer is ColossalaiOptimizer
if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizer)):
if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizerV2)):
optimizer = ColossalaiOptimizer(optim=optimizer)
# gradient accumulation

View File

@ -1,4 +1,8 @@
from asyncio.log import logger
from distutils.command.config import config
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
from colossalai.zero.shard_utils import TensorShardStrategy
import torch
import torch.nn as nn
from colossalai.amp.naive_amp import NaiveAMPModel
@ -7,6 +11,53 @@ from colossalai.core import global_context as gpc
from torch.optim import Optimizer
from .sharded_model import ShardedModel
from .sharded_optim import ShardedOptimizer
from colossalai.zero.init_ctx import ZeroInitContext
from typing import Callable, Type
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
def convert_to_zero_v2(model_builder: Callable, optimizer_config) -> (ShardedModelV2, ShardedOptimizerV2):
"""
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
:param model: Your model object
:type model: :class:`torch.nn.Module`
:param optimizer_config: Your optimizer object
:type optimizer_config: :class:`dict`
:return: (model, optimizer)
:rtype: Tuple
"""
logger = get_dist_logger('convert_to_zero_v2')
# FIXME() pass shard strategy from config
shard_strategy = TensorShardStrategy()
if isinstance(model_builder, nn.Module):
model = model_builder
elif isinstance(model_builder, Callable):
with ZeroInitContext(convert_fp16='fp16' in gpc.config,
target_device=torch.cuda.current_device(),
shard_strategy=shard_strategy,
shard_param=True):
model = model_builder()
else:
raise TypeError(f"convert_to_zero_v2 dose not support model_builder of type {type(convert_to_zero_v2)}")
zero_model = ShardedModelV2(model, shard_strategy=shard_strategy)
optimizer_class = optimizer_config.get('optimizer_type', None)
if optimizer_class is None:
raise RuntimeError("Set optimizer_class in zero_config")
logger.info(f'optimizer class is {optimizer_class}')
cfg = optimizer_config.get('optimizer_config', None)
logger.info(f'optimizer_config is {cfg}')
zero_optimizer = ShardedOptimizerV2(zero_model, optimizer_class, **optimizer_config.get('optimizer_config', None))
return zero_model, zero_optimizer
def convert_to_zero(model: nn.Module, optimizer: Optimizer, level: int, zero_config: dict):

View File

@ -223,3 +223,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Because we will judge whether local grad accumulation
# is enabled by wheter grad is None
self.optim.zero_grad(set_to_none=True)
def sync_grad(self):
pass

View File

@ -19,9 +19,10 @@ def run_dist(rank, world_size, port):
# as this model has sync batch normalization
# need to configure cudnn deterministic so that
# randomness of convolution layers will be disabled
colossalai.launch(config=dict(zero=dict(level=2, partition_grad=True),
cudnn_determinstic=True,
cudnn_benchmark=False),
colossalai.launch(config=dict(
zero=dict(optimzer=dict(optimizer_type=torch.optim.Adam, optimizer_config=dict(lr=1e-3))),
cudnn_determinstic=True,
cudnn_benchmark=False),
rank=rank,
world_size=world_size,
host='localhost',

View File

@ -0,0 +1,92 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import copy
from functools import partial
import pytest
import colossalai
from colossalai.utils import free_port
import torch
import torch.multiprocessing as mp
from tests.components_to_test.registry import non_distributed_component_funcs
from common import check_sharded_params_padding
def run_dist(rank, world_size, port):
_config = dict(fp16=dict(mode=None,),
zero=dict(optimzer=dict(optimizer_type=torch.optim.Adam, optimizer_config=dict(lr=1e-3)),
offload_optimizer_config=dict(device='cpu',
pin_memory=True,
buffer_count=5,
fast_init=False),
offload_param_config=dict(device='cpu',
pin_memory=True,
buffer_count=5,
buffer_size=1e8,
max_in_cpu=1e9)),
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
colossalai.launch(config=_config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# FIXME revert back
# test_models = ['repeated_computed_layers', 'resnet18', 'bert']
test_models = ['bert']
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
# adapt to a Callbale with empty parameters
# def module_builder_new():
# return model_builder(checkpoint=True)
zero_model = model_builder(checkpoint=True)
torch_model = copy.deepcopy(zero_model).cuda()
engine, train_dataloader, _, _ = colossalai.initialize(zero_model,
optimizer=optimizer_class,
criterion=criterion,
train_dataloader=train_dataloader)
engine.train()
torch_optimizer = optimizer_class(torch_model.parameters())
i = 0
for data, label in train_dataloader:
if i > 3:
break
data, label = data.cuda(), label.cuda()
engine.zero_grad()
torch_optimizer.zero_grad()
if criterion:
output = engine(data)
loss = engine.criterion(output, label)
torch_model(data, label)
torch_loss = engine.criterion(output, label)
else:
loss = engine(data, label)
torch_loss = torch_model(data, label)
engine.backward(loss)
engine.step()
torch_loss.backward()
torch_optimizer.step()
i += 1
check_sharded_params_padding(torch_model, zero_model, loose=True)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
def test_zero_init(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_init(world_size=2)