mirror of https://github.com/hpcaitech/ColossalAI
[booster] implement Gemini plugin (#3352)
* [booster] add gemini plugin * [booster] update docstr * [booster] gemini plugin add coloparam convertor * [booster] fix coloparam convertor * [booster] fix gemini plugin device * [booster] add gemini plugin test * [booster] gemini plugin ignore sync bn * [booster] skip some model * [booster] skip some model * [booster] modify test world size * [booster] modify test world size * [booster] skip testpull/3377/head
parent
1a1d68b053
commit
5f2e34e6c9
@ -1,4 +1,5 @@
|
|||||||
|
from .gemini_plugin import GeminiPlugin
|
||||||
from .plugin_base import Plugin
|
from .plugin_base import Plugin
|
||||||
from .torch_ddp_plugin import TorchDDPPlugin
|
from .torch_ddp_plugin import TorchDDPPlugin
|
||||||
|
|
||||||
__all__ = ['Plugin', 'TorchDDPPlugin']
|
__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin']
|
||||||
|
@ -0,0 +1,338 @@
|
|||||||
|
import random
|
||||||
|
import warnings
|
||||||
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
|
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||||
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.gemini.memory_tracer import MemStats
|
||||||
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
|
from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
|
||||||
|
from colossalai.tensor.colo_parameter import ColoParameter
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from colossalai.utils.model.colo_init_context import _convert_to_coloparam
|
||||||
|
|
||||||
|
from .plugin_base import Plugin
|
||||||
|
|
||||||
|
__all__ = ['GeminiPlugin']
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_colo_param(module: nn.Module) -> None:
|
||||||
|
"""Convert module's paramters to ColoParameter. This is a workaround and will be deprecated when lazy init is compatible with Gemini.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (nn.Module): Module to be converted.
|
||||||
|
"""
|
||||||
|
converted_modules = set() # handle shared modules
|
||||||
|
converted_params = dict() # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference
|
||||||
|
|
||||||
|
def convert_recursively(m: nn.Module):
|
||||||
|
for child in m.children():
|
||||||
|
if child not in converted_modules:
|
||||||
|
converted_modules.add(child)
|
||||||
|
convert_recursively(child)
|
||||||
|
|
||||||
|
for name, p in m.named_parameters(recurse=False):
|
||||||
|
assert not isinstance(p, ColoParameter)
|
||||||
|
if p in converted_params:
|
||||||
|
target = converted_params[p]
|
||||||
|
else:
|
||||||
|
target = _convert_to_coloparam(p, p.device, p.dtype)
|
||||||
|
converted_params[p] = target
|
||||||
|
setattr(m, name, target)
|
||||||
|
target.shared_param_modules.append(m)
|
||||||
|
|
||||||
|
convert_recursively(module)
|
||||||
|
|
||||||
|
# optimizer should replace params in group as well. This attr should be deleted after replacing to avoid memory leak
|
||||||
|
module._converted_params = converted_params
|
||||||
|
|
||||||
|
|
||||||
|
def replace_param_in_group(optimizer: Optimizer, converted_params: dict) -> None:
|
||||||
|
"""Replace param in optimizer's group with converted ColoParameter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Optimizer to be replaced.
|
||||||
|
converted_params (dict): Mapping between (torch.Tensor, ColoTensor).
|
||||||
|
"""
|
||||||
|
for group in optimizer.param_groups:
|
||||||
|
for i, p in enumerate(group['params']):
|
||||||
|
if p in converted_params:
|
||||||
|
group['params'][i] = converted_params[p]
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.coordinator = DistCoordinator()
|
||||||
|
|
||||||
|
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
|
||||||
|
"""
|
||||||
|
Load model from checkpoint with automatic unwrapping.
|
||||||
|
"""
|
||||||
|
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||||
|
return super().load_unsharded_model(model, checkpoint, strict=strict)
|
||||||
|
|
||||||
|
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str):
|
||||||
|
"""
|
||||||
|
Save model to checkpoint but only on master process.
|
||||||
|
"""
|
||||||
|
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||||
|
# as there is communication when get state dict, this must be called on all processes
|
||||||
|
state_dict = model.state_dict(only_rank_0=True)
|
||||||
|
if self.coordinator.is_master():
|
||||||
|
self.save_checkpoint(state_dict, checkpoint)
|
||||||
|
|
||||||
|
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||||
|
"""
|
||||||
|
Save optimizer to checkpoint but only on master process.
|
||||||
|
"""
|
||||||
|
# TODO(ver217): optimizer state dict is sharded
|
||||||
|
super().save_unsharded_optimizer(optimizer, checkpoint)
|
||||||
|
|
||||||
|
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||||
|
"""
|
||||||
|
Save model to checkpoint but only on master process.
|
||||||
|
"""
|
||||||
|
if self.coordinator.is_master():
|
||||||
|
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiModel(ModelWrapper):
|
||||||
|
|
||||||
|
def __init__(self, module: nn.Module, gemini_config: dict) -> None:
|
||||||
|
super().__init__(module)
|
||||||
|
# TODO(ver217): only support Gemini now
|
||||||
|
convert_to_colo_param(module)
|
||||||
|
self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config)
|
||||||
|
|
||||||
|
def unwrap(self):
|
||||||
|
# as save/load state dict is coupled with the GeminiDDP, we only return GeminiDDP model
|
||||||
|
return self.module
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiOptimizer(OptimizerWrapper):
|
||||||
|
|
||||||
|
def __init__(self, module: GeminiDDP, optimizer: Optimizer, zero_optim_config: dict, optim_kwargs: dict) -> None:
|
||||||
|
replace_param_in_group(optimizer, module.module._converted_params)
|
||||||
|
del module.module._converted_params
|
||||||
|
optimizer = zero_optim_wrapper(module, optimizer, optim_config=zero_optim_config, **optim_kwargs)
|
||||||
|
super().__init__(optimizer)
|
||||||
|
|
||||||
|
def backward(self, loss: Tensor, *args, **kwargs):
|
||||||
|
self.optim.backward(loss)
|
||||||
|
|
||||||
|
def clip_grad_by_norm(self,
|
||||||
|
max_norm: Union[float, int],
|
||||||
|
norm_type: Union[float, int] = 2,
|
||||||
|
error_if_nonfinite: bool = False,
|
||||||
|
*args,
|
||||||
|
**kwargs) -> Tensor:
|
||||||
|
warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm')
|
||||||
|
|
||||||
|
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
|
||||||
|
raise NotImplementedError('Gemini does not support clip_grad_by_value')
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiPlugin(Plugin):
|
||||||
|
"""
|
||||||
|
Plugin for Gemini.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from colossalai.booster import Booster
|
||||||
|
>>> from colossalai.booster.plugin import GeminiPlugin
|
||||||
|
>>>
|
||||||
|
>>> model, train_dataset, optimizer, criterion = ...
|
||||||
|
>>> plugin = GeminiPlugin()
|
||||||
|
|
||||||
|
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
|
||||||
|
>>> booster = Booster(plugin=plugin)
|
||||||
|
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (torch.device): device to place the model.
|
||||||
|
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
|
||||||
|
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
|
||||||
|
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
|
||||||
|
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
|
||||||
|
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
|
||||||
|
hidden_dim (int, optional): the hidden dimension of DNN.
|
||||||
|
Users can provide this argument to speed up searching.
|
||||||
|
If users do not know this argument before training, it is ok. We will use a default value 1024.
|
||||||
|
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
|
||||||
|
If the aggregate size of parameters is still samller than the minimum chunk size,
|
||||||
|
all parameters will be compacted into one small chunk.
|
||||||
|
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
|
||||||
|
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
|
||||||
|
which will be used when using hybrid CPU optimizer.
|
||||||
|
This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto".
|
||||||
|
Defaults to 0.0.
|
||||||
|
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
|
||||||
|
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
|
||||||
|
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
|
||||||
|
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
|
||||||
|
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
|
||||||
|
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
|
||||||
|
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
|
||||||
|
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
|
||||||
|
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
|
||||||
|
norm_type (float, optional): norm_type used for `clip_grad_norm`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
placement_policy: str = "cpu",
|
||||||
|
pin_memory: bool = False,
|
||||||
|
force_outputs_fp32: bool = False,
|
||||||
|
strict_ddp_mode: bool = False,
|
||||||
|
search_range_mb: int = 32,
|
||||||
|
hidden_dim: Optional[int] = None,
|
||||||
|
min_chunk_size_mb: float = 32,
|
||||||
|
memstats: Optional[MemStats] = None,
|
||||||
|
gpu_margin_mem_ratio: float = 0.0,
|
||||||
|
initial_scale: float = 2**32,
|
||||||
|
min_scale: float = 1,
|
||||||
|
growth_factor: float = 2,
|
||||||
|
backoff_factor: float = 0.5,
|
||||||
|
growth_interval: int = 1000,
|
||||||
|
hysteresis: int = 2,
|
||||||
|
max_scale: float = 2**32,
|
||||||
|
max_norm: float = 0.0,
|
||||||
|
norm_type: float = 2.0,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
assert dist.is_initialized(
|
||||||
|
), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment'
|
||||||
|
self.rank = dist.get_rank()
|
||||||
|
self.world_size = dist.get_world_size()
|
||||||
|
self.gemini_config = dict(
|
||||||
|
device=(device or get_current_device()),
|
||||||
|
placement_policy=placement_policy,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
force_outputs_fp32=force_outputs_fp32,
|
||||||
|
strict_ddp_mode=strict_ddp_mode,
|
||||||
|
search_range_mb=search_range_mb,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
min_chunk_size_mb=min_chunk_size_mb,
|
||||||
|
memstats=memstats,
|
||||||
|
)
|
||||||
|
self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,)
|
||||||
|
self.optim_kwargs = dict(initial_scale=initial_scale,
|
||||||
|
growth_factor=growth_factor,
|
||||||
|
backoff_factor=backoff_factor,
|
||||||
|
growth_interval=growth_interval,
|
||||||
|
hysteresis=hysteresis,
|
||||||
|
min_scale=min_scale,
|
||||||
|
max_scale=max_scale,
|
||||||
|
max_norm=max_norm,
|
||||||
|
norm_type=norm_type)
|
||||||
|
|
||||||
|
def support_no_sync(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def control_precision(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def supported_precisions(self) -> List[str]:
|
||||||
|
return ['fp16']
|
||||||
|
|
||||||
|
def control_device(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def supported_devices(self) -> List[str]:
|
||||||
|
return ['cuda']
|
||||||
|
|
||||||
|
def prepare_train_dataloader(self,
|
||||||
|
dataset,
|
||||||
|
batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
seed=1024,
|
||||||
|
drop_last=False,
|
||||||
|
pin_memory=False,
|
||||||
|
num_workers=0,
|
||||||
|
**kwargs):
|
||||||
|
r"""
|
||||||
|
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||||
|
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
1. Evaluation datasets should not be passed to this function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
|
||||||
|
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
|
||||||
|
seed (int, optional): Random worker seed for sampling, defaults to 1024.
|
||||||
|
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
|
||||||
|
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
|
||||||
|
is not divisible by the batch size. If False and the size of dataset is not divisible by
|
||||||
|
the batch size, then the last batch will be smaller, defaults to False.
|
||||||
|
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
|
||||||
|
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
|
||||||
|
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
|
||||||
|
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||||
|
"""
|
||||||
|
_kwargs = kwargs.copy()
|
||||||
|
sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
|
||||||
|
|
||||||
|
# Deterministic dataloader
|
||||||
|
def seed_worker(worker_id):
|
||||||
|
worker_seed = seed
|
||||||
|
np.random.seed(worker_seed)
|
||||||
|
torch.manual_seed(worker_seed)
|
||||||
|
random.seed(worker_seed)
|
||||||
|
|
||||||
|
return DataLoader(dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
sampler=sampler,
|
||||||
|
worker_init_fn=seed_worker,
|
||||||
|
drop_last=drop_last,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
num_workers=num_workers,
|
||||||
|
**_kwargs)
|
||||||
|
|
||||||
|
def configure(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
criterion: Callable = None,
|
||||||
|
dataloader: DataLoader = None,
|
||||||
|
lr_scheduler: LRScheduler = None,
|
||||||
|
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
||||||
|
|
||||||
|
if not isinstance(model, ModelWrapper):
|
||||||
|
# convert model to sync bn
|
||||||
|
# FIXME(ver217): gemini does not support sync bn
|
||||||
|
# In torch/nn/modules/_functions.py, line 22, ``mean, invstd = torch.batch_norm_stats(input, eps)`` will get fp32 mean and invstd even though the input is fp16.
|
||||||
|
# This inconsistency of dtype will cause the error.
|
||||||
|
# We have two possible solutions:
|
||||||
|
# 1. keep batch norm always in fp32. This is hard for gemini, as it use chunks.
|
||||||
|
# 2. patch sync bn or write a new on. This is relatively easy, but we need to test it.
|
||||||
|
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
|
||||||
|
|
||||||
|
# wrap the model with Gemini
|
||||||
|
model = GeminiModel(model, self.gemini_config)
|
||||||
|
|
||||||
|
if not isinstance(optimizer, OptimizerWrapper):
|
||||||
|
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs)
|
||||||
|
|
||||||
|
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||||
|
|
||||||
|
def control_checkpoint_io(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_checkpoint_io(self) -> CheckpointIO:
|
||||||
|
return GeminiCheckpointIO()
|
@ -0,0 +1,150 @@
|
|||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import GeminiPlugin
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.tensor.colo_parameter import ColoParameter
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
|
def check_gemini_plugin(early_stop: bool = True):
|
||||||
|
"""check gemini plugin over model zoo
|
||||||
|
|
||||||
|
Args:
|
||||||
|
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
|
||||||
|
"""
|
||||||
|
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5)
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
|
passed_models = []
|
||||||
|
failed_info = {} # (model_name, error) pair
|
||||||
|
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
|
||||||
|
# These models lead to CUDA error
|
||||||
|
if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp',
|
||||||
|
'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext'):
|
||||||
|
continue
|
||||||
|
# These models are not compatible with gemini
|
||||||
|
if name in [
|
||||||
|
'diffusers_clip_vision_model',
|
||||||
|
'timm_resnet',
|
||||||
|
'timm_beit',
|
||||||
|
'timm_beitv2',
|
||||||
|
'timm_eca_nfnet',
|
||||||
|
'timm_efficientformer',
|
||||||
|
'timm_hrnet_w18_small',
|
||||||
|
'timm_nf_ecaresnet101',
|
||||||
|
'timm_nf_regnet_b0',
|
||||||
|
'timm_skresnet18',
|
||||||
|
'timm_wide_resnet50_2',
|
||||||
|
'timm_convit',
|
||||||
|
'timm_dm_nfnet',
|
||||||
|
'timm_swin_transformer',
|
||||||
|
'torchaudio_conformer',
|
||||||
|
'torchaudio_deepspeech',
|
||||||
|
'torchaudio_wavernn',
|
||||||
|
'torchaudio_tacotron',
|
||||||
|
'deepfm_interactionarch',
|
||||||
|
'deepfm_simpledeepfmnn',
|
||||||
|
'dlrm',
|
||||||
|
'dlrm_interactionarch',
|
||||||
|
'torchvision_googlenet',
|
||||||
|
'torchvision_inception_v3',
|
||||||
|
'torchvision_mobilenet_v3_small',
|
||||||
|
'torchvision_resnet18',
|
||||||
|
'torchvision_resnext50_32x4d',
|
||||||
|
'torchvision_wide_resnet50_2',
|
||||||
|
'torchvision_vit_b_16',
|
||||||
|
'torchvision_convnext_base',
|
||||||
|
'torchvision_swin_s',
|
||||||
|
'transformers_albert',
|
||||||
|
'transformers_albert_for_pretraining',
|
||||||
|
'transformers_bert',
|
||||||
|
'transformers_bert_for_pretraining',
|
||||||
|
'transformers_gpt_double_heads',
|
||||||
|
'torchaudio_hubert_base',
|
||||||
|
]:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
model = model_fn()
|
||||||
|
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||||
|
criterion = lambda x: x.mean()
|
||||||
|
data = data_gen_fn()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
|
||||||
|
for k, v in data.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||||
|
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
assert isinstance(p, ColoParameter), f'{n} is not a ColoParameter'
|
||||||
|
|
||||||
|
output = model(**data)
|
||||||
|
output = output_transform_fn(output)
|
||||||
|
output_key = list(output.keys())[0]
|
||||||
|
loss = criterion(output[output_key])
|
||||||
|
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
|
optimizer.step()
|
||||||
|
passed_models.append(name)
|
||||||
|
except Exception as e:
|
||||||
|
failed_info[name] = e
|
||||||
|
if early_stop:
|
||||||
|
raise e
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
print(f'Passed models({len(passed_models)}): {passed_models}\n\n')
|
||||||
|
print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n')
|
||||||
|
assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()])
|
||||||
|
|
||||||
|
|
||||||
|
def check_dataloader_sharding():
|
||||||
|
plugin = GeminiPlugin()
|
||||||
|
|
||||||
|
# create a custom dasetset with 0 to 10
|
||||||
|
dataset = torch.utils.data.TensorDataset(torch.arange(0, 10))
|
||||||
|
train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2)
|
||||||
|
|
||||||
|
# get the first batch of data
|
||||||
|
batch = next(iter(train_dataloader))[0].cuda()
|
||||||
|
is_rank_0 = dist.get_rank() == 0
|
||||||
|
|
||||||
|
if is_rank_0:
|
||||||
|
batch_to_compare = batch.clone()
|
||||||
|
else:
|
||||||
|
batch_to_compare = batch
|
||||||
|
# pass to the rank 1 value to rank 0
|
||||||
|
dist.broadcast(batch_to_compare, src=1)
|
||||||
|
|
||||||
|
# compare on rank 0
|
||||||
|
if is_rank_0:
|
||||||
|
assert not torch.equal(batch,
|
||||||
|
batch_to_compare), 'Same number was found across ranks but expected it to be different'
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port, early_stop: bool = True):
|
||||||
|
# init dist env
|
||||||
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||||
|
check_dataloader_sharding()
|
||||||
|
check_gemini_plugin(early_stop=early_stop)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason='Skip gemini plugin test due to OOM')
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_gemini_plugin(early_stop: bool = True):
|
||||||
|
world_size = 2
|
||||||
|
run_func = partial(run_dist, world_size=world_size, port=free_port(), early_stop=early_stop)
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_gemini_plugin(early_stop=False)
|
Loading…
Reference in new issue