diff --git a/colossalai/booster/plugin/__init__.py b/colossalai/booster/plugin/__init__.py index 3328fe2b9..8e09b6cb2 100644 --- a/colossalai/booster/plugin/__init__.py +++ b/colossalai/booster/plugin/__init__.py @@ -1,4 +1,5 @@ +from .gemini_plugin import GeminiPlugin from .plugin_base import Plugin from .torch_ddp_plugin import TorchDDPPlugin -__all__ = ['Plugin', 'TorchDDPPlugin'] +__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin'] diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py new file mode 100644 index 000000000..c3c9d007d --- /dev/null +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -0,0 +1,338 @@ +import random +import warnings +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from torch import Tensor +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO +from colossalai.cluster import DistCoordinator +from colossalai.gemini.memory_tracer import MemStats +from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper +from colossalai.tensor.colo_parameter import ColoParameter +from colossalai.utils import get_current_device +from colossalai.utils.model.colo_init_context import _convert_to_coloparam + +from .plugin_base import Plugin + +__all__ = ['GeminiPlugin'] + + +def convert_to_colo_param(module: nn.Module) -> None: + """Convert module's paramters to ColoParameter. This is a workaround and will be deprecated when lazy init is compatible with Gemini. + + Args: + module (nn.Module): Module to be converted. + """ + converted_modules = set() # handle shared modules + converted_params = dict() # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference + + def convert_recursively(m: nn.Module): + for child in m.children(): + if child not in converted_modules: + converted_modules.add(child) + convert_recursively(child) + + for name, p in m.named_parameters(recurse=False): + assert not isinstance(p, ColoParameter) + if p in converted_params: + target = converted_params[p] + else: + target = _convert_to_coloparam(p, p.device, p.dtype) + converted_params[p] = target + setattr(m, name, target) + target.shared_param_modules.append(m) + + convert_recursively(module) + + # optimizer should replace params in group as well. This attr should be deleted after replacing to avoid memory leak + module._converted_params = converted_params + + +def replace_param_in_group(optimizer: Optimizer, converted_params: dict) -> None: + """Replace param in optimizer's group with converted ColoParameter. + + Args: + optimizer (Optimizer): Optimizer to be replaced. + converted_params (dict): Mapping between (torch.Tensor, ColoTensor). + """ + for group in optimizer.param_groups: + for i, p in enumerate(group['params']): + if p in converted_params: + group['params'][i] = converted_params[p] + + +class GeminiCheckpointIO(GeneralCheckpointIO): + + def __init__(self) -> None: + super().__init__() + self.coordinator = DistCoordinator() + + def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True): + """ + Load model from checkpoint with automatic unwrapping. + """ + # the model should be unwrapped in self.load_model via ModelWrapper.unwrap + return super().load_unsharded_model(model, checkpoint, strict=strict) + + def save_unsharded_model(self, model: GeminiDDP, checkpoint: str): + """ + Save model to checkpoint but only on master process. + """ + # the model should be unwrapped in self.load_model via ModelWrapper.unwrap + # as there is communication when get state dict, this must be called on all processes + state_dict = model.state_dict(only_rank_0=True) + if self.coordinator.is_master(): + self.save_checkpoint(state_dict, checkpoint) + + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str): + """ + Save optimizer to checkpoint but only on master process. + """ + # TODO(ver217): optimizer state dict is sharded + super().save_unsharded_optimizer(optimizer, checkpoint) + + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """ + Save model to checkpoint but only on master process. + """ + if self.coordinator.is_master(): + super().save_lr_scheduler(lr_scheduler, checkpoint) + + +class GeminiModel(ModelWrapper): + + def __init__(self, module: nn.Module, gemini_config: dict) -> None: + super().__init__(module) + # TODO(ver217): only support Gemini now + convert_to_colo_param(module) + self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config) + + def unwrap(self): + # as save/load state dict is coupled with the GeminiDDP, we only return GeminiDDP model + return self.module + + +class GeminiOptimizer(OptimizerWrapper): + + def __init__(self, module: GeminiDDP, optimizer: Optimizer, zero_optim_config: dict, optim_kwargs: dict) -> None: + replace_param_in_group(optimizer, module.module._converted_params) + del module.module._converted_params + optimizer = zero_optim_wrapper(module, optimizer, optim_config=zero_optim_config, **optim_kwargs) + super().__init__(optimizer) + + def backward(self, loss: Tensor, *args, **kwargs): + self.optim.backward(loss) + + def clip_grad_by_norm(self, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2, + error_if_nonfinite: bool = False, + *args, + **kwargs) -> Tensor: + warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm') + + def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: + raise NotImplementedError('Gemini does not support clip_grad_by_value') + + +class GeminiPlugin(Plugin): + """ + Plugin for Gemini. + + Example: + >>> from colossalai.booster import Booster + >>> from colossalai.booster.plugin import GeminiPlugin + >>> + >>> model, train_dataset, optimizer, criterion = ... + >>> plugin = GeminiPlugin() + + >>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) + >>> booster = Booster(plugin=plugin) + >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + + Args: + device (torch.device): device to place the model. + placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu". + pin_memory (bool, optional): use pin memory on CPU. Defaults to False. + force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. + strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False. + search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32. + hidden_dim (int, optional): the hidden dimension of DNN. + Users can provide this argument to speed up searching. + If users do not know this argument before training, it is ok. We will use a default value 1024. + min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte. + If the aggregate size of parameters is still samller than the minimum chunk size, + all parameters will be compacted into one small chunk. + memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer. + gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward) + which will be used when using hybrid CPU optimizer. + This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto". + Defaults to 0.0. + initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. + min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. + growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. + backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5. + growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000. + hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2. + max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32. + max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do + clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm. + norm_type (float, optional): norm_type used for `clip_grad_norm`. + """ + + def __init__( + self, + device: Optional[torch.device] = None, + placement_policy: str = "cpu", + pin_memory: bool = False, + force_outputs_fp32: bool = False, + strict_ddp_mode: bool = False, + search_range_mb: int = 32, + hidden_dim: Optional[int] = None, + min_chunk_size_mb: float = 32, + memstats: Optional[MemStats] = None, + gpu_margin_mem_ratio: float = 0.0, + initial_scale: float = 2**32, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + ) -> None: + + assert dist.is_initialized( + ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment' + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.gemini_config = dict( + device=(device or get_current_device()), + placement_policy=placement_policy, + pin_memory=pin_memory, + force_outputs_fp32=force_outputs_fp32, + strict_ddp_mode=strict_ddp_mode, + search_range_mb=search_range_mb, + hidden_dim=hidden_dim, + min_chunk_size_mb=min_chunk_size_mb, + memstats=memstats, + ) + self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,) + self.optim_kwargs = dict(initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + max_norm=max_norm, + norm_type=norm_type) + + def support_no_sync(self) -> bool: + return False + + def control_precision(self) -> bool: + return True + + def supported_precisions(self) -> List[str]: + return ['fp16'] + + def control_device(self) -> bool: + return True + + def supported_devices(self) -> List[str]: + return ['cuda'] + + def prepare_train_dataloader(self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + **kwargs): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. + + Note: + 1. Evaluation datasets should not be passed to this function. + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader(dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs) + + def configure( + self, + model: nn.Module, + optimizer: Optimizer, + criterion: Callable = None, + dataloader: DataLoader = None, + lr_scheduler: LRScheduler = None, + ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: + + if not isinstance(model, ModelWrapper): + # convert model to sync bn + # FIXME(ver217): gemini does not support sync bn + # In torch/nn/modules/_functions.py, line 22, ``mean, invstd = torch.batch_norm_stats(input, eps)`` will get fp32 mean and invstd even though the input is fp16. + # This inconsistency of dtype will cause the error. + # We have two possible solutions: + # 1. keep batch norm always in fp32. This is hard for gemini, as it use chunks. + # 2. patch sync bn or write a new on. This is relatively easy, but we need to test it. + # model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None) + + # wrap the model with Gemini + model = GeminiModel(model, self.gemini_config) + + if not isinstance(optimizer, OptimizerWrapper): + optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs) + + return model, optimizer, criterion, dataloader, lr_scheduler + + def control_checkpoint_io(self) -> bool: + return True + + def get_checkpoint_io(self) -> CheckpointIO: + return GeminiCheckpointIO() diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py new file mode 100644 index 000000000..7a0d4a15d --- /dev/null +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -0,0 +1,150 @@ +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin +from colossalai.nn.optimizer import HybridAdam +from colossalai.tensor.colo_parameter import ColoParameter +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from tests.kit.model_zoo import model_zoo + + +def check_gemini_plugin(early_stop: bool = True): + """check gemini plugin over model zoo + + Args: + early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. + """ + plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5) + booster = Booster(plugin=plugin) + + passed_models = [] + failed_info = {} # (model_name, error) pair + + for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): + # These models lead to CUDA error + if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp', + 'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext'): + continue + # These models are not compatible with gemini + if name in [ + 'diffusers_clip_vision_model', + 'timm_resnet', + 'timm_beit', + 'timm_beitv2', + 'timm_eca_nfnet', + 'timm_efficientformer', + 'timm_hrnet_w18_small', + 'timm_nf_ecaresnet101', + 'timm_nf_regnet_b0', + 'timm_skresnet18', + 'timm_wide_resnet50_2', + 'timm_convit', + 'timm_dm_nfnet', + 'timm_swin_transformer', + 'torchaudio_conformer', + 'torchaudio_deepspeech', + 'torchaudio_wavernn', + 'torchaudio_tacotron', + 'deepfm_interactionarch', + 'deepfm_simpledeepfmnn', + 'dlrm', + 'dlrm_interactionarch', + 'torchvision_googlenet', + 'torchvision_inception_v3', + 'torchvision_mobilenet_v3_small', + 'torchvision_resnet18', + 'torchvision_resnext50_32x4d', + 'torchvision_wide_resnet50_2', + 'torchvision_vit_b_16', + 'torchvision_convnext_base', + 'torchvision_swin_s', + 'transformers_albert', + 'transformers_albert_for_pretraining', + 'transformers_bert', + 'transformers_bert_for_pretraining', + 'transformers_gpt_double_heads', + 'torchaudio_hubert_base', + ]: + continue + try: + model = model_fn() + optimizer = HybridAdam(model.parameters(), lr=1e-3) + criterion = lambda x: x.mean() + data = data_gen_fn() + + data = { + k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v + for k, v in data.items() + } + + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + for n, p in model.named_parameters(): + assert isinstance(p, ColoParameter), f'{n} is not a ColoParameter' + + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + + booster.backward(loss, optimizer) + optimizer.step() + passed_models.append(name) + except Exception as e: + failed_info[name] = e + if early_stop: + raise e + if dist.get_rank() == 0: + print(f'Passed models({len(passed_models)}): {passed_models}\n\n') + print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n') + assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()]) + + +def check_dataloader_sharding(): + plugin = GeminiPlugin() + + # create a custom dasetset with 0 to 10 + dataset = torch.utils.data.TensorDataset(torch.arange(0, 10)) + train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2) + + # get the first batch of data + batch = next(iter(train_dataloader))[0].cuda() + is_rank_0 = dist.get_rank() == 0 + + if is_rank_0: + batch_to_compare = batch.clone() + else: + batch_to_compare = batch + # pass to the rank 1 value to rank 0 + dist.broadcast(batch_to_compare, src=1) + + # compare on rank 0 + if is_rank_0: + assert not torch.equal(batch, + batch_to_compare), 'Same number was found across ranks but expected it to be different' + + +def run_dist(rank, world_size, port, early_stop: bool = True): + # init dist env + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + check_dataloader_sharding() + check_gemini_plugin(early_stop=early_stop) + + +@pytest.mark.skip(reason='Skip gemini plugin test due to OOM') +@rerun_if_address_is_in_use() +def test_gemini_plugin(early_stop: bool = True): + world_size = 2 + run_func = partial(run_dist, world_size=world_size, port=free_port(), early_stop=early_stop) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_gemini_plugin(early_stop=False)