added buffer sync to naive amp model wrapper (#291)

pull/394/head
Frank Lee 2022-03-02 16:47:17 +08:00
parent 8d653af408
commit e17e54e32a
4 changed files with 191 additions and 46 deletions

View File

@ -3,12 +3,15 @@
import torch
import torch.nn as nn
import torch.distributed as dist
from torch import Tensor
from typing import Union, List, Any, Dict
from typing import Any
from torch.optim import Optimizer
import torch.cuda.amp as torch_amp
from torch.distributed import ReduceOp
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.nn.optimizer import ColossalaiOptimizer
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from ._fp16_optimizer import FP16Optimizer
@ -43,16 +46,36 @@ class NaiveAMPOptimizer(ColossalaiOptimizer):
class NaiveAMPModel(nn.Module):
"""A wrapper class for model to cast the model into fp16 and
"""A wrapper class for model to cast the model into fp16 and
automatically cast the input and output
"""
def __init__(self,
model: nn.Module,
output_to_fp32: bool = True):
output_to_fp32: bool = True,
parallel_mode: ParallelMode = ParallelMode.DATA,
sync_buffer: bool = True):
super().__init__()
self.model = model.half()
self._output_to_fp32 = output_to_fp32
self._sync_buf = sync_buffer
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
self._process_group = gpc.get_group(parallel_mode)
self._world_size = gpc.get_world_size(parallel_mode)
else:
self._process_group = None
self._world_size = 1
self._sync_buf = False
self._first_eval_run = False
@property
def sync_buffer(self):
return self._sync_buf
@sync_buffer.setter
def sync_buffer(self, state: bool):
self._sync_buf = state
def _convert_to_fp16(self, input_: Any):
if isinstance(input_, Tensor) and input_.dtype == torch.float32:
@ -64,7 +87,46 @@ class NaiveAMPModel(nn.Module):
input_ = input_.float()
return input_
def _reduce_module_buffer(self):
"""
All-reduce the buffers (e.g. running stats of batch normalization) across
data parallel ranks so that all the ranks will produce consistent results
when given the same input
"""
buf_list = []
# find valid buffers
for buf in self.model.buffers():
if buf is not None:
buf_list.append(buf)
# reduce buffers across data parallel ranks
if buf_list:
coalesced_buf = _flatten_dense_tensors(buf_list)
coalesced_buf.div_(self._world_size)
dist.all_reduce(coalesced_buf, op=ReduceOp.SUM, group=self._process_group)
unflattened_buf_list = _unflatten_dense_tensors(coalesced_buf, buf_list)
for old, new in zip(buf_list, unflattened_buf_list):
old.copy_(new)
def eval(self):
self.model.eval()
# we only sync buffer in the first eval iteration
# so that future eval iterations can be done without communication
self._first_eval_run = True
def forward(self, *args, **kwargs):
# reduce buffers after forward will lead to error
# as we cannot change the variables needed for gradient computation after forward
# so we sync buffer before forward
if (self.training or self._first_eval_run) and self._sync_buf:
with torch.no_grad():
self._reduce_module_buffer()
if self._first_eval_run:
self._first_eval_run = False
if args:
args = [self._convert_to_fp16(arg) for arg in args]
if kwargs:

View File

@ -16,6 +16,7 @@ from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from colossalai.amp import AMP_TYPE, convert_to_amp
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.core import global_context as gpc
@ -23,8 +24,7 @@ from colossalai.engine import Engine
from colossalai.global_variables import moe_env
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
from colossalai.utils import (accumulate_gradient, get_current_device,
is_using_ddp, is_using_pp, is_using_sequence,
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
sync_model_param)
from colossalai.zero import convert_to_zero, ShardedOptimizer
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
@ -39,21 +39,12 @@ def get_default_parser():
"""
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, help='path to the config file')
parser.add_argument('--host',
type=str,
help='the master address for distributed training')
parser.add_argument('--port',
type=int,
help='the master port for distributed training')
parser.add_argument('--host', type=str, help='the master address for distributed training')
parser.add_argument('--port', type=int, help='the master port for distributed training')
parser.add_argument('--world_size', type=int, help='world size for distributed training')
parser.add_argument('--rank', type=int, help='rank for the default process group')
parser.add_argument('--local_rank',
type=int,
help='local rank on the node')
parser.add_argument('--backend',
type=str,
default='nccl',
help='backend for distributed communication')
parser.add_argument('--local_rank', type=int, help='local rank on the node')
parser.add_argument('--backend', type=str, default='nccl', help='backend for distributed communication')
return parser
@ -116,9 +107,11 @@ def launch(config: Union[str, Path, Config, Dict],
if verbose:
logger = get_dist_logger()
logger.info(f'Distributed environment is initialized, '
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0])
logger.info(
f'Distributed environment is initialized, '
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
f'tensor parallel size: {gpc.tensor_parallel_size}',
ranks=[0])
def launch_from_slurm(config: Union[str, Path, Config, Dict],
@ -261,9 +254,11 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
# print config
if verbose:
logger.info(f"\n========== Your Config ========\n"
f"{pprint.pformat(gpc.config)}\n"
f"================================\n", ranks=[0])
logger.info(
f"\n========== Your Config ========\n"
f"{pprint.pformat(gpc.config)}\n"
f"================================\n",
ranks=[0])
# cudnn
cudnn_benchmark = config.get('cudnn_benchmark', True)
@ -271,8 +266,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
torch.backends.cudnn.benchmark = cudnn_benchmark
torch.backends.cudnn.deterministic = cudnn_deterministic
if verbose:
logger.info(
f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
# first sync model across dp ranks
model.to(get_current_device())
@ -321,11 +315,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
if zero_cfg is not None:
cfg_ = zero_cfg.copy()
level = cfg_.pop('level')
model, optimizer = convert_to_zero(model=model,
optimizer=optimizer,
level=level,
zero_config=cfg_
)
model, optimizer = convert_to_zero(model=model, optimizer=optimizer, level=level, zero_config=cfg_)
# gradient handler
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
@ -350,21 +340,22 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
"added even though not specified in the configuration",
ranks=[0])
elif is_using_sequence():
model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
model = DDP(model,
process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
device_ids=[torch.cuda.current_device()])
if verbose:
logger.info(
'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', ranks=[0])
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism',
ranks=[0])
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA), device_ids=[torch.cuda.current_device()])
if verbose:
logger.info(
'Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
elif is_using_ddp():
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
if verbose:
logger.info(
"Data parallel training is detected when using pipeline parallel, DataParallelGradientHandler is automatically "
"Data parallel training is detected when using pipeline parallel, "
"DataParallelGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
# add pipeline parallel gradient handler, if pipeline shared module is detected
@ -383,7 +374,13 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
else:
if not isinstance(gradient_handler_cfg, list):
raise ConfigException(
f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}")
f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}"
)
# turn off sync buffer for NaiveAMPModel if using torch DDP and NaiveAMPModel at the same time
# to avoid duplicated buffer synchronization
if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel):
model.module.sync_buffer = False
if gradient_handler_cfg is None:
gradient_handlers = None

View File

@ -9,10 +9,7 @@ from .sharded_model import ShardedModel
from .sharded_optim import ShardedOptimizer
def convert_to_zero(model: nn.Module,
optimizer: Optimizer,
level: int,
zero_config: dict):
def convert_to_zero(model: nn.Module, optimizer: Optimizer, level: int, zero_config: dict):
"""
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
@ -31,11 +28,16 @@ def convert_to_zero(model: nn.Module,
assert 1 <= level <= 3, 'Only ZERO Optimizer Level 1-3 are provided'
if level in [1, 2]:
if level == 2:
assert config['partition_grad'], 'ZeRO Optimizer requires partition_grad to be True'
if 'partition_grad' in zero_config:
assert zero_config['partition_grad'], \
'Sharded Optimizer requires partition_grad to be True'
else:
zero_config['partiton_grad'] = True
model = NaiveAMPModel(model, output_to_fp32=True)
optimizer = ShardedOptimizer(model.parameters(), *zero_config)
optimizer = ShardedOptimizer(optimizer, **zero_config)
else:
model = ShardedModel(module=model, **zero_config)
return model, optimizer
__all__ = ['convert_to_zero', 'ShardedModel', 'ShardedOptimizer']

View File

@ -0,0 +1,84 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
from torchvision.models import resnet50
import torch.distributed as dist
def run_dist(rank, world_size, port):
# need to configure cudnn deterministic so that
# randomness of convolution layers will be disabled
colossalai.launch(config=dict(zero=dict(level=2, partition_grad=True),
cudnn_determinstic=True,
cudnn_benchmark=False),
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
model = resnet50()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
engine, *args = colossalai.initialize(model, optimizer, criterion)
# train for dummy iterations
engine.train()
for _ in range(2):
data = torch.rand(4, 3, 128, 128).cuda().half()
label = torch.randint(0, 10, size=(4,)).cuda()
engine.zero_grad()
out = engine(data)
loss = engine.criterion(out, label)
engine.backward(loss)
engine.step()
# test
# need to make sure the batch norm stats are synchronized
# so that given the same input, the model will produce the same
# output on different ranks
engine.eval()
data = torch.rand(4, 3, 128, 128).cuda().half()
dist.broadcast(data, src=0, group=gpc.get_group(ParallelMode.DATA))
# predict
out = engine(data)
# test if results are equal
tensor_list = [torch.empty_like(out) for _ in range(world_size - 1)]
tensor_list.insert(rank, out)
dist.all_gather(tensor_list=tensor_list, tensor=out, group=gpc.get_group(ParallelMode.DATA))
assert torch.all(tensor_list[0] == tensor_list[1]), \
'expected the output from different ranks to be the same, but got different values'
@pytest.mark.dist
def test_sharded_optim_with_sync_bn():
"""
This test is to make sure that buffers are synchronized between ranks
when using ZeRO. An example of module buffer is the running stats of
BatchNormalization layer, i.e. mean and var.
If the buffers are not synchronized, the model will produce different
output even though the input and parameters are the same. This is not
wanted if we are doing predictions.
"""
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_sharded_optim_with_sync_bn()