diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 155f72dc6..016323ae7 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,5 +1,6 @@ import random from contextlib import nullcontext +from functools import partial from typing import Any, Callable, Iterator, List, Optional, Tuple, Union import numpy as np @@ -10,6 +11,7 @@ from torch.nn import Module, SyncBatchNorm from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils._pytree import tree_map from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -27,32 +29,49 @@ from .pp_plugin_base import PipelinePluginBase DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 +def _convert_floating_point(x, dtype: torch.dtype = torch.float16): + if isinstance(x, torch.Tensor) and torch.is_floating_point(x): + return x.to(dtype) + return x + + class HybridParallelModule(ModelWrapper): def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, ddp_config: dict) -> None: + self.stage_manager = shard_config.pipeline_stage_manager self.dp_group = dp_group + shardformer = ShardFormer(shard_config) module, self.shared_params = shardformer.optimize(module) - # TODO(ver217): add input type cast + + # setting process groups for shared parameters self.shared_param_process_groups = [] for shared_param in self.shared_params: if len(shared_param) > 0: self.shared_param_process_groups.append( self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))) + + # setting mixed_precision + self.mixed_precision = None if precision == 'fp16': - module = module.half().cuda() + self.mixed_precision = torch.float16 elif precision == 'bf16': - module = module.to(dtype=torch.bfloat16).cuda() - else: - module = module.cuda() # train without AMP + self.mixed_precision = torch.bfloat16 + if self.mixed_precision is not None: + module = module.to(self.mixed_precision) + module = module.cuda() + # setting input type cast when using mixed precision + self.convert_fn = None + if self.mixed_precision is not None: + self.convert_fn = partial(_convert_floating_point, dtype=self.mixed_precision) + + # setting ddp configs if use_ddp: - # convert model to sync bn module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group) - # wrap the model with PyTorch DDP module = DDP(module, process_group=dp_group, **ddp_config) @@ -78,6 +97,12 @@ class HybridParallelModule(ModelWrapper): dist.all_reduce(p.grad, group=self.dp_group) p.grad.div_(self.dp_group.size()) + def forward(self, *args, **kwargs): + if self.convert_fn is not None: + args = tree_map(self.convert_fn, args) + kwargs = tree_map(self.convert_fn, kwargs) + return super().forward(*args, **kwargs) + def unwrap(self): module = super().unwrap() if isinstance(module, DDP): @@ -180,7 +205,6 @@ class HybridParallelPlugin(PipelinePluginBase): Defaults to 'fp16'. zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2]. When set to 0, ZeRO will not be used. Defaults to 0. - cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. Currently all the optimization methods include fused normalization, flash attention and JIT. Defaults to False. @@ -196,12 +220,16 @@ class HybridParallelPlugin(PipelinePluginBase): hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2. max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32. max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0. - broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Only for usage of DDP. Defaults to True. - bucket_cap_mb (int, optional): The bucket size in MB. Only for usage of DDP. Defaults to 25. - find_unused_parameters (bool, optional): Whether to find unused parameters. Only for usage of DDP. Defaults to False. - check_reduction (bool, optional): Whether to check reduction. Only for usage of DDP. Defaults to False. - gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Only for usage of DDP. Defaults to False. - static_graph (bool, optional): Whether to use static graph. Only for usage of DDP. Defaults to False. + broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True. + ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25. + find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False. + check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False. + gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False. + static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False. + zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12. + cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. + communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. + overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. """ def __init__(self, @@ -209,7 +237,6 @@ class HybridParallelPlugin(PipelinePluginBase): pp_size: int, precision: str = 'fp16', zero_stage: int = 0, - cpu_offload: bool = False, enable_all_optimization: bool = False, enable_fused_normalization: bool = False, enable_flash_attention: bool = False, @@ -224,12 +251,16 @@ class HybridParallelPlugin(PipelinePluginBase): hysteresis: int = 2, max_scale: float = 2**32, max_norm: float = 0, - broadcast_buffers=True, - bucket_cap_mb=25, - find_unused_parameters=False, - check_reduction=False, - gradient_as_bucket_view=False, - static_graph=False) -> None: + broadcast_buffers: bool = True, + ddp_bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + zero_bucket_size_in_m: int = 12, + cpu_offload: bool = False, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True) -> None: super().__init__() assert dist.get_world_size() % ( @@ -239,8 +270,6 @@ class HybridParallelPlugin(PipelinePluginBase): if enable_sequence_parallelism: assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism' - # TODO(ver217): support zero - assert zero_stage == 0, 'zero is not support yet' self.tp_size = tp_size self.pp_size = pp_size self.dp_size = dist.get_world_size() // (tp_size * pp_size) @@ -282,11 +311,18 @@ class HybridParallelPlugin(PipelinePluginBase): ) self.ddp_config = dict(broadcast_buffers=broadcast_buffers, - bucket_cap_mb=bucket_cap_mb, + bucket_cap_mb=ddp_bucket_cap_mb, find_unused_parameters=find_unused_parameters, check_reduction=check_reduction, gradient_as_bucket_view=gradient_as_bucket_view, static_graph=static_graph) + + self.zero_config = dict(reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload, + partition_grad=(self.zero_stage == 2)) + self.max_norm = max_norm @property @@ -337,15 +373,16 @@ class HybridParallelPlugin(PipelinePluginBase): model, use_pipeline=self.enable_pipeline_parallelism) else: + assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." + assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO." optimizer = HybridParallelZeroOptimizer(optimizer, model, use_pipeline=self.enable_pipeline_parallelism, - partition_grad=(self.zero_stage == 2), - cpu_offload=self.cpu_offload, dp_process_group=self.dp_group, tp_process_group=self.tp_group, verbose=True, clip_grad_norm=self.max_norm, + **self.zero_config, **self.amp_config) return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 49de9cc03..c96701704 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -56,9 +56,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if stage_manager is None or stage_manager.is_first_stage(): - #check_weight(bert.embeddings.word_embeddings, sharded_bert.embeddings.word_embeddings, tp_group, atol=1e-5, rtol=1e-3) - #check_weight(bert.encoder.layer[0].attention.self.query, sharded_bert.encoder.layer[0].attention.self.query, tp_group, atol=5e-3, rtol=1e-3) + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) @@ -101,6 +99,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_bert_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index af014a858..bd87be8b7 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -53,7 +53,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grad row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings'] col_layer_for_check = ['h[0].self_attention.dense'] - if stage_manager is None or stage_manager.is_first_stage(): + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-6, 1e-5 else: @@ -101,6 +101,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_bloom_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index 210f775b5..64732e06b 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -55,7 +55,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grad row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings'] col_layer_for_check = ['encoder.layers[0].self_attention.dense'] - if stage_manager is None or stage_manager.is_first_stage(): + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-6, 1e-3 else: @@ -125,6 +125,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_chatglm_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 0e29f1dd9..c776a80d8 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -56,7 +56,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] # check grad - if stage_manager is None or stage_manager.is_first_stage(): + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-4, 1e-3 else: @@ -120,6 +120,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'use_lazy_init': True, 'enable_sequence_parallelism': True, 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) @clear_cache_before_run() def run_gpt2_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index ec5578a76..7140c4666 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -60,7 +60,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grad row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens'] col_layer_for_check = ['layers[0].self_attn.o_proj'] - if stage_manager is None or stage_manager.is_first_stage(): + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-6, 1e-4 else: @@ -135,6 +135,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_llama_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 2fb14903b..e6faafdae 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -58,7 +58,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grad row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens' col_layer_for_check = ['decoder.layers[0].self_attn.out_proj'] - if stage_manager is None or stage_manager.is_first_stage(): + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-6, 1e-3 else: @@ -127,6 +127,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_opt_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 234ce812a..599f5a80d 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -55,12 +55,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q'] - # check weights and gradients + # check grad if test_config['precision'] == 'fp32': atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if stage_manager is None or stage_manager.is_first_stage(): + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0) # check weights after optimizer.step() @@ -110,6 +110,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) @clear_cache_before_run() def run_t5_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index b9d303841..b27add24c 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -55,7 +55,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grad row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection'] col_layer_for_check = ['encoder.layer[0].attention.output.dense'] - if stage_manager is None or stage_manager.is_first_stage(): + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-5, 1e-3 else: @@ -124,6 +124,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_vit_test(test_config):