[refactory] refactory the initialize method for new zero design (#431)

pull/420/head^2
Jiarui Fang 2022-03-16 19:29:37 +08:00 committed by GitHub
parent 4f85b687cf
commit 640a6cd304
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 184 additions and 24 deletions

View File

@ -5,7 +5,7 @@ import argparse
import os import os
import pprint import pprint
from pathlib import Path from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, Union from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -26,8 +26,9 @@ from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence, from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
sync_model_param) sync_model_param)
from colossalai.zero import convert_to_zero, ShardedOptimizer from colossalai.zero import convert_to_zero_v2
from colossalai.engine.ophooks import BaseOpHook from colossalai.engine.ophooks import BaseOpHook
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
def get_default_parser(): def get_default_parser():
@ -216,8 +217,8 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
verbose=verbose) verbose=verbose)
def initialize(model: nn.Module, def initialize(model: Union[Callable, nn.Module],
optimizer: Optimizer, optimizer: Union[Type[Optimizer], Optimizer],
criterion: Optional[_Loss] = None, criterion: Optional[_Loss] = None,
train_dataloader: Optional[Iterable] = None, train_dataloader: Optional[Iterable] = None,
test_dataloader: Optional[Iterable] = None, test_dataloader: Optional[Iterable] = None,
@ -227,10 +228,10 @@ def initialize(model: nn.Module,
"""Core function to wrap the essential training components with our functionality based on the config which is """Core function to wrap the essential training components with our functionality based on the config which is
loaded into gpc.config. loaded into gpc.config.
:param model: Your model instance :param model: Your model instance or a function to build the model
:type model: :class:`torch.nn.Module` :type model: :class:`torch.nn.Module` or Callbale
:param optimizer: Your optimizer instance :param optimizer: Your optimizer instance
:type optimizer: :class:`torch.optim.optimizer.Optimizer` :type optimizer: :class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`
:param criterion: Your criterion instance :param criterion: Your criterion instance
:type criterion: :class:`torch.nn.modules.loss._Loss`, optional :type criterion: :class:`torch.nn.modules.loss._Loss`, optional
:param train_dataloader: Dataloader for training :param train_dataloader: Dataloader for training
@ -267,10 +268,28 @@ def initialize(model: nn.Module,
if verbose: if verbose:
logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0]) logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
# first sync model across dp ranks use_zero = hasattr(gpc.config, 'zero')
model.to(get_current_device()) if use_zero:
use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3 zero_cfg = gpc.config.get('zero', None)
if not moe_env.is_initialized() and not use_zero3: if zero_cfg is not None:
cfg_ = zero_cfg.copy()
else:
cfg_ = {}
optimizer_config = zero_cfg.get('optimzer', None)
model, optimizer = convert_to_zero_v2(model_builder=model, optimizer_config=optimizer_config)
logger.info("Initializing ZeRO model and optimzer finished!", ranks=[0])
#FIXME() throw a warning if using zero with MP
if gpc.get_world_size(ParallelMode.MODEL) > 1:
logger.warning("ZeRO currently has not been tested with model parallelism.", ranks=[0])
else:
if isinstance(model, nn.Module):
# first sync model across dp ranks
model.to(get_current_device())
elif isinstance(model, Callable):
model = model().to(get_current_device())
if not moe_env.is_initialized() and not use_zero:
if is_using_sequence(): if is_using_sequence():
sync_model_param(model, ParallelMode.SEQUENCE_DP) sync_model_param(model, ParallelMode.SEQUENCE_DP)
elif is_using_ddp(): elif is_using_ddp():
@ -283,16 +302,15 @@ def initialize(model: nn.Module,
# check amp and zero # check amp and zero
fp16_cfg = gpc.config.get('fp16', None) fp16_cfg = gpc.config.get('fp16', None)
zero_cfg = gpc.config.get('zero', None)
if fp16_cfg is not None and fp16_cfg.mode is not None and zero_cfg is not None: if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero:
raise ConfigException( raise ConfigException(
"It is not allowed to set fp16 and zero configuration in your config file at the same time") "It is not allowed to set fp16 and zero configuration in your config file at the same time")
# clip grad norm # clip grad norm
clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0) clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
if clip_grad_norm > 0: if clip_grad_norm > 0:
if zero_cfg is not None: if use_zero and zero_cfg is not None:
raise ConfigException( raise ConfigException(
"clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration") "clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration")
@ -311,11 +329,6 @@ def initialize(model: nn.Module,
mode=amp_mode, mode=amp_mode,
amp_config=cfg_) amp_config=cfg_)
if zero_cfg is not None:
cfg_ = zero_cfg.copy()
level = cfg_.pop('level')
model, optimizer = convert_to_zero(model=model, optimizer=optimizer, level=level, zero_config=cfg_)
# gradient handler # gradient handler
gradient_handler_cfg = gpc.config.get('gradient_handler', None) gradient_handler_cfg = gpc.config.get('gradient_handler', None)
if gradient_handler_cfg is None: if gradient_handler_cfg is None:
@ -324,7 +337,7 @@ def initialize(model: nn.Module,
# 1. if optimizer is ZERO, then use zero grad handler # 1. if optimizer is ZERO, then use zero grad handler
# 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp # 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
# 3. if using pipeline and dp size larger than 1, use data parallel grad handler # 3. if using pipeline and dp size larger than 1, use data parallel grad handler
if isinstance(optimizer, ShardedOptimizer): if isinstance(optimizer, ShardedOptimizerV2):
gradient_handler_cfg = [dict(type='ZeROGradientHandler')] gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
if verbose: if verbose:
logger.info( logger.info(
@ -392,7 +405,7 @@ def initialize(model: nn.Module,
gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg] gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
# check if optimizer is ColossalaiOptimizer # check if optimizer is ColossalaiOptimizer
if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizer)): if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizerV2)):
optimizer = ColossalaiOptimizer(optim=optimizer) optimizer = ColossalaiOptimizer(optim=optimizer)
# gradient accumulation # gradient accumulation

View File

@ -1,4 +1,8 @@
from asyncio.log import logger
from distutils.command.config import config from distutils.command.config import config
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
from colossalai.zero.shard_utils import TensorShardStrategy
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.amp.naive_amp import NaiveAMPModel
@ -7,6 +11,53 @@ from colossalai.core import global_context as gpc
from torch.optim import Optimizer from torch.optim import Optimizer
from .sharded_model import ShardedModel from .sharded_model import ShardedModel
from .sharded_optim import ShardedOptimizer from .sharded_optim import ShardedOptimizer
from colossalai.zero.init_ctx import ZeroInitContext
from typing import Callable, Type
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
def convert_to_zero_v2(model_builder: Callable, optimizer_config) -> (ShardedModelV2, ShardedOptimizerV2):
"""
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
:param model: Your model object
:type model: :class:`torch.nn.Module`
:param optimizer_config: Your optimizer object
:type optimizer_config: :class:`dict`
:return: (model, optimizer)
:rtype: Tuple
"""
logger = get_dist_logger('convert_to_zero_v2')
# FIXME() pass shard strategy from config
shard_strategy = TensorShardStrategy()
if isinstance(model_builder, nn.Module):
model = model_builder
elif isinstance(model_builder, Callable):
with ZeroInitContext(convert_fp16='fp16' in gpc.config,
target_device=torch.cuda.current_device(),
shard_strategy=shard_strategy,
shard_param=True):
model = model_builder()
else:
raise TypeError(f"convert_to_zero_v2 dose not support model_builder of type {type(convert_to_zero_v2)}")
zero_model = ShardedModelV2(model, shard_strategy=shard_strategy)
optimizer_class = optimizer_config.get('optimizer_type', None)
if optimizer_class is None:
raise RuntimeError("Set optimizer_class in zero_config")
logger.info(f'optimizer class is {optimizer_class}')
cfg = optimizer_config.get('optimizer_config', None)
logger.info(f'optimizer_config is {cfg}')
zero_optimizer = ShardedOptimizerV2(zero_model, optimizer_class, **optimizer_config.get('optimizer_config', None))
return zero_model, zero_optimizer
def convert_to_zero(model: nn.Module, optimizer: Optimizer, level: int, zero_config: dict): def convert_to_zero(model: nn.Module, optimizer: Optimizer, level: int, zero_config: dict):

View File

@ -223,3 +223,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Because we will judge whether local grad accumulation # Because we will judge whether local grad accumulation
# is enabled by wheter grad is None # is enabled by wheter grad is None
self.optim.zero_grad(set_to_none=True) self.optim.zero_grad(set_to_none=True)
def sync_grad(self):
pass

View File

@ -19,9 +19,10 @@ def run_dist(rank, world_size, port):
# as this model has sync batch normalization # as this model has sync batch normalization
# need to configure cudnn deterministic so that # need to configure cudnn deterministic so that
# randomness of convolution layers will be disabled # randomness of convolution layers will be disabled
colossalai.launch(config=dict(zero=dict(level=2, partition_grad=True), colossalai.launch(config=dict(
cudnn_determinstic=True, zero=dict(optimzer=dict(optimizer_type=torch.optim.Adam, optimizer_config=dict(lr=1e-3))),
cudnn_benchmark=False), cudnn_determinstic=True,
cudnn_benchmark=False),
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
host='localhost', host='localhost',

View File

@ -0,0 +1,92 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import copy
from functools import partial
import pytest
import colossalai
from colossalai.utils import free_port
import torch
import torch.multiprocessing as mp
from tests.components_to_test.registry import non_distributed_component_funcs
from common import check_sharded_params_padding
def run_dist(rank, world_size, port):
_config = dict(fp16=dict(mode=None,),
zero=dict(optimzer=dict(optimizer_type=torch.optim.Adam, optimizer_config=dict(lr=1e-3)),
offload_optimizer_config=dict(device='cpu',
pin_memory=True,
buffer_count=5,
fast_init=False),
offload_param_config=dict(device='cpu',
pin_memory=True,
buffer_count=5,
buffer_size=1e8,
max_in_cpu=1e9)),
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
colossalai.launch(config=_config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# FIXME revert back
# test_models = ['repeated_computed_layers', 'resnet18', 'bert']
test_models = ['bert']
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
# adapt to a Callbale with empty parameters
# def module_builder_new():
# return model_builder(checkpoint=True)
zero_model = model_builder(checkpoint=True)
torch_model = copy.deepcopy(zero_model).cuda()
engine, train_dataloader, _, _ = colossalai.initialize(zero_model,
optimizer=optimizer_class,
criterion=criterion,
train_dataloader=train_dataloader)
engine.train()
torch_optimizer = optimizer_class(torch_model.parameters())
i = 0
for data, label in train_dataloader:
if i > 3:
break
data, label = data.cuda(), label.cuda()
engine.zero_grad()
torch_optimizer.zero_grad()
if criterion:
output = engine(data)
loss = engine.criterion(output, label)
torch_model(data, label)
torch_loss = engine.criterion(output, label)
else:
loss = engine(data, label)
torch_loss = torch_model(data, label)
engine.backward(loss)
engine.step()
torch_loss.backward()
torch_optimizer.step()
i += 1
check_sharded_params_padding(torch_model, zero_model, loose=True)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
def test_zero_init(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_init(world_size=2)