[plugin] support all-gather overlap for hybrid parallel (#5919)

* [plugin] fixed all-gather overlap support for hybrid parallel
pull/5924/head
Hongxin Liu 2024-07-18 15:33:03 +08:00 committed by GitHub
parent 73494de577
commit e86127925a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 42 additions and 12 deletions

View File

@ -2,7 +2,7 @@ import ctypes
import random import random
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager, nullcontext
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from types import MethodType from types import MethodType
@ -33,8 +33,11 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.d_tensor.api import is_distributed_tensor from colossalai.tensor.d_tensor.api import is_distributed_tensor
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero.low_level import LowLevelZeroOptimizer from colossalai.zero.low_level import LowLevelZeroOptimizer
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
from .pp_plugin_base import PipelinePluginBase from .pp_plugin_base import PipelinePluginBase
@ -61,6 +64,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
use_ddp: bool, use_ddp: bool,
ddp_config: dict, ddp_config: dict,
custom_policy: Policy, custom_policy: Policy,
overlap_allgather: bool = False,
) -> None: ) -> None:
self.stage_manager = shard_config.pipeline_stage_manager self.stage_manager = shard_config.pipeline_stage_manager
self.shard_config = shard_config self.shard_config = shard_config
@ -69,6 +73,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
self.sp_group = sp_group self.sp_group = sp_group
self.use_dpp = use_ddp self.use_dpp = use_ddp
self.require_grad_sync = True self.require_grad_sync = True
self.overlap_allgather = overlap_allgather
shardformer = ShardFormer(shard_config) shardformer = ShardFormer(shard_config)
if custom_policy is not None: if custom_policy is not None:
@ -106,6 +111,12 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
module = DDP(module, process_group=dp_group, **ddp_config) module = DDP(module, process_group=dp_group, **ddp_config)
super().__init__(module) super().__init__(module)
if overlap_allgather:
self.op_hook = ZeroOpHook()
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
p.__class__ = ColoParameter
p.__init__(p, requires_grad=True)
def sync_shared_params(self): def sync_shared_params(self):
for shared_param, group in zip(self.shared_params, self.shared_param_process_groups): for shared_param, group in zip(self.shared_params, self.shared_param_process_groups):
@ -197,6 +208,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
if self.convert_fn is not None: if self.convert_fn is not None:
args = tree_map(self.convert_fn, args) args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs) kwargs = tree_map(self.convert_fn, kwargs)
with self._wait_all_gather():
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
def unwrap(self): def unwrap(self):
@ -205,6 +217,13 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
module = module.module module = module.module
return module return module
def _force_wait_all_gather(self):
for p in self.module.parameters():
wait_all_gather_handle(p)
def _wait_all_gather(self):
return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
def get_param_info(optim: Optimizer): def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes: # Get a backup of necessary information of parameters for future use, which includes:
@ -650,6 +669,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
tp_process_group: Optional[ProcessGroup] = None, # if using tp tp_process_group: Optional[ProcessGroup] = None, # if using tp
pp_process_group: Optional[ProcessGroup] = None, # if using pp pp_process_group: Optional[ProcessGroup] = None, # if using pp
forced_dtype: Optional[torch.dtype] = None, forced_dtype: Optional[torch.dtype] = None,
overlap_allgather: bool = False,
): ):
self.model = model self.model = model
self.param_info = param_info self.param_info = param_info
@ -677,7 +697,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
dp_process_group=dp_process_group, dp_process_group=dp_process_group,
forced_dtype=forced_dtype, forced_dtype=forced_dtype,
overlap_allgather=False, overlap_allgather=overlap_allgather,
) )
def sync_dp_grads(self): def sync_dp_grads(self):
@ -993,6 +1013,7 @@ class HybridParallelPlugin(PipelinePluginBase):
make_vocab_size_divisible_by: int = 64, make_vocab_size_divisible_by: int = 64,
dp_outside: bool = True, dp_outside: bool = True,
overlap_p2p: bool = True, overlap_p2p: bool = True,
overlap_allgather: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
assert ( assert (
@ -1144,6 +1165,7 @@ class HybridParallelPlugin(PipelinePluginBase):
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2), partition_grad=(self.zero_stage == 2),
forced_dtype=PRECISION_TORCH_TYPE[precision], forced_dtype=PRECISION_TORCH_TYPE[precision],
overlap_allgather=overlap_allgather,
) )
self.max_norm = max_norm self.max_norm = max_norm
@ -1221,6 +1243,7 @@ class HybridParallelPlugin(PipelinePluginBase):
use_ddp=use_ddp, use_ddp=use_ddp,
ddp_config=self.ddp_config, ddp_config=self.ddp_config,
custom_policy=self.custom_policy, custom_policy=self.custom_policy,
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
) )
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if zero_stage == 0: if zero_stage == 0:
@ -1303,7 +1326,7 @@ class HybridParallelPlugin(PipelinePluginBase):
# so we disable it, performing manual reduction instead. # so we disable it, performing manual reduction instead.
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
with ctx: with ctx, model._wait_all_gather():
outputs = self.schedule.forward_backward_step( outputs = self.schedule.forward_backward_step(
model, data_iter, criterion, optimizer, return_loss, return_outputs model, data_iter, criterion, optimizer, return_loss, return_outputs
) )

View File

@ -62,7 +62,7 @@ class OptimizerParamCheckState(enum.Enum):
class LowLevelZeroModel(ModelWrapper, AMPModelMixin): class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str, overlap_communication: bool = False) -> None: def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None:
super().__init__(module) super().__init__(module)
self.dtype = None self.dtype = None
if precision == "fp16": if precision == "fp16":
@ -76,8 +76,8 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
self.convert_fn = None self.convert_fn = None
if self.dtype is not None: if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
self.overlap_communication = overlap_communication self.overlap_allgather = overlap_allgather
if overlap_communication: if overlap_allgather:
self.op_hook = ZeroOpHook() self.op_hook = ZeroOpHook()
for p in module.parameters(): for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter: if p.requires_grad and type(p) is not ColoParameter:
@ -88,7 +88,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
if self.convert_fn is not None: if self.convert_fn is not None:
args = tree_map(self.convert_fn, args) args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs) kwargs = tree_map(self.convert_fn, kwargs)
ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_communication else nullcontext() ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
with ctx: with ctx:
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
@ -356,8 +356,8 @@ class LowLevelZeroPlugin(DPPluginBase):
partition_grad=(stage == 2), partition_grad=(stage == 2),
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
master_weights=master_weights, master_weights=master_weights,
overlap_allgather=overlap_allgather,
) )
self.overlap_allgather = overlap_allgather
self.lora_enabled = False self.lora_enabled = False
self.verbose = verbose self.verbose = verbose
@ -473,11 +473,13 @@ class LowLevelZeroPlugin(DPPluginBase):
self.add_lora_params_to_optimizer(model, optimizer) self.add_lora_params_to_optimizer(model, optimizer)
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision, overlap_communication=self.overlap_allgather) model = LowLevelZeroModel(
model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"]
)
# TODO: Support Galore + ZeRO # TODO: Support Galore + ZeRO
zero_stage = self.stage zero_stage = self.stage
zero_optim_kwargs = {**self.zero_optim_kwargs, "overlap_allgather": self.overlap_allgather} zero_optim_kwargs = {**self.zero_optim_kwargs}
dp_size = dist.get_world_size() dp_size = dist.get_world_size()
# Replace with the distributed implementation if exists # Replace with the distributed implementation if exists

View File

@ -195,6 +195,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
""" """
assert isinstance(model, ModelWrapper), "Please boost the model before saving!" assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
model = model.unwrap() model = model.unwrap()
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
@ -303,6 +304,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
This argument should be manually set to False since params on same device might be stored in different files. This argument should be manually set to False since params on same device might be stored in different files.
""" """
assert isinstance(model, ModelWrapper), "Please boost the model before loading!" assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
model._force_wait_all_gather()
model_before_wrapping = model # backup for model before wrapping model_before_wrapping = model # backup for model before wrapping
model = model.unwrap() model = model.unwrap()
@ -639,6 +641,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(model, ModelWrapper), "Please boost the model before saving!" assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
model = model.unwrap() model = model.unwrap()
if self.dp_rank != 0: if self.dp_rank != 0:
@ -679,6 +682,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(model, ModelWrapper), "Please boost the model before loading!" assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
model._force_wait_all_gather()
strict = False strict = False
model_before_wrapping = model model_before_wrapping = model
model = model.unwrap() model = model.unwrap()

View File

@ -98,6 +98,7 @@ def main():
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation") parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true") parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--overlap_allgather", action="store_true")
args = parser.parse_args() args = parser.parse_args()
colossalai.launch_from_torch() colossalai.launch_from_torch()
@ -199,9 +200,9 @@ def main():
enable_flash_attention=args.xformers, enable_flash_attention=args.xformers,
microbatch_size=args.mbs, microbatch_size=args.mbs,
precision="bf16", precision="bf16",
dp_outside=False,
overlap_p2p=args.overlap, overlap_p2p=args.overlap,
enable_metadata_cache=not args.no_cache, enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather,
**hybrid_kwargs, **hybrid_kwargs,
) )
elif args.plugin == "3d_cpu": elif args.plugin == "3d_cpu":