mirror of https://github.com/hpcaitech/ColossalAI
added buffer sync to naive amp model wrapper (#291)
parent
8d653af408
commit
e17e54e32a
|
@ -3,12 +3,15 @@
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from typing import Union, List, Any, Dict
|
||||
from typing import Any
|
||||
from torch.optim import Optimizer
|
||||
import torch.cuda.amp as torch_amp
|
||||
|
||||
from torch.distributed import ReduceOp
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from ._fp16_optimizer import FP16Optimizer
|
||||
|
||||
|
||||
|
@ -49,10 +52,30 @@ class NaiveAMPModel(nn.Module):
|
|||
|
||||
def __init__(self,
|
||||
model: nn.Module,
|
||||
output_to_fp32: bool = True):
|
||||
output_to_fp32: bool = True,
|
||||
parallel_mode: ParallelMode = ParallelMode.DATA,
|
||||
sync_buffer: bool = True):
|
||||
super().__init__()
|
||||
self.model = model.half()
|
||||
self._output_to_fp32 = output_to_fp32
|
||||
self._sync_buf = sync_buffer
|
||||
|
||||
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
|
||||
self._process_group = gpc.get_group(parallel_mode)
|
||||
self._world_size = gpc.get_world_size(parallel_mode)
|
||||
else:
|
||||
self._process_group = None
|
||||
self._world_size = 1
|
||||
self._sync_buf = False
|
||||
self._first_eval_run = False
|
||||
|
||||
@property
|
||||
def sync_buffer(self):
|
||||
return self._sync_buf
|
||||
|
||||
@sync_buffer.setter
|
||||
def sync_buffer(self, state: bool):
|
||||
self._sync_buf = state
|
||||
|
||||
def _convert_to_fp16(self, input_: Any):
|
||||
if isinstance(input_, Tensor) and input_.dtype == torch.float32:
|
||||
|
@ -64,7 +87,46 @@ class NaiveAMPModel(nn.Module):
|
|||
input_ = input_.float()
|
||||
return input_
|
||||
|
||||
def _reduce_module_buffer(self):
|
||||
"""
|
||||
All-reduce the buffers (e.g. running stats of batch normalization) across
|
||||
data parallel ranks so that all the ranks will produce consistent results
|
||||
when given the same input
|
||||
"""
|
||||
buf_list = []
|
||||
|
||||
# find valid buffers
|
||||
for buf in self.model.buffers():
|
||||
if buf is not None:
|
||||
buf_list.append(buf)
|
||||
|
||||
# reduce buffers across data parallel ranks
|
||||
if buf_list:
|
||||
coalesced_buf = _flatten_dense_tensors(buf_list)
|
||||
coalesced_buf.div_(self._world_size)
|
||||
dist.all_reduce(coalesced_buf, op=ReduceOp.SUM, group=self._process_group)
|
||||
unflattened_buf_list = _unflatten_dense_tensors(coalesced_buf, buf_list)
|
||||
for old, new in zip(buf_list, unflattened_buf_list):
|
||||
old.copy_(new)
|
||||
|
||||
def eval(self):
|
||||
self.model.eval()
|
||||
|
||||
# we only sync buffer in the first eval iteration
|
||||
# so that future eval iterations can be done without communication
|
||||
self._first_eval_run = True
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
# reduce buffers after forward will lead to error
|
||||
# as we cannot change the variables needed for gradient computation after forward
|
||||
# so we sync buffer before forward
|
||||
if (self.training or self._first_eval_run) and self._sync_buf:
|
||||
with torch.no_grad():
|
||||
self._reduce_module_buffer()
|
||||
|
||||
if self._first_eval_run:
|
||||
self._first_eval_run = False
|
||||
|
||||
if args:
|
||||
args = [self._convert_to_fp16(arg) for arg in args]
|
||||
if kwargs:
|
||||
|
|
|
@ -16,6 +16,7 @@ from torch.optim.optimizer import Optimizer
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.amp import AMP_TYPE, convert_to_amp
|
||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||
from colossalai.builder.builder import build_gradient_handler
|
||||
from colossalai.context import Config, ConfigException, ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
@ -23,8 +24,7 @@ from colossalai.engine import Engine
|
|||
from colossalai.global_variables import moe_env
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils import (accumulate_gradient, get_current_device,
|
||||
is_using_ddp, is_using_pp, is_using_sequence,
|
||||
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
|
||||
sync_model_param)
|
||||
from colossalai.zero import convert_to_zero, ShardedOptimizer
|
||||
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
|
||||
|
@ -39,21 +39,12 @@ def get_default_parser():
|
|||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, help='path to the config file')
|
||||
parser.add_argument('--host',
|
||||
type=str,
|
||||
help='the master address for distributed training')
|
||||
parser.add_argument('--port',
|
||||
type=int,
|
||||
help='the master port for distributed training')
|
||||
parser.add_argument('--host', type=str, help='the master address for distributed training')
|
||||
parser.add_argument('--port', type=int, help='the master port for distributed training')
|
||||
parser.add_argument('--world_size', type=int, help='world size for distributed training')
|
||||
parser.add_argument('--rank', type=int, help='rank for the default process group')
|
||||
parser.add_argument('--local_rank',
|
||||
type=int,
|
||||
help='local rank on the node')
|
||||
parser.add_argument('--backend',
|
||||
type=str,
|
||||
default='nccl',
|
||||
help='backend for distributed communication')
|
||||
parser.add_argument('--local_rank', type=int, help='local rank on the node')
|
||||
parser.add_argument('--backend', type=str, default='nccl', help='backend for distributed communication')
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -116,9 +107,11 @@ def launch(config: Union[str, Path, Config, Dict],
|
|||
|
||||
if verbose:
|
||||
logger = get_dist_logger()
|
||||
logger.info(f'Distributed environment is initialized, '
|
||||
logger.info(
|
||||
f'Distributed environment is initialized, '
|
||||
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
|
||||
f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0])
|
||||
f'tensor parallel size: {gpc.tensor_parallel_size}',
|
||||
ranks=[0])
|
||||
|
||||
|
||||
def launch_from_slurm(config: Union[str, Path, Config, Dict],
|
||||
|
@ -261,9 +254,11 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||
|
||||
# print config
|
||||
if verbose:
|
||||
logger.info(f"\n========== Your Config ========\n"
|
||||
logger.info(
|
||||
f"\n========== Your Config ========\n"
|
||||
f"{pprint.pformat(gpc.config)}\n"
|
||||
f"================================\n", ranks=[0])
|
||||
f"================================\n",
|
||||
ranks=[0])
|
||||
|
||||
# cudnn
|
||||
cudnn_benchmark = config.get('cudnn_benchmark', True)
|
||||
|
@ -271,8 +266,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||
torch.backends.cudnn.benchmark = cudnn_benchmark
|
||||
torch.backends.cudnn.deterministic = cudnn_deterministic
|
||||
if verbose:
|
||||
logger.info(
|
||||
f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
|
||||
logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
|
||||
|
||||
# first sync model across dp ranks
|
||||
model.to(get_current_device())
|
||||
|
@ -321,11 +315,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||
if zero_cfg is not None:
|
||||
cfg_ = zero_cfg.copy()
|
||||
level = cfg_.pop('level')
|
||||
model, optimizer = convert_to_zero(model=model,
|
||||
optimizer=optimizer,
|
||||
level=level,
|
||||
zero_config=cfg_
|
||||
)
|
||||
model, optimizer = convert_to_zero(model=model, optimizer=optimizer, level=level, zero_config=cfg_)
|
||||
|
||||
# gradient handler
|
||||
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
|
||||
|
@ -350,21 +340,22 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
elif is_using_sequence():
|
||||
model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
|
||||
model = DDP(model,
|
||||
process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
|
||||
device_ids=[torch.cuda.current_device()])
|
||||
if verbose:
|
||||
logger.info(
|
||||
'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', ranks=[0])
|
||||
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism',
|
||||
ranks=[0])
|
||||
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
|
||||
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA), device_ids=[torch.cuda.current_device()])
|
||||
if verbose:
|
||||
logger.info(
|
||||
'Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
|
||||
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
|
||||
elif is_using_ddp():
|
||||
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
|
||||
if verbose:
|
||||
logger.info(
|
||||
"Data parallel training is detected when using pipeline parallel, DataParallelGradientHandler is automatically "
|
||||
"Data parallel training is detected when using pipeline parallel, "
|
||||
"DataParallelGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
# add pipeline parallel gradient handler, if pipeline shared module is detected
|
||||
|
@ -383,7 +374,13 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||
else:
|
||||
if not isinstance(gradient_handler_cfg, list):
|
||||
raise ConfigException(
|
||||
f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}")
|
||||
f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}"
|
||||
)
|
||||
|
||||
# turn off sync buffer for NaiveAMPModel if using torch DDP and NaiveAMPModel at the same time
|
||||
# to avoid duplicated buffer synchronization
|
||||
if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel):
|
||||
model.module.sync_buffer = False
|
||||
|
||||
if gradient_handler_cfg is None:
|
||||
gradient_handlers = None
|
||||
|
|
|
@ -9,10 +9,7 @@ from .sharded_model import ShardedModel
|
|||
from .sharded_optim import ShardedOptimizer
|
||||
|
||||
|
||||
def convert_to_zero(model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
level: int,
|
||||
zero_config: dict):
|
||||
def convert_to_zero(model: nn.Module, optimizer: Optimizer, level: int, zero_config: dict):
|
||||
"""
|
||||
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
|
||||
|
||||
|
@ -31,11 +28,16 @@ def convert_to_zero(model: nn.Module,
|
|||
assert 1 <= level <= 3, 'Only ZERO Optimizer Level 1-3 are provided'
|
||||
if level in [1, 2]:
|
||||
if level == 2:
|
||||
assert config['partition_grad'], 'ZeRO Optimizer requires partition_grad to be True'
|
||||
if 'partition_grad' in zero_config:
|
||||
assert zero_config['partition_grad'], \
|
||||
'Sharded Optimizer requires partition_grad to be True'
|
||||
else:
|
||||
zero_config['partiton_grad'] = True
|
||||
model = NaiveAMPModel(model, output_to_fp32=True)
|
||||
optimizer = ShardedOptimizer(model.parameters(), *zero_config)
|
||||
optimizer = ShardedOptimizer(optimizer, **zero_config)
|
||||
else:
|
||||
model = ShardedModel(module=model, **zero_config)
|
||||
return model, optimizer
|
||||
|
||||
|
||||
__all__ = ['convert_to_zero', 'ShardedModel', 'ShardedOptimizer']
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from torchvision.models import resnet50
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
# need to configure cudnn deterministic so that
|
||||
# randomness of convolution layers will be disabled
|
||||
colossalai.launch(config=dict(zero=dict(level=2, partition_grad=True),
|
||||
cudnn_determinstic=True,
|
||||
cudnn_benchmark=False),
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
model = resnet50()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
engine, *args = colossalai.initialize(model, optimizer, criterion)
|
||||
|
||||
# train for dummy iterations
|
||||
engine.train()
|
||||
for _ in range(2):
|
||||
data = torch.rand(4, 3, 128, 128).cuda().half()
|
||||
label = torch.randint(0, 10, size=(4,)).cuda()
|
||||
engine.zero_grad()
|
||||
out = engine(data)
|
||||
loss = engine.criterion(out, label)
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
|
||||
# test
|
||||
# need to make sure the batch norm stats are synchronized
|
||||
# so that given the same input, the model will produce the same
|
||||
# output on different ranks
|
||||
engine.eval()
|
||||
data = torch.rand(4, 3, 128, 128).cuda().half()
|
||||
dist.broadcast(data, src=0, group=gpc.get_group(ParallelMode.DATA))
|
||||
|
||||
# predict
|
||||
out = engine(data)
|
||||
|
||||
# test if results are equal
|
||||
tensor_list = [torch.empty_like(out) for _ in range(world_size - 1)]
|
||||
tensor_list.insert(rank, out)
|
||||
dist.all_gather(tensor_list=tensor_list, tensor=out, group=gpc.get_group(ParallelMode.DATA))
|
||||
|
||||
assert torch.all(tensor_list[0] == tensor_list[1]), \
|
||||
'expected the output from different ranks to be the same, but got different values'
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_sharded_optim_with_sync_bn():
|
||||
"""
|
||||
This test is to make sure that buffers are synchronized between ranks
|
||||
when using ZeRO. An example of module buffer is the running stats of
|
||||
BatchNormalization layer, i.e. mean and var.
|
||||
|
||||
If the buffers are not synchronized, the model will produce different
|
||||
output even though the input and parameters are the same. This is not
|
||||
wanted if we are doing predictions.
|
||||
|
||||
"""
|
||||
world_size = 2
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sharded_optim_with_sync_bn()
|
Loading…
Reference in New Issue