diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py index fc7e0b741..b2ba47f67 100644 --- a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py @@ -43,7 +43,7 @@ class MixedPrecisionMixin(ABC): dtype: torch.dtype @abstractmethod - def pre_backward(self, loss: Tensor) -> Tensor: + def pre_backward(self, loss: Tensor, *args, **kwargs) -> Tensor: """Called before backward. Args: diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py index af509899c..1539bc01d 100644 --- a/colossalai/amp/naive_amp/mixed_precision_optimizer.py +++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py @@ -86,13 +86,18 @@ class MixedPrecisionOptimizer(OptimizerWrapper): group["params"] = master_params self._current_grad_norm: Optional[float] = None - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): loss = self.mixed_precision.pre_backward(loss) - loss.backward(*args, **kwargs) + loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs) - def backward_by_grad(self, tensor: Tensor, grad: Tensor): + def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): grad = self.mixed_precision.pre_backward_by_grad(tensor, grad) - tensor.backward(grad) + torch.autograd.backward( + tensors=tensor, + grad_tensors=grad, + inputs=inputs, + retain_graph=retain_graph, + ) def zero_grad(self, *args, **kwargs): for p in self.working_to_master_map.keys(): diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py index c757a878d..a85d9f808 100644 --- a/colossalai/booster/mixed_precision/fp16_torch.py +++ b/colossalai/booster/mixed_precision/fp16_torch.py @@ -46,9 +46,9 @@ class TorchAMPOptimizer(OptimizerWrapper): growth_interval=growth_interval, ) - def backward(self, loss: Tensor, *args, **kwargs) -> None: + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs) -> None: scaled_loss = self.scale_loss(loss) - scaled_loss.backward(*args, **kwargs) + scaled_loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs) def step(self, *args, **kwargs) -> Optional[float]: out = self.scaler.step(self.optim, *args, **kwargs) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index e021a9ead..79c9379cc 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -28,7 +28,7 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface.optimizer import DistributedOptim from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed -from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule +from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.quantization.fp8_hook import FP8Hook @@ -296,7 +296,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): self._current_grad_norm: Optional[float] = None super().__init__(optim) - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): r""" Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. @@ -315,7 +315,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): # Call the superclass backward method to compute gradients. with self.model._hook_context(): - super().backward(loss, *args, **kwargs) + super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -324,7 +324,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): # If gradient synchronization is is not required, return. return - def backward_by_grad(self, tensor: Tensor, grad: Tensor): + def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): """ Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. @@ -341,7 +341,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): """ # Call the superclass backward method to compute gradients. - super().backward_by_grad(tensor, grad) + super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -525,7 +525,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): max_norm=max_norm, ) - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): r""" Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. @@ -543,7 +543,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): """ # Call the superclass backward method to compute gradients. with self.model._hook_context(): - super().backward(loss, *args, **kwargs) + super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -552,7 +552,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): # If gradient synchronization is is not required, return. return - def backward_by_grad(self, tensor: Tensor, grad: Tensor): + def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): """ Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. @@ -568,7 +568,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): None """ # Call the superclass backward method to compute gradients. - super().backward_by_grad(tensor, grad) + super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -785,7 +785,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): else: return - def backward(self, loss, retain_graph=False): + def backward(self, loss, inputs=None, retain_graph=False): """ Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. @@ -801,7 +801,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): None """ # Call the superclass backward method to compute gradients. - super().backward(loss, retain_graph) + super().backward(loss, inputs=inputs, retain_graph=retain_graph) if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -810,7 +810,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): # If gradient synchronization is is not required, return. return - def backward_by_grad(self, tensor, grad): + def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False): """ Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. @@ -826,7 +826,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): None """ # Call the superclass backward_by_grad method to compute gradients. - super().backward_by_grad(tensor, grad) + super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph) if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -1030,6 +1030,7 @@ class HybridParallelPlugin(PipelinePluginBase): custom_policy: Policy = None, pp_style: str = "1f1b", num_model_chunks: int = 1, + scheduler_nodes: List = None, num_layers_per_stage: Optional[List[int]] = None, gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, enable_metadata_cache: bool = True, @@ -1048,6 +1049,9 @@ class HybridParallelPlugin(PipelinePluginBase): dist.get_world_size() % (tp_size * pp_size) == 0 ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" + assert ( + not pp_style == "zbv" or scheduler_nodes is not None + ), f"scheduler_nodes must not be None when using zero bubble pipeline." if enable_sequence_parallelism: self.sequence_parallelism_mode = ( sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" @@ -1109,29 +1113,39 @@ class HybridParallelPlugin(PipelinePluginBase): self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) self.stage_manager = None - self.schedule = None + self.scheduler = None self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: - assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" - assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" + assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style" + assert ( + pp_style in ["interleaved", "zbv"] or num_model_chunks == 1 + ), "num_model_chunks must be 1 when using 1f1b" + assert ( + pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2 + ), "num_model_chunks must be 2 when using zero bubble pipeline" assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" assert ( self.zero_stage <= 1 ), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism" + if pp_style == "zbv": + self.logger.warning( + """the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})""" + ) self.stage_manager = PipelineStageManager( self.pg_mesh, pipeline_axis=self.pp_axis, - enable_interleave=pp_style == "interleaved", + enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"), + use_zbv=(pp_style == "zbv"), num_model_chunks=num_model_chunks, num_layers_per_stage=num_layers_per_stage, ) if pp_style == "interleaved": assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" - self.schedule = InterleavedSchedule( + self.scheduler = InterleavedSchedule( stage_manager=self.stage_manager, num_model_chunks=num_model_chunks, num_microbatch=num_microbatches, @@ -1141,13 +1155,21 @@ class HybridParallelPlugin(PipelinePluginBase): fp8_communication=fp8_communication, ) elif pp_style == "1f1b": - self.schedule = OneForwardOneBackwardSchedule( + self.scheduler = OneForwardOneBackwardSchedule( stage_manager=self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size, enable_metadata_cache=enable_metadata_cache, fp8_communication=fp8_communication, ) + elif pp_style == "zbv": + self.scheduler = ZeroBubbleVPipeScheduler( + stage_manager=self.stage_manager, + schedule=scheduler_nodes, + num_model_chunks=num_model_chunks, + num_microbatch=num_microbatches, + microbatch_size=microbatch_size, + ) else: raise NotImplementedError() if sequence_parallelism_mode == "ring_attn": @@ -1263,7 +1285,6 @@ class HybridParallelPlugin(PipelinePluginBase): # Replace with distributed implementation if exists optimizer = cast_to_distributed(optimizer) - if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: self.logger.warning( "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.", @@ -1278,6 +1299,7 @@ class HybridParallelPlugin(PipelinePluginBase): self.dp_size == 1 and self.pp_size == 1 ) # sync gradients across DP * SP ranks + # sync gradients across DP * SP ranks # Apply Hybrid ZeRO across DP * SP ranks if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode): dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) @@ -1380,7 +1402,7 @@ class HybridParallelPlugin(PipelinePluginBase): ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() with ctx, model._hook_context(): - outputs = self.schedule.forward_backward_step( + outputs = self.scheduler.forward_backward_step( model, data_iter, criterion, optimizer, return_loss, return_outputs ) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 2f08d2183..96531a04f 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -29,6 +29,7 @@ from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import cast_to_distributed from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule +from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig @@ -212,6 +213,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): custom_policy: Policy = None, pp_style: str = "1f1b", num_model_chunks: int = 1, + scheduler_nodes: List = None, num_layers_per_stage: Optional[List[int]] = None, gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, enable_metadata_cache: bool = True, @@ -285,12 +287,17 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size) self.stage_manager = None - self.schedule = None + self.scheduler = None self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: - assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" - assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" + assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style" + assert ( + pp_style in ["interleaved", "zbv"] or num_model_chunks == 1 + ), "num_model_chunks must be 1 when using 1f1b" + assert ( + pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2 + ), "num_model_chunks must be 2 when using zero bubble pipeline" assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" @@ -300,14 +307,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.stage_manager = PipelineStageManager( self.pg_mesh, pipeline_axis=self.pp_axis, - enable_interleave=pp_style == "interleaved", + enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"), num_model_chunks=num_model_chunks, num_layers_per_stage=num_layers_per_stage, + use_zbv=(pp_style == "zbv"), ) if pp_style == "interleaved": assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" - self.schedule = InterleavedSchedule( + self.scheduler = InterleavedSchedule( stage_manager=self.stage_manager, num_model_chunks=num_model_chunks, num_microbatch=num_microbatches, @@ -316,12 +324,21 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): overlap_p2p=overlap_p2p, ) elif pp_style == "1f1b": - self.schedule = OneForwardOneBackwardSchedule( + self.scheduler = OneForwardOneBackwardSchedule( stage_manager=self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size, enable_metadata_cache=enable_metadata_cache, ) + elif pp_style == "zbv": + assert num_model_chunks > 1, "number of model chunks must be > 1 when using ZerbubbleV" + self.scheduler = ZeroBubbleVPipeScheduler( + schedule=scheduler_nodes, + stage_manager=self.stage_manager, + num_model_chunks=num_model_chunks, + num_microbatch=num_microbatches, + overlap_p2p=overlap_p2p, + ) else: raise NotImplementedError() diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py index 2cb63601f..22115a72c 100644 --- a/colossalai/interface/optimizer.py +++ b/colossalai/interface/optimizer.py @@ -49,14 +49,31 @@ class OptimizerWrapper: """ self.optim.zero_grad(*args, **kwargs) - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): """ Performs a backward pass on the loss. """ - loss.backward(*args, **kwargs) + loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs) - def backward_by_grad(self, tensor: Tensor, grad: Tensor): - torch.autograd.backward(tensor, grad) + def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): + """ + Performs a backward pass for dx or dw, + for dx, we only calculate dx = w*dy here + for dw, we only calculate dw = x*dy here + + Args: + tensor (Tensor): y or loss of current chunk; + grad_tensors (Tensor): dy of current chunk; + input_obj (Tensor): for dx, input_obj is x of current chunk; + for dw, input_obj is w of current chunk; + retain_graph (bool): default to be True, we retain graph in backward_b + """ + torch.autograd.backward( + tensors=tensor, + grad_tensors=grad, + inputs=inputs, + retain_graph=retain_graph, + ) def state_dict(self): """ diff --git a/colossalai/pipeline/__init__.py b/colossalai/pipeline/__init__.py index 4754212c1..5d44530e7 100644 --- a/colossalai/pipeline/__init__.py +++ b/colossalai/pipeline/__init__.py @@ -1,11 +1,12 @@ from .p2p import PipelineP2PCommunication -from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule +from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule, ZeroBubbleVPipeScheduler from .stage_manager import PipelineStageManager __all__ = [ "PipelineSchedule", "OneForwardOneBackwardSchedule", "InterleavedSchedule", + "ZeroBubbleVPipeScheduler", "PipelineP2PCommunication", "PipelineStageManager", ] diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index b7b284213..8c319aceb 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -432,7 +432,6 @@ def _communicate( overlap_p2p=overlap_p2p, send_first=send_first if send_first != None else True, ) - if metadata_recv is not None: assert isinstance(metadata_recv, P2PMetadata) tree_spec = metadata_recv.tree_spec diff --git a/colossalai/pipeline/schedule/__init__.py b/colossalai/pipeline/schedule/__init__.py index 6845dc237..05dd24e81 100644 --- a/colossalai/pipeline/schedule/__init__.py +++ b/colossalai/pipeline/schedule/__init__.py @@ -1,9 +1,11 @@ from .base import PipelineSchedule from .interleaved_pp import InterleavedSchedule from .one_f_one_b import OneForwardOneBackwardSchedule +from .zero_bubble_pp import ZeroBubbleVPipeScheduler __all__ = [ "PipelineSchedule", "OneForwardOneBackwardSchedule", "InterleavedSchedule", + "ZeroBubbleVPipeScheduler", ] diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py index 8f42a9014..ded75d968 100644 --- a/colossalai/pipeline/schedule/_utils.py +++ b/colossalai/pipeline/schedule/_utils.py @@ -137,6 +137,16 @@ def retain_grad(x: Any) -> None: x.retain_grad() +def require_grad(x: Any) -> None: + """Call require_grad on a tensor. + + Args: + x (Any): Object to be called. + """ + if isinstance(x, torch.Tensor) and not x.requires_grad: + x.requires_grad_() + + def detach(x: Any) -> Any: """Call detach() on a tensor. @@ -151,6 +161,34 @@ def detach(x: Any) -> Any: return x +def clone(x: Any) -> Any: + """Call clone() on a tensor. + + Args: + x (Any): Object to be called. + + Returns: + Any: The cloned object. + """ + if isinstance(x, torch.Tensor): + return x.clone() + return x + + +def release_tensor_data(x: Any) -> Any: + """Call untyped_storage().resize_(0) on a tensor. Use to release tensor.data and keep grad_fn. + + Args: + x (Any): Object to be called. + + Returns: + Any: The deallocate .data object. + """ + if isinstance(x, torch.Tensor): + return x.data.untyped_storage().resize_(0) + return x + + def merge_batch(data: List[Any], batch_size_dim=0) -> Any: """Merge micro batches into a batch. diff --git a/colossalai/pipeline/schedule/v_schedule.py b/colossalai/pipeline/schedule/v_schedule.py new file mode 100644 index 000000000..d3e94c0ba --- /dev/null +++ b/colossalai/pipeline/schedule/v_schedule.py @@ -0,0 +1,449 @@ +# Refer from Zero Bubble Pipeline Parallelism. +# Github: https://github.com/sail-sg/zero-bubble-pipeline-parallelism +# Paper: https://arxiv.org/abs/2401.10241 +# The following applies to all files unless otherwise noted: +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from collections import deque +from dataclasses import dataclass + + +@dataclass(eq=True, frozen=True) +class ScheduledNode: + type: str + chunk: int + stage: int + minibatch: int + start_time: int = 0 + completion_time: int = 0 + rollback: bool = False + + +class PipelineGraph(object): + """PipelineGraph""" + + def __init__( + self, + n_stage, + n_micro, + f_cost, + b_cost, + w_cost, + c_cost, + f_mem, + b_mem, + w_mem, + max_mem=None, + ): + self.n_node = 6 * n_stage * n_micro + self.n_stage = n_stage + self.n_micro = n_micro + self.f_cost = f_cost + self.b_cost = b_cost + self.w_cost = w_cost + self.c_cost = c_cost + self.f_mem = f_mem + self.b_mem = b_mem + self.w_mem = w_mem + self.fbw_cost = [f_cost, b_cost, w_cost] + self.fbw_mem = [f_mem, b_mem, w_mem] + self.max_mem = max_mem or f_mem * self.n_stage * 2 + + def get_id(self, cat, chunk, stage, micro): + return ( + cat * 2 * self.n_stage * self.n_micro + chunk * self.n_stage * self.n_micro + stage * self.n_micro + micro + ) + + def try_v_schedule(self, fill_f=True, fill_b=True, approved_bubble=None): + count = [] + for i in range(self.n_stage): + count.append([0] * 6) + + end_time = [-1] * self.n_node + cur_time = [0] * self.n_stage + mem = [0] * self.n_stage + stage_bubble = [0] * self.n_stage + pending_w = [deque() for _ in range(self.n_stage)] + schedule = [[] for _ in range(self.n_stage)] + stage_str = [" " * i for i in range(self.n_stage)] + + if approved_bubble is None: + approved_bubble = [-1] * self.n_stage + max_approved_bubble = max(approved_bubble) + + def get_max_stage_bubble(stage=-1): + max_stage_bubble = 0 + for bb in stage_bubble: + max_stage_bubble = max(max_stage_bubble, bb) + if stage >= 0: + max_stage_bubble = max(max_stage_bubble, max_approved_bubble - approved_bubble[stage]) + return max_stage_bubble + + def put_w(stage): + assert len(pending_w[stage]) > 0 + _, chunk_, _ = pending_w[stage].popleft() + put(2, chunk_, stage) + + def put(cat, chunk, stage, assert_cnt=True): + _tmp = _no_bubble = cur_time[stage] + self.fbw_cost[cat] + _cnt = count[stage][cat * 2 + chunk] + # assert _cnt < self.n_micro + if _cnt >= self.n_micro: + if not assert_cnt: + stage_str[stage] += " " + cur_time[stage] = _tmp # TODO + return + assert False + assert mem[stage] + self.fbw_mem[cat] <= self.max_mem + stage_str[stage] += "FfBbWw"[cat * 2 + chunk] + str(_cnt + 1) + " " * (3 - len(str(_cnt + 1))) + if cat > 0 or chunk > 0: + last_id = cat * 2 + chunk - 1 + if cat < 2: + assert end_time[self.get_id(last_id // 2, last_id % 2, stage, _cnt)] >= 0 + else: + assert end_time[self.get_id(1, chunk, stage, _cnt)] >= 0 + if chunk == 1 and cat < 2: + if stage < self.n_stage - 1: + _fa_id = self.get_id(cat, chunk, stage + 1, _cnt) + assert end_time[_fa_id] >= 0 + _tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat]) + if chunk == 0 and cat < 2: + if stage > 0: + _fa_id = self.get_id(cat, chunk, stage - 1, _cnt) + assert end_time[_fa_id] >= 0, f"{cat}, {chunk}, {stage}, {_cnt}" + _tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat]) + _id = self.get_id(cat, chunk, stage, _cnt) + if count[stage][0] > 0: + stage_bubble[stage] += _tmp - _no_bubble + end_time[_id] = _tmp + cur_time[stage] = _tmp + mem[stage] += self.fbw_mem[cat] + # noinspection PyTypeChecker + schedule[stage].append((cat, chunk, _cnt)) + if cat == 1: + pending_w[stage].append((2, chunk, _cnt)) + count[stage][cat * 2 + chunk] += 1 + + for i in range(self.n_stage): + put(0, 0, i) + for i in range(self.n_stage - 1, -1, -1): + if i == self.n_stage - 1: + put(0, 1, i) + continue + tmp = end_time[self.get_id(0, 1, i + 1, 0)] + self.c_cost + while ( + mem[i] + self.fbw_mem[0] * (2 + i * 2) <= self.max_mem + and cur_time[i] + self.fbw_cost[0] <= tmp + and count[i][0] < self.n_micro + ): + for j in range(i + 1): + put(0, 0, j) + put(0, 1, i) + iter_chunk_ = 0 + end_tmp = 0 + for i in range(self.n_stage): + if i == 0: + end_tmp = cur_time[0] + self.fbw_cost[1] + continue + tmp = end_tmp + self.c_cost + while ( + count[i][0] + count[i][1] < count[i - 1][0] + count[i - 1][1] + or count[i][1] <= count[i - 1][1] < self.n_micro + ): + for j in range(self.n_stage - 1, i - 1, -1): + if count[j][iter_chunk_] < self.n_micro: + put(0, iter_chunk_, j) + iter_chunk_ = 1 - iter_chunk_ + + for _ in range(2 * self.n_micro): + # check mem before putting b + for i in range(self.n_stage): + while mem[i] + self.fbw_mem[1] > self.max_mem: + assert len(pending_w[i]) > 0 + put_w(i) + b0_ranks, b1_ranks = [], [] + for i in range(self.n_stage): + if count[i][3] >= count[i][2]: + b0_ranks.append(i) + elif i == self.n_stage - 1: + b1_ranks.append(i) + else: + fa_id = self.get_id(1, 1, i + 1, count[i][3]) + if end_time[fa_id] >= 0 or count[i][2] >= self.n_micro: + b1_ranks.append(i) + else: + b0_ranks.append(i) + b_ranks = [] + # put b1 + for i in reversed(b1_ranks): + b_ranks.append((i, 1)) + # put b0 + for i in b0_ranks: + b_ranks.append((i, 0)) + for i, _chunk_ in b_ranks: + fa_id = -1 + if _chunk_ == 1 and i < self.n_stage - 1: + fa_id = self.get_id(1, 1, i + 1, count[i][3]) + if _chunk_ == 0 and i > 0: + fa_id = self.get_id(1, 0, i - 1, count[i][2]) + while ( + len(pending_w[i]) > 0 + and fa_id >= 0 + and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2] + ): + # fill the bubble + put_w(i) + if ( + len(pending_w[i]) > 0 + and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i] + ): + if _chunk_ == 1: + put_w(i) + elif fill_b: + put_w(i) + put(1, _chunk_, i) + + # put f + for i in range(self.n_stage): + if count[i][1] >= self.n_micro: + continue + put_item = None + if count[i][1] >= count[i][0]: + put_item = 0 + elif i == self.n_stage - 1: + put_item = 1 + else: + if end_time[self.get_id(0, 1, i + 1, count[i][1])] >= 0: + put_item = 1 + elif count[i][0] < self.n_micro: + if i == 0: + put_item = 0 + elif end_time[self.get_id(0, 0, i - 1, count[i][0])] >= 0: + put_item = 0 + if put_item is None: + continue + # check mem before putting f + while mem[i] + self.fbw_mem[0] > self.max_mem: + assert len(pending_w[i]) > 0 + put_w(i) + fa_id = -1 + if put_item == 0 and i > 0: + fa_id = self.get_id(0, 0, i - 1, count[i][0]) + if put_item == 1 and i < self.n_stage - 1: + fa_id = self.get_id(0, 1, i + 1, count[i][1]) + while ( + len(pending_w[i]) > 0 + and fa_id >= 0 + and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2] + ): + # fill the bubble + put_w(i) + if ( + len(pending_w[i]) > 0 + and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i] + ): + if fill_f: + put_w(i) + put(0, put_item, i) + + for i in range(self.n_stage): + while len(pending_w[i]) > 0: + put_w(i) + + max_bubble = get_max_stage_bubble() + expected_time = sum(self.fbw_cost) * self.n_micro * 2 + max_bubble / expected_time + if max_approved_bubble < 0 or max_bubble < max_approved_bubble: + _schedule, _end_time, _max_bubble = self.try_v_schedule( + fill_f=fill_f, + fill_b=fill_b, + approved_bubble=stage_bubble, + ) + if _max_bubble < max_bubble: + return _schedule, _end_time, _max_bubble + return schedule, end_time, max_bubble + + def print_details(self, end_time, print_scaling=1): + for stage in range(self.n_stage): + stage_str = ["."] * int(max(end_time) / print_scaling) + for _cat in range(3): + for _chunk in range(2): + for _micro in range(self.n_micro): + _id = self.get_id(_cat, _chunk, stage, _micro) + if end_time[_id] < 0: + continue + end = int(end_time[_id] / print_scaling) + start = int((end_time[_id] - self.fbw_cost[_cat]) / print_scaling) + for j in range(start, end): + if j == start or j == end - 1: + stage_str[j] = "FfBbWw"[_cat * 2 + _chunk] + elif j == start + 1: + if _micro >= 10: + stage_str[j] = str(_micro // 10) + else: + stage_str[j] = str(_micro) + elif j == start + 2 and _micro >= 10: + stage_str[j] = str(_micro % 10) + else: + stage_str[j] = "-" + _str = "" + for _c in stage_str: + _str += _c + print(_str) + + def get_v_schedule(self, only_run_time=False): + schedule, end_time, max_bubble = None, None, None + expected_time = sum(self.fbw_cost) * self.n_micro * 2 + for fill_b in [True, False]: + for fill_f in [True, False]: + _schedule, _end_time, _max_bubble = self.try_v_schedule(fill_b=fill_b, fill_f=fill_f) + if max_bubble is None or _max_bubble < max_bubble: + max_bubble = _max_bubble + schedule = _schedule + end_time = _end_time + if only_run_time: + return max_bubble + expected_time + max_bubble / (expected_time + max_bubble) + local_order = [[] for _ in range(self.n_stage)] + comm_id = {} + comm_id_counter = 0 + post_validation_time = 0 + for i in range(self.n_stage - 1, -1, -1): + pv_id = min(2 * (self.n_stage - 1 - i), self.n_micro - 1) + post_validation_time = max( + post_validation_time, end_time[self.get_id(0, 0, i, pv_id)] - self.fbw_cost[0] - self.c_cost + ) + # post_validation_time = 0 + for it in ["RECV_", "SEND_", ""]: + if i == 0 and it == "SEND_": + continue + if i == self.n_stage - 1 and it == "RECV_": + continue + # stage_ = i - 1 if it == "RECV_" else i + stage_ = i + local_order[stage_].append( + ScheduledNode( + type=it + "POST_VALIDATION", + chunk=0, + stage=stage_, + minibatch=0, + start_time=post_validation_time, + completion_time=post_validation_time, + ) + ) + comm_id[local_order[stage_][-1]] = comm_id_counter + comm_id_counter += 1 + for i in range(self.n_stage): + for _cat_, _chunk_, _micro_ in schedule[i]: + complete_time = end_time[self.get_id(_cat_, _chunk_, i, _micro_)] + local_order[i].append( + ScheduledNode( + type="FBW"[_cat_], + chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_, + stage=i, + minibatch=_micro_, + start_time=complete_time - self.fbw_cost[_cat_], + completion_time=complete_time, + ) + ) + if _cat_ == 2: # no communication for W + continue + cat_str = "FORWARD" if _cat_ == 0 else "BACKWARD" + + def communicate(send_recv, stage_): + # noinspection PyTypeChecker + local_order[stage_].append( + ScheduledNode( + type=send_recv + cat_str, + chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_, + stage=stage_, + minibatch=_micro_, + start_time=complete_time, + completion_time=complete_time, + ) + ) + comm_id[local_order[stage_][-1]] = comm_id_counter + + if _chunk_ == 1 and i > 0: + communicate("SEND_", i) + communicate("RECV_", i - 1) + if _chunk_ == 0 and i < self.n_stage - 1: + communicate("SEND_", i) + communicate("RECV_", i + 1) + comm_id_counter += 1 + for rank in range(self.n_stage): + # For nodes with the same timestamp on the same stage, communication will be prioritized. + def even_breaker(x: ScheduledNode): + # Compute nodes are always delayed. + if x.type in ["F", "B", "W"]: + return comm_id_counter + # For comm nodes, order by their unique comm id + return comm_id[x] + + local_order[rank] = list(sorted(local_order[rank], key=lambda x: (x.start_time, even_breaker(x)))) + # If a recv with intersects with previous computation, reorder them so that recv + # is executed before computation and hence can be overlapped. + for i in range(len(local_order[rank])): + if ( + i > 0 + and local_order[rank][i - 1].type in {"F", "B", "W"} + and local_order[rank][i].type.startswith("RECV") + and "POST_VALIDATION" not in local_order[rank][i].type + and local_order[rank][i].start_time <= local_order[rank][i - 1].completion_time + ): + local_order[rank][i], local_order[rank][i - 1] = local_order[rank][i - 1], local_order[rank][i] + + local_order_with_rollback = [[] for _ in range(self.n_stage)] + for rank in range(self.n_stage): + rollback_comm = set() + if rank > 0: + for node in local_order[rank - 1]: + if node.type == "POST_VALIDATION": + break + if node.type == "SEND_FORWARD": + assert node.chunk == 0 + rollback_comm.add(node.minibatch) + for node in local_order[rank]: + if node.type == "RECV_FORWARD" and node.chunk == 0 and node.minibatch in rollback_comm: + rollback = True + rollback_comm.remove(node.minibatch) + else: + rollback = False + local_order_with_rollback[rank].append( + ScheduledNode( + type=node.type, + chunk=node.chunk, + stage=node.stage, + minibatch=node.minibatch, + start_time=node.start_time, + completion_time=node.completion_time, + rollback=rollback, + ) + ) + assert len(rollback_comm) == 0 + + return local_order_with_rollback diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py new file mode 100644 index 000000000..7cec5f003 --- /dev/null +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -0,0 +1,958 @@ +from functools import partial +from typing import Any, Callable, Dict, Iterable, List, Optional, Union + +import torch +import torch.cuda +import torch.distributed +from torch.nn import Module, ModuleList +from torch.utils._pytree import tree_flatten, tree_map + +from colossalai.accelerator import get_accelerator +from colossalai.interface import OptimizerWrapper +from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata +from colossalai.pipeline.schedule.v_schedule import ScheduledNode +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.pipeline.weight_grad_store import WeightGradStore + +from ._utils import ( + clone, + detach, + get_batch_size, + get_micro_batch, + merge_batch, + model_forward, + release_tensor_data, + require_grad, + retain_grad, + to_device, +) +from .base import PipelineSchedule + +AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"} + + +def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None: + if wait_handles is not None: + for req in wait_handles: + req.wait() + + +class ZeroBubbleVPipeScheduler(PipelineSchedule): + def __init__( + self, + stage_manager: PipelineStageManager, + schedule: List[ScheduledNode], + num_model_chunks: int, + num_microbatch: Optional[int] = None, + microbatch_size: Optional[int] = None, + enable_metadata_cache: bool = True, + overlap_p2p: bool = True, + ): + super().__init__(stage_manager) + # batch info + self.num_microbatch = num_microbatch + self.microbatch_size = microbatch_size + self.num_model_chunks = num_model_chunks + self.batch: Any + self.batch_size: int + self.last_batch_size: Optional[int] = None + self.microbatch_offset: List[int] + + self.schedules = schedule + # TODO: optim post valid + self.do_post_validation = False + + # P2PMeta cache + self.enable_metadata_cache = enable_metadata_cache + + # check send_tensor_metadata, send_grad_metadata + # pp4 as sample, we should follow this meta strategy + # send_tensor_meta(fwd) send_grad_meta(bwd) + # chunk0 | chunk1 chunk0 | chunk 1 + # stage 0 T | F F | T + # stage 1 T | T T | T + # stage 2 T | T T | T + # stage 3 F | T F | T + if stage_manager.is_first_stage(ignore_chunk=True): + self.send_tensor_metadata = [True, False] + self.send_grad_metadata = [False, True] + elif stage_manager.is_last_stage(ignore_chunk=True): + self.send_tensor_metadata = [False, True] + self.send_grad_metadata = [True, False] + else: + self.send_tensor_metadata = [True, True] + self.send_grad_metadata = [True, True] + + # meta cache buffer + self.tensor_metadata_recv = [None, None] # [chunk 0 meta, chunk 1 meta] + self.grad_metadata_recv = [None, None] + + # P2P communication + self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p) + + # init communication map + self.communication_map = { + "SEND_FORWARD": self.send_forward, + "RECV_FORWARD": self.recv_forward, + "SEND_BACKWARD": self.send_backward, + "RECV_BACKWARD": self.recv_backward, + } + + # init buffer + self._free_buffers() + + def _free_buffers(self): + # free local buffer + # two dim array, first dim is the model chunk, second dim is the microbatch queue + + # x & y buffer for schedule b + self.input_tensors = [[], []] + self.output_tensors = [[], []] + + # y & dy buffer for schedule w + self.output_tensors_dw = [[], []] + self.output_tensors_grad_dw = [[], []] + + # buffer for communication + self.send_forward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]] + self.recv_forward_buffer = [ + [], + [], + ] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]] + self.send_backward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]] + self.recv_backward_buffer = [ + [], + [], + ] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]] + + # y buffer for local send fwd + self.local_send_forward_buffer = [] + # dy buffer for local send bwd + self.local_send_backward_buffer = [] + + # wait pp buffer + self.wait_handles = [] + + def assert_buffer_empty(self): + # assert buffer is empty at end + assert len(self.input_tensors[0]) == 0 + assert len(self.input_tensors[1]) == 0 + assert len(self.output_tensors[0]) == 0 + assert len(self.output_tensors[1]) == 0 + assert len(self.output_tensors_dw[0]) == 0 + assert len(self.output_tensors_dw[1]) == 0 + assert len(self.output_tensors_grad_dw[0]) == 0 + assert len(self.output_tensors_grad_dw[1]) == 0 + assert len(self.send_forward_buffer[0]) == 0 + assert len(self.send_forward_buffer[1]) == 0 + assert len(self.recv_forward_buffer[0]) == 0 + assert len(self.recv_forward_buffer[1]) == 0 + assert len(self.send_backward_buffer[0]) == 0 + assert len(self.send_backward_buffer[1]) == 0 + assert len(self.recv_backward_buffer[0]) == 0 + assert len(self.recv_backward_buffer[1]) == 0 + assert len(self.local_send_forward_buffer) == 0 + assert len(self.local_send_backward_buffer) == 0 + + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: + """Load a batch from data iterator. + + Args: + data_iter (Iterable): Data iterator. + device (Optional[torch.device], optional): Target device. Defaults to None. + """ + batch = next(data_iter) + if device is not None: + batch = tree_map(partial(to_device, device=device), batch) + + self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] + self.batch = batch + self.batch_size = get_batch_size(batch) + + if self.microbatch_size is None: + assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch" + self.microbatch_size = self.batch_size // self.num_microbatch + if self.num_microbatch is None: + assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size" + self.num_microbatch = self.batch_size // self.microbatch_size + + if not self.forward_only: + assert self.last_batch_size is None or self.last_batch_size == self.batch_size + assert self.batch_size == self.microbatch_size * self.num_microbatch + + assert ( + self.num_microbatch % self.stage_manager.num_stages == 0 + ), "Number of microbatch should be an integer multiple of number of pipeline parallel devices" + + if self.forward_only: + self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1 + + self.last_batch_size = self.batch_size + + def load_micro_batch(self, model_chunk_id: int) -> Any: + """Load a micro batch from the current batch. + + Args: + microbatch_id (int): the current model chunk idx. + + Returns: + Any: Micro batch. + """ + assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted" + micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size) + self.microbatch_offset[model_chunk_id] += self.microbatch_size + return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch) + + def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int: + """Helper method to get the model chunk ID given the iteration number. + + Args: + microbatch_id (int): the current microbatch idx + forward (bool): if is the forward process + + Returns: + int: The model chunk idx of the input microbatch_id + """ + assert ( + microbatch_id < self.num_microbatch * self.num_model_chunks + ), f"microbatch_id {microbatch_id} is out of range ({self.num_microbatch * self.num_model_chunks})" + microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks) + model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages + if not is_forward: + # Reverse order + model_chunk_id = self.num_model_chunks - model_chunk_id - 1 + return model_chunk_id + + def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> List: + """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. + For ZBV. + + Args: + model_chunk_id (int): The current model chunk idx. + prev_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input tensor or input tensor list. + Any: The wait handles for the communication. + """ + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + if model_chunk_id == 0: + ################ + # chunk = 0 & is_first_stage + # do nothing; cause u are chunk 0 in first rank, u have no prev rank; + ################# + if self.stage_manager.is_first_stage(ignore_chunk=True): + return [] + + ################ + # chunk = 0 & not is_first_stage + # Recv y from PREV_rank as input + ################# + else: + prev_rank = self.stage_manager.get_prev_rank() + input_tensor, wait_handles = self.comm.recv_forward( + prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id] + ) + if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: + self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) + self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles)) + return wait_handles + + else: + ################ + # chunk = 1 & is_last_stage + # do nothing; cause u get y from local_send_forward_buffer in schedule f + ################ + if self.stage_manager.is_last_stage(ignore_chunk=True): + # return None, [] + return [] + + ################ + # chunk = 1 & not is_last_stage + # recv y from NEXT_rank as input + ################ + else: + next_rank = self.stage_manager.get_next_rank() + input_tensor, wait_handles = self.comm.recv_forward( + next_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id] + ) + if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: + self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) + self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles)) + return wait_handles + + def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List: + """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. + For ZBV. + + Args: + model_chunk_id (int): The current model chunk idx. + next_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input gradient tensor or gradient tensor list. + Any: The wait handles for the communication. + """ + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + if model_chunk_id == 0: + # bwd chunk0 is right V; + ################ + # chunk = 0 & is_last_stage + # do nothing; Already get dy from local_send_backward_buffer in schedule b + ################ + if self.stage_manager.is_last_stage(ignore_chunk=True): + return [] + + ################ + # chunk = 0 & not is_last_stage + # Recv bwd from next stage; + ################ + else: + next_rank = self.stage_manager.get_next_rank() + output_tensor_grad, wait_handles = self.comm.recv_backward( + next_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id] + ) + if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None: + self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad) + self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles)) + return wait_handles + + else: + # bwd chunk1 is left V; + ################ + # chunk = 1 & is_first_stage + # do nothing; get loss from local + ################ + if self.stage_manager.is_first_stage(ignore_chunk=True): + return [] + + ################ + # chunk = 1 & not first stage + # recv_backward recv bwd from prev stage; + ################ + else: + prev_rank = self.stage_manager.get_prev_rank() + output_tensor_grad, wait_handles = self.comm.recv_backward( + next_rank=prev_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id] + ) + if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None: + self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad) + self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles)) + return wait_handles + + def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List: + """Sends the input tensor to the next stage in pipeline. + For ZBV. + + Args: + model_chunk_id (int): The current model chunk idx. + next_rank (int, optional): The rank of the recipient of the tensor. + + Returns: + Any: The wait handles for the communication. + """ + + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + if model_chunk_id == 0: + ################ + # chunk = 0 && is_last_stage + # do nothing; hold y on local_send_forward_buffer + ################ + if self.stage_manager.is_last_stage(ignore_chunk=True): + self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache + return [] + + ################ + # chunk = 0 && not is_last_stage + # self.comm.send_forward send y to NEXT stage + ################ + else: + next_rank = self.stage_manager.get_next_rank() + output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) + send_handles = self.comm.send_forward( + output_object=output_tensor, + next_rank=next_rank, + send_metadata=self.send_tensor_metadata[model_chunk_id], + ) + self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache + return send_handles + + else: + ################ + # chunk = 1 && is_first_stage + # do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part + ################ + if self.stage_manager.is_first_stage(ignore_chunk=True): + self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache + return [] + + ################ + # chunk = 1 && not is_first_stage + # self.comm.send_forward send y to PREV stage + ################ + else: + prev_rank = self.stage_manager.get_prev_rank() + output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) + send_handles = self.comm.send_forward( + output_tensor, prev_rank, send_metadata=self.send_tensor_metadata[model_chunk_id] + ) + self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache + return send_handles + + def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List: + """Sends the gradient tensor to the previous stage in pipeline. + For ZBV. + + Args: + model_chunk_id (int): The current model chunk idx. + prev_rank (int, optional): The rank of the recipient of the tensor + + Returns: + Any: The wait handles for the communication. + """ + + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + if model_chunk_id == 0: + # bwd chunk0 is right V; + ################ + # chunk = 0 && is_first_stage + # do nothing; cause u are the first chunk in first stage; bwd end + ################ + if self.stage_manager.is_first_stage(ignore_chunk=True): + self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache + return [] + + ################ + # chunk = 0 && not is_first_stage + # Send dx to PREV stage; + ################ + else: + prev_rank = self.stage_manager.get_prev_rank() + input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) + send_handles = self.comm.send_backward( + input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata[model_chunk_id] + ) + self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache + return send_handles + + # bwd chunk1 is left V; + else: + ################ + # chunk = 1 && is_last_stage + # do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b; + ################ + if self.stage_manager.is_last_stage(ignore_chunk=True): + self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache + return [] + + ################ + # chunk = 1 && not is_last_stage + # Send dx to NEXT stage; + ################ + else: + next_rank = self.stage_manager.get_next_rank() + input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) + send_handles = self.comm.send_backward( + input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata[model_chunk_id] + ) + self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache + return send_handles + + def forward_step( + self, + model_chunk: Union[ModuleList, Module], + model_chunk_id: int, + micro_batch: Optional[dict], + input_obj: Optional[dict], + criterion: Callable, + accum_loss: Optional[torch.Tensor] = None, + outputs: Optional[List[Any]] = None, + ) -> Union[torch.Tensor, dict]: + """Forward one step of the pipeline + Args: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + input_obj (Optional[dict]): x; + criterion (Callable): loss function; + accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. + outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. + + Returns: + Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). + """ + # Load input ids, attention mask and labels + # for the first stage, input_obj is None; So,we use micro_batch as input_obj + # for other stages, input_obj is the output of the previous/next stage containing hidden_states etc. + # Only attention_mask from micro_batch is used + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + # fwd calculate + internal_inputs = {} if input_obj is None else input_obj + internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] + output_obj = model_forward(model_chunk, micro_batch, internal_inputs) + # last layer in model + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + loss = criterion(output_obj, micro_batch) / self.num_microbatch + if accum_loss is not None: + accum_loss.add_(loss.detach()) + if outputs is not None: + outputs.append(tree_map(detach, output_obj)) + return loss + else: + return output_obj + + def backward_b_step( + self, + model_chunk: Union[ModuleList, Module], + model_chunk_id: int, + optimizer: OptimizerWrapper, + # micro_batch: Optional[dict], + input_obj: Optional[dict], + output_obj: Union[dict, torch.Tensor], + output_obj_grad: Optional[dict], + ) -> Optional[dict]: + """Backward dx step of the pipeline; we calculate "dx = w*dy" here; + + Args: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + optimizer (OptimizerWrapper): Optimizer to update the model + input_obj (Optional[Tuple(dict)]): x. (microbatch, input_obj) + output_obj (Union[dict, torch.Tensor]): y. + output_obj_grad (dict): dy. + + Returns: + Optional[dict]: dx. + """ + # calculate bwd b step ; only dx = w*dy; + + # Retain the grad on the input_obj. No need retain_grad microbatch + if input_obj is not None: + tree_map(retain_grad, input_obj) + + # x, y, dy list for backward_by_grad; Type: list[tensor]; + input_obj_ = [] + output_obj_ = [] + output_obj_grad_ = [] + + # For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx. + + # For loss backward; output_obj is loss; output_obj_grad should be None + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + assert output_obj_grad is None + input_obj_, _ = tree_flatten(input_obj) + output_obj_.append(output_obj) # LOSS + output_obj_grad_.append(output_obj_grad) # None + + # For other chunk stage, use input_obj as input_obj_; + else: + input_obj_, _ = tree_flatten(input_obj) + output_obj_, _ = tree_flatten(output_obj) # y + output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + + # filter item which is not torch.Tensor + input_obj_ = [v for v in input_obj_ if isinstance(v, torch.Tensor) or v is None] + output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None] + output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None] + + try: + ctx = optimizer.no_sync() + except AttributeError: + ctx = model_chunk.no_sync() + with ctx: + optimizer.backward_by_grad( + tensor=output_obj_, + grad=output_obj_grad_, + # inputs=input_obj_, + retain_graph=False, + ) + # Format output_obj_grad + input_obj_grad = dict() + if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): + pass + else: + for k, v in input_obj.items(): + if isinstance(v, torch.Tensor) and v.grad is not None: + input_obj_grad[k] = v.grad + return input_obj_grad + + def backward_w_step( + self, + model_chunk: Union[ModuleList, Module], + model_chunk_id: int, + optimizer: OptimizerWrapper, + output_obj: Union[dict, torch.Tensor], + output_obj_grad: Optional[dict], + ): + """Backward dw step of the pipeline; we calculate "dw = x*dy" here; + + Args: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + optimizer (OptimizerWrapper): Optimizer to update the model + output_obj (Union[dict, torch.Tensor]): y. + output_obj_grad (dict): dy. + + Returns: + Nothing need to return; we only calculate dw then update w; + """ + # calculate bwd w step ; only dw = x*dy; + + # y, dy list for w backward_by_grad; Type: list[tensor]; + output_obj_ = [] + output_obj_grad_ = [] + + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + # loss backward; output_obj is loss; + output_obj_.append(output_obj) # LOSS + output_obj_grad_.append(None) # None + else: + output_obj_, _ = tree_flatten(output_obj) # y + output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + + # filter item which is not torch.Tensor + output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None] + output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None] + + optimizer.backward_by_grad( + tensor=output_obj_, + grad=output_obj_grad_, + inputs=list(model_chunk.parameters()), + retain_graph=False, + ) + + def schedule_f( + self, + scheduled_node, + model_chunk: torch.nn.ModuleList, + model_chunk_id: int, + criterion: Callable, + accum_loss: Optional[torch.Tensor] = None, + outputs: Optional[List[Any]] = None, + ): + """A complete forward schedule; Include recv fwd --> cal fwd --> send fwd; + + Args: + scheduled_node: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + criterion (Callable): loss function; + accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. + outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. + + Returns: + Nothing. + """ + micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) + # Step1: recv fwd + if model_chunk_id == 0: + # is first stage; get input from microbatch + if self.stage_manager.is_first_stage(ignore_chunk=True): + input_obj = None # (tensor, wait_handle) + else: + input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) + for h in input_obj[1]: + h.wait() + input_obj = input_obj[0] + else: + # is last stage; recv from local + if self.stage_manager.is_last_stage(ignore_chunk=True): + input_obj = self.local_send_forward_buffer.pop(0) + # not last stage; recv from next + else: + input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) + for h in input_obj[1]: + h.wait() + input_obj = input_obj[0] + # Here, let input_obj.requires_grad_() + # if input_obj is not None: + if not isinstance(input_obj, torch.Tensor): + tree_map(require_grad, input_obj) + + # Also requires_grad_ for micro_batch in stage 0 chunk 0 fwd, + # tree_map(torch.Tensor.requires_grad_, micro_batch) + + # Step2: fwd step + output_obj = self.forward_step( + model_chunk=model_chunk, + model_chunk_id=model_chunk_id, + micro_batch=micro_batch, + input_obj=input_obj, + criterion=criterion, + accum_loss=accum_loss, + outputs=outputs, + ) + + # Step3: + # 3-1:detach output; detach output for send fwd; + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + # We should not detach bwd LOSS + pass + else: + # detach output + detached_output_obj = tree_map(detach, output_obj) + # 3-2 clone detached_output_obj + detached_output_obj = tree_map(clone, detached_output_obj) + + # 3-3 release cloned output.data; release_tensor_data output for bwd b & w; (do not detach output) + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + # We should not release_tensor_data bwd LOSS + pass + else: + # release_tensor_data output + tree_map(release_tensor_data, output_obj) + + # add input and output object for backward b + self.input_tensors[model_chunk_id].append(input_obj) + + # for bwd b&w, we only need the graph(grad_fn) of output_obj + # Do not release_tensor_data loss, release_tensor_data other output_obj; + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + self.output_tensors[model_chunk_id].append(output_obj) + else: + self.output_tensors[model_chunk_id].append(output_obj) + + # add output to send_fwd_buffer + if model_chunk_id == 0: # chunk 0 + # is last stage; send to local_send_forward_buffer + if self.stage_manager.is_last_stage(ignore_chunk=True): + self.local_send_forward_buffer.append(detached_output_obj) + else: + self.send_forward_buffer[model_chunk_id].append(detached_output_obj) + else: # chunk 1 + # is first stage; end of fwd; do nothing + if self.stage_manager.is_first_stage(ignore_chunk=True): + pass + else: + self.send_forward_buffer[model_chunk_id].append(detached_output_obj) + + def schedule_b( + self, + scheduled_node, + model_chunk: Union[ModuleList, Module], + model_chunk_id: int, + optimizer: OptimizerWrapper, + ): + """A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd; + + Args: + scheduled_node: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + Returns: + Nothing. + """ + # Step1: recv bwd + if model_chunk_id == 0: + # chunk0 is last stage; recv output_grad from local_send_backward_buffer + if self.stage_manager.is_last_stage(ignore_chunk=True): + output_tensor_grad = self.local_send_backward_buffer.pop(0) + # chunk0 not last stage; recv output_grad from recv_backward_buffer + else: + output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) + for h in output_tensor_grad[1]: + h.wait() + output_tensor_grad = output_tensor_grad[0] + else: + # chunk1, is first stage; recv LOSS from local send bwd buffer + if self.stage_manager.is_first_stage(ignore_chunk=True): + output_tensor_grad = None + # chunk1, not first stage; recv output_grad from recv_backward_buffer + else: + output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) + for h in output_tensor_grad[1]: + h.wait() + output_tensor_grad = output_tensor_grad[0] + + # get input and output object from buffer; + input_obj = self.input_tensors[model_chunk_id].pop(0) + output_obj = self.output_tensors[model_chunk_id].pop(0) + + input_object_grad = self.backward_b_step( + model_chunk=model_chunk, + model_chunk_id=model_chunk_id, + optimizer=optimizer, + input_obj=input_obj, + output_obj=output_obj, + output_obj_grad=output_tensor_grad, + ) + + # Step3: send bwd + if model_chunk_id == 0: + # do nothing; end of bwd; + if self.stage_manager.is_first_stage(ignore_chunk=True): + pass + # save input_object_grad to send_backward_buffer + else: + self.send_backward_buffer[model_chunk_id].append(input_object_grad) + else: + # send to local_send_backward_buffer + if self.stage_manager.is_last_stage(ignore_chunk=True): + self.local_send_backward_buffer.append(input_object_grad) + # send to next + else: + self.send_backward_buffer[model_chunk_id].append(input_object_grad) + WeightGradStore.flush(chunk=model_chunk_id) + + def schedule_w( + self, + scheduled_node, + model_chunk: Union[ModuleList, Module], + model_chunk_id: int, + optimizer: OptimizerWrapper, + ): + """A complete backward w schedule; Include get y & dy from buffer --> cal bwd w step(cal dw & update w); + + Args: + scheduled_node: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + Returns: + Nothing. + """ + WeightGradStore.pop(chunk=model_chunk_id) + + def run_forward_only( + self, + model_chunk: Union[ModuleList, Module], + data_iter: Iterable, + criterion: Callable[..., Any], + return_loss: bool = False, + return_outputs: bool = False, + ) -> Dict: + assert self.forward_only + + # prepare batch + self.load_batch(data_iter) + + # prepare accum loss & output + accum_loss = None + + # reset accum loss at fwd end; + if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True): + accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device()) + + outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None + + # while we still have schedules_node in self.schedules + for it in range(len(self.schedules)): + scheduled_node = self.schedules[it] + + if scheduled_node.type in {"RECV_FORWARD", "SEND_FORWARD"}: + # communication + communication_func = self.communication_map[scheduled_node.type] + communication_func(scheduled_node.chunk) + if scheduled_node.type == "F": + self.schedule_f( + scheduled_node=scheduled_node, + model_chunk=model_chunk, + model_chunk_id=scheduled_node.chunk, + criterion=criterion, + accum_loss=accum_loss, + outputs=outputs, + ) + # return loss & output + if outputs is not None: + outputs = merge_batch(outputs) + return {"loss": accum_loss, "outputs": outputs} + + def run_forward_backward( + self, + model_chunk: Union[ModuleList, Module], + data_iter: Iterable, + criterion: Callable[..., Any], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = False, + return_outputs: bool = False, + ) -> Dict: + """ + Runs Zerobubble schedule, with communication between pipeline stages. + """ + # prepare batch + self.load_batch(data_iter) + + # prepare accum loss & output + accum_loss = None + + # reset accum loss at fwd end; + if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True): + accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device()) + + outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None + + # while we still have schedules_node in self.schedules + schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) + for it in range(len(schedule)): + scheduled_node = schedule[it] + if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: + # communication + communication_func = self.communication_map[scheduled_node.type] + wait_handle = communication_func(scheduled_node.chunk) + # We wait recv handle in fwd step and bwd step. Here only need to wait for send handle + if scheduled_node.type in {"SEND_FORWARD", "SEND_BACKWARD"}: + self.wait_handles.append(wait_handle) + elif scheduled_node.type == "F": + self.schedule_f( + scheduled_node=scheduled_node, + model_chunk=model_chunk, + model_chunk_id=scheduled_node.chunk, + criterion=criterion, + accum_loss=accum_loss, + outputs=outputs, + ) + elif scheduled_node.type == "B": + self.schedule_b( + scheduled_node=scheduled_node, + model_chunk=model_chunk, + model_chunk_id=scheduled_node.chunk, + optimizer=optimizer, + ) + elif scheduled_node.type == "W": + self.schedule_w( + scheduled_node=scheduled_node, + model_chunk=model_chunk, + model_chunk_id=scheduled_node.chunk, + optimizer=optimizer, + ) + # wait here to ensure all communication is done + for h in self.wait_handles: + for hh in h: + hh.wait() + # return loss & output + if outputs is not None: + outputs = merge_batch(outputs) + return {"loss": accum_loss, "outputs": outputs} + + def forward_backward_step( + self, + model_chunk: Union[ModuleList, Module], + data_iter: Iterable, + criterion: Callable[..., Any], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = False, + return_outputs: bool = False, + ) -> dict: + """ + Args: + model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification + data_iter (Iterable): Data iterator. + criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. + optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None. + return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. + return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. + + Returns: + dict: A dict with keys: 'loss' and 'outputs'. + """ + self.forward_only = not torch.is_grad_enabled() + if optimizer is None: + assert self.forward_only, "Optimizer should be passed when doing backward." + + if self.forward_only: + result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs) + else: + result = self.run_forward_backward( + model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs + ) + + self.assert_buffer_empty() + return result diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 354f110f0..8ef394ec3 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -26,6 +26,7 @@ class PipelineStageManager: pg_mesh: ProcessGroupMesh, pipeline_axis: int, enable_interleave: bool = False, + use_zbv: bool = False, num_model_chunks: int = 1, num_layers_per_stage: Optional[List[int]] = None, ) -> None: @@ -49,6 +50,7 @@ class PipelineStageManager: next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :] self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap") self.is_interleave = enable_interleave + self.use_zbv = use_zbv # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers self.num_model_chunks: int = num_model_chunks # for shardformer, hold stage indices of model @@ -85,6 +87,16 @@ class PipelineStageManager: num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) stage_indices = [] + if self.use_zbv: + stage_indices.append([num_layers_per_stage_accumulated[stage], num_layers_per_stage_accumulated[stage + 1]]) + stage_indices.append( + [ + num_layers_per_stage_accumulated[2 * num_stages - stage - 1], + num_layers_per_stage_accumulated[2 * num_stages - stage], + ] + ) + return stage_indices + for model_chunk in range(num_model_chunks): start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages] end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] @@ -124,7 +136,11 @@ class PipelineStageManager: if not self.is_interleave or ignore_chunk: return self.stage == self.num_stages - 1 else: - return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1 + # use zero bubble pipeline + if self.use_zbv: + return self.stage == 0 and self.model_chunk_id == self.num_model_chunks - 1 + else: + return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1 @property def num_stages(self) -> int: @@ -207,7 +223,6 @@ class PipelineStageManager: # calculate the num_layers per stage layers_per_stage = [quotient] * num_stages * num_model_chunks - # deal with the rest layers if remainder > 0: start_position = (num_stages * num_model_chunks) // 2 - remainder // 2 diff --git a/colossalai/pipeline/weight_grad_store.py b/colossalai/pipeline/weight_grad_store.py new file mode 100644 index 000000000..c51c45085 --- /dev/null +++ b/colossalai/pipeline/weight_grad_store.py @@ -0,0 +1,32 @@ +import queue + + +class WeightGradStore: + + cache = [] + weight_grad_queue = [queue.Queue(), queue.Queue()] + + @classmethod + def put(cls, total_input, grad_output, weight, func): + # func(total_input, grad_output, weight.main_grad) + cls.cache.append((total_input, grad_output, weight, func)) + + @classmethod + def flush(cls, chunk=0): + cls.weight_grad_queue[chunk].put(cls.cache) + cls.cache = [] + + @classmethod + def pop(cls, chunk=0): + # print(f"chunk id {chunk} queue size {cls.weight_grad_queue[chunk].qsize()}") + if cls.weight_grad_queue[chunk].qsize() > 0: + stored_grads = cls.weight_grad_queue[chunk].get() + for total_input, grad_output, weight, func in stored_grads: + if weight.grad is not None: + func(total_input, grad_output, weight.grad) + # for first bwd; weight.grad is None, assign grad_weight to weight.grad + else: + grad_weight = func(total_input, grad_output) + weight.grad = grad_weight + else: + raise Exception("Pop empty queue.") diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 684993de6..da5363840 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -2,7 +2,7 @@ from ._operation import all_to_all_comm from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D -from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D +from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D from .loss import cross_entropy_1d, dist_cross_entropy from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule @@ -11,6 +11,7 @@ from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLin __all__ = [ "Embedding1D", "VocabParallelEmbedding1D", + "LinearWithGradAccum", "Linear1D_Col", "Linear1D_Row", "GPT2FusedLinearConv1D_Col", diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 5499443b6..8c2e6e7c5 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -1,7 +1,11 @@ +import functools + import torch import torch.distributed as dist import torch.nn.functional as F +from colossalai.pipeline.weight_grad_store import WeightGradStore + from .utils import is_share_sp_tp try: @@ -125,12 +129,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): + def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce ctx.fp8_communication = fp8_communication + ctx.use_zbv = use_zbv if bias is not None: output = F.linear(input_, weight, bias) else: @@ -143,6 +148,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function): input, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias fp8_communication = ctx.fp8_communication + use_zbv = ctx.use_zbv + + def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): + wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_) + + def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): + return wgrad_gemm_func(_grad_output_.t(), _input_) # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias. if use_bias: @@ -164,24 +176,160 @@ class LinearWithAsyncCommunication(torch.autograd.Function): handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py - if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad - if grad.dtype == torch.float32: - fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) - grad_weight = None - elif grad.dtype == torch.float16: - fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + if use_zbv: + # TODO: append input, grad_output_, weight, grad func to WeightGradStore + if grad.dtype == torch.float32: + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, + ), + ) + grad_weight = None + elif grad.dtype in (torch.float16, torch.bfloat16): + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, + ), + ) + grad_weight = None + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + else: + if grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) + grad_weight = None + elif grad.dtype == torch.float16: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + else: + if use_zbv: + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass, + wgrad_gemm_func=torch.matmul, + ), + ) grad_weight = None else: grad_weight = grad_output.t().matmul(total_input) - else: - grad_weight = grad_output.t().matmul(total_input) grad_bias = grad_output.sum(dim=0) if use_bias else None if ctx.async_grad_allreduce and not fp8_communication: handle.wait() + return grad_input, grad_weight, grad_bias, None, None, None, None + + +class LinearWithGradAccum(torch.autograd.Function): + """ + Linear layer baseline (no tensor parallel version). + """ + + @staticmethod + def forward(ctx, input_, weight, bias, async_grad_allreduce, use_zbv=False): + ctx.save_for_backward(input_, weight, bias) + ctx.use_bias = bias is not None + ctx.async_grad_allreduce = async_grad_allreduce + ctx.use_zbv = use_zbv + if bias is not None: + output = F.linear(input_, weight, bias) + else: + output = F.linear(input_, weight) + + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + use_bias = ctx.use_bias + use_zbv = ctx.use_zbv + + def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): + wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_) + + def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): + return wgrad_gemm_func(_grad_output_.t(), _input_) + + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias. + if use_bias: + bias.view(bias.shape) + + total_input = input.contiguous() + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + if _grad_accum_fusion_available and weight.grad is not None: + grad = weight.grad + if use_zbv: + # TODO: append input, grad_output_, weight, grad func to WeightGradStore + if grad.dtype == torch.float32: + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, + ), + ) + grad_weight = None + elif grad.dtype in (torch.float16, torch.bfloat16): + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, + ), + ) + grad_weight = None + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + else: + if grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) + grad_weight = None + elif grad.dtype == torch.float16: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + else: + if use_zbv: + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass, + wgrad_gemm_func=torch.matmul, + ), + ) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + + grad_bias = grad_output.sum(dim=0) if use_bias else None return grad_input, grad_weight, grad_bias, None, None, None, None @@ -966,12 +1114,18 @@ def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allre ) -def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): +def linear_with_async_comm( + input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False +): return LinearWithAsyncCommunication.apply( - input_, weight, bias, process_group, async_grad_allreduce, fp8_communication + input_, weight, bias, process_group, async_grad_allreduce, fp8_communication, use_zbv ) +def linear_with_grad_accum(input_, weight, bias, async_grad_allreduce, use_zbv=False): + return LinearWithGradAccum.apply(input_, weight, bias, async_grad_allreduce, use_zbv) + + def linear_gather_forward_reducescatter_backward( input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False ): diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 0ba23c296..d39d6e997 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -27,13 +27,155 @@ from ._operation import ( linear_gather_forward_reducescatter_backward, linear_reducescatter_forward_gather_backward, linear_with_async_comm, + linear_with_grad_accum, reduce_forward, split_forward_gather_backward, ) from .parallel_module import PaddingParallelModule, ParallelModule from .utils import create_randomizer_with_offset, is_share_sp_tp -__all__ = ["Linear1D_Col", "Linear1D_Row"] +__all__ = ["LinearWithGradAccum", "Linear1D_Col", "Linear1D_Row"] + + +class LinearWithGradAccum(ParallelModule): + r"""Linear layer with no parallelism. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (`typing.Callable`): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (`typing.Callable`): + The initializer of bias, defaults to xavier uniform initializer. + + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + use_zbv: bool = False, + **kwargs, + ): + super().__init__(weight=weight, bias_=bias_, **kwargs) + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.skip_bias_add = skip_bias_add + self.device = device + self.use_zbv = use_zbv + + if skip_bias_add and not bias: + raise ValueError("cannot skip bias addition if bias is None") + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + + self.randomizer = create_randomizer_with_offset(seed, process_group=None) + + # sanity check + if weight is not None: + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" + else: + assert bias_ is None, "bias_ must be None if weight is None" + + # Parameters. + if weight is None: + factory_kwargs = {"device": device, "dtype": dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + + if bias: + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ + else: + self.bias = None + + if weight is None: + # init weights + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module(module: nn.Linear, **kwargs) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + linear_1d = LinearWithGradAccum( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + weight=module.weight, + bias_=module.bias, + **kwargs, + ) + + return linear_1d + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) + + # Set up backprop all-reduce. + input_parallel = input_ + + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + output_parallel = linear_with_grad_accum( + input_parallel, + self.weight, + bias, + False, + use_zbv=self.use_zbv, + ) + + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output class Linear1D_Col(ParallelModule): @@ -81,6 +223,7 @@ class Linear1D_Col(ParallelModule): weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), fp8_communication: bool = False, + use_zbv: bool = False, **kwargs, ): super().__init__(weight=weight, bias_=bias_, **kwargs) @@ -95,6 +238,7 @@ class Linear1D_Col(ParallelModule): self.device = device self.process_group = process_group self.fp8_communication = fp8_communication + self.use_zbv = use_zbv if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") @@ -209,9 +353,14 @@ class Linear1D_Col(ParallelModule): ) else: output_parallel = linear_with_async_comm( - input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication + input_parallel, + self.weight, + bias, + self.process_group, + True, + fp8_communication=self.fp8_communication, + use_zbv=self.use_zbv, ) - if self.gather_output: # All-gather across the partitions. output = gather_forward_split_backward( @@ -267,6 +416,7 @@ class Linear1D_Row(ParallelModule): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), stream_chunk_num: int = 1, fp8_communication: bool = False, + use_zbv: bool = False, ): super().__init__() @@ -282,6 +432,7 @@ class Linear1D_Row(ParallelModule): self.seq_parallel_dim = seq_parallel_dim self.num_partitions = dist.get_world_size(self.process_group) self.fp8_communication = fp8_communication + self.use_zbv = use_zbv if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 9a0da82f5..d1ad84604 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -82,7 +82,7 @@ class LlamaPipelineForwards: elif input_ids is not None: batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape[:2] + batch_size, seq_length = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: @@ -191,7 +191,6 @@ class LlamaPipelineForwards: num_model_chunks=stage_manager.num_model_chunks, ) assert num_ckpt_layers <= end_idx - start_idx - for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 4f8ec162f..a88db87bc 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -60,6 +60,7 @@ class EPMixtralSparseMoeBlock(ParallelModule): moe_dp_group: ProcessGroup, ep_group: ProcessGroup, fp8_communication: bool = False, + use_zbv: bool = False, ): assert tp_group is not None assert moe_dp_group is not None @@ -70,6 +71,7 @@ class EPMixtralSparseMoeBlock(ParallelModule): self.ep_rank = dist.get_rank(ep_group) self.ep_group = ep_group self.fp8_communication = fp8_communication + self.use_zbv = use_zbv if self.num_experts % self.ep_size != 0: raise ValueError("The number of experts must be divisible by the number of expert parallel groups.") @@ -89,13 +91,13 @@ class EPMixtralSparseMoeBlock(ParallelModule): if self.tp_group.size() > 1: for expert in held_experts: expert.w1 = Linear1D_Col.from_native_module( - expert.w1, self.tp_group, fp8_communication=self.fp8_communication + expert.w1, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv ) expert.w3 = Linear1D_Col.from_native_module( - expert.w3, self.tp_group, fp8_communication=self.fp8_communication + expert.w3, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv ) expert.w2 = Linear1D_Row.from_native_module( - expert.w2, self.tp_group, fp8_communication=self.fp8_communication + expert.w2, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv ) for p in self.experts.parameters(): @@ -379,7 +381,6 @@ class MixtralPipelineForwards: output_router_logits, use_cache, ) - hidden_states = layer_outputs[0] if use_cache: @@ -399,6 +400,7 @@ class MixtralPipelineForwards: if output_router_logits and past_router_logits is not None: all_router_logits = past_router_logits + all_router_logits + if stage_manager.is_last_stage(): if not return_dict: return tuple( @@ -512,7 +514,6 @@ class MixtralPipelineForwards: hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() - loss = None if labels is not None: # Shift so that tokens < n predict n diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 09673d396..63cd49280 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -75,6 +75,8 @@ class BertPolicy(Policy): sp_partial_derived = sp_mode == "split_gather" + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -97,6 +99,7 @@ class BertPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -105,6 +108,7 @@ class BertPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -113,6 +117,7 @@ class BertPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -125,6 +130,7 @@ class BertPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -138,6 +144,7 @@ class BertPolicy(Policy): "seq_parallel_mode": sp_mode, "skip_bias_add": self.enable_bias_gelu_fused, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -146,6 +153,97 @@ class BertPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ], + ) + + policy[BertEmbeddings] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ] + ) + if self.enable_bias_gelu_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_bert_intermediate_forward(), + }, + policy=policy, + target_key=BertIntermediate, + ) + + elif use_zbv: + policy[BertLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.self.query", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.self.key", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.self.value", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.self.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "skip_bias_add": self.enable_bias_gelu_fused, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index ec517da49..e8f9471f9 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -9,6 +9,7 @@ from colossalai.shardformer.layer import ( FusedRMSNorm, Linear1D_Col, Linear1D_Row, + LinearWithGradAccum, PaddingEmbedding, PaddingLMHead, RMSNorm, @@ -60,6 +61,8 @@ class LlamaPolicy(Policy): else: norm_cls = RMSNorm + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + sp_mode = self.shard_config.sequence_parallelism_mode or None sp_size = self.shard_config.sequence_parallel_size or None sp_group = self.shard_config.sequence_parallel_process_group or None @@ -102,7 +105,7 @@ class LlamaPolicy(Policy): policy=policy, target_key=LlamaModel, ) - + # enable tp, replace layer to tp Linear1D_Col,Linear1D_Row, if self.shard_config.enable_tensor_parallelism: assert ( num_q_heads % tp_size == 0 @@ -126,37 +129,135 @@ class LlamaPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="mlp.gate_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="mlp.up_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + ], + ) + + # not enable tp, replace layer to LinearWithGradAccum + elif use_zbv: + policy[LlamaDecoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), ], ) @@ -261,9 +362,10 @@ class LlamaPolicy(Policy): held_layers.append(module.embed_tokens) for start_idx, end_idx in stage_indices: held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(ignore_chunk=True): + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): held_layers.append(module.norm) - else: layers_per_stage = stage_manager.distribute_layers(len(module.layers)) if stage_manager.is_first_stage(): @@ -353,11 +455,15 @@ class LlamaForCausalLMPolicy(LlamaPolicy): """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(ignore_chunk=True): + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: + if self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv: + return [] llama_model = self.model.model if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if ( @@ -379,7 +485,9 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): from transformers import LlamaForSequenceClassification policy = super().module_policy() + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + # enable tp, replace layer to tp Linear1D_Col,Linear1D_Row, if self.shard_config.enable_tensor_parallelism: # add a new item for sequence classification new_item = { @@ -391,12 +499,32 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): kwargs=dict( gather_output=True, fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ) + ] + ) + } + policy.update(new_item) + # enable tp, replace layer to LinearWithGradAccum + elif use_zbv: + # add a new item for sequence classification + new_item = { + LlamaForSequenceClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="score", + target_module=LinearWithGradAccum, + kwargs=dict( + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, ), ) ] ) } policy.update(new_item) + # to be confirmed if self.pipeline_stage_manager: # set None as default @@ -411,7 +539,9 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(ignore_chunk=True): + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): held_layers.append(self.model.score) return held_layers diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 4d16038c1..b4b87df92 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -10,6 +10,7 @@ from colossalai.shardformer.layer import ( FusedRMSNorm, Linear1D_Col, Linear1D_Row, + LinearWithGradAccum, PaddingEmbedding, PaddingLMHead, VocabParallelEmbedding1D, @@ -62,6 +63,8 @@ class MistralPolicy(Policy): if self.tie_weight: embedding_cls = PaddingEmbedding + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn( @@ -90,6 +93,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -97,6 +101,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -104,6 +109,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -111,6 +117,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -118,6 +125,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -125,6 +133,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -132,6 +141,68 @@ class MistralPolicy(Policy): target_module=Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) + elif use_zbv: + policy[MistralDecoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), ], diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 8e2ca5de0..fab437c01 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -7,9 +7,18 @@ from torch import Tensor from torch.nn import Module from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel -from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col -from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D -from colossalai.shardformer.layer.linear import Linear1D_Row +from colossalai.shardformer.layer import ( + FusedRMSNorm, + Linear1D_Col, + Linear1D_Row, + LinearWithGradAccum, + PaddingEmbedding, + VocabParallelEmbedding1D, +) + +# from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col +# from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D +# from colossalai.shardformer.layer.linear import Linear1D_Row from colossalai.shardformer.modeling.mixtral import ( EPMixtralSparseMoeBlock, MixtralPipelineForwards, @@ -52,6 +61,7 @@ class MixtralPolicy(Policy): sp_group = self.shard_config.sequence_parallel_process_group or None sp_partial_derived = sp_mode in ["split_gather", "ring"] tp_size = self.shard_config.tensor_parallel_size + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv # modified for both SP and TP num_q_heads = self.model.config.num_attention_heads @@ -124,31 +134,92 @@ class MixtralPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, - kwargs={"fp8_communication": self.shard_config.fp8_communication}, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, - kwargs={"fp8_communication": self.shard_config.fp8_communication}, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, - kwargs={"fp8_communication": self.shard_config.fp8_communication}, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, - kwargs={"fp8_communication": self.shard_config.fp8_communication}, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, ), SubModuleReplacementDescription( suffix="block_sparse_moe.gate", target_module=Linear1D_Col, - kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication}, + kwargs={ + "gather_output": True, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, ), ], ) + elif use_zbv: + policy[MixtralDecoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="block_sparse_moe.gate", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( @@ -179,6 +250,7 @@ class MixtralPolicy(Policy): "tp_group": self.shard_config.tensor_parallel_process_group, "moe_dp_group": self.shard_config.moe_dp_group, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ) ], @@ -258,14 +330,30 @@ class MixtralPolicy(Policy): stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = stage_manager.distribute_layers(len(module.layers)) - if stage_manager.is_first_stage(): - held_layers.append(module.embed_tokens) - start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) - held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.norm) + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + stage_manager.stage_indices = stage_indices + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.embed_tokens) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.layers[start_idx:end_idx]) + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + # for zbv, when is_first_stage (last fwd), we append norm + # for interleaved, when is_last_stage (last fwd), we also append norm + held_layers.append(module.norm) + else: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) return held_layers @@ -297,6 +385,7 @@ class MixtralModelPolicy(MixtralPolicy): class MixtralForCausalLMPolicy(MixtralPolicy): def module_policy(self): policy = super().module_policy() + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv # TODO: assign pg mesh from plugin to all modules if self.shard_config.enable_tensor_parallelism: # add a new item for causal lm @@ -306,9 +395,29 @@ class MixtralForCausalLMPolicy(MixtralPolicy): SubModuleReplacementDescription( suffix="lm_head", target_module=Linear1D_Col, - kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ) - ] + ], + ) + } + policy.update(new_item) + elif use_zbv: + new_item = { + MixtralForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=LinearWithGradAccum, + kwargs=dict( + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ) + ], ) } policy.update(new_item) @@ -327,7 +436,9 @@ class MixtralForCausalLMPolicy(MixtralPolicy): """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(): + if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) + elif stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.lm_head) return held_layers @@ -353,6 +464,7 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy): from transformers import MixtralForSequenceClassification policy = super().module_policy() + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv if self.shard_config.enable_tensor_parallelism: # add a new item for sequence classification @@ -362,7 +474,11 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy): SubModuleReplacementDescription( suffix="score", target_module=Linear1D_Col, - kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ) ] ) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 680ff07fd..fe5fb82ca 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -351,7 +351,7 @@ class GeminiDDP(ModelWrapper): loss.backward() self._post_backward() - def backward_by_grad(self, tensor, grad): + def backward_by_grad(self, tensor, grad, inputs: torch.Tensor = None, retain_graph: bool = False): raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.") @staticmethod diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 9a2e9998d..ca91b4d9f 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -300,12 +300,14 @@ class GeminiOptimizer(OptimizerWrapper): loss = self.mix_precision_mixin.pre_backward(loss) self.module.backward(loss) - def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor): + def backward_by_grad( + self, tensor: torch.Tensor, grad: torch.Tensor, inputs: torch.Tensor = None, retain_graph: bool = False + ): # This function is called except the last stage of pipeline parallel # It receives the scaled grad from the previous rank # No need to scale the grad again # Need to unscale when optimizing - grad = self.mix_precision_mixin.pre_backward_by_grad(grad) + grad = self.mix_precision_mixin.pre_backward_by_grad(grad, inputs=inputs, retain_graph=retain_graph) self.module.backward_by_grad(tensor, grad) def _maybe_move_fp32_params(self): diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index db26269b4..74b1817a0 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -448,7 +448,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # torch.optim.Optimizer methods ################################ - def backward(self, loss, retain_graph=False): + def backward(self, loss, inputs=None, retain_graph=False): assert not ( self._partition_grads and not self.require_grad_sync ), "ZeRO2(partition_grads) and no_sync are not compatible" @@ -458,7 +458,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ctx = nullcontext() if self._backward_context is None else self._backward_context() with ctx: - loss.backward(retain_graph=retain_graph) + loss.backward(inputs=inputs, retain_graph=retain_graph) if not self.require_grad_sync: return @@ -469,14 +469,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if self._overlap_communication: get_accelerator().synchronize() - def backward_by_grad(self, tensor, grad): + def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False): assert not ( self._partition_grads and not self.require_grad_sync ), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" if self.mixed_precision_mixin is not None: grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) - torch.autograd.backward(tensor, grad) + torch.autograd.backward( + tensor, + grad, + inputs=inputs, + retain_graph=retain_graph, + ) if not self.require_grad_sync: return diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 0e88fabf1..1e49f0aa8 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -21,6 +21,7 @@ from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchF from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam +from colossalai.pipeline.schedule.v_schedule import PipelineGraph from colossalai.shardformer import PipelineGradientCheckpointConfig warnings.filterwarnings("ignore") @@ -39,6 +40,7 @@ MODEL_CONFIGS = { ), "5b": LlamaConfig(max_position_embeddings=4096, num_key_value_heads=8), "7b": LlamaConfig(max_position_embeddings=4096), + # "7b": LlamaConfig(num_hidden_layers=4, max_position_embeddings=4096), "13b": LlamaConfig( hidden_size=5120, intermediate_size=13824, @@ -91,7 +93,7 @@ def main(): parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled") parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) - parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"]) + parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"]) parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) parser.add_argument("--profile", action="store_true", help="Profile the code") parser.add_argument( @@ -106,6 +108,7 @@ def main(): parser.add_argument("--no_cache", action="store_true") parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication") parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear") + parser.add_argument("--overlap_p2p", action="store_true", default=True, help="for using overlap p2p") parser.add_argument("--overlap_allgather", action="store_true") parser.add_argument( "--sp_mode", @@ -137,6 +140,11 @@ def main(): # ============================== # Initialize Booster # ============================== + if args.config in MODEL_CONFIGS: + config = MODEL_CONFIGS[args.config] + else: + config = AutoConfig.from_pretrained(args.config, trust_remote_code=True) + use_empty_init = True if args.plugin == "gemini": plugin = GeminiPlugin( @@ -210,6 +218,24 @@ def main(): fp8_communication=args.use_fp8_comm, ) elif args.plugin == "3d": + if args.pp_style == "zbv": + mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length + mem_w = -32 * config.hidden_size + mem_b = -mem_w - mem_f + scheduler_nodes = PipelineGraph( + n_stage=args.pp, + n_micro=args.batch_size // args.mbs, + f_cost=1000, + b_cost=1000, + w_cost=1000, + c_cost=1, + f_mem=mem_f * 1.5, + b_mem=mem_b * 1.5, + w_mem=mem_w * 1.5, + ).get_v_schedule() + else: + scheduler_nodes = None + plugin = HybridParallelPlugin( tp_size=args.tp, pp_size=args.pp, @@ -227,6 +253,7 @@ def main(): overlap_allgather=args.overlap_allgather, use_fp8=args.use_fp8, fp8_communication=args.use_fp8_comm, + scheduler_nodes=scheduler_nodes, **hybrid_kwargs, ) elif args.plugin == "3d_cpu": @@ -242,7 +269,7 @@ def main(): microbatch_size=args.mbs, initial_scale=2**8, precision="bf16", - overlap_p2p=args.overlap, + overlap_p2p=args.overlap_p2p, use_fp8=args.use_fp8, fp8_communication=args.use_fp8_comm, ) @@ -260,6 +287,7 @@ def main(): config = MODEL_CONFIGS[args.config] else: config = AutoConfig.from_pretrained(args.config, trust_remote_code=True) + torch.cuda.manual_seed(42) dataset = RandomDataset( num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size @@ -319,7 +347,7 @@ def main(): args.profile, args.ignore_steps, 1, # avoid creating massive log files - save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", + save_dir=f"./profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", nsys=args.nsys, ) as prof: if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: @@ -334,8 +362,12 @@ def main(): return_loss=True, ) loss = outputs["loss"] - if dist.get_rank() == dist.get_world_size() - 1: - print(f"Step {step} loss: {loss}") + if args.pp_style == "zbv": + if coordinator.is_master(): + print(f"Step {step} loss: {loss}") + else: + if coordinator.is_last_process(): + print(f"Step {step} loss: {loss}") optimizer.step() optimizer.zero_grad() diff --git a/examples/language/mixtral/benchmark.py b/examples/language/mixtral/benchmark.py index bb2a32d01..dbffd0c2a 100644 --- a/examples/language/mixtral/benchmark.py +++ b/examples/language/mixtral/benchmark.py @@ -11,6 +11,7 @@ from data_utils import RandomDataset from model_utils import format_numel_str, get_model_numel from performance_evaluator import PerformanceEvaluator, get_profile_context from tqdm import tqdm +from transformers import AutoConfig from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM import colossalai @@ -20,6 +21,7 @@ from colossalai.booster.plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam +from colossalai.pipeline.schedule.v_schedule import PipelineGraph from colossalai.shardformer import PipelineGradientCheckpointConfig warnings.filterwarnings("ignore") @@ -85,7 +87,7 @@ def main(): parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled") parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) - parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"]) + parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"]) parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) parser.add_argument("--profile", action="store_true", help="Profile the code") parser.add_argument( @@ -129,7 +131,29 @@ def main(): # ============================== # Initialize Booster # ============================== + if args.config in MODEL_CONFIGS: + config = MODEL_CONFIGS[args.config] + else: + config = AutoConfig.from_pretrained(args.config, trust_remote_code=True) + if args.plugin == "3d": + if args.pp_style == "zbv": + mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length + mem_w = -32 * config.hidden_size + mem_b = -mem_w - mem_f + scheduler_nodes = PipelineGraph( + n_stage=args.pp, + n_micro=args.batch_size // args.mbs, + f_cost=1000, + b_cost=1000, + w_cost=1000, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, + ).get_v_schedule() + else: + scheduler_nodes = None plugin = MoeHybridParallelPlugin( ep_size=args.ep, tp_size=args.tp, @@ -143,11 +167,13 @@ def main(): enable_fused_normalization=torch.cuda.is_available(), enable_flash_attention=args.xformers, microbatch_size=args.mbs, + num_microbatches=args.batch_size // args.mbs, precision="bf16", enable_metadata_cache=not args.no_cache, overlap_allgather=args.overlap_allgather, use_fp8=args.use_fp8, fp8_communication=args.use_fp8_comm, + scheduler_nodes=scheduler_nodes, **hybrid_kwargs, ) else: @@ -183,8 +209,10 @@ def main(): with init_ctx: model = MixtralForCausalLM(config=config).to(torch.bfloat16) + # if args.grad_checkpoint: + # model.gradient_checkpointing_enable() if args.grad_checkpoint: - model.gradient_checkpointing_enable() + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) model_numel = get_model_numel(model) coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") @@ -229,8 +257,12 @@ def main(): return_loss=True, ) loss = outputs["loss"] - if dist.get_rank() == dist.get_world_size() - 1: - print(f"Step {step} loss: {loss}") + if args.pp_style == "zbv": + if dist.get_rank() == 0: + print(f"Step {step} loss: {loss}") + else: + if dist.get_rank() == dist.get_world_size() - 1: + print(f"Step {step} loss: {loss}") optimizer.step() optimizer.zero_grad() diff --git a/examples/language/performance_evaluator.py b/examples/language/performance_evaluator.py index 65c7e49a2..4bebf6d03 100644 --- a/examples/language/performance_evaluator.py +++ b/examples/language/performance_evaluator.py @@ -21,11 +21,16 @@ def divide(x: float, y: float) -> float: def all_reduce_mean(x: float, world_size: int) -> float: if world_size == 1: return x - - # Use CPU tensor to avoid OOM/weird NCCl error - gloo_group = dist.new_group(backend="gloo") - tensor = torch.tensor([x], device="cpu") - dist.all_reduce(tensor, group=gloo_group) + # BUG: RuntimeError: Invalid scalar type when use dist.all_reduce(tensor, group=gloo_group) + # # Use CPU tensor to avoid OOM/weird NCCl error + # gloo_group = dist.new_group(backend="gloo") + # tensor = torch.tensor([x], device="cpu") + # dist.all_reduce(tensor, group=gloo_group) + # tensor = tensor / world_size + # return tensor.item() + + tensor = torch.tensor([x], device=torch.cuda.current_device(), dtype=torch.float) + dist.all_reduce(tensor) tensor = tensor / world_size return tensor.item() diff --git a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py index e2f71ff89..f79bdeb3a 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py @@ -15,6 +15,7 @@ class _PipelineStageManager(PipelineStageManager): self.is_interleave = False self.num_layers_per_stage = None self.num_model_chunks = 1 + self.use_zbv = False @property def num_stages(self): diff --git a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py index d39c5ea91..722b8fd7c 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py @@ -15,6 +15,7 @@ class _PipelineStageManager(PipelineStageManager): self.is_interleave = False self.num_layers_per_stage = None self.num_model_chunks = 1 + self.use_zbv = False @property def num_stages(self): diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py new file mode 100644 index 000000000..a01b75eee --- /dev/null +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -0,0 +1,1085 @@ +from contextlib import nullcontext +from copy import deepcopy +from functools import partial +from typing import Tuple + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaModel +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralModel + +import colossalai +from colossalai.booster.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin, MoeHybridParallelPlugin +from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import OptimizerWrapper +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode +from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all +from tests.test_moe.moe_utils import assert_loose_close + +NUM_BATCH = 8 +NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4 +NUM_LAYERS = 8 +HIDDEN_SIZE_PER_HEAD = 4 +NUM_HEADS = 4 +TOP_K = 1 + + +class MlpModel(nn.Module): + def __init__( + self, + in_dim, + out_dim, + num_layers, + stage_index=None, + stage_mgr: PipelineStageManager = None, + ): + super().__init__() + self.layers = nn.Sequential(*[nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)]) + + def forward( + self, + data: torch.Tensor = None, + hidden_states: torch.Tensor = None, + stage_index=None, + stage_mgr: PipelineStageManager = None, + model_chunk_id: int = None, + ): + if stage_mgr is None: + hidden_states = data + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + else: + # Set not used layer to None + held_layers = self.layers[stage_index[0] : stage_index[1]] + + # fwd end + if stage_mgr.is_first_stage() and stage_mgr.model_chunk_id == 1: + return held_layers(hidden_states) + # fwd start + elif stage_mgr.is_first_stage() and stage_mgr.model_chunk_id == 0: + return {"hidden_states": held_layers(data)} + # fwd middle + else: + return {"hidden_states": held_layers(hidden_states)} + + def no_sync(self): + return nullcontext() + + +def assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups): + for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()): + if key_base == key_pp: + if key_base != "params": + assert val_base == val_pp + + +def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: + num_params = 0 + num_params_trainable = 0 + for p in model.parameters(): + num_params += p.numel() + if p.requires_grad: + num_params_trainable += p.numel() + return num_params, num_params_trainable + + +# 1) Test manual v_schedule with multiple microbatch +@parameterize( + "test_config", + [ + { + "batch_size": 8, + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "zero_stage": 1, + "precision": "bf16", + "num_model_chunk": 2, + }, + ], +) +def run_fwd_bwd_iter_input(test_config): + # init dist + rank = dist.get_rank() + pp_size = test_config["pp_size"] + pg_mesh = ProcessGroupMesh(pp_size) + num_microbatch = test_config["num_microbatches"] + num_model_chunk = test_config["num_model_chunk"] + # stage_manager + stage_manager = PipelineStageManager( + pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk + ) + + # schedule list + zbv_schedule = [ + # stage 0 + [ + # microbatch 0 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="F", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=0), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="F", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=0), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="B", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="W", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=0), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="B", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="W", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0), + # microbatch 1 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="F", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=1), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="F", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=1), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="B", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="W", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=1), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="B", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="W", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1), + # microbatch 2 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="F", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=2), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="F", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=2), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="B", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="W", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=2), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="B", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="W", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2), + # microbatch 3 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="F", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=3), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="F", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=3), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="B", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="W", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=3), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="B", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="W", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3), + ], + # stage 1 + [ + # microbatch 0 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="F", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=0), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="F", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=0), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="B", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="W", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=0), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="B", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="W", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0), + # microbatch 1 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="F", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=1), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="F", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=1), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="B", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="W", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=1), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="B", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="W", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1), + # microbatch 2 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="F", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=2), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="F", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=2), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="B", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="W", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=2), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="B", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="W", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2), + # microbatch 3 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="F", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=3), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="F", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=3), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="B", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="W", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=3), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="B", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="W", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3), + ], + # stage 2 + [ + # microbatch 0 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="F", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=0), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="F", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=0), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="B", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="W", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=0), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="B", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="W", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=0), + # microbatch 1 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="F", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=1), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="F", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=1), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="B", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="W", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=1), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="B", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="W", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=1), + # microbatch 2 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="F", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=2), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="F", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=2), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="B", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="W", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=2), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="B", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="W", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=2), + # microbatch 3 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="F", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=3), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="F", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=3), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="B", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="W", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=3), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="B", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="W", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=3), + ], + # stage 3 + [ + # microbatch 0 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="F", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=0), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="F", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=0), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="B", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="W", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=0), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="B", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="W", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=0), + # microbatch 1 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="F", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=1), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="F", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=1), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="B", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="W", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=1), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="B", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="W", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=1), + # microbatch 2 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="F", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=2), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="F", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=2), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="B", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="W", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=2), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="B", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="W", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=2), + # microbatch 3 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="F", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=3), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="F", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=3), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="B", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="W", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=3), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="B", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="W", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=3), + ], + ] + + scheduler = ZeroBubbleVPipeScheduler( + schedule=zbv_schedule, # hint: send whole schedule or local schedule only ? + stage_manager=stage_manager, + num_model_chunks=pp_size, + num_microbatch=num_microbatch, + overlap_p2p=False, + ) + + # loss func + def criterion(x, *args, **kwargs): + return (x * x).mean() + + # init model and input + batch_size = 4 + num_layers = 8 + in_dim = out_dim = 8 + print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") + model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) + data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] + + input_base = [t.clone() for t in data_iter] + model_base = deepcopy(model) + + if rank == 0: + # layer 0 & 7 to chunk 0 on rank0 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 0 or idx == 7: + local_chunk.append(sub_model) + elif rank == 1: + # layer 1 & 6 to chunk 1 on rank1 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 1 or idx == 6: + local_chunk.append(sub_model) + elif rank == 2: + # layer 2 & 5 to chunk 2 on rank2 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 2 or idx == 5: + local_chunk.append(sub_model) + else: + # layer 3 & 4 to chunk 3 on rank3 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 3 or idx == 4: + local_chunk.append(sub_model) + # init optimizer + optimizer_base = torch.optim.SGD(model_base.parameters(), lr=1e-5) + optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), lr=1e-5)) + + print( + f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + torch.cuda.synchronize() + result = scheduler.forward_backward_step( + model_chunk=local_chunk, + data_iter=iter(data_iter), + criterion=criterion, + optimizer=optimizer_pp, + return_loss=True, + return_outputs=True, + ) + + optimizer_pp.step() + + ########################## + # Fwd bwd for base + ########################## + # fwd & bwd + output_base = model_base(input_base[0]) + loss_base = criterion(output_base) + loss_base.backward() + optimizer_base.step() + print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + ########################## + # assert weight + ########################## + if rank == 0: + # layer 0 + assert_close(local_chunk[0].weight, model_base.layers[0].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad) + # layer 7 + assert_close(local_chunk[1].weight, model_base.layers[7].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad) + if rank == 1: + # layer 1 + assert_close(local_chunk[0].weight, model_base.layers[1].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad) + # layer 6 + assert_close(local_chunk[1].weight, model_base.layers[6].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad) + if rank == 2: + # layer 2 + assert_close(local_chunk[0].weight, model_base.layers[2].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad) + # layer 5 + assert_close(local_chunk[1].weight, model_base.layers[5].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad) + if rank == 3: + # layer 3 + assert_close(local_chunk[0].weight, model_base.layers[3].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad) + # layer 4 + assert_close(local_chunk[1].weight, model_base.layers[4].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad) + + +# 2) add optimizer base 1) +@parameterize( + "test_config", + [ + { + "batch_size": 8, + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "zero_stage": 1, + "precision": "bf16", + "num_model_chunk": 2, + }, + { + "batch_size": 8, + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 8, + "zero_stage": 1, + "precision": "bf16", + "num_model_chunk": 2, + }, + ], +) +def run_fwd_bwd_vschedule_with_optim(test_config): + # init dist + rank = dist.get_rank() + pp_size = test_config["pp_size"] + pg_mesh = ProcessGroupMesh(pp_size) + num_microbatch = test_config["num_microbatches"] + num_model_chunk = test_config["num_model_chunk"] + # stage_manager + stage_manager = PipelineStageManager( + pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk, use_zbv=True + ) + + h, a, s = 4096, 32, 1024 + mem_f = 34 * h + 5 * a * s + mem_w = -32 * h + mem_b = -mem_w - mem_f + graph = PipelineGraph( + n_stage=pp_size, + n_micro=num_microbatch, + f_cost=1, + b_cost=1, + w_cost=1, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, + # max_mem=mem_f * (p * 2 + m_offset), + ) + + zbv_schedule = graph.get_v_schedule() + + scheduler = ZeroBubbleVPipeScheduler( + schedule=zbv_schedule, # hint: send whole schedule or local schedule only ? + stage_manager=stage_manager, + num_model_chunks=num_model_chunk, + num_microbatch=num_microbatch, + overlap_p2p=False, + ) + + # init loss func + def criterion(x, *args, **kwargs): + x = x["hidden_states"] + return (x * x).mean() + + def criterion_base(x, *args, **kwargs): + return (x * x).mean() + + # init model and input + batch_size = test_config["batch_size"] + num_layers = 8 + assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk" + in_dim = out_dim = 1024 + before_init_memory = torch.cuda.memory_allocated() / 1024**3 + print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};") + model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) + data_iter = {"data": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)} + input_base = {k: v.clone() for k, v in data_iter.items()} + model_base = deepcopy(model) + model_pp = deepcopy(model) + layers_per_stage = stage_manager.distribute_layers(len(model.layers)) + stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) + + model_pp._forward = model_pp.forward + + model_pp.forward = partial(model_pp._forward, stage_mgr=stage_manager) + + # init optimizer + optimizer_base = torch.optim.SGD(model_base.parameters(), momentum=0.1, lr=1e-5) + optimizer_pp = OptimizerWrapper(torch.optim.SGD(model_pp.parameters(), momentum=0.1, lr=1e-5)) + + after_init_memory = torch.cuda.memory_allocated() / 1024**3 + print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};") + + torch.cuda.synchronize() + result = scheduler.forward_backward_step( + model_chunk=model_pp, + data_iter=iter([data_iter]), + criterion=criterion, + optimizer=optimizer_pp, + return_loss=True, + return_outputs=True, + ) + + optimizer_pp.step() + + after_pp_step_memory = torch.cuda.memory_allocated() / 1024**3 + + # assert memory + if rank != 0: + # w.grad: hid_dim * hid_dim * microbatch * 4(fp32) * 2 (2 layer in each stage) / 1024**3 + # output: hid_dim * hid_dim * microbatch * 4(fp32) / 1024**3 + # optim: state hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3 + print( + f" num_microbatch {num_microbatch} rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 5 * batch_size / 1024**3)}" + ) + assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 5 * batch_size / 1024**3) + else: + # rank0 will also hold output; + print( + f" num_microbatch {num_microbatch} rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 5 * batch_size / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}" + ) + assert round((after_pp_step_memory - after_init_memory), 5) <= round( + (in_dim * in_dim * 4 * 5 * batch_size / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5 + ) + + ########################## + # Fwd bwd for base + ########################## + # fwd & bwd + # output_base = model_base(input_base["data"]) + output_base = model_base.forward(data=input_base["data"]) + loss_base = criterion_base(output_base) + loss_base.backward() + optimizer_base.step() + + ########################## + # assert loss & output + ########################## + # only chunk 1 stage 0 hold loss and output + if rank == 0: + assert_close(result["loss"], loss_base) + assert_close(result["outputs"]["hidden_states"], output_base) + + # ########################## + # # assert weight & optim state + # ########################## + optim_base_state = optimizer_base.state_dict()["state"] + optim_pp_state = optimizer_pp.state_dict()["state"] + optim_base_param_groups = optimizer_base.state_dict()["param_groups"][0] + optim_pp_param_groups = optimizer_pp.state_dict()["param_groups"][0] + + if rank == 0: + # layer 0 + assert_close(model_pp.layers[0].weight, model_base.layers[0].weight) + assert_close(model_pp.layers[0].weight.grad, model_base.layers[0].weight.grad) + assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[0]["momentum_buffer"]) + # layer 7 + assert_close(model_pp.layers[7].weight, model_base.layers[7].weight) + assert_close(model_pp.layers[7].weight.grad, model_base.layers[7].weight.grad) + assert_close(optim_pp_state[7]["momentum_buffer"], optim_base_state[7]["momentum_buffer"]) + if rank == 1: + # layer 1 + assert_close(model_pp.layers[1].weight, model_base.layers[1].weight) + assert_close(model_pp.layers[1].weight.grad, model_base.layers[1].weight.grad) + assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[1]["momentum_buffer"]) + # layer 6 + assert_close(model_pp.layers[6].weight, model_base.layers[6].weight) + assert_close(model_pp.layers[6].weight.grad, model_base.layers[6].weight.grad) + assert_close(optim_pp_state[6]["momentum_buffer"], optim_base_state[6]["momentum_buffer"]) + if rank == 2: + # layer 2 + assert_close(model_pp.layers[2].weight, model_base.layers[2].weight) + assert_close(model_pp.layers[2].weight.grad, model_base.layers[2].weight.grad) + assert_close(optim_pp_state[2]["momentum_buffer"], optim_base_state[2]["momentum_buffer"]) + # layer 5 + assert_close(model_pp.layers[5].weight, model_base.layers[5].weight) + assert_close(model_pp.layers[5].weight.grad, model_base.layers[5].weight.grad) + assert_close(optim_pp_state[5]["momentum_buffer"], optim_base_state[5]["momentum_buffer"]) + if rank == 3: + # layer 3 + assert_close(model_pp.layers[3].weight, model_base.layers[3].weight) + assert_close(model_pp.layers[3].weight.grad, model_base.layers[3].weight.grad) + assert_close(optim_pp_state[3]["momentum_buffer"], optim_base_state[3]["momentum_buffer"]) + # layer 4 + assert_close(model_pp.layers[4].weight, model_base.layers[4].weight) + assert_close(model_pp.layers[4].weight.grad, model_base.layers[4].weight.grad) + assert_close(optim_pp_state[4]["momentum_buffer"], optim_base_state[4]["momentum_buffer"]) + + # assert optim param_groups + assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups) + + +@parameterize( + "config", + [ + (1, 2, 1, 1, 2), + (1, 1, 2, 2, 1), + (1, 2, 1, 2, 1), + (1, 2, 2, 1, 1), + (1, 1, 4, 1, 1), + ], +) +def run_with_booster_moehybridplugin(config: Tuple[int, ...]): + stage, ep_size, pp_size, tp_size, sp_size = config + num_microbatches = pp_size + dist.get_world_size() + rank = dist.get_rank() + dtype, precision = torch.float16, "fp16" + torch.cuda.set_device(dist.get_rank()) + + ######## + # init base model + ######## + assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS" + config = MixtralConfig( + hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, + intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + num_hidden_layers=NUM_LAYERS, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_HEADS, + num_local_experts=NUM_EXPERTS, + num_experts_per_tok=TOP_K, + attn_implementation="flash_attention_2", + ) + + # init model with the same seed + seed_all(10086) + + torch_model = MixtralModel(config).to(dtype).cuda() + # TODO: Support MixtralForCausalLM + # torch_model = MixtralForCausalLM(config).to(dtype).cuda() + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) + # init schedule + h, a, s = config.hidden_size, config.num_attention_heads, 1024 + mem_f = 34 * h + 5 * a * s + mem_w = -32 * h + mem_b = -mem_w - mem_f + graph = PipelineGraph( + n_stage=pp_size, + n_micro=num_microbatches, + f_cost=1, + b_cost=1, + w_cost=1, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, + # max_mem=mem_f * (p * 2 + m_offset), + ) + + zbv_schedule = graph.get_v_schedule() + + # init MoeHybridPlugin + plugin = MoeHybridParallelPlugin( + pp_size=pp_size, + num_microbatches=pp_size, + tp_size=tp_size, + sp_size=sp_size, + ep_size=ep_size, + zero_stage=stage, + enable_sequence_parallelism=sp_size > 1, + sequence_parallelism_mode="all_to_all" if sp_size > 1 else None, + overlap_communication=False, + initial_scale=1, + precision=precision, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, + ) + + dp_size = plugin.dp_size + + booster = Booster(plugin=plugin) + + ######## + # init pp model + ######## + + parallel_model = deepcopy(torch_model) + parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1) + parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer) + # create different input along dp axis + seed_all(1453 + rank) + + torch_model.train() + parallel_model.train() + for _ in range(2): + # gen random input + input_embeddings = torch.rand( + NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True + ).cuda() + dist.all_reduce( + input_embeddings, group=plugin.pp_group + ) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check + + dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input + dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input + + # run the model with hybrid parallel + if booster.plugin.stage_manager is not None: + # for test with pp + data_iter = iter([{"inputs_embeds": input_embeddings}]) + sharded_output = booster.execute_pipeline( + data_iter, + parallel_model, + lambda x, y: x.last_hidden_state.mean(), + parallel_optimizer, + return_loss=True, + return_outputs=True, + ) + # stage 0 chunk 0 + if ( + booster.plugin.stage_manager.is_first_stage(ignore_chunk=True) + and rank == dist.get_process_group_ranks(plugin.pp_group)[0] + ): + parallel_output = sharded_output["loss"] + else: + parallel_output = torch.tensor(12345.0, device="cuda") + # broadcast along pp axis + dist.broadcast(parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[0], group=plugin.pp_group) + + else: + # for test without pp + parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean() + parallel_optimizer.backward(parallel_output) + parallel_optimizer.step() + parallel_optimizer.zero_grad() + dist.all_reduce(parallel_output, group=plugin.dp_group) + + # =================================================================================== + # run normal model with all dp(different) inputs + all_inputs = [input_embeddings.clone() for _ in range(dp_size)] + dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) + torch_output_sum = 0 + for input_data_ in all_inputs: + torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() + torch_output.backward() + torch_output_sum += torch_output.detach() + # avg dp grads follows zero optimizer + for p in torch_model.parameters(): + if p.grad is not None: + p.grad /= dp_size + torch_optimizer.step() + torch_optimizer.zero_grad() + assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize( + "config", + [ + # Pass + (1, 2, 2, 1), + (1, 2, 1, 2), + (1, 1, 2, 2), + (1, 4, 1, 1), + ], +) +def run_with_booster_hybridplugin(config: Tuple[int, ...]): + stage, pp_size, tp_size, sp_size = config + num_microbatches = pp_size + dist.get_world_size() + rank = dist.get_rank() + dtype, precision = torch.float16, "fp16" + torch.cuda.set_device(dist.get_rank()) + + ######## + # init base model + ######## + assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS" + config = LlamaConfig( + hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, + intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + num_hidden_layers=NUM_LAYERS, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_HEADS, + attn_implementation="flash_attention_2", + ) + + # init model with the same seed + seed_all(10086) + + torch_model = LlamaModel(config).to(dtype).cuda() + # TODO: Support MixtralForCausalLM + # torch_model = MixtralForCausalLM(config).to(dtype).cuda() + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) + # init schedule + h, a, s = config.hidden_size, config.num_attention_heads, 1024 + mem_f = 34 * h + 5 * a * s + mem_w = -32 * h + mem_b = -mem_w - mem_f + graph = PipelineGraph( + n_stage=pp_size, + n_micro=num_microbatches, + f_cost=1, + b_cost=1, + w_cost=1, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, + ) + + zbv_schedule = graph.get_v_schedule() + + # init HybridParallelPlugin + plugin = HybridParallelPlugin( + pp_size=pp_size, + num_microbatches=pp_size, + tp_size=tp_size, + sp_size=sp_size, + zero_stage=stage, + enable_sequence_parallelism=sp_size > 1, + sequence_parallelism_mode="all_to_all" if sp_size > 1 else None, + overlap_communication=False, + initial_scale=1, + precision=precision, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, + ) + + dp_size = plugin.dp_size + + booster = Booster(plugin=plugin) + + ######## + # init pp model + ######## + + parallel_model = deepcopy(torch_model) + parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1) + parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer) + # create different input along dp axis + seed_all(1453 + rank) + + torch_model.train() + parallel_model.train() + for _ in range(2): + # gen random input + input_embeddings = torch.rand( + NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True + ).cuda() + dist.all_reduce( + input_embeddings, group=plugin.pp_group + ) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check + + dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input + dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input + + # run the model with hybrid parallel + if booster.plugin.stage_manager is not None: + # for test with pp + data_iter = iter([{"inputs_embeds": input_embeddings}]) + sharded_output = booster.execute_pipeline( + data_iter, + parallel_model, + lambda x, y: x.last_hidden_state.mean(), + parallel_optimizer, + return_loss=True, + return_outputs=True, + ) + # stage 0 chunk 0 + if ( + booster.plugin.stage_manager.is_first_stage(ignore_chunk=True) + and rank == dist.get_process_group_ranks(plugin.pp_group)[0] + ): + parallel_output = sharded_output["loss"] + else: + parallel_output = torch.tensor(12345.0, device="cuda") + # broadcast along pp axis + dist.broadcast(parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[0], group=plugin.pp_group) + + else: + # for test without pp + parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean() + parallel_optimizer.backward(parallel_output) + parallel_optimizer.step() + parallel_optimizer.zero_grad() + dist.all_reduce(parallel_output, group=plugin.dp_group) + + # =================================================================================== + # run normal model with all dp(different) inputs + all_inputs = [input_embeddings.clone() for _ in range(dp_size)] + dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) + torch_output_sum = 0 + for input_data_ in all_inputs: + torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() + torch_output.backward() + torch_output_sum += torch_output.detach() + # avg dp grads follows zero optimizer + for p in torch_model.parameters(): + if p.grad is not None: + p.grad /= dp_size + torch_optimizer.step() + torch_optimizer.zero_grad() + assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +def run_dist(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_with_booster_moehybridplugin() + run_with_booster_hybridplugin() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_pp(): + spawn( + run_dist, + nprocs=4, + ) + + +# python -m pytest -s tests/test_pipeline/test_schedule/test_zerobubble_pp.py +if __name__ == "__main__": + test_pp() diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index 541aa3251..773799c1c 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -8,7 +8,8 @@ from torch.testing import assert_close import colossalai from colossalai.lazy import LazyInitContext -from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row +from colossalai.pipeline.weight_grad_store import WeightGradStore +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, LinearWithGradAccum from colossalai.tensor.d_tensor import is_distributed_tensor from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -117,6 +118,93 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel_mode: bool): assert_close(x_for_unshard.grad, x_for_shard.grad) +def check_linear_without_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + linear = nn.Linear(32, 128).cuda() + with ctx: + linear_copy = nn.Linear(32, 128).cuda() + linear_base = LinearWithGradAccum.from_native_module( + linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=False + ) + assert linear_base.weight.shape == torch.Size([128, 32]) + assert linear_base.bias.shape == torch.Size([128]) + assert linear_copy.weight is linear_base.weight + assert linear_copy.bias is linear_base.bias + + linear.load_state_dict(linear_base.state_dict()) + linear_base.load_state_dict(linear.state_dict()) + + # check computation correctness + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() + x_for_unshard = x.expand_as(x.clone()) + x_for_unshard.requires_grad_(True) + x_for_shard = x.expand_as(x.clone()) + x_for_shard.requires_grad_(True) + + # run forward + out = linear(x_for_unshard) + gather_out = linear_base(x_for_shard) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + assert_close(linear.weight.grad, linear_base.weight.grad) + # check the input gradients + assert x_for_shard.grad is not None + assert x_for_unshard.grad is not None + assert_close(x_for_unshard.grad, x_for_shard.grad) + + +def check_linear_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + linear = nn.Linear(32, 128).cuda() + with ctx: + linear_copy = nn.Linear(32, 128).cuda() + linear_base = LinearWithGradAccum.from_native_module( + linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=True + ) + assert linear_base.weight.shape == torch.Size([128, 32]) + assert linear_base.bias.shape == torch.Size([128]) + assert linear_copy.weight is linear_base.weight + assert linear_copy.bias is linear_base.bias + + linear.load_state_dict(linear_base.state_dict()) + linear_base.load_state_dict(linear.state_dict()) + + # check computation correctness + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() + x_for_unshard = x.expand_as(x.clone()) + x_for_unshard.requires_grad_(True) + x_for_shard = x.expand_as(x.clone()) + x_for_shard.requires_grad_(True) + + # run forward + out = linear(x_for_unshard) + gather_out = linear_base(x_for_shard) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + # Weight grad is None before we do WeightGradStore pop + assert linear_base.weight.grad is None + # after WeightGradStore pop (dw computation complete), we assert weight grad + WeightGradStore.flush(chunk=0) # flush buffer to chunk 0 Queue + WeightGradStore.pop(chunk=0) + assert_close(linear.weight.grad, linear_base.weight.grad) + + # check the input gradients + assert x_for_shard.grad is not None + assert x_for_unshard.grad is not None + assert_close(x_for_unshard.grad, x_for_shard.grad) + + def check_linear_col_plus_row(lazy_init: bool, seq_parallel_mode: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() @@ -182,6 +270,8 @@ def run_dist_linear_test(lazy_init, seq_parallel_mode, overlap): check_linear_1d_col(lazy_init, seq_parallel_mode, overlap) check_linear_1d_row(lazy_init, seq_parallel_mode) check_linear_col_plus_row(lazy_init, seq_parallel_mode, overlap) + check_linear_without_weight_grad_store(lazy_init, seq_parallel_mode) + check_linear_with_weight_grad_store(lazy_init, seq_parallel_mode) def check_dist_linear(rank, world_size, port): diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 0f9ec6013..3a8057c1f 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -310,8 +310,16 @@ def check_output_hidden_state( ): org_hidden_state = org_output.last_hidden_state - if stage_manager and stage_manager.is_last_stage(ignore_chunk=True): - sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] + if stage_manager: + if stage_manager.use_zbv: + if stage_manager.is_first_stage(ignore_chunk=True): + sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] + else: + sharded_hidden_state = sharded_output.last_hidden_state + elif stage_manager.is_last_stage(ignore_chunk=True): + sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] + else: + sharded_hidden_state = sharded_output.last_hidden_state else: sharded_hidden_state = sharded_output.last_hidden_state @@ -388,7 +396,6 @@ def get_grad_tensors_for_check( pass if verbose and dist.get_rank() == 0: print(f"'{suffix}' grad: {org_grad}, {shard_grad}") - grad_to_check[suffix] = { "org_grad": org_grad.float(), "shard_grad": shard_grad.float(), diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index e8f791697..b97846408 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -7,6 +7,7 @@ from torch.testing import assert_close import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.schedule.v_schedule import PipelineGraph from colossalai.shardformer import PipelineGradientCheckpointConfig from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter @@ -33,7 +34,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) if enable_gradient_checkpointing: # org_model.gradient_checkpointing_enable() - sharded_model.unwrap().gradient_checkpointing_enable() + sharded_model.unwrap().gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster @@ -112,12 +113,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, sharded_optimizer.step() # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True): + check_flag = False + if ( + (stage_manager is None) + or (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) + or (not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)) + ): + check_flag = True + if check_flag: if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == "LlamaModel": check_output_hidden_state( org_output, @@ -274,6 +281,22 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) def run_llama_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") + if test_config.get("pp_style", None) == "zbv": + mem_f = 34 * 32 + 5 * 4 * 16 + mem_w = -32 * 32 + mem_b = -mem_w - mem_f + scheduler_nodes = PipelineGraph( + n_stage=test_config["pp_size"], + n_micro=test_config["num_microbatches"], + f_cost=1000, + b_cost=1000, + w_cost=1000, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, + ).get_v_schedule() + test_config["scheduler_nodes"] = scheduler_nodes for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name: continue