From c068ef0fa0777c57cb756dbba61ce9ca49e5f5b6 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 11 Jul 2024 18:59:59 +0800 Subject: [PATCH 1/8] [zero] support all-gather overlap (#5898) * [zero] support all-gather overlap * [zero] add overlap all-gather flag * [misc] fix typo * [zero] update api --- .../booster/plugin/hybrid_parallel_plugin.py | 1 + .../booster/plugin/low_level_zero_plugin.py | 50 +++++++++++++++++-- colossalai/zero/low_level/low_level_optim.py | 50 ++++++++++++------- colossalai/zero/low_level/zero_hook.py | 33 ++++++++++++ examples/language/performance_evaluator.py | 4 +- .../test_zero/test_low_level/test_grad_acc.py | 4 ++ .../test_zero/test_low_level/test_zero1_2.py | 2 + 7 files changed, 119 insertions(+), 25 deletions(-) create mode 100644 colossalai/zero/low_level/zero_hook.py diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 485833398..6f27fa641 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -677,6 +677,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): cpu_offload=cpu_offload, dp_process_group=dp_process_group, forced_dtype=forced_dtype, + overlap_allgather=False, ) def sync_dp_grads(self): diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 7b5aec2aa..b9b2c57dc 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -2,6 +2,7 @@ import enum import logging import os import warnings +from contextlib import nullcontext from functools import partial from pathlib import Path from types import MethodType @@ -34,7 +35,10 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface.optimizer import DistributedOptim from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.quantization import BnbQuantizationConfig, quantize_model +from colossalai.tensor.colo_parameter import ColoParameter +from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.zero import LowLevelZeroOptimizer +from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle from .dp_plugin_base import DPPluginBase from .torch_ddp_plugin import TorchDDPCheckpointIO @@ -58,7 +62,7 @@ class OptimizerParamCheckState(enum.Enum): class LowLevelZeroModel(ModelWrapper, AMPModelMixin): - def __init__(self, module: nn.Module, precision: str) -> None: + def __init__(self, module: nn.Module, precision: str, overlap_communication: bool = False) -> None: super().__init__(module) self.dtype = None if precision == "fp16": @@ -72,12 +76,25 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): self.convert_fn = None if self.dtype is not None: self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) + self.overlap_communication = overlap_communication + if overlap_communication: + self.op_hook = ZeroOpHook() + for p in module.parameters(): + if p.requires_grad and type(p) is not ColoParameter: + p.__class__ = ColoParameter + p.__init__(p, requires_grad=True) def forward(self, *args, **kwargs): if self.convert_fn is not None: args = tree_map(self.convert_fn, args) kwargs = tree_map(self.convert_fn, kwargs) - return super().forward(*args, **kwargs) + ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_communication else nullcontext() + with ctx: + return super().forward(*args, **kwargs) + + def _force_wait_all_gather(self): + for p in self.module.parameters(): + wait_all_gather_handle(p) class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): @@ -209,6 +226,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True): assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" + model._force_wait_all_gather() super().load_unsharded_model(model, checkpoint, strict) model.update_master_params() @@ -221,9 +239,30 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): load_sub_module: bool = True, ): assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" + model._force_wait_all_gather() super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module) model.update_master_params() + def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" + model._force_wait_all_gather() + return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) + + def save_sharded_model( + self, + model: ModelWrapper, + checkpoint_path: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + max_shard_size: int = 1024, + use_safetensors: bool = False, + ): + assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" + model._force_wait_all_gather() + return super().save_sharded_model( + model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors + ) + def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") @@ -231,6 +270,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): from peft import PeftModel assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model._force_wait_all_gather() peft_model = model.unwrap() assert isinstance( peft_model, PeftModel @@ -290,6 +330,7 @@ class LowLevelZeroPlugin(DPPluginBase): reduce_bucket_size_in_m: int = 12, communication_dtype: Optional[torch.dtype] = None, overlap_communication: bool = True, + overlap_allgather: bool = False, cpu_offload: bool = False, master_weights: bool = True, verbose: bool = False, @@ -316,6 +357,7 @@ class LowLevelZeroPlugin(DPPluginBase): cpu_offload=cpu_offload, master_weights=master_weights, ) + self.overlap_allgather = overlap_allgather self.lora_enabled = False self.verbose = verbose @@ -431,11 +473,11 @@ class LowLevelZeroPlugin(DPPluginBase): self.add_lora_params_to_optimizer(model, optimizer) if not isinstance(model, ModelWrapper): - model = LowLevelZeroModel(model, self.precision) + model = LowLevelZeroModel(model, self.precision, overlap_communication=self.overlap_allgather) # TODO: Support Galore + ZeRO zero_stage = self.stage - zero_optim_kwargs = {**self.zero_optim_kwargs} + zero_optim_kwargs = {**self.zero_optim_kwargs, "overlap_allgather": self.overlap_allgather} dp_size = dist.get_world_size() # Replace with the distributed implementation if exists diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index bdc91b51f..6ff235b96 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -23,6 +23,7 @@ from colossalai.logging import get_dist_logger from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor from .bookkeeping import BucketStore, GradientStore, TensorBucket +from .zero_hook import set_all_gather_handle, wait_all_gather_handle class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): @@ -83,6 +84,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): dp_process_group: Optional[ProcessGroup] = None, forced_dtype: Optional[torch.dtype] = None, master_weights: bool = True, # master weights + overlap_allgather: bool = False, ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) @@ -121,6 +123,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # communication params self._overlap_communication = overlap_communication + self._overlap_allgather = overlap_allgather self._reduce_bucket_size = reduce_bucket_size self._communication_dtype = communication_dtype @@ -145,6 +148,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # record the padding size of each param self._padding_map = dict() + # padded working param is all-gather buffer and it shares the same memory with working param + self._working_param_to_padded_working_param = dict() # mapping working param and master param self.master_to_working_param = dict() @@ -245,11 +250,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper): with torch.no_grad(): if padding_size > 0: padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) - # reset working params' ptr when no master weights - if self._master_weights == False: - param.data = padding_param[: param.numel()].view(param.shape) + # # reset working params' ptr when no master weights + # if self._master_weights == False: + param.data = padding_param[: param.numel()].view(param.shape) else: padding_param = param.data.view(-1) + self._working_param_to_padded_working_param[param] = padding_param splited_params = padding_param.split( padding_param.numel() // self.pid_to_bucket_store[id(param)].world_size @@ -258,7 +264,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # use fp32 when master_weights is True if self._master_weights is True: - splited_param_current_rank = splited_params.detach().float().to(device) + splited_param_current_rank = splited_params.detach().clone().float().to(device) else: splited_param_current_rank = splited_params @@ -549,22 +555,24 @@ class LowLevelZeroOptimizer(OptimizerWrapper): working_param = real_working_params[group_id][idx] param_to_gather = master_param.to(device).to(self._dtype) pg = self.param_to_pg[working_param] - if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size: - buffer_tensor = torch.empty_like( - torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))]) - ) - dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg) - working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param)) - continue - try: - self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) - except RuntimeError: - self.pg_to_tensor_bucket[pg].all_gather(pg) - self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) + padded_working_param = self._working_param_to_padded_working_param[working_param] + if self._overlap_allgather: + handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True) + set_all_gather_handle(working_param, handle) + else: + if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size: + dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg) + continue + try: + self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) + except RuntimeError: + self.pg_to_tensor_bucket[pg].all_gather(pg) + self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] - for pg, tensor_bucket in self.pg_to_tensor_bucket.items(): - if not tensor_bucket.is_empty(): - tensor_bucket.all_gather(pg) + if not self._overlap_allgather: + for pg, tensor_bucket in self.pg_to_tensor_bucket.items(): + if not tensor_bucket.is_empty(): + tensor_bucket.all_gather(pg) def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float: r""" @@ -892,3 +900,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List: grad_store = self.pid_to_grad_store[param_id] return grad_store.get_partitioned_gradients_by_param_id(group_id, param_id) + + def _force_wait_all_gather(self): + for param in self._working_param_to_padded_working_param.keys(): + wait_all_gather_handle(param) diff --git a/colossalai/zero/low_level/zero_hook.py b/colossalai/zero/low_level/zero_hook.py new file mode 100644 index 000000000..20f9ef31a --- /dev/null +++ b/colossalai/zero/low_level/zero_hook.py @@ -0,0 +1,33 @@ +from typing import List + +from torch._tensor import Tensor + +from colossalai.tensor.param_op_hook import ColoParamOpHook + +_ALL_GATHER_HANDLE = "_all_gather_handle" + + +def wait_all_gather_handle(p): + if hasattr(p, _ALL_GATHER_HANDLE): + handle = getattr(p, _ALL_GATHER_HANDLE) + handle.wait() + delattr(p, _ALL_GATHER_HANDLE) + + +def set_all_gather_handle(p, handle): + setattr(p, _ALL_GATHER_HANDLE, handle) + + +class ZeroOpHook(ColoParamOpHook): + def pre_forward(self, params: List[Tensor]) -> None: + for p in params: + wait_all_gather_handle(p) + + def post_forward(self, params: List[Tensor]) -> None: + pass + + def pre_backward(self, params: List[Tensor]) -> None: + pass + + def post_backward(self, params: List[Tensor]) -> None: + pass diff --git a/examples/language/performance_evaluator.py b/examples/language/performance_evaluator.py index 6b8daf37d..ca4a02cd2 100644 --- a/examples/language/performance_evaluator.py +++ b/examples/language/performance_evaluator.py @@ -113,13 +113,13 @@ class PerformanceEvaluator: self.disable = self.ignore_steps > 0 and step < self.ignore_steps if self.disable: return - get_accelerator().synchronize() + # get_accelerator().synchronize() self.timer.start() def on_step_end(self, input_ids: Tensor, **kwargs) -> None: if self.disable: return - get_accelerator().synchronize() + # get_accelerator().synchronize() self.timer.end() batch_size, seq_len = input_ids.shape diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py index ed12bb72d..94db70ca5 100644 --- a/tests/test_zero/test_low_level/test_grad_acc.py +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -64,8 +64,12 @@ def exam_zero_1_2_grad_acc(): zero1_optimizer.step() zero2_optimizer.step() + zero1_optimizer._force_wait_all_gather() + zero2_optimizer._force_wait_all_gather() + # check updated param for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): + assert not hasattr(z1p, "_all_gather_handle") assert torch.equal(z1p.data, z2p.data) diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py index 8df35bdaa..c376c50e0 100644 --- a/tests/test_zero/test_low_level/test_zero1_2.py +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -177,6 +177,8 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): # torch ddp step torch_optimizer.step() + zero_optimizer._force_wait_all_gather() + # check updated param for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): loose_close(p, z1p, dtype=dtype) From 45c49dde96613427f8ccd1f6c7f9b48fd303256e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stephan=20K=C3=B6?= Date: Mon, 15 Jul 2024 12:05:06 +0800 Subject: [PATCH 2/8] [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446) * Remove unnecessary calls to deepcopy * Build DimSpec's difference dict only once This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough. * Fix documentation of DimSpec's difference method --- colossalai/tensor/d_tensor/sharding_spec.py | 87 +++++++++++--------- colossalai/tensor/sharding_spec.py | 89 ++++++++++++--------- 2 files changed, 103 insertions(+), 73 deletions(-) diff --git a/colossalai/tensor/d_tensor/sharding_spec.py b/colossalai/tensor/d_tensor/sharding_spec.py index 16a4f248b..76d85a112 100644 --- a/colossalai/tensor/d_tensor/sharding_spec.py +++ b/colossalai/tensor/d_tensor/sharding_spec.py @@ -1,4 +1,3 @@ -from copy import deepcopy from typing import Dict, List from ..utils import merge_same_dim_mesh_list @@ -23,10 +22,11 @@ class DimSpec: Otherwise, the element in shard_list means the data will be sharded in that dimension. """ + _DIFFERENCE_DICT = None + def __init__(self, shard_list): self.is_replica = len(shard_list) == 0 self.shard_list = shard_list - self.build_difference_2d_dict() def __eq__(self, other): return str(self) == str(other) @@ -39,24 +39,43 @@ class DimSpec: target += str(dim) return target - def _convert_str_to_shard_list(self, str_spec): + @property + def difference_dict(self): """ - Convert str_spec into shard_list. + Returns the difference dict, and lazily initializes it when needed - Argument: - str_spec(str): dim spec in str type. + Return: + difference_dict(Dict[Tuple[int, int], Union[int, float, str]]): + difference dict """ + if self._DIFFERENCE_DICT is None: + self._DIFFERENCE_DICT = self._build_difference_2d_dict() - if str_spec == "R": - return [] - if str_spec == "S0": - return [0] - if str_spec == "S1": - return [1] - if str_spec == "S01": - return [0, 1] + return self._DIFFERENCE_DICT - def build_difference_2d_dict(self): + def dim_diff(self, other): + """ + The difference between two DimSpec. + + Argument: + other(DimSpec): the dim spec to compare with. + + Return: + difference(int): the difference between two DimSpec. + + Example: + dim_spec = DimSpec([0]) + other_dim_spec = DimSpec([0, 1]) + print(dim_spec.dim_diff(other_dim_spec)) + + Output: + 5 + """ + difference = self.difference_dict[(str(self), str(other))] + return difference + + @classmethod + def _build_difference_2d_dict(cls): """ Build a difference mapping for 2D device mesh case. It will be used to compute the difference between DimSpec pairs. @@ -67,9 +86,8 @@ class DimSpec: difference_dict = {} for source_spec in source_spec_list: for target_spec in target_spec_list: - spec_pair = (deepcopy(source_spec), deepcopy(target_spec)) - source_shard_list = self._convert_str_to_shard_list(source_spec) - target_shard_list = self._convert_str_to_shard_list(target_spec) + source_shard_list = cls._convert_str_to_shard_list(source_spec) + target_shard_list = cls._convert_str_to_shard_list(target_spec) # source same as target if source_shard_list == target_shard_list: @@ -112,30 +130,27 @@ class DimSpec: else: difference = NAN - difference_dict[spec_pair] = difference + difference_dict[(source_spec, target_spec)] = difference - self.difference_dict = difference_dict + return difference_dict - def dim_diff(self, other): + @staticmethod + def _convert_str_to_shard_list(str_spec): """ - The difference between two _DimSpec. + Convert str_spec into shard_list. Argument: - other(_DimSpec): the dim spec to compare with. - - Return: - difference(int): the difference between two _DimSpec. - - Example: - dim_spec = _DimSpec([0]) - other_dim_spec = _DimSpec([0, 1]) - print(dim_spec.difference(other_dim_spec)) - - Output: - 5 + str_spec(str): dim spec in str type. """ - difference = self.difference_dict[(str(self), str(other))] - return difference + + if str_spec == "R": + return [] + if str_spec == "S0": + return [0] + if str_spec == "S1": + return [1] + if str_spec == "S01": + return [0, 1] class ShardingSpec: diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py index b78ef6d97..fb42afab7 100644 --- a/colossalai/tensor/sharding_spec.py +++ b/colossalai/tensor/sharding_spec.py @@ -1,5 +1,4 @@ import operator -from copy import deepcopy from functools import reduce import torch @@ -27,10 +26,11 @@ class _DimSpec: Otherwise, the element in shard_list means the data will be sharded in that dimension. """ + _DIFFERENCE_DICT = None + def __init__(self, shard_list): self.is_replica = len(shard_list) == 0 self.shard_list = shard_list - self.build_difference_2d_dict() def __eq__(self, other): return str(self) == str(other) @@ -43,27 +43,46 @@ class _DimSpec: target += str(dim) return target - def _convert_str_to_shard_list(self, str_spec): + @property + def difference_dict(self): """ - Convert str_spec into shard_list. + Returns the difference dict, and lazily initializes it when needed - Argument: - str_spec(str): dim spec in str type. + Return: + difference_dict(Dict[Tuple[int, int], Union[int, float, str]]): + difference dict """ + if self._DIFFERENCE_DICT is None: + self._DIFFERENCE_DICT = self._build_difference_2d_dict() - if str_spec == "R": - return [] - if str_spec == "S0": - return [0] - if str_spec == "S1": - return [1] - if str_spec == "S01": - return [0, 1] + return self._DIFFERENCE_DICT - def build_difference_2d_dict(self): + def difference(self, other): + """ + The difference between two _DimSpec. + + Argument: + other(_DimSpec): the dim spec to compare with. + + Return: + difference(int): the difference between two _DimSpec. + + Example: + dim_spec = _DimSpec([0]) + other_dim_spec = _DimSpec([0, 1]) + print(dim_spec.difference(other_dim_spec)) + + Output: + 5 + """ + difference = self.difference_dict[(str(self), str(other))] + return difference + + @classmethod + def _build_difference_2d_dict(cls): """ Build a difference mapping for 2D device mesh case. It will be used to - compute the difference between DimSpec pairs. + compute the difference between _DimSpec pairs. """ source_spec_list = ["R", "S0", "S1", "S01"] @@ -71,9 +90,8 @@ class _DimSpec: difference_dict = {} for source_spec in source_spec_list: for target_spec in target_spec_list: - spec_pair = (deepcopy(source_spec), deepcopy(target_spec)) - source_shard_list = self._convert_str_to_shard_list(source_spec) - target_shard_list = self._convert_str_to_shard_list(target_spec) + source_shard_list = cls._convert_str_to_shard_list(source_spec) + target_shard_list = cls._convert_str_to_shard_list(target_spec) # source same as target if source_shard_list == target_shard_list: @@ -116,30 +134,27 @@ class _DimSpec: else: difference = NAN - difference_dict[spec_pair] = difference + difference_dict[(source_spec, target_spec)] = difference - self.difference_dict = difference_dict + return difference_dict - def difference(self, other): + @staticmethod + def _convert_str_to_shard_list(str_spec): """ - The difference between two _DimSpec. + Convert str_spec into shard_list. Argument: - other(_DimSpec): the dim spec to compare with. - - Return: - difference(int): the difference between two _DimSpec. - - Example: - dim_spec = _DimSpec([0]) - other_dim_spec = _DimSpec([0, 1]) - print(dim_spec.difference(other_dim_spec)) - - Output: - 5 + str_spec(str): dim spec in str type. """ - difference = self.difference_dict[(str(self), str(other))] - return difference + + if str_spec == "R": + return [] + if str_spec == "S0": + return [0] + if str_spec == "S1": + return [1] + if str_spec == "S01": + return [0, 1] class ShardingSpecException(Exception): From 1c961b20f33a5213c9feb9e5634e0a6f7cae0ca7 Mon Sep 17 00:00:00 2001 From: Guangyao Zhang Date: Mon, 15 Jul 2024 13:58:06 +0800 Subject: [PATCH 3/8] [ShardFormer] fix qwen2 sp (#5903) --- colossalai/shardformer/modeling/qwen2.py | 6 +- .../test_model/test_shard_qwen2.py | 99 ++++++++++--------- 2 files changed, 56 insertions(+), 49 deletions(-) diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index da78dfc0b..55822b150 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -1,3 +1,4 @@ +import math from typing import List, Optional, Tuple, Union import torch @@ -513,7 +514,6 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": query_states = all_to_all_comm(query_states, sp_group) @@ -698,9 +698,9 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No next_decoder_cache = None if sp_mode in ["ring", "split_gather"]: - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size) for decoder_layer in self.layers: if output_hidden_states: diff --git a/tests/test_shardformer/test_model/test_shard_qwen2.py b/tests/test_shardformer/test_model/test_shard_qwen2.py index 160f9c53b..c87415b75 100644 --- a/tests/test_shardformer/test_model/test_shard_qwen2.py +++ b/tests/test_shardformer/test_model/test_shard_qwen2.py @@ -135,51 +135,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 2, - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 1, - "precision": "fp16", - "initial_scale": 1, - }, - ], -) -def run_qwen2_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2") - - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) - - clear_layout_converter() - Randomizer.reset_index() - torch.cuda.empty_cache() - - -@parameterize( - "test_config", - [ - { - "tp_size": 2, - "pp_size": 2, - "num_microbatches": 4, - "enable_all_optimization": False, - "use_lazy_init": False, - "precision": "fp32", - "initial_scale": 1, - }, - { - "tp_size": 2, - "pp_size": 2, - "num_microbatches": 4, - "enable_all_optimization": False, - "use_lazy_init": False, - "precision": "fp16", - "zero_stage": 1, - "initial_scale": 1, - }, { # Ulysess + Flash attention "tp_size": 1, "pp_size": 2, @@ -242,6 +197,54 @@ def run_qwen2_test(test_config): "precision": "fp16", "initial_scale": 1, }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) +def run_qwen2_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + try: + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + except Exception as e: + print(f"Failed config: {test_config}") + raise e + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2, @@ -259,7 +262,11 @@ def run_qwen2_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + try: + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + except Exception as e: + print(f"Failed config: {test_config}") + raise e clear_layout_converter() Randomizer.reset_index() From 2e28c793cea175fbb9d727d6f056634801a77d24 Mon Sep 17 00:00:00 2001 From: Guangyao Zhang Date: Thu, 4 Jul 2024 10:53:09 +0800 Subject: [PATCH 4/8] [compatibility] support torch 2.2 (#5875) * Support Pytorch 2.2.2 * keep build_on_pr file and update .compatibility --- .compatibility | 1 + colossalai/tensor/d_tensor/layout_converter.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.compatibility b/.compatibility index d90a74b58..7ecced624 100644 --- a/.compatibility +++ b/.compatibility @@ -1 +1,2 @@ 2.1.0-12.1.0 +2.2.2-12.1.0 diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index c2cf73181..0f0150d90 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -473,7 +473,7 @@ class LayoutConverter(metaclass=SingletonMeta): for process_group in used_process_groups: try: dist.get_rank(process_group) - except RuntimeError as e: + except (ValueError, RuntimeError) as e: # If the group is not registered, it means it has been deleted if str(e) == ( f"Group {process_group} is not registered, please create group with torch.distributed.new_group API" From 530283dba034b20c8f3562a661995e38926f3e80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=82=A2=E3=83=9E=E3=83=87=E3=82=A6=E3=82=B9?= Date: Thu, 4 Jul 2024 10:53:58 +0800 Subject: [PATCH 5/8] fix object_to_tensor usage when torch>=2.3.0 (#5820) --- colossalai/pipeline/p2p.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index ed190eb08..b7b284213 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -91,7 +91,11 @@ def _broadcast_object_list( my_rank = dist.get_rank() # Serialize object_list elements to tensors on src rank. if my_rank == src: - if Version(torch.__version__) >= Version("1.13.0"): + if Version(torch.__version__) >= Version("2.3.0"): + tensor_list, size_list = zip( + *[c10d._object_to_tensor(obj, device=current_device, group=group) for obj in object_list] + ) + elif Version(torch.__version__) >= Version("1.13.0"): tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=current_device) for obj in object_list]) else: tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list]) @@ -276,7 +280,11 @@ def _send_recv_serialization_object( send_object_tensor = None send_object_size_tensor = None if object is not None and send_dst is not None: - if Version(torch.__version__) >= Version("1.13.0"): + if Version(torch.__version__) >= Version("2.3.0"): + send_object_tensor, send_object_size_tensor = c10d._object_to_tensor( + object, device=current_device, group=send_group + ) + elif Version(torch.__version__) >= Version("1.13.0"): send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device) else: send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object) From 27a72f0de1be2c4f4d087e6581e321129f0f38db Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 11 Jul 2024 16:43:18 +0800 Subject: [PATCH 6/8] [misc] support torch2.3 (#5893) * [misc] support torch2.3 * [devops] update compatibility ci * [devops] update compatibility ci * [devops] add debug * [devops] add debug * [devops] add debug * [devops] add debug * [devops] remove debug * [devops] remove debug --- .compatibility | 1 + .../compatiblity_test_on_dispatch.yml | 32 +++++------------- .github/workflows/compatiblity_test_on_pr.yml | 33 +++++-------------- .../compatiblity_test_on_schedule.yml | 33 ++++--------------- requirements/requirements.txt | 2 +- 5 files changed, 27 insertions(+), 74 deletions(-) diff --git a/.compatibility b/.compatibility index 7ecced624..4f808740b 100644 --- a/.compatibility +++ b/.compatibility @@ -1,2 +1,3 @@ 2.1.0-12.1.0 2.2.2-12.1.0 +2.3.0-12.1.0 diff --git a/.github/workflows/compatiblity_test_on_dispatch.yml b/.github/workflows/compatiblity_test_on_dispatch.yml index 3eee564c2..1a458d7bb 100644 --- a/.github/workflows/compatiblity_test_on_dispatch.yml +++ b/.github/workflows/compatiblity_test_on_dispatch.yml @@ -55,41 +55,27 @@ jobs: steps: - name: Install dependencies run: | - pip install -U pip setuptools==68.2.2 wheel --user - - uses: actions/checkout@v2 - with: - repository: hpcaitech/TensorNVMe - ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} - path: TensorNVMe - - name: Install tensornvme - run: | - cd TensorNVMe apt update && apt install -y cmake - pip install -r requirements.txt - DISABLE_URING=1 pip install -v . + pip install -U pip setuptools==68.2.2 wheel --user + - uses: actions/checkout@v2 with: ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} - - name: Download cub for CUDA 10.2 - run: | - CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}') - # check if it is CUDA 10.2 - # download cub - if [ "$CUDA_VERSION" = "10.2" ]; then - wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip - unzip 1.8.0.zip - cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/ - fi - name: Install Colossal-AI run: | BUILD_EXT=1 pip install -v . - pip install -r requirements/requirements-test.txt + pip install --no-cache-dir -r requirements/requirements-test.txt + + - name: Install tensornvme + run: | + DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git + - name: Unit Testing run: | PYTHONPATH=$PWD pytest --durations=0 tests env: DATA: /data/scratch/cifar-10 - LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + LD_LIBRARY_PATH: /github/home/.tensornvme/lib LLAMA_PATH: /data/scratch/llama-tiny MOE_TENSOR_PATH: /data/scratch/moe_tensors diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml index b418c843e..770f4b933 100644 --- a/.github/workflows/compatiblity_test_on_pr.yml +++ b/.github/workflows/compatiblity_test_on_pr.yml @@ -49,42 +49,27 @@ jobs: steps: - name: Install dependencies run: | - pip install -U pip setuptools==68.2.2 wheel --user - - uses: actions/checkout@v2 - with: - repository: hpcaitech/TensorNVMe - ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} - path: TensorNVMe - - name: Install tensornvme - run: | - cd TensorNVMe apt update && apt install -y cmake - pip install -r requirements.txt - DISABLE_URING=1 pip install -v . + pip install -U pip setuptools==68.2.2 wheel --user + - uses: actions/checkout@v2 with: ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} - - name: Download cub for CUDA 10.2 - run: | - CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}') - - # check if it is CUDA 10.2 - # download cub - if [ "$CUDA_VERSION" = "10.2" ]; then - wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip - unzip 1.8.0.zip - cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/ - fi - name: Install Colossal-AI run: | BUILD_EXT=1 pip install -v . - pip install -r requirements/requirements-test.txt + pip install --no-cache-dir -r requirements/requirements-test.txt + + - name: Install tensornvme + run: | + DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git + - name: Unit Testing run: | PYTHONPATH=$PWD pytest --durations=0 tests env: DATA: /data/scratch/cifar-10 - LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + LD_LIBRARY_PATH: /github/home/.tensornvme/lib LLAMA_PATH: /data/scratch/llama-tiny MOE_TENSOR_PATH: /data/scratch/moe_tensors diff --git a/.github/workflows/compatiblity_test_on_schedule.yml b/.github/workflows/compatiblity_test_on_schedule.yml index 8d98e775c..c6455604f 100644 --- a/.github/workflows/compatiblity_test_on_schedule.yml +++ b/.github/workflows/compatiblity_test_on_schedule.yml @@ -43,47 +43,28 @@ jobs: steps: - name: Install dependencies run: | + apt update && apt install -y cmake pip install -U pip setuptools==68.2.2 wheel --user - - uses: actions/checkout@v2 - with: - repository: hpcaitech/TensorNVMe - ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} - path: TensorNVMe - - - name: Install tensornvme - run: | - cd TensorNVMe - apt update && apt install -y cmake - pip install -r requirements.txt - DISABLE_URING=1 pip install -v . - uses: actions/checkout@v2 with: ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} - - name: Download cub for CUDA 10.2 - run: | - CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}') - - # check if it is CUDA 10.2 - # download cub - if [ "$CUDA_VERSION" = "10.2" ]; then - wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip - unzip 1.8.0.zip - cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/ - fi - - name: Install Colossal-AI run: | BUILD_EXT=1 pip install -v . - pip install -r requirements/requirements-test.txt + pip install --no-cache-dir -r requirements/requirements-test.txt + + - name: Install tensornvme + run: | + DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git - name: Unit Testing run: | PYTHONPATH=$PWD pytest --durations=0 tests env: DATA: /data/scratch/cifar-10 - LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + LD_LIBRARY_PATH: /github/home/.tensornvme/lib LLAMA_PATH: /data/scratch/llama-tiny MOE_TENSOR_PATH: /data/scratch/moe_tensors diff --git a/requirements/requirements.txt b/requirements/requirements.txt index b54d1cf91..651eb66e8 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,7 +8,7 @@ click fabric contexttimer ninja -torch>=2.1.0,<2.3.0 +torch>=2.1.0,<=2.3.0 safetensors einops pydantic From 73494de57773cfc804f729234bf3611b65f13447 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 17 Jul 2024 17:29:59 +0800 Subject: [PATCH 7/8] [release] update version (#5912) --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index 1d0ba9ea1..267577d47 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.4.0 +0.4.1 From e86127925aca92467cbdc58bbea9920a2565b82c Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 18 Jul 2024 15:33:03 +0800 Subject: [PATCH 8/8] [plugin] support all-gather overlap for hybrid parallel (#5919) * [plugin] fixed all-gather overlap support for hybrid parallel --- .../booster/plugin/hybrid_parallel_plugin.py | 31 ++++++++++++++++--- .../booster/plugin/low_level_zero_plugin.py | 16 +++++----- .../hybrid_parallel_checkpoint_io.py | 4 +++ examples/language/llama/benchmark.py | 3 +- 4 files changed, 42 insertions(+), 12 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 6f27fa641..2c8cb6ba1 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -2,7 +2,7 @@ import ctypes import random import warnings from collections import defaultdict -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from copy import deepcopy from functools import partial from types import MethodType @@ -33,8 +33,11 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.policies.base_policy import Policy +from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.d_tensor.api import is_distributed_tensor +from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.zero.low_level import LowLevelZeroOptimizer +from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle from .pp_plugin_base import PipelinePluginBase @@ -61,6 +64,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): use_ddp: bool, ddp_config: dict, custom_policy: Policy, + overlap_allgather: bool = False, ) -> None: self.stage_manager = shard_config.pipeline_stage_manager self.shard_config = shard_config @@ -69,6 +73,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): self.sp_group = sp_group self.use_dpp = use_ddp self.require_grad_sync = True + self.overlap_allgather = overlap_allgather shardformer = ShardFormer(shard_config) if custom_policy is not None: @@ -106,6 +111,12 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): module = DDP(module, process_group=dp_group, **ddp_config) super().__init__(module) + if overlap_allgather: + self.op_hook = ZeroOpHook() + for p in module.parameters(): + if p.requires_grad and type(p) is not ColoParameter: + p.__class__ = ColoParameter + p.__init__(p, requires_grad=True) def sync_shared_params(self): for shared_param, group in zip(self.shared_params, self.shared_param_process_groups): @@ -197,7 +208,8 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): if self.convert_fn is not None: args = tree_map(self.convert_fn, args) kwargs = tree_map(self.convert_fn, kwargs) - return super().forward(*args, **kwargs) + with self._wait_all_gather(): + return super().forward(*args, **kwargs) def unwrap(self): module = super().unwrap() @@ -205,6 +217,13 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): module = module.module return module + def _force_wait_all_gather(self): + for p in self.module.parameters(): + wait_all_gather_handle(p) + + def _wait_all_gather(self): + return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext() + def get_param_info(optim: Optimizer): # Get a backup of necessary information of parameters for future use, which includes: @@ -650,6 +669,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): tp_process_group: Optional[ProcessGroup] = None, # if using tp pp_process_group: Optional[ProcessGroup] = None, # if using pp forced_dtype: Optional[torch.dtype] = None, + overlap_allgather: bool = False, ): self.model = model self.param_info = param_info @@ -677,7 +697,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): cpu_offload=cpu_offload, dp_process_group=dp_process_group, forced_dtype=forced_dtype, - overlap_allgather=False, + overlap_allgather=overlap_allgather, ) def sync_dp_grads(self): @@ -993,6 +1013,7 @@ class HybridParallelPlugin(PipelinePluginBase): make_vocab_size_divisible_by: int = 64, dp_outside: bool = True, overlap_p2p: bool = True, + overlap_allgather: bool = False, ) -> None: super().__init__() assert ( @@ -1144,6 +1165,7 @@ class HybridParallelPlugin(PipelinePluginBase): cpu_offload=cpu_offload, partition_grad=(self.zero_stage == 2), forced_dtype=PRECISION_TORCH_TYPE[precision], + overlap_allgather=overlap_allgather, ) self.max_norm = max_norm @@ -1221,6 +1243,7 @@ class HybridParallelPlugin(PipelinePluginBase): use_ddp=use_ddp, ddp_config=self.ddp_config, custom_policy=self.custom_policy, + overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]), ) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if zero_stage == 0: @@ -1303,7 +1326,7 @@ class HybridParallelPlugin(PipelinePluginBase): # so we disable it, performing manual reduction instead. ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() - with ctx: + with ctx, model._wait_all_gather(): outputs = self.schedule.forward_backward_step( model, data_iter, criterion, optimizer, return_loss, return_outputs ) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index b9b2c57dc..1a6547796 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -62,7 +62,7 @@ class OptimizerParamCheckState(enum.Enum): class LowLevelZeroModel(ModelWrapper, AMPModelMixin): - def __init__(self, module: nn.Module, precision: str, overlap_communication: bool = False) -> None: + def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None: super().__init__(module) self.dtype = None if precision == "fp16": @@ -76,8 +76,8 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): self.convert_fn = None if self.dtype is not None: self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) - self.overlap_communication = overlap_communication - if overlap_communication: + self.overlap_allgather = overlap_allgather + if overlap_allgather: self.op_hook = ZeroOpHook() for p in module.parameters(): if p.requires_grad and type(p) is not ColoParameter: @@ -88,7 +88,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): if self.convert_fn is not None: args = tree_map(self.convert_fn, args) kwargs = tree_map(self.convert_fn, kwargs) - ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_communication else nullcontext() + ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext() with ctx: return super().forward(*args, **kwargs) @@ -356,8 +356,8 @@ class LowLevelZeroPlugin(DPPluginBase): partition_grad=(stage == 2), cpu_offload=cpu_offload, master_weights=master_weights, + overlap_allgather=overlap_allgather, ) - self.overlap_allgather = overlap_allgather self.lora_enabled = False self.verbose = verbose @@ -473,11 +473,13 @@ class LowLevelZeroPlugin(DPPluginBase): self.add_lora_params_to_optimizer(model, optimizer) if not isinstance(model, ModelWrapper): - model = LowLevelZeroModel(model, self.precision, overlap_communication=self.overlap_allgather) + model = LowLevelZeroModel( + model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"] + ) # TODO: Support Galore + ZeRO zero_stage = self.stage - zero_optim_kwargs = {**self.zero_optim_kwargs, "overlap_allgather": self.overlap_allgather} + zero_optim_kwargs = {**self.zero_optim_kwargs} dp_size = dist.get_world_size() # Replace with the distributed implementation if exists diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 61c9d1438..b7097e432 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -195,6 +195,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): """ assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model._force_wait_all_gather() model = model.unwrap() if os.path.isfile(checkpoint): @@ -303,6 +304,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): This argument should be manually set to False since params on same device might be stored in different files. """ assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + model._force_wait_all_gather() model_before_wrapping = model # backup for model before wrapping model = model.unwrap() @@ -639,6 +641,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model._force_wait_all_gather() model = model.unwrap() if self.dp_rank != 0: @@ -679,6 +682,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + model._force_wait_all_gather() strict = False model_before_wrapping = model model = model.unwrap() diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 2b7bd50b8..e530e2d6a 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -98,6 +98,7 @@ def main(): parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation") parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") parser.add_argument("--no_cache", action="store_true") + parser.add_argument("--overlap_allgather", action="store_true") args = parser.parse_args() colossalai.launch_from_torch() @@ -199,9 +200,9 @@ def main(): enable_flash_attention=args.xformers, microbatch_size=args.mbs, precision="bf16", - dp_outside=False, overlap_p2p=args.overlap, enable_metadata_cache=not args.no_cache, + overlap_allgather=args.overlap_allgather, **hybrid_kwargs, ) elif args.plugin == "3d_cpu":