mirror of https://github.com/hpcaitech/ColossalAI
[refactory] refactory the initialize method for new zero design (#431)
parent
4f85b687cf
commit
640a6cd304
|
@ -5,7 +5,7 @@ import argparse
|
|||
import os
|
||||
import pprint
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -26,8 +26,9 @@ from colossalai.logging import get_dist_logger
|
|||
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
|
||||
sync_model_param)
|
||||
from colossalai.zero import convert_to_zero, ShardedOptimizer
|
||||
from colossalai.zero import convert_to_zero_v2
|
||||
from colossalai.engine.ophooks import BaseOpHook
|
||||
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
|
||||
|
||||
|
||||
def get_default_parser():
|
||||
|
@ -216,8 +217,8 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
|
|||
verbose=verbose)
|
||||
|
||||
|
||||
def initialize(model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
def initialize(model: Union[Callable, nn.Module],
|
||||
optimizer: Union[Type[Optimizer], Optimizer],
|
||||
criterion: Optional[_Loss] = None,
|
||||
train_dataloader: Optional[Iterable] = None,
|
||||
test_dataloader: Optional[Iterable] = None,
|
||||
|
@ -227,10 +228,10 @@ def initialize(model: nn.Module,
|
|||
"""Core function to wrap the essential training components with our functionality based on the config which is
|
||||
loaded into gpc.config.
|
||||
|
||||
:param model: Your model instance
|
||||
:type model: :class:`torch.nn.Module`
|
||||
:param model: Your model instance or a function to build the model
|
||||
:type model: :class:`torch.nn.Module` or Callbale
|
||||
:param optimizer: Your optimizer instance
|
||||
:type optimizer: :class:`torch.optim.optimizer.Optimizer`
|
||||
:type optimizer: :class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`
|
||||
:param criterion: Your criterion instance
|
||||
:type criterion: :class:`torch.nn.modules.loss._Loss`, optional
|
||||
:param train_dataloader: Dataloader for training
|
||||
|
@ -267,10 +268,28 @@ def initialize(model: nn.Module,
|
|||
if verbose:
|
||||
logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
|
||||
|
||||
# first sync model across dp ranks
|
||||
model.to(get_current_device())
|
||||
use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3
|
||||
if not moe_env.is_initialized() and not use_zero3:
|
||||
use_zero = hasattr(gpc.config, 'zero')
|
||||
if use_zero:
|
||||
zero_cfg = gpc.config.get('zero', None)
|
||||
if zero_cfg is not None:
|
||||
cfg_ = zero_cfg.copy()
|
||||
else:
|
||||
cfg_ = {}
|
||||
optimizer_config = zero_cfg.get('optimzer', None)
|
||||
model, optimizer = convert_to_zero_v2(model_builder=model, optimizer_config=optimizer_config)
|
||||
|
||||
logger.info("Initializing ZeRO model and optimzer finished!", ranks=[0])
|
||||
#FIXME() throw a warning if using zero with MP
|
||||
if gpc.get_world_size(ParallelMode.MODEL) > 1:
|
||||
logger.warning("ZeRO currently has not been tested with model parallelism.", ranks=[0])
|
||||
else:
|
||||
if isinstance(model, nn.Module):
|
||||
# first sync model across dp ranks
|
||||
model.to(get_current_device())
|
||||
elif isinstance(model, Callable):
|
||||
model = model().to(get_current_device())
|
||||
|
||||
if not moe_env.is_initialized() and not use_zero:
|
||||
if is_using_sequence():
|
||||
sync_model_param(model, ParallelMode.SEQUENCE_DP)
|
||||
elif is_using_ddp():
|
||||
|
@ -283,16 +302,15 @@ def initialize(model: nn.Module,
|
|||
|
||||
# check amp and zero
|
||||
fp16_cfg = gpc.config.get('fp16', None)
|
||||
zero_cfg = gpc.config.get('zero', None)
|
||||
|
||||
if fp16_cfg is not None and fp16_cfg.mode is not None and zero_cfg is not None:
|
||||
if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero:
|
||||
raise ConfigException(
|
||||
"It is not allowed to set fp16 and zero configuration in your config file at the same time")
|
||||
|
||||
# clip grad norm
|
||||
clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
|
||||
if clip_grad_norm > 0:
|
||||
if zero_cfg is not None:
|
||||
if use_zero and zero_cfg is not None:
|
||||
raise ConfigException(
|
||||
"clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration")
|
||||
|
||||
|
@ -311,11 +329,6 @@ def initialize(model: nn.Module,
|
|||
mode=amp_mode,
|
||||
amp_config=cfg_)
|
||||
|
||||
if zero_cfg is not None:
|
||||
cfg_ = zero_cfg.copy()
|
||||
level = cfg_.pop('level')
|
||||
model, optimizer = convert_to_zero(model=model, optimizer=optimizer, level=level, zero_config=cfg_)
|
||||
|
||||
# gradient handler
|
||||
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
|
||||
if gradient_handler_cfg is None:
|
||||
|
@ -324,7 +337,7 @@ def initialize(model: nn.Module,
|
|||
# 1. if optimizer is ZERO, then use zero grad handler
|
||||
# 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
|
||||
# 3. if using pipeline and dp size larger than 1, use data parallel grad handler
|
||||
if isinstance(optimizer, ShardedOptimizer):
|
||||
if isinstance(optimizer, ShardedOptimizerV2):
|
||||
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
|
||||
if verbose:
|
||||
logger.info(
|
||||
|
@ -392,7 +405,7 @@ def initialize(model: nn.Module,
|
|||
gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
|
||||
|
||||
# check if optimizer is ColossalaiOptimizer
|
||||
if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizer)):
|
||||
if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizerV2)):
|
||||
optimizer = ColossalaiOptimizer(optim=optimizer)
|
||||
|
||||
# gradient accumulation
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
from asyncio.log import logger
|
||||
from distutils.command.config import config
|
||||
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
|
||||
from colossalai.zero.shard_utils import TensorShardStrategy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||
|
@ -7,6 +11,53 @@ from colossalai.core import global_context as gpc
|
|||
from torch.optim import Optimizer
|
||||
from .sharded_model import ShardedModel
|
||||
from .sharded_optim import ShardedOptimizer
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from typing import Callable, Type
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
def convert_to_zero_v2(model_builder: Callable, optimizer_config) -> (ShardedModelV2, ShardedOptimizerV2):
|
||||
"""
|
||||
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
|
||||
|
||||
:param model: Your model object
|
||||
:type model: :class:`torch.nn.Module`
|
||||
:param optimizer_config: Your optimizer object
|
||||
:type optimizer_config: :class:`dict`
|
||||
|
||||
:return: (model, optimizer)
|
||||
:rtype: Tuple
|
||||
"""
|
||||
|
||||
logger = get_dist_logger('convert_to_zero_v2')
|
||||
|
||||
# FIXME() pass shard strategy from config
|
||||
shard_strategy = TensorShardStrategy()
|
||||
|
||||
if isinstance(model_builder, nn.Module):
|
||||
model = model_builder
|
||||
elif isinstance(model_builder, Callable):
|
||||
with ZeroInitContext(convert_fp16='fp16' in gpc.config,
|
||||
target_device=torch.cuda.current_device(),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True):
|
||||
model = model_builder()
|
||||
else:
|
||||
raise TypeError(f"convert_to_zero_v2 dose not support model_builder of type {type(convert_to_zero_v2)}")
|
||||
|
||||
zero_model = ShardedModelV2(model, shard_strategy=shard_strategy)
|
||||
|
||||
optimizer_class = optimizer_config.get('optimizer_type', None)
|
||||
if optimizer_class is None:
|
||||
raise RuntimeError("Set optimizer_class in zero_config")
|
||||
logger.info(f'optimizer class is {optimizer_class}')
|
||||
|
||||
cfg = optimizer_config.get('optimizer_config', None)
|
||||
logger.info(f'optimizer_config is {cfg}')
|
||||
|
||||
zero_optimizer = ShardedOptimizerV2(zero_model, optimizer_class, **optimizer_config.get('optimizer_config', None))
|
||||
return zero_model, zero_optimizer
|
||||
|
||||
|
||||
def convert_to_zero(model: nn.Module, optimizer: Optimizer, level: int, zero_config: dict):
|
||||
|
|
|
@ -223,3 +223,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
# Because we will judge whether local grad accumulation
|
||||
# is enabled by wheter grad is None
|
||||
self.optim.zero_grad(set_to_none=True)
|
||||
|
||||
def sync_grad(self):
|
||||
pass
|
||||
|
|
|
@ -19,9 +19,10 @@ def run_dist(rank, world_size, port):
|
|||
# as this model has sync batch normalization
|
||||
# need to configure cudnn deterministic so that
|
||||
# randomness of convolution layers will be disabled
|
||||
colossalai.launch(config=dict(zero=dict(level=2, partition_grad=True),
|
||||
cudnn_determinstic=True,
|
||||
cudnn_benchmark=False),
|
||||
colossalai.launch(config=dict(
|
||||
zero=dict(optimzer=dict(optimizer_type=torch.optim.Adam, optimizer_config=dict(lr=1e-3))),
|
||||
cudnn_determinstic=True,
|
||||
cudnn_benchmark=False),
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
import pytest
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
from common import check_sharded_params_padding
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
_config = dict(fp16=dict(mode=None,),
|
||||
zero=dict(optimzer=dict(optimizer_type=torch.optim.Adam, optimizer_config=dict(lr=1e-3)),
|
||||
offload_optimizer_config=dict(device='cpu',
|
||||
pin_memory=True,
|
||||
buffer_count=5,
|
||||
fast_init=False),
|
||||
offload_param_config=dict(device='cpu',
|
||||
pin_memory=True,
|
||||
buffer_count=5,
|
||||
buffer_size=1e8,
|
||||
max_in_cpu=1e9)),
|
||||
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
|
||||
|
||||
colossalai.launch(config=_config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
# FIXME revert back
|
||||
# test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
||||
test_models = ['bert']
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||
|
||||
# adapt to a Callbale with empty parameters
|
||||
# def module_builder_new():
|
||||
# return model_builder(checkpoint=True)
|
||||
|
||||
zero_model = model_builder(checkpoint=True)
|
||||
torch_model = copy.deepcopy(zero_model).cuda()
|
||||
engine, train_dataloader, _, _ = colossalai.initialize(zero_model,
|
||||
optimizer=optimizer_class,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader)
|
||||
engine.train()
|
||||
torch_optimizer = optimizer_class(torch_model.parameters())
|
||||
|
||||
i = 0
|
||||
for data, label in train_dataloader:
|
||||
if i > 3:
|
||||
break
|
||||
|
||||
data, label = data.cuda(), label.cuda()
|
||||
|
||||
engine.zero_grad()
|
||||
torch_optimizer.zero_grad()
|
||||
|
||||
if criterion:
|
||||
output = engine(data)
|
||||
loss = engine.criterion(output, label)
|
||||
|
||||
torch_model(data, label)
|
||||
torch_loss = engine.criterion(output, label)
|
||||
else:
|
||||
loss = engine(data, label)
|
||||
torch_loss = torch_model(data, label)
|
||||
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
|
||||
torch_loss.backward()
|
||||
torch_optimizer.step()
|
||||
i += 1
|
||||
|
||||
check_sharded_params_padding(torch_model, zero_model, loose=True)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
def test_zero_init(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_init(world_size=2)
|
Loading…
Reference in New Issue