mirror of https://github.com/hpcaitech/ColossalAI
[refactory] refactory the initialize method for new zero design (#431)
parent
4f85b687cf
commit
640a6cd304
|
@ -5,7 +5,7 @@ import argparse
|
||||||
import os
|
import os
|
||||||
import pprint
|
import pprint
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -26,8 +26,9 @@ from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
||||||
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
|
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
|
||||||
sync_model_param)
|
sync_model_param)
|
||||||
from colossalai.zero import convert_to_zero, ShardedOptimizer
|
from colossalai.zero import convert_to_zero_v2
|
||||||
from colossalai.engine.ophooks import BaseOpHook
|
from colossalai.engine.ophooks import BaseOpHook
|
||||||
|
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
|
||||||
|
|
||||||
|
|
||||||
def get_default_parser():
|
def get_default_parser():
|
||||||
|
@ -216,8 +217,8 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
|
||||||
verbose=verbose)
|
verbose=verbose)
|
||||||
|
|
||||||
|
|
||||||
def initialize(model: nn.Module,
|
def initialize(model: Union[Callable, nn.Module],
|
||||||
optimizer: Optimizer,
|
optimizer: Union[Type[Optimizer], Optimizer],
|
||||||
criterion: Optional[_Loss] = None,
|
criterion: Optional[_Loss] = None,
|
||||||
train_dataloader: Optional[Iterable] = None,
|
train_dataloader: Optional[Iterable] = None,
|
||||||
test_dataloader: Optional[Iterable] = None,
|
test_dataloader: Optional[Iterable] = None,
|
||||||
|
@ -227,10 +228,10 @@ def initialize(model: nn.Module,
|
||||||
"""Core function to wrap the essential training components with our functionality based on the config which is
|
"""Core function to wrap the essential training components with our functionality based on the config which is
|
||||||
loaded into gpc.config.
|
loaded into gpc.config.
|
||||||
|
|
||||||
:param model: Your model instance
|
:param model: Your model instance or a function to build the model
|
||||||
:type model: :class:`torch.nn.Module`
|
:type model: :class:`torch.nn.Module` or Callbale
|
||||||
:param optimizer: Your optimizer instance
|
:param optimizer: Your optimizer instance
|
||||||
:type optimizer: :class:`torch.optim.optimizer.Optimizer`
|
:type optimizer: :class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`
|
||||||
:param criterion: Your criterion instance
|
:param criterion: Your criterion instance
|
||||||
:type criterion: :class:`torch.nn.modules.loss._Loss`, optional
|
:type criterion: :class:`torch.nn.modules.loss._Loss`, optional
|
||||||
:param train_dataloader: Dataloader for training
|
:param train_dataloader: Dataloader for training
|
||||||
|
@ -267,10 +268,28 @@ def initialize(model: nn.Module,
|
||||||
if verbose:
|
if verbose:
|
||||||
logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
|
logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
|
||||||
|
|
||||||
# first sync model across dp ranks
|
use_zero = hasattr(gpc.config, 'zero')
|
||||||
model.to(get_current_device())
|
if use_zero:
|
||||||
use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3
|
zero_cfg = gpc.config.get('zero', None)
|
||||||
if not moe_env.is_initialized() and not use_zero3:
|
if zero_cfg is not None:
|
||||||
|
cfg_ = zero_cfg.copy()
|
||||||
|
else:
|
||||||
|
cfg_ = {}
|
||||||
|
optimizer_config = zero_cfg.get('optimzer', None)
|
||||||
|
model, optimizer = convert_to_zero_v2(model_builder=model, optimizer_config=optimizer_config)
|
||||||
|
|
||||||
|
logger.info("Initializing ZeRO model and optimzer finished!", ranks=[0])
|
||||||
|
#FIXME() throw a warning if using zero with MP
|
||||||
|
if gpc.get_world_size(ParallelMode.MODEL) > 1:
|
||||||
|
logger.warning("ZeRO currently has not been tested with model parallelism.", ranks=[0])
|
||||||
|
else:
|
||||||
|
if isinstance(model, nn.Module):
|
||||||
|
# first sync model across dp ranks
|
||||||
|
model.to(get_current_device())
|
||||||
|
elif isinstance(model, Callable):
|
||||||
|
model = model().to(get_current_device())
|
||||||
|
|
||||||
|
if not moe_env.is_initialized() and not use_zero:
|
||||||
if is_using_sequence():
|
if is_using_sequence():
|
||||||
sync_model_param(model, ParallelMode.SEQUENCE_DP)
|
sync_model_param(model, ParallelMode.SEQUENCE_DP)
|
||||||
elif is_using_ddp():
|
elif is_using_ddp():
|
||||||
|
@ -283,16 +302,15 @@ def initialize(model: nn.Module,
|
||||||
|
|
||||||
# check amp and zero
|
# check amp and zero
|
||||||
fp16_cfg = gpc.config.get('fp16', None)
|
fp16_cfg = gpc.config.get('fp16', None)
|
||||||
zero_cfg = gpc.config.get('zero', None)
|
|
||||||
|
|
||||||
if fp16_cfg is not None and fp16_cfg.mode is not None and zero_cfg is not None:
|
if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero:
|
||||||
raise ConfigException(
|
raise ConfigException(
|
||||||
"It is not allowed to set fp16 and zero configuration in your config file at the same time")
|
"It is not allowed to set fp16 and zero configuration in your config file at the same time")
|
||||||
|
|
||||||
# clip grad norm
|
# clip grad norm
|
||||||
clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
|
clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
|
||||||
if clip_grad_norm > 0:
|
if clip_grad_norm > 0:
|
||||||
if zero_cfg is not None:
|
if use_zero and zero_cfg is not None:
|
||||||
raise ConfigException(
|
raise ConfigException(
|
||||||
"clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration")
|
"clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration")
|
||||||
|
|
||||||
|
@ -311,11 +329,6 @@ def initialize(model: nn.Module,
|
||||||
mode=amp_mode,
|
mode=amp_mode,
|
||||||
amp_config=cfg_)
|
amp_config=cfg_)
|
||||||
|
|
||||||
if zero_cfg is not None:
|
|
||||||
cfg_ = zero_cfg.copy()
|
|
||||||
level = cfg_.pop('level')
|
|
||||||
model, optimizer = convert_to_zero(model=model, optimizer=optimizer, level=level, zero_config=cfg_)
|
|
||||||
|
|
||||||
# gradient handler
|
# gradient handler
|
||||||
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
|
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
|
||||||
if gradient_handler_cfg is None:
|
if gradient_handler_cfg is None:
|
||||||
|
@ -324,7 +337,7 @@ def initialize(model: nn.Module,
|
||||||
# 1. if optimizer is ZERO, then use zero grad handler
|
# 1. if optimizer is ZERO, then use zero grad handler
|
||||||
# 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
|
# 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
|
||||||
# 3. if using pipeline and dp size larger than 1, use data parallel grad handler
|
# 3. if using pipeline and dp size larger than 1, use data parallel grad handler
|
||||||
if isinstance(optimizer, ShardedOptimizer):
|
if isinstance(optimizer, ShardedOptimizerV2):
|
||||||
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
|
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
|
||||||
if verbose:
|
if verbose:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -392,7 +405,7 @@ def initialize(model: nn.Module,
|
||||||
gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
|
gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
|
||||||
|
|
||||||
# check if optimizer is ColossalaiOptimizer
|
# check if optimizer is ColossalaiOptimizer
|
||||||
if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizer)):
|
if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizerV2)):
|
||||||
optimizer = ColossalaiOptimizer(optim=optimizer)
|
optimizer = ColossalaiOptimizer(optim=optimizer)
|
||||||
|
|
||||||
# gradient accumulation
|
# gradient accumulation
|
||||||
|
|
|
@ -1,4 +1,8 @@
|
||||||
|
from asyncio.log import logger
|
||||||
from distutils.command.config import config
|
from distutils.command.config import config
|
||||||
|
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||||
|
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
|
||||||
|
from colossalai.zero.shard_utils import TensorShardStrategy
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||||
|
@ -7,6 +11,53 @@ from colossalai.core import global_context as gpc
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from .sharded_model import ShardedModel
|
from .sharded_model import ShardedModel
|
||||||
from .sharded_optim import ShardedOptimizer
|
from .sharded_optim import ShardedOptimizer
|
||||||
|
from colossalai.zero.init_ctx import ZeroInitContext
|
||||||
|
from typing import Callable, Type
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_zero_v2(model_builder: Callable, optimizer_config) -> (ShardedModelV2, ShardedOptimizerV2):
|
||||||
|
"""
|
||||||
|
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
|
||||||
|
|
||||||
|
:param model: Your model object
|
||||||
|
:type model: :class:`torch.nn.Module`
|
||||||
|
:param optimizer_config: Your optimizer object
|
||||||
|
:type optimizer_config: :class:`dict`
|
||||||
|
|
||||||
|
:return: (model, optimizer)
|
||||||
|
:rtype: Tuple
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger = get_dist_logger('convert_to_zero_v2')
|
||||||
|
|
||||||
|
# FIXME() pass shard strategy from config
|
||||||
|
shard_strategy = TensorShardStrategy()
|
||||||
|
|
||||||
|
if isinstance(model_builder, nn.Module):
|
||||||
|
model = model_builder
|
||||||
|
elif isinstance(model_builder, Callable):
|
||||||
|
with ZeroInitContext(convert_fp16='fp16' in gpc.config,
|
||||||
|
target_device=torch.cuda.current_device(),
|
||||||
|
shard_strategy=shard_strategy,
|
||||||
|
shard_param=True):
|
||||||
|
model = model_builder()
|
||||||
|
else:
|
||||||
|
raise TypeError(f"convert_to_zero_v2 dose not support model_builder of type {type(convert_to_zero_v2)}")
|
||||||
|
|
||||||
|
zero_model = ShardedModelV2(model, shard_strategy=shard_strategy)
|
||||||
|
|
||||||
|
optimizer_class = optimizer_config.get('optimizer_type', None)
|
||||||
|
if optimizer_class is None:
|
||||||
|
raise RuntimeError("Set optimizer_class in zero_config")
|
||||||
|
logger.info(f'optimizer class is {optimizer_class}')
|
||||||
|
|
||||||
|
cfg = optimizer_config.get('optimizer_config', None)
|
||||||
|
logger.info(f'optimizer_config is {cfg}')
|
||||||
|
|
||||||
|
zero_optimizer = ShardedOptimizerV2(zero_model, optimizer_class, **optimizer_config.get('optimizer_config', None))
|
||||||
|
return zero_model, zero_optimizer
|
||||||
|
|
||||||
|
|
||||||
def convert_to_zero(model: nn.Module, optimizer: Optimizer, level: int, zero_config: dict):
|
def convert_to_zero(model: nn.Module, optimizer: Optimizer, level: int, zero_config: dict):
|
||||||
|
|
|
@ -223,3 +223,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
# Because we will judge whether local grad accumulation
|
# Because we will judge whether local grad accumulation
|
||||||
# is enabled by wheter grad is None
|
# is enabled by wheter grad is None
|
||||||
self.optim.zero_grad(set_to_none=True)
|
self.optim.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
def sync_grad(self):
|
||||||
|
pass
|
||||||
|
|
|
@ -19,9 +19,10 @@ def run_dist(rank, world_size, port):
|
||||||
# as this model has sync batch normalization
|
# as this model has sync batch normalization
|
||||||
# need to configure cudnn deterministic so that
|
# need to configure cudnn deterministic so that
|
||||||
# randomness of convolution layers will be disabled
|
# randomness of convolution layers will be disabled
|
||||||
colossalai.launch(config=dict(zero=dict(level=2, partition_grad=True),
|
colossalai.launch(config=dict(
|
||||||
cudnn_determinstic=True,
|
zero=dict(optimzer=dict(optimizer_type=torch.optim.Adam, optimizer_config=dict(lr=1e-3))),
|
||||||
cudnn_benchmark=False),
|
cudnn_determinstic=True,
|
||||||
|
cudnn_benchmark=False),
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
host='localhost',
|
host='localhost',
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
import copy
|
||||||
|
from functools import partial
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
|
from common import check_sharded_params_padding
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
_config = dict(fp16=dict(mode=None,),
|
||||||
|
zero=dict(optimzer=dict(optimizer_type=torch.optim.Adam, optimizer_config=dict(lr=1e-3)),
|
||||||
|
offload_optimizer_config=dict(device='cpu',
|
||||||
|
pin_memory=True,
|
||||||
|
buffer_count=5,
|
||||||
|
fast_init=False),
|
||||||
|
offload_param_config=dict(device='cpu',
|
||||||
|
pin_memory=True,
|
||||||
|
buffer_count=5,
|
||||||
|
buffer_size=1e8,
|
||||||
|
max_in_cpu=1e9)),
|
||||||
|
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
|
||||||
|
|
||||||
|
colossalai.launch(config=_config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
# FIXME revert back
|
||||||
|
# test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
||||||
|
test_models = ['bert']
|
||||||
|
for model_name in test_models:
|
||||||
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
|
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||||
|
|
||||||
|
# adapt to a Callbale with empty parameters
|
||||||
|
# def module_builder_new():
|
||||||
|
# return model_builder(checkpoint=True)
|
||||||
|
|
||||||
|
zero_model = model_builder(checkpoint=True)
|
||||||
|
torch_model = copy.deepcopy(zero_model).cuda()
|
||||||
|
engine, train_dataloader, _, _ = colossalai.initialize(zero_model,
|
||||||
|
optimizer=optimizer_class,
|
||||||
|
criterion=criterion,
|
||||||
|
train_dataloader=train_dataloader)
|
||||||
|
engine.train()
|
||||||
|
torch_optimizer = optimizer_class(torch_model.parameters())
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
for data, label in train_dataloader:
|
||||||
|
if i > 3:
|
||||||
|
break
|
||||||
|
|
||||||
|
data, label = data.cuda(), label.cuda()
|
||||||
|
|
||||||
|
engine.zero_grad()
|
||||||
|
torch_optimizer.zero_grad()
|
||||||
|
|
||||||
|
if criterion:
|
||||||
|
output = engine(data)
|
||||||
|
loss = engine.criterion(output, label)
|
||||||
|
|
||||||
|
torch_model(data, label)
|
||||||
|
torch_loss = engine.criterion(output, label)
|
||||||
|
else:
|
||||||
|
loss = engine(data, label)
|
||||||
|
torch_loss = torch_model(data, label)
|
||||||
|
|
||||||
|
engine.backward(loss)
|
||||||
|
engine.step()
|
||||||
|
|
||||||
|
torch_loss.backward()
|
||||||
|
torch_optimizer.step()
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
check_sharded_params_padding(torch_model, zero_model, loose=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@pytest.mark.parametrize("world_size", [1, 2])
|
||||||
|
def test_zero_init(world_size):
|
||||||
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_zero_init(world_size=2)
|
Loading…
Reference in New Issue