[booster] implement Gemini plugin (#3352)

* [booster] add gemini plugin

* [booster] update docstr

* [booster] gemini plugin add coloparam convertor

* [booster] fix coloparam convertor

* [booster] fix gemini plugin device

* [booster] add gemini plugin test

* [booster] gemini plugin ignore sync bn

* [booster] skip some model

* [booster] skip some model

* [booster] modify test world size

* [booster] modify test world size

* [booster] skip test
pull/3377/head
ver217 2 years ago committed by GitHub
parent 1a1d68b053
commit 5f2e34e6c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,4 +1,5 @@
from .gemini_plugin import GeminiPlugin
from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin
__all__ = ['Plugin', 'TorchDDPPlugin']
__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin']

@ -0,0 +1,338 @@
import random
import warnings
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
from colossalai.gemini.memory_tracer import MemStats
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import _convert_to_coloparam
from .plugin_base import Plugin
__all__ = ['GeminiPlugin']
def convert_to_colo_param(module: nn.Module) -> None:
"""Convert module's paramters to ColoParameter. This is a workaround and will be deprecated when lazy init is compatible with Gemini.
Args:
module (nn.Module): Module to be converted.
"""
converted_modules = set() # handle shared modules
converted_params = dict() # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference
def convert_recursively(m: nn.Module):
for child in m.children():
if child not in converted_modules:
converted_modules.add(child)
convert_recursively(child)
for name, p in m.named_parameters(recurse=False):
assert not isinstance(p, ColoParameter)
if p in converted_params:
target = converted_params[p]
else:
target = _convert_to_coloparam(p, p.device, p.dtype)
converted_params[p] = target
setattr(m, name, target)
target.shared_param_modules.append(m)
convert_recursively(module)
# optimizer should replace params in group as well. This attr should be deleted after replacing to avoid memory leak
module._converted_params = converted_params
def replace_param_in_group(optimizer: Optimizer, converted_params: dict) -> None:
"""Replace param in optimizer's group with converted ColoParameter.
Args:
optimizer (Optimizer): Optimizer to be replaced.
converted_params (dict): Mapping between (torch.Tensor, ColoTensor).
"""
for group in optimizer.param_groups:
for i, p in enumerate(group['params']):
if p in converted_params:
group['params'][i] = converted_params[p]
class GeminiCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
"""
Load model from checkpoint with automatic unwrapping.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
return super().load_unsharded_model(model, checkpoint, strict=strict)
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str):
"""
Save model to checkpoint but only on master process.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
# as there is communication when get state dict, this must be called on all processes
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
self.save_checkpoint(state_dict, checkpoint)
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
"""
Save optimizer to checkpoint but only on master process.
"""
# TODO(ver217): optimizer state dict is sharded
super().save_unsharded_optimizer(optimizer, checkpoint)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save model to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
class GeminiModel(ModelWrapper):
def __init__(self, module: nn.Module, gemini_config: dict) -> None:
super().__init__(module)
# TODO(ver217): only support Gemini now
convert_to_colo_param(module)
self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config)
def unwrap(self):
# as save/load state dict is coupled with the GeminiDDP, we only return GeminiDDP model
return self.module
class GeminiOptimizer(OptimizerWrapper):
def __init__(self, module: GeminiDDP, optimizer: Optimizer, zero_optim_config: dict, optim_kwargs: dict) -> None:
replace_param_in_group(optimizer, module.module._converted_params)
del module.module._converted_params
optimizer = zero_optim_wrapper(module, optimizer, optim_config=zero_optim_config, **optim_kwargs)
super().__init__(optimizer)
def backward(self, loss: Tensor, *args, **kwargs):
self.optim.backward(loss)
def clip_grad_by_norm(self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> Tensor:
warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm')
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
raise NotImplementedError('Gemini does not support clip_grad_by_value')
class GeminiPlugin(Plugin):
"""
Plugin for Gemini.
Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import GeminiPlugin
>>>
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = GeminiPlugin()
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
Args:
device (torch.device): device to place the model.
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
If the aggregate size of parameters is still samller than the minimum chunk size,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
which will be used when using hybrid CPU optimizer.
This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto".
Defaults to 0.0.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
norm_type (float, optional): norm_type used for `clip_grad_norm`.
"""
def __init__(
self,
device: Optional[torch.device] = None,
placement_policy: str = "cpu",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
search_range_mb: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_mb: float = 32,
memstats: Optional[MemStats] = None,
gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
) -> None:
assert dist.is_initialized(
), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment'
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.gemini_config = dict(
device=(device or get_current_device()),
placement_policy=placement_policy,
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=strict_ddp_mode,
search_range_mb=search_range_mb,
hidden_dim=hidden_dim,
min_chunk_size_mb=min_chunk_size_mb,
memstats=memstats,
)
self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,)
self.optim_kwargs = dict(initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type)
def support_no_sync(self) -> bool:
return False
def control_precision(self) -> bool:
return True
def supported_precisions(self) -> List[str]:
return ['fp16']
def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
return ['cuda']
def prepare_train_dataloader(self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
**kwargs):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
Note:
1. Evaluation datasets should not be passed to this function.
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs)
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
if not isinstance(model, ModelWrapper):
# convert model to sync bn
# FIXME(ver217): gemini does not support sync bn
# In torch/nn/modules/_functions.py, line 22, ``mean, invstd = torch.batch_norm_stats(input, eps)`` will get fp32 mean and invstd even though the input is fp16.
# This inconsistency of dtype will cause the error.
# We have two possible solutions:
# 1. keep batch norm always in fp32. This is hard for gemini, as it use chunks.
# 2. patch sync bn or write a new on. This is relatively easy, but we need to test it.
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
# wrap the model with Gemini
model = GeminiModel(model, self.gemini_config)
if not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs)
return model, optimizer, criterion, dataloader, lr_scheduler
def control_checkpoint_io(self) -> bool:
return True
def get_checkpoint_io(self) -> CheckpointIO:
return GeminiCheckpointIO()

@ -0,0 +1,150 @@
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.kit.model_zoo import model_zoo
def check_gemini_plugin(early_stop: bool = True):
"""check gemini plugin over model zoo
Args:
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
"""
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin)
passed_models = []
failed_info = {} # (model_name, error) pair
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
# These models lead to CUDA error
if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp',
'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext'):
continue
# These models are not compatible with gemini
if name in [
'diffusers_clip_vision_model',
'timm_resnet',
'timm_beit',
'timm_beitv2',
'timm_eca_nfnet',
'timm_efficientformer',
'timm_hrnet_w18_small',
'timm_nf_ecaresnet101',
'timm_nf_regnet_b0',
'timm_skresnet18',
'timm_wide_resnet50_2',
'timm_convit',
'timm_dm_nfnet',
'timm_swin_transformer',
'torchaudio_conformer',
'torchaudio_deepspeech',
'torchaudio_wavernn',
'torchaudio_tacotron',
'deepfm_interactionarch',
'deepfm_simpledeepfmnn',
'dlrm',
'dlrm_interactionarch',
'torchvision_googlenet',
'torchvision_inception_v3',
'torchvision_mobilenet_v3_small',
'torchvision_resnet18',
'torchvision_resnext50_32x4d',
'torchvision_wide_resnet50_2',
'torchvision_vit_b_16',
'torchvision_convnext_base',
'torchvision_swin_s',
'transformers_albert',
'transformers_albert_for_pretraining',
'transformers_bert',
'transformers_bert_for_pretraining',
'transformers_gpt_double_heads',
'torchaudio_hubert_base',
]:
continue
try:
model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
data = {
k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
for k, v in data.items()
}
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
for n, p in model.named_parameters():
assert isinstance(p, ColoParameter), f'{n} is not a ColoParameter'
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
booster.backward(loss, optimizer)
optimizer.step()
passed_models.append(name)
except Exception as e:
failed_info[name] = e
if early_stop:
raise e
if dist.get_rank() == 0:
print(f'Passed models({len(passed_models)}): {passed_models}\n\n')
print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n')
assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()])
def check_dataloader_sharding():
plugin = GeminiPlugin()
# create a custom dasetset with 0 to 10
dataset = torch.utils.data.TensorDataset(torch.arange(0, 10))
train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2)
# get the first batch of data
batch = next(iter(train_dataloader))[0].cuda()
is_rank_0 = dist.get_rank() == 0
if is_rank_0:
batch_to_compare = batch.clone()
else:
batch_to_compare = batch
# pass to the rank 1 value to rank 0
dist.broadcast(batch_to_compare, src=1)
# compare on rank 0
if is_rank_0:
assert not torch.equal(batch,
batch_to_compare), 'Same number was found across ranks but expected it to be different'
def run_dist(rank, world_size, port, early_stop: bool = True):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_dataloader_sharding()
check_gemini_plugin(early_stop=early_stop)
@pytest.mark.skip(reason='Skip gemini plugin test due to OOM')
@rerun_if_address_is_in_use()
def test_gemini_plugin(early_stop: bool = True):
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port(), early_stop=early_stop)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_gemini_plugin(early_stop=False)
Loading…
Cancel
Save