From e86127925aca92467cbdc58bbea9920a2565b82c Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 18 Jul 2024 15:33:03 +0800 Subject: [PATCH] [plugin] support all-gather overlap for hybrid parallel (#5919) * [plugin] fixed all-gather overlap support for hybrid parallel --- .../booster/plugin/hybrid_parallel_plugin.py | 31 ++++++++++++++++--- .../booster/plugin/low_level_zero_plugin.py | 16 +++++----- .../hybrid_parallel_checkpoint_io.py | 4 +++ examples/language/llama/benchmark.py | 3 +- 4 files changed, 42 insertions(+), 12 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 6f27fa641..2c8cb6ba1 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -2,7 +2,7 @@ import ctypes import random import warnings from collections import defaultdict -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from copy import deepcopy from functools import partial from types import MethodType @@ -33,8 +33,11 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.policies.base_policy import Policy +from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.d_tensor.api import is_distributed_tensor +from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.zero.low_level import LowLevelZeroOptimizer +from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle from .pp_plugin_base import PipelinePluginBase @@ -61,6 +64,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): use_ddp: bool, ddp_config: dict, custom_policy: Policy, + overlap_allgather: bool = False, ) -> None: self.stage_manager = shard_config.pipeline_stage_manager self.shard_config = shard_config @@ -69,6 +73,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): self.sp_group = sp_group self.use_dpp = use_ddp self.require_grad_sync = True + self.overlap_allgather = overlap_allgather shardformer = ShardFormer(shard_config) if custom_policy is not None: @@ -106,6 +111,12 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): module = DDP(module, process_group=dp_group, **ddp_config) super().__init__(module) + if overlap_allgather: + self.op_hook = ZeroOpHook() + for p in module.parameters(): + if p.requires_grad and type(p) is not ColoParameter: + p.__class__ = ColoParameter + p.__init__(p, requires_grad=True) def sync_shared_params(self): for shared_param, group in zip(self.shared_params, self.shared_param_process_groups): @@ -197,7 +208,8 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): if self.convert_fn is not None: args = tree_map(self.convert_fn, args) kwargs = tree_map(self.convert_fn, kwargs) - return super().forward(*args, **kwargs) + with self._wait_all_gather(): + return super().forward(*args, **kwargs) def unwrap(self): module = super().unwrap() @@ -205,6 +217,13 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): module = module.module return module + def _force_wait_all_gather(self): + for p in self.module.parameters(): + wait_all_gather_handle(p) + + def _wait_all_gather(self): + return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext() + def get_param_info(optim: Optimizer): # Get a backup of necessary information of parameters for future use, which includes: @@ -650,6 +669,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): tp_process_group: Optional[ProcessGroup] = None, # if using tp pp_process_group: Optional[ProcessGroup] = None, # if using pp forced_dtype: Optional[torch.dtype] = None, + overlap_allgather: bool = False, ): self.model = model self.param_info = param_info @@ -677,7 +697,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): cpu_offload=cpu_offload, dp_process_group=dp_process_group, forced_dtype=forced_dtype, - overlap_allgather=False, + overlap_allgather=overlap_allgather, ) def sync_dp_grads(self): @@ -993,6 +1013,7 @@ class HybridParallelPlugin(PipelinePluginBase): make_vocab_size_divisible_by: int = 64, dp_outside: bool = True, overlap_p2p: bool = True, + overlap_allgather: bool = False, ) -> None: super().__init__() assert ( @@ -1144,6 +1165,7 @@ class HybridParallelPlugin(PipelinePluginBase): cpu_offload=cpu_offload, partition_grad=(self.zero_stage == 2), forced_dtype=PRECISION_TORCH_TYPE[precision], + overlap_allgather=overlap_allgather, ) self.max_norm = max_norm @@ -1221,6 +1243,7 @@ class HybridParallelPlugin(PipelinePluginBase): use_ddp=use_ddp, ddp_config=self.ddp_config, custom_policy=self.custom_policy, + overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]), ) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if zero_stage == 0: @@ -1303,7 +1326,7 @@ class HybridParallelPlugin(PipelinePluginBase): # so we disable it, performing manual reduction instead. ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() - with ctx: + with ctx, model._wait_all_gather(): outputs = self.schedule.forward_backward_step( model, data_iter, criterion, optimizer, return_loss, return_outputs ) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index b9b2c57dc..1a6547796 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -62,7 +62,7 @@ class OptimizerParamCheckState(enum.Enum): class LowLevelZeroModel(ModelWrapper, AMPModelMixin): - def __init__(self, module: nn.Module, precision: str, overlap_communication: bool = False) -> None: + def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None: super().__init__(module) self.dtype = None if precision == "fp16": @@ -76,8 +76,8 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): self.convert_fn = None if self.dtype is not None: self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) - self.overlap_communication = overlap_communication - if overlap_communication: + self.overlap_allgather = overlap_allgather + if overlap_allgather: self.op_hook = ZeroOpHook() for p in module.parameters(): if p.requires_grad and type(p) is not ColoParameter: @@ -88,7 +88,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): if self.convert_fn is not None: args = tree_map(self.convert_fn, args) kwargs = tree_map(self.convert_fn, kwargs) - ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_communication else nullcontext() + ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext() with ctx: return super().forward(*args, **kwargs) @@ -356,8 +356,8 @@ class LowLevelZeroPlugin(DPPluginBase): partition_grad=(stage == 2), cpu_offload=cpu_offload, master_weights=master_weights, + overlap_allgather=overlap_allgather, ) - self.overlap_allgather = overlap_allgather self.lora_enabled = False self.verbose = verbose @@ -473,11 +473,13 @@ class LowLevelZeroPlugin(DPPluginBase): self.add_lora_params_to_optimizer(model, optimizer) if not isinstance(model, ModelWrapper): - model = LowLevelZeroModel(model, self.precision, overlap_communication=self.overlap_allgather) + model = LowLevelZeroModel( + model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"] + ) # TODO: Support Galore + ZeRO zero_stage = self.stage - zero_optim_kwargs = {**self.zero_optim_kwargs, "overlap_allgather": self.overlap_allgather} + zero_optim_kwargs = {**self.zero_optim_kwargs} dp_size = dist.get_world_size() # Replace with the distributed implementation if exists diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 61c9d1438..b7097e432 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -195,6 +195,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): """ assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model._force_wait_all_gather() model = model.unwrap() if os.path.isfile(checkpoint): @@ -303,6 +304,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): This argument should be manually set to False since params on same device might be stored in different files. """ assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + model._force_wait_all_gather() model_before_wrapping = model # backup for model before wrapping model = model.unwrap() @@ -639,6 +641,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model._force_wait_all_gather() model = model.unwrap() if self.dp_rank != 0: @@ -679,6 +682,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + model._force_wait_all_gather() strict = False model_before_wrapping = model model = model.unwrap() diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 2b7bd50b8..e530e2d6a 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -98,6 +98,7 @@ def main(): parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation") parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") parser.add_argument("--no_cache", action="store_true") + parser.add_argument("--overlap_allgather", action="store_true") args = parser.parse_args() colossalai.launch_from_torch() @@ -199,9 +200,9 @@ def main(): enable_flash_attention=args.xformers, microbatch_size=args.mbs, precision="bf16", - dp_outside=False, overlap_p2p=args.overlap, enable_metadata_cache=not args.no_cache, + overlap_allgather=args.overlap_allgather, **hybrid_kwargs, ) elif args.plugin == "3d_cpu":